Add testable request builder (#162)
* Add testable request builder * improve tests
This commit is contained in:
20
api.go
20
api.go
@@ -1,7 +1,6 @@
|
|||||||
package openai
|
package openai
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -12,7 +11,7 @@ import (
|
|||||||
type Client struct {
|
type Client struct {
|
||||||
config ClientConfig
|
config ClientConfig
|
||||||
|
|
||||||
marshaller marshaller
|
requestBuilder requestBuilder
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewClient creates new OpenAI API client.
|
// NewClient creates new OpenAI API client.
|
||||||
@@ -24,8 +23,8 @@ func NewClient(authToken string) *Client {
|
|||||||
// NewClientWithConfig creates new OpenAI API client for specified config.
|
// NewClientWithConfig creates new OpenAI API client for specified config.
|
||||||
func NewClientWithConfig(config ClientConfig) *Client {
|
func NewClientWithConfig(config ClientConfig) *Client {
|
||||||
return &Client{
|
return &Client{
|
||||||
config: config,
|
config: config,
|
||||||
marshaller: &jsonMarshaller{},
|
requestBuilder: newRequestBuilder(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -91,17 +90,8 @@ func (c *Client) newStreamRequest(
|
|||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
method string,
|
method string,
|
||||||
urlSuffix string,
|
urlSuffix string,
|
||||||
body interface{}) (*http.Request, error) {
|
body any) (*http.Request, error) {
|
||||||
var reqBody []byte
|
req, err := c.requestBuilder.build(ctx, method, c.fullURL(urlSuffix), body)
|
||||||
if body != nil {
|
|
||||||
var err error
|
|
||||||
reqBody, err = c.marshaller.marshal(body)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, method, c.fullURL(urlSuffix), bytes.NewBuffer(reqBody))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
9
chat.go
9
chat.go
@@ -1,7 +1,6 @@
|
|||||||
package openai
|
package openai
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -72,14 +71,8 @@ func (c *Client) CreateChatCompletion(
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var reqBytes []byte
|
|
||||||
reqBytes, err = c.marshaller.marshal(request)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
urlSuffix := "/chat/completions"
|
urlSuffix := "/chat/completions"
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix), bytes.NewBuffer(reqBytes))
|
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL(urlSuffix), request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package openai
|
package openai
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -105,14 +104,8 @@ func (c *Client) CreateCompletion(
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var reqBytes []byte
|
|
||||||
reqBytes, err = c.marshaller.marshal(request)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
urlSuffix := "/completions"
|
urlSuffix := "/completions"
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix), bytes.NewBuffer(reqBytes))
|
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL(urlSuffix), request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
9
edits.go
9
edits.go
@@ -1,7 +1,6 @@
|
|||||||
package openai
|
package openai
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"context"
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
)
|
)
|
||||||
@@ -32,13 +31,7 @@ type EditsResponse struct {
|
|||||||
|
|
||||||
// Perform an API call to the Edits endpoint.
|
// Perform an API call to the Edits endpoint.
|
||||||
func (c *Client) Edits(ctx context.Context, request EditsRequest) (response EditsResponse, err error) {
|
func (c *Client) Edits(ctx context.Context, request EditsRequest) (response EditsResponse, err error) {
|
||||||
var reqBytes []byte
|
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/edits"), request)
|
||||||
reqBytes, err = c.marshaller.marshal(request)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL("/edits"), bytes.NewBuffer(reqBytes))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package openai
|
package openai
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"context"
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
)
|
)
|
||||||
@@ -133,14 +132,7 @@ type EmbeddingRequest struct {
|
|||||||
// CreateEmbeddings returns an EmbeddingResponse which will contain an Embedding for every item in |request.Input|.
|
// CreateEmbeddings returns an EmbeddingResponse which will contain an Embedding for every item in |request.Input|.
|
||||||
// https://beta.openai.com/docs/api-reference/embeddings/create
|
// https://beta.openai.com/docs/api-reference/embeddings/create
|
||||||
func (c *Client) CreateEmbeddings(ctx context.Context, request EmbeddingRequest) (resp EmbeddingResponse, err error) {
|
func (c *Client) CreateEmbeddings(ctx context.Context, request EmbeddingRequest) (resp EmbeddingResponse, err error) {
|
||||||
var reqBytes []byte
|
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/embeddings"), request)
|
||||||
reqBytes, err = c.marshaller.marshal(request)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
urlSuffix := "/embeddings"
|
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix), bytes.NewBuffer(reqBytes))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ type EnginesList struct {
|
|||||||
// ListEngines Lists the currently available engines, and provides basic
|
// ListEngines Lists the currently available engines, and provides basic
|
||||||
// information about each option such as the owner and availability.
|
// information about each option such as the owner and availability.
|
||||||
func (c *Client) ListEngines(ctx context.Context) (engines EnginesList, err error) {
|
func (c *Client) ListEngines(ctx context.Context) (engines EnginesList, err error) {
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.fullURL("/engines"), nil)
|
req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL("/engines"), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -38,7 +38,7 @@ func (c *Client) GetEngine(
|
|||||||
engineID string,
|
engineID string,
|
||||||
) (engine Engine, err error) {
|
) (engine Engine, err error) {
|
||||||
urlSuffix := fmt.Sprintf("/engines/%s", engineID)
|
urlSuffix := fmt.Sprintf("/engines/%s", engineID)
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.fullURL(urlSuffix), nil)
|
req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
6
files.go
6
files.go
@@ -112,7 +112,7 @@ func (c *Client) CreateFile(ctx context.Context, request FileRequest) (file File
|
|||||||
|
|
||||||
// DeleteFile deletes an existing file.
|
// DeleteFile deletes an existing file.
|
||||||
func (c *Client) DeleteFile(ctx context.Context, fileID string) (err error) {
|
func (c *Client) DeleteFile(ctx context.Context, fileID string) (err error) {
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodDelete, c.fullURL("/files/"+fileID), nil)
|
req, err := c.requestBuilder.build(ctx, http.MethodDelete, c.fullURL("/files/"+fileID), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -124,7 +124,7 @@ func (c *Client) DeleteFile(ctx context.Context, fileID string) (err error) {
|
|||||||
// ListFiles Lists the currently available files,
|
// ListFiles Lists the currently available files,
|
||||||
// and provides basic information about each file such as the file name and purpose.
|
// and provides basic information about each file such as the file name and purpose.
|
||||||
func (c *Client) ListFiles(ctx context.Context) (files FilesList, err error) {
|
func (c *Client) ListFiles(ctx context.Context) (files FilesList, err error) {
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.fullURL("/files"), nil)
|
req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL("/files"), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -137,7 +137,7 @@ func (c *Client) ListFiles(ctx context.Context) (files FilesList, err error) {
|
|||||||
// such as the file name and purpose.
|
// such as the file name and purpose.
|
||||||
func (c *Client) GetFile(ctx context.Context, fileID string) (file File, err error) {
|
func (c *Client) GetFile(ctx context.Context, fileID string) (file File, err error) {
|
||||||
urlSuffix := fmt.Sprintf("/files/%s", fileID)
|
urlSuffix := fmt.Sprintf("/files/%s", fileID)
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.fullURL(urlSuffix), nil)
|
req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package openai
|
package openai
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -68,14 +67,8 @@ type FineTuneDeleteResponse struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) CreateFineTune(ctx context.Context, request FineTuneRequest) (response FineTune, err error) {
|
func (c *Client) CreateFineTune(ctx context.Context, request FineTuneRequest) (response FineTune, err error) {
|
||||||
var reqBytes []byte
|
|
||||||
reqBytes, err = c.marshaller.marshal(request)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
urlSuffix := "/fine-tunes"
|
urlSuffix := "/fine-tunes"
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix), bytes.NewBuffer(reqBytes))
|
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL(urlSuffix), request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -86,7 +79,7 @@ func (c *Client) CreateFineTune(ctx context.Context, request FineTuneRequest) (r
|
|||||||
|
|
||||||
// CancelFineTune cancel a fine-tune job.
|
// CancelFineTune cancel a fine-tune job.
|
||||||
func (c *Client) CancelFineTune(ctx context.Context, fineTuneID string) (response FineTune, err error) {
|
func (c *Client) CancelFineTune(ctx context.Context, fineTuneID string) (response FineTune, err error) {
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL("/fine-tunes/"+fineTuneID+"/cancel"), nil)
|
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/fine-tunes/"+fineTuneID+"/cancel"), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -96,7 +89,7 @@ func (c *Client) CancelFineTune(ctx context.Context, fineTuneID string) (respons
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) ListFineTunes(ctx context.Context) (response FineTuneList, err error) {
|
func (c *Client) ListFineTunes(ctx context.Context) (response FineTuneList, err error) {
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.fullURL("/fine-tunes"), nil)
|
req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL("/fine-tunes"), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -107,7 +100,7 @@ func (c *Client) ListFineTunes(ctx context.Context) (response FineTuneList, err
|
|||||||
|
|
||||||
func (c *Client) GetFineTune(ctx context.Context, fineTuneID string) (response FineTune, err error) {
|
func (c *Client) GetFineTune(ctx context.Context, fineTuneID string) (response FineTune, err error) {
|
||||||
urlSuffix := fmt.Sprintf("/fine-tunes/%s", fineTuneID)
|
urlSuffix := fmt.Sprintf("/fine-tunes/%s", fineTuneID)
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.fullURL(urlSuffix), nil)
|
req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -117,7 +110,7 @@ func (c *Client) GetFineTune(ctx context.Context, fineTuneID string) (response F
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) DeleteFineTune(ctx context.Context, fineTuneID string) (response FineTuneDeleteResponse, err error) {
|
func (c *Client) DeleteFineTune(ctx context.Context, fineTuneID string) (response FineTuneDeleteResponse, err error) {
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodDelete, c.fullURL("/fine-tunes/"+fineTuneID), nil)
|
req, err := c.requestBuilder.build(ctx, http.MethodDelete, c.fullURL("/fine-tunes/"+fineTuneID), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -127,7 +120,7 @@ func (c *Client) DeleteFineTune(ctx context.Context, fineTuneID string) (respons
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) ListFineTuneEvents(ctx context.Context, fineTuneID string) (response FineTuneEventList, err error) {
|
func (c *Client) ListFineTuneEvents(ctx context.Context, fineTuneID string) (response FineTuneEventList, err error) {
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.fullURL("/fine-tunes/"+fineTuneID+"/events"), nil)
|
req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL("/fine-tunes/"+fineTuneID+"/events"), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
8
image.go
8
image.go
@@ -45,14 +45,8 @@ type ImageResponseDataInner struct {
|
|||||||
|
|
||||||
// CreateImage - API call to create an image. This is the main endpoint of the DALL-E API.
|
// CreateImage - API call to create an image. This is the main endpoint of the DALL-E API.
|
||||||
func (c *Client) CreateImage(ctx context.Context, request ImageRequest) (response ImageResponse, err error) {
|
func (c *Client) CreateImage(ctx context.Context, request ImageRequest) (response ImageResponse, err error) {
|
||||||
var reqBytes []byte
|
|
||||||
reqBytes, err = c.marshaller.marshal(request)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
urlSuffix := "/images/generations"
|
urlSuffix := "/images/generations"
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix), bytes.NewBuffer(reqBytes))
|
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL(urlSuffix), request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,71 +0,0 @@
|
|||||||
package openai //nolint:testpackage // testing private field
|
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/sashabaranov/go-openai/internal/test"
|
|
||||||
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
type failingMarshaller struct{}
|
|
||||||
|
|
||||||
var errTestMarshallerFailed = errors.New("test marshaller failed")
|
|
||||||
|
|
||||||
func (jm *failingMarshaller) marshal(value any) ([]byte, error) {
|
|
||||||
return []byte{}, errTestMarshallerFailed
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestClientReturnMarshallerErrors(t *testing.T) {
|
|
||||||
var err error
|
|
||||||
ts := test.NewTestServer().OpenAITestServer()
|
|
||||||
ts.Start()
|
|
||||||
defer ts.Close()
|
|
||||||
|
|
||||||
config := DefaultConfig(test.GetTestToken())
|
|
||||||
config.BaseURL = ts.URL + "/v1"
|
|
||||||
client := NewClientWithConfig(config)
|
|
||||||
client.marshaller = &failingMarshaller{}
|
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
_, err = client.CreateCompletion(ctx, CompletionRequest{})
|
|
||||||
if !errors.Is(err, errTestMarshallerFailed) {
|
|
||||||
t.Fatalf("Did not return error when marshaller failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = client.CreateChatCompletion(ctx, ChatCompletionRequest{Model: GPT3Dot5Turbo})
|
|
||||||
if !errors.Is(err, errTestMarshallerFailed) {
|
|
||||||
t.Fatalf("Did not return error when marshaller failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = client.CreateChatCompletionStream(ctx, ChatCompletionRequest{})
|
|
||||||
if !errors.Is(err, errTestMarshallerFailed) {
|
|
||||||
t.Fatalf("Did not return error when marshaller failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = client.CreateFineTune(ctx, FineTuneRequest{})
|
|
||||||
if !errors.Is(err, errTestMarshallerFailed) {
|
|
||||||
t.Fatalf("Did not return error when marshaller failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = client.Moderations(ctx, ModerationRequest{})
|
|
||||||
if !errors.Is(err, errTestMarshallerFailed) {
|
|
||||||
t.Fatalf("Did not return error when marshaller failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = client.Edits(ctx, EditsRequest{})
|
|
||||||
if !errors.Is(err, errTestMarshallerFailed) {
|
|
||||||
t.Fatalf("Did not return error when marshaller failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = client.CreateEmbeddings(ctx, EmbeddingRequest{})
|
|
||||||
if !errors.Is(err, errTestMarshallerFailed) {
|
|
||||||
t.Fatalf("Did not return error when marshaller failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = client.CreateImage(ctx, ImageRequest{})
|
|
||||||
if !errors.Is(err, errTestMarshallerFailed) {
|
|
||||||
t.Fatalf("Did not return error when marshaller failed: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,7 +1,6 @@
|
|||||||
package openai
|
package openai
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"context"
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
)
|
)
|
||||||
@@ -51,13 +50,7 @@ type ModerationResponse struct {
|
|||||||
// Moderations — perform a moderation api call over a string.
|
// Moderations — perform a moderation api call over a string.
|
||||||
// Input can be an array or slice but a string will reduce the complexity.
|
// Input can be an array or slice but a string will reduce the complexity.
|
||||||
func (c *Client) Moderations(ctx context.Context, request ModerationRequest) (response ModerationResponse, err error) {
|
func (c *Client) Moderations(ctx context.Context, request ModerationRequest) (response ModerationResponse, err error) {
|
||||||
var reqBytes []byte
|
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/moderations"), request)
|
||||||
reqBytes, err = c.marshaller.marshal(request)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL("/moderations"), bytes.NewBuffer(reqBytes))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
40
request_builder.go
Normal file
40
request_builder.go
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
type requestBuilder interface {
|
||||||
|
build(ctx context.Context, method, url string, request any) (*http.Request, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type httpRequestBuilder struct {
|
||||||
|
marshaller marshaller
|
||||||
|
}
|
||||||
|
|
||||||
|
func newRequestBuilder() *httpRequestBuilder {
|
||||||
|
return &httpRequestBuilder{
|
||||||
|
marshaller: &jsonMarshaller{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *httpRequestBuilder) build(ctx context.Context, method, url string, request any) (*http.Request, error) {
|
||||||
|
if request == nil {
|
||||||
|
return http.NewRequestWithContext(ctx, method, url, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
var reqBytes []byte
|
||||||
|
reqBytes, err := b.marshaller.marshal(request)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return http.NewRequestWithContext(
|
||||||
|
ctx,
|
||||||
|
method,
|
||||||
|
url,
|
||||||
|
bytes.NewBuffer(reqBytes),
|
||||||
|
)
|
||||||
|
}
|
||||||
143
request_builder_test.go
Normal file
143
request_builder_test.go
Normal file
@@ -0,0 +1,143 @@
|
|||||||
|
package openai //nolint:testpackage // testing private field
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/sashabaranov/go-openai/internal/test"
|
||||||
|
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
errTestMarshallerFailed = errors.New("test marshaller failed")
|
||||||
|
errTestRequestBuilderFailed = errors.New("test request builder failed")
|
||||||
|
)
|
||||||
|
|
||||||
|
type (
|
||||||
|
failingRequestBuilder struct{}
|
||||||
|
failingMarshaller struct{}
|
||||||
|
)
|
||||||
|
|
||||||
|
func (*failingMarshaller) marshal(value any) ([]byte, error) {
|
||||||
|
return []byte{}, errTestMarshallerFailed
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*failingRequestBuilder) build(ctx context.Context, method, url string, requset any) (*http.Request, error) {
|
||||||
|
return nil, errTestRequestBuilderFailed
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestBuilderReturnsMarshallerErrors(t *testing.T) {
|
||||||
|
builder := httpRequestBuilder{
|
||||||
|
marshaller: &failingMarshaller{},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := builder.build(context.Background(), "", "", struct{}{})
|
||||||
|
if !errors.Is(err, errTestMarshallerFailed) {
|
||||||
|
t.Fatalf("Did not return error when marshaller failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientReturnsRequestBuilderErrors(t *testing.T) {
|
||||||
|
var err error
|
||||||
|
ts := test.NewTestServer().OpenAITestServer()
|
||||||
|
ts.Start()
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
config := DefaultConfig(test.GetTestToken())
|
||||||
|
config.BaseURL = ts.URL + "/v1"
|
||||||
|
client := NewClientWithConfig(config)
|
||||||
|
client.requestBuilder = &failingRequestBuilder{}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
_, err = client.CreateCompletion(ctx, CompletionRequest{})
|
||||||
|
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||||
|
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = client.CreateChatCompletion(ctx, ChatCompletionRequest{Model: GPT3Dot5Turbo})
|
||||||
|
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||||
|
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = client.CreateChatCompletionStream(ctx, ChatCompletionRequest{})
|
||||||
|
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||||
|
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = client.CreateFineTune(ctx, FineTuneRequest{})
|
||||||
|
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||||
|
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = client.ListFineTunes(ctx)
|
||||||
|
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||||
|
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = client.CancelFineTune(ctx, "")
|
||||||
|
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||||
|
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = client.GetFineTune(ctx, "")
|
||||||
|
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||||
|
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = client.DeleteFineTune(ctx, "")
|
||||||
|
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||||
|
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = client.ListFineTuneEvents(ctx, "")
|
||||||
|
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||||
|
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = client.Moderations(ctx, ModerationRequest{})
|
||||||
|
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||||
|
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = client.Edits(ctx, EditsRequest{})
|
||||||
|
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||||
|
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = client.CreateEmbeddings(ctx, EmbeddingRequest{})
|
||||||
|
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||||
|
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = client.CreateImage(ctx, ImageRequest{})
|
||||||
|
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||||
|
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = client.DeleteFile(ctx, "")
|
||||||
|
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||||
|
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = client.GetFile(ctx, "")
|
||||||
|
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||||
|
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = client.ListFiles(ctx)
|
||||||
|
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||||
|
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = client.ListEngines(ctx)
|
||||||
|
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||||
|
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = client.GetEngine(ctx, "")
|
||||||
|
if !errors.Is(err, errTestRequestBuilderFailed) {
|
||||||
|
t.Fatalf("Did not return error when request builder failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user