From 2bd65aa720926506c49ddf89d7e619b3b83512c4 Mon Sep 17 00:00:00 2001 From: Ccheers <1048315650@qq.com> Date: Thu, 15 Jun 2023 16:49:54 +0800 Subject: [PATCH] feat(chat): support function call api (#369) * feat(chat): support function call api * rename struct & add const ChatMessageRoleFunction --- chat.go | 77 +++++++++++++++++++++++++++++++++++++++++++++++--- chat_stream.go | 2 +- completion.go | 2 +- 3 files changed, 75 insertions(+), 6 deletions(-) diff --git a/chat.go b/chat.go index a7ce548..c8cff31 100644 --- a/chat.go +++ b/chat.go @@ -11,8 +11,11 @@ const ( ChatMessageRoleSystem = "system" ChatMessageRoleUser = "user" ChatMessageRoleAssistant = "assistant" + ChatMessageRoleFunction = "function" ) +const chatCompletionsSuffix = "/chat/completions" + var ( ErrChatCompletionInvalidModel = errors.New("this model is not supported with this method, please use CreateCompletion client method instead") //nolint:lll ErrChatCompletionStreamNotSupported = errors.New("streaming is not supported with this method, please use CreateChatCompletionStream") //nolint:lll @@ -27,6 +30,14 @@ type ChatCompletionMessage struct { // - https://github.com/openai/openai-python/blob/main/chatml.md // - https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb Name string `json:"name,omitempty"` + + FunctionCall *FunctionCall `json:"function_call,omitempty"` +} + +type FunctionCall struct { + Name string `json:"name,omitempty"` + // call function with arguments in JSON format + Arguments string `json:"arguments,omitempty"` } // ChatCompletionRequest represents a request structure for chat completion API. @@ -43,12 +54,70 @@ type ChatCompletionRequest struct { FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` LogitBias map[string]int `json:"logit_bias,omitempty"` User string `json:"user,omitempty"` + Functions []*FunctionDefine `json:"functions,omitempty"` + FunctionCall string `json:"function_call,omitempty"` } +type FunctionDefine struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + // it's required in function call + Parameters *FunctionParams `json:"parameters"` +} + +type FunctionParams struct { + // the Type must be JSONSchemaTypeObject + Type JSONSchemaType `json:"type"` + Properties map[string]*JSONSchemaDefine `json:"properties,omitempty"` + Required []string `json:"required,omitempty"` +} + +type JSONSchemaType string + +const ( + JSONSchemaTypeObject JSONSchemaType = "object" + JSONSchemaTypeNumber JSONSchemaType = "number" + JSONSchemaTypeString JSONSchemaType = "string" + JSONSchemaTypeArray JSONSchemaType = "array" + JSONSchemaTypeNull JSONSchemaType = "null" + JSONSchemaTypeBoolean JSONSchemaType = "boolean" +) + +// JSONSchemaDefine is a struct for JSON Schema. +type JSONSchemaDefine struct { + // Type is a type of JSON Schema. + Type JSONSchemaType `json:"type,omitempty"` + // Description is a description of JSON Schema. + Description string `json:"description,omitempty"` + // Enum is a enum of JSON Schema. It used if Type is JSONSchemaTypeString. + Enum []string `json:"enum,omitempty"` + // Properties is a properties of JSON Schema. It used if Type is JSONSchemaTypeObject. + Properties map[string]*JSONSchemaDefine `json:"properties,omitempty"` + // Required is a required of JSON Schema. It used if Type is JSONSchemaTypeObject. + Required []string `json:"required,omitempty"` +} + +type FinishReason string + +const ( + FinishReasonStop FinishReason = "stop" + FinishReasonLength FinishReason = "length" + FinishReasonFunctionCall FinishReason = "function_call" + FinishReasonContentFilter FinishReason = "content_filter" + FinishReasonNull FinishReason = "null" +) + type ChatCompletionChoice struct { - Index int `json:"index"` - Message ChatCompletionMessage `json:"message"` - FinishReason string `json:"finish_reason"` + Index int `json:"index"` + Message ChatCompletionMessage `json:"message"` + // FinishReason + // stop: API returned complete message, + // or a message terminated by one of the stop sequences provided via the stop parameter + // length: Incomplete model output due to max_tokens parameter or token limit + // function_call: The model decided to call a function + // content_filter: Omitted content due to a flag from our content filters + // null: API response still in progress or incomplete + FinishReason FinishReason `json:"finish_reason"` } // ChatCompletionResponse represents a response structure for chat completion API. @@ -71,7 +140,7 @@ func (c *Client) CreateChatCompletion( return } - urlSuffix := "/chat/completions" + urlSuffix := chatCompletionsSuffix if !checkEndpointSupportsModel(urlSuffix, request.Model) { err = ErrChatCompletionInvalidModel return diff --git a/chat_stream.go b/chat_stream.go index 625d436..c7341fe 100644 --- a/chat_stream.go +++ b/chat_stream.go @@ -40,7 +40,7 @@ func (c *Client) CreateChatCompletionStream( ctx context.Context, request ChatCompletionRequest, ) (stream *ChatCompletionStream, err error) { - urlSuffix := "/chat/completions" + urlSuffix := chatCompletionsSuffix if !checkEndpointSupportsModel(urlSuffix, request.Model) { err = ErrChatCompletionInvalidModel return diff --git a/completion.go b/completion.go index efded20..e0571b0 100644 --- a/completion.go +++ b/completion.go @@ -65,7 +65,7 @@ var disabledModelsForEndpoints = map[string]map[string]bool{ GPT432K0314: true, GPT432K0613: true, }, - "/chat/completions": { + chatCompletionsSuffix: { CodexCodeDavinci002: true, CodexCodeCushman001: true, CodexCodeDavinci001: true,