convert EmbeddingModel to string type (#629)
This gives the user the ability to pass in models for embeddings that are not already defined in the library. Also more closely matches how the completions API works.
This commit is contained in:
120
embeddings.go
120
embeddings.go
@@ -13,108 +13,30 @@ var ErrVectorLengthMismatch = errors.New("vector length mismatch")
|
|||||||
|
|
||||||
// EmbeddingModel enumerates the models which can be used
|
// EmbeddingModel enumerates the models which can be used
|
||||||
// to generate Embedding vectors.
|
// to generate Embedding vectors.
|
||||||
type EmbeddingModel int
|
type EmbeddingModel string
|
||||||
|
|
||||||
// String implements the fmt.Stringer interface.
|
|
||||||
func (e EmbeddingModel) String() string {
|
|
||||||
return enumToString[e]
|
|
||||||
}
|
|
||||||
|
|
||||||
// MarshalText implements the encoding.TextMarshaler interface.
|
|
||||||
func (e EmbeddingModel) MarshalText() ([]byte, error) {
|
|
||||||
return []byte(e.String()), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// UnmarshalText implements the encoding.TextUnmarshaler interface.
|
|
||||||
// On unrecognized value, it sets |e| to Unknown.
|
|
||||||
func (e *EmbeddingModel) UnmarshalText(b []byte) error {
|
|
||||||
if val, ok := stringToEnum[(string(b))]; ok {
|
|
||||||
*e = val
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
*e = Unknown
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
const (
|
||||||
Unknown EmbeddingModel = iota
|
// Deprecated: The following block will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
|
||||||
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
|
AdaSimilarity EmbeddingModel = "text-similarity-ada-001"
|
||||||
AdaSimilarity
|
BabbageSimilarity EmbeddingModel = "text-similarity-babbage-001"
|
||||||
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
|
CurieSimilarity EmbeddingModel = "text-similarity-curie-001"
|
||||||
BabbageSimilarity
|
DavinciSimilarity EmbeddingModel = "text-similarity-davinci-001"
|
||||||
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
|
AdaSearchDocument EmbeddingModel = "text-search-ada-doc-001"
|
||||||
CurieSimilarity
|
AdaSearchQuery EmbeddingModel = "text-search-ada-query-001"
|
||||||
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
|
BabbageSearchDocument EmbeddingModel = "text-search-babbage-doc-001"
|
||||||
DavinciSimilarity
|
BabbageSearchQuery EmbeddingModel = "text-search-babbage-query-001"
|
||||||
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
|
CurieSearchDocument EmbeddingModel = "text-search-curie-doc-001"
|
||||||
AdaSearchDocument
|
CurieSearchQuery EmbeddingModel = "text-search-curie-query-001"
|
||||||
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
|
DavinciSearchDocument EmbeddingModel = "text-search-davinci-doc-001"
|
||||||
AdaSearchQuery
|
DavinciSearchQuery EmbeddingModel = "text-search-davinci-query-001"
|
||||||
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
|
AdaCodeSearchCode EmbeddingModel = "code-search-ada-code-001"
|
||||||
BabbageSearchDocument
|
AdaCodeSearchText EmbeddingModel = "code-search-ada-text-001"
|
||||||
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
|
BabbageCodeSearchCode EmbeddingModel = "code-search-babbage-code-001"
|
||||||
BabbageSearchQuery
|
BabbageCodeSearchText EmbeddingModel = "code-search-babbage-text-001"
|
||||||
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
|
|
||||||
CurieSearchDocument
|
AdaEmbeddingV2 EmbeddingModel = "text-embedding-ada-002"
|
||||||
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
|
|
||||||
CurieSearchQuery
|
|
||||||
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
|
|
||||||
DavinciSearchDocument
|
|
||||||
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
|
|
||||||
DavinciSearchQuery
|
|
||||||
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
|
|
||||||
AdaCodeSearchCode
|
|
||||||
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
|
|
||||||
AdaCodeSearchText
|
|
||||||
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
|
|
||||||
BabbageCodeSearchCode
|
|
||||||
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
|
|
||||||
BabbageCodeSearchText
|
|
||||||
AdaEmbeddingV2
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var enumToString = map[EmbeddingModel]string{
|
|
||||||
AdaSimilarity: "text-similarity-ada-001",
|
|
||||||
BabbageSimilarity: "text-similarity-babbage-001",
|
|
||||||
CurieSimilarity: "text-similarity-curie-001",
|
|
||||||
DavinciSimilarity: "text-similarity-davinci-001",
|
|
||||||
AdaSearchDocument: "text-search-ada-doc-001",
|
|
||||||
AdaSearchQuery: "text-search-ada-query-001",
|
|
||||||
BabbageSearchDocument: "text-search-babbage-doc-001",
|
|
||||||
BabbageSearchQuery: "text-search-babbage-query-001",
|
|
||||||
CurieSearchDocument: "text-search-curie-doc-001",
|
|
||||||
CurieSearchQuery: "text-search-curie-query-001",
|
|
||||||
DavinciSearchDocument: "text-search-davinci-doc-001",
|
|
||||||
DavinciSearchQuery: "text-search-davinci-query-001",
|
|
||||||
AdaCodeSearchCode: "code-search-ada-code-001",
|
|
||||||
AdaCodeSearchText: "code-search-ada-text-001",
|
|
||||||
BabbageCodeSearchCode: "code-search-babbage-code-001",
|
|
||||||
BabbageCodeSearchText: "code-search-babbage-text-001",
|
|
||||||
AdaEmbeddingV2: "text-embedding-ada-002",
|
|
||||||
}
|
|
||||||
|
|
||||||
var stringToEnum = map[string]EmbeddingModel{
|
|
||||||
"text-similarity-ada-001": AdaSimilarity,
|
|
||||||
"text-similarity-babbage-001": BabbageSimilarity,
|
|
||||||
"text-similarity-curie-001": CurieSimilarity,
|
|
||||||
"text-similarity-davinci-001": DavinciSimilarity,
|
|
||||||
"text-search-ada-doc-001": AdaSearchDocument,
|
|
||||||
"text-search-ada-query-001": AdaSearchQuery,
|
|
||||||
"text-search-babbage-doc-001": BabbageSearchDocument,
|
|
||||||
"text-search-babbage-query-001": BabbageSearchQuery,
|
|
||||||
"text-search-curie-doc-001": CurieSearchDocument,
|
|
||||||
"text-search-curie-query-001": CurieSearchQuery,
|
|
||||||
"text-search-davinci-doc-001": DavinciSearchDocument,
|
|
||||||
"text-search-davinci-query-001": DavinciSearchQuery,
|
|
||||||
"code-search-ada-code-001": AdaCodeSearchCode,
|
|
||||||
"code-search-ada-text-001": AdaCodeSearchText,
|
|
||||||
"code-search-babbage-code-001": BabbageCodeSearchCode,
|
|
||||||
"code-search-babbage-text-001": BabbageCodeSearchText,
|
|
||||||
"text-embedding-ada-002": AdaEmbeddingV2,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Embedding is a special format of data representation that can be easily utilized by machine
|
// Embedding is a special format of data representation that can be easily utilized by machine
|
||||||
// learning models and algorithms. The embedding is an information dense representation of the
|
// learning models and algorithms. The embedding is an information dense representation of the
|
||||||
// semantic meaning of a piece of text. Each embedding is a vector of floating point numbers,
|
// semantic meaning of a piece of text. Each embedding is a vector of floating point numbers,
|
||||||
@@ -306,7 +228,7 @@ func (c *Client) CreateEmbeddings(
|
|||||||
conv EmbeddingRequestConverter,
|
conv EmbeddingRequestConverter,
|
||||||
) (res EmbeddingResponse, err error) {
|
) (res EmbeddingResponse, err error) {
|
||||||
baseReq := conv.Convert()
|
baseReq := conv.Convert()
|
||||||
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/embeddings", baseReq.Model.String()), withBody(baseReq))
|
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/embeddings", baseReq.Model), withBody(baseReq))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ func TestEmbedding(t *testing.T) {
|
|||||||
// the AdaSearchQuery type
|
// the AdaSearchQuery type
|
||||||
marshaled, err := json.Marshal(embeddingReq)
|
marshaled, err := json.Marshal(embeddingReq)
|
||||||
checks.NoError(t, err, "Could not marshal embedding request")
|
checks.NoError(t, err, "Could not marshal embedding request")
|
||||||
if !bytes.Contains(marshaled, []byte(`"model":"`+model.String()+`"`)) {
|
if !bytes.Contains(marshaled, []byte(`"model":"`+model+`"`)) {
|
||||||
t.Fatalf("Expected embedding request to contain model field")
|
t.Fatalf("Expected embedding request to contain model field")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -61,7 +61,7 @@ func TestEmbedding(t *testing.T) {
|
|||||||
}
|
}
|
||||||
marshaled, err = json.Marshal(embeddingReqStrings)
|
marshaled, err = json.Marshal(embeddingReqStrings)
|
||||||
checks.NoError(t, err, "Could not marshal embedding request")
|
checks.NoError(t, err, "Could not marshal embedding request")
|
||||||
if !bytes.Contains(marshaled, []byte(`"model":"`+model.String()+`"`)) {
|
if !bytes.Contains(marshaled, []byte(`"model":"`+model+`"`)) {
|
||||||
t.Fatalf("Expected embedding request to contain model field")
|
t.Fatalf("Expected embedding request to contain model field")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -75,28 +75,12 @@ func TestEmbedding(t *testing.T) {
|
|||||||
}
|
}
|
||||||
marshaled, err = json.Marshal(embeddingReqTokens)
|
marshaled, err = json.Marshal(embeddingReqTokens)
|
||||||
checks.NoError(t, err, "Could not marshal embedding request")
|
checks.NoError(t, err, "Could not marshal embedding request")
|
||||||
if !bytes.Contains(marshaled, []byte(`"model":"`+model.String()+`"`)) {
|
if !bytes.Contains(marshaled, []byte(`"model":"`+model+`"`)) {
|
||||||
t.Fatalf("Expected embedding request to contain model field")
|
t.Fatalf("Expected embedding request to contain model field")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestEmbeddingModel(t *testing.T) {
|
|
||||||
var em openai.EmbeddingModel
|
|
||||||
err := em.UnmarshalText([]byte("text-similarity-ada-001"))
|
|
||||||
checks.NoError(t, err, "Could not marshal embedding model")
|
|
||||||
|
|
||||||
if em != openai.AdaSimilarity {
|
|
||||||
t.Errorf("Model is not equal to AdaSimilarity")
|
|
||||||
}
|
|
||||||
|
|
||||||
err = em.UnmarshalText([]byte("some-non-existent-model"))
|
|
||||||
checks.NoError(t, err, "Could not marshal embedding model")
|
|
||||||
if em != openai.Unknown {
|
|
||||||
t.Errorf("Model is not equal to Unknown")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestEmbeddingEndpoint(t *testing.T) {
|
func TestEmbeddingEndpoint(t *testing.T) {
|
||||||
client, server, teardown := setupOpenAITestServer()
|
client, server, teardown := setupOpenAITestServer()
|
||||||
defer teardown()
|
defer teardown()
|
||||||
|
|||||||
Reference in New Issue
Block a user