* Implement optional io.Reader in AudioRequest (#303) (#265) * Fix err shadowing * Add test to cover AudioRequest io.Reader usage * Add additional test cases to cover AudioRequest io.Reader usage * Add test to cover opening the file specified in an AudioRequest
This commit is contained in:
41
audio.go
41
audio.go
@@ -4,6 +4,7 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
@@ -28,7 +29,13 @@ const (
|
||||
// ResponseFormat is not supported for now. We only return JSON text, which may be sufficient.
|
||||
type AudioRequest struct {
|
||||
Model string
|
||||
|
||||
// FilePath is either an existing file in your filesystem or a filename representing the contents of Reader.
|
||||
FilePath string
|
||||
|
||||
// Reader is an optional io.Reader when you do not want to use an existing file.
|
||||
Reader io.Reader
|
||||
|
||||
Prompt string // For translation, it should be in English
|
||||
Temperature float32
|
||||
Language string // For translation, just do not use it. It seems "en" works, not confirmed...
|
||||
@@ -95,15 +102,9 @@ func (r AudioRequest) HasJSONResponse() bool {
|
||||
// audioMultipartForm creates a form with audio file contents and the name of the model to use for
|
||||
// audio processing.
|
||||
func audioMultipartForm(request AudioRequest, b utils.FormBuilder) error {
|
||||
f, err := os.Open(request.FilePath)
|
||||
err := createFileField(request, b)
|
||||
if err != nil {
|
||||
return fmt.Errorf("opening audio file: %w", err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
err = b.CreateFormFile("file", f)
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating form file: %w", err)
|
||||
return err
|
||||
}
|
||||
|
||||
err = b.WriteField("model", request.Model)
|
||||
@@ -146,3 +147,27 @@ func audioMultipartForm(request AudioRequest, b utils.FormBuilder) error {
|
||||
// Close the multipart writer
|
||||
return b.Close()
|
||||
}
|
||||
|
||||
// createFileField creates the "file" form field from either an existing file or by using the reader.
|
||||
func createFileField(request AudioRequest, b utils.FormBuilder) error {
|
||||
if request.Reader != nil {
|
||||
err := b.CreateFormFileReader("file", request.Reader, request.FilePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating form using reader: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
f, err := os.Open(request.FilePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("opening audio file: %w", err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
err = b.CreateFormFile("file", f)
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating form file: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package openai //nolint:testpackage // testing private field
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -11,12 +12,10 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/sashabaranov/go-openai/internal/test"
|
||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||
|
||||
"context"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestAudio Tests the transcription and translation endpoints of the API using the mocked server.
|
||||
@@ -65,6 +64,16 @@ func TestAudio(t *testing.T) {
|
||||
_, err = tc.createFn(ctx, req)
|
||||
checks.NoError(t, err, "audio API error")
|
||||
})
|
||||
|
||||
t.Run(tc.name+" (with reader)", func(t *testing.T) {
|
||||
req := AudioRequest{
|
||||
FilePath: "fake.webm",
|
||||
Reader: bytes.NewBuffer([]byte(`some webm binary data`)),
|
||||
Model: "whisper-3",
|
||||
}
|
||||
_, err = tc.createFn(ctx, req)
|
||||
checks.NoError(t, err, "audio API error")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -213,3 +222,54 @@ func TestAudioWithFailingFormBuilder(t *testing.T) {
|
||||
checks.ErrorIs(t, err, mockFailedErr, "audioMultipartForm should return error if form builder fails")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateFileField(t *testing.T) {
|
||||
t.Run("createFileField failing file", func(t *testing.T) {
|
||||
dir, cleanup := test.CreateTestDirectory(t)
|
||||
defer cleanup()
|
||||
path := filepath.Join(dir, "fake.mp3")
|
||||
test.CreateTestFile(t, path)
|
||||
|
||||
req := AudioRequest{
|
||||
FilePath: path,
|
||||
}
|
||||
|
||||
mockFailedErr := fmt.Errorf("mock form builder fail")
|
||||
mockBuilder := &mockFormBuilder{
|
||||
mockCreateFormFile: func(string, *os.File) error {
|
||||
return mockFailedErr
|
||||
},
|
||||
}
|
||||
|
||||
err := createFileField(req, mockBuilder)
|
||||
checks.ErrorIs(t, err, mockFailedErr, "createFileField using a file should return error if form builder fails")
|
||||
})
|
||||
|
||||
t.Run("createFileField failing reader", func(t *testing.T) {
|
||||
req := AudioRequest{
|
||||
FilePath: "test.wav",
|
||||
Reader: bytes.NewBuffer([]byte(`wav test contents`)),
|
||||
}
|
||||
|
||||
mockFailedErr := fmt.Errorf("mock form builder fail")
|
||||
mockBuilder := &mockFormBuilder{
|
||||
mockCreateFormFileReader: func(string, io.Reader, string) error {
|
||||
return mockFailedErr
|
||||
},
|
||||
}
|
||||
|
||||
err := createFileField(req, mockBuilder)
|
||||
checks.ErrorIs(t, err, mockFailedErr, "createFileField using a reader should return error if form builder fails")
|
||||
})
|
||||
|
||||
t.Run("createFileField failing open", func(t *testing.T) {
|
||||
req := AudioRequest{
|
||||
FilePath: "non_existing_file.wav",
|
||||
}
|
||||
|
||||
mockBuilder := &mockFormBuilder{}
|
||||
|
||||
err := createFileField(req, mockBuilder)
|
||||
checks.HasError(t, err, "createFileField using file should return error when open file fails")
|
||||
})
|
||||
}
|
||||
|
||||
@@ -265,6 +265,7 @@ func handleVariateImageEndpoint(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
type mockFormBuilder struct {
|
||||
mockCreateFormFile func(string, *os.File) error
|
||||
mockCreateFormFileReader func(string, io.Reader, string) error
|
||||
mockWriteField func(string, string) error
|
||||
mockClose func() error
|
||||
}
|
||||
@@ -273,6 +274,10 @@ func (fb *mockFormBuilder) CreateFormFile(fieldname string, file *os.File) error
|
||||
return fb.mockCreateFormFile(fieldname, file)
|
||||
}
|
||||
|
||||
func (fb *mockFormBuilder) CreateFormFileReader(fieldname string, r io.Reader, filename string) error {
|
||||
return fb.mockCreateFormFileReader(fieldname, r, filename)
|
||||
}
|
||||
|
||||
func (fb *mockFormBuilder) WriteField(fieldname, value string) error {
|
||||
return fb.mockWriteField(fieldname, value)
|
||||
}
|
||||
|
||||
@@ -1,13 +1,16 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"os"
|
||||
"path"
|
||||
)
|
||||
|
||||
type FormBuilder interface {
|
||||
CreateFormFile(fieldname string, file *os.File) error
|
||||
CreateFormFileReader(fieldname string, r io.Reader, filename string) error
|
||||
WriteField(fieldname, value string) error
|
||||
Close() error
|
||||
FormDataContentType() string
|
||||
@@ -24,15 +27,28 @@ func NewFormBuilder(body io.Writer) *DefaultFormBuilder {
|
||||
}
|
||||
|
||||
func (fb *DefaultFormBuilder) CreateFormFile(fieldname string, file *os.File) error {
|
||||
fieldWriter, err := fb.writer.CreateFormFile(fieldname, file.Name())
|
||||
return fb.createFormFile(fieldname, file, file.Name())
|
||||
}
|
||||
|
||||
func (fb *DefaultFormBuilder) CreateFormFileReader(fieldname string, r io.Reader, filename string) error {
|
||||
return fb.createFormFile(fieldname, r, path.Base(filename))
|
||||
}
|
||||
|
||||
func (fb *DefaultFormBuilder) createFormFile(fieldname string, r io.Reader, filename string) error {
|
||||
if filename == "" {
|
||||
return fmt.Errorf("filename cannot be empty")
|
||||
}
|
||||
|
||||
fieldWriter, err := fb.writer.CreateFormFile(fieldname, filename)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = io.Copy(fieldWriter, file)
|
||||
_, err = io.Copy(fieldWriter, r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user