diff --git a/audio.go b/audio.go index 9cc5c51..9db9298 100644 --- a/audio.go +++ b/audio.go @@ -4,8 +4,6 @@ import ( "bytes" "context" "fmt" - "io" - "mime/multipart" "net/http" "os" ) @@ -55,9 +53,9 @@ func (c *Client) callAudioAPI( endpointSuffix string, ) (response AudioResponse, err error) { var formBody bytes.Buffer - w := multipart.NewWriter(&formBody) + builder := c.createFormBuilder(&formBody) - if err = audioMultipartForm(request, w); err != nil { + if err = audioMultipartForm(request, builder); err != nil { return } @@ -66,7 +64,7 @@ func (c *Client) callAudioAPI( if err != nil { return } - req.Header.Add("Content-Type", w.FormDataContentType()) + req.Header.Add("Content-Type", builder.formDataContentType()) err = c.sendRequest(req, &response) return @@ -74,73 +72,47 @@ func (c *Client) callAudioAPI( // audioMultipartForm creates a form with audio file contents and the name of the model to use for // audio processing. -func audioMultipartForm(request AudioRequest, w *multipart.Writer) error { +func audioMultipartForm(request AudioRequest, b formBuilder) error { f, err := os.Open(request.FilePath) if err != nil { return fmt.Errorf("opening audio file: %w", err) } defer f.Close() - fw, err := w.CreateFormFile("file", f.Name()) + err = b.createFormFile("file", f) if err != nil { return fmt.Errorf("creating form file: %w", err) } - if _, err = io.Copy(fw, f); err != nil { - return fmt.Errorf("reading from opened audio file: %w", err) - } - - fw, err = w.CreateFormField("model") + err = b.writeField("model", request.Model) if err != nil { - return fmt.Errorf("creating form field: %w", err) - } - - modelName := bytes.NewReader([]byte(request.Model)) - if _, err = io.Copy(fw, modelName); err != nil { return fmt.Errorf("writing model name: %w", err) } // Create a form field for the prompt (if provided) if request.Prompt != "" { - fw, err = w.CreateFormField("prompt") + err = b.writeField("prompt", request.Prompt) if err != nil { - return fmt.Errorf("creating form field: %w", err) - } - - prompt := bytes.NewReader([]byte(request.Prompt)) - if _, err = io.Copy(fw, prompt); err != nil { return fmt.Errorf("writing prompt: %w", err) } } // Create a form field for the temperature (if provided) if request.Temperature != 0 { - fw, err = w.CreateFormField("temperature") + err = b.writeField("temperature", fmt.Sprintf("%.2f", request.Temperature)) if err != nil { - return fmt.Errorf("creating form field: %w", err) - } - - temperature := bytes.NewReader([]byte(fmt.Sprintf("%.2f", request.Temperature))) - if _, err = io.Copy(fw, temperature); err != nil { return fmt.Errorf("writing temperature: %w", err) } } // Create a form field for the language (if provided) if request.Language != "" { - fw, err = w.CreateFormField("language") + err = b.writeField("language", request.Language) if err != nil { - return fmt.Errorf("creating form field: %w", err) - } - - language := bytes.NewReader([]byte(request.Language)) - if _, err = io.Copy(fw, language); err != nil { return fmt.Errorf("writing language: %w", err) } } // Close the multipart writer - w.Close() - - return nil + return b.close() } diff --git a/audio_test.go b/audio_test.go index 0870848..a6dac21 100644 --- a/audio_test.go +++ b/audio_test.go @@ -1,8 +1,9 @@ -package openai_test +package openai //nolint:testpackage // testing private field import ( "bytes" "errors" + "fmt" "io" "mime" "mime/multipart" @@ -11,7 +12,6 @@ import ( "path/filepath" "strings" - . "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai/internal/test" "github.com/sashabaranov/go-openai/internal/test/checks" @@ -188,3 +188,47 @@ func handleAudioEndpoint(w http.ResponseWriter, r *http.Request) { return } } + +func TestAudioWithFailingFormBuilder(t *testing.T) { + dir, cleanup := createTestDirectory(t) + defer cleanup() + path := filepath.Join(dir, "fake.mp3") + createTestFile(t, path) + + req := AudioRequest{ + FilePath: path, + Prompt: "test", + Temperature: 0.5, + Language: "en", + } + + mockFailedErr := fmt.Errorf("mock form builder fail") + mockBuilder := &mockFormBuilder{} + + mockBuilder.mockCreateFormFile = func(string, *os.File) error { + return mockFailedErr + } + err := audioMultipartForm(req, mockBuilder) + checks.ErrorIs(t, err, mockFailedErr, "audioMultipartForm should return error if form builder fails") + + mockBuilder.mockCreateFormFile = func(string, *os.File) error { + return nil + } + + var failForField string + mockBuilder.mockWriteField = func(fieldname, value string) error { + if fieldname == failForField { + return mockFailedErr + } + return nil + } + + failOn := []string{"model", "prompt", "temperature", "language"} + for _, failingField := range failOn { + failForField = failingField + mockFailedErr = fmt.Errorf("mock form builder fail on field %s", failingField) + + err = audioMultipartForm(req, mockBuilder) + checks.ErrorIs(t, err, mockFailedErr, "audioMultipartForm should return error if form builder fails") + } +} diff --git a/api.go b/client.go similarity index 95% rename from api.go rename to client.go index 2c978bc..c1f76d7 100644 --- a/api.go +++ b/client.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "io" "net/http" "strings" ) @@ -12,7 +13,8 @@ import ( type Client struct { config ClientConfig - requestBuilder requestBuilder + requestBuilder requestBuilder + createFormBuilder func(io.Writer) formBuilder } // NewClient creates new OpenAI API client. @@ -26,6 +28,9 @@ func NewClientWithConfig(config ClientConfig) *Client { return &Client{ config: config, requestBuilder: newRequestBuilder(), + createFormBuilder: func(body io.Writer) formBuilder { + return newFormBuilder(body) + }, } } diff --git a/config.go b/config.go index 52e1efc..c800df1 100644 --- a/config.go +++ b/config.go @@ -64,3 +64,7 @@ func DefaultAzureConfig(apiKey, baseURL, engine string) ClientConfig { EmptyMessagesLimit: defaultEmptyMessagesLimit, } } + +func (ClientConfig) String() string { + return "" +} diff --git a/files_test.go b/files_test.go index 3e8dfc4..fbfe11c 100644 --- a/files_test.go +++ b/files_test.go @@ -30,7 +30,7 @@ func TestFileUpload(t *testing.T) { req := FileRequest{ FileName: "test.go", - FilePath: "api.go", + FilePath: "client.go", Purpose: "fine-tune", } _, err = client.CreateFile(ctx, req) diff --git a/form_builder.go b/form_builder.go new file mode 100644 index 0000000..7fbb164 --- /dev/null +++ b/form_builder.go @@ -0,0 +1,49 @@ +package openai + +import ( + "io" + "mime/multipart" + "os" +) + +type formBuilder interface { + createFormFile(fieldname string, file *os.File) error + writeField(fieldname, value string) error + close() error + formDataContentType() string +} + +type defaultFormBuilder struct { + writer *multipart.Writer +} + +func newFormBuilder(body io.Writer) *defaultFormBuilder { + return &defaultFormBuilder{ + writer: multipart.NewWriter(body), + } +} + +func (fb *defaultFormBuilder) createFormFile(fieldname string, file *os.File) error { + fieldWriter, err := fb.writer.CreateFormFile(fieldname, file.Name()) + if err != nil { + return err + } + + _, err = io.Copy(fieldWriter, file) + if err != nil { + return err + } + return nil +} + +func (fb *defaultFormBuilder) writeField(fieldname, value string) error { + return fb.writer.WriteField(fieldname, value) +} + +func (fb *defaultFormBuilder) close() error { + return fb.writer.Close() +} + +func (fb *defaultFormBuilder) formDataContentType() string { + return fb.writer.FormDataContentType() +} diff --git a/image.go b/image.go index d10fde1..bda23f7 100644 --- a/image.go +++ b/image.go @@ -3,8 +3,6 @@ package openai import ( "bytes" "context" - "io" - "mime/multipart" "net/http" "os" "strconv" @@ -67,50 +65,46 @@ type ImageEditRequest struct { // CreateEditImage - API call to create an image. This is the main endpoint of the DALL-E API. func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest) (response ImageResponse, err error) { body := &bytes.Buffer{} - writer := multipart.NewWriter(body) + builder := c.createFormBuilder(body) // image - image, err := writer.CreateFormFile("image", request.Image.Name()) - if err != nil { - return - } - _, err = io.Copy(image, request.Image) + err = builder.createFormFile("image", request.Image) if err != nil { return } // mask, it is optional if request.Mask != nil { - mask, err2 := writer.CreateFormFile("mask", request.Mask.Name()) - if err2 != nil { - return - } - _, err = io.Copy(mask, request.Mask) + err = builder.createFormFile("mask", request.Mask) if err != nil { return } } - err = writer.WriteField("prompt", request.Prompt) + err = builder.writeField("prompt", request.Prompt) if err != nil { return } - err = writer.WriteField("n", strconv.Itoa(request.N)) + err = builder.writeField("n", strconv.Itoa(request.N)) if err != nil { return } - err = writer.WriteField("size", request.Size) + err = builder.writeField("size", request.Size) if err != nil { return } - writer.Close() + err = builder.close() + if err != nil { + return + } + urlSuffix := "/images/edits" req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix), body) if err != nil { return } - req.Header.Set("Content-Type", writer.FormDataContentType()) + req.Header.Set("Content-Type", builder.formDataContentType()) err = c.sendRequest(req, &response) return } @@ -126,27 +120,27 @@ type ImageVariRequest struct { // Use abbreviations(vari for variation) because ci-lint has a single-line length limit ... func (c *Client) CreateVariImage(ctx context.Context, request ImageVariRequest) (response ImageResponse, err error) { body := &bytes.Buffer{} - writer := multipart.NewWriter(body) + builder := c.createFormBuilder(body) // image - image, err := writer.CreateFormFile("image", request.Image.Name()) - if err != nil { - return - } - _, err = io.Copy(image, request.Image) + err = builder.createFormFile("image", request.Image) if err != nil { return } - err = writer.WriteField("n", strconv.Itoa(request.N)) + err = builder.writeField("n", strconv.Itoa(request.N)) if err != nil { return } - err = writer.WriteField("size", request.Size) + err = builder.writeField("size", request.Size) if err != nil { return } - writer.Close() + err = builder.close() + if err != nil { + return + } + //https://platform.openai.com/docs/api-reference/images/create-variation urlSuffix := "/images/variations" req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix), body) @@ -154,7 +148,7 @@ func (c *Client) CreateVariImage(ctx context.Context, request ImageVariRequest) return } - req.Header.Set("Content-Type", writer.FormDataContentType()) + req.Header.Set("Content-Type", builder.formDataContentType()) err = c.sendRequest(req, &response) return } diff --git a/image_test.go b/image_test.go index 9917b78..34367b8 100644 --- a/image_test.go +++ b/image_test.go @@ -1,7 +1,6 @@ -package openai_test +package openai //nolint:testpackage // testing private field import ( - . "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai/internal/test" "github.com/sashabaranov/go-openai/internal/test/checks" @@ -259,3 +258,136 @@ func handleVariateImageEndpoint(w http.ResponseWriter, r *http.Request) { resBytes, _ = json.Marshal(responses) fmt.Fprintln(w, string(resBytes)) } + +type mockFormBuilder struct { + mockCreateFormFile func(string, *os.File) error + mockWriteField func(string, string) error + mockClose func() error +} + +func (fb *mockFormBuilder) createFormFile(fieldname string, file *os.File) error { + return fb.mockCreateFormFile(fieldname, file) +} + +func (fb *mockFormBuilder) writeField(fieldname, value string) error { + return fb.mockWriteField(fieldname, value) +} + +func (fb *mockFormBuilder) close() error { + return fb.mockClose() +} + +func (fb *mockFormBuilder) formDataContentType() string { + return "" +} + +func TestImageFormBuilderFailures(t *testing.T) { + config := DefaultConfig("") + config.BaseURL = "" + client := NewClientWithConfig(config) + + mockBuilder := &mockFormBuilder{} + client.createFormBuilder = func(io.Writer) formBuilder { + return mockBuilder + } + ctx := context.Background() + + req := ImageEditRequest{ + Mask: &os.File{}, + } + + mockFailedErr := fmt.Errorf("mock form builder fail") + mockBuilder.mockCreateFormFile = func(string, *os.File) error { + return mockFailedErr + } + _, err := client.CreateEditImage(ctx, req) + checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails") + + mockBuilder.mockCreateFormFile = func(name string, file *os.File) error { + if name == "mask" { + return mockFailedErr + } + return nil + } + _, err = client.CreateEditImage(ctx, req) + checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails") + + mockBuilder.mockCreateFormFile = func(name string, file *os.File) error { + return nil + } + + var failForField string + mockBuilder.mockWriteField = func(fieldname, value string) error { + if fieldname == failForField { + return mockFailedErr + } + return nil + } + + failForField = "prompt" + _, err = client.CreateEditImage(ctx, req) + checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails") + + failForField = "n" + _, err = client.CreateEditImage(ctx, req) + checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails") + + failForField = "size" + _, err = client.CreateEditImage(ctx, req) + checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails") + + failForField = "" + mockBuilder.mockClose = func() error { + return mockFailedErr + } + _, err = client.CreateEditImage(ctx, req) + checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails") +} + +func TestVariImageFormBuilderFailures(t *testing.T) { + config := DefaultConfig("") + config.BaseURL = "" + client := NewClientWithConfig(config) + + mockBuilder := &mockFormBuilder{} + client.createFormBuilder = func(io.Writer) formBuilder { + return mockBuilder + } + ctx := context.Background() + + req := ImageVariRequest{} + + mockFailedErr := fmt.Errorf("mock form builder fail") + mockBuilder.mockCreateFormFile = func(string, *os.File) error { + return mockFailedErr + } + _, err := client.CreateVariImage(ctx, req) + checks.ErrorIs(t, err, mockFailedErr, "CreateVariImage should return error if form builder fails") + + mockBuilder.mockCreateFormFile = func(name string, file *os.File) error { + return nil + } + + var failForField string + mockBuilder.mockWriteField = func(fieldname, value string) error { + if fieldname == failForField { + return mockFailedErr + } + return nil + } + + failForField = "n" + _, err = client.CreateVariImage(ctx, req) + checks.ErrorIs(t, err, mockFailedErr, "CreateVariImage should return error if form builder fails") + + failForField = "size" + _, err = client.CreateVariImage(ctx, req) + checks.ErrorIs(t, err, mockFailedErr, "CreateVariImage should return error if form builder fails") + + failForField = "" + mockBuilder.mockClose = func() error { + return mockFailedErr + } + _, err = client.CreateVariImage(ctx, req) + checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails") +}