Chore Support base64 embedding format (#485)
* chore: support base64 embedding format * fix: add sizeOfFloat32 * chore: refactor base64 decoding * chore: add tests * fix linting * fix test * fix return error * fix: use smaller slice for tests * fix [skip ci] * chore: refactor test to consider CreateEmbeddings response * trigger build * chore: remove named returns * chore: refactor code to simplify the understanding * chore: tests have been refactored to match the encoding format passed by request * chore: fix tests * fix * fix
This commit is contained in:
116
embeddings.go
116
embeddings.go
@@ -2,6 +2,9 @@ package openai
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/binary"
|
||||||
|
"math"
|
||||||
"net/http"
|
"net/http"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -129,15 +132,83 @@ type EmbeddingResponse struct {
|
|||||||
Usage Usage `json:"usage"`
|
Usage Usage `json:"usage"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type base64String string
|
||||||
|
|
||||||
|
func (b base64String) Decode() ([]float32, error) {
|
||||||
|
decodedData, err := base64.StdEncoding.DecodeString(string(b))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
const sizeOfFloat32 = 4
|
||||||
|
floats := make([]float32, len(decodedData)/sizeOfFloat32)
|
||||||
|
for i := 0; i < len(floats); i++ {
|
||||||
|
floats[i] = math.Float32frombits(binary.LittleEndian.Uint32(decodedData[i*4 : (i+1)*4]))
|
||||||
|
}
|
||||||
|
|
||||||
|
return floats, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Base64Embedding is a container for base64 encoded embeddings.
|
||||||
|
type Base64Embedding struct {
|
||||||
|
Object string `json:"object"`
|
||||||
|
Embedding base64String `json:"embedding"`
|
||||||
|
Index int `json:"index"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// EmbeddingResponseBase64 is the response from a Create embeddings request with base64 encoding format.
|
||||||
|
type EmbeddingResponseBase64 struct {
|
||||||
|
Object string `json:"object"`
|
||||||
|
Data []Base64Embedding `json:"data"`
|
||||||
|
Model EmbeddingModel `json:"model"`
|
||||||
|
Usage Usage `json:"usage"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToEmbeddingResponse converts an embeddingResponseBase64 to an EmbeddingResponse.
|
||||||
|
func (r *EmbeddingResponseBase64) ToEmbeddingResponse() (EmbeddingResponse, error) {
|
||||||
|
data := make([]Embedding, len(r.Data))
|
||||||
|
|
||||||
|
for i, base64Embedding := range r.Data {
|
||||||
|
embedding, err := base64Embedding.Embedding.Decode()
|
||||||
|
if err != nil {
|
||||||
|
return EmbeddingResponse{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
data[i] = Embedding{
|
||||||
|
Object: base64Embedding.Object,
|
||||||
|
Embedding: embedding,
|
||||||
|
Index: base64Embedding.Index,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return EmbeddingResponse{
|
||||||
|
Object: r.Object,
|
||||||
|
Model: r.Model,
|
||||||
|
Data: data,
|
||||||
|
Usage: r.Usage,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
type EmbeddingRequestConverter interface {
|
type EmbeddingRequestConverter interface {
|
||||||
// Needs to be of type EmbeddingRequestStrings or EmbeddingRequestTokens
|
// Needs to be of type EmbeddingRequestStrings or EmbeddingRequestTokens
|
||||||
Convert() EmbeddingRequest
|
Convert() EmbeddingRequest
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// EmbeddingEncodingFormat is the format of the embeddings data.
|
||||||
|
// Currently, only "float" and "base64" are supported, however, "base64" is not officially documented.
|
||||||
|
// If not specified OpenAI will use "float".
|
||||||
|
type EmbeddingEncodingFormat string
|
||||||
|
|
||||||
|
const (
|
||||||
|
EmbeddingEncodingFormatFloat EmbeddingEncodingFormat = "float"
|
||||||
|
EmbeddingEncodingFormatBase64 EmbeddingEncodingFormat = "base64"
|
||||||
|
)
|
||||||
|
|
||||||
type EmbeddingRequest struct {
|
type EmbeddingRequest struct {
|
||||||
Input any `json:"input"`
|
Input any `json:"input"`
|
||||||
Model EmbeddingModel `json:"model"`
|
Model EmbeddingModel `json:"model"`
|
||||||
User string `json:"user"`
|
User string `json:"user"`
|
||||||
|
EncodingFormat EmbeddingEncodingFormat `json:"encoding_format,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r EmbeddingRequest) Convert() EmbeddingRequest {
|
func (r EmbeddingRequest) Convert() EmbeddingRequest {
|
||||||
@@ -158,13 +229,18 @@ type EmbeddingRequestStrings struct {
|
|||||||
Model EmbeddingModel `json:"model"`
|
Model EmbeddingModel `json:"model"`
|
||||||
// A unique identifier representing your end-user, which will help OpenAI to monitor and detect abuse.
|
// A unique identifier representing your end-user, which will help OpenAI to monitor and detect abuse.
|
||||||
User string `json:"user"`
|
User string `json:"user"`
|
||||||
|
// EmbeddingEncodingFormat is the format of the embeddings data.
|
||||||
|
// Currently, only "float" and "base64" are supported, however, "base64" is not officially documented.
|
||||||
|
// If not specified OpenAI will use "float".
|
||||||
|
EncodingFormat EmbeddingEncodingFormat `json:"encoding_format,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r EmbeddingRequestStrings) Convert() EmbeddingRequest {
|
func (r EmbeddingRequestStrings) Convert() EmbeddingRequest {
|
||||||
return EmbeddingRequest{
|
return EmbeddingRequest{
|
||||||
Input: r.Input,
|
Input: r.Input,
|
||||||
Model: r.Model,
|
Model: r.Model,
|
||||||
User: r.User,
|
User: r.User,
|
||||||
|
EncodingFormat: r.EncodingFormat,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -181,13 +257,18 @@ type EmbeddingRequestTokens struct {
|
|||||||
Model EmbeddingModel `json:"model"`
|
Model EmbeddingModel `json:"model"`
|
||||||
// A unique identifier representing your end-user, which will help OpenAI to monitor and detect abuse.
|
// A unique identifier representing your end-user, which will help OpenAI to monitor and detect abuse.
|
||||||
User string `json:"user"`
|
User string `json:"user"`
|
||||||
|
// EmbeddingEncodingFormat is the format of the embeddings data.
|
||||||
|
// Currently, only "float" and "base64" are supported, however, "base64" is not officially documented.
|
||||||
|
// If not specified OpenAI will use "float".
|
||||||
|
EncodingFormat EmbeddingEncodingFormat `json:"encoding_format,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r EmbeddingRequestTokens) Convert() EmbeddingRequest {
|
func (r EmbeddingRequestTokens) Convert() EmbeddingRequest {
|
||||||
return EmbeddingRequest{
|
return EmbeddingRequest{
|
||||||
Input: r.Input,
|
Input: r.Input,
|
||||||
Model: r.Model,
|
Model: r.Model,
|
||||||
User: r.User,
|
User: r.User,
|
||||||
|
EncodingFormat: r.EncodingFormat,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -196,14 +277,27 @@ func (r EmbeddingRequestTokens) Convert() EmbeddingRequest {
|
|||||||
//
|
//
|
||||||
// Body should be of type EmbeddingRequestStrings for embedding strings or EmbeddingRequestTokens
|
// Body should be of type EmbeddingRequestStrings for embedding strings or EmbeddingRequestTokens
|
||||||
// for embedding groups of text already converted to tokens.
|
// for embedding groups of text already converted to tokens.
|
||||||
func (c *Client) CreateEmbeddings(ctx context.Context, conv EmbeddingRequestConverter) (res EmbeddingResponse, err error) { //nolint:lll
|
func (c *Client) CreateEmbeddings(
|
||||||
|
ctx context.Context,
|
||||||
|
conv EmbeddingRequestConverter,
|
||||||
|
) (res EmbeddingResponse, err error) {
|
||||||
baseReq := conv.Convert()
|
baseReq := conv.Convert()
|
||||||
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/embeddings", baseReq.Model.String()), withBody(baseReq))
|
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/embeddings", baseReq.Model.String()), withBody(baseReq))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = c.sendRequest(req, &res)
|
if baseReq.EncodingFormat != EmbeddingEncodingFormatBase64 {
|
||||||
|
err = c.sendRequest(req, &res)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
base64Response := &EmbeddingResponseBase64{}
|
||||||
|
err = c.sendRequest(req, base64Response)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
res, err = base64Response.ToEmbeddingResponse()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,15 +1,16 @@
|
|||||||
package openai_test
|
package openai_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
. "github.com/sashabaranov/go-openai"
|
|
||||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
|
||||||
|
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
. "github.com/sashabaranov/go-openai"
|
||||||
|
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestEmbedding(t *testing.T) {
|
func TestEmbedding(t *testing.T) {
|
||||||
@@ -97,22 +98,138 @@ func TestEmbeddingModel(t *testing.T) {
|
|||||||
func TestEmbeddingEndpoint(t *testing.T) {
|
func TestEmbeddingEndpoint(t *testing.T) {
|
||||||
client, server, teardown := setupOpenAITestServer()
|
client, server, teardown := setupOpenAITestServer()
|
||||||
defer teardown()
|
defer teardown()
|
||||||
|
|
||||||
|
sampleEmbeddings := []Embedding{
|
||||||
|
{Embedding: []float32{1.23, 4.56, 7.89}},
|
||||||
|
{Embedding: []float32{-0.006968617, -0.0052718227, 0.011901081}},
|
||||||
|
}
|
||||||
|
|
||||||
|
sampleBase64Embeddings := []Base64Embedding{
|
||||||
|
{Embedding: "pHCdP4XrkUDhevxA"},
|
||||||
|
{Embedding: "/1jku0G/rLvA/EI8"},
|
||||||
|
}
|
||||||
|
|
||||||
server.RegisterHandler(
|
server.RegisterHandler(
|
||||||
"/v1/embeddings",
|
"/v1/embeddings",
|
||||||
func(w http.ResponseWriter, r *http.Request) {
|
func(w http.ResponseWriter, r *http.Request) {
|
||||||
resBytes, _ := json.Marshal(EmbeddingResponse{})
|
var req struct {
|
||||||
|
EncodingFormat EmbeddingEncodingFormat `json:"encoding_format"`
|
||||||
|
User string `json:"user"`
|
||||||
|
}
|
||||||
|
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||||
|
|
||||||
|
var resBytes []byte
|
||||||
|
switch {
|
||||||
|
case req.User == "invalid":
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
case req.EncodingFormat == EmbeddingEncodingFormatBase64:
|
||||||
|
resBytes, _ = json.Marshal(EmbeddingResponseBase64{Data: sampleBase64Embeddings})
|
||||||
|
default:
|
||||||
|
resBytes, _ = json.Marshal(EmbeddingResponse{Data: sampleEmbeddings})
|
||||||
|
}
|
||||||
fmt.Fprintln(w, string(resBytes))
|
fmt.Fprintln(w, string(resBytes))
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
// test create embeddings with strings (simple embedding request)
|
// test create embeddings with strings (simple embedding request)
|
||||||
_, err := client.CreateEmbeddings(context.Background(), EmbeddingRequest{})
|
res, err := client.CreateEmbeddings(context.Background(), EmbeddingRequest{})
|
||||||
checks.NoError(t, err, "CreateEmbeddings error")
|
checks.NoError(t, err, "CreateEmbeddings error")
|
||||||
|
if !reflect.DeepEqual(res.Data, sampleEmbeddings) {
|
||||||
|
t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data)
|
||||||
|
}
|
||||||
|
|
||||||
|
// test create embeddings with strings (simple embedding request)
|
||||||
|
res, err = client.CreateEmbeddings(
|
||||||
|
context.Background(),
|
||||||
|
EmbeddingRequest{
|
||||||
|
EncodingFormat: EmbeddingEncodingFormatBase64,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
checks.NoError(t, err, "CreateEmbeddings error")
|
||||||
|
if !reflect.DeepEqual(res.Data, sampleEmbeddings) {
|
||||||
|
t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data)
|
||||||
|
}
|
||||||
|
|
||||||
// test create embeddings with strings
|
// test create embeddings with strings
|
||||||
_, err = client.CreateEmbeddings(context.Background(), EmbeddingRequestStrings{})
|
res, err = client.CreateEmbeddings(context.Background(), EmbeddingRequestStrings{})
|
||||||
checks.NoError(t, err, "CreateEmbeddings strings error")
|
checks.NoError(t, err, "CreateEmbeddings strings error")
|
||||||
|
if !reflect.DeepEqual(res.Data, sampleEmbeddings) {
|
||||||
|
t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data)
|
||||||
|
}
|
||||||
|
|
||||||
// test create embeddings with tokens
|
// test create embeddings with tokens
|
||||||
_, err = client.CreateEmbeddings(context.Background(), EmbeddingRequestTokens{})
|
res, err = client.CreateEmbeddings(context.Background(), EmbeddingRequestTokens{})
|
||||||
checks.NoError(t, err, "CreateEmbeddings tokens error")
|
checks.NoError(t, err, "CreateEmbeddings tokens error")
|
||||||
|
if !reflect.DeepEqual(res.Data, sampleEmbeddings) {
|
||||||
|
t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data)
|
||||||
|
}
|
||||||
|
|
||||||
|
// test failed sendRequest
|
||||||
|
_, err = client.CreateEmbeddings(context.Background(), EmbeddingRequest{
|
||||||
|
User: "invalid",
|
||||||
|
EncodingFormat: EmbeddingEncodingFormatBase64,
|
||||||
|
})
|
||||||
|
checks.HasError(t, err, "CreateEmbeddings error")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEmbeddingResponseBase64_ToEmbeddingResponse(t *testing.T) {
|
||||||
|
type fields struct {
|
||||||
|
Object string
|
||||||
|
Data []Base64Embedding
|
||||||
|
Model EmbeddingModel
|
||||||
|
Usage Usage
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
fields fields
|
||||||
|
want EmbeddingResponse
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "test embedding response base64 to embedding response",
|
||||||
|
fields: fields{
|
||||||
|
Data: []Base64Embedding{
|
||||||
|
{Embedding: "pHCdP4XrkUDhevxA"},
|
||||||
|
{Embedding: "/1jku0G/rLvA/EI8"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: EmbeddingResponse{
|
||||||
|
Data: []Embedding{
|
||||||
|
{Embedding: []float32{1.23, 4.56, 7.89}},
|
||||||
|
{Embedding: []float32{-0.006968617, -0.0052718227, 0.011901081}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Invalid embedding",
|
||||||
|
fields: fields{
|
||||||
|
Data: []Base64Embedding{
|
||||||
|
{
|
||||||
|
Embedding: "----",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: EmbeddingResponse{},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
r := &EmbeddingResponseBase64{
|
||||||
|
Object: tt.fields.Object,
|
||||||
|
Data: tt.fields.Data,
|
||||||
|
Model: tt.fields.Model,
|
||||||
|
Usage: tt.fields.Usage,
|
||||||
|
}
|
||||||
|
got, err := r.ToEmbeddingResponse()
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("EmbeddingResponseBase64.ToEmbeddingResponse() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("EmbeddingResponseBase64.ToEmbeddingResponse() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user