add embeddings tests (#237)

This commit is contained in:
sashabaranov
2023-04-08 19:49:27 +04:00
committed by GitHub
parent 89219e31b2
commit 4dc1edac38

View File

@@ -2,10 +2,14 @@ package openai_test
import ( import (
. "github.com/sashabaranov/go-openai" . "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks" "github.com/sashabaranov/go-openai/internal/test/checks"
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"fmt"
"net/http"
"testing" "testing"
) )
@@ -45,3 +49,43 @@ func TestEmbedding(t *testing.T) {
} }
} }
} }
func TestEmbeddingModel(t *testing.T) {
var em EmbeddingModel
err := em.UnmarshalText([]byte("text-similarity-ada-001"))
checks.NoError(t, err, "Could not marshal embedding model")
if em != AdaSimilarity {
t.Errorf("Model is not equal to AdaSimilarity")
}
err = em.UnmarshalText([]byte("some-non-existent-model"))
checks.NoError(t, err, "Could not marshal embedding model")
if em != Unknown {
t.Errorf("Model is not equal to Unknown")
}
}
func TestEmbeddingEndpoint(t *testing.T) {
server := test.NewTestServer()
server.RegisterHandler(
"/v1/embeddings",
func(w http.ResponseWriter, r *http.Request) {
resBytes, _ := json.Marshal(EmbeddingResponse{})
fmt.Fprintln(w, string(resBytes))
},
)
// create the test server
var err error
ts := server.OpenAITestServer()
ts.Start()
defer ts.Close()
config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
ctx := context.Background()
_, err = client.CreateEmbeddings(ctx, EmbeddingRequest{})
checks.NoError(t, err, "CreateEmbeddings error")
}