diff --git a/chat_stream_test.go b/chat_stream_test.go index 4d992e4..eabb0f3 100644 --- a/chat_stream_test.go +++ b/chat_stream_test.go @@ -959,6 +959,56 @@ func TestCreateChatCompletionStreamReasoningValidatorFails(t *testing.T) { } } +func TestCreateChatCompletionStreamO3ReasoningValidatorFails(t *testing.T) { + client, _, _ := setupOpenAITestServer() + + stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{ + MaxTokens: 100, // This will trigger the validator to fail + Model: openai.O3, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Stream: true, + }) + + if stream != nil { + t.Error("Expected nil stream when validation fails") + stream.Close() + } + + if !errors.Is(err, openai.ErrReasoningModelMaxTokensDeprecated) { + t.Errorf("Expected ErrReasoningModelMaxTokensDeprecated for O3, got: %v", err) + } +} + +func TestCreateChatCompletionStreamO4MiniReasoningValidatorFails(t *testing.T) { + client, _, _ := setupOpenAITestServer() + + stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{ + MaxTokens: 100, // This will trigger the validator to fail + Model: openai.O4Mini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Stream: true, + }) + + if stream != nil { + t.Error("Expected nil stream when validation fails") + stream.Close() + } + + if !errors.Is(err, openai.ErrReasoningModelMaxTokensDeprecated) { + t.Errorf("Expected ErrReasoningModelMaxTokensDeprecated for O4Mini, got: %v", err) + } +} + func compareChatStreamResponseChoices(c1, c2 openai.ChatCompletionStreamChoice) bool { if c1.Index != c2.Index { return false diff --git a/completion.go b/completion.go index 0d0c1a8..7a6de30 100644 --- a/completion.go +++ b/completion.go @@ -16,8 +16,12 @@ const ( O1Preview20240912 = "o1-preview-2024-09-12" O1 = "o1" O120241217 = "o1-2024-12-17" + O3 = "o3" + O320250416 = "o3-2025-04-16" O3Mini = "o3-mini" O3Mini20250131 = "o3-mini-2025-01-31" + O4Mini = "o4-mini" + O4Mini2020416 = "o4-mini-2025-04-16" GPT432K0613 = "gpt-4-32k-0613" GPT432K0314 = "gpt-4-32k-0314" GPT432K = "gpt-4-32k" @@ -99,6 +103,10 @@ var disabledModelsForEndpoints = map[string]map[string]bool{ O1Preview20240912: true, O3Mini: true, O3Mini20250131: true, + O4Mini: true, + O4Mini2020416: true, + O3: true, + O320250416: true, GPT3Dot5Turbo: true, GPT3Dot5Turbo0301: true, GPT3Dot5Turbo0613: true, diff --git a/completion_test.go b/completion_test.go index 83bd899..27e2d15 100644 --- a/completion_test.go +++ b/completion_test.go @@ -33,6 +33,42 @@ func TestCompletionsWrongModel(t *testing.T) { } } +// TestCompletionsWrongModelO3 Tests the completions endpoint with O3 model which is not supported. +func TestCompletionsWrongModelO3(t *testing.T) { + config := openai.DefaultConfig("whatever") + config.BaseURL = "http://localhost/v1" + client := openai.NewClientWithConfig(config) + + _, err := client.CreateCompletion( + context.Background(), + openai.CompletionRequest{ + MaxTokens: 5, + Model: openai.O3, + }, + ) + if !errors.Is(err, openai.ErrCompletionUnsupportedModel) { + t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel for O3, but returned: %v", err) + } +} + +// TestCompletionsWrongModelO4Mini Tests the completions endpoint with O4Mini model which is not supported. +func TestCompletionsWrongModelO4Mini(t *testing.T) { + config := openai.DefaultConfig("whatever") + config.BaseURL = "http://localhost/v1" + client := openai.NewClientWithConfig(config) + + _, err := client.CreateCompletion( + context.Background(), + openai.CompletionRequest{ + MaxTokens: 5, + Model: openai.O4Mini, + }, + ) + if !errors.Is(err, openai.ErrCompletionUnsupportedModel) { + t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel for O4Mini, but returned: %v", err) + } +} + func TestCompletionWithStream(t *testing.T) { config := openai.DefaultConfig("whatever") client := openai.NewClientWithConfig(config) diff --git a/models_test.go b/models_test.go index 24a28ed..7fd010c 100644 --- a/models_test.go +++ b/models_test.go @@ -47,6 +47,24 @@ func TestGetModel(t *testing.T) { checks.NoError(t, err, "GetModel error") } +// TestGetModelO3 Tests the retrieve O3 model endpoint of the API using the mocked server. +func TestGetModelO3(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/models/o3", handleGetModelEndpoint) + _, err := client.GetModel(context.Background(), "o3") + checks.NoError(t, err, "GetModel error for O3") +} + +// TestGetModelO4Mini Tests the retrieve O4Mini model endpoint of the API using the mocked server. +func TestGetModelO4Mini(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/models/o4-mini", handleGetModelEndpoint) + _, err := client.GetModel(context.Background(), "o4-mini") + checks.NoError(t, err, "GetModel error for O4Mini") +} + func TestAzureGetModel(t *testing.T) { client, server, teardown := setupAzureTestServer() defer teardown() diff --git a/reasoning_validator.go b/reasoning_validator.go index 040d6b4..2910b13 100644 --- a/reasoning_validator.go +++ b/reasoning_validator.go @@ -40,8 +40,9 @@ func NewReasoningValidator() *ReasoningValidator { func (v *ReasoningValidator) Validate(request ChatCompletionRequest) error { o1Series := strings.HasPrefix(request.Model, "o1") o3Series := strings.HasPrefix(request.Model, "o3") + o4Series := strings.HasPrefix(request.Model, "o4") - if !o1Series && !o3Series { + if !o1Series && !o3Series && !o4Series { return nil }