Add support for multi part chat messages (and gpt-4-vision-preview model) (#580)
* Add support for multi part chat messages OpenAI has recently introduced a new model called gpt-4-visual-preview, which now supports images as input. The chat completion endpoint accepts multi-part chat messages, where the content can be an array of structs in addition to the usual string format. This commit introduces new structures and constants to represent different types of content parts. It also implements the json.Marshaler and json.Unmarshaler interfaces on ChatCompletionMessage. * Add ImageURLDetail and ChatMessagePartType types * Optimize ChatCompletionMessage deserialization * Add ErrContentFieldsMisused error
This commit is contained in:
103
chat_test.go
103
chat_test.go
@@ -3,6 +3,7 @@ package openai_test
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -296,6 +297,108 @@ func TestAzureChatCompletions(t *testing.T) {
|
||||
checks.NoError(t, err, "CreateAzureChatCompletion error")
|
||||
}
|
||||
|
||||
func TestMultipartChatCompletions(t *testing.T) {
|
||||
client, server, teardown := setupAzureTestServer()
|
||||
defer teardown()
|
||||
server.RegisterHandler("/openai/deployments/*", handleChatCompletionEndpoint)
|
||||
|
||||
_, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{
|
||||
MaxTokens: 5,
|
||||
Model: openai.GPT3Dot5Turbo,
|
||||
Messages: []openai.ChatCompletionMessage{
|
||||
{
|
||||
Role: openai.ChatMessageRoleUser,
|
||||
MultiContent: []openai.ChatMessagePart{
|
||||
{
|
||||
Type: openai.ChatMessagePartTypeText,
|
||||
Text: "Hello!",
|
||||
},
|
||||
{
|
||||
Type: openai.ChatMessagePartTypeImageURL,
|
||||
ImageURL: &openai.ChatMessageImageURL{
|
||||
URL: "URL",
|
||||
Detail: openai.ImageURLDetailLow,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
checks.NoError(t, err, "CreateAzureChatCompletion error")
|
||||
}
|
||||
|
||||
func TestMultipartChatMessageSerialization(t *testing.T) {
|
||||
jsonText := `[{"role":"system","content":"system-message"},` +
|
||||
`{"role":"user","content":[{"type":"text","text":"nice-text"},` +
|
||||
`{"type":"image_url","image_url":{"url":"URL","detail":"high"}}]}]`
|
||||
|
||||
var msgs []openai.ChatCompletionMessage
|
||||
err := json.Unmarshal([]byte(jsonText), &msgs)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error: %s", err)
|
||||
}
|
||||
if len(msgs) != 2 {
|
||||
t.Errorf("unexpected number of messages")
|
||||
}
|
||||
if msgs[0].Role != "system" || msgs[0].Content != "system-message" || msgs[0].MultiContent != nil {
|
||||
t.Errorf("invalid user message: %v", msgs[0])
|
||||
}
|
||||
if msgs[1].Role != "user" || msgs[1].Content != "" || len(msgs[1].MultiContent) != 2 {
|
||||
t.Errorf("invalid user message")
|
||||
}
|
||||
parts := msgs[1].MultiContent
|
||||
if parts[0].Type != "text" || parts[0].Text != "nice-text" {
|
||||
t.Errorf("invalid text part: %v", parts[0])
|
||||
}
|
||||
if parts[1].Type != "image_url" || parts[1].ImageURL.URL != "URL" || parts[1].ImageURL.Detail != "high" {
|
||||
t.Errorf("invalid image_url part")
|
||||
}
|
||||
|
||||
s, err := json.Marshal(msgs)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error: %s", err)
|
||||
}
|
||||
res := strings.ReplaceAll(string(s), " ", "")
|
||||
if res != jsonText {
|
||||
t.Fatalf("invalid message: %s", string(s))
|
||||
}
|
||||
|
||||
invalidMsg := []openai.ChatCompletionMessage{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "some-text",
|
||||
MultiContent: []openai.ChatMessagePart{
|
||||
{
|
||||
Type: "text",
|
||||
Text: "nice-text",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
_, err = json.Marshal(invalidMsg)
|
||||
if !errors.Is(err, openai.ErrContentFieldsMisused) {
|
||||
t.Fatalf("Expected error: %s", err)
|
||||
}
|
||||
|
||||
err = json.Unmarshal([]byte(`["not-a-message"]`), &msgs)
|
||||
if err == nil {
|
||||
t.Fatalf("Expected error")
|
||||
}
|
||||
|
||||
emptyMultiContentMsg := openai.ChatCompletionMessage{
|
||||
Role: "user",
|
||||
MultiContent: []openai.ChatMessagePart{},
|
||||
}
|
||||
s, err = json.Marshal(emptyMultiContentMsg)
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error")
|
||||
}
|
||||
res = strings.ReplaceAll(string(s), " ", "")
|
||||
if res != `{"role":"user","content":""}` {
|
||||
t.Fatalf("invalid message: %s", string(s))
|
||||
}
|
||||
}
|
||||
|
||||
// handleChatCompletionEndpoint Handles the ChatGPT completion endpoint by the test server.
|
||||
func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
|
||||
var err error
|
||||
|
||||
Reference in New Issue
Block a user