From abffeceb712708ad5e496d6d1b4936c37d244c8e Mon Sep 17 00:00:00 2001 From: sashabaranov <677093+sashabaranov@users.noreply.github.com> Date: Thu, 16 Mar 2023 10:43:41 +0400 Subject: [PATCH] Check for GPT-4 models (#169) * add chat gpt4 model support (#158) * remove check for gpt4 for CreateCompletion * test for model check --------- Co-authored-by: aeieli --- chat.go | 4 +++- completion_test.go | 18 ++++++++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/chat.go b/chat.go index 14be6f4..99edfe8 100644 --- a/chat.go +++ b/chat.go @@ -66,7 +66,9 @@ func (c *Client) CreateChatCompletion( request ChatCompletionRequest, ) (response ChatCompletionResponse, err error) { model := request.Model - if model != GPT3Dot5Turbo0301 && model != GPT3Dot5Turbo { + switch model { + case GPT3Dot5Turbo0301, GPT3Dot5Turbo, GPT4, GPT40314, GPT432K0314, GPT432K: + default: err = ErrChatCompletionInvalidModel return } diff --git a/completion_test.go b/completion_test.go index 74e29ff..9868eb2 100644 --- a/completion_test.go +++ b/completion_test.go @@ -6,6 +6,7 @@ import ( "context" "encoding/json" + "errors" "fmt" "io" "net/http" @@ -15,6 +16,23 @@ import ( "time" ) +func TestCompletionsWrongModel(t *testing.T) { + config := DefaultConfig("whatever") + config.BaseURL = "http://localhost/v1" + client := NewClientWithConfig(config) + + _, err := client.CreateCompletion( + context.Background(), + CompletionRequest{ + MaxTokens: 5, + Model: GPT3Dot5Turbo, + }, + ) + if !errors.Is(err, ErrCompletionUnsupportedModel) { + t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel, but returned: %v", err) + } +} + // TestCompletions Tests the completions endpoint of the API using the mocked server. func TestCompletions(t *testing.T) { server := test.NewTestServer()