feat: allow more input types to functions, fix tests (#377)
* feat: use json.rawMessage, test functions
* chore: lint
* fix: tests
the ChatCompletion mock server doesn't actually run otherwise. N=0
is the default request but the server will treat it as n=1
* fix: tests should default to n=1 completions
* chore: add back removed interfaces, custom marshal
* chore: lint
* chore: lint
* chore: add some tests
* chore: appease lint
* clean up JSON schema + tests
* chore: lint
* feat: remove backwards compatible functions
for illustrative purposes
* fix: revert params change
* chore: use interface{}
* chore: add test
* chore: add back FunctionDefine
* chore: /s/interface{}/any
* chore: add back jsonschemadefinition
* chore: testcov
* chore: lint
* chore: remove pointers
* chore: update comment
* chore: address CR
added test for compatibility as well
---------
Co-authored-by: James <jmacwhyte@MacBooger-II.local>
This commit is contained in:
34
chat.go
34
chat.go
@@ -54,23 +54,23 @@ type ChatCompletionRequest struct {
|
|||||||
FrequencyPenalty float32 `json:"frequency_penalty,omitempty"`
|
FrequencyPenalty float32 `json:"frequency_penalty,omitempty"`
|
||||||
LogitBias map[string]int `json:"logit_bias,omitempty"`
|
LogitBias map[string]int `json:"logit_bias,omitempty"`
|
||||||
User string `json:"user,omitempty"`
|
User string `json:"user,omitempty"`
|
||||||
Functions []*FunctionDefine `json:"functions,omitempty"`
|
Functions []FunctionDefinition `json:"functions,omitempty"`
|
||||||
FunctionCall string `json:"function_call,omitempty"`
|
FunctionCall any `json:"function_call,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type FunctionDefine struct {
|
type FunctionDefinition struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Description string `json:"description,omitempty"`
|
Description string `json:"description,omitempty"`
|
||||||
// it's required in function call
|
// Parameters is an object describing the function.
|
||||||
Parameters *FunctionParams `json:"parameters"`
|
// You can pass a raw byte array describing the schema,
|
||||||
|
// or you can pass in a struct which serializes to the proper JSONSchema.
|
||||||
|
// The JSONSchemaDefinition struct is provided for convenience, but you should
|
||||||
|
// consider another specialized library for more complex schemas.
|
||||||
|
Parameters any `json:"parameters"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type FunctionParams struct {
|
// Deprecated: use FunctionDefinition instead.
|
||||||
// the Type must be JSONSchemaTypeObject
|
type FunctionDefine = FunctionDefinition
|
||||||
Type JSONSchemaType `json:"type"`
|
|
||||||
Properties map[string]*JSONSchemaDefine `json:"properties,omitempty"`
|
|
||||||
Required []string `json:"required,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type JSONSchemaType string
|
type JSONSchemaType string
|
||||||
|
|
||||||
@@ -83,8 +83,9 @@ const (
|
|||||||
JSONSchemaTypeBoolean JSONSchemaType = "boolean"
|
JSONSchemaTypeBoolean JSONSchemaType = "boolean"
|
||||||
)
|
)
|
||||||
|
|
||||||
// JSONSchemaDefine is a struct for JSON Schema.
|
// JSONSchemaDefinition is a struct for JSON Schema.
|
||||||
type JSONSchemaDefine struct {
|
// It is fairly limited and you may have better luck using a third-party library.
|
||||||
|
type JSONSchemaDefinition struct {
|
||||||
// Type is a type of JSON Schema.
|
// Type is a type of JSON Schema.
|
||||||
Type JSONSchemaType `json:"type,omitempty"`
|
Type JSONSchemaType `json:"type,omitempty"`
|
||||||
// Description is a description of JSON Schema.
|
// Description is a description of JSON Schema.
|
||||||
@@ -92,13 +93,16 @@ type JSONSchemaDefine struct {
|
|||||||
// Enum is a enum of JSON Schema. It used if Type is JSONSchemaTypeString.
|
// Enum is a enum of JSON Schema. It used if Type is JSONSchemaTypeString.
|
||||||
Enum []string `json:"enum,omitempty"`
|
Enum []string `json:"enum,omitempty"`
|
||||||
// Properties is a properties of JSON Schema. It used if Type is JSONSchemaTypeObject.
|
// Properties is a properties of JSON Schema. It used if Type is JSONSchemaTypeObject.
|
||||||
Properties map[string]*JSONSchemaDefine `json:"properties,omitempty"`
|
Properties map[string]JSONSchemaDefinition `json:"properties,omitempty"`
|
||||||
// Required is a required of JSON Schema. It used if Type is JSONSchemaTypeObject.
|
// Required is a required of JSON Schema. It used if Type is JSONSchemaTypeObject.
|
||||||
Required []string `json:"required,omitempty"`
|
Required []string `json:"required,omitempty"`
|
||||||
// Items is a property of JSON Schema. It used if Type is JSONSchemaTypeArray.
|
// Items is a property of JSON Schema. It used if Type is JSONSchemaTypeArray.
|
||||||
Items *JSONSchemaDefine `json:"items,omitempty"`
|
Items *JSONSchemaDefinition `json:"items,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Deprecated: use JSONSchemaDefinition instead.
|
||||||
|
type JSONSchemaDefine = JSONSchemaDefinition
|
||||||
|
|
||||||
type FinishReason string
|
type FinishReason string
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
|||||||
157
chat_test.go
157
chat_test.go
@@ -67,6 +67,130 @@ func TestChatCompletions(t *testing.T) {
|
|||||||
checks.NoError(t, err, "CreateChatCompletion error")
|
checks.NoError(t, err, "CreateChatCompletion error")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestChatCompletionsFunctions tests including a function call.
|
||||||
|
func TestChatCompletionsFunctions(t *testing.T) {
|
||||||
|
client, server, teardown := setupOpenAITestServer()
|
||||||
|
defer teardown()
|
||||||
|
server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint)
|
||||||
|
t.Run("bytes", func(t *testing.T) {
|
||||||
|
//nolint:lll
|
||||||
|
msg := json.RawMessage(`{"properties":{"count":{"type":"integer","description":"total number of words in sentence"},"words":{"items":{"type":"string"},"type":"array","description":"list of words in sentence"}},"type":"object","required":["count","words"]}`)
|
||||||
|
_, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{
|
||||||
|
MaxTokens: 5,
|
||||||
|
Model: GPT3Dot5Turbo0613,
|
||||||
|
Messages: []ChatCompletionMessage{
|
||||||
|
{
|
||||||
|
Role: ChatMessageRoleUser,
|
||||||
|
Content: "Hello!",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Functions: []FunctionDefine{{
|
||||||
|
Name: "test",
|
||||||
|
Parameters: &msg,
|
||||||
|
}},
|
||||||
|
})
|
||||||
|
checks.NoError(t, err, "CreateChatCompletion with functions error")
|
||||||
|
})
|
||||||
|
t.Run("struct", func(t *testing.T) {
|
||||||
|
type testMessage struct {
|
||||||
|
Count int `json:"count"`
|
||||||
|
Words []string `json:"words"`
|
||||||
|
}
|
||||||
|
msg := testMessage{
|
||||||
|
Count: 2,
|
||||||
|
Words: []string{"hello", "world"},
|
||||||
|
}
|
||||||
|
_, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{
|
||||||
|
MaxTokens: 5,
|
||||||
|
Model: GPT3Dot5Turbo0613,
|
||||||
|
Messages: []ChatCompletionMessage{
|
||||||
|
{
|
||||||
|
Role: ChatMessageRoleUser,
|
||||||
|
Content: "Hello!",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Functions: []FunctionDefinition{{
|
||||||
|
Name: "test",
|
||||||
|
Parameters: &msg,
|
||||||
|
}},
|
||||||
|
})
|
||||||
|
checks.NoError(t, err, "CreateChatCompletion with functions error")
|
||||||
|
})
|
||||||
|
t.Run("JSONSchemaDefine", func(t *testing.T) {
|
||||||
|
_, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{
|
||||||
|
MaxTokens: 5,
|
||||||
|
Model: GPT3Dot5Turbo0613,
|
||||||
|
Messages: []ChatCompletionMessage{
|
||||||
|
{
|
||||||
|
Role: ChatMessageRoleUser,
|
||||||
|
Content: "Hello!",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Functions: []FunctionDefinition{{
|
||||||
|
Name: "test",
|
||||||
|
Parameters: &JSONSchemaDefinition{
|
||||||
|
Type: JSONSchemaTypeObject,
|
||||||
|
Properties: map[string]JSONSchemaDefinition{
|
||||||
|
"count": {
|
||||||
|
Type: JSONSchemaTypeNumber,
|
||||||
|
Description: "total number of words in sentence",
|
||||||
|
},
|
||||||
|
"words": {
|
||||||
|
Type: JSONSchemaTypeArray,
|
||||||
|
Description: "list of words in sentence",
|
||||||
|
Items: &JSONSchemaDefinition{
|
||||||
|
Type: JSONSchemaTypeString,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"enumTest": {
|
||||||
|
Type: JSONSchemaTypeString,
|
||||||
|
Enum: []string{"hello", "world"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
})
|
||||||
|
checks.NoError(t, err, "CreateChatCompletion with functions error")
|
||||||
|
})
|
||||||
|
t.Run("JSONSchemaDefineWithFunctionDefine", func(t *testing.T) {
|
||||||
|
// this is a compatibility check
|
||||||
|
_, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{
|
||||||
|
MaxTokens: 5,
|
||||||
|
Model: GPT3Dot5Turbo0613,
|
||||||
|
Messages: []ChatCompletionMessage{
|
||||||
|
{
|
||||||
|
Role: ChatMessageRoleUser,
|
||||||
|
Content: "Hello!",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Functions: []FunctionDefine{{
|
||||||
|
Name: "test",
|
||||||
|
Parameters: &JSONSchemaDefine{
|
||||||
|
Type: JSONSchemaTypeObject,
|
||||||
|
Properties: map[string]JSONSchemaDefine{
|
||||||
|
"count": {
|
||||||
|
Type: JSONSchemaTypeNumber,
|
||||||
|
Description: "total number of words in sentence",
|
||||||
|
},
|
||||||
|
"words": {
|
||||||
|
Type: JSONSchemaTypeArray,
|
||||||
|
Description: "list of words in sentence",
|
||||||
|
Items: &JSONSchemaDefine{
|
||||||
|
Type: JSONSchemaTypeString,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"enumTest": {
|
||||||
|
Type: JSONSchemaTypeString,
|
||||||
|
Enum: []string{"hello", "world"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
})
|
||||||
|
checks.NoError(t, err, "CreateChatCompletion with functions error")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestAzureChatCompletions(t *testing.T) {
|
func TestAzureChatCompletions(t *testing.T) {
|
||||||
client, server, teardown := setupAzureTestServer()
|
client, server, teardown := setupAzureTestServer()
|
||||||
defer teardown()
|
defer teardown()
|
||||||
@@ -109,7 +233,34 @@ func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
|
|||||||
Model: completionReq.Model,
|
Model: completionReq.Model,
|
||||||
}
|
}
|
||||||
// create completions
|
// create completions
|
||||||
for i := 0; i < completionReq.N; i++ {
|
n := completionReq.N
|
||||||
|
if n == 0 {
|
||||||
|
n = 1
|
||||||
|
}
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
// if there are functions, include them
|
||||||
|
if len(completionReq.Functions) > 0 {
|
||||||
|
var fcb []byte
|
||||||
|
b := completionReq.Functions[0].Parameters
|
||||||
|
fcb, err = json.Marshal(b)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "could not marshal function parameters", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
res.Choices = append(res.Choices, ChatCompletionChoice{
|
||||||
|
Message: ChatCompletionMessage{
|
||||||
|
Role: ChatMessageRoleFunction,
|
||||||
|
// this is valid json so it should be fine
|
||||||
|
FunctionCall: &FunctionCall{
|
||||||
|
Name: completionReq.Functions[0].Name,
|
||||||
|
Arguments: string(fcb),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Index: i,
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
}
|
||||||
// generate a random string of length completionReq.Length
|
// generate a random string of length completionReq.Length
|
||||||
completionStr := strings.Repeat("a", completionReq.MaxTokens)
|
completionStr := strings.Repeat("a", completionReq.MaxTokens)
|
||||||
|
|
||||||
@@ -121,8 +272,8 @@ func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
|
|||||||
Index: i,
|
Index: i,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
inputTokens := numTokens(completionReq.Messages[0].Content) * completionReq.N
|
inputTokens := numTokens(completionReq.Messages[0].Content) * n
|
||||||
completionTokens := completionReq.MaxTokens * completionReq.N
|
completionTokens := completionReq.MaxTokens * n
|
||||||
res.Usage = Usage{
|
res.Usage = Usage{
|
||||||
PromptTokens: inputTokens,
|
PromptTokens: inputTokens,
|
||||||
CompletionTokens: completionTokens,
|
CompletionTokens: completionTokens,
|
||||||
|
|||||||
@@ -83,7 +83,11 @@ func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
|
|||||||
Model: completionReq.Model,
|
Model: completionReq.Model,
|
||||||
}
|
}
|
||||||
// create completions
|
// create completions
|
||||||
for i := 0; i < completionReq.N; i++ {
|
n := completionReq.N
|
||||||
|
if n == 0 {
|
||||||
|
n = 1
|
||||||
|
}
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
// generate a random string of length completionReq.Length
|
// generate a random string of length completionReq.Length
|
||||||
completionStr := strings.Repeat("a", completionReq.MaxTokens)
|
completionStr := strings.Repeat("a", completionReq.MaxTokens)
|
||||||
if completionReq.Echo {
|
if completionReq.Echo {
|
||||||
@@ -94,8 +98,8 @@ func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
|
|||||||
Index: i,
|
Index: i,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
inputTokens := numTokens(completionReq.Prompt.(string)) * completionReq.N
|
inputTokens := numTokens(completionReq.Prompt.(string)) * n
|
||||||
completionTokens := completionReq.MaxTokens * completionReq.N
|
completionTokens := completionReq.MaxTokens * n
|
||||||
res.Usage = Usage{
|
res.Usage = Usage{
|
||||||
PromptTokens: inputTokens,
|
PromptTokens: inputTokens,
|
||||||
CompletionTokens: completionTokens,
|
CompletionTokens: completionTokens,
|
||||||
|
|||||||
Reference in New Issue
Block a user