From 8c65b35c57ad4e9ba408def9bf9ff97817aab932 Mon Sep 17 00:00:00 2001 From: Axb12 <67110563+Axb12@users.noreply.github.com> Date: Tue, 20 May 2025 21:45:40 +0800 Subject: [PATCH] update image api *os.File to io.Reader (#994) * update image api *os.File to io.Reader * update code style * add reader test * supplementary reader test * update the reader in the form builder test * add commnet * update comment * update code style --- image.go | 43 ++++++++++++++++++----------------- image_test.go | 8 +++---- internal/form_builder.go | 35 ++++++++++++++++++++++++++-- internal/form_builder_test.go | 29 +++++++++++++++++++++++ 4 files changed, 88 insertions(+), 27 deletions(-) diff --git a/image.go b/image.go index d62622a..72077ce 100644 --- a/image.go +++ b/image.go @@ -3,8 +3,8 @@ package openai import ( "bytes" "context" + "io" "net/http" - "os" "strconv" ) @@ -134,15 +134,15 @@ func (c *Client) CreateImage(ctx context.Context, request ImageRequest) (respons // ImageEditRequest represents the request structure for the image API. type ImageEditRequest struct { - Image *os.File `json:"image,omitempty"` - Mask *os.File `json:"mask,omitempty"` - Prompt string `json:"prompt,omitempty"` - Model string `json:"model,omitempty"` - N int `json:"n,omitempty"` - Size string `json:"size,omitempty"` - ResponseFormat string `json:"response_format,omitempty"` - Quality string `json:"quality,omitempty"` - User string `json:"user,omitempty"` + Image io.Reader `json:"image,omitempty"` + Mask io.Reader `json:"mask,omitempty"` + Prompt string `json:"prompt,omitempty"` + Model string `json:"model,omitempty"` + N int `json:"n,omitempty"` + Size string `json:"size,omitempty"` + ResponseFormat string `json:"response_format,omitempty"` + Quality string `json:"quality,omitempty"` + User string `json:"user,omitempty"` } // CreateEditImage - API call to create an image. This is the main endpoint of the DALL-E API. @@ -150,15 +150,16 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest) body := &bytes.Buffer{} builder := c.createFormBuilder(body) - // image - err = builder.CreateFormFile("image", request.Image) + // image, filename is not required + err = builder.CreateFormFileReader("image", request.Image, "") if err != nil { return } // mask, it is optional if request.Mask != nil { - err = builder.CreateFormFile("mask", request.Mask) + // mask, filename is not required + err = builder.CreateFormFileReader("mask", request.Mask, "") if err != nil { return } @@ -206,12 +207,12 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest) // ImageVariRequest represents the request structure for the image API. type ImageVariRequest struct { - Image *os.File `json:"image,omitempty"` - Model string `json:"model,omitempty"` - N int `json:"n,omitempty"` - Size string `json:"size,omitempty"` - ResponseFormat string `json:"response_format,omitempty"` - User string `json:"user,omitempty"` + Image io.Reader `json:"image,omitempty"` + Model string `json:"model,omitempty"` + N int `json:"n,omitempty"` + Size string `json:"size,omitempty"` + ResponseFormat string `json:"response_format,omitempty"` + User string `json:"user,omitempty"` } // CreateVariImage - API call to create an image variation. This is the main endpoint of the DALL-E API. @@ -220,8 +221,8 @@ func (c *Client) CreateVariImage(ctx context.Context, request ImageVariRequest) body := &bytes.Buffer{} builder := c.createFormBuilder(body) - // image - err = builder.CreateFormFile("image", request.Image) + // image, filename is not required + err = builder.CreateFormFileReader("image", request.Image, "") if err != nil { return } diff --git a/image_test.go b/image_test.go index 9332dd5..6440055 100644 --- a/image_test.go +++ b/image_test.go @@ -54,13 +54,13 @@ func TestImageFormBuilderFailures(t *testing.T) { } mockFailedErr := fmt.Errorf("mock form builder fail") - mockBuilder.mockCreateFormFile = func(string, *os.File) error { + mockBuilder.mockCreateFormFileReader = func(string, io.Reader, string) error { return mockFailedErr } _, err := client.CreateEditImage(ctx, req) checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails") - mockBuilder.mockCreateFormFile = func(name string, _ *os.File) error { + mockBuilder.mockCreateFormFileReader = func(name string, _ io.Reader, _ string) error { if name == "mask" { return mockFailedErr } @@ -119,13 +119,13 @@ func TestVariImageFormBuilderFailures(t *testing.T) { req := ImageVariRequest{} mockFailedErr := fmt.Errorf("mock form builder fail") - mockBuilder.mockCreateFormFile = func(string, *os.File) error { + mockBuilder.mockCreateFormFileReader = func(string, io.Reader, string) error { return mockFailedErr } _, err := client.CreateVariImage(ctx, req) checks.ErrorIs(t, err, mockFailedErr, "CreateVariImage should return error if form builder fails") - mockBuilder.mockCreateFormFile = func(string, *os.File) error { + mockBuilder.mockCreateFormFileReader = func(string, io.Reader, string) error { return nil } diff --git a/internal/form_builder.go b/internal/form_builder.go index 2224fad..1c2513d 100644 --- a/internal/form_builder.go +++ b/internal/form_builder.go @@ -4,8 +4,10 @@ import ( "fmt" "io" "mime/multipart" + "net/textproto" "os" - "path" + "path/filepath" + "strings" ) type FormBuilder interface { @@ -30,8 +32,37 @@ func (fb *DefaultFormBuilder) CreateFormFile(fieldname string, file *os.File) er return fb.createFormFile(fieldname, file, file.Name()) } +var quoteEscaper = strings.NewReplacer("\\", "\\\\", `"`, "\\\"") + +func escapeQuotes(s string) string { + return quoteEscaper.Replace(s) +} + +// CreateFormFileReader creates a form field with a file reader. +// The filename in parameters can be an empty string. +// The filename in Content-Disposition is required, But it can be an empty string. func (fb *DefaultFormBuilder) CreateFormFileReader(fieldname string, r io.Reader, filename string) error { - return fb.createFormFile(fieldname, r, path.Base(filename)) + h := make(textproto.MIMEHeader) + h.Set( + "Content-Disposition", + fmt.Sprintf( + `form-data; name="%s"; filename="%s"`, + escapeQuotes(fieldname), + escapeQuotes(filepath.Base(filename)), + ), + ) + + fieldWriter, err := fb.writer.CreatePart(h) + if err != nil { + return err + } + + _, err = io.Copy(fieldWriter, r) + if err != nil { + return err + } + + return nil } func (fb *DefaultFormBuilder) createFormFile(fieldname string, r io.Reader, filename string) error { diff --git a/internal/form_builder_test.go b/internal/form_builder_test.go index 8df989e..76922c1 100644 --- a/internal/form_builder_test.go +++ b/internal/form_builder_test.go @@ -43,3 +43,32 @@ func TestFormBuilderWithClosedFile(t *testing.T) { checks.HasError(t, err, "formbuilder should return error if file is closed") checks.ErrorIs(t, err, os.ErrClosed, "formbuilder should return error if file is closed") } + +type failingReader struct { +} + +var errMockFailingReaderError = errors.New("mock reader failed") + +func (*failingReader) Read([]byte) (int, error) { + return 0, errMockFailingReaderError +} + +func TestFormBuilderWithReader(t *testing.T) { + file, err := os.CreateTemp(t.TempDir(), "") + if err != nil { + t.Fatalf("Error creating tmp file: %v", err) + } + defer file.Close() + builder := NewFormBuilder(&failingWriter{}) + err = builder.CreateFormFileReader("file", file, file.Name()) + checks.ErrorIs(t, err, errMockFailingWriterError, "formbuilder should return error if writer fails") + + builder = NewFormBuilder(&bytes.Buffer{}) + reader := &failingReader{} + err = builder.CreateFormFileReader("file", reader, "") + checks.ErrorIs(t, err, errMockFailingReaderError, "formbuilder should return error if copy reader fails") + + successReader := &bytes.Buffer{} + err = builder.CreateFormFileReader("file", successReader, "") + checks.NoError(t, err, "formbuilder should not return error") +}