add testable json marshaller (#161)
This commit is contained in:
13
api.go
13
api.go
@@ -11,17 +11,22 @@ import (
|
||||
// Client is OpenAI GPT-3 API client.
|
||||
type Client struct {
|
||||
config ClientConfig
|
||||
|
||||
marshaller marshaller
|
||||
}
|
||||
|
||||
// NewClient creates new OpenAI API client.
|
||||
func NewClient(authToken string) *Client {
|
||||
config := DefaultConfig(authToken)
|
||||
return &Client{config}
|
||||
return NewClientWithConfig(config)
|
||||
}
|
||||
|
||||
// NewClientWithConfig creates new OpenAI API client for specified config.
|
||||
func NewClientWithConfig(config ClientConfig) *Client {
|
||||
return &Client{config}
|
||||
return &Client{
|
||||
config: config,
|
||||
marshaller: &jsonMarshaller{},
|
||||
}
|
||||
}
|
||||
|
||||
// NewOrgClient creates new OpenAI API client for specified Organization ID.
|
||||
@@ -30,7 +35,7 @@ func NewClientWithConfig(config ClientConfig) *Client {
|
||||
func NewOrgClient(authToken, org string) *Client {
|
||||
config := DefaultConfig(authToken)
|
||||
config.OrgID = org
|
||||
return &Client{config}
|
||||
return NewClientWithConfig(config)
|
||||
}
|
||||
|
||||
func (c *Client) sendRequest(req *http.Request, v interface{}) error {
|
||||
@@ -90,7 +95,7 @@ func (c *Client) newStreamRequest(
|
||||
var reqBody []byte
|
||||
if body != nil {
|
||||
var err error
|
||||
reqBody, err = json.Marshal(body)
|
||||
reqBody, err = c.marshaller.marshal(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
3
chat.go
3
chat.go
@@ -3,7 +3,6 @@ package openai
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
)
|
||||
@@ -74,7 +73,7 @@ func (c *Client) CreateChatCompletion(
|
||||
}
|
||||
|
||||
var reqBytes []byte
|
||||
reqBytes, err = json.Marshal(request)
|
||||
reqBytes, err = c.marshaller.marshal(request)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -3,7 +3,6 @@ package openai
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
)
|
||||
@@ -107,7 +106,7 @@ func (c *Client) CreateCompletion(
|
||||
}
|
||||
|
||||
var reqBytes []byte
|
||||
reqBytes, err = json.Marshal(request)
|
||||
reqBytes, err = c.marshaller.marshal(request)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
3
edits.go
3
edits.go
@@ -3,7 +3,6 @@ package openai
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
@@ -34,7 +33,7 @@ type EditsResponse struct {
|
||||
// Perform an API call to the Edits endpoint.
|
||||
func (c *Client) Edits(ctx context.Context, request EditsRequest) (response EditsResponse, err error) {
|
||||
var reqBytes []byte
|
||||
reqBytes, err = json.Marshal(request)
|
||||
reqBytes, err = c.marshaller.marshal(request)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -3,7 +3,6 @@ package openai
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
@@ -135,7 +134,7 @@ type EmbeddingRequest struct {
|
||||
// https://beta.openai.com/docs/api-reference/embeddings/create
|
||||
func (c *Client) CreateEmbeddings(ctx context.Context, request EmbeddingRequest) (resp EmbeddingResponse, err error) {
|
||||
var reqBytes []byte
|
||||
reqBytes, err = json.Marshal(request)
|
||||
reqBytes, err = c.marshaller.marshal(request)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -3,7 +3,6 @@ package openai
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
@@ -70,7 +69,7 @@ type FineTuneDeleteResponse struct {
|
||||
|
||||
func (c *Client) CreateFineTune(ctx context.Context, request FineTuneRequest) (response FineTune, err error) {
|
||||
var reqBytes []byte
|
||||
reqBytes, err = json.Marshal(request)
|
||||
reqBytes, err = c.marshaller.marshal(request)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
3
image.go
3
image.go
@@ -3,7 +3,6 @@ package openai
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
@@ -47,7 +46,7 @@ type ImageResponseDataInner struct {
|
||||
// CreateImage - API call to create an image. This is the main endpoint of the DALL-E API.
|
||||
func (c *Client) CreateImage(ctx context.Context, request ImageRequest) (response ImageResponse, err error) {
|
||||
var reqBytes []byte
|
||||
reqBytes, err = json.Marshal(request)
|
||||
reqBytes, err = c.marshaller.marshal(request)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
15
marshaller.go
Normal file
15
marshaller.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
type marshaller interface {
|
||||
marshal(value any) ([]byte, error)
|
||||
}
|
||||
|
||||
type jsonMarshaller struct{}
|
||||
|
||||
func (jm *jsonMarshaller) marshal(value any) ([]byte, error) {
|
||||
return json.Marshal(value)
|
||||
}
|
||||
71
marshaller_test.go
Normal file
71
marshaller_test.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package openai //nolint:testpackage // testing private field
|
||||
|
||||
import (
|
||||
"github.com/sashabaranov/go-openai/internal/test"
|
||||
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type failingMarshaller struct{}
|
||||
|
||||
var errTestMarshallerFailed = errors.New("test marshaller failed")
|
||||
|
||||
func (jm *failingMarshaller) marshal(value any) ([]byte, error) {
|
||||
return []byte{}, errTestMarshallerFailed
|
||||
}
|
||||
|
||||
func TestClientReturnMarshallerErrors(t *testing.T) {
|
||||
var err error
|
||||
ts := test.NewTestServer().OpenAITestServer()
|
||||
ts.Start()
|
||||
defer ts.Close()
|
||||
|
||||
config := DefaultConfig(test.GetTestToken())
|
||||
config.BaseURL = ts.URL + "/v1"
|
||||
client := NewClientWithConfig(config)
|
||||
client.marshaller = &failingMarshaller{}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
_, err = client.CreateCompletion(ctx, CompletionRequest{})
|
||||
if !errors.Is(err, errTestMarshallerFailed) {
|
||||
t.Fatalf("Did not return error when marshaller failed: %v", err)
|
||||
}
|
||||
|
||||
_, err = client.CreateChatCompletion(ctx, ChatCompletionRequest{Model: GPT3Dot5Turbo})
|
||||
if !errors.Is(err, errTestMarshallerFailed) {
|
||||
t.Fatalf("Did not return error when marshaller failed: %v", err)
|
||||
}
|
||||
|
||||
_, err = client.CreateChatCompletionStream(ctx, ChatCompletionRequest{})
|
||||
if !errors.Is(err, errTestMarshallerFailed) {
|
||||
t.Fatalf("Did not return error when marshaller failed: %v", err)
|
||||
}
|
||||
|
||||
_, err = client.CreateFineTune(ctx, FineTuneRequest{})
|
||||
if !errors.Is(err, errTestMarshallerFailed) {
|
||||
t.Fatalf("Did not return error when marshaller failed: %v", err)
|
||||
}
|
||||
|
||||
_, err = client.Moderations(ctx, ModerationRequest{})
|
||||
if !errors.Is(err, errTestMarshallerFailed) {
|
||||
t.Fatalf("Did not return error when marshaller failed: %v", err)
|
||||
}
|
||||
|
||||
_, err = client.Edits(ctx, EditsRequest{})
|
||||
if !errors.Is(err, errTestMarshallerFailed) {
|
||||
t.Fatalf("Did not return error when marshaller failed: %v", err)
|
||||
}
|
||||
|
||||
_, err = client.CreateEmbeddings(ctx, EmbeddingRequest{})
|
||||
if !errors.Is(err, errTestMarshallerFailed) {
|
||||
t.Fatalf("Did not return error when marshaller failed: %v", err)
|
||||
}
|
||||
|
||||
_, err = client.CreateImage(ctx, ImageRequest{})
|
||||
if !errors.Is(err, errTestMarshallerFailed) {
|
||||
t.Fatalf("Did not return error when marshaller failed: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -3,7 +3,6 @@ package openai
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
@@ -53,7 +52,7 @@ type ModerationResponse struct {
|
||||
// Input can be an array or slice but a string will reduce the complexity.
|
||||
func (c *Client) Moderations(ctx context.Context, request ModerationRequest) (response ModerationResponse, err error) {
|
||||
var reqBytes []byte
|
||||
reqBytes, err = json.Marshal(request)
|
||||
reqBytes, err = c.marshaller.marshal(request)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user