Implement optional io.Reader in AudioRequest (#303) (#265) (#331)

* Implement optional io.Reader in AudioRequest (#303) (#265)

* Fix err shadowing

* Add test to cover AudioRequest io.Reader usage

* Add additional test cases to cover AudioRequest io.Reader usage

* Add test to cover opening the file specified in an AudioRequest
This commit is contained in:
Mariano Darc
2023-06-05 08:07:13 +02:00
committed by GitHub
parent 61ba5f3369
commit fa694c61c2
4 changed files with 124 additions and 18 deletions

View File

@@ -4,6 +4,7 @@ import (
"bytes" "bytes"
"context" "context"
"fmt" "fmt"
"io"
"net/http" "net/http"
"os" "os"
@@ -27,8 +28,14 @@ const (
// AudioRequest represents a request structure for audio API. // AudioRequest represents a request structure for audio API.
// ResponseFormat is not supported for now. We only return JSON text, which may be sufficient. // ResponseFormat is not supported for now. We only return JSON text, which may be sufficient.
type AudioRequest struct { type AudioRequest struct {
Model string Model string
FilePath string
// FilePath is either an existing file in your filesystem or a filename representing the contents of Reader.
FilePath string
// Reader is an optional io.Reader when you do not want to use an existing file.
Reader io.Reader
Prompt string // For translation, it should be in English Prompt string // For translation, it should be in English
Temperature float32 Temperature float32
Language string // For translation, just do not use it. It seems "en" works, not confirmed... Language string // For translation, just do not use it. It seems "en" works, not confirmed...
@@ -95,15 +102,9 @@ func (r AudioRequest) HasJSONResponse() bool {
// audioMultipartForm creates a form with audio file contents and the name of the model to use for // audioMultipartForm creates a form with audio file contents and the name of the model to use for
// audio processing. // audio processing.
func audioMultipartForm(request AudioRequest, b utils.FormBuilder) error { func audioMultipartForm(request AudioRequest, b utils.FormBuilder) error {
f, err := os.Open(request.FilePath) err := createFileField(request, b)
if err != nil { if err != nil {
return fmt.Errorf("opening audio file: %w", err) return err
}
defer f.Close()
err = b.CreateFormFile("file", f)
if err != nil {
return fmt.Errorf("creating form file: %w", err)
} }
err = b.WriteField("model", request.Model) err = b.WriteField("model", request.Model)
@@ -146,3 +147,27 @@ func audioMultipartForm(request AudioRequest, b utils.FormBuilder) error {
// Close the multipart writer // Close the multipart writer
return b.Close() return b.Close()
} }
// createFileField creates the "file" form field from either an existing file or by using the reader.
func createFileField(request AudioRequest, b utils.FormBuilder) error {
if request.Reader != nil {
err := b.CreateFormFileReader("file", request.Reader, request.FilePath)
if err != nil {
return fmt.Errorf("creating form using reader: %w", err)
}
return nil
}
f, err := os.Open(request.FilePath)
if err != nil {
return fmt.Errorf("opening audio file: %w", err)
}
defer f.Close()
err = b.CreateFormFile("file", f)
if err != nil {
return fmt.Errorf("creating form file: %w", err)
}
return nil
}

View File

@@ -2,6 +2,7 @@ package openai //nolint:testpackage // testing private field
import ( import (
"bytes" "bytes"
"context"
"errors" "errors"
"fmt" "fmt"
"io" "io"
@@ -11,12 +12,10 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"strings" "strings"
"testing"
"github.com/sashabaranov/go-openai/internal/test" "github.com/sashabaranov/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks" "github.com/sashabaranov/go-openai/internal/test/checks"
"context"
"testing"
) )
// TestAudio Tests the transcription and translation endpoints of the API using the mocked server. // TestAudio Tests the transcription and translation endpoints of the API using the mocked server.
@@ -65,6 +64,16 @@ func TestAudio(t *testing.T) {
_, err = tc.createFn(ctx, req) _, err = tc.createFn(ctx, req)
checks.NoError(t, err, "audio API error") checks.NoError(t, err, "audio API error")
}) })
t.Run(tc.name+" (with reader)", func(t *testing.T) {
req := AudioRequest{
FilePath: "fake.webm",
Reader: bytes.NewBuffer([]byte(`some webm binary data`)),
Model: "whisper-3",
}
_, err = tc.createFn(ctx, req)
checks.NoError(t, err, "audio API error")
})
} }
} }
@@ -213,3 +222,54 @@ func TestAudioWithFailingFormBuilder(t *testing.T) {
checks.ErrorIs(t, err, mockFailedErr, "audioMultipartForm should return error if form builder fails") checks.ErrorIs(t, err, mockFailedErr, "audioMultipartForm should return error if form builder fails")
} }
} }
func TestCreateFileField(t *testing.T) {
t.Run("createFileField failing file", func(t *testing.T) {
dir, cleanup := test.CreateTestDirectory(t)
defer cleanup()
path := filepath.Join(dir, "fake.mp3")
test.CreateTestFile(t, path)
req := AudioRequest{
FilePath: path,
}
mockFailedErr := fmt.Errorf("mock form builder fail")
mockBuilder := &mockFormBuilder{
mockCreateFormFile: func(string, *os.File) error {
return mockFailedErr
},
}
err := createFileField(req, mockBuilder)
checks.ErrorIs(t, err, mockFailedErr, "createFileField using a file should return error if form builder fails")
})
t.Run("createFileField failing reader", func(t *testing.T) {
req := AudioRequest{
FilePath: "test.wav",
Reader: bytes.NewBuffer([]byte(`wav test contents`)),
}
mockFailedErr := fmt.Errorf("mock form builder fail")
mockBuilder := &mockFormBuilder{
mockCreateFormFileReader: func(string, io.Reader, string) error {
return mockFailedErr
},
}
err := createFileField(req, mockBuilder)
checks.ErrorIs(t, err, mockFailedErr, "createFileField using a reader should return error if form builder fails")
})
t.Run("createFileField failing open", func(t *testing.T) {
req := AudioRequest{
FilePath: "non_existing_file.wav",
}
mockBuilder := &mockFormBuilder{}
err := createFileField(req, mockBuilder)
checks.HasError(t, err, "createFileField using file should return error when open file fails")
})
}

View File

@@ -264,15 +264,20 @@ func handleVariateImageEndpoint(w http.ResponseWriter, r *http.Request) {
} }
type mockFormBuilder struct { type mockFormBuilder struct {
mockCreateFormFile func(string, *os.File) error mockCreateFormFile func(string, *os.File) error
mockWriteField func(string, string) error mockCreateFormFileReader func(string, io.Reader, string) error
mockClose func() error mockWriteField func(string, string) error
mockClose func() error
} }
func (fb *mockFormBuilder) CreateFormFile(fieldname string, file *os.File) error { func (fb *mockFormBuilder) CreateFormFile(fieldname string, file *os.File) error {
return fb.mockCreateFormFile(fieldname, file) return fb.mockCreateFormFile(fieldname, file)
} }
func (fb *mockFormBuilder) CreateFormFileReader(fieldname string, r io.Reader, filename string) error {
return fb.mockCreateFormFileReader(fieldname, r, filename)
}
func (fb *mockFormBuilder) WriteField(fieldname, value string) error { func (fb *mockFormBuilder) WriteField(fieldname, value string) error {
return fb.mockWriteField(fieldname, value) return fb.mockWriteField(fieldname, value)
} }

View File

@@ -1,13 +1,16 @@
package openai package openai
import ( import (
"fmt"
"io" "io"
"mime/multipart" "mime/multipart"
"os" "os"
"path"
) )
type FormBuilder interface { type FormBuilder interface {
CreateFormFile(fieldname string, file *os.File) error CreateFormFile(fieldname string, file *os.File) error
CreateFormFileReader(fieldname string, r io.Reader, filename string) error
WriteField(fieldname, value string) error WriteField(fieldname, value string) error
Close() error Close() error
FormDataContentType() string FormDataContentType() string
@@ -24,15 +27,28 @@ func NewFormBuilder(body io.Writer) *DefaultFormBuilder {
} }
func (fb *DefaultFormBuilder) CreateFormFile(fieldname string, file *os.File) error { func (fb *DefaultFormBuilder) CreateFormFile(fieldname string, file *os.File) error {
fieldWriter, err := fb.writer.CreateFormFile(fieldname, file.Name()) return fb.createFormFile(fieldname, file, file.Name())
}
func (fb *DefaultFormBuilder) CreateFormFileReader(fieldname string, r io.Reader, filename string) error {
return fb.createFormFile(fieldname, r, path.Base(filename))
}
func (fb *DefaultFormBuilder) createFormFile(fieldname string, r io.Reader, filename string) error {
if filename == "" {
return fmt.Errorf("filename cannot be empty")
}
fieldWriter, err := fb.writer.CreateFormFile(fieldname, filename)
if err != nil { if err != nil {
return err return err
} }
_, err = io.Copy(fieldWriter, file) _, err = io.Copy(fieldWriter, r)
if err != nil { if err != nil {
return err return err
} }
return nil return nil
} }