add max_completions_tokens for o1 series models (#857)
* add max_completions_tokens for o1 series models * add validation for o1 series models validataion + beta limitations
This commit is contained in:
35
chat.go
35
chat.go
@@ -200,18 +200,25 @@ type ChatCompletionResponseFormatJSONSchema struct {
|
|||||||
|
|
||||||
// ChatCompletionRequest represents a request structure for chat completion API.
|
// ChatCompletionRequest represents a request structure for chat completion API.
|
||||||
type ChatCompletionRequest struct {
|
type ChatCompletionRequest struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Messages []ChatCompletionMessage `json:"messages"`
|
Messages []ChatCompletionMessage `json:"messages"`
|
||||||
MaxTokens int `json:"max_tokens,omitempty"`
|
// MaxTokens The maximum number of tokens that can be generated in the chat completion.
|
||||||
Temperature float32 `json:"temperature,omitempty"`
|
// This value can be used to control costs for text generated via API.
|
||||||
TopP float32 `json:"top_p,omitempty"`
|
// This value is now deprecated in favor of max_completion_tokens, and is not compatible with o1 series models.
|
||||||
N int `json:"n,omitempty"`
|
// refs: https://platform.openai.com/docs/api-reference/chat/create#chat-create-max_tokens
|
||||||
Stream bool `json:"stream,omitempty"`
|
MaxTokens int `json:"max_tokens,omitempty"`
|
||||||
Stop []string `json:"stop,omitempty"`
|
// MaxCompletionsTokens An upper bound for the number of tokens that can be generated for a completion,
|
||||||
PresencePenalty float32 `json:"presence_penalty,omitempty"`
|
// including visible output tokens and reasoning tokens https://platform.openai.com/docs/guides/reasoning
|
||||||
ResponseFormat *ChatCompletionResponseFormat `json:"response_format,omitempty"`
|
MaxCompletionsTokens int `json:"max_completions_tokens,omitempty"`
|
||||||
Seed *int `json:"seed,omitempty"`
|
Temperature float32 `json:"temperature,omitempty"`
|
||||||
FrequencyPenalty float32 `json:"frequency_penalty,omitempty"`
|
TopP float32 `json:"top_p,omitempty"`
|
||||||
|
N int `json:"n,omitempty"`
|
||||||
|
Stream bool `json:"stream,omitempty"`
|
||||||
|
Stop []string `json:"stop,omitempty"`
|
||||||
|
PresencePenalty float32 `json:"presence_penalty,omitempty"`
|
||||||
|
ResponseFormat *ChatCompletionResponseFormat `json:"response_format,omitempty"`
|
||||||
|
Seed *int `json:"seed,omitempty"`
|
||||||
|
FrequencyPenalty float32 `json:"frequency_penalty,omitempty"`
|
||||||
// LogitBias is must be a token id string (specified by their token ID in the tokenizer), not a word string.
|
// LogitBias is must be a token id string (specified by their token ID in the tokenizer), not a word string.
|
||||||
// incorrect: `"logit_bias":{"You": 6}`, correct: `"logit_bias":{"1639": 6}`
|
// incorrect: `"logit_bias":{"You": 6}`, correct: `"logit_bias":{"1639": 6}`
|
||||||
// refs: https://platform.openai.com/docs/api-reference/chat/create#chat/create-logit_bias
|
// refs: https://platform.openai.com/docs/api-reference/chat/create#chat/create-logit_bias
|
||||||
@@ -364,6 +371,10 @@ func (c *Client) CreateChatCompletion(
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err = validateRequestForO1Models(request); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
req, err := c.newRequest(
|
req, err := c.newRequest(
|
||||||
ctx,
|
ctx,
|
||||||
http.MethodPost,
|
http.MethodPost,
|
||||||
|
|||||||
@@ -60,6 +60,10 @@ func (c *Client) CreateChatCompletionStream(
|
|||||||
}
|
}
|
||||||
|
|
||||||
request.Stream = true
|
request.Stream = true
|
||||||
|
if err = validateRequestForO1Models(request); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
req, err := c.newRequest(
|
req, err := c.newRequest(
|
||||||
ctx,
|
ctx,
|
||||||
http.MethodPost,
|
http.MethodPost,
|
||||||
|
|||||||
@@ -36,6 +36,27 @@ func TestChatCompletionsStreamWrongModel(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestChatCompletionsStreamWithO1BetaLimitations(t *testing.T) {
|
||||||
|
config := openai.DefaultConfig("whatever")
|
||||||
|
config.BaseURL = "http://localhost/v1/chat/completions"
|
||||||
|
client := openai.NewClientWithConfig(config)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
req := openai.ChatCompletionRequest{
|
||||||
|
Model: openai.O1Preview,
|
||||||
|
Messages: []openai.ChatCompletionMessage{
|
||||||
|
{
|
||||||
|
Role: openai.ChatMessageRoleUser,
|
||||||
|
Content: "Hello!",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
_, err := client.CreateChatCompletionStream(ctx, req)
|
||||||
|
if !errors.Is(err, openai.ErrO1BetaLimitationsStreaming) {
|
||||||
|
t.Fatalf("CreateChatCompletion should return ErrO1BetaLimitationsStreaming, but returned: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestCreateChatCompletionStream(t *testing.T) {
|
func TestCreateChatCompletionStream(t *testing.T) {
|
||||||
client, server, teardown := setupOpenAITestServer()
|
client, server, teardown := setupOpenAITestServer()
|
||||||
defer teardown()
|
defer teardown()
|
||||||
|
|||||||
211
chat_test.go
211
chat_test.go
@@ -52,6 +52,199 @@ func TestChatCompletionsWrongModel(t *testing.T) {
|
|||||||
checks.ErrorIs(t, err, openai.ErrChatCompletionInvalidModel, msg)
|
checks.ErrorIs(t, err, openai.ErrChatCompletionInvalidModel, msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestO1ModelsChatCompletionsDeprecatedFields(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
in openai.ChatCompletionRequest
|
||||||
|
expectedError error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "o1-preview_MaxTokens_deprecated",
|
||||||
|
in: openai.ChatCompletionRequest{
|
||||||
|
MaxTokens: 5,
|
||||||
|
Model: openai.O1Preview,
|
||||||
|
},
|
||||||
|
expectedError: openai.ErrO1MaxTokensDeprecated,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "o1-mini_MaxTokens_deprecated",
|
||||||
|
in: openai.ChatCompletionRequest{
|
||||||
|
MaxTokens: 5,
|
||||||
|
Model: openai.O1Mini,
|
||||||
|
},
|
||||||
|
expectedError: openai.ErrO1MaxTokensDeprecated,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
config := openai.DefaultConfig("whatever")
|
||||||
|
config.BaseURL = "http://localhost/v1"
|
||||||
|
client := openai.NewClientWithConfig(config)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
_, err := client.CreateChatCompletion(ctx, tt.in)
|
||||||
|
checks.HasError(t, err)
|
||||||
|
msg := fmt.Sprintf("CreateChatCompletion should return wrong model error, returned: %s", err)
|
||||||
|
checks.ErrorIs(t, err, tt.expectedError, msg)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
in openai.ChatCompletionRequest
|
||||||
|
expectedError error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "log_probs_unsupported",
|
||||||
|
in: openai.ChatCompletionRequest{
|
||||||
|
MaxCompletionsTokens: 1000,
|
||||||
|
LogProbs: true,
|
||||||
|
Model: openai.O1Preview,
|
||||||
|
},
|
||||||
|
expectedError: openai.ErrO1BetaLimitationsLogprobs,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "message_type_unsupported",
|
||||||
|
in: openai.ChatCompletionRequest{
|
||||||
|
MaxCompletionsTokens: 1000,
|
||||||
|
Model: openai.O1Mini,
|
||||||
|
Messages: []openai.ChatCompletionMessage{
|
||||||
|
{
|
||||||
|
Role: openai.ChatMessageRoleSystem,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedError: openai.ErrO1BetaLimitationsMessageTypes,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tool_unsupported",
|
||||||
|
in: openai.ChatCompletionRequest{
|
||||||
|
MaxCompletionsTokens: 1000,
|
||||||
|
Model: openai.O1Mini,
|
||||||
|
Messages: []openai.ChatCompletionMessage{
|
||||||
|
{
|
||||||
|
Role: openai.ChatMessageRoleUser,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: openai.ChatMessageRoleAssistant,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Tools: []openai.Tool{
|
||||||
|
{
|
||||||
|
Type: openai.ToolTypeFunction,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedError: openai.ErrO1BetaLimitationsTools,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "set_temperature_unsupported",
|
||||||
|
in: openai.ChatCompletionRequest{
|
||||||
|
MaxCompletionsTokens: 1000,
|
||||||
|
Model: openai.O1Mini,
|
||||||
|
Messages: []openai.ChatCompletionMessage{
|
||||||
|
{
|
||||||
|
Role: openai.ChatMessageRoleUser,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: openai.ChatMessageRoleAssistant,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Temperature: float32(2),
|
||||||
|
},
|
||||||
|
expectedError: openai.ErrO1BetaLimitationsOther,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "set_top_unsupported",
|
||||||
|
in: openai.ChatCompletionRequest{
|
||||||
|
MaxCompletionsTokens: 1000,
|
||||||
|
Model: openai.O1Mini,
|
||||||
|
Messages: []openai.ChatCompletionMessage{
|
||||||
|
{
|
||||||
|
Role: openai.ChatMessageRoleUser,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: openai.ChatMessageRoleAssistant,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Temperature: float32(1),
|
||||||
|
TopP: float32(0.1),
|
||||||
|
},
|
||||||
|
expectedError: openai.ErrO1BetaLimitationsOther,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "set_n_unsupported",
|
||||||
|
in: openai.ChatCompletionRequest{
|
||||||
|
MaxCompletionsTokens: 1000,
|
||||||
|
Model: openai.O1Mini,
|
||||||
|
Messages: []openai.ChatCompletionMessage{
|
||||||
|
{
|
||||||
|
Role: openai.ChatMessageRoleUser,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: openai.ChatMessageRoleAssistant,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Temperature: float32(1),
|
||||||
|
TopP: float32(1),
|
||||||
|
N: 2,
|
||||||
|
},
|
||||||
|
expectedError: openai.ErrO1BetaLimitationsOther,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "set_presence_penalty_unsupported",
|
||||||
|
in: openai.ChatCompletionRequest{
|
||||||
|
MaxCompletionsTokens: 1000,
|
||||||
|
Model: openai.O1Mini,
|
||||||
|
Messages: []openai.ChatCompletionMessage{
|
||||||
|
{
|
||||||
|
Role: openai.ChatMessageRoleUser,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: openai.ChatMessageRoleAssistant,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
PresencePenalty: float32(1),
|
||||||
|
},
|
||||||
|
expectedError: openai.ErrO1BetaLimitationsOther,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "set_frequency_penalty_unsupported",
|
||||||
|
in: openai.ChatCompletionRequest{
|
||||||
|
MaxCompletionsTokens: 1000,
|
||||||
|
Model: openai.O1Mini,
|
||||||
|
Messages: []openai.ChatCompletionMessage{
|
||||||
|
{
|
||||||
|
Role: openai.ChatMessageRoleUser,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: openai.ChatMessageRoleAssistant,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
FrequencyPenalty: float32(0.1),
|
||||||
|
},
|
||||||
|
expectedError: openai.ErrO1BetaLimitationsOther,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
config := openai.DefaultConfig("whatever")
|
||||||
|
config.BaseURL = "http://localhost/v1"
|
||||||
|
client := openai.NewClientWithConfig(config)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
_, err := client.CreateChatCompletion(ctx, tt.in)
|
||||||
|
checks.HasError(t, err)
|
||||||
|
msg := fmt.Sprintf("CreateChatCompletion should return wrong model error, returned: %s", err)
|
||||||
|
checks.ErrorIs(t, err, tt.expectedError, msg)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestChatRequestOmitEmpty(t *testing.T) {
|
func TestChatRequestOmitEmpty(t *testing.T) {
|
||||||
data, err := json.Marshal(openai.ChatCompletionRequest{
|
data, err := json.Marshal(openai.ChatCompletionRequest{
|
||||||
// We set model b/c it's required, so omitempty doesn't make sense
|
// We set model b/c it's required, so omitempty doesn't make sense
|
||||||
@@ -97,6 +290,24 @@ func TestChatCompletions(t *testing.T) {
|
|||||||
checks.NoError(t, err, "CreateChatCompletion error")
|
checks.NoError(t, err, "CreateChatCompletion error")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestCompletions Tests the completions endpoint of the API using the mocked server.
|
||||||
|
func TestO1ModelChatCompletions(t *testing.T) {
|
||||||
|
client, server, teardown := setupOpenAITestServer()
|
||||||
|
defer teardown()
|
||||||
|
server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint)
|
||||||
|
_, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{
|
||||||
|
Model: openai.O1Preview,
|
||||||
|
MaxCompletionsTokens: 1000,
|
||||||
|
Messages: []openai.ChatCompletionMessage{
|
||||||
|
{
|
||||||
|
Role: openai.ChatMessageRoleUser,
|
||||||
|
Content: "Hello!",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
checks.NoError(t, err, "CreateChatCompletion error")
|
||||||
|
}
|
||||||
|
|
||||||
// TestCompletions Tests the completions endpoint of the API using the mocked server.
|
// TestCompletions Tests the completions endpoint of the API using the mocked server.
|
||||||
func TestChatCompletionsWithHeaders(t *testing.T) {
|
func TestChatCompletionsWithHeaders(t *testing.T) {
|
||||||
client, server, teardown := setupOpenAITestServer()
|
client, server, teardown := setupOpenAITestServer()
|
||||||
|
|||||||
@@ -7,11 +7,20 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
ErrO1MaxTokensDeprecated = errors.New("this model is not supported MaxTokens, please use MaxCompletionsTokens") //nolint:lll
|
||||||
ErrCompletionUnsupportedModel = errors.New("this model is not supported with this method, please use CreateChatCompletion client method instead") //nolint:lll
|
ErrCompletionUnsupportedModel = errors.New("this model is not supported with this method, please use CreateChatCompletion client method instead") //nolint:lll
|
||||||
ErrCompletionStreamNotSupported = errors.New("streaming is not supported with this method, please use CreateCompletionStream") //nolint:lll
|
ErrCompletionStreamNotSupported = errors.New("streaming is not supported with this method, please use CreateCompletionStream") //nolint:lll
|
||||||
ErrCompletionRequestPromptTypeNotSupported = errors.New("the type of CompletionRequest.Prompt only supports string and []string") //nolint:lll
|
ErrCompletionRequestPromptTypeNotSupported = errors.New("the type of CompletionRequest.Prompt only supports string and []string") //nolint:lll
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrO1BetaLimitationsMessageTypes = errors.New("this model has beta-limitations, user and assistant messages only, system messages are not supported") //nolint:lll
|
||||||
|
ErrO1BetaLimitationsStreaming = errors.New("this model has beta-limitations, streaming not supported") //nolint:lll
|
||||||
|
ErrO1BetaLimitationsTools = errors.New("this model has beta-limitations, tools, function calling, and response format parameters are not supported") //nolint:lll
|
||||||
|
ErrO1BetaLimitationsLogprobs = errors.New("this model has beta-limitations, logprobs not supported") //nolint:lll
|
||||||
|
ErrO1BetaLimitationsOther = errors.New("this model has beta-limitations, temperature, top_p and n are fixed at 1, while presence_penalty and frequency_penalty are fixed at 0") //nolint:lll
|
||||||
|
)
|
||||||
|
|
||||||
// GPT3 Defines the models provided by OpenAI to use when generating
|
// GPT3 Defines the models provided by OpenAI to use when generating
|
||||||
// completions from OpenAI.
|
// completions from OpenAI.
|
||||||
// GPT3 Models are designed for text-based tasks. For code-specific
|
// GPT3 Models are designed for text-based tasks. For code-specific
|
||||||
@@ -85,6 +94,15 @@ const (
|
|||||||
CodexCodeDavinci001 = "code-davinci-001"
|
CodexCodeDavinci001 = "code-davinci-001"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// O1SeriesModels List of new Series of OpenAI models.
|
||||||
|
// Some old api attributes not supported.
|
||||||
|
var O1SeriesModels = map[string]struct{}{
|
||||||
|
O1Mini: {},
|
||||||
|
O1Mini20240912: {},
|
||||||
|
O1Preview: {},
|
||||||
|
O1Preview20240912: {},
|
||||||
|
}
|
||||||
|
|
||||||
var disabledModelsForEndpoints = map[string]map[string]bool{
|
var disabledModelsForEndpoints = map[string]map[string]bool{
|
||||||
"/completions": {
|
"/completions": {
|
||||||
O1Mini: true,
|
O1Mini: true,
|
||||||
@@ -146,6 +164,70 @@ func checkPromptType(prompt any) bool {
|
|||||||
return isString || isStringSlice
|
return isString || isStringSlice
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var unsupportedToolsForO1Models = map[ToolType]struct{}{
|
||||||
|
ToolTypeFunction: {},
|
||||||
|
}
|
||||||
|
|
||||||
|
var availableMessageRoleForO1Models = map[string]struct{}{
|
||||||
|
ChatMessageRoleUser: {},
|
||||||
|
ChatMessageRoleAssistant: {},
|
||||||
|
}
|
||||||
|
|
||||||
|
// validateRequestForO1Models checks for deprecated fields of OpenAI models.
|
||||||
|
func validateRequestForO1Models(request ChatCompletionRequest) error {
|
||||||
|
if _, found := O1SeriesModels[request.Model]; !found {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if request.MaxTokens > 0 {
|
||||||
|
return ErrO1MaxTokensDeprecated
|
||||||
|
}
|
||||||
|
|
||||||
|
// Beta Limitations
|
||||||
|
// refs:https://platform.openai.com/docs/guides/reasoning/beta-limitations
|
||||||
|
// Streaming: not supported
|
||||||
|
if request.Stream {
|
||||||
|
return ErrO1BetaLimitationsStreaming
|
||||||
|
}
|
||||||
|
// Logprobs: not supported.
|
||||||
|
if request.LogProbs {
|
||||||
|
return ErrO1BetaLimitationsLogprobs
|
||||||
|
}
|
||||||
|
|
||||||
|
// Message types: user and assistant messages only, system messages are not supported.
|
||||||
|
for _, m := range request.Messages {
|
||||||
|
if _, found := availableMessageRoleForO1Models[m.Role]; !found {
|
||||||
|
return ErrO1BetaLimitationsMessageTypes
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tools: tools, function calling, and response format parameters are not supported
|
||||||
|
for _, t := range request.Tools {
|
||||||
|
if _, found := unsupportedToolsForO1Models[t.Type]; found {
|
||||||
|
return ErrO1BetaLimitationsTools
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Other: temperature, top_p and n are fixed at 1, while presence_penalty and frequency_penalty are fixed at 0.
|
||||||
|
if request.Temperature > 0 && request.Temperature != 1 {
|
||||||
|
return ErrO1BetaLimitationsOther
|
||||||
|
}
|
||||||
|
if request.TopP > 0 && request.TopP != 1 {
|
||||||
|
return ErrO1BetaLimitationsOther
|
||||||
|
}
|
||||||
|
if request.N > 0 && request.N != 1 {
|
||||||
|
return ErrO1BetaLimitationsOther
|
||||||
|
}
|
||||||
|
if request.PresencePenalty > 0 {
|
||||||
|
return ErrO1BetaLimitationsOther
|
||||||
|
}
|
||||||
|
if request.FrequencyPenalty > 0 {
|
||||||
|
return ErrO1BetaLimitationsOther
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// CompletionRequest represents a request structure for completion API.
|
// CompletionRequest represents a request structure for completion API.
|
||||||
type CompletionRequest struct {
|
type CompletionRequest struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
|
|||||||
Reference in New Issue
Block a user