convert EmbeddingModel to string type (#629)
This gives the user the ability to pass in models for embeddings that are not already defined in the library. Also more closely matches how the completions API works.
This commit is contained in:
@@ -47,7 +47,7 @@ func TestEmbedding(t *testing.T) {
|
||||
// the AdaSearchQuery type
|
||||
marshaled, err := json.Marshal(embeddingReq)
|
||||
checks.NoError(t, err, "Could not marshal embedding request")
|
||||
if !bytes.Contains(marshaled, []byte(`"model":"`+model.String()+`"`)) {
|
||||
if !bytes.Contains(marshaled, []byte(`"model":"`+model+`"`)) {
|
||||
t.Fatalf("Expected embedding request to contain model field")
|
||||
}
|
||||
|
||||
@@ -61,7 +61,7 @@ func TestEmbedding(t *testing.T) {
|
||||
}
|
||||
marshaled, err = json.Marshal(embeddingReqStrings)
|
||||
checks.NoError(t, err, "Could not marshal embedding request")
|
||||
if !bytes.Contains(marshaled, []byte(`"model":"`+model.String()+`"`)) {
|
||||
if !bytes.Contains(marshaled, []byte(`"model":"`+model+`"`)) {
|
||||
t.Fatalf("Expected embedding request to contain model field")
|
||||
}
|
||||
|
||||
@@ -75,28 +75,12 @@ func TestEmbedding(t *testing.T) {
|
||||
}
|
||||
marshaled, err = json.Marshal(embeddingReqTokens)
|
||||
checks.NoError(t, err, "Could not marshal embedding request")
|
||||
if !bytes.Contains(marshaled, []byte(`"model":"`+model.String()+`"`)) {
|
||||
if !bytes.Contains(marshaled, []byte(`"model":"`+model+`"`)) {
|
||||
t.Fatalf("Expected embedding request to contain model field")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEmbeddingModel(t *testing.T) {
|
||||
var em openai.EmbeddingModel
|
||||
err := em.UnmarshalText([]byte("text-similarity-ada-001"))
|
||||
checks.NoError(t, err, "Could not marshal embedding model")
|
||||
|
||||
if em != openai.AdaSimilarity {
|
||||
t.Errorf("Model is not equal to AdaSimilarity")
|
||||
}
|
||||
|
||||
err = em.UnmarshalText([]byte("some-non-existent-model"))
|
||||
checks.NoError(t, err, "Could not marshal embedding model")
|
||||
if em != openai.Unknown {
|
||||
t.Errorf("Model is not equal to Unknown")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEmbeddingEndpoint(t *testing.T) {
|
||||
client, server, teardown := setupOpenAITestServer()
|
||||
defer teardown()
|
||||
|
||||
Reference in New Issue
Block a user