Add support for reasoning_content field in chat completion messages for DeepSeek R1 (#925)
* support deepseek field "reasoning_content" * support deepseek field "reasoning_content" * Comment ends in a period (godot) * add comment on field reasoning_content * fix go lint error * chore: trigger CI * make field "content" in MarshalJSON function omitempty * remove reasoning_content in TestO1ModelChatCompletions func * feat: Add test and handler for deepseek-reasoner chat model completions, including support for reasoning content in responses. * feat: Add test and handler for deepseek-reasoner chat model completions, including support for reasoning content in responses. * feat: Add test and handler for deepseek-reasoner chat model completions, including support for reasoning content in responses.
This commit is contained in:
74
chat.go
74
chat.go
@@ -104,6 +104,12 @@ type ChatCompletionMessage struct {
|
|||||||
// - https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
// - https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
||||||
Name string `json:"name,omitempty"`
|
Name string `json:"name,omitempty"`
|
||||||
|
|
||||||
|
// This property is used for the "reasoning" feature supported by deepseek-reasoner
|
||||||
|
// which is not in the official documentation.
|
||||||
|
// the doc from deepseek:
|
||||||
|
// - https://api-docs.deepseek.com/api/create-chat-completion#responses
|
||||||
|
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||||
|
|
||||||
FunctionCall *FunctionCall `json:"function_call,omitempty"`
|
FunctionCall *FunctionCall `json:"function_call,omitempty"`
|
||||||
|
|
||||||
// For Role=assistant prompts this may be set to the tool calls generated by the model, such as function calls.
|
// For Role=assistant prompts this may be set to the tool calls generated by the model, such as function calls.
|
||||||
@@ -119,41 +125,44 @@ func (m ChatCompletionMessage) MarshalJSON() ([]byte, error) {
|
|||||||
}
|
}
|
||||||
if len(m.MultiContent) > 0 {
|
if len(m.MultiContent) > 0 {
|
||||||
msg := struct {
|
msg := struct {
|
||||||
Role string `json:"role"`
|
Role string `json:"role"`
|
||||||
Content string `json:"-"`
|
Content string `json:"-"`
|
||||||
Refusal string `json:"refusal,omitempty"`
|
Refusal string `json:"refusal,omitempty"`
|
||||||
MultiContent []ChatMessagePart `json:"content,omitempty"`
|
MultiContent []ChatMessagePart `json:"content,omitempty"`
|
||||||
Name string `json:"name,omitempty"`
|
Name string `json:"name,omitempty"`
|
||||||
FunctionCall *FunctionCall `json:"function_call,omitempty"`
|
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
FunctionCall *FunctionCall `json:"function_call,omitempty"`
|
||||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||||
|
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||||
}(m)
|
}(m)
|
||||||
return json.Marshal(msg)
|
return json.Marshal(msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
msg := struct {
|
msg := struct {
|
||||||
Role string `json:"role"`
|
Role string `json:"role"`
|
||||||
Content string `json:"content,omitempty"`
|
Content string `json:"content,omitempty"`
|
||||||
Refusal string `json:"refusal,omitempty"`
|
Refusal string `json:"refusal,omitempty"`
|
||||||
MultiContent []ChatMessagePart `json:"-"`
|
MultiContent []ChatMessagePart `json:"-"`
|
||||||
Name string `json:"name,omitempty"`
|
Name string `json:"name,omitempty"`
|
||||||
FunctionCall *FunctionCall `json:"function_call,omitempty"`
|
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
FunctionCall *FunctionCall `json:"function_call,omitempty"`
|
||||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||||
|
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||||
}(m)
|
}(m)
|
||||||
return json.Marshal(msg)
|
return json.Marshal(msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *ChatCompletionMessage) UnmarshalJSON(bs []byte) error {
|
func (m *ChatCompletionMessage) UnmarshalJSON(bs []byte) error {
|
||||||
msg := struct {
|
msg := struct {
|
||||||
Role string `json:"role"`
|
Role string `json:"role"`
|
||||||
Content string `json:"content,omitempty"`
|
Content string `json:"content"`
|
||||||
Refusal string `json:"refusal,omitempty"`
|
Refusal string `json:"refusal,omitempty"`
|
||||||
MultiContent []ChatMessagePart
|
MultiContent []ChatMessagePart
|
||||||
Name string `json:"name,omitempty"`
|
Name string `json:"name,omitempty"`
|
||||||
FunctionCall *FunctionCall `json:"function_call,omitempty"`
|
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
FunctionCall *FunctionCall `json:"function_call,omitempty"`
|
||||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||||
|
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||||
}{}
|
}{}
|
||||||
|
|
||||||
if err := json.Unmarshal(bs, &msg); err == nil {
|
if err := json.Unmarshal(bs, &msg); err == nil {
|
||||||
@@ -161,14 +170,15 @@ func (m *ChatCompletionMessage) UnmarshalJSON(bs []byte) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
multiMsg := struct {
|
multiMsg := struct {
|
||||||
Role string `json:"role"`
|
Role string `json:"role"`
|
||||||
Content string
|
Content string
|
||||||
Refusal string `json:"refusal,omitempty"`
|
Refusal string `json:"refusal,omitempty"`
|
||||||
MultiContent []ChatMessagePart `json:"content"`
|
MultiContent []ChatMessagePart `json:"content"`
|
||||||
Name string `json:"name,omitempty"`
|
Name string `json:"name,omitempty"`
|
||||||
FunctionCall *FunctionCall `json:"function_call,omitempty"`
|
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
FunctionCall *FunctionCall `json:"function_call,omitempty"`
|
||||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||||
|
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||||
}{}
|
}{}
|
||||||
if err := json.Unmarshal(bs, &multiMsg); err != nil {
|
if err := json.Unmarshal(bs, &multiMsg); err != nil {
|
||||||
return err
|
return err
|
||||||
|
|||||||
@@ -11,6 +11,12 @@ type ChatCompletionStreamChoiceDelta struct {
|
|||||||
FunctionCall *FunctionCall `json:"function_call,omitempty"`
|
FunctionCall *FunctionCall `json:"function_call,omitempty"`
|
||||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||||
Refusal string `json:"refusal,omitempty"`
|
Refusal string `json:"refusal,omitempty"`
|
||||||
|
|
||||||
|
// This property is used for the "reasoning" feature supported by deepseek-reasoner
|
||||||
|
// which is not in the official documentation.
|
||||||
|
// the doc from deepseek:
|
||||||
|
// - https://api-docs.deepseek.com/api/create-chat-completion#responses
|
||||||
|
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ChatCompletionStreamChoiceLogprobs struct {
|
type ChatCompletionStreamChoiceLogprobs struct {
|
||||||
|
|||||||
79
chat_test.go
79
chat_test.go
@@ -411,6 +411,23 @@ func TestO3ModelChatCompletions(t *testing.T) {
|
|||||||
checks.NoError(t, err, "CreateChatCompletion error")
|
checks.NoError(t, err, "CreateChatCompletion error")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDeepseekR1ModelChatCompletions(t *testing.T) {
|
||||||
|
client, server, teardown := setupOpenAITestServer()
|
||||||
|
defer teardown()
|
||||||
|
server.RegisterHandler("/v1/chat/completions", handleDeepseekR1ChatCompletionEndpoint)
|
||||||
|
_, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{
|
||||||
|
Model: "deepseek-reasoner",
|
||||||
|
MaxCompletionTokens: 100,
|
||||||
|
Messages: []openai.ChatCompletionMessage{
|
||||||
|
{
|
||||||
|
Role: openai.ChatMessageRoleUser,
|
||||||
|
Content: "Hello!",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
checks.NoError(t, err, "CreateChatCompletion error")
|
||||||
|
}
|
||||||
|
|
||||||
// TestCompletions Tests the completions endpoint of the API using the mocked server.
|
// TestCompletions Tests the completions endpoint of the API using the mocked server.
|
||||||
func TestChatCompletionsWithHeaders(t *testing.T) {
|
func TestChatCompletionsWithHeaders(t *testing.T) {
|
||||||
client, server, teardown := setupOpenAITestServer()
|
client, server, teardown := setupOpenAITestServer()
|
||||||
@@ -822,6 +839,68 @@ func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
|
|||||||
fmt.Fprintln(w, string(resBytes))
|
fmt.Fprintln(w, string(resBytes))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func handleDeepseekR1ChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
|
||||||
|
var err error
|
||||||
|
var resBytes []byte
|
||||||
|
|
||||||
|
// completions only accepts POST requests
|
||||||
|
if r.Method != "POST" {
|
||||||
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||||
|
}
|
||||||
|
var completionReq openai.ChatCompletionRequest
|
||||||
|
if completionReq, err = getChatCompletionBody(r); err != nil {
|
||||||
|
http.Error(w, "could not read request", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
res := openai.ChatCompletionResponse{
|
||||||
|
ID: strconv.Itoa(int(time.Now().Unix())),
|
||||||
|
Object: "test-object",
|
||||||
|
Created: time.Now().Unix(),
|
||||||
|
// would be nice to validate Model during testing, but
|
||||||
|
// this may not be possible with how much upkeep
|
||||||
|
// would be required / wouldn't make much sense
|
||||||
|
Model: completionReq.Model,
|
||||||
|
}
|
||||||
|
// create completions
|
||||||
|
n := completionReq.N
|
||||||
|
if n == 0 {
|
||||||
|
n = 1
|
||||||
|
}
|
||||||
|
if completionReq.MaxCompletionTokens == 0 {
|
||||||
|
completionReq.MaxCompletionTokens = 1000
|
||||||
|
}
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
reasoningContent := "User says hello! And I need to reply"
|
||||||
|
completionStr := strings.Repeat("a", completionReq.MaxCompletionTokens-numTokens(reasoningContent))
|
||||||
|
res.Choices = append(res.Choices, openai.ChatCompletionChoice{
|
||||||
|
Message: openai.ChatCompletionMessage{
|
||||||
|
Role: openai.ChatMessageRoleAssistant,
|
||||||
|
ReasoningContent: reasoningContent,
|
||||||
|
Content: completionStr,
|
||||||
|
},
|
||||||
|
Index: i,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
inputTokens := numTokens(completionReq.Messages[0].Content) * n
|
||||||
|
completionTokens := completionReq.MaxTokens * n
|
||||||
|
res.Usage = openai.Usage{
|
||||||
|
PromptTokens: inputTokens,
|
||||||
|
CompletionTokens: completionTokens,
|
||||||
|
TotalTokens: inputTokens + completionTokens,
|
||||||
|
}
|
||||||
|
resBytes, _ = json.Marshal(res)
|
||||||
|
w.Header().Set(xCustomHeader, xCustomHeaderValue)
|
||||||
|
for k, v := range rateLimitHeaders {
|
||||||
|
switch val := v.(type) {
|
||||||
|
case int:
|
||||||
|
w.Header().Set(k, strconv.Itoa(val))
|
||||||
|
default:
|
||||||
|
w.Header().Set(k, fmt.Sprintf("%s", v))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fmt.Fprintln(w, string(resBytes))
|
||||||
|
}
|
||||||
|
|
||||||
// getChatCompletionBody Returns the body of the request to create a completion.
|
// getChatCompletionBody Returns the body of the request to create a completion.
|
||||||
func getChatCompletionBody(r *http.Request) (openai.ChatCompletionRequest, error) {
|
func getChatCompletionBody(r *http.Request) (openai.ChatCompletionRequest, error) {
|
||||||
completion := openai.ChatCompletionRequest{}
|
completion := openai.ChatCompletionRequest{}
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ func setupAzureTestServer() (client *openai.Client, server *test.ServerTest, tea
|
|||||||
|
|
||||||
// numTokens Returns the number of GPT-3 encoded tokens in the given text.
|
// numTokens Returns the number of GPT-3 encoded tokens in the given text.
|
||||||
// This function approximates based on the rule of thumb stated by OpenAI:
|
// This function approximates based on the rule of thumb stated by OpenAI:
|
||||||
// https://beta.openai.com/tokenizer/
|
// https://beta.openai.com/tokenizer.
|
||||||
//
|
//
|
||||||
// TODO: implement an actual tokenizer for GPT-3 and Codex (once available).
|
// TODO: implement an actual tokenizer for GPT-3 and Codex (once available).
|
||||||
func numTokens(s string) int {
|
func numTokens(s string) int {
|
||||||
|
|||||||
Reference in New Issue
Block a user