Add form builder (#235)

* add form builder

* cover VariImage

* test for closing errors

* simplify tests

* add audio api test coverage

* don't leak authToken when printed

* rename api->client

* fix test
This commit is contained in:
sashabaranov
2023-04-08 19:26:26 +04:00
committed by GitHub
parent 2f3700f4c5
commit 226ff328e2
8 changed files with 272 additions and 72 deletions

View File

@@ -4,8 +4,6 @@ import (
"bytes"
"context"
"fmt"
"io"
"mime/multipart"
"net/http"
"os"
)
@@ -55,9 +53,9 @@ func (c *Client) callAudioAPI(
endpointSuffix string,
) (response AudioResponse, err error) {
var formBody bytes.Buffer
w := multipart.NewWriter(&formBody)
builder := c.createFormBuilder(&formBody)
if err = audioMultipartForm(request, w); err != nil {
if err = audioMultipartForm(request, builder); err != nil {
return
}
@@ -66,7 +64,7 @@ func (c *Client) callAudioAPI(
if err != nil {
return
}
req.Header.Add("Content-Type", w.FormDataContentType())
req.Header.Add("Content-Type", builder.formDataContentType())
err = c.sendRequest(req, &response)
return
@@ -74,73 +72,47 @@ func (c *Client) callAudioAPI(
// audioMultipartForm creates a form with audio file contents and the name of the model to use for
// audio processing.
func audioMultipartForm(request AudioRequest, w *multipart.Writer) error {
func audioMultipartForm(request AudioRequest, b formBuilder) error {
f, err := os.Open(request.FilePath)
if err != nil {
return fmt.Errorf("opening audio file: %w", err)
}
defer f.Close()
fw, err := w.CreateFormFile("file", f.Name())
err = b.createFormFile("file", f)
if err != nil {
return fmt.Errorf("creating form file: %w", err)
}
if _, err = io.Copy(fw, f); err != nil {
return fmt.Errorf("reading from opened audio file: %w", err)
}
fw, err = w.CreateFormField("model")
err = b.writeField("model", request.Model)
if err != nil {
return fmt.Errorf("creating form field: %w", err)
}
modelName := bytes.NewReader([]byte(request.Model))
if _, err = io.Copy(fw, modelName); err != nil {
return fmt.Errorf("writing model name: %w", err)
}
// Create a form field for the prompt (if provided)
if request.Prompt != "" {
fw, err = w.CreateFormField("prompt")
err = b.writeField("prompt", request.Prompt)
if err != nil {
return fmt.Errorf("creating form field: %w", err)
}
prompt := bytes.NewReader([]byte(request.Prompt))
if _, err = io.Copy(fw, prompt); err != nil {
return fmt.Errorf("writing prompt: %w", err)
}
}
// Create a form field for the temperature (if provided)
if request.Temperature != 0 {
fw, err = w.CreateFormField("temperature")
err = b.writeField("temperature", fmt.Sprintf("%.2f", request.Temperature))
if err != nil {
return fmt.Errorf("creating form field: %w", err)
}
temperature := bytes.NewReader([]byte(fmt.Sprintf("%.2f", request.Temperature)))
if _, err = io.Copy(fw, temperature); err != nil {
return fmt.Errorf("writing temperature: %w", err)
}
}
// Create a form field for the language (if provided)
if request.Language != "" {
fw, err = w.CreateFormField("language")
err = b.writeField("language", request.Language)
if err != nil {
return fmt.Errorf("creating form field: %w", err)
}
language := bytes.NewReader([]byte(request.Language))
if _, err = io.Copy(fw, language); err != nil {
return fmt.Errorf("writing language: %w", err)
}
}
// Close the multipart writer
w.Close()
return nil
return b.close()
}