Check for Stream parameter usage (#174)

* check for stream:true usage

* lint
This commit is contained in:
sashabaranov
2023-03-18 19:31:54 +04:00
committed by GitHub
parent 34f3a118df
commit a6b35c3ab5
6 changed files with 45 additions and 7 deletions

11
chat.go
View File

@@ -14,7 +14,8 @@ const (
) )
var ( var (
ErrChatCompletionInvalidModel = errors.New("currently, only gpt-3.5-turbo and gpt-3.5-turbo-0301 are supported") ErrChatCompletionInvalidModel = errors.New("currently, only gpt-3.5-turbo and gpt-3.5-turbo-0301 are supported") //nolint:lll
ErrChatCompletionStreamNotSupported = errors.New("streaming is not supported with this method, please use CreateChatCompletionStream") //nolint:lll
) )
type ChatCompletionMessage struct { type ChatCompletionMessage struct {
@@ -65,8 +66,12 @@ func (c *Client) CreateChatCompletion(
ctx context.Context, ctx context.Context,
request ChatCompletionRequest, request ChatCompletionRequest,
) (response ChatCompletionResponse, err error) { ) (response ChatCompletionResponse, err error) {
model := request.Model if request.Stream {
switch model { err = ErrChatCompletionStreamNotSupported
return
}
switch request.Model {
case GPT3Dot5Turbo0301, GPT3Dot5Turbo, GPT4, GPT40314, GPT432K0314, GPT432K: case GPT3Dot5Turbo0301, GPT3Dot5Turbo, GPT4, GPT40314, GPT432K0314, GPT432K:
default: default:
err = ErrChatCompletionInvalidModel err = ErrChatCompletionInvalidModel

View File

@@ -38,6 +38,21 @@ func TestChatCompletionsWrongModel(t *testing.T) {
} }
} }
func TestChatCompletionsWithStream(t *testing.T) {
config := DefaultConfig("whatever")
config.BaseURL = "http://localhost/v1"
client := NewClientWithConfig(config)
ctx := context.Background()
req := ChatCompletionRequest{
Stream: true,
}
_, err := client.CreateChatCompletion(ctx, req)
if !errors.Is(err, ErrChatCompletionStreamNotSupported) {
t.Fatalf("CreateChatCompletion didn't return ErrChatCompletionStreamNotSupported error")
}
}
// TestCompletions Tests the completions endpoint of the API using the mocked server. // TestCompletions Tests the completions endpoint of the API using the mocked server.
func TestChatCompletions(t *testing.T) { func TestChatCompletions(t *testing.T) {
server := test.NewTestServer() server := test.NewTestServer()

View File

@@ -8,6 +8,7 @@ import (
var ( var (
ErrCompletionUnsupportedModel = errors.New("this model is not supported with this method, please use CreateChatCompletion client method instead") //nolint:lll ErrCompletionUnsupportedModel = errors.New("this model is not supported with this method, please use CreateChatCompletion client method instead") //nolint:lll
ErrCompletionStreamNotSupported = errors.New("streaming is not supported with this method, please use CreateCompletionStream") //nolint:lll
) )
// GPT3 Defines the models provided by OpenAI to use when generating // GPT3 Defines the models provided by OpenAI to use when generating
@@ -99,6 +100,11 @@ func (c *Client) CreateCompletion(
ctx context.Context, ctx context.Context,
request CompletionRequest, request CompletionRequest,
) (response CompletionResponse, err error) { ) (response CompletionResponse, err error) {
if request.Stream {
err = ErrCompletionStreamNotSupported
return
}
if request.Model == GPT3Dot5Turbo0301 || request.Model == GPT3Dot5Turbo { if request.Model == GPT3Dot5Turbo0301 || request.Model == GPT3Dot5Turbo {
err = ErrCompletionUnsupportedModel err = ErrCompletionUnsupportedModel
return return

View File

@@ -33,6 +33,18 @@ func TestCompletionsWrongModel(t *testing.T) {
} }
} }
func TestCompletionWithStream(t *testing.T) {
config := DefaultConfig("whatever")
client := NewClientWithConfig(config)
ctx := context.Background()
req := CompletionRequest{Stream: true}
_, err := client.CreateCompletion(ctx, req)
if !errors.Is(err, ErrCompletionStreamNotSupported) {
t.Fatalf("CreateCompletion didn't return ErrCompletionStreamNotSupported")
}
}
// TestCompletions Tests the completions endpoint of the API using the mocked server. // TestCompletions Tests the completions endpoint of the API using the mocked server.
func TestCompletions(t *testing.T) { func TestCompletions(t *testing.T) {
server := test.NewTestServer() server := test.NewTestServer()

View File

@@ -33,7 +33,7 @@ func TestListModels(t *testing.T) {
} }
// handleModelsEndpoint Handles the models endpoint by the test server. // handleModelsEndpoint Handles the models endpoint by the test server.
func handleModelsEndpoint(w http.ResponseWriter, r *http.Request) { func handleModelsEndpoint(w http.ResponseWriter, _ *http.Request) {
resBytes, _ := json.Marshal(ModelsList{}) resBytes, _ := json.Marshal(ModelsList{})
fmt.Fprintln(w, string(resBytes)) fmt.Fprintln(w, string(resBytes))
} }

View File

@@ -19,11 +19,11 @@ type (
failingMarshaller struct{} failingMarshaller struct{}
) )
func (*failingMarshaller) marshal(value any) ([]byte, error) { func (*failingMarshaller) marshal(_ any) ([]byte, error) {
return []byte{}, errTestMarshallerFailed return []byte{}, errTestMarshallerFailed
} }
func (*failingRequestBuilder) build(ctx context.Context, method, url string, requset any) (*http.Request, error) { func (*failingRequestBuilder) build(_ context.Context, _, _ string, _ any) (*http.Request, error) {
return nil, errTestRequestBuilderFailed return nil, errTestRequestBuilderFailed
} }