Refactor/internal testing (#194)

* added NoError check

* corrected NoError

* has error checks

* replace more checks

* Used checks test helper

* Used checks test helper

* remove duplicate import

* fixed lint issues regarding length of messages

---------

Co-authored-by: Rex Posadas <rposadas@redwoodlogistics.com>
This commit is contained in:
rex posadas
2023-03-25 01:55:25 +08:00
committed by GitHub
parent 479dab3b69
commit 8e3a04664e
15 changed files with 115 additions and 140 deletions

View File

@@ -2,6 +2,7 @@ package openai_test
import ( import (
. "github.com/sashabaranov/go-openai" . "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks"
"context" "context"
"errors" "errors"
@@ -20,25 +21,17 @@ func TestAPI(t *testing.T) {
c := NewClient(apiToken) c := NewClient(apiToken)
ctx := context.Background() ctx := context.Background()
_, err = c.ListEngines(ctx) _, err = c.ListEngines(ctx)
if err != nil { checks.NoError(t, err, "ListEngines error")
t.Fatalf("ListEngines error: %v", err)
}
_, err = c.GetEngine(ctx, "davinci") _, err = c.GetEngine(ctx, "davinci")
if err != nil { checks.NoError(t, err, "GetEngine error")
t.Fatalf("GetEngine error: %v", err)
}
fileRes, err := c.ListFiles(ctx) fileRes, err := c.ListFiles(ctx)
if err != nil { checks.NoError(t, err, "ListFiles error")
t.Fatalf("ListFiles error: %v", err)
}
if len(fileRes.Files) > 0 { if len(fileRes.Files) > 0 {
_, err = c.GetFile(ctx, fileRes.Files[0].ID) _, err = c.GetFile(ctx, fileRes.Files[0].ID)
if err != nil { checks.NoError(t, err, "GetFile error")
t.Fatalf("GetFile error: %v", err)
}
} // else skip } // else skip
embeddingReq := EmbeddingRequest{ embeddingReq := EmbeddingRequest{
@@ -49,9 +42,7 @@ func TestAPI(t *testing.T) {
Model: AdaSearchQuery, Model: AdaSearchQuery,
} }
_, err = c.CreateEmbeddings(ctx, embeddingReq) _, err = c.CreateEmbeddings(ctx, embeddingReq)
if err != nil { checks.NoError(t, err, "Embedding error")
t.Fatalf("Embedding error: %v", err)
}
_, err = c.CreateChatCompletion( _, err = c.CreateChatCompletion(
ctx, ctx,
@@ -66,9 +57,7 @@ func TestAPI(t *testing.T) {
}, },
) )
if err != nil { checks.NoError(t, err, "CreateChatCompletion (without name) returned error")
t.Errorf("CreateChatCompletion (without name) returned error: %v", err)
}
_, err = c.CreateChatCompletion( _, err = c.CreateChatCompletion(
ctx, ctx,
@@ -83,10 +72,7 @@ func TestAPI(t *testing.T) {
}, },
}, },
) )
checks.NoError(t, err, "CreateChatCompletion (with name) returned error")
if err != nil {
t.Errorf("CreateChatCompletion (with name) returned error: %v", err)
}
stream, err := c.CreateCompletionStream(ctx, CompletionRequest{ stream, err := c.CreateCompletionStream(ctx, CompletionRequest{
Prompt: "Ex falso quodlibet", Prompt: "Ex falso quodlibet",
@@ -94,9 +80,7 @@ func TestAPI(t *testing.T) {
MaxTokens: 5, MaxTokens: 5,
Stream: true, Stream: true,
}) })
if err != nil { checks.NoError(t, err, "CreateCompletionStream returned error")
t.Errorf("CreateCompletionStream returned error: %v", err)
}
defer stream.Close() defer stream.Close()
counter := 0 counter := 0
@@ -126,9 +110,7 @@ func TestAPIError(t *testing.T) {
c := NewClient(apiToken + "_invalid") c := NewClient(apiToken + "_invalid")
ctx := context.Background() ctx := context.Background()
_, err = c.ListEngines(ctx) _, err = c.ListEngines(ctx)
if err == nil { checks.NoError(t, err, "ListEngines did not fail")
t.Fatal("ListEngines did not fail")
}
var apiErr *APIError var apiErr *APIError
if !errors.As(err, &apiErr) { if !errors.As(err, &apiErr) {
@@ -154,9 +136,7 @@ func TestRequestError(t *testing.T) {
c := NewClientWithConfig(config) c := NewClientWithConfig(config)
ctx := context.Background() ctx := context.Background()
_, err = c.ListEngines(ctx) _, err = c.ListEngines(ctx)
if err == nil { checks.HasError(t, err, "ListEngines did not fail")
t.Fatal("ListEngines request did not fail")
}
var reqErr *RequestError var reqErr *RequestError
if !errors.As(err, &reqErr) { if !errors.As(err, &reqErr) {

View File

@@ -13,6 +13,7 @@ import (
. "github.com/sashabaranov/go-openai" . "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test" "github.com/sashabaranov/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks"
"context" "context"
"testing" "testing"
@@ -62,9 +63,7 @@ func TestAudio(t *testing.T) {
Model: "whisper-3", Model: "whisper-3",
} }
_, err = tc.createFn(ctx, req) _, err = tc.createFn(ctx, req)
if err != nil { checks.NoError(t, err, "audio API error")
t.Fatalf("audio API error: %v", err)
}
}) })
} }
} }
@@ -115,9 +114,7 @@ func TestAudioWithOptionalArgs(t *testing.T) {
Language: "zh", Language: "zh",
} }
_, err = tc.createFn(ctx, req) _, err = tc.createFn(ctx, req)
if err != nil { checks.NoError(t, err, "audio API error")
t.Fatalf("audio API error: %v", err)
}
}) })
} }
} }
@@ -125,9 +122,8 @@ func TestAudioWithOptionalArgs(t *testing.T) {
// createTestFile creates a fake file with "hello" as the content. // createTestFile creates a fake file with "hello" as the content.
func createTestFile(t *testing.T, path string) { func createTestFile(t *testing.T, path string) {
file, err := os.Create(path) file, err := os.Create(path)
if err != nil { checks.NoError(t, err, "failed to create file")
t.Fatalf("failed to create file %v", err)
}
if _, err = file.WriteString("hello"); err != nil { if _, err = file.WriteString("hello"); err != nil {
t.Fatalf("failed to write to file %v", err) t.Fatalf("failed to write to file %v", err)
} }
@@ -139,9 +135,7 @@ func createTestDirectory(t *testing.T) (path string, cleanup func()) {
t.Helper() t.Helper()
path, err := os.MkdirTemp(os.TempDir(), "") path, err := os.MkdirTemp(os.TempDir(), "")
if err != nil { checks.NoError(t, err)
t.Fatal(err)
}
return path, func() { os.RemoveAll(path) } return path, func() { os.RemoveAll(path) }
} }

View File

@@ -3,6 +3,7 @@ package openai_test
import ( import (
. "github.com/sashabaranov/go-openai" . "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test" "github.com/sashabaranov/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks"
"context" "context"
"encoding/json" "encoding/json"
@@ -55,9 +56,7 @@ func TestCreateChatCompletionStream(t *testing.T) {
dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...) dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...)
_, err := w.Write(dataBytes) _, err := w.Write(dataBytes)
if err != nil { checks.NoError(t, err, "Write error")
t.Errorf("Write error: %s", err)
}
})) }))
defer server.Close() defer server.Close()
@@ -85,9 +84,7 @@ func TestCreateChatCompletionStream(t *testing.T) {
} }
stream, err := client.CreateChatCompletionStream(ctx, request) stream, err := client.CreateChatCompletionStream(ctx, request)
if err != nil { checks.NoError(t, err, "CreateCompletionStream returned error")
t.Errorf("CreateCompletionStream returned error: %v", err)
}
defer stream.Close() defer stream.Close()
expectedResponses := []ChatCompletionStreamResponse{ expectedResponses := []ChatCompletionStreamResponse{
@@ -126,9 +123,7 @@ func TestCreateChatCompletionStream(t *testing.T) {
t.Logf("%d: %s", ix, string(b)) t.Logf("%d: %s", ix, string(b))
receivedResponse, streamErr := stream.Recv() receivedResponse, streamErr := stream.Recv()
if streamErr != nil { checks.NoError(t, streamErr, "stream.Recv() failed")
t.Errorf("stream.Recv() failed: %v", streamErr)
}
if !compareChatResponses(expectedResponse, receivedResponse) { if !compareChatResponses(expectedResponse, receivedResponse) {
t.Errorf("Stream response %v is %v, expected %v", ix, receivedResponse, expectedResponse) t.Errorf("Stream response %v is %v, expected %v", ix, receivedResponse, expectedResponse)
} }
@@ -140,6 +135,8 @@ func TestCreateChatCompletionStream(t *testing.T) {
} }
_, streamErr = stream.Recv() _, streamErr = stream.Recv()
checks.ErrorIs(t, streamErr, io.EOF, "stream.Recv() did not return EOF when the stream is finished")
if !errors.Is(streamErr, io.EOF) { if !errors.Is(streamErr, io.EOF) {
t.Errorf("stream.Recv() did not return EOF when the stream is finished: %v", streamErr) t.Errorf("stream.Recv() did not return EOF when the stream is finished: %v", streamErr)
} }
@@ -166,9 +163,7 @@ func TestCreateChatCompletionStreamError(t *testing.T) {
} }
_, err := w.Write(dataBytes) _, err := w.Write(dataBytes)
if err != nil { checks.NoError(t, err, "Write error")
t.Errorf("Write error: %s", err)
}
})) }))
defer server.Close() defer server.Close()
@@ -196,15 +191,12 @@ func TestCreateChatCompletionStreamError(t *testing.T) {
} }
stream, err := client.CreateChatCompletionStream(ctx, request) stream, err := client.CreateChatCompletionStream(ctx, request)
if err != nil { checks.NoError(t, err, "CreateCompletionStream returned error")
t.Errorf("CreateCompletionStream returned error: %v", err)
}
defer stream.Close() defer stream.Close()
_, streamErr := stream.Recv() _, streamErr := stream.Recv()
if streamErr == nil { checks.HasError(t, streamErr, "stream.Recv() did not return error")
t.Errorf("stream.Recv() did not return error")
}
var apiErr *APIError var apiErr *APIError
if !errors.As(streamErr, &apiErr) { if !errors.As(streamErr, &apiErr) {
t.Errorf("stream.Recv() did not return APIError") t.Errorf("stream.Recv() did not return APIError")

View File

@@ -3,10 +3,10 @@ package openai_test
import ( import (
. "github.com/sashabaranov/go-openai" . "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test" "github.com/sashabaranov/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks"
"context" "context"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
@@ -33,9 +33,8 @@ func TestChatCompletionsWrongModel(t *testing.T) {
}, },
} }
_, err := client.CreateChatCompletion(ctx, req) _, err := client.CreateChatCompletion(ctx, req)
if !errors.Is(err, ErrChatCompletionInvalidModel) { msg := fmt.Sprintf("CreateChatCompletion should return wrong model error, returned: %s", err)
t.Fatalf("CreateChatCompletion should return ErrChatCompletionInvalidModel, but returned: %v", err) checks.ErrorIs(t, err, ErrChatCompletionInvalidModel, msg)
}
} }
func TestChatCompletionsWithStream(t *testing.T) { func TestChatCompletionsWithStream(t *testing.T) {
@@ -48,9 +47,7 @@ func TestChatCompletionsWithStream(t *testing.T) {
Stream: true, Stream: true,
} }
_, err := client.CreateChatCompletion(ctx, req) _, err := client.CreateChatCompletion(ctx, req)
if !errors.Is(err, ErrChatCompletionStreamNotSupported) { checks.ErrorIs(t, err, ErrChatCompletionStreamNotSupported, "unexpected error")
t.Fatalf("CreateChatCompletion didn't return ErrChatCompletionStreamNotSupported error")
}
} }
// TestCompletions Tests the completions endpoint of the API using the mocked server. // TestCompletions Tests the completions endpoint of the API using the mocked server.
@@ -79,9 +76,7 @@ func TestChatCompletions(t *testing.T) {
}, },
} }
_, err = client.CreateChatCompletion(ctx, req) _, err = client.CreateChatCompletion(ctx, req)
if err != nil { checks.NoError(t, err, "CreateChatCompletion error")
t.Fatalf("CreateChatCompletion error: %v", err)
}
} }
// handleChatCompletionEndpoint Handles the ChatGPT completion endpoint by the test server. // handleChatCompletionEndpoint Handles the ChatGPT completion endpoint by the test server.

View File

@@ -3,6 +3,7 @@ package openai_test
import ( import (
. "github.com/sashabaranov/go-openai" . "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test" "github.com/sashabaranov/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks"
"context" "context"
"encoding/json" "encoding/json"
@@ -66,9 +67,7 @@ func TestCompletions(t *testing.T) {
} }
req.Prompt = "Lorem ipsum" req.Prompt = "Lorem ipsum"
_, err = client.CreateCompletion(ctx, req) _, err = client.CreateCompletion(ctx, req)
if err != nil { checks.NoError(t, err, "CreateCompletion error")
t.Fatalf("CreateCompletion error: %v", err)
}
} }
// handleCompletionEndpoint Handles the completion endpoint by the test server. // handleCompletionEndpoint Handles the completion endpoint by the test server.

View File

@@ -3,6 +3,7 @@ package openai_test
import ( import (
. "github.com/sashabaranov/go-openai" . "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test" "github.com/sashabaranov/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks"
"context" "context"
"encoding/json" "encoding/json"
@@ -40,9 +41,7 @@ func TestEdits(t *testing.T) {
N: 3, N: 3,
} }
response, err := client.Edits(ctx, editReq) response, err := client.Edits(ctx, editReq)
if err != nil { checks.NoError(t, err, "Edits error")
t.Fatalf("Edits error: %v", err)
}
if len(response.Choices) != editReq.N { if len(response.Choices) != editReq.N {
t.Fatalf("edits does not properly return the correct number of choices") t.Fatalf("edits does not properly return the correct number of choices")
} }

View File

@@ -2,6 +2,7 @@ package openai_test
import ( import (
. "github.com/sashabaranov/go-openai" . "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks"
"bytes" "bytes"
"encoding/json" "encoding/json"
@@ -38,9 +39,7 @@ func TestEmbedding(t *testing.T) {
// marshal embeddingReq to JSON and confirm that the model field equals // marshal embeddingReq to JSON and confirm that the model field equals
// the AdaSearchQuery type // the AdaSearchQuery type
marshaled, err := json.Marshal(embeddingReq) marshaled, err := json.Marshal(embeddingReq)
if err != nil { checks.NoError(t, err, "Could not marshal embedding request")
t.Fatalf("Could not marshal embedding request: %v", err)
}
if !bytes.Contains(marshaled, []byte(`"model":"`+model.String()+`"`)) { if !bytes.Contains(marshaled, []byte(`"model":"`+model.String()+`"`)) {
t.Fatalf("Expected embedding request to contain model field") t.Fatalf("Expected embedding request to contain model field")
} }

View File

@@ -7,6 +7,7 @@ import (
"testing" "testing"
"github.com/sashabaranov/go-openai/internal/test" "github.com/sashabaranov/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks"
) )
var ( var (
@@ -81,16 +82,13 @@ func TestErrorAccumulatorWriteErrors(t *testing.T) {
ctx := context.Background() ctx := context.Background()
stream, err := client.CreateChatCompletionStream(ctx, ChatCompletionRequest{}) stream, err := client.CreateChatCompletionStream(ctx, ChatCompletionRequest{})
if err != nil { checks.NoError(t, err)
t.Fatal(err)
}
stream.errAccumulator = &defaultErrorAccumulator{ stream.errAccumulator = &defaultErrorAccumulator{
buffer: &failingErrorBuffer{}, buffer: &failingErrorBuffer{},
unmarshaler: &jsonUnmarshaler{}, unmarshaler: &jsonUnmarshaler{},
} }
_, err = stream.Recv() _, err = stream.Recv()
if !errors.Is(err, errTestErrorAccumulatorWriteFailed) { checks.ErrorIs(t, err, errTestErrorAccumulatorWriteFailed, "Did not return error when write failed", err.Error())
t.Fatalf("Did not return error when write failed: %v", err)
}
} }

View File

@@ -3,6 +3,7 @@ package openai_test
import ( import (
. "github.com/sashabaranov/go-openai" . "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test" "github.com/sashabaranov/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks"
"context" "context"
"encoding/json" "encoding/json"
@@ -33,9 +34,7 @@ func TestFileUpload(t *testing.T) {
Purpose: "fine-tune", Purpose: "fine-tune",
} }
_, err = client.CreateFile(ctx, req) _, err = client.CreateFile(ctx, req)
if err != nil { checks.NoError(t, err, "CreateFile erro")
t.Fatalf("CreateFile error: %v", err)
}
} }
// handleCreateFile Handles the images endpoint by the test server. // handleCreateFile Handles the images endpoint by the test server.

View File

@@ -3,6 +3,7 @@ package openai_test
import ( import (
. "github.com/sashabaranov/go-openai" . "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test" "github.com/sashabaranov/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks"
"context" "context"
"encoding/json" "encoding/json"
@@ -70,32 +71,20 @@ func TestFineTunes(t *testing.T) {
ctx := context.Background() ctx := context.Background()
_, err = client.ListFineTunes(ctx) _, err = client.ListFineTunes(ctx)
if err != nil { checks.NoError(t, err, "ListFineTunes error")
t.Fatalf("ListFineTunes error: %v", err)
}
_, err = client.CreateFineTune(ctx, FineTuneRequest{}) _, err = client.CreateFineTune(ctx, FineTuneRequest{})
if err != nil { checks.NoError(t, err, "CreateFineTune error")
t.Fatalf("CreateFineTune error: %v", err)
}
_, err = client.CancelFineTune(ctx, testFineTuneID) _, err = client.CancelFineTune(ctx, testFineTuneID)
if err != nil { checks.NoError(t, err, "CancelFineTune error")
t.Fatalf("CancelFineTune error: %v", err)
}
_, err = client.GetFineTune(ctx, testFineTuneID) _, err = client.GetFineTune(ctx, testFineTuneID)
if err != nil { checks.NoError(t, err, "GetFineTune error")
t.Fatalf("GetFineTune error: %v", err)
}
_, err = client.DeleteFineTune(ctx, testFineTuneID) _, err = client.DeleteFineTune(ctx, testFineTuneID)
if err != nil { checks.NoError(t, err, "DeleteFineTune error")
t.Fatalf("DeleteFineTune error: %v", err)
}
_, err = client.ListFineTuneEvents(ctx, testFineTuneID) _, err = client.ListFineTuneEvents(ctx, testFineTuneID)
if err != nil { checks.NoError(t, err, "ListFineTuneEvents error")
t.Fatalf("ListFineTuneEvents error: %v", err)
}
} }

View File

@@ -3,6 +3,7 @@ package openai_test
import ( import (
. "github.com/sashabaranov/go-openai" . "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test" "github.com/sashabaranov/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks"
"context" "context"
"encoding/json" "encoding/json"
@@ -31,9 +32,7 @@ func TestImages(t *testing.T) {
req := ImageRequest{} req := ImageRequest{}
req.Prompt = "Lorem ipsum" req.Prompt = "Lorem ipsum"
_, err = client.CreateImage(ctx, req) _, err = client.CreateImage(ctx, req)
if err != nil { checks.NoError(t, err, "CreateImage error")
t.Fatalf("CreateImage error: %v", err)
}
} }
// handleImageEndpoint Handles the images endpoint by the test server. // handleImageEndpoint Handles the images endpoint by the test server.
@@ -127,9 +126,7 @@ func TestImageEdit(t *testing.T) {
Size: CreateImageSize1024x1024, Size: CreateImageSize1024x1024,
} }
_, err = client.CreateEditImage(ctx, req) _, err = client.CreateEditImage(ctx, req)
if err != nil { checks.NoError(t, err, "CreateImage error")
t.Fatalf("CreateImage error: %v", err)
}
} }
func TestImageEditWithoutMask(t *testing.T) { func TestImageEditWithoutMask(t *testing.T) {
@@ -164,9 +161,7 @@ func TestImageEditWithoutMask(t *testing.T) {
Size: CreateImageSize1024x1024, Size: CreateImageSize1024x1024,
} }
_, err = client.CreateEditImage(ctx, req) _, err = client.CreateEditImage(ctx, req)
if err != nil { checks.NoError(t, err, "CreateImage error")
t.Fatalf("CreateImage error: %v", err)
}
} }
// handleEditImageEndpoint Handles the images endpoint by the test server. // handleEditImageEndpoint Handles the images endpoint by the test server.
@@ -231,9 +226,7 @@ func TestImageVariation(t *testing.T) {
Size: CreateImageSize1024x1024, Size: CreateImageSize1024x1024,
} }
_, err = client.CreateVariImage(ctx, req) _, err = client.CreateVariImage(ctx, req)
if err != nil { checks.NoError(t, err, "CreateImage error")
t.Fatalf("CreateImage error: %v", err)
}
} }
// handleVariateImageEndpoint Handles the images endpoint by the test server. // handleVariateImageEndpoint Handles the images endpoint by the test server.

View File

@@ -0,0 +1,48 @@
package checks
import (
"errors"
"testing"
)
func NoError(t *testing.T, err error, message ...string) {
t.Helper()
if err != nil {
t.Error(err, message)
}
}
func HasError(t *testing.T, err error, message ...string) {
t.Helper()
if err == nil {
t.Error(err, message)
}
}
func ErrorIs(t *testing.T, err, target error, msg ...string) {
t.Helper()
if !errors.Is(err, target) {
t.Fatal(msg)
}
}
func ErrorIsF(t *testing.T, err, target error, format string, msg ...string) {
t.Helper()
if !errors.Is(err, target) {
t.Fatalf(format, msg)
}
}
func ErrorIsNot(t *testing.T, err, target error, msg ...string) {
t.Helper()
if errors.Is(err, target) {
t.Fatal(msg)
}
}
func ErrorIsNotf(t *testing.T, err, target error, format string, msg ...string) {
t.Helper()
if errors.Is(err, target) {
t.Fatalf(format, msg)
}
}

View File

@@ -3,6 +3,7 @@ package openai_test
import ( import (
. "github.com/sashabaranov/go-openai" . "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test" "github.com/sashabaranov/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks"
"context" "context"
"encoding/json" "encoding/json"
@@ -27,9 +28,7 @@ func TestListModels(t *testing.T) {
ctx := context.Background() ctx := context.Background()
_, err = client.ListModels(ctx) _, err = client.ListModels(ctx)
if err != nil { checks.NoError(t, err, "ListModels error")
t.Fatalf("ListModels error: %v", err)
}
} }
// handleModelsEndpoint Handles the models endpoint by the test server. // handleModelsEndpoint Handles the models endpoint by the test server.

View File

@@ -3,6 +3,7 @@ package openai_test
import ( import (
. "github.com/sashabaranov/go-openai" . "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test" "github.com/sashabaranov/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks"
"context" "context"
"encoding/json" "encoding/json"
@@ -37,9 +38,7 @@ func TestModerations(t *testing.T) {
Input: "I want to kill them.", Input: "I want to kill them.",
} }
_, err = client.Moderations(ctx, moderationReq) _, err = client.Moderations(ctx, moderationReq)
if err != nil { checks.NoError(t, err, "Moderation error")
t.Fatalf("Moderation error: %v", err)
}
} }
// handleModerationEndpoint Handles the moderation endpoint by the test server. // handleModerationEndpoint Handles the moderation endpoint by the test server.

View File

@@ -3,6 +3,7 @@ package openai_test
import ( import (
. "github.com/sashabaranov/go-openai" . "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test" "github.com/sashabaranov/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks"
"context" "context"
"errors" "errors"
@@ -49,9 +50,7 @@ func TestCreateCompletionStream(t *testing.T) {
dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...) dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...)
_, err := w.Write(dataBytes) _, err := w.Write(dataBytes)
if err != nil { checks.NoError(t, err, "Write error")
t.Errorf("Write error: %s", err)
}
})) }))
defer server.Close() defer server.Close()
@@ -74,9 +73,7 @@ func TestCreateCompletionStream(t *testing.T) {
} }
stream, err := client.CreateCompletionStream(ctx, request) stream, err := client.CreateCompletionStream(ctx, request)
if err != nil { checks.NoError(t, err, "CreateCompletionStream returned error")
t.Errorf("CreateCompletionStream returned error: %v", err)
}
defer stream.Close() defer stream.Close()
expectedResponses := []CompletionResponse{ expectedResponses := []CompletionResponse{
@@ -138,9 +135,7 @@ func TestCreateCompletionStreamError(t *testing.T) {
} }
_, err := w.Write(dataBytes) _, err := w.Write(dataBytes)
if err != nil { checks.NoError(t, err, "Write error")
t.Errorf("Write error: %s", err)
}
})) }))
defer server.Close() defer server.Close()
@@ -163,15 +158,12 @@ func TestCreateCompletionStreamError(t *testing.T) {
} }
stream, err := client.CreateCompletionStream(ctx, request) stream, err := client.CreateCompletionStream(ctx, request)
if err != nil { checks.NoError(t, err, "CreateCompletionStream returned error")
t.Errorf("CreateCompletionStream returned error: %v", err)
}
defer stream.Close() defer stream.Close()
_, streamErr := stream.Recv() _, streamErr := stream.Recv()
if streamErr == nil { checks.HasError(t, streamErr, "stream.Recv() did not return error")
t.Errorf("stream.Recv() did not return error")
}
var apiErr *APIError var apiErr *APIError
if !errors.As(streamErr, &apiErr) { if !errors.As(streamErr, &apiErr) {
t.Errorf("stream.Recv() did not return APIError") t.Errorf("stream.Recv() did not return APIError")