convert EmbeddingModel to string type (#629)

This gives the user the ability to pass in models for embeddings that are not
already defined in the library. Also more closely matches how the completions
API works.
This commit is contained in:
Matthew Jaffee
2024-01-15 03:33:02 -06:00
committed by GitHub
parent 682b7adb0b
commit e01a2d7231
2 changed files with 24 additions and 118 deletions

View File

@@ -13,108 +13,30 @@ var ErrVectorLengthMismatch = errors.New("vector length mismatch")
// EmbeddingModel enumerates the models which can be used // EmbeddingModel enumerates the models which can be used
// to generate Embedding vectors. // to generate Embedding vectors.
type EmbeddingModel int type EmbeddingModel string
// String implements the fmt.Stringer interface.
func (e EmbeddingModel) String() string {
return enumToString[e]
}
// MarshalText implements the encoding.TextMarshaler interface.
func (e EmbeddingModel) MarshalText() ([]byte, error) {
return []byte(e.String()), nil
}
// UnmarshalText implements the encoding.TextUnmarshaler interface.
// On unrecognized value, it sets |e| to Unknown.
func (e *EmbeddingModel) UnmarshalText(b []byte) error {
if val, ok := stringToEnum[(string(b))]; ok {
*e = val
return nil
}
*e = Unknown
return nil
}
const ( const (
Unknown EmbeddingModel = iota // Deprecated: The following block will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. AdaSimilarity EmbeddingModel = "text-similarity-ada-001"
AdaSimilarity BabbageSimilarity EmbeddingModel = "text-similarity-babbage-001"
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. CurieSimilarity EmbeddingModel = "text-similarity-curie-001"
BabbageSimilarity DavinciSimilarity EmbeddingModel = "text-similarity-davinci-001"
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. AdaSearchDocument EmbeddingModel = "text-search-ada-doc-001"
CurieSimilarity AdaSearchQuery EmbeddingModel = "text-search-ada-query-001"
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. BabbageSearchDocument EmbeddingModel = "text-search-babbage-doc-001"
DavinciSimilarity BabbageSearchQuery EmbeddingModel = "text-search-babbage-query-001"
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. CurieSearchDocument EmbeddingModel = "text-search-curie-doc-001"
AdaSearchDocument CurieSearchQuery EmbeddingModel = "text-search-curie-query-001"
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. DavinciSearchDocument EmbeddingModel = "text-search-davinci-doc-001"
AdaSearchQuery DavinciSearchQuery EmbeddingModel = "text-search-davinci-query-001"
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. AdaCodeSearchCode EmbeddingModel = "code-search-ada-code-001"
BabbageSearchDocument AdaCodeSearchText EmbeddingModel = "code-search-ada-text-001"
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. BabbageCodeSearchCode EmbeddingModel = "code-search-babbage-code-001"
BabbageSearchQuery BabbageCodeSearchText EmbeddingModel = "code-search-babbage-text-001"
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
CurieSearchDocument AdaEmbeddingV2 EmbeddingModel = "text-embedding-ada-002"
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
CurieSearchQuery
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
DavinciSearchDocument
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
DavinciSearchQuery
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
AdaCodeSearchCode
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
AdaCodeSearchText
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
BabbageCodeSearchCode
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
BabbageCodeSearchText
AdaEmbeddingV2
) )
var enumToString = map[EmbeddingModel]string{
AdaSimilarity: "text-similarity-ada-001",
BabbageSimilarity: "text-similarity-babbage-001",
CurieSimilarity: "text-similarity-curie-001",
DavinciSimilarity: "text-similarity-davinci-001",
AdaSearchDocument: "text-search-ada-doc-001",
AdaSearchQuery: "text-search-ada-query-001",
BabbageSearchDocument: "text-search-babbage-doc-001",
BabbageSearchQuery: "text-search-babbage-query-001",
CurieSearchDocument: "text-search-curie-doc-001",
CurieSearchQuery: "text-search-curie-query-001",
DavinciSearchDocument: "text-search-davinci-doc-001",
DavinciSearchQuery: "text-search-davinci-query-001",
AdaCodeSearchCode: "code-search-ada-code-001",
AdaCodeSearchText: "code-search-ada-text-001",
BabbageCodeSearchCode: "code-search-babbage-code-001",
BabbageCodeSearchText: "code-search-babbage-text-001",
AdaEmbeddingV2: "text-embedding-ada-002",
}
var stringToEnum = map[string]EmbeddingModel{
"text-similarity-ada-001": AdaSimilarity,
"text-similarity-babbage-001": BabbageSimilarity,
"text-similarity-curie-001": CurieSimilarity,
"text-similarity-davinci-001": DavinciSimilarity,
"text-search-ada-doc-001": AdaSearchDocument,
"text-search-ada-query-001": AdaSearchQuery,
"text-search-babbage-doc-001": BabbageSearchDocument,
"text-search-babbage-query-001": BabbageSearchQuery,
"text-search-curie-doc-001": CurieSearchDocument,
"text-search-curie-query-001": CurieSearchQuery,
"text-search-davinci-doc-001": DavinciSearchDocument,
"text-search-davinci-query-001": DavinciSearchQuery,
"code-search-ada-code-001": AdaCodeSearchCode,
"code-search-ada-text-001": AdaCodeSearchText,
"code-search-babbage-code-001": BabbageCodeSearchCode,
"code-search-babbage-text-001": BabbageCodeSearchText,
"text-embedding-ada-002": AdaEmbeddingV2,
}
// Embedding is a special format of data representation that can be easily utilized by machine // Embedding is a special format of data representation that can be easily utilized by machine
// learning models and algorithms. The embedding is an information dense representation of the // learning models and algorithms. The embedding is an information dense representation of the
// semantic meaning of a piece of text. Each embedding is a vector of floating point numbers, // semantic meaning of a piece of text. Each embedding is a vector of floating point numbers,
@@ -306,7 +228,7 @@ func (c *Client) CreateEmbeddings(
conv EmbeddingRequestConverter, conv EmbeddingRequestConverter,
) (res EmbeddingResponse, err error) { ) (res EmbeddingResponse, err error) {
baseReq := conv.Convert() baseReq := conv.Convert()
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/embeddings", baseReq.Model.String()), withBody(baseReq)) req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/embeddings", baseReq.Model), withBody(baseReq))
if err != nil { if err != nil {
return return
} }

View File

@@ -47,7 +47,7 @@ func TestEmbedding(t *testing.T) {
// the AdaSearchQuery type // the AdaSearchQuery type
marshaled, err := json.Marshal(embeddingReq) marshaled, err := json.Marshal(embeddingReq)
checks.NoError(t, err, "Could not marshal embedding request") checks.NoError(t, err, "Could not marshal embedding request")
if !bytes.Contains(marshaled, []byte(`"model":"`+model.String()+`"`)) { if !bytes.Contains(marshaled, []byte(`"model":"`+model+`"`)) {
t.Fatalf("Expected embedding request to contain model field") t.Fatalf("Expected embedding request to contain model field")
} }
@@ -61,7 +61,7 @@ func TestEmbedding(t *testing.T) {
} }
marshaled, err = json.Marshal(embeddingReqStrings) marshaled, err = json.Marshal(embeddingReqStrings)
checks.NoError(t, err, "Could not marshal embedding request") checks.NoError(t, err, "Could not marshal embedding request")
if !bytes.Contains(marshaled, []byte(`"model":"`+model.String()+`"`)) { if !bytes.Contains(marshaled, []byte(`"model":"`+model+`"`)) {
t.Fatalf("Expected embedding request to contain model field") t.Fatalf("Expected embedding request to contain model field")
} }
@@ -75,28 +75,12 @@ func TestEmbedding(t *testing.T) {
} }
marshaled, err = json.Marshal(embeddingReqTokens) marshaled, err = json.Marshal(embeddingReqTokens)
checks.NoError(t, err, "Could not marshal embedding request") checks.NoError(t, err, "Could not marshal embedding request")
if !bytes.Contains(marshaled, []byte(`"model":"`+model.String()+`"`)) { if !bytes.Contains(marshaled, []byte(`"model":"`+model+`"`)) {
t.Fatalf("Expected embedding request to contain model field") t.Fatalf("Expected embedding request to contain model field")
} }
} }
} }
func TestEmbeddingModel(t *testing.T) {
var em openai.EmbeddingModel
err := em.UnmarshalText([]byte("text-similarity-ada-001"))
checks.NoError(t, err, "Could not marshal embedding model")
if em != openai.AdaSimilarity {
t.Errorf("Model is not equal to AdaSimilarity")
}
err = em.UnmarshalText([]byte("some-non-existent-model"))
checks.NoError(t, err, "Could not marshal embedding model")
if em != openai.Unknown {
t.Errorf("Model is not equal to Unknown")
}
}
func TestEmbeddingEndpoint(t *testing.T) { func TestEmbeddingEndpoint(t *testing.T) {
client, server, teardown := setupOpenAITestServer() client, server, teardown := setupOpenAITestServer()
defer teardown() defer teardown()