Fix broken implementation AssistantModify implementation (#685)

* add custom marshaller, documentation and isolate tests

* fix linter
This commit is contained in:
Quest Henkart
2024-03-15 18:59:16 +08:00
committed by GitHub
parent 699f397c36
commit 0925563e86
2 changed files with 112 additions and 33 deletions

View File

@@ -2,6 +2,7 @@ package openai
import ( import (
"context" "context"
"encoding/json"
"fmt" "fmt"
"net/http" "net/http"
"net/url" "net/url"
@@ -21,7 +22,7 @@ type Assistant struct {
Description *string `json:"description,omitempty"` Description *string `json:"description,omitempty"`
Model string `json:"model"` Model string `json:"model"`
Instructions *string `json:"instructions,omitempty"` Instructions *string `json:"instructions,omitempty"`
Tools []AssistantTool `json:"tools,omitempty"` Tools []AssistantTool `json:"tools"`
FileIDs []string `json:"file_ids,omitempty"` FileIDs []string `json:"file_ids,omitempty"`
Metadata map[string]any `json:"metadata,omitempty"` Metadata map[string]any `json:"metadata,omitempty"`
@@ -41,16 +42,41 @@ type AssistantTool struct {
Function *FunctionDefinition `json:"function,omitempty"` Function *FunctionDefinition `json:"function,omitempty"`
} }
// AssistantRequest provides the assistant request parameters.
// When modifying the tools the API functions as the following:
// If Tools is undefined, no changes are made to the Assistant's tools.
// If Tools is empty slice it will effectively delete all of the Assistant's tools.
// If Tools is populated, it will replace all of the existing Assistant's tools with the provided tools.
type AssistantRequest struct { type AssistantRequest struct {
Model string `json:"model"` Model string `json:"model"`
Name *string `json:"name,omitempty"` Name *string `json:"name,omitempty"`
Description *string `json:"description,omitempty"` Description *string `json:"description,omitempty"`
Instructions *string `json:"instructions,omitempty"` Instructions *string `json:"instructions,omitempty"`
Tools []AssistantTool `json:"tools"` Tools []AssistantTool `json:"-"`
FileIDs []string `json:"file_ids,omitempty"` FileIDs []string `json:"file_ids,omitempty"`
Metadata map[string]any `json:"metadata,omitempty"` Metadata map[string]any `json:"metadata,omitempty"`
} }
// MarshalJSON provides a custom marshaller for the assistant request to handle the API use cases
// If Tools is nil, the field is omitted from the JSON.
// If Tools is an empty slice, it's included in the JSON as an empty array ([]).
// If Tools is populated, it's included in the JSON with the elements.
func (a AssistantRequest) MarshalJSON() ([]byte, error) {
type Alias AssistantRequest
assistantAlias := &struct {
Tools *[]AssistantTool `json:"tools,omitempty"`
*Alias
}{
Alias: (*Alias)(&a),
}
if a.Tools != nil {
assistantAlias.Tools = &a.Tools
}
return json.Marshal(assistantAlias)
}
// AssistantsList is a list of assistants. // AssistantsList is a list of assistants.
type AssistantsList struct { type AssistantsList struct {
Assistants []Assistant `json:"data"` Assistants []Assistant `json:"data"`

View File

@@ -96,7 +96,7 @@ When asked a question, write and run Python code to answer the question.`
}) })
fmt.Fprintln(w, string(resBytes)) fmt.Fprintln(w, string(resBytes))
case http.MethodPost: case http.MethodPost:
var request openai.AssistantRequest var request openai.Assistant
err := json.NewDecoder(r.Body).Decode(&request) err := json.NewDecoder(r.Body).Decode(&request)
checks.NoError(t, err, "Decode error") checks.NoError(t, err, "Decode error")
@@ -163,44 +163,97 @@ When asked a question, write and run Python code to answer the question.`
ctx := context.Background() ctx := context.Background()
_, err := client.CreateAssistant(ctx, openai.AssistantRequest{ t.Run("create_assistant", func(t *testing.T) {
Name: &assistantName, _, err := client.CreateAssistant(ctx, openai.AssistantRequest{
Description: &assistantDescription, Name: &assistantName,
Model: openai.GPT4TurboPreview, Description: &assistantDescription,
Instructions: &assistantInstructions, Model: openai.GPT4TurboPreview,
Instructions: &assistantInstructions,
})
checks.NoError(t, err, "CreateAssistant error")
}) })
checks.NoError(t, err, "CreateAssistant error")
_, err = client.RetrieveAssistant(ctx, assistantID) t.Run("retrieve_assistant", func(t *testing.T) {
checks.NoError(t, err, "RetrieveAssistant error") _, err := client.RetrieveAssistant(ctx, assistantID)
checks.NoError(t, err, "RetrieveAssistant error")
_, err = client.ModifyAssistant(ctx, assistantID, openai.AssistantRequest{
Name: &assistantName,
Description: &assistantDescription,
Model: openai.GPT4TurboPreview,
Instructions: &assistantInstructions,
}) })
checks.NoError(t, err, "ModifyAssistant error")
_, err = client.DeleteAssistant(ctx, assistantID) t.Run("delete_assistant", func(t *testing.T) {
checks.NoError(t, err, "DeleteAssistant error") _, err := client.DeleteAssistant(ctx, assistantID)
checks.NoError(t, err, "DeleteAssistant error")
_, err = client.ListAssistants(ctx, &limit, &order, &after, &before)
checks.NoError(t, err, "ListAssistants error")
_, err = client.CreateAssistantFile(ctx, assistantID, openai.AssistantFileRequest{
FileID: assistantFileID,
}) })
checks.NoError(t, err, "CreateAssistantFile error")
_, err = client.ListAssistantFiles(ctx, assistantID, &limit, &order, &after, &before) t.Run("list_assistant", func(t *testing.T) {
checks.NoError(t, err, "ListAssistantFiles error") _, err := client.ListAssistants(ctx, &limit, &order, &after, &before)
checks.NoError(t, err, "ListAssistants error")
})
_, err = client.RetrieveAssistantFile(ctx, assistantID, assistantFileID) t.Run("create_assistant_file", func(t *testing.T) {
checks.NoError(t, err, "RetrieveAssistantFile error") _, err := client.CreateAssistantFile(ctx, assistantID, openai.AssistantFileRequest{
FileID: assistantFileID,
})
checks.NoError(t, err, "CreateAssistantFile error")
})
err = client.DeleteAssistantFile(ctx, assistantID, assistantFileID) t.Run("list_assistant_files", func(t *testing.T) {
checks.NoError(t, err, "DeleteAssistantFile error") _, err := client.ListAssistantFiles(ctx, assistantID, &limit, &order, &after, &before)
checks.NoError(t, err, "ListAssistantFiles error")
})
t.Run("retrieve_assistant_file", func(t *testing.T) {
_, err := client.RetrieveAssistantFile(ctx, assistantID, assistantFileID)
checks.NoError(t, err, "RetrieveAssistantFile error")
})
t.Run("delete_assistant_file", func(t *testing.T) {
err := client.DeleteAssistantFile(ctx, assistantID, assistantFileID)
checks.NoError(t, err, "DeleteAssistantFile error")
})
t.Run("modify_assistant_no_tools", func(t *testing.T) {
assistant, err := client.ModifyAssistant(ctx, assistantID, openai.AssistantRequest{
Name: &assistantName,
Description: &assistantDescription,
Model: openai.GPT4TurboPreview,
Instructions: &assistantInstructions,
})
checks.NoError(t, err, "ModifyAssistant error")
if assistant.Tools != nil {
t.Errorf("expected nil got %v", assistant.Tools)
}
})
t.Run("modify_assistant_with_tools", func(t *testing.T) {
assistant, err := client.ModifyAssistant(ctx, assistantID, openai.AssistantRequest{
Name: &assistantName,
Description: &assistantDescription,
Model: openai.GPT4TurboPreview,
Instructions: &assistantInstructions,
Tools: []openai.AssistantTool{{Type: openai.AssistantToolTypeFunction}},
})
checks.NoError(t, err, "ModifyAssistant error")
if assistant.Tools == nil || len(assistant.Tools) != 1 {
t.Errorf("expected a slice got %v", assistant.Tools)
}
})
t.Run("modify_assistant_empty_tools", func(t *testing.T) {
assistant, err := client.ModifyAssistant(ctx, assistantID, openai.AssistantRequest{
Name: &assistantName,
Description: &assistantDescription,
Model: openai.GPT4TurboPreview,
Instructions: &assistantInstructions,
Tools: make([]openai.AssistantTool, 0),
})
checks.NoError(t, err, "ModifyAssistant error")
if assistant.Tools == nil {
t.Errorf("expected a slice got %v", assistant.Tools)
}
})
} }
func TestAzureAssistant(t *testing.T) { func TestAzureAssistant(t *testing.T) {