lint: fix linter warnings reported by golangci-lint (#522)

- Fix #519
This commit is contained in:
Simon Klee
2023-11-07 10:23:06 +01:00
committed by GitHub
parent 9e0232f941
commit 0664105387
23 changed files with 425 additions and 431 deletions

View File

@@ -9,7 +9,6 @@ import (
"os" "os"
"testing" "testing"
. "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks" "github.com/sashabaranov/go-openai/internal/test/checks"
"github.com/sashabaranov/go-openai/jsonschema" "github.com/sashabaranov/go-openai/jsonschema"
) )

View File

@@ -12,7 +12,7 @@ import (
"strings" "strings"
"testing" "testing"
. "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test" "github.com/sashabaranov/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks" "github.com/sashabaranov/go-openai/internal/test/checks"
) )
@@ -26,7 +26,7 @@ func TestAudio(t *testing.T) {
testcases := []struct { testcases := []struct {
name string name string
createFn func(context.Context, AudioRequest) (AudioResponse, error) createFn func(context.Context, openai.AudioRequest) (openai.AudioResponse, error)
}{ }{
{ {
"transcribe", "transcribe",
@@ -48,7 +48,7 @@ func TestAudio(t *testing.T) {
path := filepath.Join(dir, "fake.mp3") path := filepath.Join(dir, "fake.mp3")
test.CreateTestFile(t, path) test.CreateTestFile(t, path)
req := AudioRequest{ req := openai.AudioRequest{
FilePath: path, FilePath: path,
Model: "whisper-3", Model: "whisper-3",
} }
@@ -57,7 +57,7 @@ func TestAudio(t *testing.T) {
}) })
t.Run(tc.name+" (with reader)", func(t *testing.T) { t.Run(tc.name+" (with reader)", func(t *testing.T) {
req := AudioRequest{ req := openai.AudioRequest{
FilePath: "fake.webm", FilePath: "fake.webm",
Reader: bytes.NewBuffer([]byte(`some webm binary data`)), Reader: bytes.NewBuffer([]byte(`some webm binary data`)),
Model: "whisper-3", Model: "whisper-3",
@@ -76,7 +76,7 @@ func TestAudioWithOptionalArgs(t *testing.T) {
testcases := []struct { testcases := []struct {
name string name string
createFn func(context.Context, AudioRequest) (AudioResponse, error) createFn func(context.Context, openai.AudioRequest) (openai.AudioResponse, error)
}{ }{
{ {
"transcribe", "transcribe",
@@ -98,13 +98,13 @@ func TestAudioWithOptionalArgs(t *testing.T) {
path := filepath.Join(dir, "fake.mp3") path := filepath.Join(dir, "fake.mp3")
test.CreateTestFile(t, path) test.CreateTestFile(t, path)
req := AudioRequest{ req := openai.AudioRequest{
FilePath: path, FilePath: path,
Model: "whisper-3", Model: "whisper-3",
Prompt: "用简体中文", Prompt: "用简体中文",
Temperature: 0.5, Temperature: 0.5,
Language: "zh", Language: "zh",
Format: AudioResponseFormatSRT, Format: openai.AudioResponseFormatSRT,
} }
_, err := tc.createFn(ctx, req) _, err := tc.createFn(ctx, req)
checks.NoError(t, err, "audio API error") checks.NoError(t, err, "audio API error")

View File

@@ -40,7 +40,7 @@ func TestAudioWithFailingFormBuilder(t *testing.T) {
} }
var failForField string var failForField string
mockBuilder.mockWriteField = func(fieldname, value string) error { mockBuilder.mockWriteField = func(fieldname, _ string) error {
if fieldname == failForField { if fieldname == failForField {
return mockFailedErr return mockFailedErr
} }

View File

@@ -10,28 +10,28 @@ import (
"strconv" "strconv"
"testing" "testing"
. "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks" "github.com/sashabaranov/go-openai/internal/test/checks"
) )
func TestChatCompletionsStreamWrongModel(t *testing.T) { func TestChatCompletionsStreamWrongModel(t *testing.T) {
config := DefaultConfig("whatever") config := openai.DefaultConfig("whatever")
config.BaseURL = "http://localhost/v1" config.BaseURL = "http://localhost/v1"
client := NewClientWithConfig(config) client := openai.NewClientWithConfig(config)
ctx := context.Background() ctx := context.Background()
req := ChatCompletionRequest{ req := openai.ChatCompletionRequest{
MaxTokens: 5, MaxTokens: 5,
Model: "ada", Model: "ada",
Messages: []ChatCompletionMessage{ Messages: []openai.ChatCompletionMessage{
{ {
Role: ChatMessageRoleUser, Role: openai.ChatMessageRoleUser,
Content: "Hello!", Content: "Hello!",
}, },
}, },
} }
_, err := client.CreateChatCompletionStream(ctx, req) _, err := client.CreateChatCompletionStream(ctx, req)
if !errors.Is(err, ErrChatCompletionInvalidModel) { if !errors.Is(err, openai.ErrChatCompletionInvalidModel) {
t.Fatalf("CreateChatCompletion should return ErrChatCompletionInvalidModel, but returned: %v", err) t.Fatalf("CreateChatCompletion should return ErrChatCompletionInvalidModel, but returned: %v", err)
} }
} }
@@ -39,7 +39,7 @@ func TestChatCompletionsStreamWrongModel(t *testing.T) {
func TestCreateChatCompletionStream(t *testing.T) { func TestCreateChatCompletionStream(t *testing.T) {
client, server, teardown := setupOpenAITestServer() client, server, teardown := setupOpenAITestServer()
defer teardown() defer teardown()
server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Content-Type", "text/event-stream")
// Send test responses // Send test responses
@@ -61,12 +61,12 @@ func TestCreateChatCompletionStream(t *testing.T) {
checks.NoError(t, err, "Write error") checks.NoError(t, err, "Write error")
}) })
stream, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{ stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{
MaxTokens: 5, MaxTokens: 5,
Model: GPT3Dot5Turbo, Model: openai.GPT3Dot5Turbo,
Messages: []ChatCompletionMessage{ Messages: []openai.ChatCompletionMessage{
{ {
Role: ChatMessageRoleUser, Role: openai.ChatMessageRoleUser,
Content: "Hello!", Content: "Hello!",
}, },
}, },
@@ -75,15 +75,15 @@ func TestCreateChatCompletionStream(t *testing.T) {
checks.NoError(t, err, "CreateCompletionStream returned error") checks.NoError(t, err, "CreateCompletionStream returned error")
defer stream.Close() defer stream.Close()
expectedResponses := []ChatCompletionStreamResponse{ expectedResponses := []openai.ChatCompletionStreamResponse{
{ {
ID: "1", ID: "1",
Object: "completion", Object: "completion",
Created: 1598069254, Created: 1598069254,
Model: GPT3Dot5Turbo, Model: openai.GPT3Dot5Turbo,
Choices: []ChatCompletionStreamChoice{ Choices: []openai.ChatCompletionStreamChoice{
{ {
Delta: ChatCompletionStreamChoiceDelta{ Delta: openai.ChatCompletionStreamChoiceDelta{
Content: "response1", Content: "response1",
}, },
FinishReason: "max_tokens", FinishReason: "max_tokens",
@@ -94,10 +94,10 @@ func TestCreateChatCompletionStream(t *testing.T) {
ID: "2", ID: "2",
Object: "completion", Object: "completion",
Created: 1598069255, Created: 1598069255,
Model: GPT3Dot5Turbo, Model: openai.GPT3Dot5Turbo,
Choices: []ChatCompletionStreamChoice{ Choices: []openai.ChatCompletionStreamChoice{
{ {
Delta: ChatCompletionStreamChoiceDelta{ Delta: openai.ChatCompletionStreamChoiceDelta{
Content: "response2", Content: "response2",
}, },
FinishReason: "max_tokens", FinishReason: "max_tokens",
@@ -133,7 +133,7 @@ func TestCreateChatCompletionStream(t *testing.T) {
func TestCreateChatCompletionStreamError(t *testing.T) { func TestCreateChatCompletionStreamError(t *testing.T) {
client, server, teardown := setupOpenAITestServer() client, server, teardown := setupOpenAITestServer()
defer teardown() defer teardown()
server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Content-Type", "text/event-stream")
// Send test responses // Send test responses
@@ -156,12 +156,12 @@ func TestCreateChatCompletionStreamError(t *testing.T) {
checks.NoError(t, err, "Write error") checks.NoError(t, err, "Write error")
}) })
stream, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{ stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{
MaxTokens: 5, MaxTokens: 5,
Model: GPT3Dot5Turbo, Model: openai.GPT3Dot5Turbo,
Messages: []ChatCompletionMessage{ Messages: []openai.ChatCompletionMessage{
{ {
Role: ChatMessageRoleUser, Role: openai.ChatMessageRoleUser,
Content: "Hello!", Content: "Hello!",
}, },
}, },
@@ -173,7 +173,7 @@ func TestCreateChatCompletionStreamError(t *testing.T) {
_, streamErr := stream.Recv() _, streamErr := stream.Recv()
checks.HasError(t, streamErr, "stream.Recv() did not return error") checks.HasError(t, streamErr, "stream.Recv() did not return error")
var apiErr *APIError var apiErr *openai.APIError
if !errors.As(streamErr, &apiErr) { if !errors.As(streamErr, &apiErr) {
t.Errorf("stream.Recv() did not return APIError") t.Errorf("stream.Recv() did not return APIError")
} }
@@ -183,7 +183,7 @@ func TestCreateChatCompletionStreamError(t *testing.T) {
func TestCreateChatCompletionStreamWithHeaders(t *testing.T) { func TestCreateChatCompletionStreamWithHeaders(t *testing.T) {
client, server, teardown := setupOpenAITestServer() client, server, teardown := setupOpenAITestServer()
defer teardown() defer teardown()
server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set(xCustomHeader, xCustomHeaderValue) w.Header().Set(xCustomHeader, xCustomHeaderValue)
@@ -196,12 +196,12 @@ func TestCreateChatCompletionStreamWithHeaders(t *testing.T) {
checks.NoError(t, err, "Write error") checks.NoError(t, err, "Write error")
}) })
stream, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{ stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{
MaxTokens: 5, MaxTokens: 5,
Model: GPT3Dot5Turbo, Model: openai.GPT3Dot5Turbo,
Messages: []ChatCompletionMessage{ Messages: []openai.ChatCompletionMessage{
{ {
Role: ChatMessageRoleUser, Role: openai.ChatMessageRoleUser,
Content: "Hello!", Content: "Hello!",
}, },
}, },
@@ -219,7 +219,7 @@ func TestCreateChatCompletionStreamWithHeaders(t *testing.T) {
func TestCreateChatCompletionStreamWithRatelimitHeaders(t *testing.T) { func TestCreateChatCompletionStreamWithRatelimitHeaders(t *testing.T) {
client, server, teardown := setupOpenAITestServer() client, server, teardown := setupOpenAITestServer()
defer teardown() defer teardown()
server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Content-Type", "text/event-stream")
for k, v := range rateLimitHeaders { for k, v := range rateLimitHeaders {
switch val := v.(type) { switch val := v.(type) {
@@ -239,12 +239,12 @@ func TestCreateChatCompletionStreamWithRatelimitHeaders(t *testing.T) {
checks.NoError(t, err, "Write error") checks.NoError(t, err, "Write error")
}) })
stream, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{ stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{
MaxTokens: 5, MaxTokens: 5,
Model: GPT3Dot5Turbo, Model: openai.GPT3Dot5Turbo,
Messages: []ChatCompletionMessage{ Messages: []openai.ChatCompletionMessage{
{ {
Role: ChatMessageRoleUser, Role: openai.ChatMessageRoleUser,
Content: "Hello!", Content: "Hello!",
}, },
}, },
@@ -264,7 +264,7 @@ func TestCreateChatCompletionStreamWithRatelimitHeaders(t *testing.T) {
func TestCreateChatCompletionStreamErrorWithDataPrefix(t *testing.T) { func TestCreateChatCompletionStreamErrorWithDataPrefix(t *testing.T) {
client, server, teardown := setupOpenAITestServer() client, server, teardown := setupOpenAITestServer()
defer teardown() defer teardown()
server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Content-Type", "text/event-stream")
// Send test responses // Send test responses
@@ -276,12 +276,12 @@ func TestCreateChatCompletionStreamErrorWithDataPrefix(t *testing.T) {
checks.NoError(t, err, "Write error") checks.NoError(t, err, "Write error")
}) })
stream, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{ stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{
MaxTokens: 5, MaxTokens: 5,
Model: GPT3Dot5Turbo, Model: openai.GPT3Dot5Turbo,
Messages: []ChatCompletionMessage{ Messages: []openai.ChatCompletionMessage{
{ {
Role: ChatMessageRoleUser, Role: openai.ChatMessageRoleUser,
Content: "Hello!", Content: "Hello!",
}, },
}, },
@@ -293,7 +293,7 @@ func TestCreateChatCompletionStreamErrorWithDataPrefix(t *testing.T) {
_, streamErr := stream.Recv() _, streamErr := stream.Recv()
checks.HasError(t, streamErr, "stream.Recv() did not return error") checks.HasError(t, streamErr, "stream.Recv() did not return error")
var apiErr *APIError var apiErr *openai.APIError
if !errors.As(streamErr, &apiErr) { if !errors.As(streamErr, &apiErr) {
t.Errorf("stream.Recv() did not return APIError") t.Errorf("stream.Recv() did not return APIError")
} }
@@ -303,7 +303,7 @@ func TestCreateChatCompletionStreamErrorWithDataPrefix(t *testing.T) {
func TestCreateChatCompletionStreamRateLimitError(t *testing.T) { func TestCreateChatCompletionStreamRateLimitError(t *testing.T) {
client, server, teardown := setupOpenAITestServer() client, server, teardown := setupOpenAITestServer()
defer teardown() defer teardown()
server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.WriteHeader(429) w.WriteHeader(429)
@@ -317,18 +317,18 @@ func TestCreateChatCompletionStreamRateLimitError(t *testing.T) {
_, err := w.Write(dataBytes) _, err := w.Write(dataBytes)
checks.NoError(t, err, "Write error") checks.NoError(t, err, "Write error")
}) })
_, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{ _, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{
MaxTokens: 5, MaxTokens: 5,
Model: GPT3Dot5Turbo, Model: openai.GPT3Dot5Turbo,
Messages: []ChatCompletionMessage{ Messages: []openai.ChatCompletionMessage{
{ {
Role: ChatMessageRoleUser, Role: openai.ChatMessageRoleUser,
Content: "Hello!", Content: "Hello!",
}, },
}, },
Stream: true, Stream: true,
}) })
var apiErr *APIError var apiErr *openai.APIError
if !errors.As(err, &apiErr) { if !errors.As(err, &apiErr) {
t.Errorf("TestCreateChatCompletionStreamRateLimitError did not return APIError") t.Errorf("TestCreateChatCompletionStreamRateLimitError did not return APIError")
} }
@@ -345,7 +345,7 @@ func TestAzureCreateChatCompletionStreamRateLimitError(t *testing.T) {
client, server, teardown := setupAzureTestServer() client, server, teardown := setupAzureTestServer()
defer teardown() defer teardown()
server.RegisterHandler("/openai/deployments/gpt-35-turbo/chat/completions", server.RegisterHandler("/openai/deployments/gpt-35-turbo/chat/completions",
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusTooManyRequests) w.WriteHeader(http.StatusTooManyRequests)
// Send test responses // Send test responses
@@ -355,13 +355,13 @@ func TestAzureCreateChatCompletionStreamRateLimitError(t *testing.T) {
checks.NoError(t, err, "Write error") checks.NoError(t, err, "Write error")
}) })
apiErr := &APIError{} apiErr := &openai.APIError{}
_, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{ _, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{
MaxTokens: 5, MaxTokens: 5,
Model: GPT3Dot5Turbo, Model: openai.GPT3Dot5Turbo,
Messages: []ChatCompletionMessage{ Messages: []openai.ChatCompletionMessage{
{ {
Role: ChatMessageRoleUser, Role: openai.ChatMessageRoleUser,
Content: "Hello!", Content: "Hello!",
}, },
}, },
@@ -387,7 +387,7 @@ func TestAzureCreateChatCompletionStreamRateLimitError(t *testing.T) {
} }
// Helper funcs. // Helper funcs.
func compareChatResponses(r1, r2 ChatCompletionStreamResponse) bool { func compareChatResponses(r1, r2 openai.ChatCompletionStreamResponse) bool {
if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model { if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model {
return false return false
} }
@@ -402,7 +402,7 @@ func compareChatResponses(r1, r2 ChatCompletionStreamResponse) bool {
return true return true
} }
func compareChatStreamResponseChoices(c1, c2 ChatCompletionStreamChoice) bool { func compareChatStreamResponseChoices(c1, c2 openai.ChatCompletionStreamChoice) bool {
if c1.Index != c2.Index { if c1.Index != c2.Index {
return false return false
} }

View File

@@ -11,7 +11,7 @@ import (
"testing" "testing"
"time" "time"
. "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks" "github.com/sashabaranov/go-openai/internal/test/checks"
"github.com/sashabaranov/go-openai/jsonschema" "github.com/sashabaranov/go-openai/jsonschema"
) )
@@ -21,49 +21,47 @@ const (
xCustomHeaderValue = "test" xCustomHeaderValue = "test"
) )
var ( var rateLimitHeaders = map[string]any{
rateLimitHeaders = map[string]any{ "x-ratelimit-limit-requests": 60,
"x-ratelimit-limit-requests": 60, "x-ratelimit-limit-tokens": 150000,
"x-ratelimit-limit-tokens": 150000, "x-ratelimit-remaining-requests": 59,
"x-ratelimit-remaining-requests": 59, "x-ratelimit-remaining-tokens": 149984,
"x-ratelimit-remaining-tokens": 149984, "x-ratelimit-reset-requests": "1s",
"x-ratelimit-reset-requests": "1s", "x-ratelimit-reset-tokens": "6m0s",
"x-ratelimit-reset-tokens": "6m0s", }
}
)
func TestChatCompletionsWrongModel(t *testing.T) { func TestChatCompletionsWrongModel(t *testing.T) {
config := DefaultConfig("whatever") config := openai.DefaultConfig("whatever")
config.BaseURL = "http://localhost/v1" config.BaseURL = "http://localhost/v1"
client := NewClientWithConfig(config) client := openai.NewClientWithConfig(config)
ctx := context.Background() ctx := context.Background()
req := ChatCompletionRequest{ req := openai.ChatCompletionRequest{
MaxTokens: 5, MaxTokens: 5,
Model: "ada", Model: "ada",
Messages: []ChatCompletionMessage{ Messages: []openai.ChatCompletionMessage{
{ {
Role: ChatMessageRoleUser, Role: openai.ChatMessageRoleUser,
Content: "Hello!", Content: "Hello!",
}, },
}, },
} }
_, err := client.CreateChatCompletion(ctx, req) _, err := client.CreateChatCompletion(ctx, req)
msg := fmt.Sprintf("CreateChatCompletion should return wrong model error, returned: %s", err) msg := fmt.Sprintf("CreateChatCompletion should return wrong model error, returned: %s", err)
checks.ErrorIs(t, err, ErrChatCompletionInvalidModel, msg) checks.ErrorIs(t, err, openai.ErrChatCompletionInvalidModel, msg)
} }
func TestChatCompletionsWithStream(t *testing.T) { func TestChatCompletionsWithStream(t *testing.T) {
config := DefaultConfig("whatever") config := openai.DefaultConfig("whatever")
config.BaseURL = "http://localhost/v1" config.BaseURL = "http://localhost/v1"
client := NewClientWithConfig(config) client := openai.NewClientWithConfig(config)
ctx := context.Background() ctx := context.Background()
req := ChatCompletionRequest{ req := openai.ChatCompletionRequest{
Stream: true, Stream: true,
} }
_, err := client.CreateChatCompletion(ctx, req) _, err := client.CreateChatCompletion(ctx, req)
checks.ErrorIs(t, err, ErrChatCompletionStreamNotSupported, "unexpected error") checks.ErrorIs(t, err, openai.ErrChatCompletionStreamNotSupported, "unexpected error")
} }
// TestCompletions Tests the completions endpoint of the API using the mocked server. // TestCompletions Tests the completions endpoint of the API using the mocked server.
@@ -71,12 +69,12 @@ func TestChatCompletions(t *testing.T) {
client, server, teardown := setupOpenAITestServer() client, server, teardown := setupOpenAITestServer()
defer teardown() defer teardown()
server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint) server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint)
_, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{ _, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{
MaxTokens: 5, MaxTokens: 5,
Model: GPT3Dot5Turbo, Model: openai.GPT3Dot5Turbo,
Messages: []ChatCompletionMessage{ Messages: []openai.ChatCompletionMessage{
{ {
Role: ChatMessageRoleUser, Role: openai.ChatMessageRoleUser,
Content: "Hello!", Content: "Hello!",
}, },
}, },
@@ -89,12 +87,12 @@ func TestChatCompletionsWithHeaders(t *testing.T) {
client, server, teardown := setupOpenAITestServer() client, server, teardown := setupOpenAITestServer()
defer teardown() defer teardown()
server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint) server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint)
resp, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{ resp, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{
MaxTokens: 5, MaxTokens: 5,
Model: GPT3Dot5Turbo, Model: openai.GPT3Dot5Turbo,
Messages: []ChatCompletionMessage{ Messages: []openai.ChatCompletionMessage{
{ {
Role: ChatMessageRoleUser, Role: openai.ChatMessageRoleUser,
Content: "Hello!", Content: "Hello!",
}, },
}, },
@@ -113,12 +111,12 @@ func TestChatCompletionsWithRateLimitHeaders(t *testing.T) {
client, server, teardown := setupOpenAITestServer() client, server, teardown := setupOpenAITestServer()
defer teardown() defer teardown()
server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint) server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint)
resp, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{ resp, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{
MaxTokens: 5, MaxTokens: 5,
Model: GPT3Dot5Turbo, Model: openai.GPT3Dot5Turbo,
Messages: []ChatCompletionMessage{ Messages: []openai.ChatCompletionMessage{
{ {
Role: ChatMessageRoleUser, Role: openai.ChatMessageRoleUser,
Content: "Hello!", Content: "Hello!",
}, },
}, },
@@ -150,16 +148,16 @@ func TestChatCompletionsFunctions(t *testing.T) {
t.Run("bytes", func(t *testing.T) { t.Run("bytes", func(t *testing.T) {
//nolint:lll //nolint:lll
msg := json.RawMessage(`{"properties":{"count":{"type":"integer","description":"total number of words in sentence"},"words":{"items":{"type":"string"},"type":"array","description":"list of words in sentence"}},"type":"object","required":["count","words"]}`) msg := json.RawMessage(`{"properties":{"count":{"type":"integer","description":"total number of words in sentence"},"words":{"items":{"type":"string"},"type":"array","description":"list of words in sentence"}},"type":"object","required":["count","words"]}`)
_, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{ _, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{
MaxTokens: 5, MaxTokens: 5,
Model: GPT3Dot5Turbo0613, Model: openai.GPT3Dot5Turbo0613,
Messages: []ChatCompletionMessage{ Messages: []openai.ChatCompletionMessage{
{ {
Role: ChatMessageRoleUser, Role: openai.ChatMessageRoleUser,
Content: "Hello!", Content: "Hello!",
}, },
}, },
Functions: []FunctionDefinition{{ Functions: []openai.FunctionDefinition{{
Name: "test", Name: "test",
Parameters: &msg, Parameters: &msg,
}}, }},
@@ -175,16 +173,16 @@ func TestChatCompletionsFunctions(t *testing.T) {
Count: 2, Count: 2,
Words: []string{"hello", "world"}, Words: []string{"hello", "world"},
} }
_, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{ _, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{
MaxTokens: 5, MaxTokens: 5,
Model: GPT3Dot5Turbo0613, Model: openai.GPT3Dot5Turbo0613,
Messages: []ChatCompletionMessage{ Messages: []openai.ChatCompletionMessage{
{ {
Role: ChatMessageRoleUser, Role: openai.ChatMessageRoleUser,
Content: "Hello!", Content: "Hello!",
}, },
}, },
Functions: []FunctionDefinition{{ Functions: []openai.FunctionDefinition{{
Name: "test", Name: "test",
Parameters: &msg, Parameters: &msg,
}}, }},
@@ -192,16 +190,16 @@ func TestChatCompletionsFunctions(t *testing.T) {
checks.NoError(t, err, "CreateChatCompletion with functions error") checks.NoError(t, err, "CreateChatCompletion with functions error")
}) })
t.Run("JSONSchemaDefinition", func(t *testing.T) { t.Run("JSONSchemaDefinition", func(t *testing.T) {
_, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{ _, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{
MaxTokens: 5, MaxTokens: 5,
Model: GPT3Dot5Turbo0613, Model: openai.GPT3Dot5Turbo0613,
Messages: []ChatCompletionMessage{ Messages: []openai.ChatCompletionMessage{
{ {
Role: ChatMessageRoleUser, Role: openai.ChatMessageRoleUser,
Content: "Hello!", Content: "Hello!",
}, },
}, },
Functions: []FunctionDefinition{{ Functions: []openai.FunctionDefinition{{
Name: "test", Name: "test",
Parameters: &jsonschema.Definition{ Parameters: &jsonschema.Definition{
Type: jsonschema.Object, Type: jsonschema.Object,
@@ -229,16 +227,16 @@ func TestChatCompletionsFunctions(t *testing.T) {
}) })
t.Run("JSONSchemaDefinitionWithFunctionDefine", func(t *testing.T) { t.Run("JSONSchemaDefinitionWithFunctionDefine", func(t *testing.T) {
// this is a compatibility check // this is a compatibility check
_, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{ _, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{
MaxTokens: 5, MaxTokens: 5,
Model: GPT3Dot5Turbo0613, Model: openai.GPT3Dot5Turbo0613,
Messages: []ChatCompletionMessage{ Messages: []openai.ChatCompletionMessage{
{ {
Role: ChatMessageRoleUser, Role: openai.ChatMessageRoleUser,
Content: "Hello!", Content: "Hello!",
}, },
}, },
Functions: []FunctionDefine{{ Functions: []openai.FunctionDefine{{
Name: "test", Name: "test",
Parameters: &jsonschema.Definition{ Parameters: &jsonschema.Definition{
Type: jsonschema.Object, Type: jsonschema.Object,
@@ -271,12 +269,12 @@ func TestAzureChatCompletions(t *testing.T) {
defer teardown() defer teardown()
server.RegisterHandler("/openai/deployments/*", handleChatCompletionEndpoint) server.RegisterHandler("/openai/deployments/*", handleChatCompletionEndpoint)
_, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{ _, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{
MaxTokens: 5, MaxTokens: 5,
Model: GPT3Dot5Turbo, Model: openai.GPT3Dot5Turbo,
Messages: []ChatCompletionMessage{ Messages: []openai.ChatCompletionMessage{
{ {
Role: ChatMessageRoleUser, Role: openai.ChatMessageRoleUser,
Content: "Hello!", Content: "Hello!",
}, },
}, },
@@ -293,12 +291,12 @@ func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" { if r.Method != "POST" {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
} }
var completionReq ChatCompletionRequest var completionReq openai.ChatCompletionRequest
if completionReq, err = getChatCompletionBody(r); err != nil { if completionReq, err = getChatCompletionBody(r); err != nil {
http.Error(w, "could not read request", http.StatusInternalServerError) http.Error(w, "could not read request", http.StatusInternalServerError)
return return
} }
res := ChatCompletionResponse{ res := openai.ChatCompletionResponse{
ID: strconv.Itoa(int(time.Now().Unix())), ID: strconv.Itoa(int(time.Now().Unix())),
Object: "test-object", Object: "test-object",
Created: time.Now().Unix(), Created: time.Now().Unix(),
@@ -323,11 +321,11 @@ func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
return return
} }
res.Choices = append(res.Choices, ChatCompletionChoice{ res.Choices = append(res.Choices, openai.ChatCompletionChoice{
Message: ChatCompletionMessage{ Message: openai.ChatCompletionMessage{
Role: ChatMessageRoleFunction, Role: openai.ChatMessageRoleFunction,
// this is valid json so it should be fine // this is valid json so it should be fine
FunctionCall: &FunctionCall{ FunctionCall: &openai.FunctionCall{
Name: completionReq.Functions[0].Name, Name: completionReq.Functions[0].Name,
Arguments: string(fcb), Arguments: string(fcb),
}, },
@@ -339,9 +337,9 @@ func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
// generate a random string of length completionReq.Length // generate a random string of length completionReq.Length
completionStr := strings.Repeat("a", completionReq.MaxTokens) completionStr := strings.Repeat("a", completionReq.MaxTokens)
res.Choices = append(res.Choices, ChatCompletionChoice{ res.Choices = append(res.Choices, openai.ChatCompletionChoice{
Message: ChatCompletionMessage{ Message: openai.ChatCompletionMessage{
Role: ChatMessageRoleAssistant, Role: openai.ChatMessageRoleAssistant,
Content: completionStr, Content: completionStr,
}, },
Index: i, Index: i,
@@ -349,7 +347,7 @@ func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
} }
inputTokens := numTokens(completionReq.Messages[0].Content) * n inputTokens := numTokens(completionReq.Messages[0].Content) * n
completionTokens := completionReq.MaxTokens * n completionTokens := completionReq.MaxTokens * n
res.Usage = Usage{ res.Usage = openai.Usage{
PromptTokens: inputTokens, PromptTokens: inputTokens,
CompletionTokens: completionTokens, CompletionTokens: completionTokens,
TotalTokens: inputTokens + completionTokens, TotalTokens: inputTokens + completionTokens,
@@ -368,23 +366,23 @@ func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
} }
// getChatCompletionBody Returns the body of the request to create a completion. // getChatCompletionBody Returns the body of the request to create a completion.
func getChatCompletionBody(r *http.Request) (ChatCompletionRequest, error) { func getChatCompletionBody(r *http.Request) (openai.ChatCompletionRequest, error) {
completion := ChatCompletionRequest{} completion := openai.ChatCompletionRequest{}
// read the request body // read the request body
reqBody, err := io.ReadAll(r.Body) reqBody, err := io.ReadAll(r.Body)
if err != nil { if err != nil {
return ChatCompletionRequest{}, err return openai.ChatCompletionRequest{}, err
} }
err = json.Unmarshal(reqBody, &completion) err = json.Unmarshal(reqBody, &completion)
if err != nil { if err != nil {
return ChatCompletionRequest{}, err return openai.ChatCompletionRequest{}, err
} }
return completion, nil return completion, nil
} }
func TestFinishReason(t *testing.T) { func TestFinishReason(t *testing.T) {
c := &ChatCompletionChoice{ c := &openai.ChatCompletionChoice{
FinishReason: FinishReasonNull, FinishReason: openai.FinishReasonNull,
} }
resBytes, _ := json.Marshal(c) resBytes, _ := json.Marshal(c)
if !strings.Contains(string(resBytes), `"finish_reason":null`) { if !strings.Contains(string(resBytes), `"finish_reason":null`) {
@@ -398,11 +396,11 @@ func TestFinishReason(t *testing.T) {
t.Error("null should not be quoted") t.Error("null should not be quoted")
} }
otherReasons := []FinishReason{ otherReasons := []openai.FinishReason{
FinishReasonStop, openai.FinishReasonStop,
FinishReasonLength, openai.FinishReasonLength,
FinishReasonFunctionCall, openai.FinishReasonFunctionCall,
FinishReasonContentFilter, openai.FinishReasonContentFilter,
} }
for _, r := range otherReasons { for _, r := range otherReasons {
c.FinishReason = r c.FinishReason = r

View File

@@ -1,9 +1,6 @@
package openai_test package openai_test
import ( import (
. "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks"
"context" "context"
"encoding/json" "encoding/json"
"errors" "errors"
@@ -14,33 +11,36 @@ import (
"strings" "strings"
"testing" "testing"
"time" "time"
"github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks"
) )
func TestCompletionsWrongModel(t *testing.T) { func TestCompletionsWrongModel(t *testing.T) {
config := DefaultConfig("whatever") config := openai.DefaultConfig("whatever")
config.BaseURL = "http://localhost/v1" config.BaseURL = "http://localhost/v1"
client := NewClientWithConfig(config) client := openai.NewClientWithConfig(config)
_, err := client.CreateCompletion( _, err := client.CreateCompletion(
context.Background(), context.Background(),
CompletionRequest{ openai.CompletionRequest{
MaxTokens: 5, MaxTokens: 5,
Model: GPT3Dot5Turbo, Model: openai.GPT3Dot5Turbo,
}, },
) )
if !errors.Is(err, ErrCompletionUnsupportedModel) { if !errors.Is(err, openai.ErrCompletionUnsupportedModel) {
t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel, but returned: %v", err) t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel, but returned: %v", err)
} }
} }
func TestCompletionWithStream(t *testing.T) { func TestCompletionWithStream(t *testing.T) {
config := DefaultConfig("whatever") config := openai.DefaultConfig("whatever")
client := NewClientWithConfig(config) client := openai.NewClientWithConfig(config)
ctx := context.Background() ctx := context.Background()
req := CompletionRequest{Stream: true} req := openai.CompletionRequest{Stream: true}
_, err := client.CreateCompletion(ctx, req) _, err := client.CreateCompletion(ctx, req)
if !errors.Is(err, ErrCompletionStreamNotSupported) { if !errors.Is(err, openai.ErrCompletionStreamNotSupported) {
t.Fatalf("CreateCompletion didn't return ErrCompletionStreamNotSupported") t.Fatalf("CreateCompletion didn't return ErrCompletionStreamNotSupported")
} }
} }
@@ -50,7 +50,7 @@ func TestCompletions(t *testing.T) {
client, server, teardown := setupOpenAITestServer() client, server, teardown := setupOpenAITestServer()
defer teardown() defer teardown()
server.RegisterHandler("/v1/completions", handleCompletionEndpoint) server.RegisterHandler("/v1/completions", handleCompletionEndpoint)
req := CompletionRequest{ req := openai.CompletionRequest{
MaxTokens: 5, MaxTokens: 5,
Model: "ada", Model: "ada",
Prompt: "Lorem ipsum", Prompt: "Lorem ipsum",
@@ -68,12 +68,12 @@ func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" { if r.Method != "POST" {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
} }
var completionReq CompletionRequest var completionReq openai.CompletionRequest
if completionReq, err = getCompletionBody(r); err != nil { if completionReq, err = getCompletionBody(r); err != nil {
http.Error(w, "could not read request", http.StatusInternalServerError) http.Error(w, "could not read request", http.StatusInternalServerError)
return return
} }
res := CompletionResponse{ res := openai.CompletionResponse{
ID: strconv.Itoa(int(time.Now().Unix())), ID: strconv.Itoa(int(time.Now().Unix())),
Object: "test-object", Object: "test-object",
Created: time.Now().Unix(), Created: time.Now().Unix(),
@@ -93,14 +93,14 @@ func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
if completionReq.Echo { if completionReq.Echo {
completionStr = completionReq.Prompt.(string) + completionStr completionStr = completionReq.Prompt.(string) + completionStr
} }
res.Choices = append(res.Choices, CompletionChoice{ res.Choices = append(res.Choices, openai.CompletionChoice{
Text: completionStr, Text: completionStr,
Index: i, Index: i,
}) })
} }
inputTokens := numTokens(completionReq.Prompt.(string)) * n inputTokens := numTokens(completionReq.Prompt.(string)) * n
completionTokens := completionReq.MaxTokens * n completionTokens := completionReq.MaxTokens * n
res.Usage = Usage{ res.Usage = openai.Usage{
PromptTokens: inputTokens, PromptTokens: inputTokens,
CompletionTokens: completionTokens, CompletionTokens: completionTokens,
TotalTokens: inputTokens + completionTokens, TotalTokens: inputTokens + completionTokens,
@@ -110,16 +110,16 @@ func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
} }
// getCompletionBody Returns the body of the request to create a completion. // getCompletionBody Returns the body of the request to create a completion.
func getCompletionBody(r *http.Request) (CompletionRequest, error) { func getCompletionBody(r *http.Request) (openai.CompletionRequest, error) {
completion := CompletionRequest{} completion := openai.CompletionRequest{}
// read the request body // read the request body
reqBody, err := io.ReadAll(r.Body) reqBody, err := io.ReadAll(r.Body)
if err != nil { if err != nil {
return CompletionRequest{}, err return openai.CompletionRequest{}, err
} }
err = json.Unmarshal(reqBody, &completion) err = json.Unmarshal(reqBody, &completion)
if err != nil { if err != nil {
return CompletionRequest{}, err return openai.CompletionRequest{}, err
} }
return completion, nil return completion, nil
} }

View File

@@ -3,7 +3,7 @@ package openai_test
import ( import (
"testing" "testing"
. "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai"
) )
func TestGetAzureDeploymentByModel(t *testing.T) { func TestGetAzureDeploymentByModel(t *testing.T) {
@@ -49,7 +49,7 @@ func TestGetAzureDeploymentByModel(t *testing.T) {
for _, c := range cases { for _, c := range cases {
t.Run(c.Model, func(t *testing.T) { t.Run(c.Model, func(t *testing.T) {
conf := DefaultAzureConfig("", "https://test.openai.azure.com/") conf := openai.DefaultAzureConfig("", "https://test.openai.azure.com/")
if c.AzureModelMapperFunc != nil { if c.AzureModelMapperFunc != nil {
conf.AzureModelMapperFunc = c.AzureModelMapperFunc conf.AzureModelMapperFunc = c.AzureModelMapperFunc
} }

View File

@@ -1,9 +1,6 @@
package openai_test package openai_test
import ( import (
. "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks"
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
@@ -11,6 +8,9 @@ import (
"net/http" "net/http"
"testing" "testing"
"time" "time"
"github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks"
) )
// TestEdits Tests the edits endpoint of the API using the mocked server. // TestEdits Tests the edits endpoint of the API using the mocked server.
@@ -20,7 +20,7 @@ func TestEdits(t *testing.T) {
server.RegisterHandler("/v1/edits", handleEditEndpoint) server.RegisterHandler("/v1/edits", handleEditEndpoint)
// create an edit request // create an edit request
model := "ada" model := "ada"
editReq := EditsRequest{ editReq := openai.EditsRequest{
Model: &model, Model: &model,
Input: "Lorem ipsum dolor sit amet, consectetur adipiscing elit, " + Input: "Lorem ipsum dolor sit amet, consectetur adipiscing elit, " +
"sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim" + "sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim" +
@@ -45,14 +45,14 @@ func handleEditEndpoint(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" { if r.Method != "POST" {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
} }
var editReq EditsRequest var editReq openai.EditsRequest
editReq, err = getEditBody(r) editReq, err = getEditBody(r)
if err != nil { if err != nil {
http.Error(w, "could not read request", http.StatusInternalServerError) http.Error(w, "could not read request", http.StatusInternalServerError)
return return
} }
// create a response // create a response
res := EditsResponse{ res := openai.EditsResponse{
Object: "test-object", Object: "test-object",
Created: time.Now().Unix(), Created: time.Now().Unix(),
} }
@@ -62,12 +62,12 @@ func handleEditEndpoint(w http.ResponseWriter, r *http.Request) {
completionTokens := int(float32(len(editString))/4) * editReq.N completionTokens := int(float32(len(editString))/4) * editReq.N
for i := 0; i < editReq.N; i++ { for i := 0; i < editReq.N; i++ {
// instruction will be hidden and only seen by OpenAI // instruction will be hidden and only seen by OpenAI
res.Choices = append(res.Choices, EditsChoice{ res.Choices = append(res.Choices, openai.EditsChoice{
Text: editReq.Input + editString, Text: editReq.Input + editString,
Index: i, Index: i,
}) })
} }
res.Usage = Usage{ res.Usage = openai.Usage{
PromptTokens: inputTokens, PromptTokens: inputTokens,
CompletionTokens: completionTokens, CompletionTokens: completionTokens,
TotalTokens: inputTokens + completionTokens, TotalTokens: inputTokens + completionTokens,
@@ -77,16 +77,16 @@ func handleEditEndpoint(w http.ResponseWriter, r *http.Request) {
} }
// getEditBody Returns the body of the request to create an edit. // getEditBody Returns the body of the request to create an edit.
func getEditBody(r *http.Request) (EditsRequest, error) { func getEditBody(r *http.Request) (openai.EditsRequest, error) {
edit := EditsRequest{} edit := openai.EditsRequest{}
// read the request body // read the request body
reqBody, err := io.ReadAll(r.Body) reqBody, err := io.ReadAll(r.Body)
if err != nil { if err != nil {
return EditsRequest{}, err return openai.EditsRequest{}, err
} }
err = json.Unmarshal(reqBody, &edit) err = json.Unmarshal(reqBody, &edit)
if err != nil { if err != nil {
return EditsRequest{}, err return openai.EditsRequest{}, err
} }
return edit, nil return edit, nil
} }

View File

@@ -11,32 +11,32 @@ import (
"reflect" "reflect"
"testing" "testing"
. "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks" "github.com/sashabaranov/go-openai/internal/test/checks"
) )
func TestEmbedding(t *testing.T) { func TestEmbedding(t *testing.T) {
embeddedModels := []EmbeddingModel{ embeddedModels := []openai.EmbeddingModel{
AdaSimilarity, openai.AdaSimilarity,
BabbageSimilarity, openai.BabbageSimilarity,
CurieSimilarity, openai.CurieSimilarity,
DavinciSimilarity, openai.DavinciSimilarity,
AdaSearchDocument, openai.AdaSearchDocument,
AdaSearchQuery, openai.AdaSearchQuery,
BabbageSearchDocument, openai.BabbageSearchDocument,
BabbageSearchQuery, openai.BabbageSearchQuery,
CurieSearchDocument, openai.CurieSearchDocument,
CurieSearchQuery, openai.CurieSearchQuery,
DavinciSearchDocument, openai.DavinciSearchDocument,
DavinciSearchQuery, openai.DavinciSearchQuery,
AdaCodeSearchCode, openai.AdaCodeSearchCode,
AdaCodeSearchText, openai.AdaCodeSearchText,
BabbageCodeSearchCode, openai.BabbageCodeSearchCode,
BabbageCodeSearchText, openai.BabbageCodeSearchText,
} }
for _, model := range embeddedModels { for _, model := range embeddedModels {
// test embedding request with strings (simple embedding request) // test embedding request with strings (simple embedding request)
embeddingReq := EmbeddingRequest{ embeddingReq := openai.EmbeddingRequest{
Input: []string{ Input: []string{
"The food was delicious and the waiter", "The food was delicious and the waiter",
"Other examples of embedding request", "Other examples of embedding request",
@@ -52,7 +52,7 @@ func TestEmbedding(t *testing.T) {
} }
// test embedding request with strings // test embedding request with strings
embeddingReqStrings := EmbeddingRequestStrings{ embeddingReqStrings := openai.EmbeddingRequestStrings{
Input: []string{ Input: []string{
"The food was delicious and the waiter", "The food was delicious and the waiter",
"Other examples of embedding request", "Other examples of embedding request",
@@ -66,7 +66,7 @@ func TestEmbedding(t *testing.T) {
} }
// test embedding request with tokens // test embedding request with tokens
embeddingReqTokens := EmbeddingRequestTokens{ embeddingReqTokens := openai.EmbeddingRequestTokens{
Input: [][]int{ Input: [][]int{
{464, 2057, 373, 12625, 290, 262, 46612}, {464, 2057, 373, 12625, 290, 262, 46612},
{6395, 6096, 286, 11525, 12083, 2581}, {6395, 6096, 286, 11525, 12083, 2581},
@@ -82,17 +82,17 @@ func TestEmbedding(t *testing.T) {
} }
func TestEmbeddingModel(t *testing.T) { func TestEmbeddingModel(t *testing.T) {
var em EmbeddingModel var em openai.EmbeddingModel
err := em.UnmarshalText([]byte("text-similarity-ada-001")) err := em.UnmarshalText([]byte("text-similarity-ada-001"))
checks.NoError(t, err, "Could not marshal embedding model") checks.NoError(t, err, "Could not marshal embedding model")
if em != AdaSimilarity { if em != openai.AdaSimilarity {
t.Errorf("Model is not equal to AdaSimilarity") t.Errorf("Model is not equal to AdaSimilarity")
} }
err = em.UnmarshalText([]byte("some-non-existent-model")) err = em.UnmarshalText([]byte("some-non-existent-model"))
checks.NoError(t, err, "Could not marshal embedding model") checks.NoError(t, err, "Could not marshal embedding model")
if em != Unknown { if em != openai.Unknown {
t.Errorf("Model is not equal to Unknown") t.Errorf("Model is not equal to Unknown")
} }
} }
@@ -101,12 +101,12 @@ func TestEmbeddingEndpoint(t *testing.T) {
client, server, teardown := setupOpenAITestServer() client, server, teardown := setupOpenAITestServer()
defer teardown() defer teardown()
sampleEmbeddings := []Embedding{ sampleEmbeddings := []openai.Embedding{
{Embedding: []float32{1.23, 4.56, 7.89}}, {Embedding: []float32{1.23, 4.56, 7.89}},
{Embedding: []float32{-0.006968617, -0.0052718227, 0.011901081}}, {Embedding: []float32{-0.006968617, -0.0052718227, 0.011901081}},
} }
sampleBase64Embeddings := []Base64Embedding{ sampleBase64Embeddings := []openai.Base64Embedding{
{Embedding: "pHCdP4XrkUDhevxA"}, {Embedding: "pHCdP4XrkUDhevxA"},
{Embedding: "/1jku0G/rLvA/EI8"}, {Embedding: "/1jku0G/rLvA/EI8"},
} }
@@ -115,8 +115,8 @@ func TestEmbeddingEndpoint(t *testing.T) {
"/v1/embeddings", "/v1/embeddings",
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
var req struct { var req struct {
EncodingFormat EmbeddingEncodingFormat `json:"encoding_format"` EncodingFormat openai.EmbeddingEncodingFormat `json:"encoding_format"`
User string `json:"user"` User string `json:"user"`
} }
_ = json.NewDecoder(r.Body).Decode(&req) _ = json.NewDecoder(r.Body).Decode(&req)
@@ -125,16 +125,16 @@ func TestEmbeddingEndpoint(t *testing.T) {
case req.User == "invalid": case req.User == "invalid":
w.WriteHeader(http.StatusBadRequest) w.WriteHeader(http.StatusBadRequest)
return return
case req.EncodingFormat == EmbeddingEncodingFormatBase64: case req.EncodingFormat == openai.EmbeddingEncodingFormatBase64:
resBytes, _ = json.Marshal(EmbeddingResponseBase64{Data: sampleBase64Embeddings}) resBytes, _ = json.Marshal(openai.EmbeddingResponseBase64{Data: sampleBase64Embeddings})
default: default:
resBytes, _ = json.Marshal(EmbeddingResponse{Data: sampleEmbeddings}) resBytes, _ = json.Marshal(openai.EmbeddingResponse{Data: sampleEmbeddings})
} }
fmt.Fprintln(w, string(resBytes)) fmt.Fprintln(w, string(resBytes))
}, },
) )
// test create embeddings with strings (simple embedding request) // test create embeddings with strings (simple embedding request)
res, err := client.CreateEmbeddings(context.Background(), EmbeddingRequest{}) res, err := client.CreateEmbeddings(context.Background(), openai.EmbeddingRequest{})
checks.NoError(t, err, "CreateEmbeddings error") checks.NoError(t, err, "CreateEmbeddings error")
if !reflect.DeepEqual(res.Data, sampleEmbeddings) { if !reflect.DeepEqual(res.Data, sampleEmbeddings) {
t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data) t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data)
@@ -143,8 +143,8 @@ func TestEmbeddingEndpoint(t *testing.T) {
// test create embeddings with strings (simple embedding request) // test create embeddings with strings (simple embedding request)
res, err = client.CreateEmbeddings( res, err = client.CreateEmbeddings(
context.Background(), context.Background(),
EmbeddingRequest{ openai.EmbeddingRequest{
EncodingFormat: EmbeddingEncodingFormatBase64, EncodingFormat: openai.EmbeddingEncodingFormatBase64,
}, },
) )
checks.NoError(t, err, "CreateEmbeddings error") checks.NoError(t, err, "CreateEmbeddings error")
@@ -153,23 +153,23 @@ func TestEmbeddingEndpoint(t *testing.T) {
} }
// test create embeddings with strings // test create embeddings with strings
res, err = client.CreateEmbeddings(context.Background(), EmbeddingRequestStrings{}) res, err = client.CreateEmbeddings(context.Background(), openai.EmbeddingRequestStrings{})
checks.NoError(t, err, "CreateEmbeddings strings error") checks.NoError(t, err, "CreateEmbeddings strings error")
if !reflect.DeepEqual(res.Data, sampleEmbeddings) { if !reflect.DeepEqual(res.Data, sampleEmbeddings) {
t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data) t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data)
} }
// test create embeddings with tokens // test create embeddings with tokens
res, err = client.CreateEmbeddings(context.Background(), EmbeddingRequestTokens{}) res, err = client.CreateEmbeddings(context.Background(), openai.EmbeddingRequestTokens{})
checks.NoError(t, err, "CreateEmbeddings tokens error") checks.NoError(t, err, "CreateEmbeddings tokens error")
if !reflect.DeepEqual(res.Data, sampleEmbeddings) { if !reflect.DeepEqual(res.Data, sampleEmbeddings) {
t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data) t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data)
} }
// test failed sendRequest // test failed sendRequest
_, err = client.CreateEmbeddings(context.Background(), EmbeddingRequest{ _, err = client.CreateEmbeddings(context.Background(), openai.EmbeddingRequest{
User: "invalid", User: "invalid",
EncodingFormat: EmbeddingEncodingFormatBase64, EncodingFormat: openai.EmbeddingEncodingFormatBase64,
}) })
checks.HasError(t, err, "CreateEmbeddings error") checks.HasError(t, err, "CreateEmbeddings error")
} }
@@ -177,26 +177,26 @@ func TestEmbeddingEndpoint(t *testing.T) {
func TestEmbeddingResponseBase64_ToEmbeddingResponse(t *testing.T) { func TestEmbeddingResponseBase64_ToEmbeddingResponse(t *testing.T) {
type fields struct { type fields struct {
Object string Object string
Data []Base64Embedding Data []openai.Base64Embedding
Model EmbeddingModel Model openai.EmbeddingModel
Usage Usage Usage openai.Usage
} }
tests := []struct { tests := []struct {
name string name string
fields fields fields fields
want EmbeddingResponse want openai.EmbeddingResponse
wantErr bool wantErr bool
}{ }{
{ {
name: "test embedding response base64 to embedding response", name: "test embedding response base64 to embedding response",
fields: fields{ fields: fields{
Data: []Base64Embedding{ Data: []openai.Base64Embedding{
{Embedding: "pHCdP4XrkUDhevxA"}, {Embedding: "pHCdP4XrkUDhevxA"},
{Embedding: "/1jku0G/rLvA/EI8"}, {Embedding: "/1jku0G/rLvA/EI8"},
}, },
}, },
want: EmbeddingResponse{ want: openai.EmbeddingResponse{
Data: []Embedding{ Data: []openai.Embedding{
{Embedding: []float32{1.23, 4.56, 7.89}}, {Embedding: []float32{1.23, 4.56, 7.89}},
{Embedding: []float32{-0.006968617, -0.0052718227, 0.011901081}}, {Embedding: []float32{-0.006968617, -0.0052718227, 0.011901081}},
}, },
@@ -206,19 +206,19 @@ func TestEmbeddingResponseBase64_ToEmbeddingResponse(t *testing.T) {
{ {
name: "Invalid embedding", name: "Invalid embedding",
fields: fields{ fields: fields{
Data: []Base64Embedding{ Data: []openai.Base64Embedding{
{ {
Embedding: "----", Embedding: "----",
}, },
}, },
}, },
want: EmbeddingResponse{}, want: openai.EmbeddingResponse{},
wantErr: true, wantErr: true,
}, },
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
r := &EmbeddingResponseBase64{ r := &openai.EmbeddingResponseBase64{
Object: tt.fields.Object, Object: tt.fields.Object,
Data: tt.fields.Data, Data: tt.fields.Data,
Model: tt.fields.Model, Model: tt.fields.Model,
@@ -237,8 +237,8 @@ func TestEmbeddingResponseBase64_ToEmbeddingResponse(t *testing.T) {
} }
func TestDotProduct(t *testing.T) { func TestDotProduct(t *testing.T) {
v1 := &Embedding{Embedding: []float32{1, 2, 3}} v1 := &openai.Embedding{Embedding: []float32{1, 2, 3}}
v2 := &Embedding{Embedding: []float32{2, 4, 6}} v2 := &openai.Embedding{Embedding: []float32{2, 4, 6}}
expected := float32(28.0) expected := float32(28.0)
result, err := v1.DotProduct(v2) result, err := v1.DotProduct(v2)
@@ -250,8 +250,8 @@ func TestDotProduct(t *testing.T) {
t.Errorf("Unexpected result. Expected: %v, but got %v", expected, result) t.Errorf("Unexpected result. Expected: %v, but got %v", expected, result)
} }
v1 = &Embedding{Embedding: []float32{1, 0, 0}} v1 = &openai.Embedding{Embedding: []float32{1, 0, 0}}
v2 = &Embedding{Embedding: []float32{0, 1, 0}} v2 = &openai.Embedding{Embedding: []float32{0, 1, 0}}
expected = float32(0.0) expected = float32(0.0)
result, err = v1.DotProduct(v2) result, err = v1.DotProduct(v2)
@@ -264,10 +264,10 @@ func TestDotProduct(t *testing.T) {
} }
// Test for VectorLengthMismatchError // Test for VectorLengthMismatchError
v1 = &Embedding{Embedding: []float32{1, 0, 0}} v1 = &openai.Embedding{Embedding: []float32{1, 0, 0}}
v2 = &Embedding{Embedding: []float32{0, 1}} v2 = &openai.Embedding{Embedding: []float32{0, 1}}
_, err = v1.DotProduct(v2) _, err = v1.DotProduct(v2)
if !errors.Is(err, ErrVectorLengthMismatch) { if !errors.Is(err, openai.ErrVectorLengthMismatch) {
t.Errorf("Expected Vector Length Mismatch Error, but got: %v", err) t.Errorf("Expected Vector Length Mismatch Error, but got: %v", err)
} }
} }

View File

@@ -7,7 +7,7 @@ import (
"net/http" "net/http"
"testing" "testing"
. "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks" "github.com/sashabaranov/go-openai/internal/test/checks"
) )
@@ -15,8 +15,8 @@ import (
func TestGetEngine(t *testing.T) { func TestGetEngine(t *testing.T) {
client, server, teardown := setupOpenAITestServer() client, server, teardown := setupOpenAITestServer()
defer teardown() defer teardown()
server.RegisterHandler("/v1/engines/text-davinci-003", func(w http.ResponseWriter, r *http.Request) { server.RegisterHandler("/v1/engines/text-davinci-003", func(w http.ResponseWriter, _ *http.Request) {
resBytes, _ := json.Marshal(Engine{}) resBytes, _ := json.Marshal(openai.Engine{})
fmt.Fprintln(w, string(resBytes)) fmt.Fprintln(w, string(resBytes))
}) })
_, err := client.GetEngine(context.Background(), "text-davinci-003") _, err := client.GetEngine(context.Background(), "text-davinci-003")
@@ -27,8 +27,8 @@ func TestGetEngine(t *testing.T) {
func TestListEngines(t *testing.T) { func TestListEngines(t *testing.T) {
client, server, teardown := setupOpenAITestServer() client, server, teardown := setupOpenAITestServer()
defer teardown() defer teardown()
server.RegisterHandler("/v1/engines", func(w http.ResponseWriter, r *http.Request) { server.RegisterHandler("/v1/engines", func(w http.ResponseWriter, _ *http.Request) {
resBytes, _ := json.Marshal(EnginesList{}) resBytes, _ := json.Marshal(openai.EnginesList{})
fmt.Fprintln(w, string(resBytes)) fmt.Fprintln(w, string(resBytes))
}) })
_, err := client.ListEngines(context.Background()) _, err := client.ListEngines(context.Background())
@@ -38,7 +38,7 @@ func TestListEngines(t *testing.T) {
func TestListEnginesReturnError(t *testing.T) { func TestListEnginesReturnError(t *testing.T) {
client, server, teardown := setupOpenAITestServer() client, server, teardown := setupOpenAITestServer()
defer teardown() defer teardown()
server.RegisterHandler("/v1/engines", func(w http.ResponseWriter, r *http.Request) { server.RegisterHandler("/v1/engines", func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusTeapot) w.WriteHeader(http.StatusTeapot)
}) })

View File

@@ -6,7 +6,7 @@ import (
"reflect" "reflect"
"testing" "testing"
. "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai"
) )
func TestAPIErrorUnmarshalJSON(t *testing.T) { func TestAPIErrorUnmarshalJSON(t *testing.T) {
@@ -14,7 +14,7 @@ func TestAPIErrorUnmarshalJSON(t *testing.T) {
name string name string
response string response string
hasError bool hasError bool
checkFunc func(t *testing.T, apiErr APIError) checkFunc func(t *testing.T, apiErr openai.APIError)
} }
testCases := []testCase{ testCases := []testCase{
// testcase for message field // testcase for message field
@@ -22,7 +22,7 @@ func TestAPIErrorUnmarshalJSON(t *testing.T) {
name: "parse succeeds when the message is string", name: "parse succeeds when the message is string",
response: `{"message":"foo","type":"invalid_request_error","param":null,"code":null}`, response: `{"message":"foo","type":"invalid_request_error","param":null,"code":null}`,
hasError: false, hasError: false,
checkFunc: func(t *testing.T, apiErr APIError) { checkFunc: func(t *testing.T, apiErr openai.APIError) {
assertAPIErrorMessage(t, apiErr, "foo") assertAPIErrorMessage(t, apiErr, "foo")
}, },
}, },
@@ -30,7 +30,7 @@ func TestAPIErrorUnmarshalJSON(t *testing.T) {
name: "parse succeeds when the message is array with single item", name: "parse succeeds when the message is array with single item",
response: `{"message":["foo"],"type":"invalid_request_error","param":null,"code":null}`, response: `{"message":["foo"],"type":"invalid_request_error","param":null,"code":null}`,
hasError: false, hasError: false,
checkFunc: func(t *testing.T, apiErr APIError) { checkFunc: func(t *testing.T, apiErr openai.APIError) {
assertAPIErrorMessage(t, apiErr, "foo") assertAPIErrorMessage(t, apiErr, "foo")
}, },
}, },
@@ -38,7 +38,7 @@ func TestAPIErrorUnmarshalJSON(t *testing.T) {
name: "parse succeeds when the message is array with multiple items", name: "parse succeeds when the message is array with multiple items",
response: `{"message":["foo", "bar", "baz"],"type":"invalid_request_error","param":null,"code":null}`, response: `{"message":["foo", "bar", "baz"],"type":"invalid_request_error","param":null,"code":null}`,
hasError: false, hasError: false,
checkFunc: func(t *testing.T, apiErr APIError) { checkFunc: func(t *testing.T, apiErr openai.APIError) {
assertAPIErrorMessage(t, apiErr, "foo, bar, baz") assertAPIErrorMessage(t, apiErr, "foo, bar, baz")
}, },
}, },
@@ -46,7 +46,7 @@ func TestAPIErrorUnmarshalJSON(t *testing.T) {
name: "parse succeeds when the message is empty array", name: "parse succeeds when the message is empty array",
response: `{"message":[],"type":"invalid_request_error","param":null,"code":null}`, response: `{"message":[],"type":"invalid_request_error","param":null,"code":null}`,
hasError: false, hasError: false,
checkFunc: func(t *testing.T, apiErr APIError) { checkFunc: func(t *testing.T, apiErr openai.APIError) {
assertAPIErrorMessage(t, apiErr, "") assertAPIErrorMessage(t, apiErr, "")
}, },
}, },
@@ -54,7 +54,7 @@ func TestAPIErrorUnmarshalJSON(t *testing.T) {
name: "parse succeeds when the message is null", name: "parse succeeds when the message is null",
response: `{"message":null,"type":"invalid_request_error","param":null,"code":null}`, response: `{"message":null,"type":"invalid_request_error","param":null,"code":null}`,
hasError: false, hasError: false,
checkFunc: func(t *testing.T, apiErr APIError) { checkFunc: func(t *testing.T, apiErr openai.APIError) {
assertAPIErrorMessage(t, apiErr, "") assertAPIErrorMessage(t, apiErr, "")
}, },
}, },
@@ -89,23 +89,23 @@ func TestAPIErrorUnmarshalJSON(t *testing.T) {
} }
}`, }`,
hasError: false, hasError: false,
checkFunc: func(t *testing.T, apiErr APIError) { checkFunc: func(t *testing.T, apiErr openai.APIError) {
assertAPIErrorInnerError(t, apiErr, &InnerError{ assertAPIErrorInnerError(t, apiErr, &openai.InnerError{
Code: "ResponsibleAIPolicyViolation", Code: "ResponsibleAIPolicyViolation",
ContentFilterResults: ContentFilterResults{ ContentFilterResults: openai.ContentFilterResults{
Hate: Hate{ Hate: openai.Hate{
Filtered: false, Filtered: false,
Severity: "safe", Severity: "safe",
}, },
SelfHarm: SelfHarm{ SelfHarm: openai.SelfHarm{
Filtered: false, Filtered: false,
Severity: "safe", Severity: "safe",
}, },
Sexual: Sexual{ Sexual: openai.Sexual{
Filtered: true, Filtered: true,
Severity: "medium", Severity: "medium",
}, },
Violence: Violence{ Violence: openai.Violence{
Filtered: false, Filtered: false,
Severity: "safe", Severity: "safe",
}, },
@@ -117,16 +117,16 @@ func TestAPIErrorUnmarshalJSON(t *testing.T) {
name: "parse succeeds when the innerError is empty (Azure Openai)", name: "parse succeeds when the innerError is empty (Azure Openai)",
response: `{"message": "","type": null,"param": "","code": "","status": 0,"innererror": {}}`, response: `{"message": "","type": null,"param": "","code": "","status": 0,"innererror": {}}`,
hasError: false, hasError: false,
checkFunc: func(t *testing.T, apiErr APIError) { checkFunc: func(t *testing.T, apiErr openai.APIError) {
assertAPIErrorInnerError(t, apiErr, &InnerError{}) assertAPIErrorInnerError(t, apiErr, &openai.InnerError{})
}, },
}, },
{ {
name: "parse succeeds when the innerError is not InnerError struct (Azure Openai)", name: "parse succeeds when the innerError is not InnerError struct (Azure Openai)",
response: `{"message": "","type": null,"param": "","code": "","status": 0,"innererror": "test"}`, response: `{"message": "","type": null,"param": "","code": "","status": 0,"innererror": "test"}`,
hasError: true, hasError: true,
checkFunc: func(t *testing.T, apiErr APIError) { checkFunc: func(t *testing.T, apiErr openai.APIError) {
assertAPIErrorInnerError(t, apiErr, &InnerError{}) assertAPIErrorInnerError(t, apiErr, &openai.InnerError{})
}, },
}, },
{ {
@@ -159,7 +159,7 @@ func TestAPIErrorUnmarshalJSON(t *testing.T) {
name: "parse succeeds when the code is int", name: "parse succeeds when the code is int",
response: `{"code":418,"message":"I'm a teapot","param":"prompt","type":"teapot_error"}`, response: `{"code":418,"message":"I'm a teapot","param":"prompt","type":"teapot_error"}`,
hasError: false, hasError: false,
checkFunc: func(t *testing.T, apiErr APIError) { checkFunc: func(t *testing.T, apiErr openai.APIError) {
assertAPIErrorCode(t, apiErr, 418) assertAPIErrorCode(t, apiErr, 418)
}, },
}, },
@@ -167,7 +167,7 @@ func TestAPIErrorUnmarshalJSON(t *testing.T) {
name: "parse succeeds when the code is string", name: "parse succeeds when the code is string",
response: `{"code":"teapot","message":"I'm a teapot","param":"prompt","type":"teapot_error"}`, response: `{"code":"teapot","message":"I'm a teapot","param":"prompt","type":"teapot_error"}`,
hasError: false, hasError: false,
checkFunc: func(t *testing.T, apiErr APIError) { checkFunc: func(t *testing.T, apiErr openai.APIError) {
assertAPIErrorCode(t, apiErr, "teapot") assertAPIErrorCode(t, apiErr, "teapot")
}, },
}, },
@@ -175,7 +175,7 @@ func TestAPIErrorUnmarshalJSON(t *testing.T) {
name: "parse succeeds when the code is not exists", name: "parse succeeds when the code is not exists",
response: `{"message":"I'm a teapot","param":"prompt","type":"teapot_error"}`, response: `{"message":"I'm a teapot","param":"prompt","type":"teapot_error"}`,
hasError: false, hasError: false,
checkFunc: func(t *testing.T, apiErr APIError) { checkFunc: func(t *testing.T, apiErr openai.APIError) {
assertAPIErrorCode(t, apiErr, nil) assertAPIErrorCode(t, apiErr, nil)
}, },
}, },
@@ -196,7 +196,7 @@ func TestAPIErrorUnmarshalJSON(t *testing.T) {
name: "parse failed when the response is invalid json", name: "parse failed when the response is invalid json",
response: `--- {"code":418,"message":"I'm a teapot","param":"prompt","type":"teapot_error"}`, response: `--- {"code":418,"message":"I'm a teapot","param":"prompt","type":"teapot_error"}`,
hasError: true, hasError: true,
checkFunc: func(t *testing.T, apiErr APIError) { checkFunc: func(t *testing.T, apiErr openai.APIError) {
assertAPIErrorCode(t, apiErr, nil) assertAPIErrorCode(t, apiErr, nil)
assertAPIErrorMessage(t, apiErr, "") assertAPIErrorMessage(t, apiErr, "")
assertAPIErrorParam(t, apiErr, nil) assertAPIErrorParam(t, apiErr, nil)
@@ -206,7 +206,7 @@ func TestAPIErrorUnmarshalJSON(t *testing.T) {
} }
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
var apiErr APIError var apiErr openai.APIError
err := apiErr.UnmarshalJSON([]byte(tc.response)) err := apiErr.UnmarshalJSON([]byte(tc.response))
if (err != nil) != tc.hasError { if (err != nil) != tc.hasError {
t.Errorf("Unexpected error: %v", err) t.Errorf("Unexpected error: %v", err)
@@ -218,19 +218,19 @@ func TestAPIErrorUnmarshalJSON(t *testing.T) {
} }
} }
func assertAPIErrorMessage(t *testing.T, apiErr APIError, expected string) { func assertAPIErrorMessage(t *testing.T, apiErr openai.APIError, expected string) {
if apiErr.Message != expected { if apiErr.Message != expected {
t.Errorf("Unexpected APIError message: %v; expected: %s", apiErr, expected) t.Errorf("Unexpected APIError message: %v; expected: %s", apiErr, expected)
} }
} }
func assertAPIErrorInnerError(t *testing.T, apiErr APIError, expected interface{}) { func assertAPIErrorInnerError(t *testing.T, apiErr openai.APIError, expected interface{}) {
if !reflect.DeepEqual(apiErr.InnerError, expected) { if !reflect.DeepEqual(apiErr.InnerError, expected) {
t.Errorf("Unexpected APIError InnerError: %v; expected: %v; ", apiErr, expected) t.Errorf("Unexpected APIError InnerError: %v; expected: %v; ", apiErr, expected)
} }
} }
func assertAPIErrorCode(t *testing.T, apiErr APIError, expected interface{}) { func assertAPIErrorCode(t *testing.T, apiErr openai.APIError, expected interface{}) {
switch v := apiErr.Code.(type) { switch v := apiErr.Code.(type) {
case int: case int:
if v != expected { if v != expected {
@@ -246,25 +246,25 @@ func assertAPIErrorCode(t *testing.T, apiErr APIError, expected interface{}) {
} }
} }
func assertAPIErrorParam(t *testing.T, apiErr APIError, expected *string) { func assertAPIErrorParam(t *testing.T, apiErr openai.APIError, expected *string) {
if apiErr.Param != expected { if apiErr.Param != expected {
t.Errorf("Unexpected APIError param: %v; expected: %s", apiErr, *expected) t.Errorf("Unexpected APIError param: %v; expected: %s", apiErr, *expected)
} }
} }
func assertAPIErrorType(t *testing.T, apiErr APIError, typ string) { func assertAPIErrorType(t *testing.T, apiErr openai.APIError, typ string) {
if apiErr.Type != typ { if apiErr.Type != typ {
t.Errorf("Unexpected API type: %v; expected: %s", apiErr, typ) t.Errorf("Unexpected API type: %v; expected: %s", apiErr, typ)
} }
} }
func TestRequestError(t *testing.T) { func TestRequestError(t *testing.T) {
var err error = &RequestError{ var err error = &openai.RequestError{
HTTPStatusCode: http.StatusTeapot, HTTPStatusCode: http.StatusTeapot,
Err: errors.New("i am a teapot"), Err: errors.New("i am a teapot"),
} }
var reqErr *RequestError var reqErr *openai.RequestError
if !errors.As(err, &reqErr) { if !errors.As(err, &reqErr) {
t.Fatalf("Error is not a RequestError: %+v", err) t.Fatalf("Error is not a RequestError: %+v", err)
} }

View File

@@ -28,7 +28,6 @@ func Example() {
}, },
}, },
) )
if err != nil { if err != nil {
fmt.Printf("ChatCompletion error: %v\n", err) fmt.Printf("ChatCompletion error: %v\n", err)
return return
@@ -319,7 +318,6 @@ func ExampleDefaultAzureConfig() {
}, },
}, },
) )
if err != nil { if err != nil {
fmt.Printf("ChatCompletion error: %v\n", err) fmt.Printf("ChatCompletion error: %v\n", err)
return return

View File

@@ -12,7 +12,7 @@ import (
"testing" "testing"
"time" "time"
. "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks" "github.com/sashabaranov/go-openai/internal/test/checks"
) )
@@ -20,7 +20,7 @@ func TestFileUpload(t *testing.T) {
client, server, teardown := setupOpenAITestServer() client, server, teardown := setupOpenAITestServer()
defer teardown() defer teardown()
server.RegisterHandler("/v1/files", handleCreateFile) server.RegisterHandler("/v1/files", handleCreateFile)
req := FileRequest{ req := openai.FileRequest{
FileName: "test.go", FileName: "test.go",
FilePath: "client.go", FilePath: "client.go",
Purpose: "fine-tune", Purpose: "fine-tune",
@@ -57,7 +57,7 @@ func handleCreateFile(w http.ResponseWriter, r *http.Request) {
} }
defer file.Close() defer file.Close()
var fileReq = File{ fileReq := openai.File{
Bytes: int(header.Size), Bytes: int(header.Size),
ID: strconv.Itoa(int(time.Now().Unix())), ID: strconv.Itoa(int(time.Now().Unix())),
FileName: header.Filename, FileName: header.Filename,
@@ -82,7 +82,7 @@ func TestListFile(t *testing.T) {
client, server, teardown := setupOpenAITestServer() client, server, teardown := setupOpenAITestServer()
defer teardown() defer teardown()
server.RegisterHandler("/v1/files", func(w http.ResponseWriter, r *http.Request) { server.RegisterHandler("/v1/files", func(w http.ResponseWriter, r *http.Request) {
resBytes, _ := json.Marshal(FilesList{}) resBytes, _ := json.Marshal(openai.FilesList{})
fmt.Fprintln(w, string(resBytes)) fmt.Fprintln(w, string(resBytes))
}) })
_, err := client.ListFiles(context.Background()) _, err := client.ListFiles(context.Background())
@@ -93,7 +93,7 @@ func TestGetFile(t *testing.T) {
client, server, teardown := setupOpenAITestServer() client, server, teardown := setupOpenAITestServer()
defer teardown() defer teardown()
server.RegisterHandler("/v1/files/deadbeef", func(w http.ResponseWriter, r *http.Request) { server.RegisterHandler("/v1/files/deadbeef", func(w http.ResponseWriter, r *http.Request) {
resBytes, _ := json.Marshal(File{}) resBytes, _ := json.Marshal(openai.File{})
fmt.Fprintln(w, string(resBytes)) fmt.Fprintln(w, string(resBytes))
}) })
_, err := client.GetFile(context.Background(), "deadbeef") _, err := client.GetFile(context.Background(), "deadbeef")
@@ -148,7 +148,7 @@ func TestGetFileContentReturnError(t *testing.T) {
t.Fatal("Did not return error") t.Fatal("Did not return error")
} }
apiErr := &APIError{} apiErr := &openai.APIError{}
if !errors.As(err, &apiErr) { if !errors.As(err, &apiErr) {
t.Fatalf("Did not return APIError: %+v\n", apiErr) t.Fatalf("Did not return APIError: %+v\n", apiErr)
} }

View File

@@ -1,14 +1,14 @@
package openai //nolint:testpackage // testing private field package openai //nolint:testpackage // testing private field
import ( import (
utils "github.com/sashabaranov/go-openai/internal"
"github.com/sashabaranov/go-openai/internal/test/checks"
"context" "context"
"fmt" "fmt"
"io" "io"
"os" "os"
"testing" "testing"
utils "github.com/sashabaranov/go-openai/internal"
"github.com/sashabaranov/go-openai/internal/test/checks"
) )
func TestFileUploadWithFailingFormBuilder(t *testing.T) { func TestFileUploadWithFailingFormBuilder(t *testing.T) {

View File

@@ -115,6 +115,7 @@ func (c *Client) CreateFineTune(ctx context.Context, request FineTuneRequest) (r
// This API will be officially deprecated on January 4th, 2024. // This API will be officially deprecated on January 4th, 2024.
// OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. // OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go.
func (c *Client) CancelFineTune(ctx context.Context, fineTuneID string) (response FineTune, err error) { func (c *Client) CancelFineTune(ctx context.Context, fineTuneID string) (response FineTune, err error) {
//nolint:goconst // Decreases readability
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/fine-tunes/"+fineTuneID+"/cancel")) req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/fine-tunes/"+fineTuneID+"/cancel"))
if err != nil { if err != nil {
return return

View File

@@ -1,14 +1,14 @@
package openai_test package openai_test
import ( import (
. "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks"
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http" "net/http"
"testing" "testing"
"github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks"
) )
const testFineTuneID = "fine-tune-id" const testFineTuneID = "fine-tune-id"
@@ -22,9 +22,9 @@ func TestFineTunes(t *testing.T) {
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
var resBytes []byte var resBytes []byte
if r.Method == http.MethodGet { if r.Method == http.MethodGet {
resBytes, _ = json.Marshal(FineTuneList{}) resBytes, _ = json.Marshal(openai.FineTuneList{})
} else { } else {
resBytes, _ = json.Marshal(FineTune{}) resBytes, _ = json.Marshal(openai.FineTune{})
} }
fmt.Fprintln(w, string(resBytes)) fmt.Fprintln(w, string(resBytes))
}, },
@@ -32,8 +32,8 @@ func TestFineTunes(t *testing.T) {
server.RegisterHandler( server.RegisterHandler(
"/v1/fine-tunes/"+testFineTuneID+"/cancel", "/v1/fine-tunes/"+testFineTuneID+"/cancel",
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, _ *http.Request) {
resBytes, _ := json.Marshal(FineTune{}) resBytes, _ := json.Marshal(openai.FineTune{})
fmt.Fprintln(w, string(resBytes)) fmt.Fprintln(w, string(resBytes))
}, },
) )
@@ -43,9 +43,9 @@ func TestFineTunes(t *testing.T) {
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
var resBytes []byte var resBytes []byte
if r.Method == http.MethodDelete { if r.Method == http.MethodDelete {
resBytes, _ = json.Marshal(FineTuneDeleteResponse{}) resBytes, _ = json.Marshal(openai.FineTuneDeleteResponse{})
} else { } else {
resBytes, _ = json.Marshal(FineTune{}) resBytes, _ = json.Marshal(openai.FineTune{})
} }
fmt.Fprintln(w, string(resBytes)) fmt.Fprintln(w, string(resBytes))
}, },
@@ -53,8 +53,8 @@ func TestFineTunes(t *testing.T) {
server.RegisterHandler( server.RegisterHandler(
"/v1/fine-tunes/"+testFineTuneID+"/events", "/v1/fine-tunes/"+testFineTuneID+"/events",
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, _ *http.Request) {
resBytes, _ := json.Marshal(FineTuneEventList{}) resBytes, _ := json.Marshal(openai.FineTuneEventList{})
fmt.Fprintln(w, string(resBytes)) fmt.Fprintln(w, string(resBytes))
}, },
) )
@@ -64,7 +64,7 @@ func TestFineTunes(t *testing.T) {
_, err := client.ListFineTunes(ctx) _, err := client.ListFineTunes(ctx)
checks.NoError(t, err, "ListFineTunes error") checks.NoError(t, err, "ListFineTunes error")
_, err = client.CreateFineTune(ctx, FineTuneRequest{}) _, err = client.CreateFineTune(ctx, openai.FineTuneRequest{})
checks.NoError(t, err, "CreateFineTune error") checks.NoError(t, err, "CreateFineTune error")
_, err = client.CancelFineTune(ctx, testFineTuneID) _, err = client.CancelFineTune(ctx, testFineTuneID)

View File

@@ -2,14 +2,13 @@ package openai_test
import ( import (
"context" "context"
. "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http" "net/http"
"testing" "testing"
"github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks"
) )
const testFineTuninigJobID = "fine-tuning-job-id" const testFineTuninigJobID = "fine-tuning-job-id"
@@ -20,8 +19,8 @@ func TestFineTuningJob(t *testing.T) {
defer teardown() defer teardown()
server.RegisterHandler( server.RegisterHandler(
"/v1/fine_tuning/jobs", "/v1/fine_tuning/jobs",
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, _ *http.Request) {
resBytes, _ := json.Marshal(FineTuningJob{ resBytes, _ := json.Marshal(openai.FineTuningJob{
Object: "fine_tuning.job", Object: "fine_tuning.job",
ID: testFineTuninigJobID, ID: testFineTuninigJobID,
Model: "davinci-002", Model: "davinci-002",
@@ -33,7 +32,7 @@ func TestFineTuningJob(t *testing.T) {
Status: "succeeded", Status: "succeeded",
ValidationFile: "", ValidationFile: "",
TrainingFile: "file-abc123", TrainingFile: "file-abc123",
Hyperparameters: Hyperparameters{ Hyperparameters: openai.Hyperparameters{
Epochs: "auto", Epochs: "auto",
}, },
TrainedTokens: 5768, TrainedTokens: 5768,
@@ -44,32 +43,32 @@ func TestFineTuningJob(t *testing.T) {
server.RegisterHandler( server.RegisterHandler(
"/fine_tuning/jobs/"+testFineTuninigJobID+"/cancel", "/fine_tuning/jobs/"+testFineTuninigJobID+"/cancel",
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, _ *http.Request) {
resBytes, _ := json.Marshal(FineTuningJob{}) resBytes, _ := json.Marshal(openai.FineTuningJob{})
fmt.Fprintln(w, string(resBytes)) fmt.Fprintln(w, string(resBytes))
}, },
) )
server.RegisterHandler( server.RegisterHandler(
"/v1/fine_tuning/jobs/"+testFineTuninigJobID, "/v1/fine_tuning/jobs/"+testFineTuninigJobID,
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, _ *http.Request) {
var resBytes []byte var resBytes []byte
resBytes, _ = json.Marshal(FineTuningJob{}) resBytes, _ = json.Marshal(openai.FineTuningJob{})
fmt.Fprintln(w, string(resBytes)) fmt.Fprintln(w, string(resBytes))
}, },
) )
server.RegisterHandler( server.RegisterHandler(
"/v1/fine_tuning/jobs/"+testFineTuninigJobID+"/events", "/v1/fine_tuning/jobs/"+testFineTuninigJobID+"/events",
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, _ *http.Request) {
resBytes, _ := json.Marshal(FineTuningJobEventList{}) resBytes, _ := json.Marshal(openai.FineTuningJobEventList{})
fmt.Fprintln(w, string(resBytes)) fmt.Fprintln(w, string(resBytes))
}, },
) )
ctx := context.Background() ctx := context.Background()
_, err := client.CreateFineTuningJob(ctx, FineTuningJobRequest{}) _, err := client.CreateFineTuningJob(ctx, openai.FineTuningJobRequest{})
checks.NoError(t, err, "CreateFineTuningJob error") checks.NoError(t, err, "CreateFineTuningJob error")
_, err = client.CancelFineTuningJob(ctx, testFineTuninigJobID) _, err = client.CancelFineTuningJob(ctx, testFineTuninigJobID)
@@ -84,22 +83,22 @@ func TestFineTuningJob(t *testing.T) {
_, err = client.ListFineTuningJobEvents( _, err = client.ListFineTuningJobEvents(
ctx, ctx,
testFineTuninigJobID, testFineTuninigJobID,
ListFineTuningJobEventsWithAfter("last-event-id"), openai.ListFineTuningJobEventsWithAfter("last-event-id"),
) )
checks.NoError(t, err, "ListFineTuningJobEvents error") checks.NoError(t, err, "ListFineTuningJobEvents error")
_, err = client.ListFineTuningJobEvents( _, err = client.ListFineTuningJobEvents(
ctx, ctx,
testFineTuninigJobID, testFineTuninigJobID,
ListFineTuningJobEventsWithLimit(10), openai.ListFineTuningJobEventsWithLimit(10),
) )
checks.NoError(t, err, "ListFineTuningJobEvents error") checks.NoError(t, err, "ListFineTuningJobEvents error")
_, err = client.ListFineTuningJobEvents( _, err = client.ListFineTuningJobEvents(
ctx, ctx,
testFineTuninigJobID, testFineTuninigJobID,
ListFineTuningJobEventsWithAfter("last-event-id"), openai.ListFineTuningJobEventsWithAfter("last-event-id"),
ListFineTuningJobEventsWithLimit(10), openai.ListFineTuningJobEventsWithLimit(10),
) )
checks.NoError(t, err, "ListFineTuningJobEvents error") checks.NoError(t, err, "ListFineTuningJobEvents error")
} }

View File

@@ -1,9 +1,6 @@
package openai_test package openai_test
import ( import (
. "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks"
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
@@ -12,13 +9,16 @@ import (
"os" "os"
"testing" "testing"
"time" "time"
"github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks"
) )
func TestImages(t *testing.T) { func TestImages(t *testing.T) {
client, server, teardown := setupOpenAITestServer() client, server, teardown := setupOpenAITestServer()
defer teardown() defer teardown()
server.RegisterHandler("/v1/images/generations", handleImageEndpoint) server.RegisterHandler("/v1/images/generations", handleImageEndpoint)
_, err := client.CreateImage(context.Background(), ImageRequest{ _, err := client.CreateImage(context.Background(), openai.ImageRequest{
Prompt: "Lorem ipsum", Prompt: "Lorem ipsum",
}) })
checks.NoError(t, err, "CreateImage error") checks.NoError(t, err, "CreateImage error")
@@ -33,20 +33,20 @@ func handleImageEndpoint(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" { if r.Method != "POST" {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
} }
var imageReq ImageRequest var imageReq openai.ImageRequest
if imageReq, err = getImageBody(r); err != nil { if imageReq, err = getImageBody(r); err != nil {
http.Error(w, "could not read request", http.StatusInternalServerError) http.Error(w, "could not read request", http.StatusInternalServerError)
return return
} }
res := ImageResponse{ res := openai.ImageResponse{
Created: time.Now().Unix(), Created: time.Now().Unix(),
} }
for i := 0; i < imageReq.N; i++ { for i := 0; i < imageReq.N; i++ {
imageData := ImageResponseDataInner{} imageData := openai.ImageResponseDataInner{}
switch imageReq.ResponseFormat { switch imageReq.ResponseFormat {
case CreateImageResponseFormatURL, "": case openai.CreateImageResponseFormatURL, "":
imageData.URL = "https://example.com/image.png" imageData.URL = "https://example.com/image.png"
case CreateImageResponseFormatB64JSON: case openai.CreateImageResponseFormatB64JSON:
// This decodes to "{}" in base64. // This decodes to "{}" in base64.
imageData.B64JSON = "e30K" imageData.B64JSON = "e30K"
default: default:
@@ -60,16 +60,16 @@ func handleImageEndpoint(w http.ResponseWriter, r *http.Request) {
} }
// getImageBody Returns the body of the request to create a image. // getImageBody Returns the body of the request to create a image.
func getImageBody(r *http.Request) (ImageRequest, error) { func getImageBody(r *http.Request) (openai.ImageRequest, error) {
image := ImageRequest{} image := openai.ImageRequest{}
// read the request body // read the request body
reqBody, err := io.ReadAll(r.Body) reqBody, err := io.ReadAll(r.Body)
if err != nil { if err != nil {
return ImageRequest{}, err return openai.ImageRequest{}, err
} }
err = json.Unmarshal(reqBody, &image) err = json.Unmarshal(reqBody, &image)
if err != nil { if err != nil {
return ImageRequest{}, err return openai.ImageRequest{}, err
} }
return image, nil return image, nil
} }
@@ -98,13 +98,13 @@ func TestImageEdit(t *testing.T) {
os.Remove("image.png") os.Remove("image.png")
}() }()
_, err = client.CreateEditImage(context.Background(), ImageEditRequest{ _, err = client.CreateEditImage(context.Background(), openai.ImageEditRequest{
Image: origin, Image: origin,
Mask: mask, Mask: mask,
Prompt: "There is a turtle in the pool", Prompt: "There is a turtle in the pool",
N: 3, N: 3,
Size: CreateImageSize1024x1024, Size: openai.CreateImageSize1024x1024,
ResponseFormat: CreateImageResponseFormatURL, ResponseFormat: openai.CreateImageResponseFormatURL,
}) })
checks.NoError(t, err, "CreateImage error") checks.NoError(t, err, "CreateImage error")
} }
@@ -125,12 +125,12 @@ func TestImageEditWithoutMask(t *testing.T) {
os.Remove("image.png") os.Remove("image.png")
}() }()
_, err = client.CreateEditImage(context.Background(), ImageEditRequest{ _, err = client.CreateEditImage(context.Background(), openai.ImageEditRequest{
Image: origin, Image: origin,
Prompt: "There is a turtle in the pool", Prompt: "There is a turtle in the pool",
N: 3, N: 3,
Size: CreateImageSize1024x1024, Size: openai.CreateImageSize1024x1024,
ResponseFormat: CreateImageResponseFormatURL, ResponseFormat: openai.CreateImageResponseFormatURL,
}) })
checks.NoError(t, err, "CreateImage error") checks.NoError(t, err, "CreateImage error")
} }
@@ -144,9 +144,9 @@ func handleEditImageEndpoint(w http.ResponseWriter, r *http.Request) {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
} }
responses := ImageResponse{ responses := openai.ImageResponse{
Created: time.Now().Unix(), Created: time.Now().Unix(),
Data: []ImageResponseDataInner{ Data: []openai.ImageResponseDataInner{
{ {
URL: "test-url1", URL: "test-url1",
B64JSON: "", B64JSON: "",
@@ -182,11 +182,11 @@ func TestImageVariation(t *testing.T) {
os.Remove("image.png") os.Remove("image.png")
}() }()
_, err = client.CreateVariImage(context.Background(), ImageVariRequest{ _, err = client.CreateVariImage(context.Background(), openai.ImageVariRequest{
Image: origin, Image: origin,
N: 3, N: 3,
Size: CreateImageSize1024x1024, Size: openai.CreateImageSize1024x1024,
ResponseFormat: CreateImageResponseFormatURL, ResponseFormat: openai.CreateImageResponseFormatURL,
}) })
checks.NoError(t, err, "CreateImage error") checks.NoError(t, err, "CreateImage error")
} }
@@ -200,9 +200,9 @@ func handleVariateImageEndpoint(w http.ResponseWriter, r *http.Request) {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
} }
responses := ImageResponse{ responses := openai.ImageResponse{
Created: time.Now().Unix(), Created: time.Now().Unix(),
Data: []ImageResponseDataInner{ Data: []openai.ImageResponseDataInner{
{ {
URL: "test-url1", URL: "test-url1",
B64JSON: "", B64JSON: "",

View File

@@ -5,28 +5,28 @@ import (
"reflect" "reflect"
"testing" "testing"
. "github.com/sashabaranov/go-openai/jsonschema" "github.com/sashabaranov/go-openai/jsonschema"
) )
func TestDefinition_MarshalJSON(t *testing.T) { func TestDefinition_MarshalJSON(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
def Definition def jsonschema.Definition
want string want string
}{ }{
{ {
name: "Test with empty Definition", name: "Test with empty Definition",
def: Definition{}, def: jsonschema.Definition{},
want: `{"properties":{}}`, want: `{"properties":{}}`,
}, },
{ {
name: "Test with Definition properties set", name: "Test with Definition properties set",
def: Definition{ def: jsonschema.Definition{
Type: String, Type: jsonschema.String,
Description: "A string type", Description: "A string type",
Properties: map[string]Definition{ Properties: map[string]jsonschema.Definition{
"name": { "name": {
Type: String, Type: jsonschema.String,
}, },
}, },
}, },
@@ -43,17 +43,17 @@ func TestDefinition_MarshalJSON(t *testing.T) {
}, },
{ {
name: "Test with nested Definition properties", name: "Test with nested Definition properties",
def: Definition{ def: jsonschema.Definition{
Type: Object, Type: jsonschema.Object,
Properties: map[string]Definition{ Properties: map[string]jsonschema.Definition{
"user": { "user": {
Type: Object, Type: jsonschema.Object,
Properties: map[string]Definition{ Properties: map[string]jsonschema.Definition{
"name": { "name": {
Type: String, Type: jsonschema.String,
}, },
"age": { "age": {
Type: Integer, Type: jsonschema.Integer,
}, },
}, },
}, },
@@ -80,26 +80,26 @@ func TestDefinition_MarshalJSON(t *testing.T) {
}, },
{ {
name: "Test with complex nested Definition", name: "Test with complex nested Definition",
def: Definition{ def: jsonschema.Definition{
Type: Object, Type: jsonschema.Object,
Properties: map[string]Definition{ Properties: map[string]jsonschema.Definition{
"user": { "user": {
Type: Object, Type: jsonschema.Object,
Properties: map[string]Definition{ Properties: map[string]jsonschema.Definition{
"name": { "name": {
Type: String, Type: jsonschema.String,
}, },
"age": { "age": {
Type: Integer, Type: jsonschema.Integer,
}, },
"address": { "address": {
Type: Object, Type: jsonschema.Object,
Properties: map[string]Definition{ Properties: map[string]jsonschema.Definition{
"city": { "city": {
Type: String, Type: jsonschema.String,
}, },
"country": { "country": {
Type: String, Type: jsonschema.String,
}, },
}, },
}, },
@@ -141,14 +141,14 @@ func TestDefinition_MarshalJSON(t *testing.T) {
}, },
{ {
name: "Test with Array type Definition", name: "Test with Array type Definition",
def: Definition{ def: jsonschema.Definition{
Type: Array, Type: jsonschema.Array,
Items: &Definition{ Items: &jsonschema.Definition{
Type: String, Type: jsonschema.String,
}, },
Properties: map[string]Definition{ Properties: map[string]jsonschema.Definition{
"name": { "name": {
Type: String, Type: jsonschema.String,
}, },
}, },
}, },

View File

@@ -1,17 +1,16 @@
package openai_test package openai_test
import ( import (
"os"
"time"
. "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks"
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http" "net/http"
"os"
"testing" "testing"
"time"
"github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks"
) )
const testFineTuneModelID = "fine-tune-model-id" const testFineTuneModelID = "fine-tune-model-id"
@@ -35,7 +34,7 @@ func TestAzureListModels(t *testing.T) {
// handleListModelsEndpoint Handles the list models endpoint by the test server. // handleListModelsEndpoint Handles the list models endpoint by the test server.
func handleListModelsEndpoint(w http.ResponseWriter, _ *http.Request) { func handleListModelsEndpoint(w http.ResponseWriter, _ *http.Request) {
resBytes, _ := json.Marshal(ModelsList{}) resBytes, _ := json.Marshal(openai.ModelsList{})
fmt.Fprintln(w, string(resBytes)) fmt.Fprintln(w, string(resBytes))
} }
@@ -58,7 +57,7 @@ func TestAzureGetModel(t *testing.T) {
// handleGetModelsEndpoint Handles the get model endpoint by the test server. // handleGetModelsEndpoint Handles the get model endpoint by the test server.
func handleGetModelEndpoint(w http.ResponseWriter, _ *http.Request) { func handleGetModelEndpoint(w http.ResponseWriter, _ *http.Request) {
resBytes, _ := json.Marshal(Model{}) resBytes, _ := json.Marshal(openai.Model{})
fmt.Fprintln(w, string(resBytes)) fmt.Fprintln(w, string(resBytes))
} }
@@ -90,6 +89,6 @@ func TestDeleteFineTuneModel(t *testing.T) {
} }
func handleDeleteFineTuneModelEndpoint(w http.ResponseWriter, _ *http.Request) { func handleDeleteFineTuneModelEndpoint(w http.ResponseWriter, _ *http.Request) {
resBytes, _ := json.Marshal(FineTuneModelDeleteResponse{}) resBytes, _ := json.Marshal(openai.FineTuneModelDeleteResponse{})
fmt.Fprintln(w, string(resBytes)) fmt.Fprintln(w, string(resBytes))
} }

View File

@@ -1,9 +1,6 @@
package openai_test package openai_test
import ( import (
. "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks"
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
@@ -13,6 +10,9 @@ import (
"strings" "strings"
"testing" "testing"
"time" "time"
"github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks"
) )
// TestModeration Tests the moderations endpoint of the API using the mocked server. // TestModeration Tests the moderations endpoint of the API using the mocked server.
@@ -20,8 +20,8 @@ func TestModerations(t *testing.T) {
client, server, teardown := setupOpenAITestServer() client, server, teardown := setupOpenAITestServer()
defer teardown() defer teardown()
server.RegisterHandler("/v1/moderations", handleModerationEndpoint) server.RegisterHandler("/v1/moderations", handleModerationEndpoint)
_, err := client.Moderations(context.Background(), ModerationRequest{ _, err := client.Moderations(context.Background(), openai.ModerationRequest{
Model: ModerationTextStable, Model: openai.ModerationTextStable,
Input: "I want to kill them.", Input: "I want to kill them.",
}) })
checks.NoError(t, err, "Moderation error") checks.NoError(t, err, "Moderation error")
@@ -34,16 +34,16 @@ func TestModerationsWithDifferentModelOptions(t *testing.T) {
expect error expect error
} }
modelOptions = append(modelOptions, modelOptions = append(modelOptions,
getModerationModelTestOption(GPT3Dot5Turbo, ErrModerationInvalidModel), getModerationModelTestOption(openai.GPT3Dot5Turbo, openai.ErrModerationInvalidModel),
getModerationModelTestOption(ModerationTextStable, nil), getModerationModelTestOption(openai.ModerationTextStable, nil),
getModerationModelTestOption(ModerationTextLatest, nil), getModerationModelTestOption(openai.ModerationTextLatest, nil),
getModerationModelTestOption("", nil), getModerationModelTestOption("", nil),
) )
client, server, teardown := setupOpenAITestServer() client, server, teardown := setupOpenAITestServer()
defer teardown() defer teardown()
server.RegisterHandler("/v1/moderations", handleModerationEndpoint) server.RegisterHandler("/v1/moderations", handleModerationEndpoint)
for _, modelTest := range modelOptions { for _, modelTest := range modelOptions {
_, err := client.Moderations(context.Background(), ModerationRequest{ _, err := client.Moderations(context.Background(), openai.ModerationRequest{
Model: modelTest.model, Model: modelTest.model,
Input: "I want to kill them.", Input: "I want to kill them.",
}) })
@@ -71,32 +71,32 @@ func handleModerationEndpoint(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" { if r.Method != "POST" {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
} }
var moderationReq ModerationRequest var moderationReq openai.ModerationRequest
if moderationReq, err = getModerationBody(r); err != nil { if moderationReq, err = getModerationBody(r); err != nil {
http.Error(w, "could not read request", http.StatusInternalServerError) http.Error(w, "could not read request", http.StatusInternalServerError)
return return
} }
resCat := ResultCategories{} resCat := openai.ResultCategories{}
resCatScore := ResultCategoryScores{} resCatScore := openai.ResultCategoryScores{}
switch { switch {
case strings.Contains(moderationReq.Input, "kill"): case strings.Contains(moderationReq.Input, "kill"):
resCat = ResultCategories{Violence: true} resCat = openai.ResultCategories{Violence: true}
resCatScore = ResultCategoryScores{Violence: 1} resCatScore = openai.ResultCategoryScores{Violence: 1}
case strings.Contains(moderationReq.Input, "hate"): case strings.Contains(moderationReq.Input, "hate"):
resCat = ResultCategories{Hate: true} resCat = openai.ResultCategories{Hate: true}
resCatScore = ResultCategoryScores{Hate: 1} resCatScore = openai.ResultCategoryScores{Hate: 1}
case strings.Contains(moderationReq.Input, "suicide"): case strings.Contains(moderationReq.Input, "suicide"):
resCat = ResultCategories{SelfHarm: true} resCat = openai.ResultCategories{SelfHarm: true}
resCatScore = ResultCategoryScores{SelfHarm: 1} resCatScore = openai.ResultCategoryScores{SelfHarm: 1}
case strings.Contains(moderationReq.Input, "porn"): case strings.Contains(moderationReq.Input, "porn"):
resCat = ResultCategories{Sexual: true} resCat = openai.ResultCategories{Sexual: true}
resCatScore = ResultCategoryScores{Sexual: 1} resCatScore = openai.ResultCategoryScores{Sexual: 1}
} }
result := Result{Categories: resCat, CategoryScores: resCatScore, Flagged: true} result := openai.Result{Categories: resCat, CategoryScores: resCatScore, Flagged: true}
res := ModerationResponse{ res := openai.ModerationResponse{
ID: strconv.Itoa(int(time.Now().Unix())), ID: strconv.Itoa(int(time.Now().Unix())),
Model: moderationReq.Model, Model: moderationReq.Model,
} }
@@ -107,16 +107,16 @@ func handleModerationEndpoint(w http.ResponseWriter, r *http.Request) {
} }
// getModerationBody Returns the body of the request to do a moderation. // getModerationBody Returns the body of the request to do a moderation.
func getModerationBody(r *http.Request) (ModerationRequest, error) { func getModerationBody(r *http.Request) (openai.ModerationRequest, error) {
moderation := ModerationRequest{} moderation := openai.ModerationRequest{}
// read the request body // read the request body
reqBody, err := io.ReadAll(r.Body) reqBody, err := io.ReadAll(r.Body)
if err != nil { if err != nil {
return ModerationRequest{}, err return openai.ModerationRequest{}, err
} }
err = json.Unmarshal(reqBody, &moderation) err = json.Unmarshal(reqBody, &moderation)
if err != nil { if err != nil {
return ModerationRequest{}, err return openai.ModerationRequest{}, err
} }
return moderation, nil return moderation, nil
} }

View File

@@ -1,29 +1,29 @@
package openai_test package openai_test
import ( import (
. "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test" "github.com/sashabaranov/go-openai/internal/test"
) )
func setupOpenAITestServer() (client *Client, server *test.ServerTest, teardown func()) { func setupOpenAITestServer() (client *openai.Client, server *test.ServerTest, teardown func()) {
server = test.NewTestServer() server = test.NewTestServer()
ts := server.OpenAITestServer() ts := server.OpenAITestServer()
ts.Start() ts.Start()
teardown = ts.Close teardown = ts.Close
config := DefaultConfig(test.GetTestToken()) config := openai.DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1" config.BaseURL = ts.URL + "/v1"
client = NewClientWithConfig(config) client = openai.NewClientWithConfig(config)
return return
} }
func setupAzureTestServer() (client *Client, server *test.ServerTest, teardown func()) { func setupAzureTestServer() (client *openai.Client, server *test.ServerTest, teardown func()) {
server = test.NewTestServer() server = test.NewTestServer()
ts := server.OpenAITestServer() ts := server.OpenAITestServer()
ts.Start() ts.Start()
teardown = ts.Close teardown = ts.Close
config := DefaultAzureConfig(test.GetTestToken(), "https://dummylab.openai.azure.com/") config := openai.DefaultAzureConfig(test.GetTestToken(), "https://dummylab.openai.azure.com/")
config.BaseURL = ts.URL config.BaseURL = ts.URL
client = NewClientWithConfig(config) client = openai.NewClientWithConfig(config)
return return
} }

View File

@@ -10,23 +10,23 @@ import (
"testing" "testing"
"time" "time"
. "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks" "github.com/sashabaranov/go-openai/internal/test/checks"
) )
func TestCompletionsStreamWrongModel(t *testing.T) { func TestCompletionsStreamWrongModel(t *testing.T) {
config := DefaultConfig("whatever") config := openai.DefaultConfig("whatever")
config.BaseURL = "http://localhost/v1" config.BaseURL = "http://localhost/v1"
client := NewClientWithConfig(config) client := openai.NewClientWithConfig(config)
_, err := client.CreateCompletionStream( _, err := client.CreateCompletionStream(
context.Background(), context.Background(),
CompletionRequest{ openai.CompletionRequest{
MaxTokens: 5, MaxTokens: 5,
Model: GPT3Dot5Turbo, Model: openai.GPT3Dot5Turbo,
}, },
) )
if !errors.Is(err, ErrCompletionUnsupportedModel) { if !errors.Is(err, openai.ErrCompletionUnsupportedModel) {
t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel, but returned: %v", err) t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel, but returned: %v", err)
} }
} }
@@ -56,7 +56,7 @@ func TestCreateCompletionStream(t *testing.T) {
checks.NoError(t, err, "Write error") checks.NoError(t, err, "Write error")
}) })
stream, err := client.CreateCompletionStream(context.Background(), CompletionRequest{ stream, err := client.CreateCompletionStream(context.Background(), openai.CompletionRequest{
Prompt: "Ex falso quodlibet", Prompt: "Ex falso quodlibet",
Model: "text-davinci-002", Model: "text-davinci-002",
MaxTokens: 10, MaxTokens: 10,
@@ -65,20 +65,20 @@ func TestCreateCompletionStream(t *testing.T) {
checks.NoError(t, err, "CreateCompletionStream returned error") checks.NoError(t, err, "CreateCompletionStream returned error")
defer stream.Close() defer stream.Close()
expectedResponses := []CompletionResponse{ expectedResponses := []openai.CompletionResponse{
{ {
ID: "1", ID: "1",
Object: "completion", Object: "completion",
Created: 1598069254, Created: 1598069254,
Model: "text-davinci-002", Model: "text-davinci-002",
Choices: []CompletionChoice{{Text: "response1", FinishReason: "max_tokens"}}, Choices: []openai.CompletionChoice{{Text: "response1", FinishReason: "max_tokens"}},
}, },
{ {
ID: "2", ID: "2",
Object: "completion", Object: "completion",
Created: 1598069255, Created: 1598069255,
Model: "text-davinci-002", Model: "text-davinci-002",
Choices: []CompletionChoice{{Text: "response2", FinishReason: "max_tokens"}}, Choices: []openai.CompletionChoice{{Text: "response2", FinishReason: "max_tokens"}},
}, },
} }
@@ -129,9 +129,9 @@ func TestCreateCompletionStreamError(t *testing.T) {
checks.NoError(t, err, "Write error") checks.NoError(t, err, "Write error")
}) })
stream, err := client.CreateCompletionStream(context.Background(), CompletionRequest{ stream, err := client.CreateCompletionStream(context.Background(), openai.CompletionRequest{
MaxTokens: 5, MaxTokens: 5,
Model: GPT3TextDavinci003, Model: openai.GPT3TextDavinci003,
Prompt: "Hello!", Prompt: "Hello!",
Stream: true, Stream: true,
}) })
@@ -141,7 +141,7 @@ func TestCreateCompletionStreamError(t *testing.T) {
_, streamErr := stream.Recv() _, streamErr := stream.Recv()
checks.HasError(t, streamErr, "stream.Recv() did not return error") checks.HasError(t, streamErr, "stream.Recv() did not return error")
var apiErr *APIError var apiErr *openai.APIError
if !errors.As(streamErr, &apiErr) { if !errors.As(streamErr, &apiErr) {
t.Errorf("stream.Recv() did not return APIError") t.Errorf("stream.Recv() did not return APIError")
} }
@@ -166,10 +166,10 @@ func TestCreateCompletionStreamRateLimitError(t *testing.T) {
checks.NoError(t, err, "Write error") checks.NoError(t, err, "Write error")
}) })
var apiErr *APIError var apiErr *openai.APIError
_, err := client.CreateCompletionStream(context.Background(), CompletionRequest{ _, err := client.CreateCompletionStream(context.Background(), openai.CompletionRequest{
MaxTokens: 5, MaxTokens: 5,
Model: GPT3Ada, Model: openai.GPT3Ada,
Prompt: "Hello!", Prompt: "Hello!",
Stream: true, Stream: true,
}) })
@@ -209,7 +209,7 @@ func TestCreateCompletionStreamTooManyEmptyStreamMessagesError(t *testing.T) {
checks.NoError(t, err, "Write error") checks.NoError(t, err, "Write error")
}) })
stream, err := client.CreateCompletionStream(context.Background(), CompletionRequest{ stream, err := client.CreateCompletionStream(context.Background(), openai.CompletionRequest{
Prompt: "Ex falso quodlibet", Prompt: "Ex falso quodlibet",
Model: "text-davinci-002", Model: "text-davinci-002",
MaxTokens: 10, MaxTokens: 10,
@@ -220,7 +220,7 @@ func TestCreateCompletionStreamTooManyEmptyStreamMessagesError(t *testing.T) {
_, _ = stream.Recv() _, _ = stream.Recv()
_, streamErr := stream.Recv() _, streamErr := stream.Recv()
if !errors.Is(streamErr, ErrTooManyEmptyStreamMessages) { if !errors.Is(streamErr, openai.ErrTooManyEmptyStreamMessages) {
t.Errorf("TestCreateCompletionStreamTooManyEmptyStreamMessagesError did not return ErrTooManyEmptyStreamMessages") t.Errorf("TestCreateCompletionStreamTooManyEmptyStreamMessagesError did not return ErrTooManyEmptyStreamMessages")
} }
} }
@@ -244,7 +244,7 @@ func TestCreateCompletionStreamUnexpectedTerminatedError(t *testing.T) {
checks.NoError(t, err, "Write error") checks.NoError(t, err, "Write error")
}) })
stream, err := client.CreateCompletionStream(context.Background(), CompletionRequest{ stream, err := client.CreateCompletionStream(context.Background(), openai.CompletionRequest{
Prompt: "Ex falso quodlibet", Prompt: "Ex falso quodlibet",
Model: "text-davinci-002", Model: "text-davinci-002",
MaxTokens: 10, MaxTokens: 10,
@@ -285,7 +285,7 @@ func TestCreateCompletionStreamBrokenJSONError(t *testing.T) {
checks.NoError(t, err, "Write error") checks.NoError(t, err, "Write error")
}) })
stream, err := client.CreateCompletionStream(context.Background(), CompletionRequest{ stream, err := client.CreateCompletionStream(context.Background(), openai.CompletionRequest{
Prompt: "Ex falso quodlibet", Prompt: "Ex falso quodlibet",
Model: "text-davinci-002", Model: "text-davinci-002",
MaxTokens: 10, MaxTokens: 10,
@@ -312,7 +312,7 @@ func TestCreateCompletionStreamReturnTimeoutError(t *testing.T) {
ctx, cancel := context.WithTimeout(ctx, time.Nanosecond) ctx, cancel := context.WithTimeout(ctx, time.Nanosecond)
defer cancel() defer cancel()
_, err := client.CreateCompletionStream(ctx, CompletionRequest{ _, err := client.CreateCompletionStream(ctx, openai.CompletionRequest{
Prompt: "Ex falso quodlibet", Prompt: "Ex falso quodlibet",
Model: "text-davinci-002", Model: "text-davinci-002",
MaxTokens: 10, MaxTokens: 10,
@@ -327,7 +327,7 @@ func TestCreateCompletionStreamReturnTimeoutError(t *testing.T) {
} }
// Helper funcs. // Helper funcs.
func compareResponses(r1, r2 CompletionResponse) bool { func compareResponses(r1, r2 openai.CompletionResponse) bool {
if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model { if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model {
return false return false
} }
@@ -342,7 +342,7 @@ func compareResponses(r1, r2 CompletionResponse) bool {
return true return true
} }
func compareResponseChoices(c1, c2 CompletionChoice) bool { func compareResponseChoices(c1, c2 openai.CompletionChoice) bool {
if c1.Text != c2.Text || c1.FinishReason != c2.FinishReason { if c1.Text != c2.Text || c1.FinishReason != c2.FinishReason {
return false return false
} }