Compare commits

..

3 Commits

Author SHA1 Message Date
VaalaCat
1337a4b683 feat: add reasoning format 2025-03-07 13:22:30 +00:00
VaalaCat
3f53ae6ab1 feat: add include_reasoning 2025-02-12 13:24:21 +00:00
VaalaCat
40de0deb41 feat: change repo name 2025-02-12 13:22:09 +00:00
26 changed files with 566 additions and 1369 deletions

View File

@@ -13,17 +13,15 @@ jobs:
- name: Setup Go - name: Setup Go
uses: actions/setup-go@v5 uses: actions/setup-go@v5
with: with:
go-version: '1.24' go-version: '1.21'
- name: Run vet - name: Run vet
run: | run: |
go vet . go vet .
- name: Run golangci-lint - name: Run golangci-lint
uses: golangci/golangci-lint-action@v7 uses: golangci/golangci-lint-action@v4
with: with:
version: v2.1.5 version: latest
- name: Run tests - name: Run tests
run: go test -race -covermode=atomic -coverprofile=coverage.out -v ./... run: go test -race -covermode=atomic -coverprofile=coverage.out -v .
- name: Upload coverage reports to Codecov - name: Upload coverage reports to Codecov
uses: codecov/codecov-action@v5 uses: codecov/codecov-action@v4
with:
token: ${{ secrets.CODECOV_TOKEN }}

View File

@@ -1,94 +1,66 @@
version: "2" ## Golden config for golangci-lint v1.47.3
linters: #
default: none # This is the best config for golangci-lint based on my experience and opinion.
enable: # It is very strict, but not extremely strict.
- asciicheck # Feel free to adopt and change it for your needs.
- bidichk
- bodyclose run:
- contextcheck # Timeout for analysis, e.g. 30s, 5m.
- cyclop # Default: 1m
- dupl timeout: 3m
- durationcheck
- errcheck
- errname # This file contains only configs which differ from defaults.
- errorlint # All possible options can be found here https://github.com/golangci/golangci-lint/blob/master/.golangci.reference.yml
- exhaustive linters-settings:
- forbidigo
- funlen
- gochecknoinits
- gocognit
- goconst
- gocritic
- gocyclo
- godot
- gomoddirectives
- gomodguard
- goprintffuncname
- gosec
- govet
- ineffassign
- lll
- makezero
- mnd
- nestif
- nilerr
- nilnil
- nolintlint
- nosprintfhostport
- predeclared
- promlinter
- revive
- rowserrcheck
- sqlclosecheck
- staticcheck
- testpackage
- tparallel
- unconvert
- unparam
- unused
- usetesting
- wastedassign
- whitespace
settings:
cyclop: cyclop:
# The maximal code complexity to report.
# Default: 10
max-complexity: 30 max-complexity: 30
package-average: 10 # The maximal average package complexity.
# If it's higher than 0.0 (float) the check is enabled
# Default: 0.0
package-average: 10.0
errcheck: errcheck:
# Report about not checking of errors in type assertions: `a := b.(MyStruct)`.
# Such cases aren't reported by default.
# Default: false
check-type-assertions: true check-type-assertions: true
funlen: funlen:
# Checks the number of lines in a function.
# If lower than 0, disable the check.
# Default: 60
lines: 100 lines: 100
# Checks the number of statements in a function.
# If lower than 0, disable the check.
# Default: 40
statements: 50 statements: 50
gocognit: gocognit:
# Minimal code complexity to report
# Default: 30 (but we recommend 10-20)
min-complexity: 20 min-complexity: 20
gocritic: gocritic:
# Settings passed to gocritic.
# The settings key is the name of a supported gocritic checker.
# The list of supported checkers can be find in https://go-critic.github.io/overview.
settings: settings:
captLocal: captLocal:
# Whether to restrict checker to params only.
# Default: true
paramsOnly: false paramsOnly: false
underef: underef:
# Whether to skip (*x).method() calls where x is a pointer receiver.
# Default: true
skipRecvDeref: false skipRecvDeref: false
gomodguard:
blocked:
modules:
- github.com/golang/protobuf:
recommendations:
- google.golang.org/protobuf
reason: see https://developers.google.com/protocol-buffers/docs/reference/go/faq#modules
- github.com/satori/go.uuid:
recommendations:
- github.com/google/uuid
reason: satori's package is not maintained
- github.com/gofrs/uuid:
recommendations:
- github.com/google/uuid
reason: 'see recommendation from dev-infra team: https://confluence.gtforge.com/x/gQI6Aw'
govet:
disable:
- fieldalignment
enable-all: true
settings:
shadow:
strict: true
mnd: mnd:
# List of function patterns to exclude from analysis.
# Values always ignored: `time.Date`
# Default: []
ignored-functions: ignored-functions:
- os.Chmod - os.Chmod
- os.Mkdir - os.Mkdir
@@ -104,44 +76,194 @@ linters:
- strconv.ParseFloat - strconv.ParseFloat
- strconv.ParseInt - strconv.ParseInt
- strconv.ParseUint - strconv.ParseUint
gomodguard:
blocked:
# List of blocked modules.
# Default: []
modules:
- github.com/golang/protobuf:
recommendations:
- google.golang.org/protobuf
reason: "see https://developers.google.com/protocol-buffers/docs/reference/go/faq#modules"
- github.com/satori/go.uuid:
recommendations:
- github.com/google/uuid
reason: "satori's package is not maintained"
- github.com/gofrs/uuid:
recommendations:
- github.com/google/uuid
reason: "see recommendation from dev-infra team: https://confluence.gtforge.com/x/gQI6Aw"
govet:
# Enable all analyzers.
# Default: false
enable-all: true
# Disable analyzers by name.
# Run `go tool vet help` to see all analyzers.
# Default: []
disable:
- fieldalignment # too strict
# Settings per analyzer.
settings:
shadow:
# Whether to be strict about shadowing; can be noisy.
# Default: false
strict: true
nakedret: nakedret:
# Make an issue if func has more lines of code than this setting, and it has naked returns.
# Default: 30
max-func-lines: 0 max-func-lines: 0
nolintlint: nolintlint:
# Exclude following linters from requiring an explanation.
# Default: []
allow-no-explanation: [ funlen, gocognit, lll ]
# Enable to require an explanation of nonzero length after each nolint directive.
# Default: false
require-explanation: true require-explanation: true
# Enable to require nolint directives to mention the specific linter being suppressed.
# Default: false
require-specific: true require-specific: true
allow-no-explanation:
- funlen
- gocognit
- lll
rowserrcheck: rowserrcheck:
# database/sql is always checked
# Default: []
packages: packages:
- github.com/jmoiron/sqlx - github.com/jmoiron/sqlx
exclusions:
generated: lax tenv:
presets: # The option `all` will run against whole test files (`_test.go`) regardless of method/function signatures.
- comments # Otherwise, only methods that take `*testing.T`, `*testing.B`, and `testing.TB` as arguments are checked.
- common-false-positives # Default: false
- legacy all: true
- std-error-handling
rules: varcheck:
- linters: # Check usage of exported fields and variables.
- forbidigo # Default: false
- mnd exported-fields: false # default false # TODO: enable after fixing false positives
- revive
path : ^examples/.*\.go$
- linters: linters:
- lll disable-all: true
source: ^//\s*go:generate\s enable:
- linters: ## enabled by default
- godot - errcheck # Errcheck is a program for checking for unchecked errors in go programs. These unchecked errors can be critical bugs in some cases
source: (noinspection|TODO) - gosimple # Linter for Go source code that specializes in simplifying a code
- linters: - govet # Vet examines Go source code and reports suspicious constructs, such as Printf calls whose arguments do not align with the format string
- gocritic - ineffassign # Detects when assignments to existing variables are not used
source: //noinspection - staticcheck # Staticcheck is a go vet on steroids, applying a ton of static analysis checks
- linters: - typecheck # Like the front-end of a Go compiler, parses and type-checks Go code
- errorlint - unused # Checks Go code for unused constants, variables, functions and types
source: ^\s+if _, ok := err\.\([^.]+\.InternalError\); ok { ## disabled by default
- linters: # - asasalint # Check for pass []any as any in variadic func(...any)
- asciicheck # Simple linter to check that your code does not contain non-ASCII identifiers
- bidichk # Checks for dangerous unicode character sequences
- bodyclose # checks whether HTTP response body is closed successfully
- contextcheck # check the function whether use a non-inherited context
- cyclop # checks function and package cyclomatic complexity
- dupl # Tool for code clone detection
- durationcheck # check for two durations multiplied together
- errname # Checks that sentinel errors are prefixed with the Err and error types are suffixed with the Error.
- errorlint # errorlint is a linter for that can be used to find code that will cause problems with the error wrapping scheme introduced in Go 1.13.
# Removed execinquery (deprecated). execinquery is a linter about query string checker in Query function which reads your Go src files and warning it finds
- exhaustive # check exhaustiveness of enum switch statements
- exportloopref # checks for pointers to enclosing loop variables
- forbidigo # Forbids identifiers
- funlen # Tool for detection of long functions
# - gochecknoglobals # check that no global variables exist
- gochecknoinits # Checks that no init functions are present in Go code
- gocognit # Computes and checks the cognitive complexity of functions
- goconst # Finds repeated strings that could be replaced by a constant
- gocritic # Provides diagnostics that check for bugs, performance and style issues.
- gocyclo # Computes and checks the cyclomatic complexity of functions
- godot # Check if comments end in a period
- goimports # In addition to fixing imports, goimports also formats your code in the same style as gofmt.
- gomoddirectives # Manage the use of 'replace', 'retract', and 'excludes' directives in go.mod.
- gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations.
- goprintffuncname # Checks that printf-like functions are named with f at the end
- gosec # Inspects source code for security problems
- lll # Reports long lines
- makezero # Finds slice declarations with non-zero initial length
# - nakedret # Finds naked returns in functions greater than a specified function length
- mnd # An analyzer to detect magic numbers.
- nestif # Reports deeply nested if statements
- nilerr # Finds the code that returns nil even if it checks that the error is not nil.
- nilnil # Checks that there is no simultaneous return of nil error and an invalid value.
# - noctx # noctx finds sending http request without context.Context
- nolintlint # Reports ill-formed or insufficient nolint directives
# - nonamedreturns # Reports all named returns
- nosprintfhostport # Checks for misuse of Sprintf to construct a host with port in a URL.
- predeclared # find code that shadows one of Go's predeclared identifiers
- promlinter # Check Prometheus metrics naming via promlint
- revive # Fast, configurable, extensible, flexible, and beautiful linter for Go. Drop-in replacement of golint.
- rowserrcheck # checks whether Err of rows is checked successfully
- sqlclosecheck # Checks that sql.Rows and sql.Stmt are closed.
- stylecheck # Stylecheck is a replacement for golint
- tenv # tenv is analyzer that detects using os.Setenv instead of t.Setenv since Go1.17
- testpackage # linter that makes you use a separate _test package
- tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes
- unconvert # Remove unnecessary type conversions
- unparam # Reports unused function parameters
- usetesting # Reports uses of functions with replacement inside the testing package
- wastedassign # wastedassign finds wasted assignment statements.
- whitespace # Tool for detection of leading and trailing whitespace
## you may want to enable
#- decorder # check declaration order and count of types, constants, variables and functions
#- exhaustruct # Checks if all structure fields are initialized
#- goheader # Checks is file header matches to pattern
#- ireturn # Accept Interfaces, Return Concrete Types
#- prealloc # [premature optimization, but can be used in some cases] Finds slice declarations that could potentially be preallocated
#- varnamelen # [great idea, but too many false positives] checks that the length of a variable's name matches its scope
#- wrapcheck # Checks that errors returned from external packages are wrapped
## disabled
#- containedctx # containedctx is a linter that detects struct contained context.Context field
#- depguard # [replaced by gomodguard] Go linter that checks if package imports are in a list of acceptable packages
#- dogsled # Checks assignments with too many blank identifiers (e.g. x, _, _, _, := f())
#- errchkjson # [don't see profit + I'm against of omitting errors like in the first example https://github.com/breml/errchkjson] Checks types passed to the json encoding functions. Reports unsupported types and optionally reports occasions, where the check for the returned error can be omitted.
#- forcetypeassert # [replaced by errcheck] finds forced type assertions
#- gci # Gci controls golang package import order and makes it always deterministic.
#- godox # Tool for detection of FIXME, TODO and other comment keywords
#- goerr113 # [too strict] Golang linter to check the errors handling expressions
#- gofmt # [replaced by goimports] Gofmt checks whether code was gofmt-ed. By default this tool runs with -s option to check for code simplification
#- gofumpt # [replaced by goimports, gofumports is not available yet] Gofumpt checks whether code was gofumpt-ed.
#- grouper # An analyzer to analyze expression groups.
#- ifshort # Checks that your code uses short syntax for if-statements whenever possible
#- importas # Enforces consistent import aliases
#- maintidx # maintidx measures the maintainability index of each function.
#- misspell # [useless] Finds commonly misspelled English words in comments
#- nlreturn # [too strict and mostly code is not more readable] nlreturn checks for a new line before return and branch statements to increase code clarity
#- nosnakecase # Detects snake case of variable naming and function name. # TODO: maybe enable after https://github.com/sivchari/nosnakecase/issues/14
#- paralleltest # [too many false positives] paralleltest detects missing usage of t.Parallel() method in your Go test
#- tagliatelle # Checks the struct tags.
#- thelper # thelper detects golang test helpers without t.Helper() call and checks the consistency of test helpers
#- wsl # [too strict and mostly code is not more readable] Whitespace Linter - Forces you to use empty lines!
## deprecated
#- exhaustivestruct # [deprecated, replaced by exhaustruct] Checks if all struct's fields are initialized
#- golint # [deprecated, replaced by revive] Golint differs from gofmt. Gofmt reformats Go source code, whereas golint prints out style mistakes
#- interfacer # [deprecated] Linter that suggests narrower interface types
#- maligned # [deprecated, replaced by govet fieldalignment] Tool to detect Go structs that would take less memory if their fields were sorted
#- scopelint # [deprecated, replaced by exportloopref] Scopelint checks for unpinned variables in go programs
issues:
# Maximum count of issues with the same text.
# Set to 0 to disable.
# Default: 3
max-same-issues: 50
exclude-rules:
- source: "^//\\s*go:generate\\s"
linters: [ lll ]
- source: "(noinspection|TODO)"
linters: [ godot ]
- source: "//noinspection"
linters: [ gocritic ]
- source: "^\\s+if _, ok := err\\.\\([^.]+\\.InternalError\\); ok {"
linters: [ errorlint ]
- path: "_test\\.go"
linters:
- bodyclose - bodyclose
- dupl - dupl
- funlen - funlen
@@ -149,20 +271,3 @@ linters:
- gosec - gosec
- noctx - noctx
- wrapcheck - wrapcheck
- staticcheck
path: _test\.go
paths:
- third_party$
- builtin$
- examples$
issues:
max-same-issues: 50
formatters:
enable:
- goimports
exclusions:
generated: lax
paths:
- third_party$
- builtin$
- examples$

View File

@@ -7,7 +7,7 @@ This library provides unofficial Go clients for [OpenAI API](https://platform.op
* ChatGPT 4o, o1 * ChatGPT 4o, o1
* GPT-3, GPT-4 * GPT-3, GPT-4
* DALL·E 2, DALL·E 3, GPT Image 1 * DALL·E 2, DALL·E 3
* Whisper * Whisper
## Installation ## Installation
@@ -357,66 +357,6 @@ func main() {
``` ```
</details> </details>
<details>
<summary>GPT Image 1 image generation</summary>
```go
package main
import (
"context"
"encoding/base64"
"fmt"
"os"
openai "github.com/sashabaranov/go-openai"
)
func main() {
c := openai.NewClient("your token")
ctx := context.Background()
req := openai.ImageRequest{
Prompt: "Parrot on a skateboard performing a trick. Large bold text \"SKATE MASTER\" banner at the bottom of the image. Cartoon style, natural light, high detail, 1:1 aspect ratio.",
Background: openai.CreateImageBackgroundOpaque,
Model: openai.CreateImageModelGptImage1,
Size: openai.CreateImageSize1024x1024,
N: 1,
Quality: openai.CreateImageQualityLow,
OutputCompression: 100,
OutputFormat: openai.CreateImageOutputFormatJPEG,
// Moderation: openai.CreateImageModerationLow,
// User: "",
}
resp, err := c.CreateImage(ctx, req)
if err != nil {
fmt.Printf("Image creation Image generation with GPT Image 1error: %v\n", err)
return
}
fmt.Println("Image Base64:", resp.Data[0].B64JSON)
// Decode the base64 data
imgBytes, err := base64.StdEncoding.DecodeString(resp.Data[0].B64JSON)
if err != nil {
fmt.Printf("Base64 decode error: %v\n", err)
return
}
// Write image to file
outputPath := "generated_image.jpg"
err = os.WriteFile(outputPath, imgBytes, 0644)
if err != nil {
fmt.Printf("Failed to write image file: %v\n", err)
return
}
fmt.Printf("The image was saved as %s\n", outputPath)
}
```
</details>
<details> <details>
<summary>Configuring proxy</summary> <summary>Configuring proxy</summary>

View File

@@ -2,11 +2,8 @@ package openai //nolint:testpackage // testing private field
import ( import (
"bytes" "bytes"
"context"
"errors"
"fmt" "fmt"
"io" "io"
"net/http"
"os" "os"
"path/filepath" "path/filepath"
"testing" "testing"
@@ -110,131 +107,3 @@ func TestCreateFileField(t *testing.T) {
checks.HasError(t, err, "createFileField using file should return error when open file fails") checks.HasError(t, err, "createFileField using file should return error when open file fails")
}) })
} }
// failingFormBuilder always returns an error when creating form files.
type failingFormBuilder struct{ err error }
func (f *failingFormBuilder) CreateFormFile(_ string, _ *os.File) error {
return f.err
}
func (f *failingFormBuilder) CreateFormFileReader(_ string, _ io.Reader, _ string) error {
return f.err
}
func (f *failingFormBuilder) WriteField(_, _ string) error {
return nil
}
func (f *failingFormBuilder) Close() error {
return nil
}
func (f *failingFormBuilder) FormDataContentType() string {
return "multipart/form-data"
}
// failingAudioRequestBuilder simulates an error during HTTP request construction.
type failingAudioRequestBuilder struct{ err error }
func (f *failingAudioRequestBuilder) Build(
_ context.Context,
_, _ string,
_ any,
_ http.Header,
) (*http.Request, error) {
return nil, f.err
}
// errorHTTPClient always returns an error when making HTTP calls.
type errorHTTPClient struct{ err error }
func (e *errorHTTPClient) Do(_ *http.Request) (*http.Response, error) {
return nil, e.err
}
func TestCallAudioAPIMultipartFormError(t *testing.T) {
client := NewClient("test-token")
errForm := errors.New("mock create form file failure")
// Override form builder to force an error during multipart form creation.
client.createFormBuilder = func(_ io.Writer) utils.FormBuilder {
return &failingFormBuilder{err: errForm}
}
// Provide a reader so createFileField uses the reader path (no file open).
req := AudioRequest{FilePath: "fake.mp3", Reader: bytes.NewBuffer([]byte("dummy")), Model: Whisper1}
_, err := client.callAudioAPI(context.Background(), req, "transcriptions")
if err == nil {
t.Fatal("expected error but got none")
}
if !errors.Is(err, errForm) {
t.Errorf("expected error %v, got %v", errForm, err)
}
}
func TestCallAudioAPINewRequestError(t *testing.T) {
client := NewClient("test-token")
// Create a real temp file so multipart form succeeds.
tmp := t.TempDir()
path := filepath.Join(tmp, "file.mp3")
if err := os.WriteFile(path, []byte("content"), 0644); err != nil {
t.Fatalf("failed to write temp file: %v", err)
}
errBuild := errors.New("mock build failure")
client.requestBuilder = &failingAudioRequestBuilder{err: errBuild}
req := AudioRequest{FilePath: path, Model: Whisper1}
_, err := client.callAudioAPI(context.Background(), req, "translations")
if err == nil {
t.Fatal("expected error but got none")
}
if !errors.Is(err, errBuild) {
t.Errorf("expected error %v, got %v", errBuild, err)
}
}
func TestCallAudioAPISendRequestErrorJSON(t *testing.T) {
client := NewClient("test-token")
// Create a real temp file so multipart form succeeds.
tmp := t.TempDir()
path := filepath.Join(tmp, "file.mp3")
if err := os.WriteFile(path, []byte("content"), 0644); err != nil {
t.Fatalf("failed to write temp file: %v", err)
}
errHTTP := errors.New("mock HTTPClient failure")
// Override HTTP client to simulate a network error.
client.config.HTTPClient = &errorHTTPClient{err: errHTTP}
req := AudioRequest{FilePath: path, Model: Whisper1}
_, err := client.callAudioAPI(context.Background(), req, "transcriptions")
if err == nil {
t.Fatal("expected error but got none")
}
if !errors.Is(err, errHTTP) {
t.Errorf("expected error %v, got %v", errHTTP, err)
}
}
func TestCallAudioAPISendRequestErrorText(t *testing.T) {
client := NewClient("test-token")
tmp := t.TempDir()
path := filepath.Join(tmp, "file.mp3")
if err := os.WriteFile(path, []byte("content"), 0644); err != nil {
t.Fatalf("failed to write temp file: %v", err)
}
errHTTP := errors.New("mock HTTPClient failure")
client.config.HTTPClient = &errorHTTPClient{err: errHTTP}
// Use a non-JSON response format to exercise the text path.
req := AudioRequest{FilePath: path, Model: Whisper1, Format: AudioResponseFormatText}
_, err := client.callAudioAPI(context.Background(), req, "translations")
if err == nil {
t.Fatal("expected error but got none")
}
if !errors.Is(err, errHTTP) {
t.Errorf("expected error %v, got %v", errHTTP, err)
}
}

25
chat.go
View File

@@ -14,7 +14,6 @@ const (
ChatMessageRoleAssistant = "assistant" ChatMessageRoleAssistant = "assistant"
ChatMessageRoleFunction = "function" ChatMessageRoleFunction = "function"
ChatMessageRoleTool = "tool" ChatMessageRoleTool = "tool"
ChatMessageRoleDeveloper = "developer"
) )
const chatCompletionsSuffix = "/chat/completions" const chatCompletionsSuffix = "/chat/completions"
@@ -104,12 +103,6 @@ type ChatCompletionMessage struct {
// - https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb // - https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
Name string `json:"name,omitempty"` Name string `json:"name,omitempty"`
// This property is used for the "reasoning" feature supported by deepseek-reasoner
// which is not in the official documentation.
// the doc from deepseek:
// - https://api-docs.deepseek.com/api/create-chat-completion#responses
ReasoningContent string `json:"reasoning_content,omitempty"`
FunctionCall *FunctionCall `json:"function_call,omitempty"` FunctionCall *FunctionCall `json:"function_call,omitempty"`
// For Role=assistant prompts this may be set to the tool calls generated by the model, such as function calls. // For Role=assistant prompts this may be set to the tool calls generated by the model, such as function calls.
@@ -130,7 +123,6 @@ func (m ChatCompletionMessage) MarshalJSON() ([]byte, error) {
Refusal string `json:"refusal,omitempty"` Refusal string `json:"refusal,omitempty"`
MultiContent []ChatMessagePart `json:"content,omitempty"` MultiContent []ChatMessagePart `json:"content,omitempty"`
Name string `json:"name,omitempty"` Name string `json:"name,omitempty"`
ReasoningContent string `json:"reasoning_content,omitempty"`
FunctionCall *FunctionCall `json:"function_call,omitempty"` FunctionCall *FunctionCall `json:"function_call,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"` ToolCalls []ToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"` ToolCallID string `json:"tool_call_id,omitempty"`
@@ -144,7 +136,6 @@ func (m ChatCompletionMessage) MarshalJSON() ([]byte, error) {
Refusal string `json:"refusal,omitempty"` Refusal string `json:"refusal,omitempty"`
MultiContent []ChatMessagePart `json:"-"` MultiContent []ChatMessagePart `json:"-"`
Name string `json:"name,omitempty"` Name string `json:"name,omitempty"`
ReasoningContent string `json:"reasoning_content,omitempty"`
FunctionCall *FunctionCall `json:"function_call,omitempty"` FunctionCall *FunctionCall `json:"function_call,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"` ToolCalls []ToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"` ToolCallID string `json:"tool_call_id,omitempty"`
@@ -155,11 +146,10 @@ func (m ChatCompletionMessage) MarshalJSON() ([]byte, error) {
func (m *ChatCompletionMessage) UnmarshalJSON(bs []byte) error { func (m *ChatCompletionMessage) UnmarshalJSON(bs []byte) error {
msg := struct { msg := struct {
Role string `json:"role"` Role string `json:"role"`
Content string `json:"content"` Content string `json:"content,omitempty"`
Refusal string `json:"refusal,omitempty"` Refusal string `json:"refusal,omitempty"`
MultiContent []ChatMessagePart MultiContent []ChatMessagePart
Name string `json:"name,omitempty"` Name string `json:"name,omitempty"`
ReasoningContent string `json:"reasoning_content,omitempty"`
FunctionCall *FunctionCall `json:"function_call,omitempty"` FunctionCall *FunctionCall `json:"function_call,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"` ToolCalls []ToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"` ToolCallID string `json:"tool_call_id,omitempty"`
@@ -175,7 +165,6 @@ func (m *ChatCompletionMessage) UnmarshalJSON(bs []byte) error {
Refusal string `json:"refusal,omitempty"` Refusal string `json:"refusal,omitempty"`
MultiContent []ChatMessagePart `json:"content"` MultiContent []ChatMessagePart `json:"content"`
Name string `json:"name,omitempty"` Name string `json:"name,omitempty"`
ReasoningContent string `json:"reasoning_content,omitempty"`
FunctionCall *FunctionCall `json:"function_call,omitempty"` FunctionCall *FunctionCall `json:"function_call,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"` ToolCalls []ToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"` ToolCallID string `json:"tool_call_id,omitempty"`
@@ -275,13 +264,6 @@ type ChatCompletionRequest struct {
Metadata map[string]string `json:"metadata,omitempty"` Metadata map[string]string `json:"metadata,omitempty"`
IncludeReasoning *bool `json:"include_reasoning,omitempty"` IncludeReasoning *bool `json:"include_reasoning,omitempty"`
ReasoningFormat *string `json:"reasoning_format,omitempty"` ReasoningFormat *string `json:"reasoning_format,omitempty"`
// Configuration for a predicted output.
Prediction *Prediction `json:"prediction,omitempty"`
// ChatTemplateKwargs provides a way to add non-standard parameters to the request body.
// Additional kwargs to pass to the template renderer. Will be accessible by the chat template.
// Such as think mode for qwen3. "chat_template_kwargs": {"enable_thinking": false}
// https://qwen.readthedocs.io/en/latest/deployment/vllm.html#thinking-non-thinking-modes
ChatTemplateKwargs map[string]any `json:"chat_template_kwargs,omitempty"`
} }
type StreamOptions struct { type StreamOptions struct {
@@ -349,11 +331,6 @@ type LogProbs struct {
Content []LogProb `json:"content"` Content []LogProb `json:"content"`
} }
type Prediction struct {
Content string `json:"content"`
Type string `json:"type"`
}
type FinishReason string type FinishReason string
const ( const (

View File

@@ -11,12 +11,6 @@ type ChatCompletionStreamChoiceDelta struct {
FunctionCall *FunctionCall `json:"function_call,omitempty"` FunctionCall *FunctionCall `json:"function_call,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"` ToolCalls []ToolCall `json:"tool_calls,omitempty"`
Refusal string `json:"refusal,omitempty"` Refusal string `json:"refusal,omitempty"`
// This property is used for the "reasoning" feature supported by deepseek-reasoner
// which is not in the official documentation.
// the doc from deepseek:
// - https://api-docs.deepseek.com/api/create-chat-completion#responses
ReasoningContent string `json:"reasoning_content,omitempty"`
} }
type ChatCompletionStreamChoiceLogprobs struct { type ChatCompletionStreamChoiceLogprobs struct {

View File

@@ -959,56 +959,6 @@ func TestCreateChatCompletionStreamReasoningValidatorFails(t *testing.T) {
} }
} }
func TestCreateChatCompletionStreamO3ReasoningValidatorFails(t *testing.T) {
client, _, _ := setupOpenAITestServer()
stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{
MaxTokens: 100, // This will trigger the validator to fail
Model: openai.O3,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
Content: "Hello!",
},
},
Stream: true,
})
if stream != nil {
t.Error("Expected nil stream when validation fails")
stream.Close()
}
if !errors.Is(err, openai.ErrReasoningModelMaxTokensDeprecated) {
t.Errorf("Expected ErrReasoningModelMaxTokensDeprecated for O3, got: %v", err)
}
}
func TestCreateChatCompletionStreamO4MiniReasoningValidatorFails(t *testing.T) {
client, _, _ := setupOpenAITestServer()
stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{
MaxTokens: 100, // This will trigger the validator to fail
Model: openai.O4Mini,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
Content: "Hello!",
},
},
Stream: true,
})
if stream != nil {
t.Error("Expected nil stream when validation fails")
stream.Close()
}
if !errors.Is(err, openai.ErrReasoningModelMaxTokensDeprecated) {
t.Errorf("Expected ErrReasoningModelMaxTokensDeprecated for O4Mini, got: %v", err)
}
}
func compareChatStreamResponseChoices(c1, c2 openai.ChatCompletionStreamChoice) bool { func compareChatStreamResponseChoices(c1, c2 openai.ChatCompletionStreamChoice) bool {
if c1.Index != c2.Index { if c1.Index != c2.Index {
return false return false

View File

@@ -106,6 +106,40 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) {
}, },
expectedError: openai.ErrReasoningModelLimitationsLogprobs, expectedError: openai.ErrReasoningModelLimitationsLogprobs,
}, },
{
name: "message_type_unsupported",
in: openai.ChatCompletionRequest{
MaxCompletionTokens: 1000,
Model: openai.O1Mini,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleSystem,
},
},
},
expectedError: openai.ErrO1BetaLimitationsMessageTypes,
},
{
name: "tool_unsupported",
in: openai.ChatCompletionRequest{
MaxCompletionTokens: 1000,
Model: openai.O1Mini,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
},
{
Role: openai.ChatMessageRoleAssistant,
},
},
Tools: []openai.Tool{
{
Type: openai.ToolTypeFunction,
},
},
},
expectedError: openai.ErrO1BetaLimitationsTools,
},
{ {
name: "set_temperature_unsupported", name: "set_temperature_unsupported",
in: openai.ChatCompletionRequest{ in: openai.ChatCompletionRequest{
@@ -411,23 +445,6 @@ func TestO3ModelChatCompletions(t *testing.T) {
checks.NoError(t, err, "CreateChatCompletion error") checks.NoError(t, err, "CreateChatCompletion error")
} }
func TestDeepseekR1ModelChatCompletions(t *testing.T) {
client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler("/v1/chat/completions", handleDeepseekR1ChatCompletionEndpoint)
_, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{
Model: "deepseek-reasoner",
MaxCompletionTokens: 100,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
Content: "Hello!",
},
},
})
checks.NoError(t, err, "CreateChatCompletion error")
}
// TestCompletions Tests the completions endpoint of the API using the mocked server. // TestCompletions Tests the completions endpoint of the API using the mocked server.
func TestChatCompletionsWithHeaders(t *testing.T) { func TestChatCompletionsWithHeaders(t *testing.T) {
client, server, teardown := setupOpenAITestServer() client, server, teardown := setupOpenAITestServer()
@@ -839,68 +856,6 @@ func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, string(resBytes)) fmt.Fprintln(w, string(resBytes))
} }
func handleDeepseekR1ChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
var err error
var resBytes []byte
// completions only accepts POST requests
if r.Method != "POST" {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
var completionReq openai.ChatCompletionRequest
if completionReq, err = getChatCompletionBody(r); err != nil {
http.Error(w, "could not read request", http.StatusInternalServerError)
return
}
res := openai.ChatCompletionResponse{
ID: strconv.Itoa(int(time.Now().Unix())),
Object: "test-object",
Created: time.Now().Unix(),
// would be nice to validate Model during testing, but
// this may not be possible with how much upkeep
// would be required / wouldn't make much sense
Model: completionReq.Model,
}
// create completions
n := completionReq.N
if n == 0 {
n = 1
}
if completionReq.MaxCompletionTokens == 0 {
completionReq.MaxCompletionTokens = 1000
}
for i := 0; i < n; i++ {
reasoningContent := "User says hello! And I need to reply"
completionStr := strings.Repeat("a", completionReq.MaxCompletionTokens-numTokens(reasoningContent))
res.Choices = append(res.Choices, openai.ChatCompletionChoice{
Message: openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleAssistant,
ReasoningContent: reasoningContent,
Content: completionStr,
},
Index: i,
})
}
inputTokens := numTokens(completionReq.Messages[0].Content) * n
completionTokens := completionReq.MaxTokens * n
res.Usage = openai.Usage{
PromptTokens: inputTokens,
CompletionTokens: completionTokens,
TotalTokens: inputTokens + completionTokens,
}
resBytes, _ = json.Marshal(res)
w.Header().Set(xCustomHeader, xCustomHeaderValue)
for k, v := range rateLimitHeaders {
switch val := v.(type) {
case int:
w.Header().Set(k, strconv.Itoa(val))
default:
w.Header().Set(k, fmt.Sprintf("%s", v))
}
}
fmt.Fprintln(w, string(resBytes))
}
// getChatCompletionBody Returns the body of the request to create a completion. // getChatCompletionBody Returns the body of the request to create a completion.
func getChatCompletionBody(r *http.Request) (openai.ChatCompletionRequest, error) { func getChatCompletionBody(r *http.Request) (openai.ChatCompletionRequest, error) {
completion := openai.ChatCompletionRequest{} completion := openai.ChatCompletionRequest{}

View File

@@ -182,21 +182,13 @@ func sendRequestStream[T streamable](client *Client, req *http.Request) (*stream
func (c *Client) setCommonHeaders(req *http.Request) { func (c *Client) setCommonHeaders(req *http.Request) {
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/reference#authentication // https://learn.microsoft.com/en-us/azure/cognitive-services/openai/reference#authentication
switch c.config.APIType {
case APITypeAzure, APITypeCloudflareAzure:
// Azure API Key authentication // Azure API Key authentication
if c.config.APIType == APITypeAzure || c.config.APIType == APITypeCloudflareAzure {
req.Header.Set(AzureAPIKeyHeader, c.config.authToken) req.Header.Set(AzureAPIKeyHeader, c.config.authToken)
case APITypeAnthropic: } else if c.config.authToken != "" {
// https://docs.anthropic.com/en/api/versioning // OpenAI or Azure AD authentication
req.Header.Set("anthropic-version", c.config.APIVersion)
case APITypeOpenAI, APITypeAzureAD:
fallthrough
default:
if c.config.authToken != "" {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken)) req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken))
} }
}
if c.config.OrgID != "" { if c.config.OrgID != "" {
req.Header.Set("OpenAI-Organization", c.config.OrgID) req.Header.Set("OpenAI-Organization", c.config.OrgID)
} }

View File

@@ -39,21 +39,6 @@ func TestClient(t *testing.T) {
} }
} }
func TestSetCommonHeadersAnthropic(t *testing.T) {
config := DefaultAnthropicConfig("mock-token", "")
client := NewClientWithConfig(config)
req, err := http.NewRequest("GET", "http://example.com", nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
client.setCommonHeaders(req)
if got := req.Header.Get("anthropic-version"); got != AnthropicAPIVersion {
t.Errorf("Expected anthropic-version header to be %q, got %q", AnthropicAPIVersion, got)
}
}
func TestDecodeResponse(t *testing.T) { func TestDecodeResponse(t *testing.T) {
stringInput := "" stringInput := ""

View File

@@ -15,8 +15,6 @@ type Usage struct {
type CompletionTokensDetails struct { type CompletionTokensDetails struct {
AudioTokens int `json:"audio_tokens"` AudioTokens int `json:"audio_tokens"`
ReasoningTokens int `json:"reasoning_tokens"` ReasoningTokens int `json:"reasoning_tokens"`
AcceptedPredictionTokens int `json:"accepted_prediction_tokens"`
RejectedPredictionTokens int `json:"rejected_prediction_tokens"`
} }
// PromptTokensDetails Breakdown of tokens used in the prompt. // PromptTokensDetails Breakdown of tokens used in the prompt.

View File

@@ -16,12 +16,8 @@ const (
O1Preview20240912 = "o1-preview-2024-09-12" O1Preview20240912 = "o1-preview-2024-09-12"
O1 = "o1" O1 = "o1"
O120241217 = "o1-2024-12-17" O120241217 = "o1-2024-12-17"
O3 = "o3"
O320250416 = "o3-2025-04-16"
O3Mini = "o3-mini" O3Mini = "o3-mini"
O3Mini20250131 = "o3-mini-2025-01-31" O3Mini20250131 = "o3-mini-2025-01-31"
O4Mini = "o4-mini"
O4Mini20250416 = "o4-mini-2025-04-16"
GPT432K0613 = "gpt-4-32k-0613" GPT432K0613 = "gpt-4-32k-0613"
GPT432K0314 = "gpt-4-32k-0314" GPT432K0314 = "gpt-4-32k-0314"
GPT432K = "gpt-4-32k" GPT432K = "gpt-4-32k"
@@ -41,14 +37,6 @@ const (
GPT4TurboPreview = "gpt-4-turbo-preview" GPT4TurboPreview = "gpt-4-turbo-preview"
GPT4VisionPreview = "gpt-4-vision-preview" GPT4VisionPreview = "gpt-4-vision-preview"
GPT4 = "gpt-4" GPT4 = "gpt-4"
GPT4Dot1 = "gpt-4.1"
GPT4Dot120250414 = "gpt-4.1-2025-04-14"
GPT4Dot1Mini = "gpt-4.1-mini"
GPT4Dot1Mini20250414 = "gpt-4.1-mini-2025-04-14"
GPT4Dot1Nano = "gpt-4.1-nano"
GPT4Dot1Nano20250414 = "gpt-4.1-nano-2025-04-14"
GPT4Dot5Preview = "gpt-4.5-preview"
GPT4Dot5Preview20250227 = "gpt-4.5-preview-2025-02-27"
GPT3Dot5Turbo0125 = "gpt-3.5-turbo-0125" GPT3Dot5Turbo0125 = "gpt-3.5-turbo-0125"
GPT3Dot5Turbo1106 = "gpt-3.5-turbo-1106" GPT3Dot5Turbo1106 = "gpt-3.5-turbo-1106"
GPT3Dot5Turbo0613 = "gpt-3.5-turbo-0613" GPT3Dot5Turbo0613 = "gpt-3.5-turbo-0613"
@@ -103,10 +91,6 @@ var disabledModelsForEndpoints = map[string]map[string]bool{
O1Preview20240912: true, O1Preview20240912: true,
O3Mini: true, O3Mini: true,
O3Mini20250131: true, O3Mini20250131: true,
O4Mini: true,
O4Mini20250416: true,
O3: true,
O320250416: true,
GPT3Dot5Turbo: true, GPT3Dot5Turbo: true,
GPT3Dot5Turbo0301: true, GPT3Dot5Turbo0301: true,
GPT3Dot5Turbo0613: true, GPT3Dot5Turbo0613: true,
@@ -115,8 +99,6 @@ var disabledModelsForEndpoints = map[string]map[string]bool{
GPT3Dot5Turbo16K: true, GPT3Dot5Turbo16K: true,
GPT3Dot5Turbo16K0613: true, GPT3Dot5Turbo16K0613: true,
GPT4: true, GPT4: true,
GPT4Dot5Preview: true,
GPT4Dot5Preview20250227: true,
GPT4o: true, GPT4o: true,
GPT4o20240513: true, GPT4o20240513: true,
GPT4o20240806: true, GPT4o20240806: true,
@@ -135,13 +117,6 @@ var disabledModelsForEndpoints = map[string]map[string]bool{
GPT432K: true, GPT432K: true,
GPT432K0314: true, GPT432K0314: true,
GPT432K0613: true, GPT432K0613: true,
O1: true,
GPT4Dot1: true,
GPT4Dot120250414: true,
GPT4Dot1Mini: true,
GPT4Dot1Mini20250414: true,
GPT4Dot1Nano: true,
GPT4Dot1Nano20250414: true,
}, },
chatCompletionsSuffix: { chatCompletionsSuffix: {
CodexCodeDavinci002: true, CodexCodeDavinci002: true,
@@ -215,8 +190,6 @@ type CompletionRequest struct {
Temperature float32 `json:"temperature,omitempty"` Temperature float32 `json:"temperature,omitempty"`
TopP float32 `json:"top_p,omitempty"` TopP float32 `json:"top_p,omitempty"`
User string `json:"user,omitempty"` User string `json:"user,omitempty"`
// Options for streaming response. Only set this when you set stream: true.
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
} }
// CompletionChoice represents one of possible completions. // CompletionChoice represents one of possible completions.

View File

@@ -33,42 +33,6 @@ func TestCompletionsWrongModel(t *testing.T) {
} }
} }
// TestCompletionsWrongModelO3 Tests the completions endpoint with O3 model which is not supported.
func TestCompletionsWrongModelO3(t *testing.T) {
config := openai.DefaultConfig("whatever")
config.BaseURL = "http://localhost/v1"
client := openai.NewClientWithConfig(config)
_, err := client.CreateCompletion(
context.Background(),
openai.CompletionRequest{
MaxTokens: 5,
Model: openai.O3,
},
)
if !errors.Is(err, openai.ErrCompletionUnsupportedModel) {
t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel for O3, but returned: %v", err)
}
}
// TestCompletionsWrongModelO4Mini Tests the completions endpoint with O4Mini model which is not supported.
func TestCompletionsWrongModelO4Mini(t *testing.T) {
config := openai.DefaultConfig("whatever")
config.BaseURL = "http://localhost/v1"
client := openai.NewClientWithConfig(config)
_, err := client.CreateCompletion(
context.Background(),
openai.CompletionRequest{
MaxTokens: 5,
Model: openai.O4Mini,
},
)
if !errors.Is(err, openai.ErrCompletionUnsupportedModel) {
t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel for O4Mini, but returned: %v", err)
}
}
func TestCompletionWithStream(t *testing.T) { func TestCompletionWithStream(t *testing.T) {
config := openai.DefaultConfig("whatever") config := openai.DefaultConfig("whatever")
client := openai.NewClientWithConfig(config) client := openai.NewClientWithConfig(config)
@@ -217,86 +181,3 @@ func getCompletionBody(r *http.Request) (openai.CompletionRequest, error) {
} }
return completion, nil return completion, nil
} }
// TestCompletionWithO1Model Tests that O1 model is not supported for completion endpoint.
func TestCompletionWithO1Model(t *testing.T) {
config := openai.DefaultConfig("whatever")
config.BaseURL = "http://localhost/v1"
client := openai.NewClientWithConfig(config)
_, err := client.CreateCompletion(
context.Background(),
openai.CompletionRequest{
MaxTokens: 5,
Model: openai.O1,
},
)
if !errors.Is(err, openai.ErrCompletionUnsupportedModel) {
t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel for O1 model, but returned: %v", err)
}
}
// TestCompletionWithGPT4DotModels Tests that newer GPT4 models are not supported for completion endpoint.
func TestCompletionWithGPT4DotModels(t *testing.T) {
config := openai.DefaultConfig("whatever")
config.BaseURL = "http://localhost/v1"
client := openai.NewClientWithConfig(config)
models := []string{
openai.GPT4Dot1,
openai.GPT4Dot120250414,
openai.GPT4Dot1Mini,
openai.GPT4Dot1Mini20250414,
openai.GPT4Dot1Nano,
openai.GPT4Dot1Nano20250414,
openai.GPT4Dot5Preview,
openai.GPT4Dot5Preview20250227,
}
for _, model := range models {
t.Run(model, func(t *testing.T) {
_, err := client.CreateCompletion(
context.Background(),
openai.CompletionRequest{
MaxTokens: 5,
Model: model,
},
)
if !errors.Is(err, openai.ErrCompletionUnsupportedModel) {
t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel for %s model, but returned: %v", model, err)
}
})
}
}
// TestCompletionWithGPT4oModels Tests that GPT4o models are not supported for completion endpoint.
func TestCompletionWithGPT4oModels(t *testing.T) {
config := openai.DefaultConfig("whatever")
config.BaseURL = "http://localhost/v1"
client := openai.NewClientWithConfig(config)
models := []string{
openai.GPT4o,
openai.GPT4o20240513,
openai.GPT4o20240806,
openai.GPT4o20241120,
openai.GPT4oLatest,
openai.GPT4oMini,
openai.GPT4oMini20240718,
}
for _, model := range models {
t.Run(model, func(t *testing.T) {
_, err := client.CreateCompletion(
context.Background(),
openai.CompletionRequest{
MaxTokens: 5,
Model: model,
},
)
if !errors.Is(err, openai.ErrCompletionUnsupportedModel) {
t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel for %s model, but returned: %v", model, err)
}
})
}
}

View File

@@ -11,8 +11,6 @@ const (
azureAPIPrefix = "openai" azureAPIPrefix = "openai"
azureDeploymentsPrefix = "deployments" azureDeploymentsPrefix = "deployments"
AnthropicAPIVersion = "2023-06-01"
) )
type APIType string type APIType string
@@ -22,7 +20,6 @@ const (
APITypeAzure APIType = "AZURE" APITypeAzure APIType = "AZURE"
APITypeAzureAD APIType = "AZURE_AD" APITypeAzureAD APIType = "AZURE_AD"
APITypeCloudflareAzure APIType = "CLOUDFLARE_AZURE" APITypeCloudflareAzure APIType = "CLOUDFLARE_AZURE"
APITypeAnthropic APIType = "ANTHROPIC"
) )
const AzureAPIKeyHeader = "api-key" const AzureAPIKeyHeader = "api-key"
@@ -40,7 +37,7 @@ type ClientConfig struct {
BaseURL string BaseURL string
OrgID string OrgID string
APIType APIType APIType APIType
APIVersion string // required when APIType is APITypeAzure or APITypeAzureAD or APITypeAnthropic APIVersion string // required when APIType is APITypeAzure or APITypeAzureAD
AssistantVersion string AssistantVersion string
AzureModelMapperFunc func(model string) string // replace model to azure deployment name func AzureModelMapperFunc func(model string) string // replace model to azure deployment name func
HTTPClient HTTPDoer HTTPClient HTTPDoer
@@ -79,23 +76,6 @@ func DefaultAzureConfig(apiKey, baseURL string) ClientConfig {
} }
} }
func DefaultAnthropicConfig(apiKey, baseURL string) ClientConfig {
if baseURL == "" {
baseURL = "https://api.anthropic.com/v1"
}
return ClientConfig{
authToken: apiKey,
BaseURL: baseURL,
OrgID: "",
APIType: APITypeAnthropic,
APIVersion: AnthropicAPIVersion,
HTTPClient: &http.Client{},
EmptyMessagesLimit: defaultEmptyMessagesLimit,
}
}
func (ClientConfig) String() string { func (ClientConfig) String() string {
return "<OpenAI API ClientConfig>" return "<OpenAI API ClientConfig>"
} }

View File

@@ -60,64 +60,3 @@ func TestGetAzureDeploymentByModel(t *testing.T) {
}) })
} }
} }
func TestDefaultAnthropicConfig(t *testing.T) {
apiKey := "test-key"
baseURL := "https://api.anthropic.com/v1"
config := openai.DefaultAnthropicConfig(apiKey, baseURL)
if config.APIType != openai.APITypeAnthropic {
t.Errorf("Expected APIType to be %v, got %v", openai.APITypeAnthropic, config.APIType)
}
if config.APIVersion != openai.AnthropicAPIVersion {
t.Errorf("Expected APIVersion to be 2023-06-01, got %v", config.APIVersion)
}
if config.BaseURL != baseURL {
t.Errorf("Expected BaseURL to be %v, got %v", baseURL, config.BaseURL)
}
if config.EmptyMessagesLimit != 300 {
t.Errorf("Expected EmptyMessagesLimit to be 300, got %v", config.EmptyMessagesLimit)
}
}
func TestDefaultAnthropicConfigWithEmptyValues(t *testing.T) {
config := openai.DefaultAnthropicConfig("", "")
if config.APIType != openai.APITypeAnthropic {
t.Errorf("Expected APIType to be %v, got %v", openai.APITypeAnthropic, config.APIType)
}
if config.APIVersion != openai.AnthropicAPIVersion {
t.Errorf("Expected APIVersion to be %s, got %v", openai.AnthropicAPIVersion, config.APIVersion)
}
expectedBaseURL := "https://api.anthropic.com/v1"
if config.BaseURL != expectedBaseURL {
t.Errorf("Expected BaseURL to be %v, got %v", expectedBaseURL, config.BaseURL)
}
}
func TestClientConfigString(t *testing.T) {
// String() should always return the constant value
conf := openai.DefaultConfig("dummy-token")
expected := "<OpenAI API ClientConfig>"
got := conf.String()
if got != expected {
t.Errorf("ClientConfig.String() = %q; want %q", got, expected)
}
}
func TestGetAzureDeploymentByModel_NoMapper(t *testing.T) {
// On a zero-value or DefaultConfig, AzureModelMapperFunc is nil,
// so GetAzureDeploymentByModel should just return the input model.
conf := openai.DefaultConfig("dummy-token")
model := "some-model"
got := conf.GetAzureDeploymentByModel(model)
if got != model {
t.Errorf("GetAzureDeploymentByModel(%q) = %q; want %q", model, got, model)
}
}

View File

@@ -3,8 +3,8 @@ package openai
import ( import (
"bytes" "bytes"
"context" "context"
"io"
"net/http" "net/http"
"os"
"strconv" "strconv"
) )
@@ -13,62 +13,31 @@ const (
CreateImageSize256x256 = "256x256" CreateImageSize256x256 = "256x256"
CreateImageSize512x512 = "512x512" CreateImageSize512x512 = "512x512"
CreateImageSize1024x1024 = "1024x1024" CreateImageSize1024x1024 = "1024x1024"
// dall-e-3 supported only. // dall-e-3 supported only.
CreateImageSize1792x1024 = "1792x1024" CreateImageSize1792x1024 = "1792x1024"
CreateImageSize1024x1792 = "1024x1792" CreateImageSize1024x1792 = "1024x1792"
// gpt-image-1 supported only.
CreateImageSize1536x1024 = "1536x1024" // Landscape
CreateImageSize1024x1536 = "1024x1536" // Portrait
) )
const ( const (
// dall-e-2 and dall-e-3 only.
CreateImageResponseFormatB64JSON = "b64_json"
CreateImageResponseFormatURL = "url" CreateImageResponseFormatURL = "url"
CreateImageResponseFormatB64JSON = "b64_json"
) )
const ( const (
CreateImageModelDallE2 = "dall-e-2" CreateImageModelDallE2 = "dall-e-2"
CreateImageModelDallE3 = "dall-e-3" CreateImageModelDallE3 = "dall-e-3"
CreateImageModelGptImage1 = "gpt-image-1"
) )
const ( const (
CreateImageQualityHD = "hd" CreateImageQualityHD = "hd"
CreateImageQualityStandard = "standard" CreateImageQualityStandard = "standard"
// gpt-image-1 only.
CreateImageQualityHigh = "high"
CreateImageQualityMedium = "medium"
CreateImageQualityLow = "low"
) )
const ( const (
// dall-e-3 only.
CreateImageStyleVivid = "vivid" CreateImageStyleVivid = "vivid"
CreateImageStyleNatural = "natural" CreateImageStyleNatural = "natural"
) )
const (
// gpt-image-1 only.
CreateImageBackgroundTransparent = "transparent"
CreateImageBackgroundOpaque = "opaque"
)
const (
// gpt-image-1 only.
CreateImageModerationLow = "low"
)
const (
// gpt-image-1 only.
CreateImageOutputFormatPNG = "png"
CreateImageOutputFormatJPEG = "jpeg"
CreateImageOutputFormatWEBP = "webp"
)
// ImageRequest represents the request structure for the image API. // ImageRequest represents the request structure for the image API.
type ImageRequest struct { type ImageRequest struct {
Prompt string `json:"prompt,omitempty"` Prompt string `json:"prompt,omitempty"`
@@ -79,35 +48,16 @@ type ImageRequest struct {
Style string `json:"style,omitempty"` Style string `json:"style,omitempty"`
ResponseFormat string `json:"response_format,omitempty"` ResponseFormat string `json:"response_format,omitempty"`
User string `json:"user,omitempty"` User string `json:"user,omitempty"`
Background string `json:"background,omitempty"`
Moderation string `json:"moderation,omitempty"`
OutputCompression int `json:"output_compression,omitempty"`
OutputFormat string `json:"output_format,omitempty"`
} }
// ImageResponse represents a response structure for image API. // ImageResponse represents a response structure for image API.
type ImageResponse struct { type ImageResponse struct {
Created int64 `json:"created,omitempty"` Created int64 `json:"created,omitempty"`
Data []ImageResponseDataInner `json:"data,omitempty"` Data []ImageResponseDataInner `json:"data,omitempty"`
Usage ImageResponseUsage `json:"usage,omitempty"`
httpHeader httpHeader
} }
// ImageResponseInputTokensDetails represents the token breakdown for input tokens.
type ImageResponseInputTokensDetails struct {
TextTokens int `json:"text_tokens,omitempty"`
ImageTokens int `json:"image_tokens,omitempty"`
}
// ImageResponseUsage represents the token usage information for image API.
type ImageResponseUsage struct {
TotalTokens int `json:"total_tokens,omitempty"`
InputTokens int `json:"input_tokens,omitempty"`
OutputTokens int `json:"output_tokens,omitempty"`
InputTokensDetails ImageResponseInputTokensDetails `json:"input_tokens_details,omitempty"`
}
// ImageResponseDataInner represents a response data structure for image API. // ImageResponseDataInner represents a response data structure for image API.
type ImageResponseDataInner struct { type ImageResponseDataInner struct {
URL string `json:"url,omitempty"` URL string `json:"url,omitempty"`
@@ -134,15 +84,13 @@ func (c *Client) CreateImage(ctx context.Context, request ImageRequest) (respons
// ImageEditRequest represents the request structure for the image API. // ImageEditRequest represents the request structure for the image API.
type ImageEditRequest struct { type ImageEditRequest struct {
Image io.Reader `json:"image,omitempty"` Image *os.File `json:"image,omitempty"`
Mask io.Reader `json:"mask,omitempty"` Mask *os.File `json:"mask,omitempty"`
Prompt string `json:"prompt,omitempty"` Prompt string `json:"prompt,omitempty"`
Model string `json:"model,omitempty"` Model string `json:"model,omitempty"`
N int `json:"n,omitempty"` N int `json:"n,omitempty"`
Size string `json:"size,omitempty"` Size string `json:"size,omitempty"`
ResponseFormat string `json:"response_format,omitempty"` ResponseFormat string `json:"response_format,omitempty"`
Quality string `json:"quality,omitempty"`
User string `json:"user,omitempty"`
} }
// CreateEditImage - API call to create an image. This is the main endpoint of the DALL-E API. // CreateEditImage - API call to create an image. This is the main endpoint of the DALL-E API.
@@ -150,16 +98,15 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest)
body := &bytes.Buffer{} body := &bytes.Buffer{}
builder := c.createFormBuilder(body) builder := c.createFormBuilder(body)
// image, filename is not required // image
err = builder.CreateFormFileReader("image", request.Image, "") err = builder.CreateFormFile("image", request.Image)
if err != nil { if err != nil {
return return
} }
// mask, it is optional // mask, it is optional
if request.Mask != nil { if request.Mask != nil {
// mask, filename is not required err = builder.CreateFormFile("mask", request.Mask)
err = builder.CreateFormFileReader("mask", request.Mask, "")
if err != nil { if err != nil {
return return
} }
@@ -207,12 +154,11 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest)
// ImageVariRequest represents the request structure for the image API. // ImageVariRequest represents the request structure for the image API.
type ImageVariRequest struct { type ImageVariRequest struct {
Image io.Reader `json:"image,omitempty"` Image *os.File `json:"image,omitempty"`
Model string `json:"model,omitempty"` Model string `json:"model,omitempty"`
N int `json:"n,omitempty"` N int `json:"n,omitempty"`
Size string `json:"size,omitempty"` Size string `json:"size,omitempty"`
ResponseFormat string `json:"response_format,omitempty"` ResponseFormat string `json:"response_format,omitempty"`
User string `json:"user,omitempty"`
} }
// CreateVariImage - API call to create an image variation. This is the main endpoint of the DALL-E API. // CreateVariImage - API call to create an image variation. This is the main endpoint of the DALL-E API.
@@ -221,8 +167,8 @@ func (c *Client) CreateVariImage(ctx context.Context, request ImageVariRequest)
body := &bytes.Buffer{} body := &bytes.Buffer{}
builder := c.createFormBuilder(body) builder := c.createFormBuilder(body)
// image, filename is not required // image
err = builder.CreateFormFileReader("image", request.Image, "") err = builder.CreateFormFile("image", request.Image)
if err != nil { if err != nil {
return return
} }

View File

@@ -54,13 +54,13 @@ func TestImageFormBuilderFailures(t *testing.T) {
} }
mockFailedErr := fmt.Errorf("mock form builder fail") mockFailedErr := fmt.Errorf("mock form builder fail")
mockBuilder.mockCreateFormFileReader = func(string, io.Reader, string) error { mockBuilder.mockCreateFormFile = func(string, *os.File) error {
return mockFailedErr return mockFailedErr
} }
_, err := client.CreateEditImage(ctx, req) _, err := client.CreateEditImage(ctx, req)
checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails") checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails")
mockBuilder.mockCreateFormFileReader = func(name string, _ io.Reader, _ string) error { mockBuilder.mockCreateFormFile = func(name string, _ *os.File) error {
if name == "mask" { if name == "mask" {
return mockFailedErr return mockFailedErr
} }
@@ -119,13 +119,13 @@ func TestVariImageFormBuilderFailures(t *testing.T) {
req := ImageVariRequest{} req := ImageVariRequest{}
mockFailedErr := fmt.Errorf("mock form builder fail") mockFailedErr := fmt.Errorf("mock form builder fail")
mockBuilder.mockCreateFormFileReader = func(string, io.Reader, string) error { mockBuilder.mockCreateFormFile = func(string, *os.File) error {
return mockFailedErr return mockFailedErr
} }
_, err := client.CreateVariImage(ctx, req) _, err := client.CreateVariImage(ctx, req)
checks.ErrorIs(t, err, mockFailedErr, "CreateVariImage should return error if form builder fails") checks.ErrorIs(t, err, mockFailedErr, "CreateVariImage should return error if form builder fails")
mockBuilder.mockCreateFormFileReader = func(string, io.Reader, string) error { mockBuilder.mockCreateFormFile = func(string, *os.File) error {
return nil return nil
} }

View File

@@ -4,10 +4,8 @@ import (
"fmt" "fmt"
"io" "io"
"mime/multipart" "mime/multipart"
"net/textproto"
"os" "os"
"path/filepath" "path"
"strings"
) )
type FormBuilder interface { type FormBuilder interface {
@@ -32,37 +30,8 @@ func (fb *DefaultFormBuilder) CreateFormFile(fieldname string, file *os.File) er
return fb.createFormFile(fieldname, file, file.Name()) return fb.createFormFile(fieldname, file, file.Name())
} }
var quoteEscaper = strings.NewReplacer("\\", "\\\\", `"`, "\\\"")
func escapeQuotes(s string) string {
return quoteEscaper.Replace(s)
}
// CreateFormFileReader creates a form field with a file reader.
// The filename in parameters can be an empty string.
// The filename in Content-Disposition is required, But it can be an empty string.
func (fb *DefaultFormBuilder) CreateFormFileReader(fieldname string, r io.Reader, filename string) error { func (fb *DefaultFormBuilder) CreateFormFileReader(fieldname string, r io.Reader, filename string) error {
h := make(textproto.MIMEHeader) return fb.createFormFile(fieldname, r, path.Base(filename))
h.Set(
"Content-Disposition",
fmt.Sprintf(
`form-data; name="%s"; filename="%s"`,
escapeQuotes(fieldname),
escapeQuotes(filepath.Base(filename)),
),
)
fieldWriter, err := fb.writer.CreatePart(h)
if err != nil {
return err
}
_, err = io.Copy(fieldWriter, r)
if err != nil {
return err
}
return nil
} }
func (fb *DefaultFormBuilder) createFormFile(fieldname string, r io.Reader, filename string) error { func (fb *DefaultFormBuilder) createFormFile(fieldname string, r io.Reader, filename string) error {

View File

@@ -43,32 +43,3 @@ func TestFormBuilderWithClosedFile(t *testing.T) {
checks.HasError(t, err, "formbuilder should return error if file is closed") checks.HasError(t, err, "formbuilder should return error if file is closed")
checks.ErrorIs(t, err, os.ErrClosed, "formbuilder should return error if file is closed") checks.ErrorIs(t, err, os.ErrClosed, "formbuilder should return error if file is closed")
} }
type failingReader struct {
}
var errMockFailingReaderError = errors.New("mock reader failed")
func (*failingReader) Read([]byte) (int, error) {
return 0, errMockFailingReaderError
}
func TestFormBuilderWithReader(t *testing.T) {
file, err := os.CreateTemp(t.TempDir(), "")
if err != nil {
t.Fatalf("Error creating tmp file: %v", err)
}
defer file.Close()
builder := NewFormBuilder(&failingWriter{})
err = builder.CreateFormFileReader("file", file, file.Name())
checks.ErrorIs(t, err, errMockFailingWriterError, "formbuilder should return error if writer fails")
builder = NewFormBuilder(&bytes.Buffer{})
reader := &failingReader{}
err = builder.CreateFormFileReader("file", reader, "")
checks.ErrorIs(t, err, errMockFailingReaderError, "formbuilder should return error if copy reader fails")
successReader := &bytes.Buffer{}
err = builder.CreateFormFileReader("file", successReader, "")
checks.NoError(t, err, "formbuilder should not return error")
}

View File

@@ -46,8 +46,6 @@ type Definition struct {
// additionalProperties: false // additionalProperties: false
// additionalProperties: jsonschema.Definition{Type: jsonschema.String} // additionalProperties: jsonschema.Definition{Type: jsonschema.String}
AdditionalProperties any `json:"additionalProperties,omitempty"` AdditionalProperties any `json:"additionalProperties,omitempty"`
// Whether the schema is nullable or not.
Nullable bool `json:"nullable,omitempty"`
} }
func (d *Definition) MarshalJSON() ([]byte, error) { func (d *Definition) MarshalJSON() ([]byte, error) {
@@ -126,12 +124,9 @@ func reflectSchemaObject(t reflect.Type) (*Definition, error) {
} }
jsonTag := field.Tag.Get("json") jsonTag := field.Tag.Get("json")
var required = true var required = true
switch { if jsonTag == "" {
case jsonTag == "-":
continue
case jsonTag == "":
jsonTag = field.Name jsonTag = field.Name
case strings.HasSuffix(jsonTag, ",omitempty"): } else if strings.HasSuffix(jsonTag, ",omitempty") {
jsonTag = strings.TrimSuffix(jsonTag, ",omitempty") jsonTag = strings.TrimSuffix(jsonTag, ",omitempty")
required = false required = false
} }
@@ -144,16 +139,6 @@ func reflectSchemaObject(t reflect.Type) (*Definition, error) {
if description != "" { if description != "" {
item.Description = description item.Description = description
} }
enum := field.Tag.Get("enum")
if enum != "" {
item.Enum = strings.Split(enum, ",")
}
if n := field.Tag.Get("nullable"); n != "" {
nullable, _ := strconv.ParseBool(n)
item.Nullable = nullable
}
properties[jsonTag] = *item properties[jsonTag] = *item
if s := field.Tag.Get("required"); s != "" { if s := field.Tag.Get("required"); s != "" {

View File

@@ -17,7 +17,7 @@ func TestDefinition_MarshalJSON(t *testing.T) {
{ {
name: "Test with empty Definition", name: "Test with empty Definition",
def: jsonschema.Definition{}, def: jsonschema.Definition{},
want: `{}`, want: `{"properties":{}}`,
}, },
{ {
name: "Test with Definition properties set", name: "Test with Definition properties set",
@@ -35,10 +35,11 @@ func TestDefinition_MarshalJSON(t *testing.T) {
"description":"A string type", "description":"A string type",
"properties":{ "properties":{
"name":{ "name":{
"type":"string" "type":"string",
"properties":{}
} }
} }
}`, }`,
}, },
{ {
name: "Test with nested Definition properties", name: "Test with nested Definition properties",
@@ -65,15 +66,17 @@ func TestDefinition_MarshalJSON(t *testing.T) {
"type":"object", "type":"object",
"properties":{ "properties":{
"name":{ "name":{
"type":"string" "type":"string",
"properties":{}
}, },
"age":{ "age":{
"type":"integer" "type":"integer",
"properties":{}
} }
} }
} }
} }
}`, }`,
}, },
{ {
name: "Test with complex nested Definition", name: "Test with complex nested Definition",
@@ -111,26 +114,30 @@ func TestDefinition_MarshalJSON(t *testing.T) {
"type":"object", "type":"object",
"properties":{ "properties":{
"name":{ "name":{
"type":"string" "type":"string",
"properties":{}
}, },
"age":{ "age":{
"type":"integer" "type":"integer",
"properties":{}
}, },
"address":{ "address":{
"type":"object", "type":"object",
"properties":{ "properties":{
"city":{ "city":{
"type":"string" "type":"string",
"properties":{}
}, },
"country":{ "country":{
"type":"string" "type":"string",
"properties":{}
} }
} }
} }
} }
} }
} }
}`, }`,
}, },
{ {
name: "Test with Array type Definition", name: "Test with Array type Definition",
@@ -148,14 +155,18 @@ func TestDefinition_MarshalJSON(t *testing.T) {
want: `{ want: `{
"type":"array", "type":"array",
"items":{ "items":{
"type":"string" "type":"string",
"properties":{
}
}, },
"properties":{ "properties":{
"name":{ "name":{
"type":"string" "type":"string",
"properties":{}
} }
} }
}`, }`,
}, },
} }
@@ -182,232 +193,6 @@ func TestDefinition_MarshalJSON(t *testing.T) {
} }
} }
func TestStructToSchema(t *testing.T) {
tests := []struct {
name string
in any
want string
}{
{
name: "Test with empty struct",
in: struct{}{},
want: `{
"type":"object",
"additionalProperties":false
}`,
},
{
name: "Test with struct containing many fields",
in: struct {
Name string `json:"name"`
Age int `json:"age"`
Active bool `json:"active"`
Height float64 `json:"height"`
Cities []struct {
Name string `json:"name"`
State string `json:"state"`
} `json:"cities"`
}{
Name: "John Doe",
Age: 30,
Cities: []struct {
Name string `json:"name"`
State string `json:"state"`
}{
{Name: "New York", State: "NY"},
{Name: "Los Angeles", State: "CA"},
},
},
want: `{
"type":"object",
"properties":{
"name":{
"type":"string"
},
"age":{
"type":"integer"
},
"active":{
"type":"boolean"
},
"height":{
"type":"number"
},
"cities":{
"type":"array",
"items":{
"additionalProperties":false,
"type":"object",
"properties":{
"name":{
"type":"string"
},
"state":{
"type":"string"
}
},
"required":["name","state"]
}
}
},
"required":["name","age","active","height","cities"],
"additionalProperties":false
}`,
},
{
name: "Test with description tag",
in: struct {
Name string `json:"name" description:"The name of the person"`
}{
Name: "John Doe",
},
want: `{
"type":"object",
"properties":{
"name":{
"type":"string",
"description":"The name of the person"
}
},
"required":["name"],
"additionalProperties":false
}`,
},
{
name: "Test with required tag",
in: struct {
Name string `json:"name" required:"false"`
}{
Name: "John Doe",
},
want: `{
"type":"object",
"properties":{
"name":{
"type":"string"
}
},
"additionalProperties":false
}`,
},
{
name: "Test with enum tag",
in: struct {
Color string `json:"color" enum:"red,green,blue"`
}{
Color: "red",
},
want: `{
"type":"object",
"properties":{
"color":{
"type":"string",
"enum":["red","green","blue"]
}
},
"required":["color"],
"additionalProperties":false
}`,
},
{
name: "Test with nullable tag",
in: struct {
Name *string `json:"name" nullable:"true"`
}{
Name: nil,
},
want: `{
"type":"object",
"properties":{
"name":{
"type":"string",
"nullable":true
}
},
"required":["name"],
"additionalProperties":false
}`,
},
{
name: "Test with exclude mark",
in: struct {
Name string `json:"-"`
}{
Name: "Name",
},
want: `{
"type":"object",
"additionalProperties":false
}`,
},
{
name: "Test with no json tag",
in: struct {
Name string
}{
Name: "",
},
want: `{
"type":"object",
"properties":{
"Name":{
"type":"string"
}
},
"required":["Name"],
"additionalProperties":false
}`,
},
{
name: "Test with omitempty tag",
in: struct {
Name string `json:"name,omitempty"`
}{
Name: "",
},
want: `{
"type":"object",
"properties":{
"name":{
"type":"string"
}
},
"additionalProperties":false
}`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
wantBytes := []byte(tt.want)
schema, err := jsonschema.GenerateSchemaForType(tt.in)
if err != nil {
t.Errorf("Failed to generate schema: error = %v", err)
return
}
var want map[string]interface{}
err = json.Unmarshal(wantBytes, &want)
if err != nil {
t.Errorf("Failed to Unmarshal JSON: error = %v", err)
return
}
got := structToMap(t, schema)
gotPtr := structToMap(t, &schema)
if !reflect.DeepEqual(got, want) {
t.Errorf("MarshalJSON() got = %v, want %v", got, want)
}
if !reflect.DeepEqual(gotPtr, want) {
t.Errorf("MarshalJSON() gotPtr = %v, want %v", gotPtr, want)
}
})
}
}
func structToMap(t *testing.T, v any) map[string]any { func structToMap(t *testing.T, v any) map[string]any {
t.Helper() t.Helper()
gotBytes, err := json.Marshal(v) gotBytes, err := json.Marshal(v)

View File

@@ -47,24 +47,6 @@ func TestGetModel(t *testing.T) {
checks.NoError(t, err, "GetModel error") checks.NoError(t, err, "GetModel error")
} }
// TestGetModelO3 Tests the retrieve O3 model endpoint of the API using the mocked server.
func TestGetModelO3(t *testing.T) {
client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler("/v1/models/o3", handleGetModelEndpoint)
_, err := client.GetModel(context.Background(), "o3")
checks.NoError(t, err, "GetModel error for O3")
}
// TestGetModelO4Mini Tests the retrieve O4Mini model endpoint of the API using the mocked server.
func TestGetModelO4Mini(t *testing.T) {
client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler("/v1/models/o4-mini", handleGetModelEndpoint)
_, err := client.GetModel(context.Background(), "o4-mini")
checks.NoError(t, err, "GetModel error for O4Mini")
}
func TestAzureGetModel(t *testing.T) { func TestAzureGetModel(t *testing.T) {
client, server, teardown := setupAzureTestServer() client, server, teardown := setupAzureTestServer()
defer teardown() defer teardown()

View File

@@ -29,7 +29,7 @@ func setupAzureTestServer() (client *openai.Client, server *test.ServerTest, tea
// numTokens Returns the number of GPT-3 encoded tokens in the given text. // numTokens Returns the number of GPT-3 encoded tokens in the given text.
// This function approximates based on the rule of thumb stated by OpenAI: // This function approximates based on the rule of thumb stated by OpenAI:
// https://beta.openai.com/tokenizer. // https://beta.openai.com/tokenizer/
// //
// TODO: implement an actual tokenizer for GPT-3 and Codex (once available). // TODO: implement an actual tokenizer for GPT-3 and Codex (once available).
func numTokens(s string) int { func numTokens(s string) int {

View File

@@ -28,6 +28,15 @@ var (
ErrReasoningModelLimitationsOther = errors.New("this model has beta-limitations, temperature, top_p and n are fixed at 1, while presence_penalty and frequency_penalty are fixed at 0") //nolint:lll ErrReasoningModelLimitationsOther = errors.New("this model has beta-limitations, temperature, top_p and n are fixed at 1, while presence_penalty and frequency_penalty are fixed at 0") //nolint:lll
) )
var unsupportedToolsForO1Models = map[ToolType]struct{}{
ToolTypeFunction: {},
}
var availableMessageRoleForO1Models = map[string]struct{}{
ChatMessageRoleUser: {},
ChatMessageRoleAssistant: {},
}
// ReasoningValidator handles validation for o-series model requests. // ReasoningValidator handles validation for o-series model requests.
type ReasoningValidator struct{} type ReasoningValidator struct{}
@@ -40,9 +49,8 @@ func NewReasoningValidator() *ReasoningValidator {
func (v *ReasoningValidator) Validate(request ChatCompletionRequest) error { func (v *ReasoningValidator) Validate(request ChatCompletionRequest) error {
o1Series := strings.HasPrefix(request.Model, "o1") o1Series := strings.HasPrefix(request.Model, "o1")
o3Series := strings.HasPrefix(request.Model, "o3") o3Series := strings.HasPrefix(request.Model, "o3")
o4Series := strings.HasPrefix(request.Model, "o4")
if !o1Series && !o3Series && !o4Series { if !o1Series && !o3Series {
return nil return nil
} }
@@ -50,6 +58,12 @@ func (v *ReasoningValidator) Validate(request ChatCompletionRequest) error {
return err return err
} }
if o1Series {
if err := v.validateO1Specific(request); err != nil {
return err
}
}
return nil return nil
} }
@@ -79,3 +93,19 @@ func (v *ReasoningValidator) validateReasoningModelParams(request ChatCompletion
return nil return nil
} }
// validateO1Specific checks O1-specific limitations.
func (v *ReasoningValidator) validateO1Specific(request ChatCompletionRequest) error {
for _, m := range request.Messages {
if _, found := availableMessageRoleForO1Models[m.Role]; !found {
return ErrO1BetaLimitationsMessageTypes
}
}
for _, t := range request.Tools {
if _, found := unsupportedToolsForO1Models[t.Type]; found {
return ErrO1BetaLimitationsTools
}
}
return nil
}

View File

@@ -11,22 +11,17 @@ const (
TTSModel1 SpeechModel = "tts-1" TTSModel1 SpeechModel = "tts-1"
TTSModel1HD SpeechModel = "tts-1-hd" TTSModel1HD SpeechModel = "tts-1-hd"
TTSModelCanary SpeechModel = "canary-tts" TTSModelCanary SpeechModel = "canary-tts"
TTSModelGPT4oMini SpeechModel = "gpt-4o-mini-tts"
) )
type SpeechVoice string type SpeechVoice string
const ( const (
VoiceAlloy SpeechVoice = "alloy" VoiceAlloy SpeechVoice = "alloy"
VoiceAsh SpeechVoice = "ash"
VoiceBallad SpeechVoice = "ballad"
VoiceCoral SpeechVoice = "coral"
VoiceEcho SpeechVoice = "echo" VoiceEcho SpeechVoice = "echo"
VoiceFable SpeechVoice = "fable" VoiceFable SpeechVoice = "fable"
VoiceOnyx SpeechVoice = "onyx" VoiceOnyx SpeechVoice = "onyx"
VoiceNova SpeechVoice = "nova" VoiceNova SpeechVoice = "nova"
VoiceShimmer SpeechVoice = "shimmer" VoiceShimmer SpeechVoice = "shimmer"
VoiceVerse SpeechVoice = "verse"
) )
type SpeechResponseFormat string type SpeechResponseFormat string
@@ -44,7 +39,6 @@ type CreateSpeechRequest struct {
Model SpeechModel `json:"model"` Model SpeechModel `json:"model"`
Input string `json:"input"` Input string `json:"input"`
Voice SpeechVoice `json:"voice"` Voice SpeechVoice `json:"voice"`
Instructions string `json:"instructions,omitempty"` // Optional, Doesnt work with tts-1 or tts-1-hd.
ResponseFormat SpeechResponseFormat `json:"response_format,omitempty"` // Optional, default to mp3 ResponseFormat SpeechResponseFormat `json:"response_format,omitempty"` // Optional, default to mp3
Speed float64 `json:"speed,omitempty"` // Optional, default to 1.0 Speed float64 `json:"speed,omitempty"` // Optional, default to 1.0
} }

View File

@@ -6,14 +6,13 @@ import (
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"regexp"
utils "git.vaala.cloud/VaalaCat/go-openai/internal" utils "git.vaala.cloud/VaalaCat/go-openai/internal"
) )
var ( var (
headerData = regexp.MustCompile(`^data:\s*`) headerData = []byte("data: ")
errorPrefix = regexp.MustCompile(`^data:\s*{"error":`) errorPrefix = []byte(`data: {"error":`)
) )
type streamable interface { type streamable interface {
@@ -71,12 +70,12 @@ func (stream *streamReader[T]) processLines() ([]byte, error) {
} }
noSpaceLine := bytes.TrimSpace(rawLine) noSpaceLine := bytes.TrimSpace(rawLine)
if errorPrefix.Match(noSpaceLine) { if bytes.HasPrefix(noSpaceLine, errorPrefix) {
hasErrorPrefix = true hasErrorPrefix = true
} }
if !headerData.Match(noSpaceLine) || hasErrorPrefix { if !bytes.HasPrefix(noSpaceLine, headerData) || hasErrorPrefix {
if hasErrorPrefix { if hasErrorPrefix {
noSpaceLine = headerData.ReplaceAll(noSpaceLine, nil) noSpaceLine = bytes.TrimPrefix(noSpaceLine, headerData)
} }
writeErr := stream.errAccumulator.Write(noSpaceLine) writeErr := stream.errAccumulator.Write(noSpaceLine)
if writeErr != nil { if writeErr != nil {
@@ -90,7 +89,7 @@ func (stream *streamReader[T]) processLines() ([]byte, error) {
continue continue
} }
noPrefixLine := headerData.ReplaceAll(noSpaceLine, nil) noPrefixLine := bytes.TrimPrefix(noSpaceLine, headerData)
if string(noPrefixLine) == "[DONE]" { if string(noPrefixLine) == "[DONE]" {
stream.isFinished = true stream.isFinished = true
return nil, io.EOF return nil, io.EOF