Check for Stream parameter usage (#174)
* check for stream:true usage * lint
This commit is contained in:
11
chat.go
11
chat.go
@@ -14,7 +14,8 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ErrChatCompletionInvalidModel = errors.New("currently, only gpt-3.5-turbo and gpt-3.5-turbo-0301 are supported")
|
ErrChatCompletionInvalidModel = errors.New("currently, only gpt-3.5-turbo and gpt-3.5-turbo-0301 are supported") //nolint:lll
|
||||||
|
ErrChatCompletionStreamNotSupported = errors.New("streaming is not supported with this method, please use CreateChatCompletionStream") //nolint:lll
|
||||||
)
|
)
|
||||||
|
|
||||||
type ChatCompletionMessage struct {
|
type ChatCompletionMessage struct {
|
||||||
@@ -65,8 +66,12 @@ func (c *Client) CreateChatCompletion(
|
|||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
request ChatCompletionRequest,
|
request ChatCompletionRequest,
|
||||||
) (response ChatCompletionResponse, err error) {
|
) (response ChatCompletionResponse, err error) {
|
||||||
model := request.Model
|
if request.Stream {
|
||||||
switch model {
|
err = ErrChatCompletionStreamNotSupported
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
switch request.Model {
|
||||||
case GPT3Dot5Turbo0301, GPT3Dot5Turbo, GPT4, GPT40314, GPT432K0314, GPT432K:
|
case GPT3Dot5Turbo0301, GPT3Dot5Turbo, GPT4, GPT40314, GPT432K0314, GPT432K:
|
||||||
default:
|
default:
|
||||||
err = ErrChatCompletionInvalidModel
|
err = ErrChatCompletionInvalidModel
|
||||||
|
|||||||
15
chat_test.go
15
chat_test.go
@@ -38,6 +38,21 @@ func TestChatCompletionsWrongModel(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestChatCompletionsWithStream(t *testing.T) {
|
||||||
|
config := DefaultConfig("whatever")
|
||||||
|
config.BaseURL = "http://localhost/v1"
|
||||||
|
client := NewClientWithConfig(config)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
req := ChatCompletionRequest{
|
||||||
|
Stream: true,
|
||||||
|
}
|
||||||
|
_, err := client.CreateChatCompletion(ctx, req)
|
||||||
|
if !errors.Is(err, ErrChatCompletionStreamNotSupported) {
|
||||||
|
t.Fatalf("CreateChatCompletion didn't return ErrChatCompletionStreamNotSupported error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// TestCompletions Tests the completions endpoint of the API using the mocked server.
|
// TestCompletions Tests the completions endpoint of the API using the mocked server.
|
||||||
func TestChatCompletions(t *testing.T) {
|
func TestChatCompletions(t *testing.T) {
|
||||||
server := test.NewTestServer()
|
server := test.NewTestServer()
|
||||||
|
|||||||
@@ -7,7 +7,8 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ErrCompletionUnsupportedModel = errors.New("this model is not supported with this method, please use CreateChatCompletion client method instead") //nolint:lll
|
ErrCompletionUnsupportedModel = errors.New("this model is not supported with this method, please use CreateChatCompletion client method instead") //nolint:lll
|
||||||
|
ErrCompletionStreamNotSupported = errors.New("streaming is not supported with this method, please use CreateCompletionStream") //nolint:lll
|
||||||
)
|
)
|
||||||
|
|
||||||
// GPT3 Defines the models provided by OpenAI to use when generating
|
// GPT3 Defines the models provided by OpenAI to use when generating
|
||||||
@@ -99,6 +100,11 @@ func (c *Client) CreateCompletion(
|
|||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
request CompletionRequest,
|
request CompletionRequest,
|
||||||
) (response CompletionResponse, err error) {
|
) (response CompletionResponse, err error) {
|
||||||
|
if request.Stream {
|
||||||
|
err = ErrCompletionStreamNotSupported
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if request.Model == GPT3Dot5Turbo0301 || request.Model == GPT3Dot5Turbo {
|
if request.Model == GPT3Dot5Turbo0301 || request.Model == GPT3Dot5Turbo {
|
||||||
err = ErrCompletionUnsupportedModel
|
err = ErrCompletionUnsupportedModel
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -33,6 +33,18 @@ func TestCompletionsWrongModel(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCompletionWithStream(t *testing.T) {
|
||||||
|
config := DefaultConfig("whatever")
|
||||||
|
client := NewClientWithConfig(config)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
req := CompletionRequest{Stream: true}
|
||||||
|
_, err := client.CreateCompletion(ctx, req)
|
||||||
|
if !errors.Is(err, ErrCompletionStreamNotSupported) {
|
||||||
|
t.Fatalf("CreateCompletion didn't return ErrCompletionStreamNotSupported")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// TestCompletions Tests the completions endpoint of the API using the mocked server.
|
// TestCompletions Tests the completions endpoint of the API using the mocked server.
|
||||||
func TestCompletions(t *testing.T) {
|
func TestCompletions(t *testing.T) {
|
||||||
server := test.NewTestServer()
|
server := test.NewTestServer()
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ func TestListModels(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// handleModelsEndpoint Handles the models endpoint by the test server.
|
// handleModelsEndpoint Handles the models endpoint by the test server.
|
||||||
func handleModelsEndpoint(w http.ResponseWriter, r *http.Request) {
|
func handleModelsEndpoint(w http.ResponseWriter, _ *http.Request) {
|
||||||
resBytes, _ := json.Marshal(ModelsList{})
|
resBytes, _ := json.Marshal(ModelsList{})
|
||||||
fmt.Fprintln(w, string(resBytes))
|
fmt.Fprintln(w, string(resBytes))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -19,11 +19,11 @@ type (
|
|||||||
failingMarshaller struct{}
|
failingMarshaller struct{}
|
||||||
)
|
)
|
||||||
|
|
||||||
func (*failingMarshaller) marshal(value any) ([]byte, error) {
|
func (*failingMarshaller) marshal(_ any) ([]byte, error) {
|
||||||
return []byte{}, errTestMarshallerFailed
|
return []byte{}, errTestMarshallerFailed
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*failingRequestBuilder) build(ctx context.Context, method, url string, requset any) (*http.Request, error) {
|
func (*failingRequestBuilder) build(_ context.Context, _, _ string, _ any) (*http.Request, error) {
|
||||||
return nil, errTestRequestBuilderFailed
|
return nil, errTestRequestBuilderFailed
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user