Compare commits

..

3 Commits

Author SHA1 Message Date
VaalaCat
1337a4b683 feat: add reasoning format 2025-03-07 13:22:30 +00:00
VaalaCat
3f53ae6ab1 feat: add include_reasoning 2025-02-12 13:24:21 +00:00
VaalaCat
40de0deb41 feat: change repo name 2025-02-12 13:22:09 +00:00
26 changed files with 566 additions and 1369 deletions

View File

@@ -13,17 +13,15 @@ jobs:
- name: Setup Go
uses: actions/setup-go@v5
with:
go-version: '1.24'
go-version: '1.21'
- name: Run vet
run: |
go vet .
- name: Run golangci-lint
uses: golangci/golangci-lint-action@v7
uses: golangci/golangci-lint-action@v4
with:
version: v2.1.5
version: latest
- name: Run tests
run: go test -race -covermode=atomic -coverprofile=coverage.out -v ./...
run: go test -race -covermode=atomic -coverprofile=coverage.out -v .
- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@v5
with:
token: ${{ secrets.CODECOV_TOKEN }}
uses: codecov/codecov-action@v4

View File

@@ -1,94 +1,66 @@
version: "2"
linters:
default: none
enable:
- asciicheck
- bidichk
- bodyclose
- contextcheck
- cyclop
- dupl
- durationcheck
- errcheck
- errname
- errorlint
- exhaustive
- forbidigo
- funlen
- gochecknoinits
- gocognit
- goconst
- gocritic
- gocyclo
- godot
- gomoddirectives
- gomodguard
- goprintffuncname
- gosec
- govet
- ineffassign
- lll
- makezero
- mnd
- nestif
- nilerr
- nilnil
- nolintlint
- nosprintfhostport
- predeclared
- promlinter
- revive
- rowserrcheck
- sqlclosecheck
- staticcheck
- testpackage
- tparallel
- unconvert
- unparam
- unused
- usetesting
- wastedassign
- whitespace
settings:
## Golden config for golangci-lint v1.47.3
#
# This is the best config for golangci-lint based on my experience and opinion.
# It is very strict, but not extremely strict.
# Feel free to adopt and change it for your needs.
run:
# Timeout for analysis, e.g. 30s, 5m.
# Default: 1m
timeout: 3m
# This file contains only configs which differ from defaults.
# All possible options can be found here https://github.com/golangci/golangci-lint/blob/master/.golangci.reference.yml
linters-settings:
cyclop:
# The maximal code complexity to report.
# Default: 10
max-complexity: 30
package-average: 10
# The maximal average package complexity.
# If it's higher than 0.0 (float) the check is enabled
# Default: 0.0
package-average: 10.0
errcheck:
# Report about not checking of errors in type assertions: `a := b.(MyStruct)`.
# Such cases aren't reported by default.
# Default: false
check-type-assertions: true
funlen:
# Checks the number of lines in a function.
# If lower than 0, disable the check.
# Default: 60
lines: 100
# Checks the number of statements in a function.
# If lower than 0, disable the check.
# Default: 40
statements: 50
gocognit:
# Minimal code complexity to report
# Default: 30 (but we recommend 10-20)
min-complexity: 20
gocritic:
# Settings passed to gocritic.
# The settings key is the name of a supported gocritic checker.
# The list of supported checkers can be find in https://go-critic.github.io/overview.
settings:
captLocal:
# Whether to restrict checker to params only.
# Default: true
paramsOnly: false
underef:
# Whether to skip (*x).method() calls where x is a pointer receiver.
# Default: true
skipRecvDeref: false
gomodguard:
blocked:
modules:
- github.com/golang/protobuf:
recommendations:
- google.golang.org/protobuf
reason: see https://developers.google.com/protocol-buffers/docs/reference/go/faq#modules
- github.com/satori/go.uuid:
recommendations:
- github.com/google/uuid
reason: satori's package is not maintained
- github.com/gofrs/uuid:
recommendations:
- github.com/google/uuid
reason: 'see recommendation from dev-infra team: https://confluence.gtforge.com/x/gQI6Aw'
govet:
disable:
- fieldalignment
enable-all: true
settings:
shadow:
strict: true
mnd:
# List of function patterns to exclude from analysis.
# Values always ignored: `time.Date`
# Default: []
ignored-functions:
- os.Chmod
- os.Mkdir
@@ -104,44 +76,194 @@ linters:
- strconv.ParseFloat
- strconv.ParseInt
- strconv.ParseUint
gomodguard:
blocked:
# List of blocked modules.
# Default: []
modules:
- github.com/golang/protobuf:
recommendations:
- google.golang.org/protobuf
reason: "see https://developers.google.com/protocol-buffers/docs/reference/go/faq#modules"
- github.com/satori/go.uuid:
recommendations:
- github.com/google/uuid
reason: "satori's package is not maintained"
- github.com/gofrs/uuid:
recommendations:
- github.com/google/uuid
reason: "see recommendation from dev-infra team: https://confluence.gtforge.com/x/gQI6Aw"
govet:
# Enable all analyzers.
# Default: false
enable-all: true
# Disable analyzers by name.
# Run `go tool vet help` to see all analyzers.
# Default: []
disable:
- fieldalignment # too strict
# Settings per analyzer.
settings:
shadow:
# Whether to be strict about shadowing; can be noisy.
# Default: false
strict: true
nakedret:
# Make an issue if func has more lines of code than this setting, and it has naked returns.
# Default: 30
max-func-lines: 0
nolintlint:
# Exclude following linters from requiring an explanation.
# Default: []
allow-no-explanation: [ funlen, gocognit, lll ]
# Enable to require an explanation of nonzero length after each nolint directive.
# Default: false
require-explanation: true
# Enable to require nolint directives to mention the specific linter being suppressed.
# Default: false
require-specific: true
allow-no-explanation:
- funlen
- gocognit
- lll
rowserrcheck:
# database/sql is always checked
# Default: []
packages:
- github.com/jmoiron/sqlx
exclusions:
generated: lax
presets:
- comments
- common-false-positives
- legacy
- std-error-handling
rules:
- linters:
- forbidigo
- mnd
- revive
path : ^examples/.*\.go$
- linters:
- lll
source: ^//\s*go:generate\s
- linters:
- godot
source: (noinspection|TODO)
- linters:
- gocritic
source: //noinspection
- linters:
- errorlint
source: ^\s+if _, ok := err\.\([^.]+\.InternalError\); ok {
- linters:
tenv:
# The option `all` will run against whole test files (`_test.go`) regardless of method/function signatures.
# Otherwise, only methods that take `*testing.T`, `*testing.B`, and `testing.TB` as arguments are checked.
# Default: false
all: true
varcheck:
# Check usage of exported fields and variables.
# Default: false
exported-fields: false # default false # TODO: enable after fixing false positives
linters:
disable-all: true
enable:
## enabled by default
- errcheck # Errcheck is a program for checking for unchecked errors in go programs. These unchecked errors can be critical bugs in some cases
- gosimple # Linter for Go source code that specializes in simplifying a code
- govet # Vet examines Go source code and reports suspicious constructs, such as Printf calls whose arguments do not align with the format string
- ineffassign # Detects when assignments to existing variables are not used
- staticcheck # Staticcheck is a go vet on steroids, applying a ton of static analysis checks
- typecheck # Like the front-end of a Go compiler, parses and type-checks Go code
- unused # Checks Go code for unused constants, variables, functions and types
## disabled by default
# - asasalint # Check for pass []any as any in variadic func(...any)
- asciicheck # Simple linter to check that your code does not contain non-ASCII identifiers
- bidichk # Checks for dangerous unicode character sequences
- bodyclose # checks whether HTTP response body is closed successfully
- contextcheck # check the function whether use a non-inherited context
- cyclop # checks function and package cyclomatic complexity
- dupl # Tool for code clone detection
- durationcheck # check for two durations multiplied together
- errname # Checks that sentinel errors are prefixed with the Err and error types are suffixed with the Error.
- errorlint # errorlint is a linter for that can be used to find code that will cause problems with the error wrapping scheme introduced in Go 1.13.
# Removed execinquery (deprecated). execinquery is a linter about query string checker in Query function which reads your Go src files and warning it finds
- exhaustive # check exhaustiveness of enum switch statements
- exportloopref # checks for pointers to enclosing loop variables
- forbidigo # Forbids identifiers
- funlen # Tool for detection of long functions
# - gochecknoglobals # check that no global variables exist
- gochecknoinits # Checks that no init functions are present in Go code
- gocognit # Computes and checks the cognitive complexity of functions
- goconst # Finds repeated strings that could be replaced by a constant
- gocritic # Provides diagnostics that check for bugs, performance and style issues.
- gocyclo # Computes and checks the cyclomatic complexity of functions
- godot # Check if comments end in a period
- goimports # In addition to fixing imports, goimports also formats your code in the same style as gofmt.
- gomoddirectives # Manage the use of 'replace', 'retract', and 'excludes' directives in go.mod.
- gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations.
- goprintffuncname # Checks that printf-like functions are named with f at the end
- gosec # Inspects source code for security problems
- lll # Reports long lines
- makezero # Finds slice declarations with non-zero initial length
# - nakedret # Finds naked returns in functions greater than a specified function length
- mnd # An analyzer to detect magic numbers.
- nestif # Reports deeply nested if statements
- nilerr # Finds the code that returns nil even if it checks that the error is not nil.
- nilnil # Checks that there is no simultaneous return of nil error and an invalid value.
# - noctx # noctx finds sending http request without context.Context
- nolintlint # Reports ill-formed or insufficient nolint directives
# - nonamedreturns # Reports all named returns
- nosprintfhostport # Checks for misuse of Sprintf to construct a host with port in a URL.
- predeclared # find code that shadows one of Go's predeclared identifiers
- promlinter # Check Prometheus metrics naming via promlint
- revive # Fast, configurable, extensible, flexible, and beautiful linter for Go. Drop-in replacement of golint.
- rowserrcheck # checks whether Err of rows is checked successfully
- sqlclosecheck # Checks that sql.Rows and sql.Stmt are closed.
- stylecheck # Stylecheck is a replacement for golint
- tenv # tenv is analyzer that detects using os.Setenv instead of t.Setenv since Go1.17
- testpackage # linter that makes you use a separate _test package
- tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes
- unconvert # Remove unnecessary type conversions
- unparam # Reports unused function parameters
- usetesting # Reports uses of functions with replacement inside the testing package
- wastedassign # wastedassign finds wasted assignment statements.
- whitespace # Tool for detection of leading and trailing whitespace
## you may want to enable
#- decorder # check declaration order and count of types, constants, variables and functions
#- exhaustruct # Checks if all structure fields are initialized
#- goheader # Checks is file header matches to pattern
#- ireturn # Accept Interfaces, Return Concrete Types
#- prealloc # [premature optimization, but can be used in some cases] Finds slice declarations that could potentially be preallocated
#- varnamelen # [great idea, but too many false positives] checks that the length of a variable's name matches its scope
#- wrapcheck # Checks that errors returned from external packages are wrapped
## disabled
#- containedctx # containedctx is a linter that detects struct contained context.Context field
#- depguard # [replaced by gomodguard] Go linter that checks if package imports are in a list of acceptable packages
#- dogsled # Checks assignments with too many blank identifiers (e.g. x, _, _, _, := f())
#- errchkjson # [don't see profit + I'm against of omitting errors like in the first example https://github.com/breml/errchkjson] Checks types passed to the json encoding functions. Reports unsupported types and optionally reports occasions, where the check for the returned error can be omitted.
#- forcetypeassert # [replaced by errcheck] finds forced type assertions
#- gci # Gci controls golang package import order and makes it always deterministic.
#- godox # Tool for detection of FIXME, TODO and other comment keywords
#- goerr113 # [too strict] Golang linter to check the errors handling expressions
#- gofmt # [replaced by goimports] Gofmt checks whether code was gofmt-ed. By default this tool runs with -s option to check for code simplification
#- gofumpt # [replaced by goimports, gofumports is not available yet] Gofumpt checks whether code was gofumpt-ed.
#- grouper # An analyzer to analyze expression groups.
#- ifshort # Checks that your code uses short syntax for if-statements whenever possible
#- importas # Enforces consistent import aliases
#- maintidx # maintidx measures the maintainability index of each function.
#- misspell # [useless] Finds commonly misspelled English words in comments
#- nlreturn # [too strict and mostly code is not more readable] nlreturn checks for a new line before return and branch statements to increase code clarity
#- nosnakecase # Detects snake case of variable naming and function name. # TODO: maybe enable after https://github.com/sivchari/nosnakecase/issues/14
#- paralleltest # [too many false positives] paralleltest detects missing usage of t.Parallel() method in your Go test
#- tagliatelle # Checks the struct tags.
#- thelper # thelper detects golang test helpers without t.Helper() call and checks the consistency of test helpers
#- wsl # [too strict and mostly code is not more readable] Whitespace Linter - Forces you to use empty lines!
## deprecated
#- exhaustivestruct # [deprecated, replaced by exhaustruct] Checks if all struct's fields are initialized
#- golint # [deprecated, replaced by revive] Golint differs from gofmt. Gofmt reformats Go source code, whereas golint prints out style mistakes
#- interfacer # [deprecated] Linter that suggests narrower interface types
#- maligned # [deprecated, replaced by govet fieldalignment] Tool to detect Go structs that would take less memory if their fields were sorted
#- scopelint # [deprecated, replaced by exportloopref] Scopelint checks for unpinned variables in go programs
issues:
# Maximum count of issues with the same text.
# Set to 0 to disable.
# Default: 3
max-same-issues: 50
exclude-rules:
- source: "^//\\s*go:generate\\s"
linters: [ lll ]
- source: "(noinspection|TODO)"
linters: [ godot ]
- source: "//noinspection"
linters: [ gocritic ]
- source: "^\\s+if _, ok := err\\.\\([^.]+\\.InternalError\\); ok {"
linters: [ errorlint ]
- path: "_test\\.go"
linters:
- bodyclose
- dupl
- funlen
@@ -149,20 +271,3 @@ linters:
- gosec
- noctx
- wrapcheck
- staticcheck
path: _test\.go
paths:
- third_party$
- builtin$
- examples$
issues:
max-same-issues: 50
formatters:
enable:
- goimports
exclusions:
generated: lax
paths:
- third_party$
- builtin$
- examples$

View File

@@ -7,7 +7,7 @@ This library provides unofficial Go clients for [OpenAI API](https://platform.op
* ChatGPT 4o, o1
* GPT-3, GPT-4
* DALL·E 2, DALL·E 3, GPT Image 1
* DALL·E 2, DALL·E 3
* Whisper
## Installation
@@ -357,66 +357,6 @@ func main() {
```
</details>
<details>
<summary>GPT Image 1 image generation</summary>
```go
package main
import (
"context"
"encoding/base64"
"fmt"
"os"
openai "github.com/sashabaranov/go-openai"
)
func main() {
c := openai.NewClient("your token")
ctx := context.Background()
req := openai.ImageRequest{
Prompt: "Parrot on a skateboard performing a trick. Large bold text \"SKATE MASTER\" banner at the bottom of the image. Cartoon style, natural light, high detail, 1:1 aspect ratio.",
Background: openai.CreateImageBackgroundOpaque,
Model: openai.CreateImageModelGptImage1,
Size: openai.CreateImageSize1024x1024,
N: 1,
Quality: openai.CreateImageQualityLow,
OutputCompression: 100,
OutputFormat: openai.CreateImageOutputFormatJPEG,
// Moderation: openai.CreateImageModerationLow,
// User: "",
}
resp, err := c.CreateImage(ctx, req)
if err != nil {
fmt.Printf("Image creation Image generation with GPT Image 1error: %v\n", err)
return
}
fmt.Println("Image Base64:", resp.Data[0].B64JSON)
// Decode the base64 data
imgBytes, err := base64.StdEncoding.DecodeString(resp.Data[0].B64JSON)
if err != nil {
fmt.Printf("Base64 decode error: %v\n", err)
return
}
// Write image to file
outputPath := "generated_image.jpg"
err = os.WriteFile(outputPath, imgBytes, 0644)
if err != nil {
fmt.Printf("Failed to write image file: %v\n", err)
return
}
fmt.Printf("The image was saved as %s\n", outputPath)
}
```
</details>
<details>
<summary>Configuring proxy</summary>

View File

@@ -2,11 +2,8 @@ package openai //nolint:testpackage // testing private field
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"testing"
@@ -110,131 +107,3 @@ func TestCreateFileField(t *testing.T) {
checks.HasError(t, err, "createFileField using file should return error when open file fails")
})
}
// failingFormBuilder always returns an error when creating form files.
type failingFormBuilder struct{ err error }
func (f *failingFormBuilder) CreateFormFile(_ string, _ *os.File) error {
return f.err
}
func (f *failingFormBuilder) CreateFormFileReader(_ string, _ io.Reader, _ string) error {
return f.err
}
func (f *failingFormBuilder) WriteField(_, _ string) error {
return nil
}
func (f *failingFormBuilder) Close() error {
return nil
}
func (f *failingFormBuilder) FormDataContentType() string {
return "multipart/form-data"
}
// failingAudioRequestBuilder simulates an error during HTTP request construction.
type failingAudioRequestBuilder struct{ err error }
func (f *failingAudioRequestBuilder) Build(
_ context.Context,
_, _ string,
_ any,
_ http.Header,
) (*http.Request, error) {
return nil, f.err
}
// errorHTTPClient always returns an error when making HTTP calls.
type errorHTTPClient struct{ err error }
func (e *errorHTTPClient) Do(_ *http.Request) (*http.Response, error) {
return nil, e.err
}
func TestCallAudioAPIMultipartFormError(t *testing.T) {
client := NewClient("test-token")
errForm := errors.New("mock create form file failure")
// Override form builder to force an error during multipart form creation.
client.createFormBuilder = func(_ io.Writer) utils.FormBuilder {
return &failingFormBuilder{err: errForm}
}
// Provide a reader so createFileField uses the reader path (no file open).
req := AudioRequest{FilePath: "fake.mp3", Reader: bytes.NewBuffer([]byte("dummy")), Model: Whisper1}
_, err := client.callAudioAPI(context.Background(), req, "transcriptions")
if err == nil {
t.Fatal("expected error but got none")
}
if !errors.Is(err, errForm) {
t.Errorf("expected error %v, got %v", errForm, err)
}
}
func TestCallAudioAPINewRequestError(t *testing.T) {
client := NewClient("test-token")
// Create a real temp file so multipart form succeeds.
tmp := t.TempDir()
path := filepath.Join(tmp, "file.mp3")
if err := os.WriteFile(path, []byte("content"), 0644); err != nil {
t.Fatalf("failed to write temp file: %v", err)
}
errBuild := errors.New("mock build failure")
client.requestBuilder = &failingAudioRequestBuilder{err: errBuild}
req := AudioRequest{FilePath: path, Model: Whisper1}
_, err := client.callAudioAPI(context.Background(), req, "translations")
if err == nil {
t.Fatal("expected error but got none")
}
if !errors.Is(err, errBuild) {
t.Errorf("expected error %v, got %v", errBuild, err)
}
}
func TestCallAudioAPISendRequestErrorJSON(t *testing.T) {
client := NewClient("test-token")
// Create a real temp file so multipart form succeeds.
tmp := t.TempDir()
path := filepath.Join(tmp, "file.mp3")
if err := os.WriteFile(path, []byte("content"), 0644); err != nil {
t.Fatalf("failed to write temp file: %v", err)
}
errHTTP := errors.New("mock HTTPClient failure")
// Override HTTP client to simulate a network error.
client.config.HTTPClient = &errorHTTPClient{err: errHTTP}
req := AudioRequest{FilePath: path, Model: Whisper1}
_, err := client.callAudioAPI(context.Background(), req, "transcriptions")
if err == nil {
t.Fatal("expected error but got none")
}
if !errors.Is(err, errHTTP) {
t.Errorf("expected error %v, got %v", errHTTP, err)
}
}
func TestCallAudioAPISendRequestErrorText(t *testing.T) {
client := NewClient("test-token")
tmp := t.TempDir()
path := filepath.Join(tmp, "file.mp3")
if err := os.WriteFile(path, []byte("content"), 0644); err != nil {
t.Fatalf("failed to write temp file: %v", err)
}
errHTTP := errors.New("mock HTTPClient failure")
client.config.HTTPClient = &errorHTTPClient{err: errHTTP}
// Use a non-JSON response format to exercise the text path.
req := AudioRequest{FilePath: path, Model: Whisper1, Format: AudioResponseFormatText}
_, err := client.callAudioAPI(context.Background(), req, "translations")
if err == nil {
t.Fatal("expected error but got none")
}
if !errors.Is(err, errHTTP) {
t.Errorf("expected error %v, got %v", errHTTP, err)
}
}

25
chat.go
View File

@@ -14,7 +14,6 @@ const (
ChatMessageRoleAssistant = "assistant"
ChatMessageRoleFunction = "function"
ChatMessageRoleTool = "tool"
ChatMessageRoleDeveloper = "developer"
)
const chatCompletionsSuffix = "/chat/completions"
@@ -104,12 +103,6 @@ type ChatCompletionMessage struct {
// - https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
Name string `json:"name,omitempty"`
// This property is used for the "reasoning" feature supported by deepseek-reasoner
// which is not in the official documentation.
// the doc from deepseek:
// - https://api-docs.deepseek.com/api/create-chat-completion#responses
ReasoningContent string `json:"reasoning_content,omitempty"`
FunctionCall *FunctionCall `json:"function_call,omitempty"`
// For Role=assistant prompts this may be set to the tool calls generated by the model, such as function calls.
@@ -130,7 +123,6 @@ func (m ChatCompletionMessage) MarshalJSON() ([]byte, error) {
Refusal string `json:"refusal,omitempty"`
MultiContent []ChatMessagePart `json:"content,omitempty"`
Name string `json:"name,omitempty"`
ReasoningContent string `json:"reasoning_content,omitempty"`
FunctionCall *FunctionCall `json:"function_call,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
@@ -144,7 +136,6 @@ func (m ChatCompletionMessage) MarshalJSON() ([]byte, error) {
Refusal string `json:"refusal,omitempty"`
MultiContent []ChatMessagePart `json:"-"`
Name string `json:"name,omitempty"`
ReasoningContent string `json:"reasoning_content,omitempty"`
FunctionCall *FunctionCall `json:"function_call,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
@@ -155,11 +146,10 @@ func (m ChatCompletionMessage) MarshalJSON() ([]byte, error) {
func (m *ChatCompletionMessage) UnmarshalJSON(bs []byte) error {
msg := struct {
Role string `json:"role"`
Content string `json:"content"`
Content string `json:"content,omitempty"`
Refusal string `json:"refusal,omitempty"`
MultiContent []ChatMessagePart
Name string `json:"name,omitempty"`
ReasoningContent string `json:"reasoning_content,omitempty"`
FunctionCall *FunctionCall `json:"function_call,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
@@ -175,7 +165,6 @@ func (m *ChatCompletionMessage) UnmarshalJSON(bs []byte) error {
Refusal string `json:"refusal,omitempty"`
MultiContent []ChatMessagePart `json:"content"`
Name string `json:"name,omitempty"`
ReasoningContent string `json:"reasoning_content,omitempty"`
FunctionCall *FunctionCall `json:"function_call,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
@@ -275,13 +264,6 @@ type ChatCompletionRequest struct {
Metadata map[string]string `json:"metadata,omitempty"`
IncludeReasoning *bool `json:"include_reasoning,omitempty"`
ReasoningFormat *string `json:"reasoning_format,omitempty"`
// Configuration for a predicted output.
Prediction *Prediction `json:"prediction,omitempty"`
// ChatTemplateKwargs provides a way to add non-standard parameters to the request body.
// Additional kwargs to pass to the template renderer. Will be accessible by the chat template.
// Such as think mode for qwen3. "chat_template_kwargs": {"enable_thinking": false}
// https://qwen.readthedocs.io/en/latest/deployment/vllm.html#thinking-non-thinking-modes
ChatTemplateKwargs map[string]any `json:"chat_template_kwargs,omitempty"`
}
type StreamOptions struct {
@@ -349,11 +331,6 @@ type LogProbs struct {
Content []LogProb `json:"content"`
}
type Prediction struct {
Content string `json:"content"`
Type string `json:"type"`
}
type FinishReason string
const (

View File

@@ -11,12 +11,6 @@ type ChatCompletionStreamChoiceDelta struct {
FunctionCall *FunctionCall `json:"function_call,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
Refusal string `json:"refusal,omitempty"`
// This property is used for the "reasoning" feature supported by deepseek-reasoner
// which is not in the official documentation.
// the doc from deepseek:
// - https://api-docs.deepseek.com/api/create-chat-completion#responses
ReasoningContent string `json:"reasoning_content,omitempty"`
}
type ChatCompletionStreamChoiceLogprobs struct {

View File

@@ -959,56 +959,6 @@ func TestCreateChatCompletionStreamReasoningValidatorFails(t *testing.T) {
}
}
func TestCreateChatCompletionStreamO3ReasoningValidatorFails(t *testing.T) {
client, _, _ := setupOpenAITestServer()
stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{
MaxTokens: 100, // This will trigger the validator to fail
Model: openai.O3,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
Content: "Hello!",
},
},
Stream: true,
})
if stream != nil {
t.Error("Expected nil stream when validation fails")
stream.Close()
}
if !errors.Is(err, openai.ErrReasoningModelMaxTokensDeprecated) {
t.Errorf("Expected ErrReasoningModelMaxTokensDeprecated for O3, got: %v", err)
}
}
func TestCreateChatCompletionStreamO4MiniReasoningValidatorFails(t *testing.T) {
client, _, _ := setupOpenAITestServer()
stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{
MaxTokens: 100, // This will trigger the validator to fail
Model: openai.O4Mini,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
Content: "Hello!",
},
},
Stream: true,
})
if stream != nil {
t.Error("Expected nil stream when validation fails")
stream.Close()
}
if !errors.Is(err, openai.ErrReasoningModelMaxTokensDeprecated) {
t.Errorf("Expected ErrReasoningModelMaxTokensDeprecated for O4Mini, got: %v", err)
}
}
func compareChatStreamResponseChoices(c1, c2 openai.ChatCompletionStreamChoice) bool {
if c1.Index != c2.Index {
return false

View File

@@ -106,6 +106,40 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) {
},
expectedError: openai.ErrReasoningModelLimitationsLogprobs,
},
{
name: "message_type_unsupported",
in: openai.ChatCompletionRequest{
MaxCompletionTokens: 1000,
Model: openai.O1Mini,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleSystem,
},
},
},
expectedError: openai.ErrO1BetaLimitationsMessageTypes,
},
{
name: "tool_unsupported",
in: openai.ChatCompletionRequest{
MaxCompletionTokens: 1000,
Model: openai.O1Mini,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
},
{
Role: openai.ChatMessageRoleAssistant,
},
},
Tools: []openai.Tool{
{
Type: openai.ToolTypeFunction,
},
},
},
expectedError: openai.ErrO1BetaLimitationsTools,
},
{
name: "set_temperature_unsupported",
in: openai.ChatCompletionRequest{
@@ -411,23 +445,6 @@ func TestO3ModelChatCompletions(t *testing.T) {
checks.NoError(t, err, "CreateChatCompletion error")
}
func TestDeepseekR1ModelChatCompletions(t *testing.T) {
client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler("/v1/chat/completions", handleDeepseekR1ChatCompletionEndpoint)
_, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{
Model: "deepseek-reasoner",
MaxCompletionTokens: 100,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
Content: "Hello!",
},
},
})
checks.NoError(t, err, "CreateChatCompletion error")
}
// TestCompletions Tests the completions endpoint of the API using the mocked server.
func TestChatCompletionsWithHeaders(t *testing.T) {
client, server, teardown := setupOpenAITestServer()
@@ -839,68 +856,6 @@ func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, string(resBytes))
}
func handleDeepseekR1ChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
var err error
var resBytes []byte
// completions only accepts POST requests
if r.Method != "POST" {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
var completionReq openai.ChatCompletionRequest
if completionReq, err = getChatCompletionBody(r); err != nil {
http.Error(w, "could not read request", http.StatusInternalServerError)
return
}
res := openai.ChatCompletionResponse{
ID: strconv.Itoa(int(time.Now().Unix())),
Object: "test-object",
Created: time.Now().Unix(),
// would be nice to validate Model during testing, but
// this may not be possible with how much upkeep
// would be required / wouldn't make much sense
Model: completionReq.Model,
}
// create completions
n := completionReq.N
if n == 0 {
n = 1
}
if completionReq.MaxCompletionTokens == 0 {
completionReq.MaxCompletionTokens = 1000
}
for i := 0; i < n; i++ {
reasoningContent := "User says hello! And I need to reply"
completionStr := strings.Repeat("a", completionReq.MaxCompletionTokens-numTokens(reasoningContent))
res.Choices = append(res.Choices, openai.ChatCompletionChoice{
Message: openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleAssistant,
ReasoningContent: reasoningContent,
Content: completionStr,
},
Index: i,
})
}
inputTokens := numTokens(completionReq.Messages[0].Content) * n
completionTokens := completionReq.MaxTokens * n
res.Usage = openai.Usage{
PromptTokens: inputTokens,
CompletionTokens: completionTokens,
TotalTokens: inputTokens + completionTokens,
}
resBytes, _ = json.Marshal(res)
w.Header().Set(xCustomHeader, xCustomHeaderValue)
for k, v := range rateLimitHeaders {
switch val := v.(type) {
case int:
w.Header().Set(k, strconv.Itoa(val))
default:
w.Header().Set(k, fmt.Sprintf("%s", v))
}
}
fmt.Fprintln(w, string(resBytes))
}
// getChatCompletionBody Returns the body of the request to create a completion.
func getChatCompletionBody(r *http.Request) (openai.ChatCompletionRequest, error) {
completion := openai.ChatCompletionRequest{}

View File

@@ -182,21 +182,13 @@ func sendRequestStream[T streamable](client *Client, req *http.Request) (*stream
func (c *Client) setCommonHeaders(req *http.Request) {
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/reference#authentication
switch c.config.APIType {
case APITypeAzure, APITypeCloudflareAzure:
// Azure API Key authentication
if c.config.APIType == APITypeAzure || c.config.APIType == APITypeCloudflareAzure {
req.Header.Set(AzureAPIKeyHeader, c.config.authToken)
case APITypeAnthropic:
// https://docs.anthropic.com/en/api/versioning
req.Header.Set("anthropic-version", c.config.APIVersion)
case APITypeOpenAI, APITypeAzureAD:
fallthrough
default:
if c.config.authToken != "" {
} else if c.config.authToken != "" {
// OpenAI or Azure AD authentication
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken))
}
}
if c.config.OrgID != "" {
req.Header.Set("OpenAI-Organization", c.config.OrgID)
}

View File

@@ -39,21 +39,6 @@ func TestClient(t *testing.T) {
}
}
func TestSetCommonHeadersAnthropic(t *testing.T) {
config := DefaultAnthropicConfig("mock-token", "")
client := NewClientWithConfig(config)
req, err := http.NewRequest("GET", "http://example.com", nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
client.setCommonHeaders(req)
if got := req.Header.Get("anthropic-version"); got != AnthropicAPIVersion {
t.Errorf("Expected anthropic-version header to be %q, got %q", AnthropicAPIVersion, got)
}
}
func TestDecodeResponse(t *testing.T) {
stringInput := ""

View File

@@ -15,8 +15,6 @@ type Usage struct {
type CompletionTokensDetails struct {
AudioTokens int `json:"audio_tokens"`
ReasoningTokens int `json:"reasoning_tokens"`
AcceptedPredictionTokens int `json:"accepted_prediction_tokens"`
RejectedPredictionTokens int `json:"rejected_prediction_tokens"`
}
// PromptTokensDetails Breakdown of tokens used in the prompt.

View File

@@ -16,12 +16,8 @@ const (
O1Preview20240912 = "o1-preview-2024-09-12"
O1 = "o1"
O120241217 = "o1-2024-12-17"
O3 = "o3"
O320250416 = "o3-2025-04-16"
O3Mini = "o3-mini"
O3Mini20250131 = "o3-mini-2025-01-31"
O4Mini = "o4-mini"
O4Mini20250416 = "o4-mini-2025-04-16"
GPT432K0613 = "gpt-4-32k-0613"
GPT432K0314 = "gpt-4-32k-0314"
GPT432K = "gpt-4-32k"
@@ -41,14 +37,6 @@ const (
GPT4TurboPreview = "gpt-4-turbo-preview"
GPT4VisionPreview = "gpt-4-vision-preview"
GPT4 = "gpt-4"
GPT4Dot1 = "gpt-4.1"
GPT4Dot120250414 = "gpt-4.1-2025-04-14"
GPT4Dot1Mini = "gpt-4.1-mini"
GPT4Dot1Mini20250414 = "gpt-4.1-mini-2025-04-14"
GPT4Dot1Nano = "gpt-4.1-nano"
GPT4Dot1Nano20250414 = "gpt-4.1-nano-2025-04-14"
GPT4Dot5Preview = "gpt-4.5-preview"
GPT4Dot5Preview20250227 = "gpt-4.5-preview-2025-02-27"
GPT3Dot5Turbo0125 = "gpt-3.5-turbo-0125"
GPT3Dot5Turbo1106 = "gpt-3.5-turbo-1106"
GPT3Dot5Turbo0613 = "gpt-3.5-turbo-0613"
@@ -103,10 +91,6 @@ var disabledModelsForEndpoints = map[string]map[string]bool{
O1Preview20240912: true,
O3Mini: true,
O3Mini20250131: true,
O4Mini: true,
O4Mini20250416: true,
O3: true,
O320250416: true,
GPT3Dot5Turbo: true,
GPT3Dot5Turbo0301: true,
GPT3Dot5Turbo0613: true,
@@ -115,8 +99,6 @@ var disabledModelsForEndpoints = map[string]map[string]bool{
GPT3Dot5Turbo16K: true,
GPT3Dot5Turbo16K0613: true,
GPT4: true,
GPT4Dot5Preview: true,
GPT4Dot5Preview20250227: true,
GPT4o: true,
GPT4o20240513: true,
GPT4o20240806: true,
@@ -135,13 +117,6 @@ var disabledModelsForEndpoints = map[string]map[string]bool{
GPT432K: true,
GPT432K0314: true,
GPT432K0613: true,
O1: true,
GPT4Dot1: true,
GPT4Dot120250414: true,
GPT4Dot1Mini: true,
GPT4Dot1Mini20250414: true,
GPT4Dot1Nano: true,
GPT4Dot1Nano20250414: true,
},
chatCompletionsSuffix: {
CodexCodeDavinci002: true,
@@ -215,8 +190,6 @@ type CompletionRequest struct {
Temperature float32 `json:"temperature,omitempty"`
TopP float32 `json:"top_p,omitempty"`
User string `json:"user,omitempty"`
// Options for streaming response. Only set this when you set stream: true.
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
}
// CompletionChoice represents one of possible completions.

View File

@@ -33,42 +33,6 @@ func TestCompletionsWrongModel(t *testing.T) {
}
}
// TestCompletionsWrongModelO3 Tests the completions endpoint with O3 model which is not supported.
func TestCompletionsWrongModelO3(t *testing.T) {
config := openai.DefaultConfig("whatever")
config.BaseURL = "http://localhost/v1"
client := openai.NewClientWithConfig(config)
_, err := client.CreateCompletion(
context.Background(),
openai.CompletionRequest{
MaxTokens: 5,
Model: openai.O3,
},
)
if !errors.Is(err, openai.ErrCompletionUnsupportedModel) {
t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel for O3, but returned: %v", err)
}
}
// TestCompletionsWrongModelO4Mini Tests the completions endpoint with O4Mini model which is not supported.
func TestCompletionsWrongModelO4Mini(t *testing.T) {
config := openai.DefaultConfig("whatever")
config.BaseURL = "http://localhost/v1"
client := openai.NewClientWithConfig(config)
_, err := client.CreateCompletion(
context.Background(),
openai.CompletionRequest{
MaxTokens: 5,
Model: openai.O4Mini,
},
)
if !errors.Is(err, openai.ErrCompletionUnsupportedModel) {
t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel for O4Mini, but returned: %v", err)
}
}
func TestCompletionWithStream(t *testing.T) {
config := openai.DefaultConfig("whatever")
client := openai.NewClientWithConfig(config)
@@ -217,86 +181,3 @@ func getCompletionBody(r *http.Request) (openai.CompletionRequest, error) {
}
return completion, nil
}
// TestCompletionWithO1Model Tests that O1 model is not supported for completion endpoint.
func TestCompletionWithO1Model(t *testing.T) {
config := openai.DefaultConfig("whatever")
config.BaseURL = "http://localhost/v1"
client := openai.NewClientWithConfig(config)
_, err := client.CreateCompletion(
context.Background(),
openai.CompletionRequest{
MaxTokens: 5,
Model: openai.O1,
},
)
if !errors.Is(err, openai.ErrCompletionUnsupportedModel) {
t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel for O1 model, but returned: %v", err)
}
}
// TestCompletionWithGPT4DotModels Tests that newer GPT4 models are not supported for completion endpoint.
func TestCompletionWithGPT4DotModels(t *testing.T) {
config := openai.DefaultConfig("whatever")
config.BaseURL = "http://localhost/v1"
client := openai.NewClientWithConfig(config)
models := []string{
openai.GPT4Dot1,
openai.GPT4Dot120250414,
openai.GPT4Dot1Mini,
openai.GPT4Dot1Mini20250414,
openai.GPT4Dot1Nano,
openai.GPT4Dot1Nano20250414,
openai.GPT4Dot5Preview,
openai.GPT4Dot5Preview20250227,
}
for _, model := range models {
t.Run(model, func(t *testing.T) {
_, err := client.CreateCompletion(
context.Background(),
openai.CompletionRequest{
MaxTokens: 5,
Model: model,
},
)
if !errors.Is(err, openai.ErrCompletionUnsupportedModel) {
t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel for %s model, but returned: %v", model, err)
}
})
}
}
// TestCompletionWithGPT4oModels Tests that GPT4o models are not supported for completion endpoint.
func TestCompletionWithGPT4oModels(t *testing.T) {
config := openai.DefaultConfig("whatever")
config.BaseURL = "http://localhost/v1"
client := openai.NewClientWithConfig(config)
models := []string{
openai.GPT4o,
openai.GPT4o20240513,
openai.GPT4o20240806,
openai.GPT4o20241120,
openai.GPT4oLatest,
openai.GPT4oMini,
openai.GPT4oMini20240718,
}
for _, model := range models {
t.Run(model, func(t *testing.T) {
_, err := client.CreateCompletion(
context.Background(),
openai.CompletionRequest{
MaxTokens: 5,
Model: model,
},
)
if !errors.Is(err, openai.ErrCompletionUnsupportedModel) {
t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel for %s model, but returned: %v", model, err)
}
})
}
}

View File

@@ -11,8 +11,6 @@ const (
azureAPIPrefix = "openai"
azureDeploymentsPrefix = "deployments"
AnthropicAPIVersion = "2023-06-01"
)
type APIType string
@@ -22,7 +20,6 @@ const (
APITypeAzure APIType = "AZURE"
APITypeAzureAD APIType = "AZURE_AD"
APITypeCloudflareAzure APIType = "CLOUDFLARE_AZURE"
APITypeAnthropic APIType = "ANTHROPIC"
)
const AzureAPIKeyHeader = "api-key"
@@ -40,7 +37,7 @@ type ClientConfig struct {
BaseURL string
OrgID string
APIType APIType
APIVersion string // required when APIType is APITypeAzure or APITypeAzureAD or APITypeAnthropic
APIVersion string // required when APIType is APITypeAzure or APITypeAzureAD
AssistantVersion string
AzureModelMapperFunc func(model string) string // replace model to azure deployment name func
HTTPClient HTTPDoer
@@ -79,23 +76,6 @@ func DefaultAzureConfig(apiKey, baseURL string) ClientConfig {
}
}
func DefaultAnthropicConfig(apiKey, baseURL string) ClientConfig {
if baseURL == "" {
baseURL = "https://api.anthropic.com/v1"
}
return ClientConfig{
authToken: apiKey,
BaseURL: baseURL,
OrgID: "",
APIType: APITypeAnthropic,
APIVersion: AnthropicAPIVersion,
HTTPClient: &http.Client{},
EmptyMessagesLimit: defaultEmptyMessagesLimit,
}
}
func (ClientConfig) String() string {
return "<OpenAI API ClientConfig>"
}

View File

@@ -60,64 +60,3 @@ func TestGetAzureDeploymentByModel(t *testing.T) {
})
}
}
func TestDefaultAnthropicConfig(t *testing.T) {
apiKey := "test-key"
baseURL := "https://api.anthropic.com/v1"
config := openai.DefaultAnthropicConfig(apiKey, baseURL)
if config.APIType != openai.APITypeAnthropic {
t.Errorf("Expected APIType to be %v, got %v", openai.APITypeAnthropic, config.APIType)
}
if config.APIVersion != openai.AnthropicAPIVersion {
t.Errorf("Expected APIVersion to be 2023-06-01, got %v", config.APIVersion)
}
if config.BaseURL != baseURL {
t.Errorf("Expected BaseURL to be %v, got %v", baseURL, config.BaseURL)
}
if config.EmptyMessagesLimit != 300 {
t.Errorf("Expected EmptyMessagesLimit to be 300, got %v", config.EmptyMessagesLimit)
}
}
func TestDefaultAnthropicConfigWithEmptyValues(t *testing.T) {
config := openai.DefaultAnthropicConfig("", "")
if config.APIType != openai.APITypeAnthropic {
t.Errorf("Expected APIType to be %v, got %v", openai.APITypeAnthropic, config.APIType)
}
if config.APIVersion != openai.AnthropicAPIVersion {
t.Errorf("Expected APIVersion to be %s, got %v", openai.AnthropicAPIVersion, config.APIVersion)
}
expectedBaseURL := "https://api.anthropic.com/v1"
if config.BaseURL != expectedBaseURL {
t.Errorf("Expected BaseURL to be %v, got %v", expectedBaseURL, config.BaseURL)
}
}
func TestClientConfigString(t *testing.T) {
// String() should always return the constant value
conf := openai.DefaultConfig("dummy-token")
expected := "<OpenAI API ClientConfig>"
got := conf.String()
if got != expected {
t.Errorf("ClientConfig.String() = %q; want %q", got, expected)
}
}
func TestGetAzureDeploymentByModel_NoMapper(t *testing.T) {
// On a zero-value or DefaultConfig, AzureModelMapperFunc is nil,
// so GetAzureDeploymentByModel should just return the input model.
conf := openai.DefaultConfig("dummy-token")
model := "some-model"
got := conf.GetAzureDeploymentByModel(model)
if got != model {
t.Errorf("GetAzureDeploymentByModel(%q) = %q; want %q", model, got, model)
}
}

View File

@@ -3,8 +3,8 @@ package openai
import (
"bytes"
"context"
"io"
"net/http"
"os"
"strconv"
)
@@ -13,62 +13,31 @@ const (
CreateImageSize256x256 = "256x256"
CreateImageSize512x512 = "512x512"
CreateImageSize1024x1024 = "1024x1024"
// dall-e-3 supported only.
CreateImageSize1792x1024 = "1792x1024"
CreateImageSize1024x1792 = "1024x1792"
// gpt-image-1 supported only.
CreateImageSize1536x1024 = "1536x1024" // Landscape
CreateImageSize1024x1536 = "1024x1536" // Portrait
)
const (
// dall-e-2 and dall-e-3 only.
CreateImageResponseFormatB64JSON = "b64_json"
CreateImageResponseFormatURL = "url"
CreateImageResponseFormatB64JSON = "b64_json"
)
const (
CreateImageModelDallE2 = "dall-e-2"
CreateImageModelDallE3 = "dall-e-3"
CreateImageModelGptImage1 = "gpt-image-1"
)
const (
CreateImageQualityHD = "hd"
CreateImageQualityStandard = "standard"
// gpt-image-1 only.
CreateImageQualityHigh = "high"
CreateImageQualityMedium = "medium"
CreateImageQualityLow = "low"
)
const (
// dall-e-3 only.
CreateImageStyleVivid = "vivid"
CreateImageStyleNatural = "natural"
)
const (
// gpt-image-1 only.
CreateImageBackgroundTransparent = "transparent"
CreateImageBackgroundOpaque = "opaque"
)
const (
// gpt-image-1 only.
CreateImageModerationLow = "low"
)
const (
// gpt-image-1 only.
CreateImageOutputFormatPNG = "png"
CreateImageOutputFormatJPEG = "jpeg"
CreateImageOutputFormatWEBP = "webp"
)
// ImageRequest represents the request structure for the image API.
type ImageRequest struct {
Prompt string `json:"prompt,omitempty"`
@@ -79,35 +48,16 @@ type ImageRequest struct {
Style string `json:"style,omitempty"`
ResponseFormat string `json:"response_format,omitempty"`
User string `json:"user,omitempty"`
Background string `json:"background,omitempty"`
Moderation string `json:"moderation,omitempty"`
OutputCompression int `json:"output_compression,omitempty"`
OutputFormat string `json:"output_format,omitempty"`
}
// ImageResponse represents a response structure for image API.
type ImageResponse struct {
Created int64 `json:"created,omitempty"`
Data []ImageResponseDataInner `json:"data,omitempty"`
Usage ImageResponseUsage `json:"usage,omitempty"`
httpHeader
}
// ImageResponseInputTokensDetails represents the token breakdown for input tokens.
type ImageResponseInputTokensDetails struct {
TextTokens int `json:"text_tokens,omitempty"`
ImageTokens int `json:"image_tokens,omitempty"`
}
// ImageResponseUsage represents the token usage information for image API.
type ImageResponseUsage struct {
TotalTokens int `json:"total_tokens,omitempty"`
InputTokens int `json:"input_tokens,omitempty"`
OutputTokens int `json:"output_tokens,omitempty"`
InputTokensDetails ImageResponseInputTokensDetails `json:"input_tokens_details,omitempty"`
}
// ImageResponseDataInner represents a response data structure for image API.
type ImageResponseDataInner struct {
URL string `json:"url,omitempty"`
@@ -134,15 +84,13 @@ func (c *Client) CreateImage(ctx context.Context, request ImageRequest) (respons
// ImageEditRequest represents the request structure for the image API.
type ImageEditRequest struct {
Image io.Reader `json:"image,omitempty"`
Mask io.Reader `json:"mask,omitempty"`
Image *os.File `json:"image,omitempty"`
Mask *os.File `json:"mask,omitempty"`
Prompt string `json:"prompt,omitempty"`
Model string `json:"model,omitempty"`
N int `json:"n,omitempty"`
Size string `json:"size,omitempty"`
ResponseFormat string `json:"response_format,omitempty"`
Quality string `json:"quality,omitempty"`
User string `json:"user,omitempty"`
}
// CreateEditImage - API call to create an image. This is the main endpoint of the DALL-E API.
@@ -150,16 +98,15 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest)
body := &bytes.Buffer{}
builder := c.createFormBuilder(body)
// image, filename is not required
err = builder.CreateFormFileReader("image", request.Image, "")
// image
err = builder.CreateFormFile("image", request.Image)
if err != nil {
return
}
// mask, it is optional
if request.Mask != nil {
// mask, filename is not required
err = builder.CreateFormFileReader("mask", request.Mask, "")
err = builder.CreateFormFile("mask", request.Mask)
if err != nil {
return
}
@@ -207,12 +154,11 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest)
// ImageVariRequest represents the request structure for the image API.
type ImageVariRequest struct {
Image io.Reader `json:"image,omitempty"`
Image *os.File `json:"image,omitempty"`
Model string `json:"model,omitempty"`
N int `json:"n,omitempty"`
Size string `json:"size,omitempty"`
ResponseFormat string `json:"response_format,omitempty"`
User string `json:"user,omitempty"`
}
// CreateVariImage - API call to create an image variation. This is the main endpoint of the DALL-E API.
@@ -221,8 +167,8 @@ func (c *Client) CreateVariImage(ctx context.Context, request ImageVariRequest)
body := &bytes.Buffer{}
builder := c.createFormBuilder(body)
// image, filename is not required
err = builder.CreateFormFileReader("image", request.Image, "")
// image
err = builder.CreateFormFile("image", request.Image)
if err != nil {
return
}

View File

@@ -54,13 +54,13 @@ func TestImageFormBuilderFailures(t *testing.T) {
}
mockFailedErr := fmt.Errorf("mock form builder fail")
mockBuilder.mockCreateFormFileReader = func(string, io.Reader, string) error {
mockBuilder.mockCreateFormFile = func(string, *os.File) error {
return mockFailedErr
}
_, err := client.CreateEditImage(ctx, req)
checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails")
mockBuilder.mockCreateFormFileReader = func(name string, _ io.Reader, _ string) error {
mockBuilder.mockCreateFormFile = func(name string, _ *os.File) error {
if name == "mask" {
return mockFailedErr
}
@@ -119,13 +119,13 @@ func TestVariImageFormBuilderFailures(t *testing.T) {
req := ImageVariRequest{}
mockFailedErr := fmt.Errorf("mock form builder fail")
mockBuilder.mockCreateFormFileReader = func(string, io.Reader, string) error {
mockBuilder.mockCreateFormFile = func(string, *os.File) error {
return mockFailedErr
}
_, err := client.CreateVariImage(ctx, req)
checks.ErrorIs(t, err, mockFailedErr, "CreateVariImage should return error if form builder fails")
mockBuilder.mockCreateFormFileReader = func(string, io.Reader, string) error {
mockBuilder.mockCreateFormFile = func(string, *os.File) error {
return nil
}

View File

@@ -4,10 +4,8 @@ import (
"fmt"
"io"
"mime/multipart"
"net/textproto"
"os"
"path/filepath"
"strings"
"path"
)
type FormBuilder interface {
@@ -32,37 +30,8 @@ func (fb *DefaultFormBuilder) CreateFormFile(fieldname string, file *os.File) er
return fb.createFormFile(fieldname, file, file.Name())
}
var quoteEscaper = strings.NewReplacer("\\", "\\\\", `"`, "\\\"")
func escapeQuotes(s string) string {
return quoteEscaper.Replace(s)
}
// CreateFormFileReader creates a form field with a file reader.
// The filename in parameters can be an empty string.
// The filename in Content-Disposition is required, But it can be an empty string.
func (fb *DefaultFormBuilder) CreateFormFileReader(fieldname string, r io.Reader, filename string) error {
h := make(textproto.MIMEHeader)
h.Set(
"Content-Disposition",
fmt.Sprintf(
`form-data; name="%s"; filename="%s"`,
escapeQuotes(fieldname),
escapeQuotes(filepath.Base(filename)),
),
)
fieldWriter, err := fb.writer.CreatePart(h)
if err != nil {
return err
}
_, err = io.Copy(fieldWriter, r)
if err != nil {
return err
}
return nil
return fb.createFormFile(fieldname, r, path.Base(filename))
}
func (fb *DefaultFormBuilder) createFormFile(fieldname string, r io.Reader, filename string) error {

View File

@@ -43,32 +43,3 @@ func TestFormBuilderWithClosedFile(t *testing.T) {
checks.HasError(t, err, "formbuilder should return error if file is closed")
checks.ErrorIs(t, err, os.ErrClosed, "formbuilder should return error if file is closed")
}
type failingReader struct {
}
var errMockFailingReaderError = errors.New("mock reader failed")
func (*failingReader) Read([]byte) (int, error) {
return 0, errMockFailingReaderError
}
func TestFormBuilderWithReader(t *testing.T) {
file, err := os.CreateTemp(t.TempDir(), "")
if err != nil {
t.Fatalf("Error creating tmp file: %v", err)
}
defer file.Close()
builder := NewFormBuilder(&failingWriter{})
err = builder.CreateFormFileReader("file", file, file.Name())
checks.ErrorIs(t, err, errMockFailingWriterError, "formbuilder should return error if writer fails")
builder = NewFormBuilder(&bytes.Buffer{})
reader := &failingReader{}
err = builder.CreateFormFileReader("file", reader, "")
checks.ErrorIs(t, err, errMockFailingReaderError, "formbuilder should return error if copy reader fails")
successReader := &bytes.Buffer{}
err = builder.CreateFormFileReader("file", successReader, "")
checks.NoError(t, err, "formbuilder should not return error")
}

View File

@@ -46,8 +46,6 @@ type Definition struct {
// additionalProperties: false
// additionalProperties: jsonschema.Definition{Type: jsonschema.String}
AdditionalProperties any `json:"additionalProperties,omitempty"`
// Whether the schema is nullable or not.
Nullable bool `json:"nullable,omitempty"`
}
func (d *Definition) MarshalJSON() ([]byte, error) {
@@ -126,12 +124,9 @@ func reflectSchemaObject(t reflect.Type) (*Definition, error) {
}
jsonTag := field.Tag.Get("json")
var required = true
switch {
case jsonTag == "-":
continue
case jsonTag == "":
if jsonTag == "" {
jsonTag = field.Name
case strings.HasSuffix(jsonTag, ",omitempty"):
} else if strings.HasSuffix(jsonTag, ",omitempty") {
jsonTag = strings.TrimSuffix(jsonTag, ",omitempty")
required = false
}
@@ -144,16 +139,6 @@ func reflectSchemaObject(t reflect.Type) (*Definition, error) {
if description != "" {
item.Description = description
}
enum := field.Tag.Get("enum")
if enum != "" {
item.Enum = strings.Split(enum, ",")
}
if n := field.Tag.Get("nullable"); n != "" {
nullable, _ := strconv.ParseBool(n)
item.Nullable = nullable
}
properties[jsonTag] = *item
if s := field.Tag.Get("required"); s != "" {

View File

@@ -17,7 +17,7 @@ func TestDefinition_MarshalJSON(t *testing.T) {
{
name: "Test with empty Definition",
def: jsonschema.Definition{},
want: `{}`,
want: `{"properties":{}}`,
},
{
name: "Test with Definition properties set",
@@ -35,7 +35,8 @@ func TestDefinition_MarshalJSON(t *testing.T) {
"description":"A string type",
"properties":{
"name":{
"type":"string"
"type":"string",
"properties":{}
}
}
}`,
@@ -65,10 +66,12 @@ func TestDefinition_MarshalJSON(t *testing.T) {
"type":"object",
"properties":{
"name":{
"type":"string"
"type":"string",
"properties":{}
},
"age":{
"type":"integer"
"type":"integer",
"properties":{}
}
}
}
@@ -111,19 +114,23 @@ func TestDefinition_MarshalJSON(t *testing.T) {
"type":"object",
"properties":{
"name":{
"type":"string"
"type":"string",
"properties":{}
},
"age":{
"type":"integer"
"type":"integer",
"properties":{}
},
"address":{
"type":"object",
"properties":{
"city":{
"type":"string"
"type":"string",
"properties":{}
},
"country":{
"type":"string"
"type":"string",
"properties":{}
}
}
}
@@ -148,11 +155,15 @@ func TestDefinition_MarshalJSON(t *testing.T) {
want: `{
"type":"array",
"items":{
"type":"string"
"type":"string",
"properties":{
}
},
"properties":{
"name":{
"type":"string"
"type":"string",
"properties":{}
}
}
}`,
@@ -182,232 +193,6 @@ func TestDefinition_MarshalJSON(t *testing.T) {
}
}
func TestStructToSchema(t *testing.T) {
tests := []struct {
name string
in any
want string
}{
{
name: "Test with empty struct",
in: struct{}{},
want: `{
"type":"object",
"additionalProperties":false
}`,
},
{
name: "Test with struct containing many fields",
in: struct {
Name string `json:"name"`
Age int `json:"age"`
Active bool `json:"active"`
Height float64 `json:"height"`
Cities []struct {
Name string `json:"name"`
State string `json:"state"`
} `json:"cities"`
}{
Name: "John Doe",
Age: 30,
Cities: []struct {
Name string `json:"name"`
State string `json:"state"`
}{
{Name: "New York", State: "NY"},
{Name: "Los Angeles", State: "CA"},
},
},
want: `{
"type":"object",
"properties":{
"name":{
"type":"string"
},
"age":{
"type":"integer"
},
"active":{
"type":"boolean"
},
"height":{
"type":"number"
},
"cities":{
"type":"array",
"items":{
"additionalProperties":false,
"type":"object",
"properties":{
"name":{
"type":"string"
},
"state":{
"type":"string"
}
},
"required":["name","state"]
}
}
},
"required":["name","age","active","height","cities"],
"additionalProperties":false
}`,
},
{
name: "Test with description tag",
in: struct {
Name string `json:"name" description:"The name of the person"`
}{
Name: "John Doe",
},
want: `{
"type":"object",
"properties":{
"name":{
"type":"string",
"description":"The name of the person"
}
},
"required":["name"],
"additionalProperties":false
}`,
},
{
name: "Test with required tag",
in: struct {
Name string `json:"name" required:"false"`
}{
Name: "John Doe",
},
want: `{
"type":"object",
"properties":{
"name":{
"type":"string"
}
},
"additionalProperties":false
}`,
},
{
name: "Test with enum tag",
in: struct {
Color string `json:"color" enum:"red,green,blue"`
}{
Color: "red",
},
want: `{
"type":"object",
"properties":{
"color":{
"type":"string",
"enum":["red","green","blue"]
}
},
"required":["color"],
"additionalProperties":false
}`,
},
{
name: "Test with nullable tag",
in: struct {
Name *string `json:"name" nullable:"true"`
}{
Name: nil,
},
want: `{
"type":"object",
"properties":{
"name":{
"type":"string",
"nullable":true
}
},
"required":["name"],
"additionalProperties":false
}`,
},
{
name: "Test with exclude mark",
in: struct {
Name string `json:"-"`
}{
Name: "Name",
},
want: `{
"type":"object",
"additionalProperties":false
}`,
},
{
name: "Test with no json tag",
in: struct {
Name string
}{
Name: "",
},
want: `{
"type":"object",
"properties":{
"Name":{
"type":"string"
}
},
"required":["Name"],
"additionalProperties":false
}`,
},
{
name: "Test with omitempty tag",
in: struct {
Name string `json:"name,omitempty"`
}{
Name: "",
},
want: `{
"type":"object",
"properties":{
"name":{
"type":"string"
}
},
"additionalProperties":false
}`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
wantBytes := []byte(tt.want)
schema, err := jsonschema.GenerateSchemaForType(tt.in)
if err != nil {
t.Errorf("Failed to generate schema: error = %v", err)
return
}
var want map[string]interface{}
err = json.Unmarshal(wantBytes, &want)
if err != nil {
t.Errorf("Failed to Unmarshal JSON: error = %v", err)
return
}
got := structToMap(t, schema)
gotPtr := structToMap(t, &schema)
if !reflect.DeepEqual(got, want) {
t.Errorf("MarshalJSON() got = %v, want %v", got, want)
}
if !reflect.DeepEqual(gotPtr, want) {
t.Errorf("MarshalJSON() gotPtr = %v, want %v", gotPtr, want)
}
})
}
}
func structToMap(t *testing.T, v any) map[string]any {
t.Helper()
gotBytes, err := json.Marshal(v)

View File

@@ -47,24 +47,6 @@ func TestGetModel(t *testing.T) {
checks.NoError(t, err, "GetModel error")
}
// TestGetModelO3 Tests the retrieve O3 model endpoint of the API using the mocked server.
func TestGetModelO3(t *testing.T) {
client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler("/v1/models/o3", handleGetModelEndpoint)
_, err := client.GetModel(context.Background(), "o3")
checks.NoError(t, err, "GetModel error for O3")
}
// TestGetModelO4Mini Tests the retrieve O4Mini model endpoint of the API using the mocked server.
func TestGetModelO4Mini(t *testing.T) {
client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler("/v1/models/o4-mini", handleGetModelEndpoint)
_, err := client.GetModel(context.Background(), "o4-mini")
checks.NoError(t, err, "GetModel error for O4Mini")
}
func TestAzureGetModel(t *testing.T) {
client, server, teardown := setupAzureTestServer()
defer teardown()

View File

@@ -29,7 +29,7 @@ func setupAzureTestServer() (client *openai.Client, server *test.ServerTest, tea
// numTokens Returns the number of GPT-3 encoded tokens in the given text.
// This function approximates based on the rule of thumb stated by OpenAI:
// https://beta.openai.com/tokenizer.
// https://beta.openai.com/tokenizer/
//
// TODO: implement an actual tokenizer for GPT-3 and Codex (once available).
func numTokens(s string) int {

View File

@@ -28,6 +28,15 @@ var (
ErrReasoningModelLimitationsOther = errors.New("this model has beta-limitations, temperature, top_p and n are fixed at 1, while presence_penalty and frequency_penalty are fixed at 0") //nolint:lll
)
var unsupportedToolsForO1Models = map[ToolType]struct{}{
ToolTypeFunction: {},
}
var availableMessageRoleForO1Models = map[string]struct{}{
ChatMessageRoleUser: {},
ChatMessageRoleAssistant: {},
}
// ReasoningValidator handles validation for o-series model requests.
type ReasoningValidator struct{}
@@ -40,9 +49,8 @@ func NewReasoningValidator() *ReasoningValidator {
func (v *ReasoningValidator) Validate(request ChatCompletionRequest) error {
o1Series := strings.HasPrefix(request.Model, "o1")
o3Series := strings.HasPrefix(request.Model, "o3")
o4Series := strings.HasPrefix(request.Model, "o4")
if !o1Series && !o3Series && !o4Series {
if !o1Series && !o3Series {
return nil
}
@@ -50,6 +58,12 @@ func (v *ReasoningValidator) Validate(request ChatCompletionRequest) error {
return err
}
if o1Series {
if err := v.validateO1Specific(request); err != nil {
return err
}
}
return nil
}
@@ -79,3 +93,19 @@ func (v *ReasoningValidator) validateReasoningModelParams(request ChatCompletion
return nil
}
// validateO1Specific checks O1-specific limitations.
func (v *ReasoningValidator) validateO1Specific(request ChatCompletionRequest) error {
for _, m := range request.Messages {
if _, found := availableMessageRoleForO1Models[m.Role]; !found {
return ErrO1BetaLimitationsMessageTypes
}
}
for _, t := range request.Tools {
if _, found := unsupportedToolsForO1Models[t.Type]; found {
return ErrO1BetaLimitationsTools
}
}
return nil
}

View File

@@ -11,22 +11,17 @@ const (
TTSModel1 SpeechModel = "tts-1"
TTSModel1HD SpeechModel = "tts-1-hd"
TTSModelCanary SpeechModel = "canary-tts"
TTSModelGPT4oMini SpeechModel = "gpt-4o-mini-tts"
)
type SpeechVoice string
const (
VoiceAlloy SpeechVoice = "alloy"
VoiceAsh SpeechVoice = "ash"
VoiceBallad SpeechVoice = "ballad"
VoiceCoral SpeechVoice = "coral"
VoiceEcho SpeechVoice = "echo"
VoiceFable SpeechVoice = "fable"
VoiceOnyx SpeechVoice = "onyx"
VoiceNova SpeechVoice = "nova"
VoiceShimmer SpeechVoice = "shimmer"
VoiceVerse SpeechVoice = "verse"
)
type SpeechResponseFormat string
@@ -44,7 +39,6 @@ type CreateSpeechRequest struct {
Model SpeechModel `json:"model"`
Input string `json:"input"`
Voice SpeechVoice `json:"voice"`
Instructions string `json:"instructions,omitempty"` // Optional, Doesnt work with tts-1 or tts-1-hd.
ResponseFormat SpeechResponseFormat `json:"response_format,omitempty"` // Optional, default to mp3
Speed float64 `json:"speed,omitempty"` // Optional, default to 1.0
}

View File

@@ -6,14 +6,13 @@ import (
"fmt"
"io"
"net/http"
"regexp"
utils "git.vaala.cloud/VaalaCat/go-openai/internal"
)
var (
headerData = regexp.MustCompile(`^data:\s*`)
errorPrefix = regexp.MustCompile(`^data:\s*{"error":`)
headerData = []byte("data: ")
errorPrefix = []byte(`data: {"error":`)
)
type streamable interface {
@@ -71,12 +70,12 @@ func (stream *streamReader[T]) processLines() ([]byte, error) {
}
noSpaceLine := bytes.TrimSpace(rawLine)
if errorPrefix.Match(noSpaceLine) {
if bytes.HasPrefix(noSpaceLine, errorPrefix) {
hasErrorPrefix = true
}
if !headerData.Match(noSpaceLine) || hasErrorPrefix {
if !bytes.HasPrefix(noSpaceLine, headerData) || hasErrorPrefix {
if hasErrorPrefix {
noSpaceLine = headerData.ReplaceAll(noSpaceLine, nil)
noSpaceLine = bytes.TrimPrefix(noSpaceLine, headerData)
}
writeErr := stream.errAccumulator.Write(noSpaceLine)
if writeErr != nil {
@@ -90,7 +89,7 @@ func (stream *streamReader[T]) processLines() ([]byte, error) {
continue
}
noPrefixLine := headerData.ReplaceAll(noSpaceLine, nil)
noPrefixLine := bytes.TrimPrefix(noSpaceLine, headerData)
if string(noPrefixLine) == "[DONE]" {
stream.isFinished = true
return nil, io.EOF