add testable json marshaller (#161)

This commit is contained in:
sashabaranov
2023-03-15 12:16:47 +04:00
committed by GitHub
parent ba77a6476e
commit 53d195cf5a
10 changed files with 102 additions and 18 deletions

13
api.go
View File

@@ -11,17 +11,22 @@ import (
// Client is OpenAI GPT-3 API client.
type Client struct {
config ClientConfig
marshaller marshaller
}
// NewClient creates new OpenAI API client.
func NewClient(authToken string) *Client {
config := DefaultConfig(authToken)
return &Client{config}
return NewClientWithConfig(config)
}
// NewClientWithConfig creates new OpenAI API client for specified config.
func NewClientWithConfig(config ClientConfig) *Client {
return &Client{config}
return &Client{
config: config,
marshaller: &jsonMarshaller{},
}
}
// NewOrgClient creates new OpenAI API client for specified Organization ID.
@@ -30,7 +35,7 @@ func NewClientWithConfig(config ClientConfig) *Client {
func NewOrgClient(authToken, org string) *Client {
config := DefaultConfig(authToken)
config.OrgID = org
return &Client{config}
return NewClientWithConfig(config)
}
func (c *Client) sendRequest(req *http.Request, v interface{}) error {
@@ -90,7 +95,7 @@ func (c *Client) newStreamRequest(
var reqBody []byte
if body != nil {
var err error
reqBody, err = json.Marshal(body)
reqBody, err = c.marshaller.marshal(body)
if err != nil {
return nil, err
}

View File

@@ -3,7 +3,6 @@ package openai
import (
"bytes"
"context"
"encoding/json"
"errors"
"net/http"
)
@@ -74,7 +73,7 @@ func (c *Client) CreateChatCompletion(
}
var reqBytes []byte
reqBytes, err = json.Marshal(request)
reqBytes, err = c.marshaller.marshal(request)
if err != nil {
return
}

View File

@@ -3,7 +3,6 @@ package openai
import (
"bytes"
"context"
"encoding/json"
"errors"
"net/http"
)
@@ -107,7 +106,7 @@ func (c *Client) CreateCompletion(
}
var reqBytes []byte
reqBytes, err = json.Marshal(request)
reqBytes, err = c.marshaller.marshal(request)
if err != nil {
return
}

View File

@@ -3,7 +3,6 @@ package openai
import (
"bytes"
"context"
"encoding/json"
"net/http"
)
@@ -34,7 +33,7 @@ type EditsResponse struct {
// Perform an API call to the Edits endpoint.
func (c *Client) Edits(ctx context.Context, request EditsRequest) (response EditsResponse, err error) {
var reqBytes []byte
reqBytes, err = json.Marshal(request)
reqBytes, err = c.marshaller.marshal(request)
if err != nil {
return
}

View File

@@ -3,7 +3,6 @@ package openai
import (
"bytes"
"context"
"encoding/json"
"net/http"
)
@@ -135,7 +134,7 @@ type EmbeddingRequest struct {
// https://beta.openai.com/docs/api-reference/embeddings/create
func (c *Client) CreateEmbeddings(ctx context.Context, request EmbeddingRequest) (resp EmbeddingResponse, err error) {
var reqBytes []byte
reqBytes, err = json.Marshal(request)
reqBytes, err = c.marshaller.marshal(request)
if err != nil {
return
}

View File

@@ -3,7 +3,6 @@ package openai
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
)
@@ -70,7 +69,7 @@ type FineTuneDeleteResponse struct {
func (c *Client) CreateFineTune(ctx context.Context, request FineTuneRequest) (response FineTune, err error) {
var reqBytes []byte
reqBytes, err = json.Marshal(request)
reqBytes, err = c.marshaller.marshal(request)
if err != nil {
return
}

View File

@@ -3,7 +3,6 @@ package openai
import (
"bytes"
"context"
"encoding/json"
"io"
"mime/multipart"
"net/http"
@@ -47,7 +46,7 @@ type ImageResponseDataInner struct {
// CreateImage - API call to create an image. This is the main endpoint of the DALL-E API.
func (c *Client) CreateImage(ctx context.Context, request ImageRequest) (response ImageResponse, err error) {
var reqBytes []byte
reqBytes, err = json.Marshal(request)
reqBytes, err = c.marshaller.marshal(request)
if err != nil {
return
}

15
marshaller.go Normal file
View File

@@ -0,0 +1,15 @@
package openai
import (
"encoding/json"
)
type marshaller interface {
marshal(value any) ([]byte, error)
}
type jsonMarshaller struct{}
func (jm *jsonMarshaller) marshal(value any) ([]byte, error) {
return json.Marshal(value)
}

71
marshaller_test.go Normal file
View File

@@ -0,0 +1,71 @@
package openai //nolint:testpackage // testing private field
import (
"github.com/sashabaranov/go-openai/internal/test"
"context"
"errors"
"testing"
)
type failingMarshaller struct{}
var errTestMarshallerFailed = errors.New("test marshaller failed")
func (jm *failingMarshaller) marshal(value any) ([]byte, error) {
return []byte{}, errTestMarshallerFailed
}
func TestClientReturnMarshallerErrors(t *testing.T) {
var err error
ts := test.NewTestServer().OpenAITestServer()
ts.Start()
defer ts.Close()
config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
client.marshaller = &failingMarshaller{}
ctx := context.Background()
_, err = client.CreateCompletion(ctx, CompletionRequest{})
if !errors.Is(err, errTestMarshallerFailed) {
t.Fatalf("Did not return error when marshaller failed: %v", err)
}
_, err = client.CreateChatCompletion(ctx, ChatCompletionRequest{Model: GPT3Dot5Turbo})
if !errors.Is(err, errTestMarshallerFailed) {
t.Fatalf("Did not return error when marshaller failed: %v", err)
}
_, err = client.CreateChatCompletionStream(ctx, ChatCompletionRequest{})
if !errors.Is(err, errTestMarshallerFailed) {
t.Fatalf("Did not return error when marshaller failed: %v", err)
}
_, err = client.CreateFineTune(ctx, FineTuneRequest{})
if !errors.Is(err, errTestMarshallerFailed) {
t.Fatalf("Did not return error when marshaller failed: %v", err)
}
_, err = client.Moderations(ctx, ModerationRequest{})
if !errors.Is(err, errTestMarshallerFailed) {
t.Fatalf("Did not return error when marshaller failed: %v", err)
}
_, err = client.Edits(ctx, EditsRequest{})
if !errors.Is(err, errTestMarshallerFailed) {
t.Fatalf("Did not return error when marshaller failed: %v", err)
}
_, err = client.CreateEmbeddings(ctx, EmbeddingRequest{})
if !errors.Is(err, errTestMarshallerFailed) {
t.Fatalf("Did not return error when marshaller failed: %v", err)
}
_, err = client.CreateImage(ctx, ImageRequest{})
if !errors.Is(err, errTestMarshallerFailed) {
t.Fatalf("Did not return error when marshaller failed: %v", err)
}
}

View File

@@ -3,7 +3,6 @@ package openai
import (
"bytes"
"context"
"encoding/json"
"net/http"
)
@@ -53,7 +52,7 @@ type ModerationResponse struct {
// Input can be an array or slice but a string will reduce the complexity.
func (c *Client) Moderations(ctx context.Context, request ModerationRequest) (response ModerationResponse, err error) {
var reqBytes []byte
reqBytes, err = json.Marshal(request)
reqBytes, err = c.marshaller.marshal(request)
if err != nil {
return
}