Add support for reasoning_content field in chat completion messages for DeepSeek R1 (#925)
* support deepseek field "reasoning_content" * support deepseek field "reasoning_content" * Comment ends in a period (godot) * add comment on field reasoning_content * fix go lint error * chore: trigger CI * make field "content" in MarshalJSON function omitempty * remove reasoning_content in TestO1ModelChatCompletions func * feat: Add test and handler for deepseek-reasoner chat model completions, including support for reasoning content in responses. * feat: Add test and handler for deepseek-reasoner chat model completions, including support for reasoning content in responses. * feat: Add test and handler for deepseek-reasoner chat model completions, including support for reasoning content in responses.
This commit is contained in:
74
chat.go
74
chat.go
@@ -104,6 +104,12 @@ type ChatCompletionMessage struct {
|
||||
// - https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
||||
Name string `json:"name,omitempty"`
|
||||
|
||||
// This property is used for the "reasoning" feature supported by deepseek-reasoner
|
||||
// which is not in the official documentation.
|
||||
// the doc from deepseek:
|
||||
// - https://api-docs.deepseek.com/api/create-chat-completion#responses
|
||||
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||
|
||||
FunctionCall *FunctionCall `json:"function_call,omitempty"`
|
||||
|
||||
// For Role=assistant prompts this may be set to the tool calls generated by the model, such as function calls.
|
||||
@@ -119,41 +125,44 @@ func (m ChatCompletionMessage) MarshalJSON() ([]byte, error) {
|
||||
}
|
||||
if len(m.MultiContent) > 0 {
|
||||
msg := struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"-"`
|
||||
Refusal string `json:"refusal,omitempty"`
|
||||
MultiContent []ChatMessagePart `json:"content,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
FunctionCall *FunctionCall `json:"function_call,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
Role string `json:"role"`
|
||||
Content string `json:"-"`
|
||||
Refusal string `json:"refusal,omitempty"`
|
||||
MultiContent []ChatMessagePart `json:"content,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||
FunctionCall *FunctionCall `json:"function_call,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
}(m)
|
||||
return json.Marshal(msg)
|
||||
}
|
||||
|
||||
msg := struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content,omitempty"`
|
||||
Refusal string `json:"refusal,omitempty"`
|
||||
MultiContent []ChatMessagePart `json:"-"`
|
||||
Name string `json:"name,omitempty"`
|
||||
FunctionCall *FunctionCall `json:"function_call,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content,omitempty"`
|
||||
Refusal string `json:"refusal,omitempty"`
|
||||
MultiContent []ChatMessagePart `json:"-"`
|
||||
Name string `json:"name,omitempty"`
|
||||
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||
FunctionCall *FunctionCall `json:"function_call,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
}(m)
|
||||
return json.Marshal(msg)
|
||||
}
|
||||
|
||||
func (m *ChatCompletionMessage) UnmarshalJSON(bs []byte) error {
|
||||
msg := struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content,omitempty"`
|
||||
Refusal string `json:"refusal,omitempty"`
|
||||
MultiContent []ChatMessagePart
|
||||
Name string `json:"name,omitempty"`
|
||||
FunctionCall *FunctionCall `json:"function_call,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
Refusal string `json:"refusal,omitempty"`
|
||||
MultiContent []ChatMessagePart
|
||||
Name string `json:"name,omitempty"`
|
||||
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||
FunctionCall *FunctionCall `json:"function_call,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
}{}
|
||||
|
||||
if err := json.Unmarshal(bs, &msg); err == nil {
|
||||
@@ -161,14 +170,15 @@ func (m *ChatCompletionMessage) UnmarshalJSON(bs []byte) error {
|
||||
return nil
|
||||
}
|
||||
multiMsg := struct {
|
||||
Role string `json:"role"`
|
||||
Content string
|
||||
Refusal string `json:"refusal,omitempty"`
|
||||
MultiContent []ChatMessagePart `json:"content"`
|
||||
Name string `json:"name,omitempty"`
|
||||
FunctionCall *FunctionCall `json:"function_call,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
Role string `json:"role"`
|
||||
Content string
|
||||
Refusal string `json:"refusal,omitempty"`
|
||||
MultiContent []ChatMessagePart `json:"content"`
|
||||
Name string `json:"name,omitempty"`
|
||||
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||
FunctionCall *FunctionCall `json:"function_call,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
}{}
|
||||
if err := json.Unmarshal(bs, &multiMsg); err != nil {
|
||||
return err
|
||||
|
||||
@@ -11,6 +11,12 @@ type ChatCompletionStreamChoiceDelta struct {
|
||||
FunctionCall *FunctionCall `json:"function_call,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
Refusal string `json:"refusal,omitempty"`
|
||||
|
||||
// This property is used for the "reasoning" feature supported by deepseek-reasoner
|
||||
// which is not in the official documentation.
|
||||
// the doc from deepseek:
|
||||
// - https://api-docs.deepseek.com/api/create-chat-completion#responses
|
||||
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||
}
|
||||
|
||||
type ChatCompletionStreamChoiceLogprobs struct {
|
||||
|
||||
79
chat_test.go
79
chat_test.go
@@ -411,6 +411,23 @@ func TestO3ModelChatCompletions(t *testing.T) {
|
||||
checks.NoError(t, err, "CreateChatCompletion error")
|
||||
}
|
||||
|
||||
func TestDeepseekR1ModelChatCompletions(t *testing.T) {
|
||||
client, server, teardown := setupOpenAITestServer()
|
||||
defer teardown()
|
||||
server.RegisterHandler("/v1/chat/completions", handleDeepseekR1ChatCompletionEndpoint)
|
||||
_, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{
|
||||
Model: "deepseek-reasoner",
|
||||
MaxCompletionTokens: 100,
|
||||
Messages: []openai.ChatCompletionMessage{
|
||||
{
|
||||
Role: openai.ChatMessageRoleUser,
|
||||
Content: "Hello!",
|
||||
},
|
||||
},
|
||||
})
|
||||
checks.NoError(t, err, "CreateChatCompletion error")
|
||||
}
|
||||
|
||||
// TestCompletions Tests the completions endpoint of the API using the mocked server.
|
||||
func TestChatCompletionsWithHeaders(t *testing.T) {
|
||||
client, server, teardown := setupOpenAITestServer()
|
||||
@@ -822,6 +839,68 @@ func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, string(resBytes))
|
||||
}
|
||||
|
||||
func handleDeepseekR1ChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
|
||||
var err error
|
||||
var resBytes []byte
|
||||
|
||||
// completions only accepts POST requests
|
||||
if r.Method != "POST" {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
}
|
||||
var completionReq openai.ChatCompletionRequest
|
||||
if completionReq, err = getChatCompletionBody(r); err != nil {
|
||||
http.Error(w, "could not read request", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
res := openai.ChatCompletionResponse{
|
||||
ID: strconv.Itoa(int(time.Now().Unix())),
|
||||
Object: "test-object",
|
||||
Created: time.Now().Unix(),
|
||||
// would be nice to validate Model during testing, but
|
||||
// this may not be possible with how much upkeep
|
||||
// would be required / wouldn't make much sense
|
||||
Model: completionReq.Model,
|
||||
}
|
||||
// create completions
|
||||
n := completionReq.N
|
||||
if n == 0 {
|
||||
n = 1
|
||||
}
|
||||
if completionReq.MaxCompletionTokens == 0 {
|
||||
completionReq.MaxCompletionTokens = 1000
|
||||
}
|
||||
for i := 0; i < n; i++ {
|
||||
reasoningContent := "User says hello! And I need to reply"
|
||||
completionStr := strings.Repeat("a", completionReq.MaxCompletionTokens-numTokens(reasoningContent))
|
||||
res.Choices = append(res.Choices, openai.ChatCompletionChoice{
|
||||
Message: openai.ChatCompletionMessage{
|
||||
Role: openai.ChatMessageRoleAssistant,
|
||||
ReasoningContent: reasoningContent,
|
||||
Content: completionStr,
|
||||
},
|
||||
Index: i,
|
||||
})
|
||||
}
|
||||
inputTokens := numTokens(completionReq.Messages[0].Content) * n
|
||||
completionTokens := completionReq.MaxTokens * n
|
||||
res.Usage = openai.Usage{
|
||||
PromptTokens: inputTokens,
|
||||
CompletionTokens: completionTokens,
|
||||
TotalTokens: inputTokens + completionTokens,
|
||||
}
|
||||
resBytes, _ = json.Marshal(res)
|
||||
w.Header().Set(xCustomHeader, xCustomHeaderValue)
|
||||
for k, v := range rateLimitHeaders {
|
||||
switch val := v.(type) {
|
||||
case int:
|
||||
w.Header().Set(k, strconv.Itoa(val))
|
||||
default:
|
||||
w.Header().Set(k, fmt.Sprintf("%s", v))
|
||||
}
|
||||
}
|
||||
fmt.Fprintln(w, string(resBytes))
|
||||
}
|
||||
|
||||
// getChatCompletionBody Returns the body of the request to create a completion.
|
||||
func getChatCompletionBody(r *http.Request) (openai.ChatCompletionRequest, error) {
|
||||
completion := openai.ChatCompletionRequest{}
|
||||
|
||||
@@ -29,7 +29,7 @@ func setupAzureTestServer() (client *openai.Client, server *test.ServerTest, tea
|
||||
|
||||
// numTokens Returns the number of GPT-3 encoded tokens in the given text.
|
||||
// This function approximates based on the rule of thumb stated by OpenAI:
|
||||
// https://beta.openai.com/tokenizer/
|
||||
// https://beta.openai.com/tokenizer.
|
||||
//
|
||||
// TODO: implement an actual tokenizer for GPT-3 and Codex (once available).
|
||||
func numTokens(s string) int {
|
||||
|
||||
Reference in New Issue
Block a user