Feat/messages api (#546)
* fix test server setup: - go map access is not deterministic - this can lead to a route: /foo/bar/1 matching /foo/bar before matching /foo/bar/1 if the map iteration go through /foo/bar first since the regex match wasn't bound to start and end anchors - registering handlers now converts * in routes to .* for proper regex matching - test server route handling now tries to fully match the handler route * add missing /v1 prefix to fine-tuning job cancel test server handler * add create message call * add messages list call * add get message call * add modify message call, fix return types for other message calls * add message file retrieve call * add list message files call * code style fixes * add test for list messages with pagination options * add beta header to msg calls now that #545 is merged * Update messages.go Co-authored-by: Simone Vellei <henomis@gmail.com> * Update messages.go Co-authored-by: Simone Vellei <henomis@gmail.com> * add missing object details for message, fix tests * fix merge formatting * minor style fixes --------- Co-authored-by: Simone Vellei <henomis@gmail.com>
This commit is contained in:
committed by
GitHub
parent
9fefd50e12
commit
b7cac703ac
@@ -301,6 +301,24 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) {
|
|||||||
{"DeleteAssistantFile", func() (any, error) {
|
{"DeleteAssistantFile", func() (any, error) {
|
||||||
return nil, client.DeleteAssistantFile(ctx, "", "")
|
return nil, client.DeleteAssistantFile(ctx, "", "")
|
||||||
}},
|
}},
|
||||||
|
{"CreateMessage", func() (any, error) {
|
||||||
|
return client.CreateMessage(ctx, "", MessageRequest{})
|
||||||
|
}},
|
||||||
|
{"ListMessage", func() (any, error) {
|
||||||
|
return client.ListMessage(ctx, "", nil, nil, nil, nil)
|
||||||
|
}},
|
||||||
|
{"RetrieveMessage", func() (any, error) {
|
||||||
|
return client.RetrieveMessage(ctx, "", "")
|
||||||
|
}},
|
||||||
|
{"ModifyMessage", func() (any, error) {
|
||||||
|
return client.ModifyMessage(ctx, "", "", nil)
|
||||||
|
}},
|
||||||
|
{"RetrieveMessageFile", func() (any, error) {
|
||||||
|
return client.RetrieveMessageFile(ctx, "", "", "")
|
||||||
|
}},
|
||||||
|
{"ListMessageFiles", func() (any, error) {
|
||||||
|
return client.ListMessageFiles(ctx, "", "")
|
||||||
|
}},
|
||||||
{"CreateThread", func() (any, error) {
|
{"CreateThread", func() (any, error) {
|
||||||
return client.CreateThread(ctx, ThreadRequest{})
|
return client.CreateThread(ctx, ThreadRequest{})
|
||||||
}},
|
}},
|
||||||
|
|||||||
178
messages.go
Normal file
178
messages.go
Normal file
@@ -0,0 +1,178 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
messagesSuffix = "messages"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Message struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Object string `json:"object"`
|
||||||
|
CreatedAt int `json:"created_at"`
|
||||||
|
ThreadID string `json:"thread_id"`
|
||||||
|
Role string `json:"role"`
|
||||||
|
Content []MessageContent `json:"content"`
|
||||||
|
FileIds []string `json:"file_ids"`
|
||||||
|
AssistantID *string `json:"assistant_id,omitempty"`
|
||||||
|
RunID *string `json:"run_id,omitempty"`
|
||||||
|
Metadata map[string]any `json:"metadata"`
|
||||||
|
|
||||||
|
httpHeader
|
||||||
|
}
|
||||||
|
|
||||||
|
type MessagesList struct {
|
||||||
|
Messages []Message `json:"data"`
|
||||||
|
|
||||||
|
httpHeader
|
||||||
|
}
|
||||||
|
|
||||||
|
type MessageContent struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Text *MessageText `json:"text,omitempty"`
|
||||||
|
ImageFile *ImageFile `json:"image_file,omitempty"`
|
||||||
|
}
|
||||||
|
type MessageText struct {
|
||||||
|
Value string `json:"value"`
|
||||||
|
Annotations []any `json:"annotations"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ImageFile struct {
|
||||||
|
FileID string `json:"file_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type MessageRequest struct {
|
||||||
|
Role string `json:"role"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
FileIds []string `json:"file_ids,omitempty"`
|
||||||
|
Metadata map[string]any `json:"metadata,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type MessageFile struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Object string `json:"object"`
|
||||||
|
CreatedAt int `json:"created_at"`
|
||||||
|
MessageID string `json:"message_id"`
|
||||||
|
|
||||||
|
httpHeader
|
||||||
|
}
|
||||||
|
|
||||||
|
type MessageFilesList struct {
|
||||||
|
MessageFiles []MessageFile `json:"data"`
|
||||||
|
|
||||||
|
httpHeader
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateMessage creates a new message.
|
||||||
|
func (c *Client) CreateMessage(ctx context.Context, threadID string, request MessageRequest) (msg Message, err error) {
|
||||||
|
urlSuffix := fmt.Sprintf("/threads/%s/%s", threadID, messagesSuffix)
|
||||||
|
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request))
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = c.sendRequest(req, &msg)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListMessage fetches all messages in the thread.
|
||||||
|
func (c *Client) ListMessage(ctx context.Context, threadID string,
|
||||||
|
limit *int,
|
||||||
|
order *string,
|
||||||
|
after *string,
|
||||||
|
before *string,
|
||||||
|
) (messages MessagesList, err error) {
|
||||||
|
urlValues := url.Values{}
|
||||||
|
if limit != nil {
|
||||||
|
urlValues.Add("limit", fmt.Sprintf("%d", *limit))
|
||||||
|
}
|
||||||
|
if order != nil {
|
||||||
|
urlValues.Add("order", *order)
|
||||||
|
}
|
||||||
|
if after != nil {
|
||||||
|
urlValues.Add("after", *after)
|
||||||
|
}
|
||||||
|
if before != nil {
|
||||||
|
urlValues.Add("before", *before)
|
||||||
|
}
|
||||||
|
encodedValues := ""
|
||||||
|
if len(urlValues) > 0 {
|
||||||
|
encodedValues = "?" + urlValues.Encode()
|
||||||
|
}
|
||||||
|
|
||||||
|
urlSuffix := fmt.Sprintf("/threads/%s/%s%s", threadID, messagesSuffix, encodedValues)
|
||||||
|
req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), withBetaAssistantV1())
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = c.sendRequest(req, &messages)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// RetrieveMessage retrieves a Message.
|
||||||
|
func (c *Client) RetrieveMessage(
|
||||||
|
ctx context.Context,
|
||||||
|
threadID, messageID string,
|
||||||
|
) (msg Message, err error) {
|
||||||
|
urlSuffix := fmt.Sprintf("/threads/%s/%s/%s", threadID, messagesSuffix, messageID)
|
||||||
|
req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), withBetaAssistantV1())
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = c.sendRequest(req, &msg)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// ModifyMessage modifies a message.
|
||||||
|
func (c *Client) ModifyMessage(
|
||||||
|
ctx context.Context,
|
||||||
|
threadID, messageID string,
|
||||||
|
metadata map[string]any,
|
||||||
|
) (msg Message, err error) {
|
||||||
|
urlSuffix := fmt.Sprintf("/threads/%s/%s/%s", threadID, messagesSuffix, messageID)
|
||||||
|
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix),
|
||||||
|
withBody(metadata), withBetaAssistantV1())
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = c.sendRequest(req, &msg)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// RetrieveMessageFile fetches a message file.
|
||||||
|
func (c *Client) RetrieveMessageFile(
|
||||||
|
ctx context.Context,
|
||||||
|
threadID, messageID, fileID string,
|
||||||
|
) (file MessageFile, err error) {
|
||||||
|
urlSuffix := fmt.Sprintf("/threads/%s/%s/%s/files/%s", threadID, messagesSuffix, messageID, fileID)
|
||||||
|
req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), withBetaAssistantV1())
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = c.sendRequest(req, &file)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListMessageFiles fetches all files attached to a message.
|
||||||
|
func (c *Client) ListMessageFiles(
|
||||||
|
ctx context.Context,
|
||||||
|
threadID, messageID string,
|
||||||
|
) (files MessageFilesList, err error) {
|
||||||
|
urlSuffix := fmt.Sprintf("/threads/%s/%s/%s/files", threadID, messagesSuffix, messageID)
|
||||||
|
req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), withBetaAssistantV1())
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = c.sendRequest(req, &files)
|
||||||
|
return
|
||||||
|
}
|
||||||
235
messages_test.go
Normal file
235
messages_test.go
Normal file
@@ -0,0 +1,235 @@
|
|||||||
|
package openai_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/sashabaranov/go-openai"
|
||||||
|
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||||
|
)
|
||||||
|
|
||||||
|
var emptyStr = ""
|
||||||
|
|
||||||
|
// TestMessages Tests the messages endpoint of the API using the mocked server.
|
||||||
|
func TestMessages(t *testing.T) {
|
||||||
|
threadID := "thread_abc123"
|
||||||
|
messageID := "msg_abc123"
|
||||||
|
fileID := "file_abc123"
|
||||||
|
|
||||||
|
client, server, teardown := setupOpenAITestServer()
|
||||||
|
defer teardown()
|
||||||
|
|
||||||
|
server.RegisterHandler(
|
||||||
|
"/v1/threads/"+threadID+"/messages/"+messageID+"/files/"+fileID,
|
||||||
|
func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.Method {
|
||||||
|
case http.MethodGet:
|
||||||
|
resBytes, _ := json.Marshal(
|
||||||
|
openai.MessageFile{
|
||||||
|
ID: fileID,
|
||||||
|
Object: "thread.message.file",
|
||||||
|
CreatedAt: 1699061776,
|
||||||
|
MessageID: messageID,
|
||||||
|
})
|
||||||
|
fmt.Fprintln(w, string(resBytes))
|
||||||
|
default:
|
||||||
|
t.Fatalf("unsupported messages http method: %s", r.Method)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
server.RegisterHandler(
|
||||||
|
"/v1/threads/"+threadID+"/messages/"+messageID+"/files",
|
||||||
|
func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.Method {
|
||||||
|
case http.MethodGet:
|
||||||
|
resBytes, _ := json.Marshal(
|
||||||
|
openai.MessageFilesList{MessageFiles: []openai.MessageFile{{
|
||||||
|
ID: fileID,
|
||||||
|
Object: "thread.message.file",
|
||||||
|
CreatedAt: 0,
|
||||||
|
MessageID: messageID,
|
||||||
|
}}})
|
||||||
|
fmt.Fprintln(w, string(resBytes))
|
||||||
|
default:
|
||||||
|
t.Fatalf("unsupported messages http method: %s", r.Method)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
server.RegisterHandler(
|
||||||
|
"/v1/threads/"+threadID+"/messages/"+messageID,
|
||||||
|
func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.Method {
|
||||||
|
case http.MethodPost:
|
||||||
|
metadata := map[string]any{}
|
||||||
|
err := json.NewDecoder(r.Body).Decode(&metadata)
|
||||||
|
checks.NoError(t, err, "unable to decode metadata in modify message call")
|
||||||
|
|
||||||
|
resBytes, _ := json.Marshal(
|
||||||
|
openai.Message{
|
||||||
|
ID: messageID,
|
||||||
|
Object: "thread.message",
|
||||||
|
CreatedAt: 1234567890,
|
||||||
|
ThreadID: threadID,
|
||||||
|
Role: "user",
|
||||||
|
Content: []openai.MessageContent{{
|
||||||
|
Type: "text",
|
||||||
|
Text: &openai.MessageText{
|
||||||
|
Value: "How does AI work?",
|
||||||
|
Annotations: nil,
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
FileIds: nil,
|
||||||
|
AssistantID: &emptyStr,
|
||||||
|
RunID: &emptyStr,
|
||||||
|
Metadata: metadata,
|
||||||
|
})
|
||||||
|
fmt.Fprintln(w, string(resBytes))
|
||||||
|
case http.MethodGet:
|
||||||
|
resBytes, _ := json.Marshal(
|
||||||
|
openai.Message{
|
||||||
|
ID: messageID,
|
||||||
|
Object: "thread.message",
|
||||||
|
CreatedAt: 1234567890,
|
||||||
|
ThreadID: threadID,
|
||||||
|
Role: "user",
|
||||||
|
Content: []openai.MessageContent{{
|
||||||
|
Type: "text",
|
||||||
|
Text: &openai.MessageText{
|
||||||
|
Value: "How does AI work?",
|
||||||
|
Annotations: nil,
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
FileIds: nil,
|
||||||
|
AssistantID: &emptyStr,
|
||||||
|
RunID: &emptyStr,
|
||||||
|
Metadata: nil,
|
||||||
|
})
|
||||||
|
fmt.Fprintln(w, string(resBytes))
|
||||||
|
default:
|
||||||
|
t.Fatalf("unsupported messages http method: %s", r.Method)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
server.RegisterHandler(
|
||||||
|
"/v1/threads/"+threadID+"/messages",
|
||||||
|
func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.Method {
|
||||||
|
case http.MethodPost:
|
||||||
|
resBytes, _ := json.Marshal(openai.Message{
|
||||||
|
ID: messageID,
|
||||||
|
Object: "thread.message",
|
||||||
|
CreatedAt: 1234567890,
|
||||||
|
ThreadID: threadID,
|
||||||
|
Role: "user",
|
||||||
|
Content: []openai.MessageContent{{
|
||||||
|
Type: "text",
|
||||||
|
Text: &openai.MessageText{
|
||||||
|
Value: "How does AI work?",
|
||||||
|
Annotations: nil,
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
FileIds: nil,
|
||||||
|
AssistantID: &emptyStr,
|
||||||
|
RunID: &emptyStr,
|
||||||
|
Metadata: nil,
|
||||||
|
})
|
||||||
|
fmt.Fprintln(w, string(resBytes))
|
||||||
|
case http.MethodGet:
|
||||||
|
resBytes, _ := json.Marshal(openai.MessagesList{
|
||||||
|
Messages: []openai.Message{{
|
||||||
|
ID: messageID,
|
||||||
|
Object: "thread.message",
|
||||||
|
CreatedAt: 1234567890,
|
||||||
|
ThreadID: threadID,
|
||||||
|
Role: "user",
|
||||||
|
Content: []openai.MessageContent{{
|
||||||
|
Type: "text",
|
||||||
|
Text: &openai.MessageText{
|
||||||
|
Value: "How does AI work?",
|
||||||
|
Annotations: nil,
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
FileIds: nil,
|
||||||
|
AssistantID: &emptyStr,
|
||||||
|
RunID: &emptyStr,
|
||||||
|
Metadata: nil,
|
||||||
|
}}})
|
||||||
|
fmt.Fprintln(w, string(resBytes))
|
||||||
|
default:
|
||||||
|
t.Fatalf("unsupported messages http method: %s", r.Method)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// static assertion of return type
|
||||||
|
var msg openai.Message
|
||||||
|
msg, err := client.CreateMessage(ctx, threadID, openai.MessageRequest{
|
||||||
|
Role: "user",
|
||||||
|
Content: "How does AI work?",
|
||||||
|
FileIds: nil,
|
||||||
|
Metadata: nil,
|
||||||
|
})
|
||||||
|
checks.NoError(t, err, "CreateMessage error")
|
||||||
|
if msg.ID != messageID {
|
||||||
|
t.Fatalf("unexpected message id: '%s'", msg.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
var msgs openai.MessagesList
|
||||||
|
msgs, err = client.ListMessage(ctx, threadID, nil, nil, nil, nil)
|
||||||
|
checks.NoError(t, err, "ListMessages error")
|
||||||
|
if len(msgs.Messages) != 1 {
|
||||||
|
t.Fatalf("unexpected length of fetched messages")
|
||||||
|
}
|
||||||
|
|
||||||
|
// with pagination options set
|
||||||
|
limit := 1
|
||||||
|
order := "desc"
|
||||||
|
after := "obj_foo"
|
||||||
|
before := "obj_bar"
|
||||||
|
msgs, err = client.ListMessage(ctx, threadID, &limit, &order, &after, &before)
|
||||||
|
checks.NoError(t, err, "ListMessages error")
|
||||||
|
if len(msgs.Messages) != 1 {
|
||||||
|
t.Fatalf("unexpected length of fetched messages")
|
||||||
|
}
|
||||||
|
|
||||||
|
msg, err = client.RetrieveMessage(ctx, threadID, messageID)
|
||||||
|
checks.NoError(t, err, "RetrieveMessage error")
|
||||||
|
if msg.ID != messageID {
|
||||||
|
t.Fatalf("unexpected message id: '%s'", msg.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
msg, err = client.ModifyMessage(ctx, threadID, messageID,
|
||||||
|
map[string]any{
|
||||||
|
"foo": "bar",
|
||||||
|
})
|
||||||
|
checks.NoError(t, err, "ModifyMessage error")
|
||||||
|
if msg.Metadata["foo"] != "bar" {
|
||||||
|
t.Fatalf("expected message metadata to get modified")
|
||||||
|
}
|
||||||
|
|
||||||
|
// message files
|
||||||
|
var msgFile openai.MessageFile
|
||||||
|
msgFile, err = client.RetrieveMessageFile(ctx, threadID, messageID, fileID)
|
||||||
|
checks.NoError(t, err, "RetrieveMessageFile error")
|
||||||
|
if msgFile.ID != fileID {
|
||||||
|
t.Fatalf("unexpected message file id: '%s'", msgFile.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
var msgFiles openai.MessageFilesList
|
||||||
|
msgFiles, err = client.ListMessageFiles(ctx, threadID, messageID)
|
||||||
|
checks.NoError(t, err, "RetrieveMessageFile error")
|
||||||
|
if len(msgFiles.MessageFiles) != 1 {
|
||||||
|
t.Fatalf("unexpected count of message files: %d", len(msgFiles.MessageFiles))
|
||||||
|
}
|
||||||
|
if msgFiles.MessageFiles[0].ID != fileID {
|
||||||
|
t.Fatalf("unexpected message file id: '%s' in list message files", msgFiles.MessageFiles[0].ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user