add embeddings tests (#237)
This commit is contained in:
@@ -2,10 +2,14 @@ package openai_test
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
. "github.com/sashabaranov/go-openai"
|
. "github.com/sashabaranov/go-openai"
|
||||||
|
"github.com/sashabaranov/go-openai/internal/test"
|
||||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||||
|
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -45,3 +49,43 @@ func TestEmbedding(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestEmbeddingModel(t *testing.T) {
|
||||||
|
var em EmbeddingModel
|
||||||
|
err := em.UnmarshalText([]byte("text-similarity-ada-001"))
|
||||||
|
checks.NoError(t, err, "Could not marshal embedding model")
|
||||||
|
|
||||||
|
if em != AdaSimilarity {
|
||||||
|
t.Errorf("Model is not equal to AdaSimilarity")
|
||||||
|
}
|
||||||
|
|
||||||
|
err = em.UnmarshalText([]byte("some-non-existent-model"))
|
||||||
|
checks.NoError(t, err, "Could not marshal embedding model")
|
||||||
|
if em != Unknown {
|
||||||
|
t.Errorf("Model is not equal to Unknown")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEmbeddingEndpoint(t *testing.T) {
|
||||||
|
server := test.NewTestServer()
|
||||||
|
server.RegisterHandler(
|
||||||
|
"/v1/embeddings",
|
||||||
|
func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
resBytes, _ := json.Marshal(EmbeddingResponse{})
|
||||||
|
fmt.Fprintln(w, string(resBytes))
|
||||||
|
},
|
||||||
|
)
|
||||||
|
// create the test server
|
||||||
|
var err error
|
||||||
|
ts := server.OpenAITestServer()
|
||||||
|
ts.Start()
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
config := DefaultConfig(test.GetTestToken())
|
||||||
|
config.BaseURL = ts.URL + "/v1"
|
||||||
|
client := NewClientWithConfig(config)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
_, err = client.CreateEmbeddings(ctx, EmbeddingRequest{})
|
||||||
|
checks.NoError(t, err, "CreateEmbeddings error")
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user