Allow structured outputs via function calling (#828)
This commit is contained in:
@@ -239,3 +239,79 @@ func TestChatCompletionResponseFormat_JSONSchema(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestChatCompletionStructuredOutputsFunctionCalling(t *testing.T) {
|
||||||
|
apiToken := os.Getenv("OPENAI_TOKEN")
|
||||||
|
if apiToken == "" {
|
||||||
|
t.Skip("Skipping testing against production OpenAI API. Set OPENAI_TOKEN environment variable to enable it.")
|
||||||
|
}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
c := openai.NewClient(apiToken)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
resp, err := c.CreateChatCompletion(
|
||||||
|
ctx,
|
||||||
|
openai.ChatCompletionRequest{
|
||||||
|
Model: openai.GPT4oMini,
|
||||||
|
Messages: []openai.ChatCompletionMessage{
|
||||||
|
{
|
||||||
|
Role: openai.ChatMessageRoleSystem,
|
||||||
|
Content: "Please enter a string, and we will convert it into the following naming conventions:" +
|
||||||
|
"1. PascalCase: Each word starts with an uppercase letter, with no spaces or separators." +
|
||||||
|
"2. CamelCase: The first word starts with a lowercase letter, " +
|
||||||
|
"and subsequent words start with an uppercase letter, with no spaces or separators." +
|
||||||
|
"3. KebabCase: All letters are lowercase, with words separated by hyphens `-`." +
|
||||||
|
"4. SnakeCase: All letters are lowercase, with words separated by underscores `_`.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: openai.ChatMessageRoleUser,
|
||||||
|
Content: "Hello World",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Tools: []openai.Tool{
|
||||||
|
{
|
||||||
|
Type: openai.ToolTypeFunction,
|
||||||
|
Function: &openai.FunctionDefinition{
|
||||||
|
Name: "display_cases",
|
||||||
|
Strict: true,
|
||||||
|
Parameters: &jsonschema.Definition{
|
||||||
|
Type: jsonschema.Object,
|
||||||
|
Properties: map[string]jsonschema.Definition{
|
||||||
|
"PascalCase": {
|
||||||
|
Type: jsonschema.String,
|
||||||
|
},
|
||||||
|
"CamelCase": {
|
||||||
|
Type: jsonschema.String,
|
||||||
|
},
|
||||||
|
"KebabCase": {
|
||||||
|
Type: jsonschema.String,
|
||||||
|
},
|
||||||
|
"SnakeCase": {
|
||||||
|
Type: jsonschema.String,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Required: []string{"PascalCase", "CamelCase", "KebabCase", "SnakeCase"},
|
||||||
|
AdditionalProperties: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
ToolChoice: openai.ToolChoice{
|
||||||
|
Type: openai.ToolTypeFunction,
|
||||||
|
Function: openai.ToolFunction{
|
||||||
|
Name: "display_cases",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
checks.NoError(t, err, "CreateChatCompletion (use structured outputs response) returned error")
|
||||||
|
var result = make(map[string]string)
|
||||||
|
err = json.Unmarshal([]byte(resp.Choices[0].Message.ToolCalls[0].Function.Arguments), &result)
|
||||||
|
checks.NoError(t, err, "CreateChatCompletion (use structured outputs response) unmarshal error")
|
||||||
|
for _, key := range []string{"PascalCase", "CamelCase", "KebabCase", "SnakeCase"} {
|
||||||
|
if _, ok := result[key]; !ok {
|
||||||
|
t.Errorf("key:%s does not exist.", key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
1
chat.go
1
chat.go
@@ -264,6 +264,7 @@ type ToolFunction struct {
|
|||||||
type FunctionDefinition struct {
|
type FunctionDefinition struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Description string `json:"description,omitempty"`
|
Description string `json:"description,omitempty"`
|
||||||
|
Strict bool `json:"strict,omitempty"`
|
||||||
// Parameters is an object describing the function.
|
// Parameters is an object describing the function.
|
||||||
// You can pass json.RawMessage to describe the schema,
|
// You can pass json.RawMessage to describe the schema,
|
||||||
// or you can pass in a struct which serializes to the proper JSON schema.
|
// or you can pass in a struct which serializes to the proper JSON schema.
|
||||||
|
|||||||
26
chat_test.go
26
chat_test.go
@@ -277,6 +277,32 @@ func TestChatCompletionsFunctions(t *testing.T) {
|
|||||||
})
|
})
|
||||||
checks.NoError(t, err, "CreateChatCompletion with functions error")
|
checks.NoError(t, err, "CreateChatCompletion with functions error")
|
||||||
})
|
})
|
||||||
|
t.Run("StructuredOutputs", func(t *testing.T) {
|
||||||
|
type testMessage struct {
|
||||||
|
Count int `json:"count"`
|
||||||
|
Words []string `json:"words"`
|
||||||
|
}
|
||||||
|
msg := testMessage{
|
||||||
|
Count: 2,
|
||||||
|
Words: []string{"hello", "world"},
|
||||||
|
}
|
||||||
|
_, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{
|
||||||
|
MaxTokens: 5,
|
||||||
|
Model: openai.GPT3Dot5Turbo0613,
|
||||||
|
Messages: []openai.ChatCompletionMessage{
|
||||||
|
{
|
||||||
|
Role: openai.ChatMessageRoleUser,
|
||||||
|
Content: "Hello!",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Functions: []openai.FunctionDefinition{{
|
||||||
|
Name: "test",
|
||||||
|
Strict: true,
|
||||||
|
Parameters: &msg,
|
||||||
|
}},
|
||||||
|
})
|
||||||
|
checks.NoError(t, err, "CreateChatCompletion with functions error")
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAzureChatCompletions(t *testing.T) {
|
func TestAzureChatCompletions(t *testing.T) {
|
||||||
|
|||||||
Reference in New Issue
Block a user