Check for Stream parameter usage (#174)
* check for stream:true usage * lint
This commit is contained in:
11
chat.go
11
chat.go
@@ -14,7 +14,8 @@ const (
|
||||
)
|
||||
|
||||
var (
|
||||
ErrChatCompletionInvalidModel = errors.New("currently, only gpt-3.5-turbo and gpt-3.5-turbo-0301 are supported")
|
||||
ErrChatCompletionInvalidModel = errors.New("currently, only gpt-3.5-turbo and gpt-3.5-turbo-0301 are supported") //nolint:lll
|
||||
ErrChatCompletionStreamNotSupported = errors.New("streaming is not supported with this method, please use CreateChatCompletionStream") //nolint:lll
|
||||
)
|
||||
|
||||
type ChatCompletionMessage struct {
|
||||
@@ -65,8 +66,12 @@ func (c *Client) CreateChatCompletion(
|
||||
ctx context.Context,
|
||||
request ChatCompletionRequest,
|
||||
) (response ChatCompletionResponse, err error) {
|
||||
model := request.Model
|
||||
switch model {
|
||||
if request.Stream {
|
||||
err = ErrChatCompletionStreamNotSupported
|
||||
return
|
||||
}
|
||||
|
||||
switch request.Model {
|
||||
case GPT3Dot5Turbo0301, GPT3Dot5Turbo, GPT4, GPT40314, GPT432K0314, GPT432K:
|
||||
default:
|
||||
err = ErrChatCompletionInvalidModel
|
||||
|
||||
15
chat_test.go
15
chat_test.go
@@ -38,6 +38,21 @@ func TestChatCompletionsWrongModel(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatCompletionsWithStream(t *testing.T) {
|
||||
config := DefaultConfig("whatever")
|
||||
config.BaseURL = "http://localhost/v1"
|
||||
client := NewClientWithConfig(config)
|
||||
ctx := context.Background()
|
||||
|
||||
req := ChatCompletionRequest{
|
||||
Stream: true,
|
||||
}
|
||||
_, err := client.CreateChatCompletion(ctx, req)
|
||||
if !errors.Is(err, ErrChatCompletionStreamNotSupported) {
|
||||
t.Fatalf("CreateChatCompletion didn't return ErrChatCompletionStreamNotSupported error")
|
||||
}
|
||||
}
|
||||
|
||||
// TestCompletions Tests the completions endpoint of the API using the mocked server.
|
||||
func TestChatCompletions(t *testing.T) {
|
||||
server := test.NewTestServer()
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
|
||||
var (
|
||||
ErrCompletionUnsupportedModel = errors.New("this model is not supported with this method, please use CreateChatCompletion client method instead") //nolint:lll
|
||||
ErrCompletionStreamNotSupported = errors.New("streaming is not supported with this method, please use CreateCompletionStream") //nolint:lll
|
||||
)
|
||||
|
||||
// GPT3 Defines the models provided by OpenAI to use when generating
|
||||
@@ -99,6 +100,11 @@ func (c *Client) CreateCompletion(
|
||||
ctx context.Context,
|
||||
request CompletionRequest,
|
||||
) (response CompletionResponse, err error) {
|
||||
if request.Stream {
|
||||
err = ErrCompletionStreamNotSupported
|
||||
return
|
||||
}
|
||||
|
||||
if request.Model == GPT3Dot5Turbo0301 || request.Model == GPT3Dot5Turbo {
|
||||
err = ErrCompletionUnsupportedModel
|
||||
return
|
||||
|
||||
@@ -33,6 +33,18 @@ func TestCompletionsWrongModel(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompletionWithStream(t *testing.T) {
|
||||
config := DefaultConfig("whatever")
|
||||
client := NewClientWithConfig(config)
|
||||
|
||||
ctx := context.Background()
|
||||
req := CompletionRequest{Stream: true}
|
||||
_, err := client.CreateCompletion(ctx, req)
|
||||
if !errors.Is(err, ErrCompletionStreamNotSupported) {
|
||||
t.Fatalf("CreateCompletion didn't return ErrCompletionStreamNotSupported")
|
||||
}
|
||||
}
|
||||
|
||||
// TestCompletions Tests the completions endpoint of the API using the mocked server.
|
||||
func TestCompletions(t *testing.T) {
|
||||
server := test.NewTestServer()
|
||||
|
||||
@@ -33,7 +33,7 @@ func TestListModels(t *testing.T) {
|
||||
}
|
||||
|
||||
// handleModelsEndpoint Handles the models endpoint by the test server.
|
||||
func handleModelsEndpoint(w http.ResponseWriter, r *http.Request) {
|
||||
func handleModelsEndpoint(w http.ResponseWriter, _ *http.Request) {
|
||||
resBytes, _ := json.Marshal(ModelsList{})
|
||||
fmt.Fprintln(w, string(resBytes))
|
||||
}
|
||||
|
||||
@@ -19,11 +19,11 @@ type (
|
||||
failingMarshaller struct{}
|
||||
)
|
||||
|
||||
func (*failingMarshaller) marshal(value any) ([]byte, error) {
|
||||
func (*failingMarshaller) marshal(_ any) ([]byte, error) {
|
||||
return []byte{}, errTestMarshallerFailed
|
||||
}
|
||||
|
||||
func (*failingRequestBuilder) build(ctx context.Context, method, url string, requset any) (*http.Request, error) {
|
||||
func (*failingRequestBuilder) build(_ context.Context, _, _ string, _ any) (*http.Request, error) {
|
||||
return nil, errTestRequestBuilderFailed
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user