Add support for O3-mini (#930)

* Add support for O3-mini

- Add support for the o3 mini set of models, including tests that match the constraints in OpenAI's API docs (https://platform.openai.com/docs/models#o3-mini).

* Deprecate and refactor

- Deprecate `ErrO1BetaLimitationsLogprobs` and `ErrO1BetaLimitationsOther`

- Implement `validationRequestForReasoningModels`, which works on both o1 & o3, and has per-model-type restrictions on functionality (eg, o3 class are allowed function calls and system messages, o1 isn't)

* Move reasoning validation to `reasoning_validator.go`

- Add a `NewReasoningValidator` which exposes a `Validate()` method for a given request

- Also adds a test for chat streams

* Final nits
This commit is contained in:
rory malcolm
2025-02-06 14:53:19 +00:00
committed by GitHub
parent 45aa99607b
commit 2054db016c
6 changed files with 431 additions and 92 deletions

View File

@@ -64,7 +64,7 @@ func TestO1ModelsChatCompletionsDeprecatedFields(t *testing.T) {
MaxTokens: 5,
Model: openai.O1Preview,
},
expectedError: openai.ErrO1MaxTokensDeprecated,
expectedError: openai.ErrReasoningModelMaxTokensDeprecated,
},
{
name: "o1-mini_MaxTokens_deprecated",
@@ -72,7 +72,7 @@ func TestO1ModelsChatCompletionsDeprecatedFields(t *testing.T) {
MaxTokens: 5,
Model: openai.O1Mini,
},
expectedError: openai.ErrO1MaxTokensDeprecated,
expectedError: openai.ErrReasoningModelMaxTokensDeprecated,
},
}
@@ -104,7 +104,7 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) {
LogProbs: true,
Model: openai.O1Preview,
},
expectedError: openai.ErrO1BetaLimitationsLogprobs,
expectedError: openai.ErrReasoningModelLimitationsLogprobs,
},
{
name: "message_type_unsupported",
@@ -155,7 +155,7 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) {
},
Temperature: float32(2),
},
expectedError: openai.ErrO1BetaLimitationsOther,
expectedError: openai.ErrReasoningModelLimitationsOther,
},
{
name: "set_top_unsupported",
@@ -173,7 +173,7 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) {
Temperature: float32(1),
TopP: float32(0.1),
},
expectedError: openai.ErrO1BetaLimitationsOther,
expectedError: openai.ErrReasoningModelLimitationsOther,
},
{
name: "set_n_unsupported",
@@ -192,7 +192,7 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) {
TopP: float32(1),
N: 2,
},
expectedError: openai.ErrO1BetaLimitationsOther,
expectedError: openai.ErrReasoningModelLimitationsOther,
},
{
name: "set_presence_penalty_unsupported",
@@ -209,7 +209,7 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) {
},
PresencePenalty: float32(1),
},
expectedError: openai.ErrO1BetaLimitationsOther,
expectedError: openai.ErrReasoningModelLimitationsOther,
},
{
name: "set_frequency_penalty_unsupported",
@@ -226,7 +226,127 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) {
},
FrequencyPenalty: float32(0.1),
},
expectedError: openai.ErrO1BetaLimitationsOther,
expectedError: openai.ErrReasoningModelLimitationsOther,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
config := openai.DefaultConfig("whatever")
config.BaseURL = "http://localhost/v1"
client := openai.NewClientWithConfig(config)
ctx := context.Background()
_, err := client.CreateChatCompletion(ctx, tt.in)
checks.HasError(t, err)
msg := fmt.Sprintf("CreateChatCompletion should return wrong model error, returned: %s", err)
checks.ErrorIs(t, err, tt.expectedError, msg)
})
}
}
func TestO3ModelsChatCompletionsBetaLimitations(t *testing.T) {
tests := []struct {
name string
in openai.ChatCompletionRequest
expectedError error
}{
{
name: "log_probs_unsupported",
in: openai.ChatCompletionRequest{
MaxCompletionTokens: 1000,
LogProbs: true,
Model: openai.O3Mini,
},
expectedError: openai.ErrReasoningModelLimitationsLogprobs,
},
{
name: "set_temperature_unsupported",
in: openai.ChatCompletionRequest{
MaxCompletionTokens: 1000,
Model: openai.O3Mini,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
},
{
Role: openai.ChatMessageRoleAssistant,
},
},
Temperature: float32(2),
},
expectedError: openai.ErrReasoningModelLimitationsOther,
},
{
name: "set_top_unsupported",
in: openai.ChatCompletionRequest{
MaxCompletionTokens: 1000,
Model: openai.O3Mini,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
},
{
Role: openai.ChatMessageRoleAssistant,
},
},
Temperature: float32(1),
TopP: float32(0.1),
},
expectedError: openai.ErrReasoningModelLimitationsOther,
},
{
name: "set_n_unsupported",
in: openai.ChatCompletionRequest{
MaxCompletionTokens: 1000,
Model: openai.O3Mini,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
},
{
Role: openai.ChatMessageRoleAssistant,
},
},
Temperature: float32(1),
TopP: float32(1),
N: 2,
},
expectedError: openai.ErrReasoningModelLimitationsOther,
},
{
name: "set_presence_penalty_unsupported",
in: openai.ChatCompletionRequest{
MaxCompletionTokens: 1000,
Model: openai.O3Mini,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
},
{
Role: openai.ChatMessageRoleAssistant,
},
},
PresencePenalty: float32(1),
},
expectedError: openai.ErrReasoningModelLimitationsOther,
},
{
name: "set_frequency_penalty_unsupported",
in: openai.ChatCompletionRequest{
MaxCompletionTokens: 1000,
Model: openai.O3Mini,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
},
{
Role: openai.ChatMessageRoleAssistant,
},
},
FrequencyPenalty: float32(0.1),
},
expectedError: openai.ErrReasoningModelLimitationsOther,
},
}
@@ -308,6 +428,23 @@ func TestO1ModelChatCompletions(t *testing.T) {
checks.NoError(t, err, "CreateChatCompletion error")
}
func TestO3ModelChatCompletions(t *testing.T) {
client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint)
_, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{
Model: openai.O3Mini,
MaxCompletionTokens: 1000,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
Content: "Hello!",
},
},
})
checks.NoError(t, err, "CreateChatCompletion error")
}
// TestCompletions Tests the completions endpoint of the API using the mocked server.
func TestChatCompletionsWithHeaders(t *testing.T) {
client, server, teardown := setupOpenAITestServer()