add testable json marshaller (#161)

This commit is contained in:
sashabaranov
2023-03-15 12:16:47 +04:00
committed by GitHub
parent ba77a6476e
commit 53d195cf5a
10 changed files with 102 additions and 18 deletions

13
api.go
View File

@@ -11,17 +11,22 @@ import (
// Client is OpenAI GPT-3 API client. // Client is OpenAI GPT-3 API client.
type Client struct { type Client struct {
config ClientConfig config ClientConfig
marshaller marshaller
} }
// NewClient creates new OpenAI API client. // NewClient creates new OpenAI API client.
func NewClient(authToken string) *Client { func NewClient(authToken string) *Client {
config := DefaultConfig(authToken) config := DefaultConfig(authToken)
return &Client{config} return NewClientWithConfig(config)
} }
// NewClientWithConfig creates new OpenAI API client for specified config. // NewClientWithConfig creates new OpenAI API client for specified config.
func NewClientWithConfig(config ClientConfig) *Client { func NewClientWithConfig(config ClientConfig) *Client {
return &Client{config} return &Client{
config: config,
marshaller: &jsonMarshaller{},
}
} }
// NewOrgClient creates new OpenAI API client for specified Organization ID. // NewOrgClient creates new OpenAI API client for specified Organization ID.
@@ -30,7 +35,7 @@ func NewClientWithConfig(config ClientConfig) *Client {
func NewOrgClient(authToken, org string) *Client { func NewOrgClient(authToken, org string) *Client {
config := DefaultConfig(authToken) config := DefaultConfig(authToken)
config.OrgID = org config.OrgID = org
return &Client{config} return NewClientWithConfig(config)
} }
func (c *Client) sendRequest(req *http.Request, v interface{}) error { func (c *Client) sendRequest(req *http.Request, v interface{}) error {
@@ -90,7 +95,7 @@ func (c *Client) newStreamRequest(
var reqBody []byte var reqBody []byte
if body != nil { if body != nil {
var err error var err error
reqBody, err = json.Marshal(body) reqBody, err = c.marshaller.marshal(body)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -3,7 +3,6 @@ package openai
import ( import (
"bytes" "bytes"
"context" "context"
"encoding/json"
"errors" "errors"
"net/http" "net/http"
) )
@@ -74,7 +73,7 @@ func (c *Client) CreateChatCompletion(
} }
var reqBytes []byte var reqBytes []byte
reqBytes, err = json.Marshal(request) reqBytes, err = c.marshaller.marshal(request)
if err != nil { if err != nil {
return return
} }

View File

@@ -3,7 +3,6 @@ package openai
import ( import (
"bytes" "bytes"
"context" "context"
"encoding/json"
"errors" "errors"
"net/http" "net/http"
) )
@@ -107,7 +106,7 @@ func (c *Client) CreateCompletion(
} }
var reqBytes []byte var reqBytes []byte
reqBytes, err = json.Marshal(request) reqBytes, err = c.marshaller.marshal(request)
if err != nil { if err != nil {
return return
} }

View File

@@ -3,7 +3,6 @@ package openai
import ( import (
"bytes" "bytes"
"context" "context"
"encoding/json"
"net/http" "net/http"
) )
@@ -34,7 +33,7 @@ type EditsResponse struct {
// Perform an API call to the Edits endpoint. // Perform an API call to the Edits endpoint.
func (c *Client) Edits(ctx context.Context, request EditsRequest) (response EditsResponse, err error) { func (c *Client) Edits(ctx context.Context, request EditsRequest) (response EditsResponse, err error) {
var reqBytes []byte var reqBytes []byte
reqBytes, err = json.Marshal(request) reqBytes, err = c.marshaller.marshal(request)
if err != nil { if err != nil {
return return
} }

View File

@@ -3,7 +3,6 @@ package openai
import ( import (
"bytes" "bytes"
"context" "context"
"encoding/json"
"net/http" "net/http"
) )
@@ -135,7 +134,7 @@ type EmbeddingRequest struct {
// https://beta.openai.com/docs/api-reference/embeddings/create // https://beta.openai.com/docs/api-reference/embeddings/create
func (c *Client) CreateEmbeddings(ctx context.Context, request EmbeddingRequest) (resp EmbeddingResponse, err error) { func (c *Client) CreateEmbeddings(ctx context.Context, request EmbeddingRequest) (resp EmbeddingResponse, err error) {
var reqBytes []byte var reqBytes []byte
reqBytes, err = json.Marshal(request) reqBytes, err = c.marshaller.marshal(request)
if err != nil { if err != nil {
return return
} }

View File

@@ -3,7 +3,6 @@ package openai
import ( import (
"bytes" "bytes"
"context" "context"
"encoding/json"
"fmt" "fmt"
"net/http" "net/http"
) )
@@ -70,7 +69,7 @@ type FineTuneDeleteResponse struct {
func (c *Client) CreateFineTune(ctx context.Context, request FineTuneRequest) (response FineTune, err error) { func (c *Client) CreateFineTune(ctx context.Context, request FineTuneRequest) (response FineTune, err error) {
var reqBytes []byte var reqBytes []byte
reqBytes, err = json.Marshal(request) reqBytes, err = c.marshaller.marshal(request)
if err != nil { if err != nil {
return return
} }

View File

@@ -3,7 +3,6 @@ package openai
import ( import (
"bytes" "bytes"
"context" "context"
"encoding/json"
"io" "io"
"mime/multipart" "mime/multipart"
"net/http" "net/http"
@@ -47,7 +46,7 @@ type ImageResponseDataInner struct {
// CreateImage - API call to create an image. This is the main endpoint of the DALL-E API. // CreateImage - API call to create an image. This is the main endpoint of the DALL-E API.
func (c *Client) CreateImage(ctx context.Context, request ImageRequest) (response ImageResponse, err error) { func (c *Client) CreateImage(ctx context.Context, request ImageRequest) (response ImageResponse, err error) {
var reqBytes []byte var reqBytes []byte
reqBytes, err = json.Marshal(request) reqBytes, err = c.marshaller.marshal(request)
if err != nil { if err != nil {
return return
} }

15
marshaller.go Normal file
View File

@@ -0,0 +1,15 @@
package openai
import (
"encoding/json"
)
type marshaller interface {
marshal(value any) ([]byte, error)
}
type jsonMarshaller struct{}
func (jm *jsonMarshaller) marshal(value any) ([]byte, error) {
return json.Marshal(value)
}

71
marshaller_test.go Normal file
View File

@@ -0,0 +1,71 @@
package openai //nolint:testpackage // testing private field
import (
"github.com/sashabaranov/go-openai/internal/test"
"context"
"errors"
"testing"
)
type failingMarshaller struct{}
var errTestMarshallerFailed = errors.New("test marshaller failed")
func (jm *failingMarshaller) marshal(value any) ([]byte, error) {
return []byte{}, errTestMarshallerFailed
}
func TestClientReturnMarshallerErrors(t *testing.T) {
var err error
ts := test.NewTestServer().OpenAITestServer()
ts.Start()
defer ts.Close()
config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
client.marshaller = &failingMarshaller{}
ctx := context.Background()
_, err = client.CreateCompletion(ctx, CompletionRequest{})
if !errors.Is(err, errTestMarshallerFailed) {
t.Fatalf("Did not return error when marshaller failed: %v", err)
}
_, err = client.CreateChatCompletion(ctx, ChatCompletionRequest{Model: GPT3Dot5Turbo})
if !errors.Is(err, errTestMarshallerFailed) {
t.Fatalf("Did not return error when marshaller failed: %v", err)
}
_, err = client.CreateChatCompletionStream(ctx, ChatCompletionRequest{})
if !errors.Is(err, errTestMarshallerFailed) {
t.Fatalf("Did not return error when marshaller failed: %v", err)
}
_, err = client.CreateFineTune(ctx, FineTuneRequest{})
if !errors.Is(err, errTestMarshallerFailed) {
t.Fatalf("Did not return error when marshaller failed: %v", err)
}
_, err = client.Moderations(ctx, ModerationRequest{})
if !errors.Is(err, errTestMarshallerFailed) {
t.Fatalf("Did not return error when marshaller failed: %v", err)
}
_, err = client.Edits(ctx, EditsRequest{})
if !errors.Is(err, errTestMarshallerFailed) {
t.Fatalf("Did not return error when marshaller failed: %v", err)
}
_, err = client.CreateEmbeddings(ctx, EmbeddingRequest{})
if !errors.Is(err, errTestMarshallerFailed) {
t.Fatalf("Did not return error when marshaller failed: %v", err)
}
_, err = client.CreateImage(ctx, ImageRequest{})
if !errors.Is(err, errTestMarshallerFailed) {
t.Fatalf("Did not return error when marshaller failed: %v", err)
}
}

View File

@@ -3,7 +3,6 @@ package openai
import ( import (
"bytes" "bytes"
"context" "context"
"encoding/json"
"net/http" "net/http"
) )
@@ -53,7 +52,7 @@ type ModerationResponse struct {
// Input can be an array or slice but a string will reduce the complexity. // Input can be an array or slice but a string will reduce the complexity.
func (c *Client) Moderations(ctx context.Context, request ModerationRequest) (response ModerationResponse, err error) { func (c *Client) Moderations(ctx context.Context, request ModerationRequest) (response ModerationResponse, err error) {
var reqBytes []byte var reqBytes []byte
reqBytes, err = json.Marshal(request) reqBytes, err = c.marshaller.marshal(request)
if err != nil { if err != nil {
return return
} }