Chore Support base64 embedding format (#485)
* chore: support base64 embedding format * fix: add sizeOfFloat32 * chore: refactor base64 decoding * chore: add tests * fix linting * fix test * fix return error * fix: use smaller slice for tests * fix [skip ci] * chore: refactor test to consider CreateEmbeddings response * trigger build * chore: remove named returns * chore: refactor code to simplify the understanding * chore: tests have been refactored to match the encoding format passed by request * chore: fix tests * fix * fix
This commit is contained in:
116
embeddings.go
116
embeddings.go
@@ -2,6 +2,9 @@ package openai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"math"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
@@ -129,15 +132,83 @@ type EmbeddingResponse struct {
|
||||
Usage Usage `json:"usage"`
|
||||
}
|
||||
|
||||
type base64String string
|
||||
|
||||
func (b base64String) Decode() ([]float32, error) {
|
||||
decodedData, err := base64.StdEncoding.DecodeString(string(b))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
const sizeOfFloat32 = 4
|
||||
floats := make([]float32, len(decodedData)/sizeOfFloat32)
|
||||
for i := 0; i < len(floats); i++ {
|
||||
floats[i] = math.Float32frombits(binary.LittleEndian.Uint32(decodedData[i*4 : (i+1)*4]))
|
||||
}
|
||||
|
||||
return floats, nil
|
||||
}
|
||||
|
||||
// Base64Embedding is a container for base64 encoded embeddings.
|
||||
type Base64Embedding struct {
|
||||
Object string `json:"object"`
|
||||
Embedding base64String `json:"embedding"`
|
||||
Index int `json:"index"`
|
||||
}
|
||||
|
||||
// EmbeddingResponseBase64 is the response from a Create embeddings request with base64 encoding format.
|
||||
type EmbeddingResponseBase64 struct {
|
||||
Object string `json:"object"`
|
||||
Data []Base64Embedding `json:"data"`
|
||||
Model EmbeddingModel `json:"model"`
|
||||
Usage Usage `json:"usage"`
|
||||
}
|
||||
|
||||
// ToEmbeddingResponse converts an embeddingResponseBase64 to an EmbeddingResponse.
|
||||
func (r *EmbeddingResponseBase64) ToEmbeddingResponse() (EmbeddingResponse, error) {
|
||||
data := make([]Embedding, len(r.Data))
|
||||
|
||||
for i, base64Embedding := range r.Data {
|
||||
embedding, err := base64Embedding.Embedding.Decode()
|
||||
if err != nil {
|
||||
return EmbeddingResponse{}, err
|
||||
}
|
||||
|
||||
data[i] = Embedding{
|
||||
Object: base64Embedding.Object,
|
||||
Embedding: embedding,
|
||||
Index: base64Embedding.Index,
|
||||
}
|
||||
}
|
||||
|
||||
return EmbeddingResponse{
|
||||
Object: r.Object,
|
||||
Model: r.Model,
|
||||
Data: data,
|
||||
Usage: r.Usage,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type EmbeddingRequestConverter interface {
|
||||
// Needs to be of type EmbeddingRequestStrings or EmbeddingRequestTokens
|
||||
Convert() EmbeddingRequest
|
||||
}
|
||||
|
||||
// EmbeddingEncodingFormat is the format of the embeddings data.
|
||||
// Currently, only "float" and "base64" are supported, however, "base64" is not officially documented.
|
||||
// If not specified OpenAI will use "float".
|
||||
type EmbeddingEncodingFormat string
|
||||
|
||||
const (
|
||||
EmbeddingEncodingFormatFloat EmbeddingEncodingFormat = "float"
|
||||
EmbeddingEncodingFormatBase64 EmbeddingEncodingFormat = "base64"
|
||||
)
|
||||
|
||||
type EmbeddingRequest struct {
|
||||
Input any `json:"input"`
|
||||
Model EmbeddingModel `json:"model"`
|
||||
User string `json:"user"`
|
||||
Input any `json:"input"`
|
||||
Model EmbeddingModel `json:"model"`
|
||||
User string `json:"user"`
|
||||
EncodingFormat EmbeddingEncodingFormat `json:"encoding_format,omitempty"`
|
||||
}
|
||||
|
||||
func (r EmbeddingRequest) Convert() EmbeddingRequest {
|
||||
@@ -158,13 +229,18 @@ type EmbeddingRequestStrings struct {
|
||||
Model EmbeddingModel `json:"model"`
|
||||
// A unique identifier representing your end-user, which will help OpenAI to monitor and detect abuse.
|
||||
User string `json:"user"`
|
||||
// EmbeddingEncodingFormat is the format of the embeddings data.
|
||||
// Currently, only "float" and "base64" are supported, however, "base64" is not officially documented.
|
||||
// If not specified OpenAI will use "float".
|
||||
EncodingFormat EmbeddingEncodingFormat `json:"encoding_format,omitempty"`
|
||||
}
|
||||
|
||||
func (r EmbeddingRequestStrings) Convert() EmbeddingRequest {
|
||||
return EmbeddingRequest{
|
||||
Input: r.Input,
|
||||
Model: r.Model,
|
||||
User: r.User,
|
||||
Input: r.Input,
|
||||
Model: r.Model,
|
||||
User: r.User,
|
||||
EncodingFormat: r.EncodingFormat,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -181,13 +257,18 @@ type EmbeddingRequestTokens struct {
|
||||
Model EmbeddingModel `json:"model"`
|
||||
// A unique identifier representing your end-user, which will help OpenAI to monitor and detect abuse.
|
||||
User string `json:"user"`
|
||||
// EmbeddingEncodingFormat is the format of the embeddings data.
|
||||
// Currently, only "float" and "base64" are supported, however, "base64" is not officially documented.
|
||||
// If not specified OpenAI will use "float".
|
||||
EncodingFormat EmbeddingEncodingFormat `json:"encoding_format,omitempty"`
|
||||
}
|
||||
|
||||
func (r EmbeddingRequestTokens) Convert() EmbeddingRequest {
|
||||
return EmbeddingRequest{
|
||||
Input: r.Input,
|
||||
Model: r.Model,
|
||||
User: r.User,
|
||||
Input: r.Input,
|
||||
Model: r.Model,
|
||||
User: r.User,
|
||||
EncodingFormat: r.EncodingFormat,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -196,14 +277,27 @@ func (r EmbeddingRequestTokens) Convert() EmbeddingRequest {
|
||||
//
|
||||
// Body should be of type EmbeddingRequestStrings for embedding strings or EmbeddingRequestTokens
|
||||
// for embedding groups of text already converted to tokens.
|
||||
func (c *Client) CreateEmbeddings(ctx context.Context, conv EmbeddingRequestConverter) (res EmbeddingResponse, err error) { //nolint:lll
|
||||
func (c *Client) CreateEmbeddings(
|
||||
ctx context.Context,
|
||||
conv EmbeddingRequestConverter,
|
||||
) (res EmbeddingResponse, err error) {
|
||||
baseReq := conv.Convert()
|
||||
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/embeddings", baseReq.Model.String()), withBody(baseReq))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
err = c.sendRequest(req, &res)
|
||||
if baseReq.EncodingFormat != EmbeddingEncodingFormatBase64 {
|
||||
err = c.sendRequest(req, &res)
|
||||
return
|
||||
}
|
||||
|
||||
base64Response := &EmbeddingResponseBase64{}
|
||||
err = c.sendRequest(req, base64Response)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
res, err = base64Response.ToEmbeddingResponse()
|
||||
return
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user