Refactor/internal testing (#194)
* added NoError check * corrected NoError * has error checks * replace more checks * Used checks test helper * Used checks test helper * remove duplicate import * fixed lint issues regarding length of messages --------- Co-authored-by: Rex Posadas <rposadas@redwoodlogistics.com>
This commit is contained in:
42
api_test.go
42
api_test.go
@@ -2,6 +2,7 @@ package openai_test
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
. "github.com/sashabaranov/go-openai"
|
. "github.com/sashabaranov/go-openai"
|
||||||
|
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||||
|
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
@@ -20,25 +21,17 @@ func TestAPI(t *testing.T) {
|
|||||||
c := NewClient(apiToken)
|
c := NewClient(apiToken)
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
_, err = c.ListEngines(ctx)
|
_, err = c.ListEngines(ctx)
|
||||||
if err != nil {
|
checks.NoError(t, err, "ListEngines error")
|
||||||
t.Fatalf("ListEngines error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = c.GetEngine(ctx, "davinci")
|
_, err = c.GetEngine(ctx, "davinci")
|
||||||
if err != nil {
|
checks.NoError(t, err, "GetEngine error")
|
||||||
t.Fatalf("GetEngine error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
fileRes, err := c.ListFiles(ctx)
|
fileRes, err := c.ListFiles(ctx)
|
||||||
if err != nil {
|
checks.NoError(t, err, "ListFiles error")
|
||||||
t.Fatalf("ListFiles error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(fileRes.Files) > 0 {
|
if len(fileRes.Files) > 0 {
|
||||||
_, err = c.GetFile(ctx, fileRes.Files[0].ID)
|
_, err = c.GetFile(ctx, fileRes.Files[0].ID)
|
||||||
if err != nil {
|
checks.NoError(t, err, "GetFile error")
|
||||||
t.Fatalf("GetFile error: %v", err)
|
|
||||||
}
|
|
||||||
} // else skip
|
} // else skip
|
||||||
|
|
||||||
embeddingReq := EmbeddingRequest{
|
embeddingReq := EmbeddingRequest{
|
||||||
@@ -49,9 +42,7 @@ func TestAPI(t *testing.T) {
|
|||||||
Model: AdaSearchQuery,
|
Model: AdaSearchQuery,
|
||||||
}
|
}
|
||||||
_, err = c.CreateEmbeddings(ctx, embeddingReq)
|
_, err = c.CreateEmbeddings(ctx, embeddingReq)
|
||||||
if err != nil {
|
checks.NoError(t, err, "Embedding error")
|
||||||
t.Fatalf("Embedding error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = c.CreateChatCompletion(
|
_, err = c.CreateChatCompletion(
|
||||||
ctx,
|
ctx,
|
||||||
@@ -66,9 +57,7 @@ func TestAPI(t *testing.T) {
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
if err != nil {
|
checks.NoError(t, err, "CreateChatCompletion (without name) returned error")
|
||||||
t.Errorf("CreateChatCompletion (without name) returned error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = c.CreateChatCompletion(
|
_, err = c.CreateChatCompletion(
|
||||||
ctx,
|
ctx,
|
||||||
@@ -83,10 +72,7 @@ func TestAPI(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
checks.NoError(t, err, "CreateChatCompletion (with name) returned error")
|
||||||
if err != nil {
|
|
||||||
t.Errorf("CreateChatCompletion (with name) returned error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
stream, err := c.CreateCompletionStream(ctx, CompletionRequest{
|
stream, err := c.CreateCompletionStream(ctx, CompletionRequest{
|
||||||
Prompt: "Ex falso quodlibet",
|
Prompt: "Ex falso quodlibet",
|
||||||
@@ -94,9 +80,7 @@ func TestAPI(t *testing.T) {
|
|||||||
MaxTokens: 5,
|
MaxTokens: 5,
|
||||||
Stream: true,
|
Stream: true,
|
||||||
})
|
})
|
||||||
if err != nil {
|
checks.NoError(t, err, "CreateCompletionStream returned error")
|
||||||
t.Errorf("CreateCompletionStream returned error: %v", err)
|
|
||||||
}
|
|
||||||
defer stream.Close()
|
defer stream.Close()
|
||||||
|
|
||||||
counter := 0
|
counter := 0
|
||||||
@@ -126,9 +110,7 @@ func TestAPIError(t *testing.T) {
|
|||||||
c := NewClient(apiToken + "_invalid")
|
c := NewClient(apiToken + "_invalid")
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
_, err = c.ListEngines(ctx)
|
_, err = c.ListEngines(ctx)
|
||||||
if err == nil {
|
checks.NoError(t, err, "ListEngines did not fail")
|
||||||
t.Fatal("ListEngines did not fail")
|
|
||||||
}
|
|
||||||
|
|
||||||
var apiErr *APIError
|
var apiErr *APIError
|
||||||
if !errors.As(err, &apiErr) {
|
if !errors.As(err, &apiErr) {
|
||||||
@@ -154,9 +136,7 @@ func TestRequestError(t *testing.T) {
|
|||||||
c := NewClientWithConfig(config)
|
c := NewClientWithConfig(config)
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
_, err = c.ListEngines(ctx)
|
_, err = c.ListEngines(ctx)
|
||||||
if err == nil {
|
checks.HasError(t, err, "ListEngines did not fail")
|
||||||
t.Fatal("ListEngines request did not fail")
|
|
||||||
}
|
|
||||||
|
|
||||||
var reqErr *RequestError
|
var reqErr *RequestError
|
||||||
if !errors.As(err, &reqErr) {
|
if !errors.As(err, &reqErr) {
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import (
|
|||||||
|
|
||||||
. "github.com/sashabaranov/go-openai"
|
. "github.com/sashabaranov/go-openai"
|
||||||
"github.com/sashabaranov/go-openai/internal/test"
|
"github.com/sashabaranov/go-openai/internal/test"
|
||||||
|
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||||
|
|
||||||
"context"
|
"context"
|
||||||
"testing"
|
"testing"
|
||||||
@@ -62,9 +63,7 @@ func TestAudio(t *testing.T) {
|
|||||||
Model: "whisper-3",
|
Model: "whisper-3",
|
||||||
}
|
}
|
||||||
_, err = tc.createFn(ctx, req)
|
_, err = tc.createFn(ctx, req)
|
||||||
if err != nil {
|
checks.NoError(t, err, "audio API error")
|
||||||
t.Fatalf("audio API error: %v", err)
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -115,9 +114,7 @@ func TestAudioWithOptionalArgs(t *testing.T) {
|
|||||||
Language: "zh",
|
Language: "zh",
|
||||||
}
|
}
|
||||||
_, err = tc.createFn(ctx, req)
|
_, err = tc.createFn(ctx, req)
|
||||||
if err != nil {
|
checks.NoError(t, err, "audio API error")
|
||||||
t.Fatalf("audio API error: %v", err)
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -125,9 +122,8 @@ func TestAudioWithOptionalArgs(t *testing.T) {
|
|||||||
// createTestFile creates a fake file with "hello" as the content.
|
// createTestFile creates a fake file with "hello" as the content.
|
||||||
func createTestFile(t *testing.T, path string) {
|
func createTestFile(t *testing.T, path string) {
|
||||||
file, err := os.Create(path)
|
file, err := os.Create(path)
|
||||||
if err != nil {
|
checks.NoError(t, err, "failed to create file")
|
||||||
t.Fatalf("failed to create file %v", err)
|
|
||||||
}
|
|
||||||
if _, err = file.WriteString("hello"); err != nil {
|
if _, err = file.WriteString("hello"); err != nil {
|
||||||
t.Fatalf("failed to write to file %v", err)
|
t.Fatalf("failed to write to file %v", err)
|
||||||
}
|
}
|
||||||
@@ -139,9 +135,7 @@ func createTestDirectory(t *testing.T) (path string, cleanup func()) {
|
|||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
path, err := os.MkdirTemp(os.TempDir(), "")
|
path, err := os.MkdirTemp(os.TempDir(), "")
|
||||||
if err != nil {
|
checks.NoError(t, err)
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return path, func() { os.RemoveAll(path) }
|
return path, func() { os.RemoveAll(path) }
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package openai_test
|
|||||||
import (
|
import (
|
||||||
. "github.com/sashabaranov/go-openai"
|
. "github.com/sashabaranov/go-openai"
|
||||||
"github.com/sashabaranov/go-openai/internal/test"
|
"github.com/sashabaranov/go-openai/internal/test"
|
||||||
|
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||||
|
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
@@ -55,9 +56,7 @@ func TestCreateChatCompletionStream(t *testing.T) {
|
|||||||
dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...)
|
dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...)
|
||||||
|
|
||||||
_, err := w.Write(dataBytes)
|
_, err := w.Write(dataBytes)
|
||||||
if err != nil {
|
checks.NoError(t, err, "Write error")
|
||||||
t.Errorf("Write error: %s", err)
|
|
||||||
}
|
|
||||||
}))
|
}))
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
|
||||||
@@ -85,9 +84,7 @@ func TestCreateChatCompletionStream(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
stream, err := client.CreateChatCompletionStream(ctx, request)
|
stream, err := client.CreateChatCompletionStream(ctx, request)
|
||||||
if err != nil {
|
checks.NoError(t, err, "CreateCompletionStream returned error")
|
||||||
t.Errorf("CreateCompletionStream returned error: %v", err)
|
|
||||||
}
|
|
||||||
defer stream.Close()
|
defer stream.Close()
|
||||||
|
|
||||||
expectedResponses := []ChatCompletionStreamResponse{
|
expectedResponses := []ChatCompletionStreamResponse{
|
||||||
@@ -126,9 +123,7 @@ func TestCreateChatCompletionStream(t *testing.T) {
|
|||||||
t.Logf("%d: %s", ix, string(b))
|
t.Logf("%d: %s", ix, string(b))
|
||||||
|
|
||||||
receivedResponse, streamErr := stream.Recv()
|
receivedResponse, streamErr := stream.Recv()
|
||||||
if streamErr != nil {
|
checks.NoError(t, streamErr, "stream.Recv() failed")
|
||||||
t.Errorf("stream.Recv() failed: %v", streamErr)
|
|
||||||
}
|
|
||||||
if !compareChatResponses(expectedResponse, receivedResponse) {
|
if !compareChatResponses(expectedResponse, receivedResponse) {
|
||||||
t.Errorf("Stream response %v is %v, expected %v", ix, receivedResponse, expectedResponse)
|
t.Errorf("Stream response %v is %v, expected %v", ix, receivedResponse, expectedResponse)
|
||||||
}
|
}
|
||||||
@@ -140,6 +135,8 @@ func TestCreateChatCompletionStream(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
_, streamErr = stream.Recv()
|
_, streamErr = stream.Recv()
|
||||||
|
|
||||||
|
checks.ErrorIs(t, streamErr, io.EOF, "stream.Recv() did not return EOF when the stream is finished")
|
||||||
if !errors.Is(streamErr, io.EOF) {
|
if !errors.Is(streamErr, io.EOF) {
|
||||||
t.Errorf("stream.Recv() did not return EOF when the stream is finished: %v", streamErr)
|
t.Errorf("stream.Recv() did not return EOF when the stream is finished: %v", streamErr)
|
||||||
}
|
}
|
||||||
@@ -166,9 +163,7 @@ func TestCreateChatCompletionStreamError(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
_, err := w.Write(dataBytes)
|
_, err := w.Write(dataBytes)
|
||||||
if err != nil {
|
checks.NoError(t, err, "Write error")
|
||||||
t.Errorf("Write error: %s", err)
|
|
||||||
}
|
|
||||||
}))
|
}))
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
|
||||||
@@ -196,15 +191,12 @@ func TestCreateChatCompletionStreamError(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
stream, err := client.CreateChatCompletionStream(ctx, request)
|
stream, err := client.CreateChatCompletionStream(ctx, request)
|
||||||
if err != nil {
|
checks.NoError(t, err, "CreateCompletionStream returned error")
|
||||||
t.Errorf("CreateCompletionStream returned error: %v", err)
|
|
||||||
}
|
|
||||||
defer stream.Close()
|
defer stream.Close()
|
||||||
|
|
||||||
_, streamErr := stream.Recv()
|
_, streamErr := stream.Recv()
|
||||||
if streamErr == nil {
|
checks.HasError(t, streamErr, "stream.Recv() did not return error")
|
||||||
t.Errorf("stream.Recv() did not return error")
|
|
||||||
}
|
|
||||||
var apiErr *APIError
|
var apiErr *APIError
|
||||||
if !errors.As(streamErr, &apiErr) {
|
if !errors.As(streamErr, &apiErr) {
|
||||||
t.Errorf("stream.Recv() did not return APIError")
|
t.Errorf("stream.Recv() did not return APIError")
|
||||||
|
|||||||
15
chat_test.go
15
chat_test.go
@@ -3,10 +3,10 @@ package openai_test
|
|||||||
import (
|
import (
|
||||||
. "github.com/sashabaranov/go-openai"
|
. "github.com/sashabaranov/go-openai"
|
||||||
"github.com/sashabaranov/go-openai/internal/test"
|
"github.com/sashabaranov/go-openai/internal/test"
|
||||||
|
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||||
|
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -33,9 +33,8 @@ func TestChatCompletionsWrongModel(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
_, err := client.CreateChatCompletion(ctx, req)
|
_, err := client.CreateChatCompletion(ctx, req)
|
||||||
if !errors.Is(err, ErrChatCompletionInvalidModel) {
|
msg := fmt.Sprintf("CreateChatCompletion should return wrong model error, returned: %s", err)
|
||||||
t.Fatalf("CreateChatCompletion should return ErrChatCompletionInvalidModel, but returned: %v", err)
|
checks.ErrorIs(t, err, ErrChatCompletionInvalidModel, msg)
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestChatCompletionsWithStream(t *testing.T) {
|
func TestChatCompletionsWithStream(t *testing.T) {
|
||||||
@@ -48,9 +47,7 @@ func TestChatCompletionsWithStream(t *testing.T) {
|
|||||||
Stream: true,
|
Stream: true,
|
||||||
}
|
}
|
||||||
_, err := client.CreateChatCompletion(ctx, req)
|
_, err := client.CreateChatCompletion(ctx, req)
|
||||||
if !errors.Is(err, ErrChatCompletionStreamNotSupported) {
|
checks.ErrorIs(t, err, ErrChatCompletionStreamNotSupported, "unexpected error")
|
||||||
t.Fatalf("CreateChatCompletion didn't return ErrChatCompletionStreamNotSupported error")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestCompletions Tests the completions endpoint of the API using the mocked server.
|
// TestCompletions Tests the completions endpoint of the API using the mocked server.
|
||||||
@@ -79,9 +76,7 @@ func TestChatCompletions(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
_, err = client.CreateChatCompletion(ctx, req)
|
_, err = client.CreateChatCompletion(ctx, req)
|
||||||
if err != nil {
|
checks.NoError(t, err, "CreateChatCompletion error")
|
||||||
t.Fatalf("CreateChatCompletion error: %v", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleChatCompletionEndpoint Handles the ChatGPT completion endpoint by the test server.
|
// handleChatCompletionEndpoint Handles the ChatGPT completion endpoint by the test server.
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package openai_test
|
|||||||
import (
|
import (
|
||||||
. "github.com/sashabaranov/go-openai"
|
. "github.com/sashabaranov/go-openai"
|
||||||
"github.com/sashabaranov/go-openai/internal/test"
|
"github.com/sashabaranov/go-openai/internal/test"
|
||||||
|
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||||
|
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
@@ -66,9 +67,7 @@ func TestCompletions(t *testing.T) {
|
|||||||
}
|
}
|
||||||
req.Prompt = "Lorem ipsum"
|
req.Prompt = "Lorem ipsum"
|
||||||
_, err = client.CreateCompletion(ctx, req)
|
_, err = client.CreateCompletion(ctx, req)
|
||||||
if err != nil {
|
checks.NoError(t, err, "CreateCompletion error")
|
||||||
t.Fatalf("CreateCompletion error: %v", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleCompletionEndpoint Handles the completion endpoint by the test server.
|
// handleCompletionEndpoint Handles the completion endpoint by the test server.
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package openai_test
|
|||||||
import (
|
import (
|
||||||
. "github.com/sashabaranov/go-openai"
|
. "github.com/sashabaranov/go-openai"
|
||||||
"github.com/sashabaranov/go-openai/internal/test"
|
"github.com/sashabaranov/go-openai/internal/test"
|
||||||
|
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||||
|
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
@@ -40,9 +41,7 @@ func TestEdits(t *testing.T) {
|
|||||||
N: 3,
|
N: 3,
|
||||||
}
|
}
|
||||||
response, err := client.Edits(ctx, editReq)
|
response, err := client.Edits(ctx, editReq)
|
||||||
if err != nil {
|
checks.NoError(t, err, "Edits error")
|
||||||
t.Fatalf("Edits error: %v", err)
|
|
||||||
}
|
|
||||||
if len(response.Choices) != editReq.N {
|
if len(response.Choices) != editReq.N {
|
||||||
t.Fatalf("edits does not properly return the correct number of choices")
|
t.Fatalf("edits does not properly return the correct number of choices")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package openai_test
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
. "github.com/sashabaranov/go-openai"
|
. "github.com/sashabaranov/go-openai"
|
||||||
|
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||||
|
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
@@ -38,9 +39,7 @@ func TestEmbedding(t *testing.T) {
|
|||||||
// marshal embeddingReq to JSON and confirm that the model field equals
|
// marshal embeddingReq to JSON and confirm that the model field equals
|
||||||
// the AdaSearchQuery type
|
// the AdaSearchQuery type
|
||||||
marshaled, err := json.Marshal(embeddingReq)
|
marshaled, err := json.Marshal(embeddingReq)
|
||||||
if err != nil {
|
checks.NoError(t, err, "Could not marshal embedding request")
|
||||||
t.Fatalf("Could not marshal embedding request: %v", err)
|
|
||||||
}
|
|
||||||
if !bytes.Contains(marshaled, []byte(`"model":"`+model.String()+`"`)) {
|
if !bytes.Contains(marshaled, []byte(`"model":"`+model.String()+`"`)) {
|
||||||
t.Fatalf("Expected embedding request to contain model field")
|
t.Fatalf("Expected embedding request to contain model field")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/sashabaranov/go-openai/internal/test"
|
"github.com/sashabaranov/go-openai/internal/test"
|
||||||
|
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -81,16 +82,13 @@ func TestErrorAccumulatorWriteErrors(t *testing.T) {
|
|||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
stream, err := client.CreateChatCompletionStream(ctx, ChatCompletionRequest{})
|
stream, err := client.CreateChatCompletionStream(ctx, ChatCompletionRequest{})
|
||||||
if err != nil {
|
checks.NoError(t, err)
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
stream.errAccumulator = &defaultErrorAccumulator{
|
stream.errAccumulator = &defaultErrorAccumulator{
|
||||||
buffer: &failingErrorBuffer{},
|
buffer: &failingErrorBuffer{},
|
||||||
unmarshaler: &jsonUnmarshaler{},
|
unmarshaler: &jsonUnmarshaler{},
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = stream.Recv()
|
_, err = stream.Recv()
|
||||||
if !errors.Is(err, errTestErrorAccumulatorWriteFailed) {
|
checks.ErrorIs(t, err, errTestErrorAccumulatorWriteFailed, "Did not return error when write failed", err.Error())
|
||||||
t.Fatalf("Did not return error when write failed: %v", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package openai_test
|
|||||||
import (
|
import (
|
||||||
. "github.com/sashabaranov/go-openai"
|
. "github.com/sashabaranov/go-openai"
|
||||||
"github.com/sashabaranov/go-openai/internal/test"
|
"github.com/sashabaranov/go-openai/internal/test"
|
||||||
|
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||||
|
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
@@ -33,9 +34,7 @@ func TestFileUpload(t *testing.T) {
|
|||||||
Purpose: "fine-tune",
|
Purpose: "fine-tune",
|
||||||
}
|
}
|
||||||
_, err = client.CreateFile(ctx, req)
|
_, err = client.CreateFile(ctx, req)
|
||||||
if err != nil {
|
checks.NoError(t, err, "CreateFile erro")
|
||||||
t.Fatalf("CreateFile error: %v", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleCreateFile Handles the images endpoint by the test server.
|
// handleCreateFile Handles the images endpoint by the test server.
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package openai_test
|
|||||||
import (
|
import (
|
||||||
. "github.com/sashabaranov/go-openai"
|
. "github.com/sashabaranov/go-openai"
|
||||||
"github.com/sashabaranov/go-openai/internal/test"
|
"github.com/sashabaranov/go-openai/internal/test"
|
||||||
|
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||||
|
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
@@ -70,32 +71,20 @@ func TestFineTunes(t *testing.T) {
|
|||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
_, err = client.ListFineTunes(ctx)
|
_, err = client.ListFineTunes(ctx)
|
||||||
if err != nil {
|
checks.NoError(t, err, "ListFineTunes error")
|
||||||
t.Fatalf("ListFineTunes error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = client.CreateFineTune(ctx, FineTuneRequest{})
|
_, err = client.CreateFineTune(ctx, FineTuneRequest{})
|
||||||
if err != nil {
|
checks.NoError(t, err, "CreateFineTune error")
|
||||||
t.Fatalf("CreateFineTune error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = client.CancelFineTune(ctx, testFineTuneID)
|
_, err = client.CancelFineTune(ctx, testFineTuneID)
|
||||||
if err != nil {
|
checks.NoError(t, err, "CancelFineTune error")
|
||||||
t.Fatalf("CancelFineTune error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = client.GetFineTune(ctx, testFineTuneID)
|
_, err = client.GetFineTune(ctx, testFineTuneID)
|
||||||
if err != nil {
|
checks.NoError(t, err, "GetFineTune error")
|
||||||
t.Fatalf("GetFineTune error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = client.DeleteFineTune(ctx, testFineTuneID)
|
_, err = client.DeleteFineTune(ctx, testFineTuneID)
|
||||||
if err != nil {
|
checks.NoError(t, err, "DeleteFineTune error")
|
||||||
t.Fatalf("DeleteFineTune error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = client.ListFineTuneEvents(ctx, testFineTuneID)
|
_, err = client.ListFineTuneEvents(ctx, testFineTuneID)
|
||||||
if err != nil {
|
checks.NoError(t, err, "ListFineTuneEvents error")
|
||||||
t.Fatalf("ListFineTuneEvents error: %v", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package openai_test
|
|||||||
import (
|
import (
|
||||||
. "github.com/sashabaranov/go-openai"
|
. "github.com/sashabaranov/go-openai"
|
||||||
"github.com/sashabaranov/go-openai/internal/test"
|
"github.com/sashabaranov/go-openai/internal/test"
|
||||||
|
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||||
|
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
@@ -31,9 +32,7 @@ func TestImages(t *testing.T) {
|
|||||||
req := ImageRequest{}
|
req := ImageRequest{}
|
||||||
req.Prompt = "Lorem ipsum"
|
req.Prompt = "Lorem ipsum"
|
||||||
_, err = client.CreateImage(ctx, req)
|
_, err = client.CreateImage(ctx, req)
|
||||||
if err != nil {
|
checks.NoError(t, err, "CreateImage error")
|
||||||
t.Fatalf("CreateImage error: %v", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleImageEndpoint Handles the images endpoint by the test server.
|
// handleImageEndpoint Handles the images endpoint by the test server.
|
||||||
@@ -127,9 +126,7 @@ func TestImageEdit(t *testing.T) {
|
|||||||
Size: CreateImageSize1024x1024,
|
Size: CreateImageSize1024x1024,
|
||||||
}
|
}
|
||||||
_, err = client.CreateEditImage(ctx, req)
|
_, err = client.CreateEditImage(ctx, req)
|
||||||
if err != nil {
|
checks.NoError(t, err, "CreateImage error")
|
||||||
t.Fatalf("CreateImage error: %v", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestImageEditWithoutMask(t *testing.T) {
|
func TestImageEditWithoutMask(t *testing.T) {
|
||||||
@@ -164,9 +161,7 @@ func TestImageEditWithoutMask(t *testing.T) {
|
|||||||
Size: CreateImageSize1024x1024,
|
Size: CreateImageSize1024x1024,
|
||||||
}
|
}
|
||||||
_, err = client.CreateEditImage(ctx, req)
|
_, err = client.CreateEditImage(ctx, req)
|
||||||
if err != nil {
|
checks.NoError(t, err, "CreateImage error")
|
||||||
t.Fatalf("CreateImage error: %v", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleEditImageEndpoint Handles the images endpoint by the test server.
|
// handleEditImageEndpoint Handles the images endpoint by the test server.
|
||||||
@@ -231,9 +226,7 @@ func TestImageVariation(t *testing.T) {
|
|||||||
Size: CreateImageSize1024x1024,
|
Size: CreateImageSize1024x1024,
|
||||||
}
|
}
|
||||||
_, err = client.CreateVariImage(ctx, req)
|
_, err = client.CreateVariImage(ctx, req)
|
||||||
if err != nil {
|
checks.NoError(t, err, "CreateImage error")
|
||||||
t.Fatalf("CreateImage error: %v", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleVariateImageEndpoint Handles the images endpoint by the test server.
|
// handleVariateImageEndpoint Handles the images endpoint by the test server.
|
||||||
|
|||||||
48
internal/test/checks/checks.go
Normal file
48
internal/test/checks/checks.go
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
package checks
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func NoError(t *testing.T, err error, message ...string) {
|
||||||
|
t.Helper()
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err, message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func HasError(t *testing.T, err error, message ...string) {
|
||||||
|
t.Helper()
|
||||||
|
if err == nil {
|
||||||
|
t.Error(err, message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ErrorIs(t *testing.T, err, target error, msg ...string) {
|
||||||
|
t.Helper()
|
||||||
|
if !errors.Is(err, target) {
|
||||||
|
t.Fatal(msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ErrorIsF(t *testing.T, err, target error, format string, msg ...string) {
|
||||||
|
t.Helper()
|
||||||
|
if !errors.Is(err, target) {
|
||||||
|
t.Fatalf(format, msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ErrorIsNot(t *testing.T, err, target error, msg ...string) {
|
||||||
|
t.Helper()
|
||||||
|
if errors.Is(err, target) {
|
||||||
|
t.Fatal(msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ErrorIsNotf(t *testing.T, err, target error, format string, msg ...string) {
|
||||||
|
t.Helper()
|
||||||
|
if errors.Is(err, target) {
|
||||||
|
t.Fatalf(format, msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -3,6 +3,7 @@ package openai_test
|
|||||||
import (
|
import (
|
||||||
. "github.com/sashabaranov/go-openai"
|
. "github.com/sashabaranov/go-openai"
|
||||||
"github.com/sashabaranov/go-openai/internal/test"
|
"github.com/sashabaranov/go-openai/internal/test"
|
||||||
|
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||||
|
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
@@ -27,9 +28,7 @@ func TestListModels(t *testing.T) {
|
|||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
_, err = client.ListModels(ctx)
|
_, err = client.ListModels(ctx)
|
||||||
if err != nil {
|
checks.NoError(t, err, "ListModels error")
|
||||||
t.Fatalf("ListModels error: %v", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleModelsEndpoint Handles the models endpoint by the test server.
|
// handleModelsEndpoint Handles the models endpoint by the test server.
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package openai_test
|
|||||||
import (
|
import (
|
||||||
. "github.com/sashabaranov/go-openai"
|
. "github.com/sashabaranov/go-openai"
|
||||||
"github.com/sashabaranov/go-openai/internal/test"
|
"github.com/sashabaranov/go-openai/internal/test"
|
||||||
|
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||||
|
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
@@ -37,9 +38,7 @@ func TestModerations(t *testing.T) {
|
|||||||
Input: "I want to kill them.",
|
Input: "I want to kill them.",
|
||||||
}
|
}
|
||||||
_, err = client.Moderations(ctx, moderationReq)
|
_, err = client.Moderations(ctx, moderationReq)
|
||||||
if err != nil {
|
checks.NoError(t, err, "Moderation error")
|
||||||
t.Fatalf("Moderation error: %v", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleModerationEndpoint Handles the moderation endpoint by the test server.
|
// handleModerationEndpoint Handles the moderation endpoint by the test server.
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package openai_test
|
|||||||
import (
|
import (
|
||||||
. "github.com/sashabaranov/go-openai"
|
. "github.com/sashabaranov/go-openai"
|
||||||
"github.com/sashabaranov/go-openai/internal/test"
|
"github.com/sashabaranov/go-openai/internal/test"
|
||||||
|
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||||
|
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
@@ -49,9 +50,7 @@ func TestCreateCompletionStream(t *testing.T) {
|
|||||||
dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...)
|
dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...)
|
||||||
|
|
||||||
_, err := w.Write(dataBytes)
|
_, err := w.Write(dataBytes)
|
||||||
if err != nil {
|
checks.NoError(t, err, "Write error")
|
||||||
t.Errorf("Write error: %s", err)
|
|
||||||
}
|
|
||||||
}))
|
}))
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
|
||||||
@@ -74,9 +73,7 @@ func TestCreateCompletionStream(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
stream, err := client.CreateCompletionStream(ctx, request)
|
stream, err := client.CreateCompletionStream(ctx, request)
|
||||||
if err != nil {
|
checks.NoError(t, err, "CreateCompletionStream returned error")
|
||||||
t.Errorf("CreateCompletionStream returned error: %v", err)
|
|
||||||
}
|
|
||||||
defer stream.Close()
|
defer stream.Close()
|
||||||
|
|
||||||
expectedResponses := []CompletionResponse{
|
expectedResponses := []CompletionResponse{
|
||||||
@@ -138,9 +135,7 @@ func TestCreateCompletionStreamError(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
_, err := w.Write(dataBytes)
|
_, err := w.Write(dataBytes)
|
||||||
if err != nil {
|
checks.NoError(t, err, "Write error")
|
||||||
t.Errorf("Write error: %s", err)
|
|
||||||
}
|
|
||||||
}))
|
}))
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
|
||||||
@@ -163,15 +158,12 @@ func TestCreateCompletionStreamError(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
stream, err := client.CreateCompletionStream(ctx, request)
|
stream, err := client.CreateCompletionStream(ctx, request)
|
||||||
if err != nil {
|
checks.NoError(t, err, "CreateCompletionStream returned error")
|
||||||
t.Errorf("CreateCompletionStream returned error: %v", err)
|
|
||||||
}
|
|
||||||
defer stream.Close()
|
defer stream.Close()
|
||||||
|
|
||||||
_, streamErr := stream.Recv()
|
_, streamErr := stream.Recv()
|
||||||
if streamErr == nil {
|
checks.HasError(t, streamErr, "stream.Recv() did not return error")
|
||||||
t.Errorf("stream.Recv() did not return error")
|
|
||||||
}
|
|
||||||
var apiErr *APIError
|
var apiErr *APIError
|
||||||
if !errors.As(streamErr, &apiErr) {
|
if !errors.As(streamErr, &apiErr) {
|
||||||
t.Errorf("stream.Recv() did not return APIError")
|
t.Errorf("stream.Recv() did not return APIError")
|
||||||
|
|||||||
Reference in New Issue
Block a user