* move request_builder into internal pkg (#304) * add some test for internal.RequestBuilder * add a test for openai.GetEngine
This commit is contained in:
committed by
GitHub
parent
62eb4beed2
commit
61ba5f3369
2
chat.go
2
chat.go
@@ -77,7 +77,7 @@ func (c *Client) CreateChatCompletion(
|
||||
return
|
||||
}
|
||||
|
||||
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), request)
|
||||
req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), request)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -15,7 +15,7 @@ import (
|
||||
type Client struct {
|
||||
config ClientConfig
|
||||
|
||||
requestBuilder requestBuilder
|
||||
requestBuilder utils.RequestBuilder
|
||||
createFormBuilder func(io.Writer) utils.FormBuilder
|
||||
}
|
||||
|
||||
@@ -29,7 +29,7 @@ func NewClient(authToken string) *Client {
|
||||
func NewClientWithConfig(config ClientConfig) *Client {
|
||||
return &Client{
|
||||
config: config,
|
||||
requestBuilder: newRequestBuilder(),
|
||||
requestBuilder: utils.NewRequestBuilder(),
|
||||
createFormBuilder: func(body io.Writer) utils.FormBuilder {
|
||||
return utils.NewFormBuilder(body)
|
||||
},
|
||||
@@ -135,7 +135,7 @@ func (c *Client) newStreamRequest(
|
||||
urlSuffix string,
|
||||
body any,
|
||||
model string) (*http.Request, error) {
|
||||
req, err := c.requestBuilder.build(ctx, method, c.fullURL(urlSuffix, model), body)
|
||||
req, err := c.requestBuilder.Build(ctx, method, c.fullURL(urlSuffix, model), body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
149
client_test.go
149
client_test.go
@@ -2,13 +2,24 @@ package openai //nolint:testpackage // testing private field
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/sashabaranov/go-openai/internal/test"
|
||||
)
|
||||
|
||||
var errTestRequestBuilderFailed = errors.New("test request builder failed")
|
||||
|
||||
type failingRequestBuilder struct{}
|
||||
|
||||
func (*failingRequestBuilder) Build(_ context.Context, _, _ string, _ any) (*http.Request, error) {
|
||||
return nil, errTestRequestBuilderFailed
|
||||
}
|
||||
|
||||
func TestClient(t *testing.T) {
|
||||
const mockToken = "mock token"
|
||||
client := NewClient(mockToken)
|
||||
@@ -145,3 +156,141 @@ func TestHandleErrorResp(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientReturnsRequestBuilderErrors(t *testing.T) {
|
||||
var err error
|
||||
ts := test.NewTestServer().OpenAITestServer()
|
||||
ts.Start()
|
||||
defer ts.Close()
|
||||
|
||||
config := DefaultConfig(test.GetTestToken())
|
||||
config.BaseURL = ts.URL + "/v1"
|
||||
client := NewClientWithConfig(config)
|
||||
client.requestBuilder = &failingRequestBuilder{}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
_, err = client.CreateCompletion(ctx, CompletionRequest{Prompt: "testing"})
|
||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||
}
|
||||
|
||||
_, err = client.CreateChatCompletion(ctx, ChatCompletionRequest{Model: GPT3Dot5Turbo})
|
||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||
}
|
||||
|
||||
_, err = client.CreateChatCompletionStream(ctx, ChatCompletionRequest{Model: GPT3Dot5Turbo})
|
||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||
}
|
||||
|
||||
_, err = client.CreateFineTune(ctx, FineTuneRequest{})
|
||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||
}
|
||||
|
||||
_, err = client.ListFineTunes(ctx)
|
||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||
}
|
||||
|
||||
_, err = client.CancelFineTune(ctx, "")
|
||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||
}
|
||||
|
||||
_, err = client.GetFineTune(ctx, "")
|
||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||
}
|
||||
|
||||
_, err = client.DeleteFineTune(ctx, "")
|
||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||
}
|
||||
|
||||
_, err = client.ListFineTuneEvents(ctx, "")
|
||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||
}
|
||||
|
||||
_, err = client.Moderations(ctx, ModerationRequest{})
|
||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||
}
|
||||
|
||||
_, err = client.Edits(ctx, EditsRequest{})
|
||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||
}
|
||||
|
||||
_, err = client.CreateEmbeddings(ctx, EmbeddingRequest{})
|
||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||
}
|
||||
|
||||
_, err = client.CreateImage(ctx, ImageRequest{})
|
||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||
}
|
||||
|
||||
err = client.DeleteFile(ctx, "")
|
||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||
}
|
||||
|
||||
_, err = client.GetFile(ctx, "")
|
||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||
}
|
||||
|
||||
_, err = client.ListFiles(ctx)
|
||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||
}
|
||||
|
||||
_, err = client.ListEngines(ctx)
|
||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||
}
|
||||
|
||||
_, err = client.GetEngine(ctx, "")
|
||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||
}
|
||||
|
||||
_, err = client.ListModels(ctx)
|
||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||
}
|
||||
|
||||
_, err = client.CreateCompletionStream(ctx, CompletionRequest{Prompt: ""})
|
||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientReturnsRequestBuilderErrorsAddtion(t *testing.T) {
|
||||
var err error
|
||||
ts := test.NewTestServer().OpenAITestServer()
|
||||
ts.Start()
|
||||
defer ts.Close()
|
||||
|
||||
config := DefaultConfig(test.GetTestToken())
|
||||
config.BaseURL = ts.URL + "/v1"
|
||||
client := NewClientWithConfig(config)
|
||||
client.requestBuilder = &failingRequestBuilder{}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
_, err = client.CreateCompletion(ctx, CompletionRequest{Prompt: 1})
|
||||
if !errors.Is(err, ErrCompletionRequestPromptTypeNotSupported) {
|
||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||
}
|
||||
|
||||
_, err = client.CreateCompletionStream(ctx, CompletionRequest{Prompt: 1})
|
||||
if !errors.Is(err, ErrCompletionRequestPromptTypeNotSupported) {
|
||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -155,7 +155,7 @@ func (c *Client) CreateCompletion(
|
||||
return
|
||||
}
|
||||
|
||||
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), request)
|
||||
req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), request)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
2
edits.go
2
edits.go
@@ -32,7 +32,7 @@ type EditsResponse struct {
|
||||
|
||||
// Perform an API call to the Edits endpoint.
|
||||
func (c *Client) Edits(ctx context.Context, request EditsRequest) (response EditsResponse, err error) {
|
||||
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/edits", fmt.Sprint(request.Model)), request)
|
||||
req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL("/edits", fmt.Sprint(request.Model)), request)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -132,7 +132,7 @@ type EmbeddingRequest struct {
|
||||
// CreateEmbeddings returns an EmbeddingResponse which will contain an Embedding for every item in |request.Input|.
|
||||
// https://beta.openai.com/docs/api-reference/embeddings/create
|
||||
func (c *Client) CreateEmbeddings(ctx context.Context, request EmbeddingRequest) (resp EmbeddingResponse, err error) {
|
||||
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/embeddings", request.Model.String()), request)
|
||||
req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL("/embeddings", request.Model.String()), request)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -22,7 +22,7 @@ type EnginesList struct {
|
||||
// ListEngines Lists the currently available engines, and provides basic
|
||||
// information about each option such as the owner and availability.
|
||||
func (c *Client) ListEngines(ctx context.Context) (engines EnginesList, err error) {
|
||||
req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL("/engines"), nil)
|
||||
req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL("/engines"), nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -38,7 +38,7 @@ func (c *Client) GetEngine(
|
||||
engineID string,
|
||||
) (engine Engine, err error) {
|
||||
urlSuffix := fmt.Sprintf("/engines/%s", engineID)
|
||||
req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil)
|
||||
req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
34
engines_test.go
Normal file
34
engines_test.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package openai_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
. "github.com/sashabaranov/go-openai"
|
||||
"github.com/sashabaranov/go-openai/internal/test"
|
||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||
)
|
||||
|
||||
// TestGetEngine Tests the retrieve engine endpoint of the API using the mocked server.
|
||||
func TestGetEngine(t *testing.T) {
|
||||
server := test.NewTestServer()
|
||||
server.RegisterHandler("/v1/engines/text-davinci-003", func(w http.ResponseWriter, r *http.Request) {
|
||||
resBytes, _ := json.Marshal(Engine{})
|
||||
fmt.Fprintln(w, string(resBytes))
|
||||
})
|
||||
// create the test server
|
||||
ts := server.OpenAITestServer()
|
||||
ts.Start()
|
||||
defer ts.Close()
|
||||
|
||||
config := DefaultConfig(test.GetTestToken())
|
||||
config.BaseURL = ts.URL + "/v1"
|
||||
client := NewClientWithConfig(config)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := client.GetEngine(ctx, "text-davinci-003")
|
||||
checks.NoError(t, err, "GetEngine error")
|
||||
}
|
||||
6
files.go
6
files.go
@@ -70,7 +70,7 @@ func (c *Client) CreateFile(ctx context.Context, request FileRequest) (file File
|
||||
|
||||
// DeleteFile deletes an existing file.
|
||||
func (c *Client) DeleteFile(ctx context.Context, fileID string) (err error) {
|
||||
req, err := c.requestBuilder.build(ctx, http.MethodDelete, c.fullURL("/files/"+fileID), nil)
|
||||
req, err := c.requestBuilder.Build(ctx, http.MethodDelete, c.fullURL("/files/"+fileID), nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -82,7 +82,7 @@ func (c *Client) DeleteFile(ctx context.Context, fileID string) (err error) {
|
||||
// ListFiles Lists the currently available files,
|
||||
// and provides basic information about each file such as the file name and purpose.
|
||||
func (c *Client) ListFiles(ctx context.Context) (files FilesList, err error) {
|
||||
req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL("/files"), nil)
|
||||
req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL("/files"), nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -95,7 +95,7 @@ func (c *Client) ListFiles(ctx context.Context) (files FilesList, err error) {
|
||||
// such as the file name and purpose.
|
||||
func (c *Client) GetFile(ctx context.Context, fileID string) (file File, err error) {
|
||||
urlSuffix := fmt.Sprintf("/files/%s", fileID)
|
||||
req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil)
|
||||
req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -68,7 +68,7 @@ type FineTuneDeleteResponse struct {
|
||||
|
||||
func (c *Client) CreateFineTune(ctx context.Context, request FineTuneRequest) (response FineTune, err error) {
|
||||
urlSuffix := "/fine-tunes"
|
||||
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL(urlSuffix), request)
|
||||
req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL(urlSuffix), request)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -79,7 +79,7 @@ func (c *Client) CreateFineTune(ctx context.Context, request FineTuneRequest) (r
|
||||
|
||||
// CancelFineTune cancel a fine-tune job.
|
||||
func (c *Client) CancelFineTune(ctx context.Context, fineTuneID string) (response FineTune, err error) {
|
||||
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/fine-tunes/"+fineTuneID+"/cancel"), nil)
|
||||
req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL("/fine-tunes/"+fineTuneID+"/cancel"), nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -89,7 +89,7 @@ func (c *Client) CancelFineTune(ctx context.Context, fineTuneID string) (respons
|
||||
}
|
||||
|
||||
func (c *Client) ListFineTunes(ctx context.Context) (response FineTuneList, err error) {
|
||||
req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL("/fine-tunes"), nil)
|
||||
req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL("/fine-tunes"), nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -100,7 +100,7 @@ func (c *Client) ListFineTunes(ctx context.Context) (response FineTuneList, err
|
||||
|
||||
func (c *Client) GetFineTune(ctx context.Context, fineTuneID string) (response FineTune, err error) {
|
||||
urlSuffix := fmt.Sprintf("/fine-tunes/%s", fineTuneID)
|
||||
req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil)
|
||||
req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -110,7 +110,7 @@ func (c *Client) GetFineTune(ctx context.Context, fineTuneID string) (response F
|
||||
}
|
||||
|
||||
func (c *Client) DeleteFineTune(ctx context.Context, fineTuneID string) (response FineTuneDeleteResponse, err error) {
|
||||
req, err := c.requestBuilder.build(ctx, http.MethodDelete, c.fullURL("/fine-tunes/"+fineTuneID), nil)
|
||||
req, err := c.requestBuilder.Build(ctx, http.MethodDelete, c.fullURL("/fine-tunes/"+fineTuneID), nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -120,7 +120,7 @@ func (c *Client) DeleteFineTune(ctx context.Context, fineTuneID string) (respons
|
||||
}
|
||||
|
||||
func (c *Client) ListFineTuneEvents(ctx context.Context, fineTuneID string) (response FineTuneEventList, err error) {
|
||||
req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL("/fine-tunes/"+fineTuneID+"/events"), nil)
|
||||
req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL("/fine-tunes/"+fineTuneID+"/events"), nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
2
image.go
2
image.go
@@ -44,7 +44,7 @@ type ImageResponseDataInner struct {
|
||||
// CreateImage - API call to create an image. This is the main endpoint of the DALL-E API.
|
||||
func (c *Client) CreateImage(ctx context.Context, request ImageRequest) (response ImageResponse, err error) {
|
||||
urlSuffix := "/images/generations"
|
||||
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL(urlSuffix), request)
|
||||
req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL(urlSuffix), request)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -4,25 +4,23 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
utils "github.com/sashabaranov/go-openai/internal"
|
||||
)
|
||||
|
||||
type requestBuilder interface {
|
||||
build(ctx context.Context, method, url string, request any) (*http.Request, error)
|
||||
type RequestBuilder interface {
|
||||
Build(ctx context.Context, method, url string, request any) (*http.Request, error)
|
||||
}
|
||||
|
||||
type httpRequestBuilder struct {
|
||||
marshaller utils.Marshaller
|
||||
type HTTPRequestBuilder struct {
|
||||
marshaller Marshaller
|
||||
}
|
||||
|
||||
func newRequestBuilder() *httpRequestBuilder {
|
||||
return &httpRequestBuilder{
|
||||
marshaller: &utils.JSONMarshaller{},
|
||||
func NewRequestBuilder() *HTTPRequestBuilder {
|
||||
return &HTTPRequestBuilder{
|
||||
marshaller: &JSONMarshaller{},
|
||||
}
|
||||
}
|
||||
|
||||
func (b *httpRequestBuilder) build(ctx context.Context, method, url string, request any) (*http.Request, error) {
|
||||
func (b *HTTPRequestBuilder) Build(ctx context.Context, method, url string, request any) (*http.Request, error) {
|
||||
if request == nil {
|
||||
return http.NewRequestWithContext(ctx, method, url, nil)
|
||||
}
|
||||
61
internal/request_builder_test.go
Normal file
61
internal/request_builder_test.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package openai //nolint:testpackage // testing private field
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var errTestMarshallerFailed = errors.New("test marshaller failed")
|
||||
|
||||
type failingMarshaller struct{}
|
||||
|
||||
func (*failingMarshaller) Marshal(_ any) ([]byte, error) {
|
||||
return []byte{}, errTestMarshallerFailed
|
||||
}
|
||||
|
||||
func TestRequestBuilderReturnsMarshallerErrors(t *testing.T) {
|
||||
builder := HTTPRequestBuilder{
|
||||
marshaller: &failingMarshaller{},
|
||||
}
|
||||
|
||||
_, err := builder.Build(context.Background(), "", "", struct{}{})
|
||||
if !errors.Is(err, errTestMarshallerFailed) {
|
||||
t.Fatalf("Did not return error when marshaller failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestBuilderReturnsRequest(t *testing.T) {
|
||||
b := NewRequestBuilder()
|
||||
var (
|
||||
ctx = context.Background()
|
||||
method = http.MethodPost
|
||||
url = "/foo"
|
||||
request = map[string]string{"foo": "bar"}
|
||||
reqBytes, _ = b.marshaller.Marshal(request)
|
||||
want, _ = http.NewRequestWithContext(ctx, method, url, bytes.NewBuffer(reqBytes))
|
||||
)
|
||||
got, _ := b.Build(ctx, method, url, request)
|
||||
if !reflect.DeepEqual(got.Body, want.Body) ||
|
||||
!reflect.DeepEqual(got.URL, want.URL) ||
|
||||
!reflect.DeepEqual(got.Method, want.Method) {
|
||||
t.Errorf("Build() got = %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestBuilderReturnsRequestWhenRequestOfArgsIsNil(t *testing.T) {
|
||||
var (
|
||||
ctx = context.Background()
|
||||
method = http.MethodGet
|
||||
url = "/foo"
|
||||
want, _ = http.NewRequestWithContext(ctx, method, url, nil)
|
||||
)
|
||||
b := NewRequestBuilder()
|
||||
got, _ := b.Build(ctx, method, url, nil)
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("Build() got = %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
@@ -40,7 +40,7 @@ type ModelsList struct {
|
||||
// ListModels Lists the currently available models,
|
||||
// and provides basic information about each model such as the model id and parent.
|
||||
func (c *Client) ListModels(ctx context.Context) (models ModelsList, err error) {
|
||||
req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL("/models"), nil)
|
||||
req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL("/models"), nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -63,7 +63,7 @@ type ModerationResponse struct {
|
||||
// Moderations — perform a moderation api call over a string.
|
||||
// Input can be an array or slice but a string will reduce the complexity.
|
||||
func (c *Client) Moderations(ctx context.Context, request ModerationRequest) (response ModerationResponse, err error) {
|
||||
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/moderations", request.Model), request)
|
||||
req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL("/moderations", request.Model), request)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1,177 +0,0 @@
|
||||
package openai //nolint:testpackage // testing private field
|
||||
|
||||
import (
|
||||
"github.com/sashabaranov/go-openai/internal/test"
|
||||
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var (
|
||||
errTestMarshallerFailed = errors.New("test marshaller failed")
|
||||
errTestRequestBuilderFailed = errors.New("test request builder failed")
|
||||
)
|
||||
|
||||
type (
|
||||
failingRequestBuilder struct{}
|
||||
failingMarshaller struct{}
|
||||
)
|
||||
|
||||
func (*failingMarshaller) Marshal(_ any) ([]byte, error) {
|
||||
return []byte{}, errTestMarshallerFailed
|
||||
}
|
||||
|
||||
func (*failingRequestBuilder) build(_ context.Context, _, _ string, _ any) (*http.Request, error) {
|
||||
return nil, errTestRequestBuilderFailed
|
||||
}
|
||||
|
||||
func TestRequestBuilderReturnsMarshallerErrors(t *testing.T) {
|
||||
builder := httpRequestBuilder{
|
||||
marshaller: &failingMarshaller{},
|
||||
}
|
||||
|
||||
_, err := builder.build(context.Background(), "", "", struct{}{})
|
||||
if !errors.Is(err, errTestMarshallerFailed) {
|
||||
t.Fatalf("Did not return error when marshaller failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientReturnsRequestBuilderErrors(t *testing.T) {
|
||||
var err error
|
||||
ts := test.NewTestServer().OpenAITestServer()
|
||||
ts.Start()
|
||||
defer ts.Close()
|
||||
|
||||
config := DefaultConfig(test.GetTestToken())
|
||||
config.BaseURL = ts.URL + "/v1"
|
||||
client := NewClientWithConfig(config)
|
||||
client.requestBuilder = &failingRequestBuilder{}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
_, err = client.CreateCompletion(ctx, CompletionRequest{Prompt: "testing"})
|
||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||
}
|
||||
|
||||
_, err = client.CreateChatCompletion(ctx, ChatCompletionRequest{Model: GPT3Dot5Turbo})
|
||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||
}
|
||||
|
||||
_, err = client.CreateChatCompletionStream(ctx, ChatCompletionRequest{Model: GPT3Dot5Turbo})
|
||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||
}
|
||||
|
||||
_, err = client.CreateFineTune(ctx, FineTuneRequest{})
|
||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||
}
|
||||
|
||||
_, err = client.ListFineTunes(ctx)
|
||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||
}
|
||||
|
||||
_, err = client.CancelFineTune(ctx, "")
|
||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||
}
|
||||
|
||||
_, err = client.GetFineTune(ctx, "")
|
||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||
}
|
||||
|
||||
_, err = client.DeleteFineTune(ctx, "")
|
||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||
}
|
||||
|
||||
_, err = client.ListFineTuneEvents(ctx, "")
|
||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||
}
|
||||
|
||||
_, err = client.Moderations(ctx, ModerationRequest{})
|
||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||
}
|
||||
|
||||
_, err = client.Edits(ctx, EditsRequest{})
|
||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||
}
|
||||
|
||||
_, err = client.CreateEmbeddings(ctx, EmbeddingRequest{})
|
||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||
}
|
||||
|
||||
_, err = client.CreateImage(ctx, ImageRequest{})
|
||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||
}
|
||||
|
||||
err = client.DeleteFile(ctx, "")
|
||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||
}
|
||||
|
||||
_, err = client.GetFile(ctx, "")
|
||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||
}
|
||||
|
||||
_, err = client.ListFiles(ctx)
|
||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||
}
|
||||
|
||||
_, err = client.ListEngines(ctx)
|
||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||
}
|
||||
|
||||
_, err = client.GetEngine(ctx, "")
|
||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||
}
|
||||
|
||||
_, err = client.ListModels(ctx)
|
||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||
}
|
||||
|
||||
_, err = client.CreateCompletionStream(ctx, CompletionRequest{Prompt: ""})
|
||||
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReturnsRequestBuilderErrorsAddtion(t *testing.T) {
|
||||
var err error
|
||||
ts := test.NewTestServer().OpenAITestServer()
|
||||
ts.Start()
|
||||
defer ts.Close()
|
||||
|
||||
config := DefaultConfig(test.GetTestToken())
|
||||
config.BaseURL = ts.URL + "/v1"
|
||||
client := NewClientWithConfig(config)
|
||||
client.requestBuilder = &failingRequestBuilder{}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
_, err = client.CreateCompletion(ctx, CompletionRequest{Prompt: 1})
|
||||
if !errors.Is(err, ErrCompletionRequestPromptTypeNotSupported) {
|
||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||
}
|
||||
|
||||
_, err = client.CreateCompletionStream(ctx, CompletionRequest{Prompt: 1})
|
||||
if !errors.Is(err, ErrCompletionRequestPromptTypeNotSupported) {
|
||||
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user