add testable json marshaller (#161)
This commit is contained in:
13
api.go
13
api.go
@@ -11,17 +11,22 @@ import (
|
|||||||
// Client is OpenAI GPT-3 API client.
|
// Client is OpenAI GPT-3 API client.
|
||||||
type Client struct {
|
type Client struct {
|
||||||
config ClientConfig
|
config ClientConfig
|
||||||
|
|
||||||
|
marshaller marshaller
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewClient creates new OpenAI API client.
|
// NewClient creates new OpenAI API client.
|
||||||
func NewClient(authToken string) *Client {
|
func NewClient(authToken string) *Client {
|
||||||
config := DefaultConfig(authToken)
|
config := DefaultConfig(authToken)
|
||||||
return &Client{config}
|
return NewClientWithConfig(config)
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewClientWithConfig creates new OpenAI API client for specified config.
|
// NewClientWithConfig creates new OpenAI API client for specified config.
|
||||||
func NewClientWithConfig(config ClientConfig) *Client {
|
func NewClientWithConfig(config ClientConfig) *Client {
|
||||||
return &Client{config}
|
return &Client{
|
||||||
|
config: config,
|
||||||
|
marshaller: &jsonMarshaller{},
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewOrgClient creates new OpenAI API client for specified Organization ID.
|
// NewOrgClient creates new OpenAI API client for specified Organization ID.
|
||||||
@@ -30,7 +35,7 @@ func NewClientWithConfig(config ClientConfig) *Client {
|
|||||||
func NewOrgClient(authToken, org string) *Client {
|
func NewOrgClient(authToken, org string) *Client {
|
||||||
config := DefaultConfig(authToken)
|
config := DefaultConfig(authToken)
|
||||||
config.OrgID = org
|
config.OrgID = org
|
||||||
return &Client{config}
|
return NewClientWithConfig(config)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) sendRequest(req *http.Request, v interface{}) error {
|
func (c *Client) sendRequest(req *http.Request, v interface{}) error {
|
||||||
@@ -90,7 +95,7 @@ func (c *Client) newStreamRequest(
|
|||||||
var reqBody []byte
|
var reqBody []byte
|
||||||
if body != nil {
|
if body != nil {
|
||||||
var err error
|
var err error
|
||||||
reqBody, err = json.Marshal(body)
|
reqBody, err = c.marshaller.marshal(body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
3
chat.go
3
chat.go
@@ -3,7 +3,6 @@ package openai
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
)
|
)
|
||||||
@@ -74,7 +73,7 @@ func (c *Client) CreateChatCompletion(
|
|||||||
}
|
}
|
||||||
|
|
||||||
var reqBytes []byte
|
var reqBytes []byte
|
||||||
reqBytes, err = json.Marshal(request)
|
reqBytes, err = c.marshaller.marshal(request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ package openai
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
)
|
)
|
||||||
@@ -107,7 +106,7 @@ func (c *Client) CreateCompletion(
|
|||||||
}
|
}
|
||||||
|
|
||||||
var reqBytes []byte
|
var reqBytes []byte
|
||||||
reqBytes, err = json.Marshal(request)
|
reqBytes, err = c.marshaller.marshal(request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
3
edits.go
3
edits.go
@@ -3,7 +3,6 @@ package openai
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -34,7 +33,7 @@ type EditsResponse struct {
|
|||||||
// Perform an API call to the Edits endpoint.
|
// Perform an API call to the Edits endpoint.
|
||||||
func (c *Client) Edits(ctx context.Context, request EditsRequest) (response EditsResponse, err error) {
|
func (c *Client) Edits(ctx context.Context, request EditsRequest) (response EditsResponse, err error) {
|
||||||
var reqBytes []byte
|
var reqBytes []byte
|
||||||
reqBytes, err = json.Marshal(request)
|
reqBytes, err = c.marshaller.marshal(request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ package openai
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -135,7 +134,7 @@ type EmbeddingRequest struct {
|
|||||||
// https://beta.openai.com/docs/api-reference/embeddings/create
|
// https://beta.openai.com/docs/api-reference/embeddings/create
|
||||||
func (c *Client) CreateEmbeddings(ctx context.Context, request EmbeddingRequest) (resp EmbeddingResponse, err error) {
|
func (c *Client) CreateEmbeddings(ctx context.Context, request EmbeddingRequest) (resp EmbeddingResponse, err error) {
|
||||||
var reqBytes []byte
|
var reqBytes []byte
|
||||||
reqBytes, err = json.Marshal(request)
|
reqBytes, err = c.marshaller.marshal(request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ package openai
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
)
|
)
|
||||||
@@ -70,7 +69,7 @@ type FineTuneDeleteResponse struct {
|
|||||||
|
|
||||||
func (c *Client) CreateFineTune(ctx context.Context, request FineTuneRequest) (response FineTune, err error) {
|
func (c *Client) CreateFineTune(ctx context.Context, request FineTuneRequest) (response FineTune, err error) {
|
||||||
var reqBytes []byte
|
var reqBytes []byte
|
||||||
reqBytes, err = json.Marshal(request)
|
reqBytes, err = c.marshaller.marshal(request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
3
image.go
3
image.go
@@ -3,7 +3,6 @@ package openai
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"io"
|
"io"
|
||||||
"mime/multipart"
|
"mime/multipart"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -47,7 +46,7 @@ type ImageResponseDataInner struct {
|
|||||||
// CreateImage - API call to create an image. This is the main endpoint of the DALL-E API.
|
// CreateImage - API call to create an image. This is the main endpoint of the DALL-E API.
|
||||||
func (c *Client) CreateImage(ctx context.Context, request ImageRequest) (response ImageResponse, err error) {
|
func (c *Client) CreateImage(ctx context.Context, request ImageRequest) (response ImageResponse, err error) {
|
||||||
var reqBytes []byte
|
var reqBytes []byte
|
||||||
reqBytes, err = json.Marshal(request)
|
reqBytes, err = c.marshaller.marshal(request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
15
marshaller.go
Normal file
15
marshaller.go
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
)
|
||||||
|
|
||||||
|
type marshaller interface {
|
||||||
|
marshal(value any) ([]byte, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type jsonMarshaller struct{}
|
||||||
|
|
||||||
|
func (jm *jsonMarshaller) marshal(value any) ([]byte, error) {
|
||||||
|
return json.Marshal(value)
|
||||||
|
}
|
||||||
71
marshaller_test.go
Normal file
71
marshaller_test.go
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
package openai //nolint:testpackage // testing private field
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/sashabaranov/go-openai/internal/test"
|
||||||
|
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
type failingMarshaller struct{}
|
||||||
|
|
||||||
|
var errTestMarshallerFailed = errors.New("test marshaller failed")
|
||||||
|
|
||||||
|
func (jm *failingMarshaller) marshal(value any) ([]byte, error) {
|
||||||
|
return []byte{}, errTestMarshallerFailed
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientReturnMarshallerErrors(t *testing.T) {
|
||||||
|
var err error
|
||||||
|
ts := test.NewTestServer().OpenAITestServer()
|
||||||
|
ts.Start()
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
config := DefaultConfig(test.GetTestToken())
|
||||||
|
config.BaseURL = ts.URL + "/v1"
|
||||||
|
client := NewClientWithConfig(config)
|
||||||
|
client.marshaller = &failingMarshaller{}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
_, err = client.CreateCompletion(ctx, CompletionRequest{})
|
||||||
|
if !errors.Is(err, errTestMarshallerFailed) {
|
||||||
|
t.Fatalf("Did not return error when marshaller failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = client.CreateChatCompletion(ctx, ChatCompletionRequest{Model: GPT3Dot5Turbo})
|
||||||
|
if !errors.Is(err, errTestMarshallerFailed) {
|
||||||
|
t.Fatalf("Did not return error when marshaller failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = client.CreateChatCompletionStream(ctx, ChatCompletionRequest{})
|
||||||
|
if !errors.Is(err, errTestMarshallerFailed) {
|
||||||
|
t.Fatalf("Did not return error when marshaller failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = client.CreateFineTune(ctx, FineTuneRequest{})
|
||||||
|
if !errors.Is(err, errTestMarshallerFailed) {
|
||||||
|
t.Fatalf("Did not return error when marshaller failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = client.Moderations(ctx, ModerationRequest{})
|
||||||
|
if !errors.Is(err, errTestMarshallerFailed) {
|
||||||
|
t.Fatalf("Did not return error when marshaller failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = client.Edits(ctx, EditsRequest{})
|
||||||
|
if !errors.Is(err, errTestMarshallerFailed) {
|
||||||
|
t.Fatalf("Did not return error when marshaller failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = client.CreateEmbeddings(ctx, EmbeddingRequest{})
|
||||||
|
if !errors.Is(err, errTestMarshallerFailed) {
|
||||||
|
t.Fatalf("Did not return error when marshaller failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = client.CreateImage(ctx, ImageRequest{})
|
||||||
|
if !errors.Is(err, errTestMarshallerFailed) {
|
||||||
|
t.Fatalf("Did not return error when marshaller failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -3,7 +3,6 @@ package openai
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -53,7 +52,7 @@ type ModerationResponse struct {
|
|||||||
// Input can be an array or slice but a string will reduce the complexity.
|
// Input can be an array or slice but a string will reduce the complexity.
|
||||||
func (c *Client) Moderations(ctx context.Context, request ModerationRequest) (response ModerationResponse, err error) {
|
func (c *Client) Moderations(ctx context.Context, request ModerationRequest) (response ModerationResponse, err error) {
|
||||||
var reqBytes []byte
|
var reqBytes []byte
|
||||||
reqBytes, err = json.Marshal(request)
|
reqBytes, err = c.marshaller.marshal(request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user