refactoring tests with mock servers (#30) (#356)

This commit is contained in:
渡邉祐一 / Yuichi Watanabe
2023-06-12 22:40:26 +09:00
committed by GitHub
parent a243e7331f
commit b616090e69
20 changed files with 732 additions and 1061 deletions

View File

@@ -6,7 +6,6 @@ import (
"errors" "errors"
"io" "io"
"net/http" "net/http"
"net/http/httptest"
"os" "os"
"testing" "testing"
@@ -226,18 +225,13 @@ func TestAPIErrorUnmarshalJSONInvalidMessage(t *testing.T) {
} }
func TestRequestError(t *testing.T) { func TestRequestError(t *testing.T) {
var err error client, server, teardown := setupOpenAITestServer()
defer teardown()
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { server.RegisterHandler("/v1/engines", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusTeapot) w.WriteHeader(http.StatusTeapot)
})) })
defer ts.Close()
config := DefaultConfig("dummy") _, err := client.ListEngines(context.Background())
config.BaseURL = ts.URL
c := NewClientWithConfig(config)
ctx := context.Background()
_, err = c.ListEngines(ctx)
checks.HasError(t, err, "ListEngines did not fail") checks.HasError(t, err, "ListEngines did not fail")
var reqErr *RequestError var reqErr *RequestError

162
audio_api_test.go Normal file
View File

@@ -0,0 +1,162 @@
package openai_test
import (
"bytes"
"context"
"errors"
"io"
"mime"
"mime/multipart"
"net/http"
"path/filepath"
"strings"
"testing"
. "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks"
)
// TestAudio Tests the transcription and translation endpoints of the API using the mocked server.
func TestAudio(t *testing.T) {
client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler("/v1/audio/transcriptions", handleAudioEndpoint)
server.RegisterHandler("/v1/audio/translations", handleAudioEndpoint)
testcases := []struct {
name string
createFn func(context.Context, AudioRequest) (AudioResponse, error)
}{
{
"transcribe",
client.CreateTranscription,
},
{
"translate",
client.CreateTranslation,
},
}
ctx := context.Background()
dir, cleanup := test.CreateTestDirectory(t)
defer cleanup()
for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) {
path := filepath.Join(dir, "fake.mp3")
test.CreateTestFile(t, path)
req := AudioRequest{
FilePath: path,
Model: "whisper-3",
}
_, err := tc.createFn(ctx, req)
checks.NoError(t, err, "audio API error")
})
t.Run(tc.name+" (with reader)", func(t *testing.T) {
req := AudioRequest{
FilePath: "fake.webm",
Reader: bytes.NewBuffer([]byte(`some webm binary data`)),
Model: "whisper-3",
}
_, err := tc.createFn(ctx, req)
checks.NoError(t, err, "audio API error")
})
}
}
func TestAudioWithOptionalArgs(t *testing.T) {
client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler("/v1/audio/transcriptions", handleAudioEndpoint)
server.RegisterHandler("/v1/audio/translations", handleAudioEndpoint)
testcases := []struct {
name string
createFn func(context.Context, AudioRequest) (AudioResponse, error)
}{
{
"transcribe",
client.CreateTranscription,
},
{
"translate",
client.CreateTranslation,
},
}
ctx := context.Background()
dir, cleanup := test.CreateTestDirectory(t)
defer cleanup()
for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) {
path := filepath.Join(dir, "fake.mp3")
test.CreateTestFile(t, path)
req := AudioRequest{
FilePath: path,
Model: "whisper-3",
Prompt: "用简体中文",
Temperature: 0.5,
Language: "zh",
Format: AudioResponseFormatSRT,
}
_, err := tc.createFn(ctx, req)
checks.NoError(t, err, "audio API error")
})
}
}
// handleAudioEndpoint Handles the completion endpoint by the test server.
func handleAudioEndpoint(w http.ResponseWriter, r *http.Request) {
var err error
// audio endpoints only accept POST requests
if r.Method != "POST" {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
}
mediaType, params, err := mime.ParseMediaType(r.Header.Get("Content-Type"))
if err != nil {
http.Error(w, "failed to parse media type", http.StatusBadRequest)
return
}
if !strings.HasPrefix(mediaType, "multipart") {
http.Error(w, "request is not multipart", http.StatusBadRequest)
}
boundary, ok := params["boundary"]
if !ok {
http.Error(w, "no boundary in params", http.StatusBadRequest)
return
}
fileData := &bytes.Buffer{}
mr := multipart.NewReader(r.Body, boundary)
part, err := mr.NextPart()
if err != nil && errors.Is(err, io.EOF) {
http.Error(w, "error accessing file", http.StatusBadRequest)
return
}
if _, err = io.Copy(fileData, part); err != nil {
http.Error(w, "failed to copy file", http.StatusInternalServerError)
return
}
if len(fileData.Bytes()) == 0 {
w.WriteHeader(http.StatusInternalServerError)
http.Error(w, "received empty file data", http.StatusBadRequest)
return
}
if _, err = w.Write([]byte(`{"body": "hello"}`)); err != nil {
http.Error(w, "failed to write body", http.StatusInternalServerError)
return
}
}

View File

@@ -2,182 +2,16 @@ package openai //nolint:testpackage // testing private field
import ( import (
"bytes" "bytes"
"context"
"errors"
"fmt" "fmt"
"io" "io"
"mime"
"mime/multipart"
"net/http"
"os" "os"
"path/filepath" "path/filepath"
"strings"
"testing" "testing"
"github.com/sashabaranov/go-openai/internal/test" "github.com/sashabaranov/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks" "github.com/sashabaranov/go-openai/internal/test/checks"
) )
// TestAudio Tests the transcription and translation endpoints of the API using the mocked server.
func TestAudio(t *testing.T) {
server := test.NewTestServer()
server.RegisterHandler("/v1/audio/transcriptions", handleAudioEndpoint)
server.RegisterHandler("/v1/audio/translations", handleAudioEndpoint)
// create the test server
var err error
ts := server.OpenAITestServer()
ts.Start()
defer ts.Close()
config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
testcases := []struct {
name string
createFn func(context.Context, AudioRequest) (AudioResponse, error)
}{
{
"transcribe",
client.CreateTranscription,
},
{
"translate",
client.CreateTranslation,
},
}
ctx := context.Background()
dir, cleanup := test.CreateTestDirectory(t)
defer cleanup()
for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) {
path := filepath.Join(dir, "fake.mp3")
test.CreateTestFile(t, path)
req := AudioRequest{
FilePath: path,
Model: "whisper-3",
}
_, err = tc.createFn(ctx, req)
checks.NoError(t, err, "audio API error")
})
t.Run(tc.name+" (with reader)", func(t *testing.T) {
req := AudioRequest{
FilePath: "fake.webm",
Reader: bytes.NewBuffer([]byte(`some webm binary data`)),
Model: "whisper-3",
}
_, err = tc.createFn(ctx, req)
checks.NoError(t, err, "audio API error")
})
}
}
func TestAudioWithOptionalArgs(t *testing.T) {
server := test.NewTestServer()
server.RegisterHandler("/v1/audio/transcriptions", handleAudioEndpoint)
server.RegisterHandler("/v1/audio/translations", handleAudioEndpoint)
// create the test server
var err error
ts := server.OpenAITestServer()
ts.Start()
defer ts.Close()
config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
testcases := []struct {
name string
createFn func(context.Context, AudioRequest) (AudioResponse, error)
}{
{
"transcribe",
client.CreateTranscription,
},
{
"translate",
client.CreateTranslation,
},
}
ctx := context.Background()
dir, cleanup := test.CreateTestDirectory(t)
defer cleanup()
for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) {
path := filepath.Join(dir, "fake.mp3")
test.CreateTestFile(t, path)
req := AudioRequest{
FilePath: path,
Model: "whisper-3",
Prompt: "用简体中文",
Temperature: 0.5,
Language: "zh",
Format: AudioResponseFormatSRT,
}
_, err = tc.createFn(ctx, req)
checks.NoError(t, err, "audio API error")
})
}
}
// handleAudioEndpoint Handles the completion endpoint by the test server.
func handleAudioEndpoint(w http.ResponseWriter, r *http.Request) {
var err error
// audio endpoints only accept POST requests
if r.Method != "POST" {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
}
mediaType, params, err := mime.ParseMediaType(r.Header.Get("Content-Type"))
if err != nil {
http.Error(w, "failed to parse media type", http.StatusBadRequest)
return
}
if !strings.HasPrefix(mediaType, "multipart") {
http.Error(w, "request is not multipart", http.StatusBadRequest)
}
boundary, ok := params["boundary"]
if !ok {
http.Error(w, "no boundary in params", http.StatusBadRequest)
return
}
fileData := &bytes.Buffer{}
mr := multipart.NewReader(r.Body, boundary)
part, err := mr.NextPart()
if err != nil && errors.Is(err, io.EOF) {
http.Error(w, "error accessing file", http.StatusBadRequest)
return
}
if _, err = io.Copy(fileData, part); err != nil {
http.Error(w, "failed to copy file", http.StatusInternalServerError)
return
}
if len(fileData.Bytes()) == 0 {
w.WriteHeader(http.StatusInternalServerError)
http.Error(w, "received empty file data", http.StatusBadRequest)
return
}
if _, err = w.Write([]byte(`{"body": "hello"}`)); err != nil {
http.Error(w, "failed to write body", http.StatusInternalServerError)
return
}
}
func TestAudioWithFailingFormBuilder(t *testing.T) { func TestAudioWithFailingFormBuilder(t *testing.T) {
dir, cleanup := test.CreateTestDirectory(t) dir, cleanup := test.CreateTestDirectory(t)
defer cleanup() defer cleanup()

View File

@@ -1,8 +1,7 @@
package openai //nolint:testpackage // testing private field package openai_test
import ( import (
utils "github.com/sashabaranov/go-openai/internal" . "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks" "github.com/sashabaranov/go-openai/internal/test/checks"
"context" "context"
@@ -10,7 +9,6 @@ import (
"errors" "errors"
"io" "io"
"net/http" "net/http"
"net/http/httptest"
"testing" "testing"
) )
@@ -37,7 +35,9 @@ func TestChatCompletionsStreamWrongModel(t *testing.T) {
} }
func TestCreateChatCompletionStream(t *testing.T) { func TestCreateChatCompletionStream(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Content-Type", "text/event-stream")
// Send test responses // Send test responses
@@ -57,21 +57,9 @@ func TestCreateChatCompletionStream(t *testing.T) {
_, err := w.Write(dataBytes) _, err := w.Write(dataBytes)
checks.NoError(t, err, "Write error") checks.NoError(t, err, "Write error")
})) })
defer server.Close()
// Client portion of the test stream, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{
config := DefaultConfig(test.GetTestToken())
config.BaseURL = server.URL + "/v1"
config.HTTPClient.Transport = &test.TokenRoundTripper{
Token: test.GetTestToken(),
Fallback: http.DefaultTransport,
}
client := NewClientWithConfig(config)
ctx := context.Background()
request := ChatCompletionRequest{
MaxTokens: 5, MaxTokens: 5,
Model: GPT3Dot5Turbo, Model: GPT3Dot5Turbo,
Messages: []ChatCompletionMessage{ Messages: []ChatCompletionMessage{
@@ -81,9 +69,7 @@ func TestCreateChatCompletionStream(t *testing.T) {
}, },
}, },
Stream: true, Stream: true,
} })
stream, err := client.CreateChatCompletionStream(ctx, request)
checks.NoError(t, err, "CreateCompletionStream returned error") checks.NoError(t, err, "CreateCompletionStream returned error")
defer stream.Close() defer stream.Close()
@@ -143,7 +129,9 @@ func TestCreateChatCompletionStream(t *testing.T) {
} }
func TestCreateChatCompletionStreamError(t *testing.T) { func TestCreateChatCompletionStreamError(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Content-Type", "text/event-stream")
// Send test responses // Send test responses
@@ -164,21 +152,9 @@ func TestCreateChatCompletionStreamError(t *testing.T) {
_, err := w.Write(dataBytes) _, err := w.Write(dataBytes)
checks.NoError(t, err, "Write error") checks.NoError(t, err, "Write error")
})) })
defer server.Close()
// Client portion of the test stream, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{
config := DefaultConfig(test.GetTestToken())
config.BaseURL = server.URL + "/v1"
config.HTTPClient.Transport = &test.TokenRoundTripper{
Token: test.GetTestToken(),
Fallback: http.DefaultTransport,
}
client := NewClientWithConfig(config)
ctx := context.Background()
request := ChatCompletionRequest{
MaxTokens: 5, MaxTokens: 5,
Model: GPT3Dot5Turbo, Model: GPT3Dot5Turbo,
Messages: []ChatCompletionMessage{ Messages: []ChatCompletionMessage{
@@ -188,9 +164,7 @@ func TestCreateChatCompletionStreamError(t *testing.T) {
}, },
}, },
Stream: true, Stream: true,
} })
stream, err := client.CreateChatCompletionStream(ctx, request)
checks.NoError(t, err, "CreateCompletionStream returned error") checks.NoError(t, err, "CreateCompletionStream returned error")
defer stream.Close() defer stream.Close()
@@ -205,7 +179,8 @@ func TestCreateChatCompletionStreamError(t *testing.T) {
} }
func TestCreateChatCompletionStreamRateLimitError(t *testing.T) { func TestCreateChatCompletionStreamRateLimitError(t *testing.T) {
server := test.NewTestServer() client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.WriteHeader(429) w.WriteHeader(429)
@@ -220,22 +195,7 @@ func TestCreateChatCompletionStreamRateLimitError(t *testing.T) {
_, err := w.Write(dataBytes) _, err := w.Write(dataBytes)
checks.NoError(t, err, "Write error") checks.NoError(t, err, "Write error")
}) })
ts := server.OpenAITestServer() _, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{
ts.Start()
defer ts.Close()
// Client portion of the test
config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
config.HTTPClient.Transport = &test.TokenRoundTripper{
Token: test.GetTestToken(),
Fallback: http.DefaultTransport,
}
client := NewClientWithConfig(config)
ctx := context.Background()
request := ChatCompletionRequest{
MaxTokens: 5, MaxTokens: 5,
Model: GPT3Dot5Turbo, Model: GPT3Dot5Turbo,
Messages: []ChatCompletionMessage{ Messages: []ChatCompletionMessage{
@@ -245,10 +205,8 @@ func TestCreateChatCompletionStreamRateLimitError(t *testing.T) {
}, },
}, },
Stream: true, Stream: true,
} })
var apiErr *APIError var apiErr *APIError
_, err := client.CreateChatCompletionStream(ctx, request)
if !errors.As(err, &apiErr) { if !errors.As(err, &apiErr) {
t.Errorf("TestCreateChatCompletionStreamRateLimitError did not return APIError") t.Errorf("TestCreateChatCompletionStreamRateLimitError did not return APIError")
} }
@@ -262,7 +220,8 @@ func TestAzureCreateChatCompletionStreamRateLimitError(t *testing.T) {
"Please retry after 20 seconds. " + "Please retry after 20 seconds. " +
"Please go here: https://aka.ms/oai/quotaincrease if you would like to further increase the default rate limit." "Please go here: https://aka.ms/oai/quotaincrease if you would like to further increase the default rate limit."
server := test.NewTestServer() client, server, teardown := setupAzureTestServer()
defer teardown()
server.RegisterHandler("/openai/deployments/gpt-35-turbo/chat/completions", server.RegisterHandler("/openai/deployments/gpt-35-turbo/chat/completions",
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
@@ -273,17 +232,9 @@ func TestAzureCreateChatCompletionStreamRateLimitError(t *testing.T) {
checks.NoError(t, err, "Write error") checks.NoError(t, err, "Write error")
}) })
// create the test server
var err error
ts := server.OpenAITestServer()
ts.Start()
defer ts.Close()
config := DefaultAzureConfig(test.GetTestToken(), ts.URL) apiErr := &APIError{}
client := NewClientWithConfig(config) _, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{
ctx := context.Background()
request := ChatCompletionRequest{
MaxTokens: 5, MaxTokens: 5,
Model: GPT3Dot5Turbo, Model: GPT3Dot5Turbo,
Messages: []ChatCompletionMessage{ Messages: []ChatCompletionMessage{
@@ -293,10 +244,7 @@ func TestAzureCreateChatCompletionStreamRateLimitError(t *testing.T) {
}, },
}, },
Stream: true, Stream: true,
} })
apiErr := &APIError{}
_, err = client.CreateChatCompletionStream(ctx, request)
if !errors.As(err, &apiErr) { if !errors.As(err, &apiErr) {
t.Errorf("Did not return APIError: %+v\n", apiErr) t.Errorf("Did not return APIError: %+v\n", apiErr)
return return
@@ -316,33 +264,6 @@ func TestAzureCreateChatCompletionStreamRateLimitError(t *testing.T) {
} }
} }
func TestCreateChatCompletionStreamErrorAccumulatorWriteErrors(t *testing.T) {
var err error
server := test.NewTestServer()
server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "error", 200)
})
ts := server.OpenAITestServer()
ts.Start()
defer ts.Close()
config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
ctx := context.Background()
stream, err := client.CreateChatCompletionStream(ctx, ChatCompletionRequest{})
checks.NoError(t, err)
stream.errAccumulator = &utils.DefaultErrorAccumulator{
Buffer: &test.FailingErrorBuffer{},
}
_, err = stream.Recv()
checks.ErrorIs(t, err, test.ErrTestErrorAccumulatorWriteFailed, "Did not return error when Write failed", err.Error())
}
// Helper funcs. // Helper funcs.
func compareChatResponses(r1, r2 ChatCompletionStreamResponse) bool { func compareChatResponses(r1, r2 ChatCompletionStreamResponse) bool {
if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model { if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model {

View File

@@ -2,7 +2,6 @@ package openai_test
import ( import (
. "github.com/sashabaranov/go-openai" . "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks" "github.com/sashabaranov/go-openai/internal/test/checks"
"context" "context"
@@ -52,20 +51,10 @@ func TestChatCompletionsWithStream(t *testing.T) {
// TestCompletions Tests the completions endpoint of the API using the mocked server. // TestCompletions Tests the completions endpoint of the API using the mocked server.
func TestChatCompletions(t *testing.T) { func TestChatCompletions(t *testing.T) {
server := test.NewTestServer() client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint) server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint)
// create the test server _, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{
var err error
ts := server.OpenAITestServer()
ts.Start()
defer ts.Close()
config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
ctx := context.Background()
req := ChatCompletionRequest{
MaxTokens: 5, MaxTokens: 5,
Model: GPT3Dot5Turbo, Model: GPT3Dot5Turbo,
Messages: []ChatCompletionMessage{ Messages: []ChatCompletionMessage{
@@ -74,8 +63,7 @@ func TestChatCompletions(t *testing.T) {
Content: "Hello!", Content: "Hello!",
}, },
}, },
} })
_, err = client.CreateChatCompletion(ctx, req)
checks.NoError(t, err, "CreateChatCompletion error") checks.NoError(t, err, "CreateChatCompletion error")
} }

View File

@@ -167,16 +167,9 @@ func TestHandleErrorResp(t *testing.T) {
} }
func TestClientReturnsRequestBuilderErrors(t *testing.T) { func TestClientReturnsRequestBuilderErrors(t *testing.T) {
var err error
ts := test.NewTestServer().OpenAITestServer()
ts.Start()
defer ts.Close()
config := DefaultConfig(test.GetTestToken()) config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config) client := NewClientWithConfig(config)
client.requestBuilder = &failingRequestBuilder{} client.requestBuilder = &failingRequestBuilder{}
ctx := context.Background() ctx := context.Background()
type TestCase struct { type TestCase struct {
@@ -254,7 +247,7 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) {
} }
for _, testCase := range testCases { for _, testCase := range testCases {
_, err = testCase.TestFunc() _, err := testCase.TestFunc()
if !errors.Is(err, errTestRequestBuilderFailed) { if !errors.Is(err, errTestRequestBuilderFailed) {
t.Fatalf("%s did not return error when request builder failed: %v", testCase.Name, err) t.Fatalf("%s did not return error when request builder failed: %v", testCase.Name, err)
} }
@@ -262,23 +255,14 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) {
} }
func TestClientReturnsRequestBuilderErrorsAddtion(t *testing.T) { func TestClientReturnsRequestBuilderErrorsAddtion(t *testing.T) {
var err error
ts := test.NewTestServer().OpenAITestServer()
ts.Start()
defer ts.Close()
config := DefaultConfig(test.GetTestToken()) config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config) client := NewClientWithConfig(config)
client.requestBuilder = &failingRequestBuilder{} client.requestBuilder = &failingRequestBuilder{}
ctx := context.Background() ctx := context.Background()
_, err := client.CreateCompletion(ctx, CompletionRequest{Prompt: 1})
_, err = client.CreateCompletion(ctx, CompletionRequest{Prompt: 1})
if !errors.Is(err, ErrCompletionRequestPromptTypeNotSupported) { if !errors.Is(err, ErrCompletionRequestPromptTypeNotSupported) {
t.Fatalf("Did not return error when request builder failed: %v", err) t.Fatalf("Did not return error when request builder failed: %v", err)
} }
_, err = client.CreateCompletionStream(ctx, CompletionRequest{Prompt: 1}) _, err = client.CreateCompletionStream(ctx, CompletionRequest{Prompt: 1})
if !errors.Is(err, ErrCompletionRequestPromptTypeNotSupported) { if !errors.Is(err, ErrCompletionRequestPromptTypeNotSupported) {
t.Fatalf("Did not return error when request builder failed: %v", err) t.Fatalf("Did not return error when request builder failed: %v", err)

View File

@@ -2,7 +2,6 @@ package openai_test
import ( import (
. "github.com/sashabaranov/go-openai" . "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks" "github.com/sashabaranov/go-openai/internal/test/checks"
"context" "context"
@@ -48,25 +47,15 @@ func TestCompletionWithStream(t *testing.T) {
// TestCompletions Tests the completions endpoint of the API using the mocked server. // TestCompletions Tests the completions endpoint of the API using the mocked server.
func TestCompletions(t *testing.T) { func TestCompletions(t *testing.T) {
server := test.NewTestServer() client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler("/v1/completions", handleCompletionEndpoint) server.RegisterHandler("/v1/completions", handleCompletionEndpoint)
// create the test server
var err error
ts := server.OpenAITestServer()
ts.Start()
defer ts.Close()
config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
ctx := context.Background()
req := CompletionRequest{ req := CompletionRequest{
MaxTokens: 5, MaxTokens: 5,
Model: "ada", Model: "ada",
Prompt: "Lorem ipsum",
} }
req.Prompt = "Lorem ipsum" _, err := client.CreateCompletion(context.Background(), req)
_, err = client.CreateCompletion(ctx, req)
checks.NoError(t, err, "CreateCompletion error") checks.NoError(t, err, "CreateCompletion error")
} }

View File

@@ -2,7 +2,6 @@ package openai_test
import ( import (
. "github.com/sashabaranov/go-openai" . "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks" "github.com/sashabaranov/go-openai/internal/test/checks"
"context" "context"
@@ -16,19 +15,9 @@ import (
// TestEdits Tests the edits endpoint of the API using the mocked server. // TestEdits Tests the edits endpoint of the API using the mocked server.
func TestEdits(t *testing.T) { func TestEdits(t *testing.T) {
server := test.NewTestServer() client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler("/v1/edits", handleEditEndpoint) server.RegisterHandler("/v1/edits", handleEditEndpoint)
// create the test server
var err error
ts := server.OpenAITestServer()
ts.Start()
defer ts.Close()
config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
ctx := context.Background()
// create an edit request // create an edit request
model := "ada" model := "ada"
editReq := EditsRequest{ editReq := EditsRequest{
@@ -40,7 +29,7 @@ func TestEdits(t *testing.T) {
Instruction: "test instruction", Instruction: "test instruction",
N: 3, N: 3,
} }
response, err := client.Edits(ctx, editReq) response, err := client.Edits(context.Background(), editReq)
checks.NoError(t, err, "Edits error") checks.NoError(t, err, "Edits error")
if len(response.Choices) != editReq.N { if len(response.Choices) != editReq.N {
t.Fatalf("edits does not properly return the correct number of choices") t.Fatalf("edits does not properly return the correct number of choices")

View File

@@ -2,7 +2,6 @@ package openai_test
import ( import (
. "github.com/sashabaranov/go-openai" . "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks" "github.com/sashabaranov/go-openai/internal/test/checks"
"bytes" "bytes"
@@ -67,7 +66,8 @@ func TestEmbeddingModel(t *testing.T) {
} }
func TestEmbeddingEndpoint(t *testing.T) { func TestEmbeddingEndpoint(t *testing.T) {
server := test.NewTestServer() client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler( server.RegisterHandler(
"/v1/embeddings", "/v1/embeddings",
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
@@ -75,17 +75,6 @@ func TestEmbeddingEndpoint(t *testing.T) {
fmt.Fprintln(w, string(resBytes)) fmt.Fprintln(w, string(resBytes))
}, },
) )
// create the test server _, err := client.CreateEmbeddings(context.Background(), EmbeddingRequest{})
var err error
ts := server.OpenAITestServer()
ts.Start()
defer ts.Close()
config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
ctx := context.Background()
_, err = client.CreateEmbeddings(ctx, EmbeddingRequest{})
checks.NoError(t, err, "CreateEmbeddings error") checks.NoError(t, err, "CreateEmbeddings error")
} }

View File

@@ -8,27 +8,29 @@ import (
"testing" "testing"
. "github.com/sashabaranov/go-openai" . "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks" "github.com/sashabaranov/go-openai/internal/test/checks"
) )
// TestGetEngine Tests the retrieve engine endpoint of the API using the mocked server. // TestGetEngine Tests the retrieve engine endpoint of the API using the mocked server.
func TestGetEngine(t *testing.T) { func TestGetEngine(t *testing.T) {
server := test.NewTestServer() client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler("/v1/engines/text-davinci-003", func(w http.ResponseWriter, r *http.Request) { server.RegisterHandler("/v1/engines/text-davinci-003", func(w http.ResponseWriter, r *http.Request) {
resBytes, _ := json.Marshal(Engine{}) resBytes, _ := json.Marshal(Engine{})
fmt.Fprintln(w, string(resBytes)) fmt.Fprintln(w, string(resBytes))
}) })
// create the test server _, err := client.GetEngine(context.Background(), "text-davinci-003")
ts := server.OpenAITestServer()
ts.Start()
defer ts.Close()
config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
ctx := context.Background()
_, err := client.GetEngine(ctx, "text-davinci-003")
checks.NoError(t, err, "GetEngine error") checks.NoError(t, err, "GetEngine error")
} }
// TestListEngines Tests the list engines endpoint of the API using the mocked server.
func TestListEngines(t *testing.T) {
client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler("/v1/engines", func(w http.ResponseWriter, r *http.Request) {
resBytes, _ := json.Marshal(EnginesList{})
fmt.Fprintln(w, string(resBytes))
})
_, err := client.ListEngines(context.Background())
checks.NoError(t, err, "ListEngines error")
}

183
files_api_test.go Normal file
View File

@@ -0,0 +1,183 @@
package openai_test
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"os"
"strconv"
"testing"
"time"
. "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks"
)
func TestFileUpload(t *testing.T) {
client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler("/v1/files", handleCreateFile)
req := FileRequest{
FileName: "test.go",
FilePath: "client.go",
Purpose: "fine-tune",
}
_, err := client.CreateFile(context.Background(), req)
checks.NoError(t, err, "CreateFile error")
}
// handleCreateFile Handles the images endpoint by the test server.
func handleCreateFile(w http.ResponseWriter, r *http.Request) {
var err error
var resBytes []byte
// edits only accepts POST requests
if r.Method != "POST" {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
err = r.ParseMultipartForm(1024 * 1024 * 1024)
if err != nil {
http.Error(w, "file is more than 1GB", http.StatusInternalServerError)
return
}
values := r.Form
var purpose string
for key, value := range values {
if key == "purpose" {
purpose = value[0]
}
}
file, header, err := r.FormFile("file")
if err != nil {
return
}
defer file.Close()
var fileReq = File{
Bytes: int(header.Size),
ID: strconv.Itoa(int(time.Now().Unix())),
FileName: header.Filename,
Purpose: purpose,
CreatedAt: time.Now().Unix(),
Object: "test-objecct",
Owner: "test-owner",
}
resBytes, _ = json.Marshal(fileReq)
fmt.Fprint(w, string(resBytes))
}
func TestDeleteFile(t *testing.T) {
client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler("/v1/files/deadbeef", func(w http.ResponseWriter, r *http.Request) {})
err := client.DeleteFile(context.Background(), "deadbeef")
checks.NoError(t, err, "DeleteFile error")
}
func TestListFile(t *testing.T) {
client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler("/v1/files", func(w http.ResponseWriter, r *http.Request) {
resBytes, _ := json.Marshal(FilesList{})
fmt.Fprintln(w, string(resBytes))
})
_, err := client.ListFiles(context.Background())
checks.NoError(t, err, "ListFiles error")
}
func TestGetFile(t *testing.T) {
client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler("/v1/files/deadbeef", func(w http.ResponseWriter, r *http.Request) {
resBytes, _ := json.Marshal(File{})
fmt.Fprintln(w, string(resBytes))
})
_, err := client.GetFile(context.Background(), "deadbeef")
checks.NoError(t, err, "GetFile error")
}
func TestGetFileContent(t *testing.T) {
wantRespJsonl := `{"prompt": "foo", "completion": "foo"}
{"prompt": "bar", "completion": "bar"}
{"prompt": "baz", "completion": "baz"}
`
client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler("/v1/files/deadbeef/content", func(w http.ResponseWriter, r *http.Request) {
// edits only accepts GET requests
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
fmt.Fprint(w, wantRespJsonl)
})
content, err := client.GetFileContent(context.Background(), "deadbeef")
checks.NoError(t, err, "GetFileContent error")
defer content.Close()
actual, _ := io.ReadAll(content)
if string(actual) != wantRespJsonl {
t.Errorf("Expected %s, got %s", wantRespJsonl, string(actual))
}
}
func TestGetFileContentReturnError(t *testing.T) {
wantMessage := "To help mitigate abuse, downloading of fine-tune training files is disabled for free accounts."
wantType := "invalid_request_error"
wantErrorResp := `{
"error": {
"message": "` + wantMessage + `",
"type": "` + wantType + `",
"param": null,
"code": null
}
}`
client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler("/v1/files/deadbeef/content", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadRequest)
fmt.Fprint(w, wantErrorResp)
})
_, err := client.GetFileContent(context.Background(), "deadbeef")
if err == nil {
t.Fatal("Did not return error")
}
apiErr := &APIError{}
if !errors.As(err, &apiErr) {
t.Fatalf("Did not return APIError: %+v\n", apiErr)
}
if apiErr.Message != wantMessage {
t.Fatalf("Expected %s Message, got = %s\n", wantMessage, apiErr.Message)
return
}
if apiErr.Type != wantType {
t.Fatalf("Expected %s Type, got = %s\n", wantType, apiErr.Type)
return
}
}
func TestGetFileContentReturnTimeoutError(t *testing.T) {
client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler("/v1/files/deadbeef/content", func(w http.ResponseWriter, r *http.Request) {
time.Sleep(10 * time.Nanosecond)
})
ctx := context.Background()
ctx, cancel := context.WithTimeout(ctx, time.Nanosecond)
defer cancel()
_, err := client.GetFileContent(ctx, "deadbeef")
if err == nil {
t.Fatal("Did not return error")
}
if !os.IsTimeout(err) {
t.Fatal("Did not return timeout error")
}
}

View File

@@ -2,86 +2,15 @@ package openai //nolint:testpackage // testing private field
import ( import (
utils "github.com/sashabaranov/go-openai/internal" utils "github.com/sashabaranov/go-openai/internal"
"github.com/sashabaranov/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks" "github.com/sashabaranov/go-openai/internal/test/checks"
"context" "context"
"encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"net/http"
"os" "os"
"strconv"
"testing" "testing"
"time"
) )
func TestFileUpload(t *testing.T) {
server := test.NewTestServer()
server.RegisterHandler("/v1/files", handleCreateFile)
// create the test server
var err error
ts := server.OpenAITestServer()
ts.Start()
defer ts.Close()
config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
ctx := context.Background()
req := FileRequest{
FileName: "test.go",
FilePath: "client.go",
Purpose: "fine-tune",
}
_, err = client.CreateFile(ctx, req)
checks.NoError(t, err, "CreateFile error")
}
// handleCreateFile Handles the images endpoint by the test server.
func handleCreateFile(w http.ResponseWriter, r *http.Request) {
var err error
var resBytes []byte
// edits only accepts POST requests
if r.Method != "POST" {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
err = r.ParseMultipartForm(1024 * 1024 * 1024)
if err != nil {
http.Error(w, "file is more than 1GB", http.StatusInternalServerError)
return
}
values := r.Form
var purpose string
for key, value := range values {
if key == "purpose" {
purpose = value[0]
}
}
file, header, err := r.FormFile("file")
if err != nil {
return
}
defer file.Close()
var fileReq = File{
Bytes: int(header.Size),
ID: strconv.Itoa(int(time.Now().Unix())),
FileName: header.Filename,
Purpose: purpose,
CreatedAt: time.Now().Unix(),
Object: "test-objecct",
Owner: "test-owner",
}
resBytes, _ = json.Marshal(fileReq)
fmt.Fprint(w, string(resBytes))
}
func TestFileUploadWithFailingFormBuilder(t *testing.T) { func TestFileUploadWithFailingFormBuilder(t *testing.T) {
config := DefaultConfig("") config := DefaultConfig("")
config.BaseURL = "" config.BaseURL = ""
@@ -142,168 +71,3 @@ func TestFileUploadWithNonExistentPath(t *testing.T) {
_, err := client.CreateFile(ctx, req) _, err := client.CreateFile(ctx, req)
checks.ErrorIs(t, err, os.ErrNotExist, "CreateFile should return error if file does not exist") checks.ErrorIs(t, err, os.ErrNotExist, "CreateFile should return error if file does not exist")
} }
func TestDeleteFile(t *testing.T) {
server := test.NewTestServer()
server.RegisterHandler("/v1/files/deadbeef", func(w http.ResponseWriter, r *http.Request) {
})
// create the test server
var err error
ts := server.OpenAITestServer()
ts.Start()
defer ts.Close()
config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
ctx := context.Background()
err = client.DeleteFile(ctx, "deadbeef")
checks.NoError(t, err, "DeleteFile error")
}
func TestListFile(t *testing.T) {
server := test.NewTestServer()
server.RegisterHandler("/v1/files", func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, "{}")
})
// create the test server
var err error
ts := server.OpenAITestServer()
ts.Start()
defer ts.Close()
config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
ctx := context.Background()
_, err = client.ListFiles(ctx)
checks.NoError(t, err, "ListFiles error")
}
func TestGetFile(t *testing.T) {
server := test.NewTestServer()
server.RegisterHandler("/v1/files/deadbeef", func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, "{}")
})
// create the test server
var err error
ts := server.OpenAITestServer()
ts.Start()
defer ts.Close()
config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
ctx := context.Background()
_, err = client.GetFile(ctx, "deadbeef")
checks.NoError(t, err, "GetFile error")
}
func TestGetFileContent(t *testing.T) {
wantRespJsonl := `{"prompt": "foo", "completion": "foo"}
{"prompt": "bar", "completion": "bar"}
{"prompt": "baz", "completion": "baz"}
`
server := test.NewTestServer()
server.RegisterHandler("/v1/files/deadbeef/content", func(w http.ResponseWriter, r *http.Request) {
// edits only accepts GET requests
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
fmt.Fprint(w, wantRespJsonl)
})
// create the test server
var err error
ts := server.OpenAITestServer()
ts.Start()
defer ts.Close()
config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
ctx := context.Background()
content, err := client.GetFileContent(ctx, "deadbeef")
checks.NoError(t, err, "GetFileContent error")
defer content.Close()
actual, _ := io.ReadAll(content)
if string(actual) != wantRespJsonl {
t.Errorf("Expected %s, got %s", wantRespJsonl, string(actual))
}
}
func TestGetFileContentReturnError(t *testing.T) {
wantMessage := "To help mitigate abuse, downloading of fine-tune training files is disabled for free accounts."
wantType := "invalid_request_error"
wantErrorResp := `{
"error": {
"message": "` + wantMessage + `",
"type": "` + wantType + `",
"param": null,
"code": null
}
}`
server := test.NewTestServer()
server.RegisterHandler("/v1/files/deadbeef/content", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadRequest)
fmt.Fprint(w, wantErrorResp)
})
// create the test server
ts := server.OpenAITestServer()
ts.Start()
defer ts.Close()
config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
ctx := context.Background()
_, err := client.GetFileContent(ctx, "deadbeef")
if err == nil {
t.Fatal("Did not return error")
}
apiErr := &APIError{}
if !errors.As(err, &apiErr) {
t.Fatalf("Did not return APIError: %+v\n", apiErr)
}
if apiErr.Message != wantMessage {
t.Fatalf("Expected %s Message, got = %s\n", wantMessage, apiErr.Message)
return
}
if apiErr.Type != wantType {
t.Fatalf("Expected %s Type, got = %s\n", wantType, apiErr.Type)
return
}
}
func TestGetFileContentReturnTimeoutError(t *testing.T) {
server := test.NewTestServer()
server.RegisterHandler("/v1/files/deadbeef/content", func(w http.ResponseWriter, r *http.Request) {
time.Sleep(10 * time.Nanosecond)
})
// create the test server
ts := server.OpenAITestServer()
ts.Start()
defer ts.Close()
config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
ctx := context.Background()
ctx, cancel := context.WithTimeout(ctx, time.Nanosecond)
defer cancel()
_, err := client.GetFileContent(ctx, "deadbeef")
if err == nil {
t.Fatal("Did not return error")
}
if !os.IsTimeout(err) {
t.Fatal("Did not return timeout error")
}
}

View File

@@ -2,7 +2,6 @@ package openai_test
import ( import (
. "github.com/sashabaranov/go-openai" . "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks" "github.com/sashabaranov/go-openai/internal/test/checks"
"context" "context"
@@ -16,7 +15,8 @@ const testFineTuneID = "fine-tune-id"
// TestFineTunes Tests the fine tunes endpoint of the API using the mocked server. // TestFineTunes Tests the fine tunes endpoint of the API using the mocked server.
func TestFineTunes(t *testing.T) { func TestFineTunes(t *testing.T) {
server := test.NewTestServer() client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler( server.RegisterHandler(
"/v1/fine-tunes", "/v1/fine-tunes",
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
@@ -59,18 +59,9 @@ func TestFineTunes(t *testing.T) {
}, },
) )
// create the test server
var err error
ts := server.OpenAITestServer()
ts.Start()
defer ts.Close()
config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
ctx := context.Background() ctx := context.Background()
_, err = client.ListFineTunes(ctx) _, err := client.ListFineTunes(ctx)
checks.NoError(t, err, "ListFineTunes error") checks.NoError(t, err, "ListFineTunes error")
_, err = client.CreateFineTune(ctx, FineTuneRequest{}) _, err = client.CreateFineTune(ctx, FineTuneRequest{})

223
image_api_test.go Normal file
View File

@@ -0,0 +1,223 @@
package openai_test
import (
. "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"testing"
"time"
)
func TestImages(t *testing.T) {
client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler("/v1/images/generations", handleImageEndpoint)
_, err := client.CreateImage(context.Background(), ImageRequest{
Prompt: "Lorem ipsum",
})
checks.NoError(t, err, "CreateImage error")
}
// handleImageEndpoint Handles the images endpoint by the test server.
func handleImageEndpoint(w http.ResponseWriter, r *http.Request) {
var err error
var resBytes []byte
// imagess only accepts POST requests
if r.Method != "POST" {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
var imageReq ImageRequest
if imageReq, err = getImageBody(r); err != nil {
http.Error(w, "could not read request", http.StatusInternalServerError)
return
}
res := ImageResponse{
Created: time.Now().Unix(),
}
for i := 0; i < imageReq.N; i++ {
imageData := ImageResponseDataInner{}
switch imageReq.ResponseFormat {
case CreateImageResponseFormatURL, "":
imageData.URL = "https://example.com/image.png"
case CreateImageResponseFormatB64JSON:
// This decodes to "{}" in base64.
imageData.B64JSON = "e30K"
default:
http.Error(w, "invalid response format", http.StatusBadRequest)
return
}
res.Data = append(res.Data, imageData)
}
resBytes, _ = json.Marshal(res)
fmt.Fprintln(w, string(resBytes))
}
// getImageBody Returns the body of the request to create a image.
func getImageBody(r *http.Request) (ImageRequest, error) {
image := ImageRequest{}
// read the request body
reqBody, err := io.ReadAll(r.Body)
if err != nil {
return ImageRequest{}, err
}
err = json.Unmarshal(reqBody, &image)
if err != nil {
return ImageRequest{}, err
}
return image, nil
}
func TestImageEdit(t *testing.T) {
client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler("/v1/images/edits", handleEditImageEndpoint)
origin, err := os.Create("image.png")
if err != nil {
t.Error("open origin file error")
return
}
mask, err := os.Create("mask.png")
if err != nil {
t.Error("open mask file error")
return
}
defer func() {
mask.Close()
origin.Close()
os.Remove("mask.png")
os.Remove("image.png")
}()
_, err = client.CreateEditImage(context.Background(), ImageEditRequest{
Image: origin,
Mask: mask,
Prompt: "There is a turtle in the pool",
N: 3,
Size: CreateImageSize1024x1024,
ResponseFormat: CreateImageResponseFormatURL,
})
checks.NoError(t, err, "CreateImage error")
}
func TestImageEditWithoutMask(t *testing.T) {
client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler("/v1/images/edits", handleEditImageEndpoint)
origin, err := os.Create("image.png")
if err != nil {
t.Error("open origin file error")
return
}
defer func() {
origin.Close()
os.Remove("image.png")
}()
_, err = client.CreateEditImage(context.Background(), ImageEditRequest{
Image: origin,
Prompt: "There is a turtle in the pool",
N: 3,
Size: CreateImageSize1024x1024,
ResponseFormat: CreateImageResponseFormatURL,
})
checks.NoError(t, err, "CreateImage error")
}
// handleEditImageEndpoint Handles the images endpoint by the test server.
func handleEditImageEndpoint(w http.ResponseWriter, r *http.Request) {
var resBytes []byte
// imagess only accepts POST requests
if r.Method != "POST" {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
responses := ImageResponse{
Created: time.Now().Unix(),
Data: []ImageResponseDataInner{
{
URL: "test-url1",
B64JSON: "",
},
{
URL: "test-url2",
B64JSON: "",
},
{
URL: "test-url3",
B64JSON: "",
},
},
}
resBytes, _ = json.Marshal(responses)
fmt.Fprintln(w, string(resBytes))
}
func TestImageVariation(t *testing.T) {
client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler("/v1/images/variations", handleVariateImageEndpoint)
origin, err := os.Create("image.png")
if err != nil {
t.Error("open origin file error")
return
}
defer func() {
origin.Close()
os.Remove("image.png")
}()
_, err = client.CreateVariImage(context.Background(), ImageVariRequest{
Image: origin,
N: 3,
Size: CreateImageSize1024x1024,
ResponseFormat: CreateImageResponseFormatURL,
})
checks.NoError(t, err, "CreateImage error")
}
// handleVariateImageEndpoint Handles the images endpoint by the test server.
func handleVariateImageEndpoint(w http.ResponseWriter, r *http.Request) {
var resBytes []byte
// imagess only accepts POST requests
if r.Method != "POST" {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
responses := ImageResponse{
Created: time.Now().Unix(),
Data: []ImageResponseDataInner{
{
URL: "test-url1",
B64JSON: "",
},
{
URL: "test-url2",
B64JSON: "",
},
{
URL: "test-url3",
B64JSON: "",
},
},
}
resBytes, _ = json.Marshal(responses)
fmt.Fprintln(w, string(resBytes))
}

View File

@@ -2,267 +2,15 @@ package openai //nolint:testpackage // testing private field
import ( import (
utils "github.com/sashabaranov/go-openai/internal" utils "github.com/sashabaranov/go-openai/internal"
"github.com/sashabaranov/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks" "github.com/sashabaranov/go-openai/internal/test/checks"
"context" "context"
"encoding/json"
"fmt" "fmt"
"io" "io"
"net/http"
"os" "os"
"testing" "testing"
"time"
) )
func TestImages(t *testing.T) {
server := test.NewTestServer()
server.RegisterHandler("/v1/images/generations", handleImageEndpoint)
// create the test server
var err error
ts := server.OpenAITestServer()
ts.Start()
defer ts.Close()
config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
ctx := context.Background()
req := ImageRequest{}
req.Prompt = "Lorem ipsum"
_, err = client.CreateImage(ctx, req)
checks.NoError(t, err, "CreateImage error")
}
// handleImageEndpoint Handles the images endpoint by the test server.
func handleImageEndpoint(w http.ResponseWriter, r *http.Request) {
var err error
var resBytes []byte
// imagess only accepts POST requests
if r.Method != "POST" {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
var imageReq ImageRequest
if imageReq, err = getImageBody(r); err != nil {
http.Error(w, "could not read request", http.StatusInternalServerError)
return
}
res := ImageResponse{
Created: time.Now().Unix(),
}
for i := 0; i < imageReq.N; i++ {
imageData := ImageResponseDataInner{}
switch imageReq.ResponseFormat {
case CreateImageResponseFormatURL, "":
imageData.URL = "https://example.com/image.png"
case CreateImageResponseFormatB64JSON:
// This decodes to "{}" in base64.
imageData.B64JSON = "e30K"
default:
http.Error(w, "invalid response format", http.StatusBadRequest)
return
}
res.Data = append(res.Data, imageData)
}
resBytes, _ = json.Marshal(res)
fmt.Fprintln(w, string(resBytes))
}
// getImageBody Returns the body of the request to create a image.
func getImageBody(r *http.Request) (ImageRequest, error) {
image := ImageRequest{}
// read the request body
reqBody, err := io.ReadAll(r.Body)
if err != nil {
return ImageRequest{}, err
}
err = json.Unmarshal(reqBody, &image)
if err != nil {
return ImageRequest{}, err
}
return image, nil
}
func TestImageEdit(t *testing.T) {
server := test.NewTestServer()
server.RegisterHandler("/v1/images/edits", handleEditImageEndpoint)
// create the test server
var err error
ts := server.OpenAITestServer()
ts.Start()
defer ts.Close()
config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
ctx := context.Background()
origin, err := os.Create("image.png")
if err != nil {
t.Error("open origin file error")
return
}
mask, err := os.Create("mask.png")
if err != nil {
t.Error("open mask file error")
return
}
defer func() {
mask.Close()
origin.Close()
os.Remove("mask.png")
os.Remove("image.png")
}()
req := ImageEditRequest{
Image: origin,
Mask: mask,
Prompt: "There is a turtle in the pool",
N: 3,
Size: CreateImageSize1024x1024,
ResponseFormat: CreateImageResponseFormatURL,
}
_, err = client.CreateEditImage(ctx, req)
checks.NoError(t, err, "CreateImage error")
}
func TestImageEditWithoutMask(t *testing.T) {
server := test.NewTestServer()
server.RegisterHandler("/v1/images/edits", handleEditImageEndpoint)
// create the test server
var err error
ts := server.OpenAITestServer()
ts.Start()
defer ts.Close()
config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
ctx := context.Background()
origin, err := os.Create("image.png")
if err != nil {
t.Error("open origin file error")
return
}
defer func() {
origin.Close()
os.Remove("image.png")
}()
req := ImageEditRequest{
Image: origin,
Prompt: "There is a turtle in the pool",
N: 3,
Size: CreateImageSize1024x1024,
ResponseFormat: CreateImageResponseFormatURL,
}
_, err = client.CreateEditImage(ctx, req)
checks.NoError(t, err, "CreateImage error")
}
// handleEditImageEndpoint Handles the images endpoint by the test server.
func handleEditImageEndpoint(w http.ResponseWriter, r *http.Request) {
var resBytes []byte
// imagess only accepts POST requests
if r.Method != "POST" {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
responses := ImageResponse{
Created: time.Now().Unix(),
Data: []ImageResponseDataInner{
{
URL: "test-url1",
B64JSON: "",
},
{
URL: "test-url2",
B64JSON: "",
},
{
URL: "test-url3",
B64JSON: "",
},
},
}
resBytes, _ = json.Marshal(responses)
fmt.Fprintln(w, string(resBytes))
}
func TestImageVariation(t *testing.T) {
server := test.NewTestServer()
server.RegisterHandler("/v1/images/variations", handleVariateImageEndpoint)
// create the test server
var err error
ts := server.OpenAITestServer()
ts.Start()
defer ts.Close()
config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
ctx := context.Background()
origin, err := os.Create("image.png")
if err != nil {
t.Error("open origin file error")
return
}
defer func() {
origin.Close()
os.Remove("image.png")
}()
req := ImageVariRequest{
Image: origin,
N: 3,
Size: CreateImageSize1024x1024,
ResponseFormat: CreateImageResponseFormatURL,
}
_, err = client.CreateVariImage(ctx, req)
checks.NoError(t, err, "CreateImage error")
}
// handleVariateImageEndpoint Handles the images endpoint by the test server.
func handleVariateImageEndpoint(w http.ResponseWriter, r *http.Request) {
var resBytes []byte
// imagess only accepts POST requests
if r.Method != "POST" {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
responses := ImageResponse{
Created: time.Now().Unix(),
Data: []ImageResponseDataInner{
{
URL: "test-url1",
B64JSON: "",
},
{
URL: "test-url2",
B64JSON: "",
},
{
URL: "test-url3",
B64JSON: "",
},
},
}
resBytes, _ = json.Marshal(responses)
fmt.Fprintln(w, string(resBytes))
}
type mockFormBuilder struct { type mockFormBuilder struct {
mockCreateFormFile func(string, *os.File) error mockCreateFormFile func(string, *os.File) error
mockCreateFormFileReader func(string, io.Reader, string) error mockCreateFormFileReader func(string, io.Reader, string) error

View File

@@ -2,7 +2,6 @@ package openai_test
import ( import (
. "github.com/sashabaranov/go-openai" . "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks" "github.com/sashabaranov/go-openai/internal/test/checks"
"context" "context"
@@ -12,85 +11,47 @@ import (
"testing" "testing"
) )
// TestListModels Tests the models endpoint of the API using the mocked server. // TestListModels Tests the list models endpoint of the API using the mocked server.
func TestListModels(t *testing.T) { func TestListModels(t *testing.T) {
server := test.NewTestServer() client, server, teardown := setupOpenAITestServer()
server.RegisterHandler("/v1/models", handleModelsEndpoint) defer teardown()
// create the test server server.RegisterHandler("/v1/models", handleListModelsEndpoint)
var err error _, err := client.ListModels(context.Background())
ts := server.OpenAITestServer()
ts.Start()
defer ts.Close()
config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
ctx := context.Background()
_, err = client.ListModels(ctx)
checks.NoError(t, err, "ListModels error") checks.NoError(t, err, "ListModels error")
} }
func TestAzureListModels(t *testing.T) { func TestAzureListModels(t *testing.T) {
server := test.NewTestServer() client, server, teardown := setupAzureTestServer()
server.RegisterHandler("/openai/models", handleModelsEndpoint) defer teardown()
// create the test server server.RegisterHandler("/openai/models", handleListModelsEndpoint)
var err error _, err := client.ListModels(context.Background())
ts := server.OpenAITestServer()
ts.Start()
defer ts.Close()
config := DefaultAzureConfig(test.GetTestToken(), "https://dummylab.openai.azure.com/")
config.BaseURL = ts.URL
client := NewClientWithConfig(config)
ctx := context.Background()
_, err = client.ListModels(ctx)
checks.NoError(t, err, "ListModels error") checks.NoError(t, err, "ListModels error")
} }
// handleModelsEndpoint Handles the models endpoint by the test server. // handleListModelsEndpoint Handles the list models endpoint by the test server.
func handleModelsEndpoint(w http.ResponseWriter, _ *http.Request) { func handleListModelsEndpoint(w http.ResponseWriter, _ *http.Request) {
resBytes, _ := json.Marshal(ModelsList{}) resBytes, _ := json.Marshal(ModelsList{})
fmt.Fprintln(w, string(resBytes)) fmt.Fprintln(w, string(resBytes))
} }
// TestGetModel Tests the retrieve model endpoint of the API using the mocked server. // TestGetModel Tests the retrieve model endpoint of the API using the mocked server.
func TestGetModel(t *testing.T) { func TestGetModel(t *testing.T) {
server := test.NewTestServer() client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler("/v1/models/text-davinci-003", handleGetModelEndpoint) server.RegisterHandler("/v1/models/text-davinci-003", handleGetModelEndpoint)
// create the test server _, err := client.GetModel(context.Background(), "text-davinci-003")
ts := server.OpenAITestServer()
ts.Start()
defer ts.Close()
config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
ctx := context.Background()
_, err := client.GetModel(ctx, "text-davinci-003")
checks.NoError(t, err, "GetModel error") checks.NoError(t, err, "GetModel error")
} }
func TestAzureGetModel(t *testing.T) { func TestAzureGetModel(t *testing.T) {
server := test.NewTestServer() client, server, teardown := setupAzureTestServer()
server.RegisterHandler("/openai/models/text-davinci-003", handleModelsEndpoint) defer teardown()
// create the test server server.RegisterHandler("/openai/models/text-davinci-003", handleGetModelEndpoint)
ts := server.OpenAITestServer() _, err := client.GetModel(context.Background(), "text-davinci-003")
ts.Start()
defer ts.Close()
config := DefaultAzureConfig(test.GetTestToken(), "https://dummylab.openai.azure.com/")
config.BaseURL = ts.URL
client := NewClientWithConfig(config)
ctx := context.Background()
_, err := client.GetModel(ctx, "text-davinci-003")
checks.NoError(t, err, "GetModel error") checks.NoError(t, err, "GetModel error")
} }
// handleModelsEndpoint Handles the models endpoint by the test server. // handleGetModelsEndpoint Handles the get model endpoint by the test server.
func handleGetModelEndpoint(w http.ResponseWriter, _ *http.Request) { func handleGetModelEndpoint(w http.ResponseWriter, _ *http.Request) {
resBytes, _ := json.Marshal(Model{}) resBytes, _ := json.Marshal(Model{})
fmt.Fprintln(w, string(resBytes)) fmt.Fprintln(w, string(resBytes))

View File

@@ -2,7 +2,6 @@ package openai_test
import ( import (
. "github.com/sashabaranov/go-openai" . "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks" "github.com/sashabaranov/go-openai/internal/test/checks"
"context" "context"
@@ -18,26 +17,13 @@ import (
// TestModeration Tests the moderations endpoint of the API using the mocked server. // TestModeration Tests the moderations endpoint of the API using the mocked server.
func TestModerations(t *testing.T) { func TestModerations(t *testing.T) {
server := test.NewTestServer() client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler("/v1/moderations", handleModerationEndpoint) server.RegisterHandler("/v1/moderations", handleModerationEndpoint)
// create the test server _, err := client.Moderations(context.Background(), ModerationRequest{
var err error Model: ModerationTextStable,
ts := server.OpenAITestServer()
ts.Start()
defer ts.Close()
config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
ctx := context.Background()
// create an edit request
model := "text-moderation-stable"
moderationReq := ModerationRequest{
Model: model,
Input: "I want to kill them.", Input: "I want to kill them.",
} })
_, err = client.Moderations(ctx, moderationReq)
checks.NoError(t, err, "Moderation error") checks.NoError(t, err, "Moderation error")
} }

28
openai_test.go Normal file
View File

@@ -0,0 +1,28 @@
package openai_test
import (
. "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test"
)
func setupOpenAITestServer() (client *Client, server *test.ServerTest, teardown func()) {
server = test.NewTestServer()
ts := server.OpenAITestServer()
ts.Start()
teardown = ts.Close
config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client = NewClientWithConfig(config)
return
}
func setupAzureTestServer() (client *Client, server *test.ServerTest, teardown func()) {
server = test.NewTestServer()
ts := server.OpenAITestServer()
ts.Start()
teardown = ts.Close
config := DefaultAzureConfig(test.GetTestToken(), "https://dummylab.openai.azure.com/")
config.BaseURL = ts.URL
client = NewClientWithConfig(config)
return
}

View File

@@ -7,6 +7,8 @@ import (
"testing" "testing"
utils "github.com/sashabaranov/go-openai/internal" utils "github.com/sashabaranov/go-openai/internal"
"github.com/sashabaranov/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks"
) )
var errTestUnmarshalerFailed = errors.New("test unmarshaler failed") var errTestUnmarshalerFailed = errors.New("test unmarshaler failed")
@@ -47,7 +49,17 @@ func TestStreamReaderReturnsErrTooManyEmptyStreamMessages(t *testing.T) {
unmarshaler: &utils.JSONUnmarshaler{}, unmarshaler: &utils.JSONUnmarshaler{},
} }
_, err := stream.Recv() _, err := stream.Recv()
if !errors.Is(err, ErrTooManyEmptyStreamMessages) { checks.ErrorIs(t, err, ErrTooManyEmptyStreamMessages, "Did not return error when recv failed", err.Error())
t.Fatalf("Did not return error when recv failed: %v", err) }
}
func TestStreamReaderReturnsErrTestErrorAccumulatorWriteFailed(t *testing.T) {
stream := &streamReader[ChatCompletionStreamResponse]{
reader: bufio.NewReader(bytes.NewReader([]byte("\n"))),
errAccumulator: &utils.DefaultErrorAccumulator{
Buffer: &test.FailingErrorBuffer{},
},
unmarshaler: &utils.JSONUnmarshaler{},
}
_, err := stream.Recv()
checks.ErrorIs(t, err, test.ErrTestErrorAccumulatorWriteFailed, "Did not return error when write failed", err.Error())
} }

View File

@@ -6,11 +6,9 @@ import (
"errors" "errors"
"io" "io"
"net/http" "net/http"
"net/http/httptest"
"testing" "testing"
. "github.com/sashabaranov/go-openai" . "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks" "github.com/sashabaranov/go-openai/internal/test/checks"
) )
@@ -32,7 +30,9 @@ func TestCompletionsStreamWrongModel(t *testing.T) {
} }
func TestCreateCompletionStream(t *testing.T) { func TestCreateCompletionStream(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler("/v1/completions", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Content-Type", "text/event-stream")
// Send test responses // Send test responses
@@ -52,28 +52,14 @@ func TestCreateCompletionStream(t *testing.T) {
_, err := w.Write(dataBytes) _, err := w.Write(dataBytes)
checks.NoError(t, err, "Write error") checks.NoError(t, err, "Write error")
})) })
defer server.Close()
// Client portion of the test stream, err := client.CreateCompletionStream(context.Background(), CompletionRequest{
config := DefaultConfig(test.GetTestToken())
config.BaseURL = server.URL + "/v1"
config.HTTPClient.Transport = &test.TokenRoundTripper{
Token: test.GetTestToken(),
Fallback: http.DefaultTransport,
}
client := NewClientWithConfig(config)
ctx := context.Background()
request := CompletionRequest{
Prompt: "Ex falso quodlibet", Prompt: "Ex falso quodlibet",
Model: "text-davinci-002", Model: "text-davinci-002",
MaxTokens: 10, MaxTokens: 10,
Stream: true, Stream: true,
} })
stream, err := client.CreateCompletionStream(ctx, request)
checks.NoError(t, err, "CreateCompletionStream returned error") checks.NoError(t, err, "CreateCompletionStream returned error")
defer stream.Close() defer stream.Close()
@@ -116,7 +102,9 @@ func TestCreateCompletionStream(t *testing.T) {
} }
func TestCreateCompletionStreamError(t *testing.T) { func TestCreateCompletionStreamError(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler("/v1/completions", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Content-Type", "text/event-stream")
// Send test responses // Send test responses
@@ -137,28 +125,14 @@ func TestCreateCompletionStreamError(t *testing.T) {
_, err := w.Write(dataBytes) _, err := w.Write(dataBytes)
checks.NoError(t, err, "Write error") checks.NoError(t, err, "Write error")
})) })
defer server.Close()
// Client portion of the test stream, err := client.CreateCompletionStream(context.Background(), CompletionRequest{
config := DefaultConfig(test.GetTestToken())
config.BaseURL = server.URL + "/v1"
config.HTTPClient.Transport = &test.TokenRoundTripper{
Token: test.GetTestToken(),
Fallback: http.DefaultTransport,
}
client := NewClientWithConfig(config)
ctx := context.Background()
request := CompletionRequest{
MaxTokens: 5, MaxTokens: 5,
Model: GPT3TextDavinci003, Model: GPT3TextDavinci003,
Prompt: "Hello!", Prompt: "Hello!",
Stream: true, Stream: true,
} })
stream, err := client.CreateCompletionStream(ctx, request)
checks.NoError(t, err, "CreateCompletionStream returned error") checks.NoError(t, err, "CreateCompletionStream returned error")
defer stream.Close() defer stream.Close()
@@ -173,7 +147,8 @@ func TestCreateCompletionStreamError(t *testing.T) {
} }
func TestCreateCompletionStreamRateLimitError(t *testing.T) { func TestCreateCompletionStreamRateLimitError(t *testing.T) {
server := test.NewTestServer() client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler("/v1/completions", func(w http.ResponseWriter, r *http.Request) { server.RegisterHandler("/v1/completions", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.WriteHeader(429) w.WriteHeader(429)
@@ -188,30 +163,14 @@ func TestCreateCompletionStreamRateLimitError(t *testing.T) {
_, err := w.Write(dataBytes) _, err := w.Write(dataBytes)
checks.NoError(t, err, "Write error") checks.NoError(t, err, "Write error")
}) })
ts := server.OpenAITestServer()
ts.Start()
defer ts.Close()
// Client portion of the test var apiErr *APIError
config := DefaultConfig(test.GetTestToken()) _, err := client.CreateCompletionStream(context.Background(), CompletionRequest{
config.BaseURL = ts.URL + "/v1"
config.HTTPClient.Transport = &test.TokenRoundTripper{
Token: test.GetTestToken(),
Fallback: http.DefaultTransport,
}
client := NewClientWithConfig(config)
ctx := context.Background()
request := CompletionRequest{
MaxTokens: 5, MaxTokens: 5,
Model: GPT3Ada, Model: GPT3Ada,
Prompt: "Hello!", Prompt: "Hello!",
Stream: true, Stream: true,
} })
var apiErr *APIError
_, err := client.CreateCompletionStream(ctx, request)
if !errors.As(err, &apiErr) { if !errors.As(err, &apiErr) {
t.Errorf("TestCreateCompletionStreamRateLimitError did not return APIError") t.Errorf("TestCreateCompletionStreamRateLimitError did not return APIError")
} }
@@ -219,7 +178,9 @@ func TestCreateCompletionStreamRateLimitError(t *testing.T) {
} }
func TestCreateCompletionStreamTooManyEmptyStreamMessagesError(t *testing.T) { func TestCreateCompletionStreamTooManyEmptyStreamMessagesError(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler("/v1/completions", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Content-Type", "text/event-stream")
// Send test responses // Send test responses
@@ -244,28 +205,14 @@ func TestCreateCompletionStreamTooManyEmptyStreamMessagesError(t *testing.T) {
_, err := w.Write(dataBytes) _, err := w.Write(dataBytes)
checks.NoError(t, err, "Write error") checks.NoError(t, err, "Write error")
})) })
defer server.Close()
// Client portion of the test stream, err := client.CreateCompletionStream(context.Background(), CompletionRequest{
config := DefaultConfig(test.GetTestToken())
config.BaseURL = server.URL + "/v1"
config.HTTPClient.Transport = &test.TokenRoundTripper{
Token: test.GetTestToken(),
Fallback: http.DefaultTransport,
}
client := NewClientWithConfig(config)
ctx := context.Background()
request := CompletionRequest{
Prompt: "Ex falso quodlibet", Prompt: "Ex falso quodlibet",
Model: "text-davinci-002", Model: "text-davinci-002",
MaxTokens: 10, MaxTokens: 10,
Stream: true, Stream: true,
} })
stream, err := client.CreateCompletionStream(ctx, request)
checks.NoError(t, err, "CreateCompletionStream returned error") checks.NoError(t, err, "CreateCompletionStream returned error")
defer stream.Close() defer stream.Close()
@@ -277,7 +224,9 @@ func TestCreateCompletionStreamTooManyEmptyStreamMessagesError(t *testing.T) {
} }
func TestCreateCompletionStreamUnexpectedTerminatedError(t *testing.T) { func TestCreateCompletionStreamUnexpectedTerminatedError(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler("/v1/completions", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Content-Type", "text/event-stream")
// Send test responses // Send test responses
@@ -291,28 +240,14 @@ func TestCreateCompletionStreamUnexpectedTerminatedError(t *testing.T) {
_, err := w.Write(dataBytes) _, err := w.Write(dataBytes)
checks.NoError(t, err, "Write error") checks.NoError(t, err, "Write error")
})) })
defer server.Close()
// Client portion of the test stream, err := client.CreateCompletionStream(context.Background(), CompletionRequest{
config := DefaultConfig(test.GetTestToken())
config.BaseURL = server.URL + "/v1"
config.HTTPClient.Transport = &test.TokenRoundTripper{
Token: test.GetTestToken(),
Fallback: http.DefaultTransport,
}
client := NewClientWithConfig(config)
ctx := context.Background()
request := CompletionRequest{
Prompt: "Ex falso quodlibet", Prompt: "Ex falso quodlibet",
Model: "text-davinci-002", Model: "text-davinci-002",
MaxTokens: 10, MaxTokens: 10,
Stream: true, Stream: true,
} })
stream, err := client.CreateCompletionStream(ctx, request)
checks.NoError(t, err, "CreateCompletionStream returned error") checks.NoError(t, err, "CreateCompletionStream returned error")
defer stream.Close() defer stream.Close()
@@ -324,7 +259,9 @@ func TestCreateCompletionStreamUnexpectedTerminatedError(t *testing.T) {
} }
func TestCreateCompletionStreamBrokenJSONError(t *testing.T) { func TestCreateCompletionStreamBrokenJSONError(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler("/v1/completions", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Content-Type", "text/event-stream")
// Send test responses // Send test responses
@@ -344,28 +281,14 @@ func TestCreateCompletionStreamBrokenJSONError(t *testing.T) {
_, err := w.Write(dataBytes) _, err := w.Write(dataBytes)
checks.NoError(t, err, "Write error") checks.NoError(t, err, "Write error")
})) })
defer server.Close()
// Client portion of the test stream, err := client.CreateCompletionStream(context.Background(), CompletionRequest{
config := DefaultConfig(test.GetTestToken())
config.BaseURL = server.URL + "/v1"
config.HTTPClient.Transport = &test.TokenRoundTripper{
Token: test.GetTestToken(),
Fallback: http.DefaultTransport,
}
client := NewClientWithConfig(config)
ctx := context.Background()
request := CompletionRequest{
Prompt: "Ex falso quodlibet", Prompt: "Ex falso quodlibet",
Model: "text-davinci-002", Model: "text-davinci-002",
MaxTokens: 10, MaxTokens: 10,
Stream: true, Stream: true,
} })
stream, err := client.CreateCompletionStream(ctx, request)
checks.NoError(t, err, "CreateCompletionStream returned error") checks.NoError(t, err, "CreateCompletionStream returned error")
defer stream.Close() defer stream.Close()