diff --git a/README.md b/README.md
index 898465c..7526ea3 100644
--- a/README.md
+++ b/README.md
@@ -223,6 +223,47 @@ func main() {
```
+
+Audio Captions
+
+```go
+package main
+
+import (
+ "context"
+ "fmt"
+ "os"
+
+ openai "github.com/sashabaranov/go-openai"
+)
+
+func main() {
+ c := openai.NewClient(os.Getenv("OPENAI_KEY"))
+
+ req := openai.AudioRequest{
+ Model: openai.Whisper1,
+ FilePath: os.Args[1],
+ Format: openai.AudioResponseFormatSRT,
+ }
+ resp, err := c.CreateTranscription(context.Background(), req)
+ if err != nil {
+ fmt.Printf("Transcription error: %v\n", err)
+ return
+ }
+ f, err := os.Create(os.Args[1] + ".srt")
+ if err != nil {
+ fmt.Printf("Could not open file: %v\n", err)
+ return
+ }
+ defer f.Close()
+ if _, err := f.WriteString(resp.Text); err != nil {
+ fmt.Printf("Error writing to file: %v\n", err)
+ return
+ }
+}
+```
+
+
DALL-E 2 image generation
@@ -420,4 +461,4 @@ func main() {
fmt.Println(resp.Choices[0].Message.Content)
}
```
-
\ No newline at end of file
+
diff --git a/audio.go b/audio.go
index 9db9298..46c3711 100644
--- a/audio.go
+++ b/audio.go
@@ -13,6 +13,15 @@ const (
Whisper1 = "whisper-1"
)
+// Response formats; Whisper uses AudioResponseFormatJSON by default.
+type AudioResponseFormat string
+
+const (
+ AudioResponseFormatJSON AudioResponseFormat = "json"
+ AudioResponseFormatSRT AudioResponseFormat = "srt"
+ AudioResponseFormatVTT AudioResponseFormat = "vtt"
+)
+
// AudioRequest represents a request structure for audio API.
// ResponseFormat is not supported for now. We only return JSON text, which may be sufficient.
type AudioRequest struct {
@@ -21,6 +30,7 @@ type AudioRequest struct {
Prompt string // For translation, it should be in English
Temperature float32
Language string // For translation, just do not use it. It seems "en" works, not confirmed...
+ Format AudioResponseFormat
}
// AudioResponse represents a response structure for audio API.
@@ -66,10 +76,19 @@ func (c *Client) callAudioAPI(
}
req.Header.Add("Content-Type", builder.formDataContentType())
- err = c.sendRequest(req, &response)
+ if request.HasJSONResponse() {
+ err = c.sendRequest(req, &response)
+ } else {
+ err = c.sendRequest(req, &response.Text)
+ }
return
}
+// HasJSONResponse returns true if the response format is JSON.
+func (r AudioRequest) HasJSONResponse() bool {
+ return r.Format == "" || r.Format == AudioResponseFormatJSON
+}
+
// audioMultipartForm creates a form with audio file contents and the name of the model to use for
// audio processing.
func audioMultipartForm(request AudioRequest, b formBuilder) error {
@@ -97,6 +116,14 @@ func audioMultipartForm(request AudioRequest, b formBuilder) error {
}
}
+ // Create a form field for the format (if provided)
+ if request.Format != "" {
+ err = b.writeField("response_format", string(request.Format))
+ if err != nil {
+ return fmt.Errorf("writing format: %w", err)
+ }
+ }
+
// Create a form field for the temperature (if provided)
if request.Temperature != 0 {
err = b.writeField("temperature", fmt.Sprintf("%.2f", request.Temperature))
diff --git a/audio_test.go b/audio_test.go
index 9d2abfc..daf51f2 100644
--- a/audio_test.go
+++ b/audio_test.go
@@ -112,6 +112,7 @@ func TestAudioWithOptionalArgs(t *testing.T) {
Prompt: "用简体中文",
Temperature: 0.5,
Language: "zh",
+ Format: AudioResponseFormatSRT,
}
_, err = tc.createFn(ctx, req)
checks.NoError(t, err, "audio API error")
@@ -179,6 +180,7 @@ func TestAudioWithFailingFormBuilder(t *testing.T) {
Prompt: "test",
Temperature: 0.5,
Language: "en",
+ Format: AudioResponseFormatSRT,
}
mockFailedErr := fmt.Errorf("mock form builder fail")
@@ -202,7 +204,7 @@ func TestAudioWithFailingFormBuilder(t *testing.T) {
return nil
}
- failOn := []string{"model", "prompt", "temperature", "language"}
+ failOn := []string{"model", "prompt", "temperature", "language", "response_format"}
for _, failingField := range failOn {
failForField = failingField
mockFailedErr = fmt.Errorf("mock form builder fail on field %s", failingField)
diff --git a/client.go b/client.go
index b15a18a..e17ded2 100644
--- a/client.go
+++ b/client.go
@@ -43,7 +43,7 @@ func NewOrgClient(authToken, org string) *Client {
return NewClientWithConfig(config)
}
-func (c *Client) sendRequest(req *http.Request, v interface{}) error {
+func (c *Client) sendRequest(req *http.Request, v any) error {
req.Header.Set("Accept", "application/json; charset=utf-8")
// Azure API Key authentication
if c.config.APIType == APITypeAzure {
@@ -75,12 +75,26 @@ func (c *Client) sendRequest(req *http.Request, v interface{}) error {
return c.handleErrorResp(res)
}
- if v != nil {
- if err = json.NewDecoder(res.Body).Decode(v); err != nil {
- return err
- }
+ return decodeResponse(res.Body, v)
+}
+
+func decodeResponse(body io.Reader, v any) error {
+ if v == nil {
+ return nil
}
+ if result, ok := v.(*string); ok {
+ return decodeString(body, result)
+ }
+ return json.NewDecoder(body).Decode(v)
+}
+
+func decodeString(body io.Reader, output *string) error {
+ b, err := io.ReadAll(body)
+ if err != nil {
+ return err
+ }
+ *output = string(b)
return nil
}
diff --git a/client_test.go b/client_test.go
index 1c15985..7bea6dd 100644
--- a/client_test.go
+++ b/client_test.go
@@ -1,6 +1,8 @@
package openai //nolint:testpackage // testing private field
import (
+ "bytes"
+ "io"
"testing"
)
@@ -20,3 +22,38 @@ func TestClient(t *testing.T) {
t.Errorf("Client does not contain proper orgID")
}
}
+
+func TestDecodeResponse(t *testing.T) {
+ stringInput := ""
+
+ testCases := []struct {
+ name string
+ value interface{}
+ body io.Reader
+ }{
+ {
+ name: "nil input",
+ value: nil,
+ body: bytes.NewReader([]byte("")),
+ },
+ {
+ name: "string input",
+ value: &stringInput,
+ body: bytes.NewReader([]byte("test")),
+ },
+ {
+ name: "map input",
+ value: &map[string]interface{}{},
+ body: bytes.NewReader([]byte(`{"test": "test"}`)),
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ err := decodeResponse(tc.body, tc.value)
+ if err != nil {
+ t.Errorf("Unexpected error: %v", err)
+ }
+ })
+ }
+}