allow custom voice and speech models (#691)
This commit is contained in:
31
speech.go
31
speech.go
@@ -2,7 +2,6 @@ package openai
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -36,11 +35,6 @@ const (
|
|||||||
SpeechResponseFormatPcm SpeechResponseFormat = "pcm"
|
SpeechResponseFormatPcm SpeechResponseFormat = "pcm"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
|
||||||
ErrInvalidSpeechModel = errors.New("invalid speech model")
|
|
||||||
ErrInvalidVoice = errors.New("invalid voice")
|
|
||||||
)
|
|
||||||
|
|
||||||
type CreateSpeechRequest struct {
|
type CreateSpeechRequest struct {
|
||||||
Model SpeechModel `json:"model"`
|
Model SpeechModel `json:"model"`
|
||||||
Input string `json:"input"`
|
Input string `json:"input"`
|
||||||
@@ -49,32 +43,7 @@ type CreateSpeechRequest struct {
|
|||||||
Speed float64 `json:"speed,omitempty"` // Optional, default to 1.0
|
Speed float64 `json:"speed,omitempty"` // Optional, default to 1.0
|
||||||
}
|
}
|
||||||
|
|
||||||
func contains[T comparable](s []T, e T) bool {
|
|
||||||
for _, v := range s {
|
|
||||||
if v == e {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func isValidSpeechModel(model SpeechModel) bool {
|
|
||||||
return contains([]SpeechModel{TTSModel1, TTSModel1HD, TTSModelCanary}, model)
|
|
||||||
}
|
|
||||||
|
|
||||||
func isValidVoice(voice SpeechVoice) bool {
|
|
||||||
return contains([]SpeechVoice{VoiceAlloy, VoiceEcho, VoiceFable, VoiceOnyx, VoiceNova, VoiceShimmer}, voice)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Client) CreateSpeech(ctx context.Context, request CreateSpeechRequest) (response RawResponse, err error) {
|
func (c *Client) CreateSpeech(ctx context.Context, request CreateSpeechRequest) (response RawResponse, err error) {
|
||||||
if !isValidSpeechModel(request.Model) {
|
|
||||||
err = ErrInvalidSpeechModel
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !isValidVoice(request.Voice) {
|
|
||||||
err = ErrInvalidVoice
|
|
||||||
return
|
|
||||||
}
|
|
||||||
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/audio/speech", string(request.Model)),
|
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/audio/speech", string(request.Model)),
|
||||||
withBody(request),
|
withBody(request),
|
||||||
withContentType("application/json"),
|
withContentType("application/json"),
|
||||||
|
|||||||
@@ -95,21 +95,4 @@ func TestSpeechIntegration(t *testing.T) {
|
|||||||
err = os.WriteFile("test.mp3", buf, 0644)
|
err = os.WriteFile("test.mp3", buf, 0644)
|
||||||
checks.NoError(t, err, "Create error")
|
checks.NoError(t, err, "Create error")
|
||||||
})
|
})
|
||||||
t.Run("invalid model", func(t *testing.T) {
|
|
||||||
_, err := client.CreateSpeech(context.Background(), openai.CreateSpeechRequest{
|
|
||||||
Model: "invalid_model",
|
|
||||||
Input: "Hello!",
|
|
||||||
Voice: openai.VoiceAlloy,
|
|
||||||
})
|
|
||||||
checks.ErrorIs(t, err, openai.ErrInvalidSpeechModel, "CreateSpeech error")
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("invalid voice", func(t *testing.T) {
|
|
||||||
_, err := client.CreateSpeech(context.Background(), openai.CreateSpeechRequest{
|
|
||||||
Model: openai.TTSModel1,
|
|
||||||
Input: "Hello!",
|
|
||||||
Voice: "invalid_voice",
|
|
||||||
})
|
|
||||||
checks.ErrorIs(t, err, openai.ErrInvalidVoice, "CreateSpeech error")
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user