@@ -11,32 +11,32 @@ import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
. "github.com/sashabaranov/go-openai"
|
||||
"github.com/sashabaranov/go-openai"
|
||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||
)
|
||||
|
||||
func TestEmbedding(t *testing.T) {
|
||||
embeddedModels := []EmbeddingModel{
|
||||
AdaSimilarity,
|
||||
BabbageSimilarity,
|
||||
CurieSimilarity,
|
||||
DavinciSimilarity,
|
||||
AdaSearchDocument,
|
||||
AdaSearchQuery,
|
||||
BabbageSearchDocument,
|
||||
BabbageSearchQuery,
|
||||
CurieSearchDocument,
|
||||
CurieSearchQuery,
|
||||
DavinciSearchDocument,
|
||||
DavinciSearchQuery,
|
||||
AdaCodeSearchCode,
|
||||
AdaCodeSearchText,
|
||||
BabbageCodeSearchCode,
|
||||
BabbageCodeSearchText,
|
||||
embeddedModels := []openai.EmbeddingModel{
|
||||
openai.AdaSimilarity,
|
||||
openai.BabbageSimilarity,
|
||||
openai.CurieSimilarity,
|
||||
openai.DavinciSimilarity,
|
||||
openai.AdaSearchDocument,
|
||||
openai.AdaSearchQuery,
|
||||
openai.BabbageSearchDocument,
|
||||
openai.BabbageSearchQuery,
|
||||
openai.CurieSearchDocument,
|
||||
openai.CurieSearchQuery,
|
||||
openai.DavinciSearchDocument,
|
||||
openai.DavinciSearchQuery,
|
||||
openai.AdaCodeSearchCode,
|
||||
openai.AdaCodeSearchText,
|
||||
openai.BabbageCodeSearchCode,
|
||||
openai.BabbageCodeSearchText,
|
||||
}
|
||||
for _, model := range embeddedModels {
|
||||
// test embedding request with strings (simple embedding request)
|
||||
embeddingReq := EmbeddingRequest{
|
||||
embeddingReq := openai.EmbeddingRequest{
|
||||
Input: []string{
|
||||
"The food was delicious and the waiter",
|
||||
"Other examples of embedding request",
|
||||
@@ -52,7 +52,7 @@ func TestEmbedding(t *testing.T) {
|
||||
}
|
||||
|
||||
// test embedding request with strings
|
||||
embeddingReqStrings := EmbeddingRequestStrings{
|
||||
embeddingReqStrings := openai.EmbeddingRequestStrings{
|
||||
Input: []string{
|
||||
"The food was delicious and the waiter",
|
||||
"Other examples of embedding request",
|
||||
@@ -66,7 +66,7 @@ func TestEmbedding(t *testing.T) {
|
||||
}
|
||||
|
||||
// test embedding request with tokens
|
||||
embeddingReqTokens := EmbeddingRequestTokens{
|
||||
embeddingReqTokens := openai.EmbeddingRequestTokens{
|
||||
Input: [][]int{
|
||||
{464, 2057, 373, 12625, 290, 262, 46612},
|
||||
{6395, 6096, 286, 11525, 12083, 2581},
|
||||
@@ -82,17 +82,17 @@ func TestEmbedding(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestEmbeddingModel(t *testing.T) {
|
||||
var em EmbeddingModel
|
||||
var em openai.EmbeddingModel
|
||||
err := em.UnmarshalText([]byte("text-similarity-ada-001"))
|
||||
checks.NoError(t, err, "Could not marshal embedding model")
|
||||
|
||||
if em != AdaSimilarity {
|
||||
if em != openai.AdaSimilarity {
|
||||
t.Errorf("Model is not equal to AdaSimilarity")
|
||||
}
|
||||
|
||||
err = em.UnmarshalText([]byte("some-non-existent-model"))
|
||||
checks.NoError(t, err, "Could not marshal embedding model")
|
||||
if em != Unknown {
|
||||
if em != openai.Unknown {
|
||||
t.Errorf("Model is not equal to Unknown")
|
||||
}
|
||||
}
|
||||
@@ -101,12 +101,12 @@ func TestEmbeddingEndpoint(t *testing.T) {
|
||||
client, server, teardown := setupOpenAITestServer()
|
||||
defer teardown()
|
||||
|
||||
sampleEmbeddings := []Embedding{
|
||||
sampleEmbeddings := []openai.Embedding{
|
||||
{Embedding: []float32{1.23, 4.56, 7.89}},
|
||||
{Embedding: []float32{-0.006968617, -0.0052718227, 0.011901081}},
|
||||
}
|
||||
|
||||
sampleBase64Embeddings := []Base64Embedding{
|
||||
sampleBase64Embeddings := []openai.Base64Embedding{
|
||||
{Embedding: "pHCdP4XrkUDhevxA"},
|
||||
{Embedding: "/1jku0G/rLvA/EI8"},
|
||||
}
|
||||
@@ -115,8 +115,8 @@ func TestEmbeddingEndpoint(t *testing.T) {
|
||||
"/v1/embeddings",
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
var req struct {
|
||||
EncodingFormat EmbeddingEncodingFormat `json:"encoding_format"`
|
||||
User string `json:"user"`
|
||||
EncodingFormat openai.EmbeddingEncodingFormat `json:"encoding_format"`
|
||||
User string `json:"user"`
|
||||
}
|
||||
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||
|
||||
@@ -125,16 +125,16 @@ func TestEmbeddingEndpoint(t *testing.T) {
|
||||
case req.User == "invalid":
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
case req.EncodingFormat == EmbeddingEncodingFormatBase64:
|
||||
resBytes, _ = json.Marshal(EmbeddingResponseBase64{Data: sampleBase64Embeddings})
|
||||
case req.EncodingFormat == openai.EmbeddingEncodingFormatBase64:
|
||||
resBytes, _ = json.Marshal(openai.EmbeddingResponseBase64{Data: sampleBase64Embeddings})
|
||||
default:
|
||||
resBytes, _ = json.Marshal(EmbeddingResponse{Data: sampleEmbeddings})
|
||||
resBytes, _ = json.Marshal(openai.EmbeddingResponse{Data: sampleEmbeddings})
|
||||
}
|
||||
fmt.Fprintln(w, string(resBytes))
|
||||
},
|
||||
)
|
||||
// test create embeddings with strings (simple embedding request)
|
||||
res, err := client.CreateEmbeddings(context.Background(), EmbeddingRequest{})
|
||||
res, err := client.CreateEmbeddings(context.Background(), openai.EmbeddingRequest{})
|
||||
checks.NoError(t, err, "CreateEmbeddings error")
|
||||
if !reflect.DeepEqual(res.Data, sampleEmbeddings) {
|
||||
t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data)
|
||||
@@ -143,8 +143,8 @@ func TestEmbeddingEndpoint(t *testing.T) {
|
||||
// test create embeddings with strings (simple embedding request)
|
||||
res, err = client.CreateEmbeddings(
|
||||
context.Background(),
|
||||
EmbeddingRequest{
|
||||
EncodingFormat: EmbeddingEncodingFormatBase64,
|
||||
openai.EmbeddingRequest{
|
||||
EncodingFormat: openai.EmbeddingEncodingFormatBase64,
|
||||
},
|
||||
)
|
||||
checks.NoError(t, err, "CreateEmbeddings error")
|
||||
@@ -153,23 +153,23 @@ func TestEmbeddingEndpoint(t *testing.T) {
|
||||
}
|
||||
|
||||
// test create embeddings with strings
|
||||
res, err = client.CreateEmbeddings(context.Background(), EmbeddingRequestStrings{})
|
||||
res, err = client.CreateEmbeddings(context.Background(), openai.EmbeddingRequestStrings{})
|
||||
checks.NoError(t, err, "CreateEmbeddings strings error")
|
||||
if !reflect.DeepEqual(res.Data, sampleEmbeddings) {
|
||||
t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data)
|
||||
}
|
||||
|
||||
// test create embeddings with tokens
|
||||
res, err = client.CreateEmbeddings(context.Background(), EmbeddingRequestTokens{})
|
||||
res, err = client.CreateEmbeddings(context.Background(), openai.EmbeddingRequestTokens{})
|
||||
checks.NoError(t, err, "CreateEmbeddings tokens error")
|
||||
if !reflect.DeepEqual(res.Data, sampleEmbeddings) {
|
||||
t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data)
|
||||
}
|
||||
|
||||
// test failed sendRequest
|
||||
_, err = client.CreateEmbeddings(context.Background(), EmbeddingRequest{
|
||||
_, err = client.CreateEmbeddings(context.Background(), openai.EmbeddingRequest{
|
||||
User: "invalid",
|
||||
EncodingFormat: EmbeddingEncodingFormatBase64,
|
||||
EncodingFormat: openai.EmbeddingEncodingFormatBase64,
|
||||
})
|
||||
checks.HasError(t, err, "CreateEmbeddings error")
|
||||
}
|
||||
@@ -177,26 +177,26 @@ func TestEmbeddingEndpoint(t *testing.T) {
|
||||
func TestEmbeddingResponseBase64_ToEmbeddingResponse(t *testing.T) {
|
||||
type fields struct {
|
||||
Object string
|
||||
Data []Base64Embedding
|
||||
Model EmbeddingModel
|
||||
Usage Usage
|
||||
Data []openai.Base64Embedding
|
||||
Model openai.EmbeddingModel
|
||||
Usage openai.Usage
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
want EmbeddingResponse
|
||||
want openai.EmbeddingResponse
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "test embedding response base64 to embedding response",
|
||||
fields: fields{
|
||||
Data: []Base64Embedding{
|
||||
Data: []openai.Base64Embedding{
|
||||
{Embedding: "pHCdP4XrkUDhevxA"},
|
||||
{Embedding: "/1jku0G/rLvA/EI8"},
|
||||
},
|
||||
},
|
||||
want: EmbeddingResponse{
|
||||
Data: []Embedding{
|
||||
want: openai.EmbeddingResponse{
|
||||
Data: []openai.Embedding{
|
||||
{Embedding: []float32{1.23, 4.56, 7.89}},
|
||||
{Embedding: []float32{-0.006968617, -0.0052718227, 0.011901081}},
|
||||
},
|
||||
@@ -206,19 +206,19 @@ func TestEmbeddingResponseBase64_ToEmbeddingResponse(t *testing.T) {
|
||||
{
|
||||
name: "Invalid embedding",
|
||||
fields: fields{
|
||||
Data: []Base64Embedding{
|
||||
Data: []openai.Base64Embedding{
|
||||
{
|
||||
Embedding: "----",
|
||||
},
|
||||
},
|
||||
},
|
||||
want: EmbeddingResponse{},
|
||||
want: openai.EmbeddingResponse{},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r := &EmbeddingResponseBase64{
|
||||
r := &openai.EmbeddingResponseBase64{
|
||||
Object: tt.fields.Object,
|
||||
Data: tt.fields.Data,
|
||||
Model: tt.fields.Model,
|
||||
@@ -237,8 +237,8 @@ func TestEmbeddingResponseBase64_ToEmbeddingResponse(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDotProduct(t *testing.T) {
|
||||
v1 := &Embedding{Embedding: []float32{1, 2, 3}}
|
||||
v2 := &Embedding{Embedding: []float32{2, 4, 6}}
|
||||
v1 := &openai.Embedding{Embedding: []float32{1, 2, 3}}
|
||||
v2 := &openai.Embedding{Embedding: []float32{2, 4, 6}}
|
||||
expected := float32(28.0)
|
||||
|
||||
result, err := v1.DotProduct(v2)
|
||||
@@ -250,8 +250,8 @@ func TestDotProduct(t *testing.T) {
|
||||
t.Errorf("Unexpected result. Expected: %v, but got %v", expected, result)
|
||||
}
|
||||
|
||||
v1 = &Embedding{Embedding: []float32{1, 0, 0}}
|
||||
v2 = &Embedding{Embedding: []float32{0, 1, 0}}
|
||||
v1 = &openai.Embedding{Embedding: []float32{1, 0, 0}}
|
||||
v2 = &openai.Embedding{Embedding: []float32{0, 1, 0}}
|
||||
expected = float32(0.0)
|
||||
|
||||
result, err = v1.DotProduct(v2)
|
||||
@@ -264,10 +264,10 @@ func TestDotProduct(t *testing.T) {
|
||||
}
|
||||
|
||||
// Test for VectorLengthMismatchError
|
||||
v1 = &Embedding{Embedding: []float32{1, 0, 0}}
|
||||
v2 = &Embedding{Embedding: []float32{0, 1}}
|
||||
v1 = &openai.Embedding{Embedding: []float32{1, 0, 0}}
|
||||
v2 = &openai.Embedding{Embedding: []float32{0, 1}}
|
||||
_, err = v1.DotProduct(v2)
|
||||
if !errors.Is(err, ErrVectorLengthMismatch) {
|
||||
if !errors.Is(err, openai.ErrVectorLengthMismatch) {
|
||||
t.Errorf("Expected Vector Length Mismatch Error, but got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user