Move form_builder into internal pkg. (#311)

* Move form_uilder into internal pkg.

* Fix import of audio.go

* Reorganize.

* Fix import.

* Fix

---------

Co-authored-by: JoyShi <joy.shi@sap.com>
This commit is contained in:
JoyShi
2023-05-17 04:38:09 +08:00
committed by GitHub
parent 83d03fca52
commit 21eef5bc8d
9 changed files with 96 additions and 90 deletions

View File

@@ -6,6 +6,8 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"os" "os"
utils "github.com/sashabaranov/go-openai/internal"
) )
// Whisper Defines the models provided by OpenAI to use when processing audio with OpenAI. // Whisper Defines the models provided by OpenAI to use when processing audio with OpenAI.
@@ -72,7 +74,7 @@ func (c *Client) callAudioAPI(
if err != nil { if err != nil {
return AudioResponse{}, err return AudioResponse{}, err
} }
req.Header.Add("Content-Type", builder.formDataContentType()) req.Header.Add("Content-Type", builder.FormDataContentType())
if request.HasJSONResponse() { if request.HasJSONResponse() {
err = c.sendRequest(req, &response) err = c.sendRequest(req, &response)
@@ -92,26 +94,26 @@ func (r AudioRequest) HasJSONResponse() bool {
// audioMultipartForm creates a form with audio file contents and the name of the model to use for // audioMultipartForm creates a form with audio file contents and the name of the model to use for
// audio processing. // audio processing.
func audioMultipartForm(request AudioRequest, b formBuilder) error { func audioMultipartForm(request AudioRequest, b utils.FormBuilder) error {
f, err := os.Open(request.FilePath) f, err := os.Open(request.FilePath)
if err != nil { if err != nil {
return fmt.Errorf("opening audio file: %w", err) return fmt.Errorf("opening audio file: %w", err)
} }
defer f.Close() defer f.Close()
err = b.createFormFile("file", f) err = b.CreateFormFile("file", f)
if err != nil { if err != nil {
return fmt.Errorf("creating form file: %w", err) return fmt.Errorf("creating form file: %w", err)
} }
err = b.writeField("model", request.Model) err = b.WriteField("model", request.Model)
if err != nil { if err != nil {
return fmt.Errorf("writing model name: %w", err) return fmt.Errorf("writing model name: %w", err)
} }
// Create a form field for the prompt (if provided) // Create a form field for the prompt (if provided)
if request.Prompt != "" { if request.Prompt != "" {
err = b.writeField("prompt", request.Prompt) err = b.WriteField("prompt", request.Prompt)
if err != nil { if err != nil {
return fmt.Errorf("writing prompt: %w", err) return fmt.Errorf("writing prompt: %w", err)
} }
@@ -119,7 +121,7 @@ func audioMultipartForm(request AudioRequest, b formBuilder) error {
// Create a form field for the format (if provided) // Create a form field for the format (if provided)
if request.Format != "" { if request.Format != "" {
err = b.writeField("response_format", string(request.Format)) err = b.WriteField("response_format", string(request.Format))
if err != nil { if err != nil {
return fmt.Errorf("writing format: %w", err) return fmt.Errorf("writing format: %w", err)
} }
@@ -127,7 +129,7 @@ func audioMultipartForm(request AudioRequest, b formBuilder) error {
// Create a form field for the temperature (if provided) // Create a form field for the temperature (if provided)
if request.Temperature != 0 { if request.Temperature != 0 {
err = b.writeField("temperature", fmt.Sprintf("%.2f", request.Temperature)) err = b.WriteField("temperature", fmt.Sprintf("%.2f", request.Temperature))
if err != nil { if err != nil {
return fmt.Errorf("writing temperature: %w", err) return fmt.Errorf("writing temperature: %w", err)
} }
@@ -135,12 +137,12 @@ func audioMultipartForm(request AudioRequest, b formBuilder) error {
// Create a form field for the language (if provided) // Create a form field for the language (if provided)
if request.Language != "" { if request.Language != "" {
err = b.writeField("language", request.Language) err = b.WriteField("language", request.Language)
if err != nil { if err != nil {
return fmt.Errorf("writing language: %w", err) return fmt.Errorf("writing language: %w", err)
} }
} }
// Close the multipart writer // Close the multipart writer
return b.close() return b.Close()
} }

View File

@@ -7,6 +7,8 @@ import (
"io" "io"
"net/http" "net/http"
"strings" "strings"
utils "github.com/sashabaranov/go-openai/internal"
) )
// Client is OpenAI GPT-3 API client. // Client is OpenAI GPT-3 API client.
@@ -14,7 +16,7 @@ type Client struct {
config ClientConfig config ClientConfig
requestBuilder requestBuilder requestBuilder requestBuilder
createFormBuilder func(io.Writer) formBuilder createFormBuilder func(io.Writer) utils.FormBuilder
} }
// NewClient creates new OpenAI API client. // NewClient creates new OpenAI API client.
@@ -28,8 +30,8 @@ func NewClientWithConfig(config ClientConfig) *Client {
return &Client{ return &Client{
config: config, config: config,
requestBuilder: newRequestBuilder(), requestBuilder: newRequestBuilder(),
createFormBuilder: func(body io.Writer) formBuilder { createFormBuilder: func(body io.Writer) utils.FormBuilder {
return newFormBuilder(body) return utils.NewFormBuilder(body)
}, },
} }
} }

View File

@@ -36,7 +36,7 @@ func (c *Client) CreateFile(ctx context.Context, request FileRequest) (file File
var b bytes.Buffer var b bytes.Buffer
builder := c.createFormBuilder(&b) builder := c.createFormBuilder(&b)
err = builder.writeField("purpose", request.Purpose) err = builder.WriteField("purpose", request.Purpose)
if err != nil { if err != nil {
return return
} }
@@ -46,12 +46,12 @@ func (c *Client) CreateFile(ctx context.Context, request FileRequest) (file File
return return
} }
err = builder.createFormFile("file", fileData) err = builder.CreateFormFile("file", fileData)
if err != nil { if err != nil {
return return
} }
err = builder.close() err = builder.Close()
if err != nil { if err != nil {
return return
} }
@@ -61,7 +61,7 @@ func (c *Client) CreateFile(ctx context.Context, request FileRequest) (file File
return return
} }
req.Header.Set("Content-Type", builder.formDataContentType()) req.Header.Set("Content-Type", builder.FormDataContentType())
err = c.sendRequest(req, &file) err = c.sendRequest(req, &file)

View File

@@ -1,6 +1,7 @@
package openai //nolint:testpackage // testing private field package openai //nolint:testpackage // testing private field
import ( import (
. "github.com/sashabaranov/go-openai/internal"
"github.com/sashabaranov/go-openai/internal/test" "github.com/sashabaranov/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks" "github.com/sashabaranov/go-openai/internal/test/checks"
@@ -85,7 +86,7 @@ func TestFileUploadWithFailingFormBuilder(t *testing.T) {
config.BaseURL = "" config.BaseURL = ""
client := NewClientWithConfig(config) client := NewClientWithConfig(config)
mockBuilder := &mockFormBuilder{} mockBuilder := &mockFormBuilder{}
client.createFormBuilder = func(io.Writer) formBuilder { client.createFormBuilder = func(io.Writer) FormBuilder {
return mockBuilder return mockBuilder
} }

View File

@@ -1,49 +0,0 @@
package openai
import (
"io"
"mime/multipart"
"os"
)
type formBuilder interface {
createFormFile(fieldname string, file *os.File) error
writeField(fieldname, value string) error
close() error
formDataContentType() string
}
type defaultFormBuilder struct {
writer *multipart.Writer
}
func newFormBuilder(body io.Writer) *defaultFormBuilder {
return &defaultFormBuilder{
writer: multipart.NewWriter(body),
}
}
func (fb *defaultFormBuilder) createFormFile(fieldname string, file *os.File) error {
fieldWriter, err := fb.writer.CreateFormFile(fieldname, file.Name())
if err != nil {
return err
}
_, err = io.Copy(fieldWriter, file)
if err != nil {
return err
}
return nil
}
func (fb *defaultFormBuilder) writeField(fieldname, value string) error {
return fb.writer.WriteField(fieldname, value)
}
func (fb *defaultFormBuilder) close() error {
return fb.writer.Close()
}
func (fb *defaultFormBuilder) formDataContentType() string {
return fb.writer.FormDataContentType()
}

View File

@@ -69,40 +69,40 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest)
builder := c.createFormBuilder(body) builder := c.createFormBuilder(body)
// image // image
err = builder.createFormFile("image", request.Image) err = builder.CreateFormFile("image", request.Image)
if err != nil { if err != nil {
return return
} }
// mask, it is optional // mask, it is optional
if request.Mask != nil { if request.Mask != nil {
err = builder.createFormFile("mask", request.Mask) err = builder.CreateFormFile("mask", request.Mask)
if err != nil { if err != nil {
return return
} }
} }
err = builder.writeField("prompt", request.Prompt) err = builder.WriteField("prompt", request.Prompt)
if err != nil { if err != nil {
return return
} }
err = builder.writeField("n", strconv.Itoa(request.N)) err = builder.WriteField("n", strconv.Itoa(request.N))
if err != nil { if err != nil {
return return
} }
err = builder.writeField("size", request.Size) err = builder.WriteField("size", request.Size)
if err != nil { if err != nil {
return return
} }
err = builder.writeField("response_format", request.ResponseFormat) err = builder.WriteField("response_format", request.ResponseFormat)
if err != nil { if err != nil {
return return
} }
err = builder.close() err = builder.Close()
if err != nil { if err != nil {
return return
} }
@@ -113,7 +113,7 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest)
return return
} }
req.Header.Set("Content-Type", builder.formDataContentType()) req.Header.Set("Content-Type", builder.FormDataContentType())
err = c.sendRequest(req, &response) err = c.sendRequest(req, &response)
return return
} }
@@ -133,27 +133,27 @@ func (c *Client) CreateVariImage(ctx context.Context, request ImageVariRequest)
builder := c.createFormBuilder(body) builder := c.createFormBuilder(body)
// image // image
err = builder.createFormFile("image", request.Image) err = builder.CreateFormFile("image", request.Image)
if err != nil { if err != nil {
return return
} }
err = builder.writeField("n", strconv.Itoa(request.N)) err = builder.WriteField("n", strconv.Itoa(request.N))
if err != nil { if err != nil {
return return
} }
err = builder.writeField("size", request.Size) err = builder.WriteField("size", request.Size)
if err != nil { if err != nil {
return return
} }
err = builder.writeField("response_format", request.ResponseFormat) err = builder.WriteField("response_format", request.ResponseFormat)
if err != nil { if err != nil {
return return
} }
err = builder.close() err = builder.Close()
if err != nil { if err != nil {
return return
} }
@@ -165,7 +165,7 @@ func (c *Client) CreateVariImage(ctx context.Context, request ImageVariRequest)
return return
} }
req.Header.Set("Content-Type", builder.formDataContentType()) req.Header.Set("Content-Type", builder.FormDataContentType())
err = c.sendRequest(req, &response) err = c.sendRequest(req, &response)
return return
} }

View File

@@ -1,6 +1,7 @@
package openai //nolint:testpackage // testing private field package openai //nolint:testpackage // testing private field
import ( import (
utils "github.com/sashabaranov/go-openai/internal"
"github.com/sashabaranov/go-openai/internal/test" "github.com/sashabaranov/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks" "github.com/sashabaranov/go-openai/internal/test/checks"
@@ -268,19 +269,19 @@ type mockFormBuilder struct {
mockClose func() error mockClose func() error
} }
func (fb *mockFormBuilder) createFormFile(fieldname string, file *os.File) error { func (fb *mockFormBuilder) CreateFormFile(fieldname string, file *os.File) error {
return fb.mockCreateFormFile(fieldname, file) return fb.mockCreateFormFile(fieldname, file)
} }
func (fb *mockFormBuilder) writeField(fieldname, value string) error { func (fb *mockFormBuilder) WriteField(fieldname, value string) error {
return fb.mockWriteField(fieldname, value) return fb.mockWriteField(fieldname, value)
} }
func (fb *mockFormBuilder) close() error { func (fb *mockFormBuilder) Close() error {
return fb.mockClose() return fb.mockClose()
} }
func (fb *mockFormBuilder) formDataContentType() string { func (fb *mockFormBuilder) FormDataContentType() string {
return "" return ""
} }
@@ -290,7 +291,7 @@ func TestImageFormBuilderFailures(t *testing.T) {
client := NewClientWithConfig(config) client := NewClientWithConfig(config)
mockBuilder := &mockFormBuilder{} mockBuilder := &mockFormBuilder{}
client.createFormBuilder = func(io.Writer) formBuilder { client.createFormBuilder = func(io.Writer) utils.FormBuilder {
return mockBuilder return mockBuilder
} }
ctx := context.Background() ctx := context.Background()
@@ -357,7 +358,7 @@ func TestVariImageFormBuilderFailures(t *testing.T) {
client := NewClientWithConfig(config) client := NewClientWithConfig(config)
mockBuilder := &mockFormBuilder{} mockBuilder := &mockFormBuilder{}
client.createFormBuilder = func(io.Writer) formBuilder { client.createFormBuilder = func(io.Writer) utils.FormBuilder {
return mockBuilder return mockBuilder
} }
ctx := context.Background() ctx := context.Background()

49
internal/form_builder.go Normal file
View File

@@ -0,0 +1,49 @@
package openai
import (
"io"
"mime/multipart"
"os"
)
type FormBuilder interface {
CreateFormFile(fieldname string, file *os.File) error
WriteField(fieldname, value string) error
Close() error
FormDataContentType() string
}
type DefaultFormBuilder struct {
writer *multipart.Writer
}
func NewFormBuilder(body io.Writer) *DefaultFormBuilder {
return &DefaultFormBuilder{
writer: multipart.NewWriter(body),
}
}
func (fb *DefaultFormBuilder) CreateFormFile(fieldname string, file *os.File) error {
fieldWriter, err := fb.writer.CreateFormFile(fieldname, file.Name())
if err != nil {
return err
}
_, err = io.Copy(fieldWriter, file)
if err != nil {
return err
}
return nil
}
func (fb *DefaultFormBuilder) WriteField(fieldname, value string) error {
return fb.writer.WriteField(fieldname, value)
}
func (fb *DefaultFormBuilder) Close() error {
return fb.writer.Close()
}
func (fb *DefaultFormBuilder) FormDataContentType() string {
return fb.writer.FormDataContentType()
}

View File

@@ -30,8 +30,8 @@ func TestFormBuilderWithFailingWriter(t *testing.T) {
defer file.Close() defer file.Close()
defer os.Remove(file.Name()) defer os.Remove(file.Name())
builder := newFormBuilder(&failingWriter{}) builder := NewFormBuilder(&failingWriter{})
err = builder.createFormFile("file", file) err = builder.CreateFormFile("file", file)
checks.ErrorIs(t, err, errMockFailingWriterError, "formbuilder should return error if writer fails") checks.ErrorIs(t, err, errMockFailingWriterError, "formbuilder should return error if writer fails")
} }
@@ -47,8 +47,8 @@ func TestFormBuilderWithClosedFile(t *testing.T) {
defer os.Remove(file.Name()) defer os.Remove(file.Name())
body := &bytes.Buffer{} body := &bytes.Buffer{}
builder := newFormBuilder(body) builder := NewFormBuilder(body)
err = builder.createFormFile("file", file) err = builder.CreateFormFile("file", file)
checks.HasError(t, err, "formbuilder should return error if file is closed") checks.HasError(t, err, "formbuilder should return error if file is closed")
checks.ErrorIs(t, err, os.ErrClosed, "formbuilder should return error if file is closed") checks.ErrorIs(t, err, os.ErrClosed, "formbuilder should return error if file is closed")
} }