* Add support for multi part chat messages OpenAI has recently introduced a new model called gpt-4-visual-preview, which now supports images as input. The chat completion endpoint accepts multi-part chat messages, where the content can be an array of structs in addition to the usual string format. This commit introduces new structures and constants to represent different types of content parts. It also implements the json.Marshaler and json.Unmarshaler interfaces on ChatCompletionMessage. * Add ImageURLDetail and ChatMessagePartType types * Optimize ChatCompletionMessage deserialization * Add ErrContentFieldsMisused error
530 lines
16 KiB
Go
530 lines
16 KiB
Go
package openai_test
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"strconv"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/sashabaranov/go-openai"
|
|
"github.com/sashabaranov/go-openai/internal/test/checks"
|
|
"github.com/sashabaranov/go-openai/jsonschema"
|
|
)
|
|
|
|
const (
|
|
xCustomHeader = "X-CUSTOM-HEADER"
|
|
xCustomHeaderValue = "test"
|
|
)
|
|
|
|
var rateLimitHeaders = map[string]any{
|
|
"x-ratelimit-limit-requests": 60,
|
|
"x-ratelimit-limit-tokens": 150000,
|
|
"x-ratelimit-remaining-requests": 59,
|
|
"x-ratelimit-remaining-tokens": 149984,
|
|
"x-ratelimit-reset-requests": "1s",
|
|
"x-ratelimit-reset-tokens": "6m0s",
|
|
}
|
|
|
|
func TestChatCompletionsWrongModel(t *testing.T) {
|
|
config := openai.DefaultConfig("whatever")
|
|
config.BaseURL = "http://localhost/v1"
|
|
client := openai.NewClientWithConfig(config)
|
|
ctx := context.Background()
|
|
|
|
req := openai.ChatCompletionRequest{
|
|
MaxTokens: 5,
|
|
Model: "ada",
|
|
Messages: []openai.ChatCompletionMessage{
|
|
{
|
|
Role: openai.ChatMessageRoleUser,
|
|
Content: "Hello!",
|
|
},
|
|
},
|
|
}
|
|
_, err := client.CreateChatCompletion(ctx, req)
|
|
msg := fmt.Sprintf("CreateChatCompletion should return wrong model error, returned: %s", err)
|
|
checks.ErrorIs(t, err, openai.ErrChatCompletionInvalidModel, msg)
|
|
}
|
|
|
|
func TestChatRequestOmitEmpty(t *testing.T) {
|
|
data, err := json.Marshal(openai.ChatCompletionRequest{
|
|
// We set model b/c it's required, so omitempty doesn't make sense
|
|
Model: "gpt-4",
|
|
})
|
|
checks.NoError(t, err)
|
|
|
|
// messages is also required so isn't omitted
|
|
const expected = `{"model":"gpt-4","messages":null}`
|
|
if string(data) != expected {
|
|
t.Errorf("expected JSON with all empty fields to be %v but was %v", expected, string(data))
|
|
}
|
|
}
|
|
|
|
func TestChatCompletionsWithStream(t *testing.T) {
|
|
config := openai.DefaultConfig("whatever")
|
|
config.BaseURL = "http://localhost/v1"
|
|
client := openai.NewClientWithConfig(config)
|
|
ctx := context.Background()
|
|
|
|
req := openai.ChatCompletionRequest{
|
|
Stream: true,
|
|
}
|
|
_, err := client.CreateChatCompletion(ctx, req)
|
|
checks.ErrorIs(t, err, openai.ErrChatCompletionStreamNotSupported, "unexpected error")
|
|
}
|
|
|
|
// TestCompletions Tests the completions endpoint of the API using the mocked server.
|
|
func TestChatCompletions(t *testing.T) {
|
|
client, server, teardown := setupOpenAITestServer()
|
|
defer teardown()
|
|
server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint)
|
|
_, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{
|
|
MaxTokens: 5,
|
|
Model: openai.GPT3Dot5Turbo,
|
|
Messages: []openai.ChatCompletionMessage{
|
|
{
|
|
Role: openai.ChatMessageRoleUser,
|
|
Content: "Hello!",
|
|
},
|
|
},
|
|
})
|
|
checks.NoError(t, err, "CreateChatCompletion error")
|
|
}
|
|
|
|
// TestCompletions Tests the completions endpoint of the API using the mocked server.
|
|
func TestChatCompletionsWithHeaders(t *testing.T) {
|
|
client, server, teardown := setupOpenAITestServer()
|
|
defer teardown()
|
|
server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint)
|
|
resp, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{
|
|
MaxTokens: 5,
|
|
Model: openai.GPT3Dot5Turbo,
|
|
Messages: []openai.ChatCompletionMessage{
|
|
{
|
|
Role: openai.ChatMessageRoleUser,
|
|
Content: "Hello!",
|
|
},
|
|
},
|
|
})
|
|
checks.NoError(t, err, "CreateChatCompletion error")
|
|
|
|
a := resp.Header().Get(xCustomHeader)
|
|
_ = a
|
|
if resp.Header().Get(xCustomHeader) != xCustomHeaderValue {
|
|
t.Errorf("expected header %s to be %s", xCustomHeader, xCustomHeaderValue)
|
|
}
|
|
}
|
|
|
|
// TestChatCompletionsWithRateLimitHeaders Tests the completions endpoint of the API using the mocked server.
|
|
func TestChatCompletionsWithRateLimitHeaders(t *testing.T) {
|
|
client, server, teardown := setupOpenAITestServer()
|
|
defer teardown()
|
|
server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint)
|
|
resp, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{
|
|
MaxTokens: 5,
|
|
Model: openai.GPT3Dot5Turbo,
|
|
Messages: []openai.ChatCompletionMessage{
|
|
{
|
|
Role: openai.ChatMessageRoleUser,
|
|
Content: "Hello!",
|
|
},
|
|
},
|
|
})
|
|
checks.NoError(t, err, "CreateChatCompletion error")
|
|
|
|
headers := resp.GetRateLimitHeaders()
|
|
resetRequests := headers.ResetRequests.String()
|
|
if resetRequests != rateLimitHeaders["x-ratelimit-reset-requests"] {
|
|
t.Errorf("expected resetRequests %s to be %s", resetRequests, rateLimitHeaders["x-ratelimit-reset-requests"])
|
|
}
|
|
resetRequestsTime := headers.ResetRequests.Time()
|
|
if resetRequestsTime.Before(time.Now()) {
|
|
t.Errorf("unexpected reset requests: %v", resetRequestsTime)
|
|
}
|
|
|
|
bs1, _ := json.Marshal(headers)
|
|
bs2, _ := json.Marshal(rateLimitHeaders)
|
|
if string(bs1) != string(bs2) {
|
|
t.Errorf("expected rate limit header %s to be %s", bs2, bs1)
|
|
}
|
|
}
|
|
|
|
// TestChatCompletionsFunctions tests including a function call.
|
|
func TestChatCompletionsFunctions(t *testing.T) {
|
|
client, server, teardown := setupOpenAITestServer()
|
|
defer teardown()
|
|
server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint)
|
|
t.Run("bytes", func(t *testing.T) {
|
|
//nolint:lll
|
|
msg := json.RawMessage(`{"properties":{"count":{"type":"integer","description":"total number of words in sentence"},"words":{"items":{"type":"string"},"type":"array","description":"list of words in sentence"}},"type":"object","required":["count","words"]}`)
|
|
_, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{
|
|
MaxTokens: 5,
|
|
Model: openai.GPT3Dot5Turbo0613,
|
|
Messages: []openai.ChatCompletionMessage{
|
|
{
|
|
Role: openai.ChatMessageRoleUser,
|
|
Content: "Hello!",
|
|
},
|
|
},
|
|
Functions: []openai.FunctionDefinition{{
|
|
Name: "test",
|
|
Parameters: &msg,
|
|
}},
|
|
})
|
|
checks.NoError(t, err, "CreateChatCompletion with functions error")
|
|
})
|
|
t.Run("struct", func(t *testing.T) {
|
|
type testMessage struct {
|
|
Count int `json:"count"`
|
|
Words []string `json:"words"`
|
|
}
|
|
msg := testMessage{
|
|
Count: 2,
|
|
Words: []string{"hello", "world"},
|
|
}
|
|
_, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{
|
|
MaxTokens: 5,
|
|
Model: openai.GPT3Dot5Turbo0613,
|
|
Messages: []openai.ChatCompletionMessage{
|
|
{
|
|
Role: openai.ChatMessageRoleUser,
|
|
Content: "Hello!",
|
|
},
|
|
},
|
|
Functions: []openai.FunctionDefinition{{
|
|
Name: "test",
|
|
Parameters: &msg,
|
|
}},
|
|
})
|
|
checks.NoError(t, err, "CreateChatCompletion with functions error")
|
|
})
|
|
t.Run("JSONSchemaDefinition", func(t *testing.T) {
|
|
_, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{
|
|
MaxTokens: 5,
|
|
Model: openai.GPT3Dot5Turbo0613,
|
|
Messages: []openai.ChatCompletionMessage{
|
|
{
|
|
Role: openai.ChatMessageRoleUser,
|
|
Content: "Hello!",
|
|
},
|
|
},
|
|
Functions: []openai.FunctionDefinition{{
|
|
Name: "test",
|
|
Parameters: &jsonschema.Definition{
|
|
Type: jsonschema.Object,
|
|
Properties: map[string]jsonschema.Definition{
|
|
"count": {
|
|
Type: jsonschema.Number,
|
|
Description: "total number of words in sentence",
|
|
},
|
|
"words": {
|
|
Type: jsonschema.Array,
|
|
Description: "list of words in sentence",
|
|
Items: &jsonschema.Definition{
|
|
Type: jsonschema.String,
|
|
},
|
|
},
|
|
"enumTest": {
|
|
Type: jsonschema.String,
|
|
Enum: []string{"hello", "world"},
|
|
},
|
|
},
|
|
},
|
|
}},
|
|
})
|
|
checks.NoError(t, err, "CreateChatCompletion with functions error")
|
|
})
|
|
t.Run("JSONSchemaDefinitionWithFunctionDefine", func(t *testing.T) {
|
|
// this is a compatibility check
|
|
_, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{
|
|
MaxTokens: 5,
|
|
Model: openai.GPT3Dot5Turbo0613,
|
|
Messages: []openai.ChatCompletionMessage{
|
|
{
|
|
Role: openai.ChatMessageRoleUser,
|
|
Content: "Hello!",
|
|
},
|
|
},
|
|
Functions: []openai.FunctionDefine{{
|
|
Name: "test",
|
|
Parameters: &jsonschema.Definition{
|
|
Type: jsonschema.Object,
|
|
Properties: map[string]jsonschema.Definition{
|
|
"count": {
|
|
Type: jsonschema.Number,
|
|
Description: "total number of words in sentence",
|
|
},
|
|
"words": {
|
|
Type: jsonschema.Array,
|
|
Description: "list of words in sentence",
|
|
Items: &jsonschema.Definition{
|
|
Type: jsonschema.String,
|
|
},
|
|
},
|
|
"enumTest": {
|
|
Type: jsonschema.String,
|
|
Enum: []string{"hello", "world"},
|
|
},
|
|
},
|
|
},
|
|
}},
|
|
})
|
|
checks.NoError(t, err, "CreateChatCompletion with functions error")
|
|
})
|
|
}
|
|
|
|
func TestAzureChatCompletions(t *testing.T) {
|
|
client, server, teardown := setupAzureTestServer()
|
|
defer teardown()
|
|
server.RegisterHandler("/openai/deployments/*", handleChatCompletionEndpoint)
|
|
|
|
_, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{
|
|
MaxTokens: 5,
|
|
Model: openai.GPT3Dot5Turbo,
|
|
Messages: []openai.ChatCompletionMessage{
|
|
{
|
|
Role: openai.ChatMessageRoleUser,
|
|
Content: "Hello!",
|
|
},
|
|
},
|
|
})
|
|
checks.NoError(t, err, "CreateAzureChatCompletion error")
|
|
}
|
|
|
|
func TestMultipartChatCompletions(t *testing.T) {
|
|
client, server, teardown := setupAzureTestServer()
|
|
defer teardown()
|
|
server.RegisterHandler("/openai/deployments/*", handleChatCompletionEndpoint)
|
|
|
|
_, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{
|
|
MaxTokens: 5,
|
|
Model: openai.GPT3Dot5Turbo,
|
|
Messages: []openai.ChatCompletionMessage{
|
|
{
|
|
Role: openai.ChatMessageRoleUser,
|
|
MultiContent: []openai.ChatMessagePart{
|
|
{
|
|
Type: openai.ChatMessagePartTypeText,
|
|
Text: "Hello!",
|
|
},
|
|
{
|
|
Type: openai.ChatMessagePartTypeImageURL,
|
|
ImageURL: &openai.ChatMessageImageURL{
|
|
URL: "URL",
|
|
Detail: openai.ImageURLDetailLow,
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
})
|
|
checks.NoError(t, err, "CreateAzureChatCompletion error")
|
|
}
|
|
|
|
func TestMultipartChatMessageSerialization(t *testing.T) {
|
|
jsonText := `[{"role":"system","content":"system-message"},` +
|
|
`{"role":"user","content":[{"type":"text","text":"nice-text"},` +
|
|
`{"type":"image_url","image_url":{"url":"URL","detail":"high"}}]}]`
|
|
|
|
var msgs []openai.ChatCompletionMessage
|
|
err := json.Unmarshal([]byte(jsonText), &msgs)
|
|
if err != nil {
|
|
t.Fatalf("Expected no error: %s", err)
|
|
}
|
|
if len(msgs) != 2 {
|
|
t.Errorf("unexpected number of messages")
|
|
}
|
|
if msgs[0].Role != "system" || msgs[0].Content != "system-message" || msgs[0].MultiContent != nil {
|
|
t.Errorf("invalid user message: %v", msgs[0])
|
|
}
|
|
if msgs[1].Role != "user" || msgs[1].Content != "" || len(msgs[1].MultiContent) != 2 {
|
|
t.Errorf("invalid user message")
|
|
}
|
|
parts := msgs[1].MultiContent
|
|
if parts[0].Type != "text" || parts[0].Text != "nice-text" {
|
|
t.Errorf("invalid text part: %v", parts[0])
|
|
}
|
|
if parts[1].Type != "image_url" || parts[1].ImageURL.URL != "URL" || parts[1].ImageURL.Detail != "high" {
|
|
t.Errorf("invalid image_url part")
|
|
}
|
|
|
|
s, err := json.Marshal(msgs)
|
|
if err != nil {
|
|
t.Fatalf("Expected no error: %s", err)
|
|
}
|
|
res := strings.ReplaceAll(string(s), " ", "")
|
|
if res != jsonText {
|
|
t.Fatalf("invalid message: %s", string(s))
|
|
}
|
|
|
|
invalidMsg := []openai.ChatCompletionMessage{
|
|
{
|
|
Role: "user",
|
|
Content: "some-text",
|
|
MultiContent: []openai.ChatMessagePart{
|
|
{
|
|
Type: "text",
|
|
Text: "nice-text",
|
|
},
|
|
},
|
|
},
|
|
}
|
|
_, err = json.Marshal(invalidMsg)
|
|
if !errors.Is(err, openai.ErrContentFieldsMisused) {
|
|
t.Fatalf("Expected error: %s", err)
|
|
}
|
|
|
|
err = json.Unmarshal([]byte(`["not-a-message"]`), &msgs)
|
|
if err == nil {
|
|
t.Fatalf("Expected error")
|
|
}
|
|
|
|
emptyMultiContentMsg := openai.ChatCompletionMessage{
|
|
Role: "user",
|
|
MultiContent: []openai.ChatMessagePart{},
|
|
}
|
|
s, err = json.Marshal(emptyMultiContentMsg)
|
|
if err != nil {
|
|
t.Fatalf("Unexpected error")
|
|
}
|
|
res = strings.ReplaceAll(string(s), " ", "")
|
|
if res != `{"role":"user","content":""}` {
|
|
t.Fatalf("invalid message: %s", string(s))
|
|
}
|
|
}
|
|
|
|
// handleChatCompletionEndpoint Handles the ChatGPT completion endpoint by the test server.
|
|
func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
|
|
var err error
|
|
var resBytes []byte
|
|
|
|
// completions only accepts POST requests
|
|
if r.Method != "POST" {
|
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
|
}
|
|
var completionReq openai.ChatCompletionRequest
|
|
if completionReq, err = getChatCompletionBody(r); err != nil {
|
|
http.Error(w, "could not read request", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
res := openai.ChatCompletionResponse{
|
|
ID: strconv.Itoa(int(time.Now().Unix())),
|
|
Object: "test-object",
|
|
Created: time.Now().Unix(),
|
|
// would be nice to validate Model during testing, but
|
|
// this may not be possible with how much upkeep
|
|
// would be required / wouldn't make much sense
|
|
Model: completionReq.Model,
|
|
}
|
|
// create completions
|
|
n := completionReq.N
|
|
if n == 0 {
|
|
n = 1
|
|
}
|
|
for i := 0; i < n; i++ {
|
|
// if there are functions, include them
|
|
if len(completionReq.Functions) > 0 {
|
|
var fcb []byte
|
|
b := completionReq.Functions[0].Parameters
|
|
fcb, err = json.Marshal(b)
|
|
if err != nil {
|
|
http.Error(w, "could not marshal function parameters", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
res.Choices = append(res.Choices, openai.ChatCompletionChoice{
|
|
Message: openai.ChatCompletionMessage{
|
|
Role: openai.ChatMessageRoleFunction,
|
|
// this is valid json so it should be fine
|
|
FunctionCall: &openai.FunctionCall{
|
|
Name: completionReq.Functions[0].Name,
|
|
Arguments: string(fcb),
|
|
},
|
|
},
|
|
Index: i,
|
|
})
|
|
continue
|
|
}
|
|
// generate a random string of length completionReq.Length
|
|
completionStr := strings.Repeat("a", completionReq.MaxTokens)
|
|
|
|
res.Choices = append(res.Choices, openai.ChatCompletionChoice{
|
|
Message: openai.ChatCompletionMessage{
|
|
Role: openai.ChatMessageRoleAssistant,
|
|
Content: completionStr,
|
|
},
|
|
Index: i,
|
|
})
|
|
}
|
|
inputTokens := numTokens(completionReq.Messages[0].Content) * n
|
|
completionTokens := completionReq.MaxTokens * n
|
|
res.Usage = openai.Usage{
|
|
PromptTokens: inputTokens,
|
|
CompletionTokens: completionTokens,
|
|
TotalTokens: inputTokens + completionTokens,
|
|
}
|
|
resBytes, _ = json.Marshal(res)
|
|
w.Header().Set(xCustomHeader, xCustomHeaderValue)
|
|
for k, v := range rateLimitHeaders {
|
|
switch val := v.(type) {
|
|
case int:
|
|
w.Header().Set(k, strconv.Itoa(val))
|
|
default:
|
|
w.Header().Set(k, fmt.Sprintf("%s", v))
|
|
}
|
|
}
|
|
fmt.Fprintln(w, string(resBytes))
|
|
}
|
|
|
|
// getChatCompletionBody Returns the body of the request to create a completion.
|
|
func getChatCompletionBody(r *http.Request) (openai.ChatCompletionRequest, error) {
|
|
completion := openai.ChatCompletionRequest{}
|
|
// read the request body
|
|
reqBody, err := io.ReadAll(r.Body)
|
|
if err != nil {
|
|
return openai.ChatCompletionRequest{}, err
|
|
}
|
|
err = json.Unmarshal(reqBody, &completion)
|
|
if err != nil {
|
|
return openai.ChatCompletionRequest{}, err
|
|
}
|
|
return completion, nil
|
|
}
|
|
|
|
func TestFinishReason(t *testing.T) {
|
|
c := &openai.ChatCompletionChoice{
|
|
FinishReason: openai.FinishReasonNull,
|
|
}
|
|
resBytes, _ := json.Marshal(c)
|
|
if !strings.Contains(string(resBytes), `"finish_reason":null`) {
|
|
t.Error("null should not be quoted")
|
|
}
|
|
|
|
c.FinishReason = ""
|
|
|
|
resBytes, _ = json.Marshal(c)
|
|
if !strings.Contains(string(resBytes), `"finish_reason":null`) {
|
|
t.Error("null should not be quoted")
|
|
}
|
|
|
|
otherReasons := []openai.FinishReason{
|
|
openai.FinishReasonStop,
|
|
openai.FinishReasonLength,
|
|
openai.FinishReasonFunctionCall,
|
|
openai.FinishReasonContentFilter,
|
|
}
|
|
for _, r := range otherReasons {
|
|
c.FinishReason = r
|
|
resBytes, _ = json.Marshal(c)
|
|
if !strings.Contains(string(resBytes), fmt.Sprintf(`"finish_reason":"%s"`, r)) {
|
|
t.Errorf("%s should be quoted", r)
|
|
}
|
|
}
|
|
}
|