Chore Support base64 embedding format (#485)

* chore: support base64 embedding format

* fix: add sizeOfFloat32

* chore: refactor base64 decoding

* chore: add tests

* fix linting

* fix test

* fix return error

* fix: use smaller slice for tests

* fix [skip ci]

* chore: refactor test to consider CreateEmbeddings response

* trigger build

* chore: remove named returns

* chore: refactor code to simplify the understanding

* chore: tests have been refactored to match the encoding format passed by request

* chore: fix tests

* fix

* fix
This commit is contained in:
Simone Vellei
2023-09-11 15:44:46 +02:00
committed by GitHub
parent 3589837b22
commit 8e4b7963a3
2 changed files with 229 additions and 18 deletions

View File

@@ -2,6 +2,9 @@ package openai
import ( import (
"context" "context"
"encoding/base64"
"encoding/binary"
"math"
"net/http" "net/http"
) )
@@ -129,15 +132,83 @@ type EmbeddingResponse struct {
Usage Usage `json:"usage"` Usage Usage `json:"usage"`
} }
type base64String string
func (b base64String) Decode() ([]float32, error) {
decodedData, err := base64.StdEncoding.DecodeString(string(b))
if err != nil {
return nil, err
}
const sizeOfFloat32 = 4
floats := make([]float32, len(decodedData)/sizeOfFloat32)
for i := 0; i < len(floats); i++ {
floats[i] = math.Float32frombits(binary.LittleEndian.Uint32(decodedData[i*4 : (i+1)*4]))
}
return floats, nil
}
// Base64Embedding is a container for base64 encoded embeddings.
type Base64Embedding struct {
Object string `json:"object"`
Embedding base64String `json:"embedding"`
Index int `json:"index"`
}
// EmbeddingResponseBase64 is the response from a Create embeddings request with base64 encoding format.
type EmbeddingResponseBase64 struct {
Object string `json:"object"`
Data []Base64Embedding `json:"data"`
Model EmbeddingModel `json:"model"`
Usage Usage `json:"usage"`
}
// ToEmbeddingResponse converts an embeddingResponseBase64 to an EmbeddingResponse.
func (r *EmbeddingResponseBase64) ToEmbeddingResponse() (EmbeddingResponse, error) {
data := make([]Embedding, len(r.Data))
for i, base64Embedding := range r.Data {
embedding, err := base64Embedding.Embedding.Decode()
if err != nil {
return EmbeddingResponse{}, err
}
data[i] = Embedding{
Object: base64Embedding.Object,
Embedding: embedding,
Index: base64Embedding.Index,
}
}
return EmbeddingResponse{
Object: r.Object,
Model: r.Model,
Data: data,
Usage: r.Usage,
}, nil
}
type EmbeddingRequestConverter interface { type EmbeddingRequestConverter interface {
// Needs to be of type EmbeddingRequestStrings or EmbeddingRequestTokens // Needs to be of type EmbeddingRequestStrings or EmbeddingRequestTokens
Convert() EmbeddingRequest Convert() EmbeddingRequest
} }
// EmbeddingEncodingFormat is the format of the embeddings data.
// Currently, only "float" and "base64" are supported, however, "base64" is not officially documented.
// If not specified OpenAI will use "float".
type EmbeddingEncodingFormat string
const (
EmbeddingEncodingFormatFloat EmbeddingEncodingFormat = "float"
EmbeddingEncodingFormatBase64 EmbeddingEncodingFormat = "base64"
)
type EmbeddingRequest struct { type EmbeddingRequest struct {
Input any `json:"input"` Input any `json:"input"`
Model EmbeddingModel `json:"model"` Model EmbeddingModel `json:"model"`
User string `json:"user"` User string `json:"user"`
EncodingFormat EmbeddingEncodingFormat `json:"encoding_format,omitempty"`
} }
func (r EmbeddingRequest) Convert() EmbeddingRequest { func (r EmbeddingRequest) Convert() EmbeddingRequest {
@@ -158,13 +229,18 @@ type EmbeddingRequestStrings struct {
Model EmbeddingModel `json:"model"` Model EmbeddingModel `json:"model"`
// A unique identifier representing your end-user, which will help OpenAI to monitor and detect abuse. // A unique identifier representing your end-user, which will help OpenAI to monitor and detect abuse.
User string `json:"user"` User string `json:"user"`
// EmbeddingEncodingFormat is the format of the embeddings data.
// Currently, only "float" and "base64" are supported, however, "base64" is not officially documented.
// If not specified OpenAI will use "float".
EncodingFormat EmbeddingEncodingFormat `json:"encoding_format,omitempty"`
} }
func (r EmbeddingRequestStrings) Convert() EmbeddingRequest { func (r EmbeddingRequestStrings) Convert() EmbeddingRequest {
return EmbeddingRequest{ return EmbeddingRequest{
Input: r.Input, Input: r.Input,
Model: r.Model, Model: r.Model,
User: r.User, User: r.User,
EncodingFormat: r.EncodingFormat,
} }
} }
@@ -181,13 +257,18 @@ type EmbeddingRequestTokens struct {
Model EmbeddingModel `json:"model"` Model EmbeddingModel `json:"model"`
// A unique identifier representing your end-user, which will help OpenAI to monitor and detect abuse. // A unique identifier representing your end-user, which will help OpenAI to monitor and detect abuse.
User string `json:"user"` User string `json:"user"`
// EmbeddingEncodingFormat is the format of the embeddings data.
// Currently, only "float" and "base64" are supported, however, "base64" is not officially documented.
// If not specified OpenAI will use "float".
EncodingFormat EmbeddingEncodingFormat `json:"encoding_format,omitempty"`
} }
func (r EmbeddingRequestTokens) Convert() EmbeddingRequest { func (r EmbeddingRequestTokens) Convert() EmbeddingRequest {
return EmbeddingRequest{ return EmbeddingRequest{
Input: r.Input, Input: r.Input,
Model: r.Model, Model: r.Model,
User: r.User, User: r.User,
EncodingFormat: r.EncodingFormat,
} }
} }
@@ -196,14 +277,27 @@ func (r EmbeddingRequestTokens) Convert() EmbeddingRequest {
// //
// Body should be of type EmbeddingRequestStrings for embedding strings or EmbeddingRequestTokens // Body should be of type EmbeddingRequestStrings for embedding strings or EmbeddingRequestTokens
// for embedding groups of text already converted to tokens. // for embedding groups of text already converted to tokens.
func (c *Client) CreateEmbeddings(ctx context.Context, conv EmbeddingRequestConverter) (res EmbeddingResponse, err error) { //nolint:lll func (c *Client) CreateEmbeddings(
ctx context.Context,
conv EmbeddingRequestConverter,
) (res EmbeddingResponse, err error) {
baseReq := conv.Convert() baseReq := conv.Convert()
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/embeddings", baseReq.Model.String()), withBody(baseReq)) req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/embeddings", baseReq.Model.String()), withBody(baseReq))
if err != nil { if err != nil {
return return
} }
err = c.sendRequest(req, &res) if baseReq.EncodingFormat != EmbeddingEncodingFormatBase64 {
err = c.sendRequest(req, &res)
return
}
base64Response := &EmbeddingResponseBase64{}
err = c.sendRequest(req, base64Response)
if err != nil {
return
}
res, err = base64Response.ToEmbeddingResponse()
return return
} }

View File

@@ -1,15 +1,16 @@
package openai_test package openai_test
import ( import (
. "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks"
"bytes" "bytes"
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http" "net/http"
"reflect"
"testing" "testing"
. "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks"
) )
func TestEmbedding(t *testing.T) { func TestEmbedding(t *testing.T) {
@@ -97,22 +98,138 @@ func TestEmbeddingModel(t *testing.T) {
func TestEmbeddingEndpoint(t *testing.T) { func TestEmbeddingEndpoint(t *testing.T) {
client, server, teardown := setupOpenAITestServer() client, server, teardown := setupOpenAITestServer()
defer teardown() defer teardown()
sampleEmbeddings := []Embedding{
{Embedding: []float32{1.23, 4.56, 7.89}},
{Embedding: []float32{-0.006968617, -0.0052718227, 0.011901081}},
}
sampleBase64Embeddings := []Base64Embedding{
{Embedding: "pHCdP4XrkUDhevxA"},
{Embedding: "/1jku0G/rLvA/EI8"},
}
server.RegisterHandler( server.RegisterHandler(
"/v1/embeddings", "/v1/embeddings",
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
resBytes, _ := json.Marshal(EmbeddingResponse{}) var req struct {
EncodingFormat EmbeddingEncodingFormat `json:"encoding_format"`
User string `json:"user"`
}
_ = json.NewDecoder(r.Body).Decode(&req)
var resBytes []byte
switch {
case req.User == "invalid":
w.WriteHeader(http.StatusBadRequest)
return
case req.EncodingFormat == EmbeddingEncodingFormatBase64:
resBytes, _ = json.Marshal(EmbeddingResponseBase64{Data: sampleBase64Embeddings})
default:
resBytes, _ = json.Marshal(EmbeddingResponse{Data: sampleEmbeddings})
}
fmt.Fprintln(w, string(resBytes)) fmt.Fprintln(w, string(resBytes))
}, },
) )
// test create embeddings with strings (simple embedding request) // test create embeddings with strings (simple embedding request)
_, err := client.CreateEmbeddings(context.Background(), EmbeddingRequest{}) res, err := client.CreateEmbeddings(context.Background(), EmbeddingRequest{})
checks.NoError(t, err, "CreateEmbeddings error") checks.NoError(t, err, "CreateEmbeddings error")
if !reflect.DeepEqual(res.Data, sampleEmbeddings) {
t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data)
}
// test create embeddings with strings (simple embedding request)
res, err = client.CreateEmbeddings(
context.Background(),
EmbeddingRequest{
EncodingFormat: EmbeddingEncodingFormatBase64,
},
)
checks.NoError(t, err, "CreateEmbeddings error")
if !reflect.DeepEqual(res.Data, sampleEmbeddings) {
t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data)
}
// test create embeddings with strings // test create embeddings with strings
_, err = client.CreateEmbeddings(context.Background(), EmbeddingRequestStrings{}) res, err = client.CreateEmbeddings(context.Background(), EmbeddingRequestStrings{})
checks.NoError(t, err, "CreateEmbeddings strings error") checks.NoError(t, err, "CreateEmbeddings strings error")
if !reflect.DeepEqual(res.Data, sampleEmbeddings) {
t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data)
}
// test create embeddings with tokens // test create embeddings with tokens
_, err = client.CreateEmbeddings(context.Background(), EmbeddingRequestTokens{}) res, err = client.CreateEmbeddings(context.Background(), EmbeddingRequestTokens{})
checks.NoError(t, err, "CreateEmbeddings tokens error") checks.NoError(t, err, "CreateEmbeddings tokens error")
if !reflect.DeepEqual(res.Data, sampleEmbeddings) {
t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data)
}
// test failed sendRequest
_, err = client.CreateEmbeddings(context.Background(), EmbeddingRequest{
User: "invalid",
EncodingFormat: EmbeddingEncodingFormatBase64,
})
checks.HasError(t, err, "CreateEmbeddings error")
}
func TestEmbeddingResponseBase64_ToEmbeddingResponse(t *testing.T) {
type fields struct {
Object string
Data []Base64Embedding
Model EmbeddingModel
Usage Usage
}
tests := []struct {
name string
fields fields
want EmbeddingResponse
wantErr bool
}{
{
name: "test embedding response base64 to embedding response",
fields: fields{
Data: []Base64Embedding{
{Embedding: "pHCdP4XrkUDhevxA"},
{Embedding: "/1jku0G/rLvA/EI8"},
},
},
want: EmbeddingResponse{
Data: []Embedding{
{Embedding: []float32{1.23, 4.56, 7.89}},
{Embedding: []float32{-0.006968617, -0.0052718227, 0.011901081}},
},
},
wantErr: false,
},
{
name: "Invalid embedding",
fields: fields{
Data: []Base64Embedding{
{
Embedding: "----",
},
},
},
want: EmbeddingResponse{},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := &EmbeddingResponseBase64{
Object: tt.fields.Object,
Data: tt.fields.Data,
Model: tt.fields.Model,
Usage: tt.fields.Usage,
}
got, err := r.ToEmbeddingResponse()
if (err != nil) != tt.wantErr {
t.Errorf("EmbeddingResponseBase64.ToEmbeddingResponse() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("EmbeddingResponseBase64.ToEmbeddingResponse() = %v, want %v", got, tt.want)
}
})
}
} }