Compare commits
27 Commits
1337a4b683
...
2436e7afb8
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2436e7afb8 | ||
|
|
67f3b169df | ||
|
|
3640274cd1 | ||
|
|
ff9d83a485 | ||
|
|
8c65b35c57 | ||
|
|
4d2e7ab29d | ||
|
|
6aaa732296 | ||
|
|
0116f2994d | ||
|
|
8ba38f6ba1 | ||
|
|
6181facea7 | ||
|
|
77ccac8d34 | ||
|
|
5ea214a188 | ||
|
|
d65f0cb54e | ||
|
|
93a611cf4f | ||
|
|
6836cf6a6f | ||
|
|
da5f9bc9bc | ||
|
|
bb5bc27567 | ||
|
|
4cccc6c934 | ||
|
|
306fbbbe6f | ||
|
|
658beda2ba | ||
|
|
d68a683815 | ||
|
|
e99eb54c9d | ||
|
|
74d6449f22 | ||
|
|
261721bfdb | ||
|
|
be2e2387d4 | ||
|
|
85f578b865 | ||
|
|
c0a9a75fe0 |
12
.github/workflows/pr.yml
vendored
12
.github/workflows/pr.yml
vendored
@@ -13,15 +13,17 @@ jobs:
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.21'
|
||||
go-version: '1.24'
|
||||
- name: Run vet
|
||||
run: |
|
||||
go vet .
|
||||
- name: Run golangci-lint
|
||||
uses: golangci/golangci-lint-action@v4
|
||||
uses: golangci/golangci-lint-action@v7
|
||||
with:
|
||||
version: latest
|
||||
version: v2.1.5
|
||||
- name: Run tests
|
||||
run: go test -race -covermode=atomic -coverprofile=coverage.out -v .
|
||||
run: go test -race -covermode=atomic -coverprofile=coverage.out -v ./...
|
||||
- name: Upload coverage reports to Codecov
|
||||
uses: codecov/codecov-action@v4
|
||||
uses: codecov/codecov-action@v5
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
|
||||
349
.golangci.yml
349
.golangci.yml
@@ -1,66 +1,94 @@
|
||||
## Golden config for golangci-lint v1.47.3
|
||||
#
|
||||
# This is the best config for golangci-lint based on my experience and opinion.
|
||||
# It is very strict, but not extremely strict.
|
||||
# Feel free to adopt and change it for your needs.
|
||||
|
||||
run:
|
||||
# Timeout for analysis, e.g. 30s, 5m.
|
||||
# Default: 1m
|
||||
timeout: 3m
|
||||
|
||||
|
||||
# This file contains only configs which differ from defaults.
|
||||
# All possible options can be found here https://github.com/golangci/golangci-lint/blob/master/.golangci.reference.yml
|
||||
linters-settings:
|
||||
version: "2"
|
||||
linters:
|
||||
default: none
|
||||
enable:
|
||||
- asciicheck
|
||||
- bidichk
|
||||
- bodyclose
|
||||
- contextcheck
|
||||
- cyclop
|
||||
- dupl
|
||||
- durationcheck
|
||||
- errcheck
|
||||
- errname
|
||||
- errorlint
|
||||
- exhaustive
|
||||
- forbidigo
|
||||
- funlen
|
||||
- gochecknoinits
|
||||
- gocognit
|
||||
- goconst
|
||||
- gocritic
|
||||
- gocyclo
|
||||
- godot
|
||||
- gomoddirectives
|
||||
- gomodguard
|
||||
- goprintffuncname
|
||||
- gosec
|
||||
- govet
|
||||
- ineffassign
|
||||
- lll
|
||||
- makezero
|
||||
- mnd
|
||||
- nestif
|
||||
- nilerr
|
||||
- nilnil
|
||||
- nolintlint
|
||||
- nosprintfhostport
|
||||
- predeclared
|
||||
- promlinter
|
||||
- revive
|
||||
- rowserrcheck
|
||||
- sqlclosecheck
|
||||
- staticcheck
|
||||
- testpackage
|
||||
- tparallel
|
||||
- unconvert
|
||||
- unparam
|
||||
- unused
|
||||
- usetesting
|
||||
- wastedassign
|
||||
- whitespace
|
||||
settings:
|
||||
cyclop:
|
||||
# The maximal code complexity to report.
|
||||
# Default: 10
|
||||
max-complexity: 30
|
||||
# The maximal average package complexity.
|
||||
# If it's higher than 0.0 (float) the check is enabled
|
||||
# Default: 0.0
|
||||
package-average: 10.0
|
||||
|
||||
package-average: 10
|
||||
errcheck:
|
||||
# Report about not checking of errors in type assertions: `a := b.(MyStruct)`.
|
||||
# Such cases aren't reported by default.
|
||||
# Default: false
|
||||
check-type-assertions: true
|
||||
|
||||
funlen:
|
||||
# Checks the number of lines in a function.
|
||||
# If lower than 0, disable the check.
|
||||
# Default: 60
|
||||
lines: 100
|
||||
# Checks the number of statements in a function.
|
||||
# If lower than 0, disable the check.
|
||||
# Default: 40
|
||||
statements: 50
|
||||
|
||||
gocognit:
|
||||
# Minimal code complexity to report
|
||||
# Default: 30 (but we recommend 10-20)
|
||||
min-complexity: 20
|
||||
|
||||
gocritic:
|
||||
# Settings passed to gocritic.
|
||||
# The settings key is the name of a supported gocritic checker.
|
||||
# The list of supported checkers can be find in https://go-critic.github.io/overview.
|
||||
settings:
|
||||
captLocal:
|
||||
# Whether to restrict checker to params only.
|
||||
# Default: true
|
||||
paramsOnly: false
|
||||
underef:
|
||||
# Whether to skip (*x).method() calls where x is a pointer receiver.
|
||||
# Default: true
|
||||
skipRecvDeref: false
|
||||
|
||||
gomodguard:
|
||||
blocked:
|
||||
modules:
|
||||
- github.com/golang/protobuf:
|
||||
recommendations:
|
||||
- google.golang.org/protobuf
|
||||
reason: see https://developers.google.com/protocol-buffers/docs/reference/go/faq#modules
|
||||
- github.com/satori/go.uuid:
|
||||
recommendations:
|
||||
- github.com/google/uuid
|
||||
reason: satori's package is not maintained
|
||||
- github.com/gofrs/uuid:
|
||||
recommendations:
|
||||
- github.com/google/uuid
|
||||
reason: 'see recommendation from dev-infra team: https://confluence.gtforge.com/x/gQI6Aw'
|
||||
govet:
|
||||
disable:
|
||||
- fieldalignment
|
||||
enable-all: true
|
||||
settings:
|
||||
shadow:
|
||||
strict: true
|
||||
mnd:
|
||||
# List of function patterns to exclude from analysis.
|
||||
# Values always ignored: `time.Date`
|
||||
# Default: []
|
||||
ignored-functions:
|
||||
- os.Chmod
|
||||
- os.Mkdir
|
||||
@@ -76,194 +104,44 @@ linters-settings:
|
||||
- strconv.ParseFloat
|
||||
- strconv.ParseInt
|
||||
- strconv.ParseUint
|
||||
|
||||
gomodguard:
|
||||
blocked:
|
||||
# List of blocked modules.
|
||||
# Default: []
|
||||
modules:
|
||||
- github.com/golang/protobuf:
|
||||
recommendations:
|
||||
- google.golang.org/protobuf
|
||||
reason: "see https://developers.google.com/protocol-buffers/docs/reference/go/faq#modules"
|
||||
- github.com/satori/go.uuid:
|
||||
recommendations:
|
||||
- github.com/google/uuid
|
||||
reason: "satori's package is not maintained"
|
||||
- github.com/gofrs/uuid:
|
||||
recommendations:
|
||||
- github.com/google/uuid
|
||||
reason: "see recommendation from dev-infra team: https://confluence.gtforge.com/x/gQI6Aw"
|
||||
|
||||
govet:
|
||||
# Enable all analyzers.
|
||||
# Default: false
|
||||
enable-all: true
|
||||
# Disable analyzers by name.
|
||||
# Run `go tool vet help` to see all analyzers.
|
||||
# Default: []
|
||||
disable:
|
||||
- fieldalignment # too strict
|
||||
# Settings per analyzer.
|
||||
settings:
|
||||
shadow:
|
||||
# Whether to be strict about shadowing; can be noisy.
|
||||
# Default: false
|
||||
strict: true
|
||||
|
||||
nakedret:
|
||||
# Make an issue if func has more lines of code than this setting, and it has naked returns.
|
||||
# Default: 30
|
||||
max-func-lines: 0
|
||||
|
||||
nolintlint:
|
||||
# Exclude following linters from requiring an explanation.
|
||||
# Default: []
|
||||
allow-no-explanation: [ funlen, gocognit, lll ]
|
||||
# Enable to require an explanation of nonzero length after each nolint directive.
|
||||
# Default: false
|
||||
require-explanation: true
|
||||
# Enable to require nolint directives to mention the specific linter being suppressed.
|
||||
# Default: false
|
||||
require-specific: true
|
||||
|
||||
allow-no-explanation:
|
||||
- funlen
|
||||
- gocognit
|
||||
- lll
|
||||
rowserrcheck:
|
||||
# database/sql is always checked
|
||||
# Default: []
|
||||
packages:
|
||||
- github.com/jmoiron/sqlx
|
||||
|
||||
tenv:
|
||||
# The option `all` will run against whole test files (`_test.go`) regardless of method/function signatures.
|
||||
# Otherwise, only methods that take `*testing.T`, `*testing.B`, and `testing.TB` as arguments are checked.
|
||||
# Default: false
|
||||
all: true
|
||||
|
||||
varcheck:
|
||||
# Check usage of exported fields and variables.
|
||||
# Default: false
|
||||
exported-fields: false # default false # TODO: enable after fixing false positives
|
||||
|
||||
|
||||
linters:
|
||||
disable-all: true
|
||||
enable:
|
||||
## enabled by default
|
||||
- errcheck # Errcheck is a program for checking for unchecked errors in go programs. These unchecked errors can be critical bugs in some cases
|
||||
- gosimple # Linter for Go source code that specializes in simplifying a code
|
||||
- govet # Vet examines Go source code and reports suspicious constructs, such as Printf calls whose arguments do not align with the format string
|
||||
- ineffassign # Detects when assignments to existing variables are not used
|
||||
- staticcheck # Staticcheck is a go vet on steroids, applying a ton of static analysis checks
|
||||
- typecheck # Like the front-end of a Go compiler, parses and type-checks Go code
|
||||
- unused # Checks Go code for unused constants, variables, functions and types
|
||||
## disabled by default
|
||||
# - asasalint # Check for pass []any as any in variadic func(...any)
|
||||
- asciicheck # Simple linter to check that your code does not contain non-ASCII identifiers
|
||||
- bidichk # Checks for dangerous unicode character sequences
|
||||
- bodyclose # checks whether HTTP response body is closed successfully
|
||||
- contextcheck # check the function whether use a non-inherited context
|
||||
- cyclop # checks function and package cyclomatic complexity
|
||||
- dupl # Tool for code clone detection
|
||||
- durationcheck # check for two durations multiplied together
|
||||
- errname # Checks that sentinel errors are prefixed with the Err and error types are suffixed with the Error.
|
||||
- errorlint # errorlint is a linter for that can be used to find code that will cause problems with the error wrapping scheme introduced in Go 1.13.
|
||||
# Removed execinquery (deprecated). execinquery is a linter about query string checker in Query function which reads your Go src files and warning it finds
|
||||
- exhaustive # check exhaustiveness of enum switch statements
|
||||
- exportloopref # checks for pointers to enclosing loop variables
|
||||
- forbidigo # Forbids identifiers
|
||||
- funlen # Tool for detection of long functions
|
||||
# - gochecknoglobals # check that no global variables exist
|
||||
- gochecknoinits # Checks that no init functions are present in Go code
|
||||
- gocognit # Computes and checks the cognitive complexity of functions
|
||||
- goconst # Finds repeated strings that could be replaced by a constant
|
||||
- gocritic # Provides diagnostics that check for bugs, performance and style issues.
|
||||
- gocyclo # Computes and checks the cyclomatic complexity of functions
|
||||
- godot # Check if comments end in a period
|
||||
- goimports # In addition to fixing imports, goimports also formats your code in the same style as gofmt.
|
||||
- gomoddirectives # Manage the use of 'replace', 'retract', and 'excludes' directives in go.mod.
|
||||
- gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations.
|
||||
- goprintffuncname # Checks that printf-like functions are named with f at the end
|
||||
- gosec # Inspects source code for security problems
|
||||
- lll # Reports long lines
|
||||
- makezero # Finds slice declarations with non-zero initial length
|
||||
# - nakedret # Finds naked returns in functions greater than a specified function length
|
||||
- mnd # An analyzer to detect magic numbers.
|
||||
- nestif # Reports deeply nested if statements
|
||||
- nilerr # Finds the code that returns nil even if it checks that the error is not nil.
|
||||
- nilnil # Checks that there is no simultaneous return of nil error and an invalid value.
|
||||
# - noctx # noctx finds sending http request without context.Context
|
||||
- nolintlint # Reports ill-formed or insufficient nolint directives
|
||||
# - nonamedreturns # Reports all named returns
|
||||
- nosprintfhostport # Checks for misuse of Sprintf to construct a host with port in a URL.
|
||||
- predeclared # find code that shadows one of Go's predeclared identifiers
|
||||
- promlinter # Check Prometheus metrics naming via promlint
|
||||
- revive # Fast, configurable, extensible, flexible, and beautiful linter for Go. Drop-in replacement of golint.
|
||||
- rowserrcheck # checks whether Err of rows is checked successfully
|
||||
- sqlclosecheck # Checks that sql.Rows and sql.Stmt are closed.
|
||||
- stylecheck # Stylecheck is a replacement for golint
|
||||
- tenv # tenv is analyzer that detects using os.Setenv instead of t.Setenv since Go1.17
|
||||
- testpackage # linter that makes you use a separate _test package
|
||||
- tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes
|
||||
- unconvert # Remove unnecessary type conversions
|
||||
- unparam # Reports unused function parameters
|
||||
- usetesting # Reports uses of functions with replacement inside the testing package
|
||||
- wastedassign # wastedassign finds wasted assignment statements.
|
||||
- whitespace # Tool for detection of leading and trailing whitespace
|
||||
## you may want to enable
|
||||
#- decorder # check declaration order and count of types, constants, variables and functions
|
||||
#- exhaustruct # Checks if all structure fields are initialized
|
||||
#- goheader # Checks is file header matches to pattern
|
||||
#- ireturn # Accept Interfaces, Return Concrete Types
|
||||
#- prealloc # [premature optimization, but can be used in some cases] Finds slice declarations that could potentially be preallocated
|
||||
#- varnamelen # [great idea, but too many false positives] checks that the length of a variable's name matches its scope
|
||||
#- wrapcheck # Checks that errors returned from external packages are wrapped
|
||||
## disabled
|
||||
#- containedctx # containedctx is a linter that detects struct contained context.Context field
|
||||
#- depguard # [replaced by gomodguard] Go linter that checks if package imports are in a list of acceptable packages
|
||||
#- dogsled # Checks assignments with too many blank identifiers (e.g. x, _, _, _, := f())
|
||||
#- errchkjson # [don't see profit + I'm against of omitting errors like in the first example https://github.com/breml/errchkjson] Checks types passed to the json encoding functions. Reports unsupported types and optionally reports occasions, where the check for the returned error can be omitted.
|
||||
#- forcetypeassert # [replaced by errcheck] finds forced type assertions
|
||||
#- gci # Gci controls golang package import order and makes it always deterministic.
|
||||
#- godox # Tool for detection of FIXME, TODO and other comment keywords
|
||||
#- goerr113 # [too strict] Golang linter to check the errors handling expressions
|
||||
#- gofmt # [replaced by goimports] Gofmt checks whether code was gofmt-ed. By default this tool runs with -s option to check for code simplification
|
||||
#- gofumpt # [replaced by goimports, gofumports is not available yet] Gofumpt checks whether code was gofumpt-ed.
|
||||
#- grouper # An analyzer to analyze expression groups.
|
||||
#- ifshort # Checks that your code uses short syntax for if-statements whenever possible
|
||||
#- importas # Enforces consistent import aliases
|
||||
#- maintidx # maintidx measures the maintainability index of each function.
|
||||
#- misspell # [useless] Finds commonly misspelled English words in comments
|
||||
#- nlreturn # [too strict and mostly code is not more readable] nlreturn checks for a new line before return and branch statements to increase code clarity
|
||||
#- nosnakecase # Detects snake case of variable naming and function name. # TODO: maybe enable after https://github.com/sivchari/nosnakecase/issues/14
|
||||
#- paralleltest # [too many false positives] paralleltest detects missing usage of t.Parallel() method in your Go test
|
||||
#- tagliatelle # Checks the struct tags.
|
||||
#- thelper # thelper detects golang test helpers without t.Helper() call and checks the consistency of test helpers
|
||||
#- wsl # [too strict and mostly code is not more readable] Whitespace Linter - Forces you to use empty lines!
|
||||
## deprecated
|
||||
#- exhaustivestruct # [deprecated, replaced by exhaustruct] Checks if all struct's fields are initialized
|
||||
#- golint # [deprecated, replaced by revive] Golint differs from gofmt. Gofmt reformats Go source code, whereas golint prints out style mistakes
|
||||
#- interfacer # [deprecated] Linter that suggests narrower interface types
|
||||
#- maligned # [deprecated, replaced by govet fieldalignment] Tool to detect Go structs that would take less memory if their fields were sorted
|
||||
#- scopelint # [deprecated, replaced by exportloopref] Scopelint checks for unpinned variables in go programs
|
||||
|
||||
|
||||
issues:
|
||||
# Maximum count of issues with the same text.
|
||||
# Set to 0 to disable.
|
||||
# Default: 3
|
||||
max-same-issues: 50
|
||||
|
||||
exclude-rules:
|
||||
- source: "^//\\s*go:generate\\s"
|
||||
linters: [ lll ]
|
||||
- source: "(noinspection|TODO)"
|
||||
linters: [ godot ]
|
||||
- source: "//noinspection"
|
||||
linters: [ gocritic ]
|
||||
- source: "^\\s+if _, ok := err\\.\\([^.]+\\.InternalError\\); ok {"
|
||||
linters: [ errorlint ]
|
||||
- path: "_test\\.go"
|
||||
linters:
|
||||
exclusions:
|
||||
generated: lax
|
||||
presets:
|
||||
- comments
|
||||
- common-false-positives
|
||||
- legacy
|
||||
- std-error-handling
|
||||
rules:
|
||||
- linters:
|
||||
- forbidigo
|
||||
- mnd
|
||||
- revive
|
||||
path : ^examples/.*\.go$
|
||||
- linters:
|
||||
- lll
|
||||
source: ^//\s*go:generate\s
|
||||
- linters:
|
||||
- godot
|
||||
source: (noinspection|TODO)
|
||||
- linters:
|
||||
- gocritic
|
||||
source: //noinspection
|
||||
- linters:
|
||||
- errorlint
|
||||
source: ^\s+if _, ok := err\.\([^.]+\.InternalError\); ok {
|
||||
- linters:
|
||||
- bodyclose
|
||||
- dupl
|
||||
- funlen
|
||||
@@ -271,3 +149,20 @@ issues:
|
||||
- gosec
|
||||
- noctx
|
||||
- wrapcheck
|
||||
- staticcheck
|
||||
path: _test\.go
|
||||
paths:
|
||||
- third_party$
|
||||
- builtin$
|
||||
- examples$
|
||||
issues:
|
||||
max-same-issues: 50
|
||||
formatters:
|
||||
enable:
|
||||
- goimports
|
||||
exclusions:
|
||||
generated: lax
|
||||
paths:
|
||||
- third_party$
|
||||
- builtin$
|
||||
- examples$
|
||||
|
||||
62
README.md
62
README.md
@@ -7,7 +7,7 @@ This library provides unofficial Go clients for [OpenAI API](https://platform.op
|
||||
|
||||
* ChatGPT 4o, o1
|
||||
* GPT-3, GPT-4
|
||||
* DALL·E 2, DALL·E 3
|
||||
* DALL·E 2, DALL·E 3, GPT Image 1
|
||||
* Whisper
|
||||
|
||||
## Installation
|
||||
@@ -357,6 +357,66 @@ func main() {
|
||||
```
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>GPT Image 1 image generation</summary>
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
openai "github.com/sashabaranov/go-openai"
|
||||
)
|
||||
|
||||
func main() {
|
||||
c := openai.NewClient("your token")
|
||||
ctx := context.Background()
|
||||
|
||||
req := openai.ImageRequest{
|
||||
Prompt: "Parrot on a skateboard performing a trick. Large bold text \"SKATE MASTER\" banner at the bottom of the image. Cartoon style, natural light, high detail, 1:1 aspect ratio.",
|
||||
Background: openai.CreateImageBackgroundOpaque,
|
||||
Model: openai.CreateImageModelGptImage1,
|
||||
Size: openai.CreateImageSize1024x1024,
|
||||
N: 1,
|
||||
Quality: openai.CreateImageQualityLow,
|
||||
OutputCompression: 100,
|
||||
OutputFormat: openai.CreateImageOutputFormatJPEG,
|
||||
// Moderation: openai.CreateImageModerationLow,
|
||||
// User: "",
|
||||
}
|
||||
|
||||
resp, err := c.CreateImage(ctx, req)
|
||||
if err != nil {
|
||||
fmt.Printf("Image creation Image generation with GPT Image 1error: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Println("Image Base64:", resp.Data[0].B64JSON)
|
||||
|
||||
// Decode the base64 data
|
||||
imgBytes, err := base64.StdEncoding.DecodeString(resp.Data[0].B64JSON)
|
||||
if err != nil {
|
||||
fmt.Printf("Base64 decode error: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Write image to file
|
||||
outputPath := "generated_image.jpg"
|
||||
err = os.WriteFile(outputPath, imgBytes, 0644)
|
||||
if err != nil {
|
||||
fmt.Printf("Failed to write image file: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Printf("The image was saved as %s\n", outputPath)
|
||||
}
|
||||
```
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>Configuring proxy</summary>
|
||||
|
||||
|
||||
131
audio_test.go
131
audio_test.go
@@ -2,8 +2,11 @@ package openai //nolint:testpackage // testing private field
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
@@ -107,3 +110,131 @@ func TestCreateFileField(t *testing.T) {
|
||||
checks.HasError(t, err, "createFileField using file should return error when open file fails")
|
||||
})
|
||||
}
|
||||
|
||||
// failingFormBuilder always returns an error when creating form files.
|
||||
type failingFormBuilder struct{ err error }
|
||||
|
||||
func (f *failingFormBuilder) CreateFormFile(_ string, _ *os.File) error {
|
||||
return f.err
|
||||
}
|
||||
|
||||
func (f *failingFormBuilder) CreateFormFileReader(_ string, _ io.Reader, _ string) error {
|
||||
return f.err
|
||||
}
|
||||
|
||||
func (f *failingFormBuilder) WriteField(_, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *failingFormBuilder) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *failingFormBuilder) FormDataContentType() string {
|
||||
return "multipart/form-data"
|
||||
}
|
||||
|
||||
// failingAudioRequestBuilder simulates an error during HTTP request construction.
|
||||
type failingAudioRequestBuilder struct{ err error }
|
||||
|
||||
func (f *failingAudioRequestBuilder) Build(
|
||||
_ context.Context,
|
||||
_, _ string,
|
||||
_ any,
|
||||
_ http.Header,
|
||||
) (*http.Request, error) {
|
||||
return nil, f.err
|
||||
}
|
||||
|
||||
// errorHTTPClient always returns an error when making HTTP calls.
|
||||
type errorHTTPClient struct{ err error }
|
||||
|
||||
func (e *errorHTTPClient) Do(_ *http.Request) (*http.Response, error) {
|
||||
return nil, e.err
|
||||
}
|
||||
|
||||
func TestCallAudioAPIMultipartFormError(t *testing.T) {
|
||||
client := NewClient("test-token")
|
||||
errForm := errors.New("mock create form file failure")
|
||||
// Override form builder to force an error during multipart form creation.
|
||||
client.createFormBuilder = func(_ io.Writer) utils.FormBuilder {
|
||||
return &failingFormBuilder{err: errForm}
|
||||
}
|
||||
|
||||
// Provide a reader so createFileField uses the reader path (no file open).
|
||||
req := AudioRequest{FilePath: "fake.mp3", Reader: bytes.NewBuffer([]byte("dummy")), Model: Whisper1}
|
||||
_, err := client.callAudioAPI(context.Background(), req, "transcriptions")
|
||||
if err == nil {
|
||||
t.Fatal("expected error but got none")
|
||||
}
|
||||
if !errors.Is(err, errForm) {
|
||||
t.Errorf("expected error %v, got %v", errForm, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCallAudioAPINewRequestError(t *testing.T) {
|
||||
client := NewClient("test-token")
|
||||
// Create a real temp file so multipart form succeeds.
|
||||
tmp := t.TempDir()
|
||||
path := filepath.Join(tmp, "file.mp3")
|
||||
if err := os.WriteFile(path, []byte("content"), 0644); err != nil {
|
||||
t.Fatalf("failed to write temp file: %v", err)
|
||||
}
|
||||
|
||||
errBuild := errors.New("mock build failure")
|
||||
client.requestBuilder = &failingAudioRequestBuilder{err: errBuild}
|
||||
|
||||
req := AudioRequest{FilePath: path, Model: Whisper1}
|
||||
_, err := client.callAudioAPI(context.Background(), req, "translations")
|
||||
if err == nil {
|
||||
t.Fatal("expected error but got none")
|
||||
}
|
||||
if !errors.Is(err, errBuild) {
|
||||
t.Errorf("expected error %v, got %v", errBuild, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCallAudioAPISendRequestErrorJSON(t *testing.T) {
|
||||
client := NewClient("test-token")
|
||||
// Create a real temp file so multipart form succeeds.
|
||||
tmp := t.TempDir()
|
||||
path := filepath.Join(tmp, "file.mp3")
|
||||
if err := os.WriteFile(path, []byte("content"), 0644); err != nil {
|
||||
t.Fatalf("failed to write temp file: %v", err)
|
||||
}
|
||||
|
||||
errHTTP := errors.New("mock HTTPClient failure")
|
||||
// Override HTTP client to simulate a network error.
|
||||
client.config.HTTPClient = &errorHTTPClient{err: errHTTP}
|
||||
|
||||
req := AudioRequest{FilePath: path, Model: Whisper1}
|
||||
_, err := client.callAudioAPI(context.Background(), req, "transcriptions")
|
||||
if err == nil {
|
||||
t.Fatal("expected error but got none")
|
||||
}
|
||||
if !errors.Is(err, errHTTP) {
|
||||
t.Errorf("expected error %v, got %v", errHTTP, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCallAudioAPISendRequestErrorText(t *testing.T) {
|
||||
client := NewClient("test-token")
|
||||
tmp := t.TempDir()
|
||||
path := filepath.Join(tmp, "file.mp3")
|
||||
if err := os.WriteFile(path, []byte("content"), 0644); err != nil {
|
||||
t.Fatalf("failed to write temp file: %v", err)
|
||||
}
|
||||
|
||||
errHTTP := errors.New("mock HTTPClient failure")
|
||||
client.config.HTTPClient = &errorHTTPClient{err: errHTTP}
|
||||
|
||||
// Use a non-JSON response format to exercise the text path.
|
||||
req := AudioRequest{FilePath: path, Model: Whisper1, Format: AudioResponseFormatText}
|
||||
_, err := client.callAudioAPI(context.Background(), req, "translations")
|
||||
if err == nil {
|
||||
t.Fatal("expected error but got none")
|
||||
}
|
||||
if !errors.Is(err, errHTTP) {
|
||||
t.Errorf("expected error %v, got %v", errHTTP, err)
|
||||
}
|
||||
}
|
||||
|
||||
25
chat.go
25
chat.go
@@ -14,6 +14,7 @@ const (
|
||||
ChatMessageRoleAssistant = "assistant"
|
||||
ChatMessageRoleFunction = "function"
|
||||
ChatMessageRoleTool = "tool"
|
||||
ChatMessageRoleDeveloper = "developer"
|
||||
)
|
||||
|
||||
const chatCompletionsSuffix = "/chat/completions"
|
||||
@@ -103,6 +104,12 @@ type ChatCompletionMessage struct {
|
||||
// - https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
||||
Name string `json:"name,omitempty"`
|
||||
|
||||
// This property is used for the "reasoning" feature supported by deepseek-reasoner
|
||||
// which is not in the official documentation.
|
||||
// the doc from deepseek:
|
||||
// - https://api-docs.deepseek.com/api/create-chat-completion#responses
|
||||
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||
|
||||
FunctionCall *FunctionCall `json:"function_call,omitempty"`
|
||||
|
||||
// For Role=assistant prompts this may be set to the tool calls generated by the model, such as function calls.
|
||||
@@ -123,6 +130,7 @@ func (m ChatCompletionMessage) MarshalJSON() ([]byte, error) {
|
||||
Refusal string `json:"refusal,omitempty"`
|
||||
MultiContent []ChatMessagePart `json:"content,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||
FunctionCall *FunctionCall `json:"function_call,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
@@ -136,6 +144,7 @@ func (m ChatCompletionMessage) MarshalJSON() ([]byte, error) {
|
||||
Refusal string `json:"refusal,omitempty"`
|
||||
MultiContent []ChatMessagePart `json:"-"`
|
||||
Name string `json:"name,omitempty"`
|
||||
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||
FunctionCall *FunctionCall `json:"function_call,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
@@ -146,10 +155,11 @@ func (m ChatCompletionMessage) MarshalJSON() ([]byte, error) {
|
||||
func (m *ChatCompletionMessage) UnmarshalJSON(bs []byte) error {
|
||||
msg := struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content,omitempty"`
|
||||
Content string `json:"content"`
|
||||
Refusal string `json:"refusal,omitempty"`
|
||||
MultiContent []ChatMessagePart
|
||||
Name string `json:"name,omitempty"`
|
||||
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||
FunctionCall *FunctionCall `json:"function_call,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
@@ -165,6 +175,7 @@ func (m *ChatCompletionMessage) UnmarshalJSON(bs []byte) error {
|
||||
Refusal string `json:"refusal,omitempty"`
|
||||
MultiContent []ChatMessagePart `json:"content"`
|
||||
Name string `json:"name,omitempty"`
|
||||
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||
FunctionCall *FunctionCall `json:"function_call,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
@@ -264,6 +275,13 @@ type ChatCompletionRequest struct {
|
||||
Metadata map[string]string `json:"metadata,omitempty"`
|
||||
IncludeReasoning *bool `json:"include_reasoning,omitempty"`
|
||||
ReasoningFormat *string `json:"reasoning_format,omitempty"`
|
||||
// Configuration for a predicted output.
|
||||
Prediction *Prediction `json:"prediction,omitempty"`
|
||||
// ChatTemplateKwargs provides a way to add non-standard parameters to the request body.
|
||||
// Additional kwargs to pass to the template renderer. Will be accessible by the chat template.
|
||||
// Such as think mode for qwen3. "chat_template_kwargs": {"enable_thinking": false}
|
||||
// https://qwen.readthedocs.io/en/latest/deployment/vllm.html#thinking-non-thinking-modes
|
||||
ChatTemplateKwargs map[string]any `json:"chat_template_kwargs,omitempty"`
|
||||
}
|
||||
|
||||
type StreamOptions struct {
|
||||
@@ -331,6 +349,11 @@ type LogProbs struct {
|
||||
Content []LogProb `json:"content"`
|
||||
}
|
||||
|
||||
type Prediction struct {
|
||||
Content string `json:"content"`
|
||||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
type FinishReason string
|
||||
|
||||
const (
|
||||
|
||||
@@ -11,6 +11,12 @@ type ChatCompletionStreamChoiceDelta struct {
|
||||
FunctionCall *FunctionCall `json:"function_call,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
Refusal string `json:"refusal,omitempty"`
|
||||
|
||||
// This property is used for the "reasoning" feature supported by deepseek-reasoner
|
||||
// which is not in the official documentation.
|
||||
// the doc from deepseek:
|
||||
// - https://api-docs.deepseek.com/api/create-chat-completion#responses
|
||||
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||
}
|
||||
|
||||
type ChatCompletionStreamChoiceLogprobs struct {
|
||||
|
||||
@@ -959,6 +959,56 @@ func TestCreateChatCompletionStreamReasoningValidatorFails(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateChatCompletionStreamO3ReasoningValidatorFails(t *testing.T) {
|
||||
client, _, _ := setupOpenAITestServer()
|
||||
|
||||
stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{
|
||||
MaxTokens: 100, // This will trigger the validator to fail
|
||||
Model: openai.O3,
|
||||
Messages: []openai.ChatCompletionMessage{
|
||||
{
|
||||
Role: openai.ChatMessageRoleUser,
|
||||
Content: "Hello!",
|
||||
},
|
||||
},
|
||||
Stream: true,
|
||||
})
|
||||
|
||||
if stream != nil {
|
||||
t.Error("Expected nil stream when validation fails")
|
||||
stream.Close()
|
||||
}
|
||||
|
||||
if !errors.Is(err, openai.ErrReasoningModelMaxTokensDeprecated) {
|
||||
t.Errorf("Expected ErrReasoningModelMaxTokensDeprecated for O3, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateChatCompletionStreamO4MiniReasoningValidatorFails(t *testing.T) {
|
||||
client, _, _ := setupOpenAITestServer()
|
||||
|
||||
stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{
|
||||
MaxTokens: 100, // This will trigger the validator to fail
|
||||
Model: openai.O4Mini,
|
||||
Messages: []openai.ChatCompletionMessage{
|
||||
{
|
||||
Role: openai.ChatMessageRoleUser,
|
||||
Content: "Hello!",
|
||||
},
|
||||
},
|
||||
Stream: true,
|
||||
})
|
||||
|
||||
if stream != nil {
|
||||
t.Error("Expected nil stream when validation fails")
|
||||
stream.Close()
|
||||
}
|
||||
|
||||
if !errors.Is(err, openai.ErrReasoningModelMaxTokensDeprecated) {
|
||||
t.Errorf("Expected ErrReasoningModelMaxTokensDeprecated for O4Mini, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func compareChatStreamResponseChoices(c1, c2 openai.ChatCompletionStreamChoice) bool {
|
||||
if c1.Index != c2.Index {
|
||||
return false
|
||||
|
||||
113
chat_test.go
113
chat_test.go
@@ -106,40 +106,6 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) {
|
||||
},
|
||||
expectedError: openai.ErrReasoningModelLimitationsLogprobs,
|
||||
},
|
||||
{
|
||||
name: "message_type_unsupported",
|
||||
in: openai.ChatCompletionRequest{
|
||||
MaxCompletionTokens: 1000,
|
||||
Model: openai.O1Mini,
|
||||
Messages: []openai.ChatCompletionMessage{
|
||||
{
|
||||
Role: openai.ChatMessageRoleSystem,
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedError: openai.ErrO1BetaLimitationsMessageTypes,
|
||||
},
|
||||
{
|
||||
name: "tool_unsupported",
|
||||
in: openai.ChatCompletionRequest{
|
||||
MaxCompletionTokens: 1000,
|
||||
Model: openai.O1Mini,
|
||||
Messages: []openai.ChatCompletionMessage{
|
||||
{
|
||||
Role: openai.ChatMessageRoleUser,
|
||||
},
|
||||
{
|
||||
Role: openai.ChatMessageRoleAssistant,
|
||||
},
|
||||
},
|
||||
Tools: []openai.Tool{
|
||||
{
|
||||
Type: openai.ToolTypeFunction,
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedError: openai.ErrO1BetaLimitationsTools,
|
||||
},
|
||||
{
|
||||
name: "set_temperature_unsupported",
|
||||
in: openai.ChatCompletionRequest{
|
||||
@@ -445,6 +411,23 @@ func TestO3ModelChatCompletions(t *testing.T) {
|
||||
checks.NoError(t, err, "CreateChatCompletion error")
|
||||
}
|
||||
|
||||
func TestDeepseekR1ModelChatCompletions(t *testing.T) {
|
||||
client, server, teardown := setupOpenAITestServer()
|
||||
defer teardown()
|
||||
server.RegisterHandler("/v1/chat/completions", handleDeepseekR1ChatCompletionEndpoint)
|
||||
_, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{
|
||||
Model: "deepseek-reasoner",
|
||||
MaxCompletionTokens: 100,
|
||||
Messages: []openai.ChatCompletionMessage{
|
||||
{
|
||||
Role: openai.ChatMessageRoleUser,
|
||||
Content: "Hello!",
|
||||
},
|
||||
},
|
||||
})
|
||||
checks.NoError(t, err, "CreateChatCompletion error")
|
||||
}
|
||||
|
||||
// TestCompletions Tests the completions endpoint of the API using the mocked server.
|
||||
func TestChatCompletionsWithHeaders(t *testing.T) {
|
||||
client, server, teardown := setupOpenAITestServer()
|
||||
@@ -856,6 +839,68 @@ func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, string(resBytes))
|
||||
}
|
||||
|
||||
func handleDeepseekR1ChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
|
||||
var err error
|
||||
var resBytes []byte
|
||||
|
||||
// completions only accepts POST requests
|
||||
if r.Method != "POST" {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
}
|
||||
var completionReq openai.ChatCompletionRequest
|
||||
if completionReq, err = getChatCompletionBody(r); err != nil {
|
||||
http.Error(w, "could not read request", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
res := openai.ChatCompletionResponse{
|
||||
ID: strconv.Itoa(int(time.Now().Unix())),
|
||||
Object: "test-object",
|
||||
Created: time.Now().Unix(),
|
||||
// would be nice to validate Model during testing, but
|
||||
// this may not be possible with how much upkeep
|
||||
// would be required / wouldn't make much sense
|
||||
Model: completionReq.Model,
|
||||
}
|
||||
// create completions
|
||||
n := completionReq.N
|
||||
if n == 0 {
|
||||
n = 1
|
||||
}
|
||||
if completionReq.MaxCompletionTokens == 0 {
|
||||
completionReq.MaxCompletionTokens = 1000
|
||||
}
|
||||
for i := 0; i < n; i++ {
|
||||
reasoningContent := "User says hello! And I need to reply"
|
||||
completionStr := strings.Repeat("a", completionReq.MaxCompletionTokens-numTokens(reasoningContent))
|
||||
res.Choices = append(res.Choices, openai.ChatCompletionChoice{
|
||||
Message: openai.ChatCompletionMessage{
|
||||
Role: openai.ChatMessageRoleAssistant,
|
||||
ReasoningContent: reasoningContent,
|
||||
Content: completionStr,
|
||||
},
|
||||
Index: i,
|
||||
})
|
||||
}
|
||||
inputTokens := numTokens(completionReq.Messages[0].Content) * n
|
||||
completionTokens := completionReq.MaxTokens * n
|
||||
res.Usage = openai.Usage{
|
||||
PromptTokens: inputTokens,
|
||||
CompletionTokens: completionTokens,
|
||||
TotalTokens: inputTokens + completionTokens,
|
||||
}
|
||||
resBytes, _ = json.Marshal(res)
|
||||
w.Header().Set(xCustomHeader, xCustomHeaderValue)
|
||||
for k, v := range rateLimitHeaders {
|
||||
switch val := v.(type) {
|
||||
case int:
|
||||
w.Header().Set(k, strconv.Itoa(val))
|
||||
default:
|
||||
w.Header().Set(k, fmt.Sprintf("%s", v))
|
||||
}
|
||||
}
|
||||
fmt.Fprintln(w, string(resBytes))
|
||||
}
|
||||
|
||||
// getChatCompletionBody Returns the body of the request to create a completion.
|
||||
func getChatCompletionBody(r *http.Request) (openai.ChatCompletionRequest, error) {
|
||||
completion := openai.ChatCompletionRequest{}
|
||||
|
||||
14
client.go
14
client.go
@@ -182,13 +182,21 @@ func sendRequestStream[T streamable](client *Client, req *http.Request) (*stream
|
||||
|
||||
func (c *Client) setCommonHeaders(req *http.Request) {
|
||||
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/reference#authentication
|
||||
switch c.config.APIType {
|
||||
case APITypeAzure, APITypeCloudflareAzure:
|
||||
// Azure API Key authentication
|
||||
if c.config.APIType == APITypeAzure || c.config.APIType == APITypeCloudflareAzure {
|
||||
req.Header.Set(AzureAPIKeyHeader, c.config.authToken)
|
||||
} else if c.config.authToken != "" {
|
||||
// OpenAI or Azure AD authentication
|
||||
case APITypeAnthropic:
|
||||
// https://docs.anthropic.com/en/api/versioning
|
||||
req.Header.Set("anthropic-version", c.config.APIVersion)
|
||||
case APITypeOpenAI, APITypeAzureAD:
|
||||
fallthrough
|
||||
default:
|
||||
if c.config.authToken != "" {
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken))
|
||||
}
|
||||
}
|
||||
|
||||
if c.config.OrgID != "" {
|
||||
req.Header.Set("OpenAI-Organization", c.config.OrgID)
|
||||
}
|
||||
|
||||
@@ -39,6 +39,21 @@ func TestClient(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetCommonHeadersAnthropic(t *testing.T) {
|
||||
config := DefaultAnthropicConfig("mock-token", "")
|
||||
client := NewClientWithConfig(config)
|
||||
req, err := http.NewRequest("GET", "http://example.com", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create request: %v", err)
|
||||
}
|
||||
|
||||
client.setCommonHeaders(req)
|
||||
|
||||
if got := req.Header.Get("anthropic-version"); got != AnthropicAPIVersion {
|
||||
t.Errorf("Expected anthropic-version header to be %q, got %q", AnthropicAPIVersion, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeResponse(t *testing.T) {
|
||||
stringInput := ""
|
||||
|
||||
|
||||
@@ -15,6 +15,8 @@ type Usage struct {
|
||||
type CompletionTokensDetails struct {
|
||||
AudioTokens int `json:"audio_tokens"`
|
||||
ReasoningTokens int `json:"reasoning_tokens"`
|
||||
AcceptedPredictionTokens int `json:"accepted_prediction_tokens"`
|
||||
RejectedPredictionTokens int `json:"rejected_prediction_tokens"`
|
||||
}
|
||||
|
||||
// PromptTokensDetails Breakdown of tokens used in the prompt.
|
||||
|
||||
@@ -16,8 +16,12 @@ const (
|
||||
O1Preview20240912 = "o1-preview-2024-09-12"
|
||||
O1 = "o1"
|
||||
O120241217 = "o1-2024-12-17"
|
||||
O3 = "o3"
|
||||
O320250416 = "o3-2025-04-16"
|
||||
O3Mini = "o3-mini"
|
||||
O3Mini20250131 = "o3-mini-2025-01-31"
|
||||
O4Mini = "o4-mini"
|
||||
O4Mini20250416 = "o4-mini-2025-04-16"
|
||||
GPT432K0613 = "gpt-4-32k-0613"
|
||||
GPT432K0314 = "gpt-4-32k-0314"
|
||||
GPT432K = "gpt-4-32k"
|
||||
@@ -37,6 +41,14 @@ const (
|
||||
GPT4TurboPreview = "gpt-4-turbo-preview"
|
||||
GPT4VisionPreview = "gpt-4-vision-preview"
|
||||
GPT4 = "gpt-4"
|
||||
GPT4Dot1 = "gpt-4.1"
|
||||
GPT4Dot120250414 = "gpt-4.1-2025-04-14"
|
||||
GPT4Dot1Mini = "gpt-4.1-mini"
|
||||
GPT4Dot1Mini20250414 = "gpt-4.1-mini-2025-04-14"
|
||||
GPT4Dot1Nano = "gpt-4.1-nano"
|
||||
GPT4Dot1Nano20250414 = "gpt-4.1-nano-2025-04-14"
|
||||
GPT4Dot5Preview = "gpt-4.5-preview"
|
||||
GPT4Dot5Preview20250227 = "gpt-4.5-preview-2025-02-27"
|
||||
GPT3Dot5Turbo0125 = "gpt-3.5-turbo-0125"
|
||||
GPT3Dot5Turbo1106 = "gpt-3.5-turbo-1106"
|
||||
GPT3Dot5Turbo0613 = "gpt-3.5-turbo-0613"
|
||||
@@ -91,6 +103,10 @@ var disabledModelsForEndpoints = map[string]map[string]bool{
|
||||
O1Preview20240912: true,
|
||||
O3Mini: true,
|
||||
O3Mini20250131: true,
|
||||
O4Mini: true,
|
||||
O4Mini20250416: true,
|
||||
O3: true,
|
||||
O320250416: true,
|
||||
GPT3Dot5Turbo: true,
|
||||
GPT3Dot5Turbo0301: true,
|
||||
GPT3Dot5Turbo0613: true,
|
||||
@@ -99,6 +115,8 @@ var disabledModelsForEndpoints = map[string]map[string]bool{
|
||||
GPT3Dot5Turbo16K: true,
|
||||
GPT3Dot5Turbo16K0613: true,
|
||||
GPT4: true,
|
||||
GPT4Dot5Preview: true,
|
||||
GPT4Dot5Preview20250227: true,
|
||||
GPT4o: true,
|
||||
GPT4o20240513: true,
|
||||
GPT4o20240806: true,
|
||||
@@ -117,6 +135,13 @@ var disabledModelsForEndpoints = map[string]map[string]bool{
|
||||
GPT432K: true,
|
||||
GPT432K0314: true,
|
||||
GPT432K0613: true,
|
||||
O1: true,
|
||||
GPT4Dot1: true,
|
||||
GPT4Dot120250414: true,
|
||||
GPT4Dot1Mini: true,
|
||||
GPT4Dot1Mini20250414: true,
|
||||
GPT4Dot1Nano: true,
|
||||
GPT4Dot1Nano20250414: true,
|
||||
},
|
||||
chatCompletionsSuffix: {
|
||||
CodexCodeDavinci002: true,
|
||||
@@ -190,6 +215,8 @@ type CompletionRequest struct {
|
||||
Temperature float32 `json:"temperature,omitempty"`
|
||||
TopP float32 `json:"top_p,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
// Options for streaming response. Only set this when you set stream: true.
|
||||
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
|
||||
}
|
||||
|
||||
// CompletionChoice represents one of possible completions.
|
||||
|
||||
@@ -33,6 +33,42 @@ func TestCompletionsWrongModel(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestCompletionsWrongModelO3 Tests the completions endpoint with O3 model which is not supported.
|
||||
func TestCompletionsWrongModelO3(t *testing.T) {
|
||||
config := openai.DefaultConfig("whatever")
|
||||
config.BaseURL = "http://localhost/v1"
|
||||
client := openai.NewClientWithConfig(config)
|
||||
|
||||
_, err := client.CreateCompletion(
|
||||
context.Background(),
|
||||
openai.CompletionRequest{
|
||||
MaxTokens: 5,
|
||||
Model: openai.O3,
|
||||
},
|
||||
)
|
||||
if !errors.Is(err, openai.ErrCompletionUnsupportedModel) {
|
||||
t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel for O3, but returned: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCompletionsWrongModelO4Mini Tests the completions endpoint with O4Mini model which is not supported.
|
||||
func TestCompletionsWrongModelO4Mini(t *testing.T) {
|
||||
config := openai.DefaultConfig("whatever")
|
||||
config.BaseURL = "http://localhost/v1"
|
||||
client := openai.NewClientWithConfig(config)
|
||||
|
||||
_, err := client.CreateCompletion(
|
||||
context.Background(),
|
||||
openai.CompletionRequest{
|
||||
MaxTokens: 5,
|
||||
Model: openai.O4Mini,
|
||||
},
|
||||
)
|
||||
if !errors.Is(err, openai.ErrCompletionUnsupportedModel) {
|
||||
t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel for O4Mini, but returned: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompletionWithStream(t *testing.T) {
|
||||
config := openai.DefaultConfig("whatever")
|
||||
client := openai.NewClientWithConfig(config)
|
||||
@@ -181,3 +217,86 @@ func getCompletionBody(r *http.Request) (openai.CompletionRequest, error) {
|
||||
}
|
||||
return completion, nil
|
||||
}
|
||||
|
||||
// TestCompletionWithO1Model Tests that O1 model is not supported for completion endpoint.
|
||||
func TestCompletionWithO1Model(t *testing.T) {
|
||||
config := openai.DefaultConfig("whatever")
|
||||
config.BaseURL = "http://localhost/v1"
|
||||
client := openai.NewClientWithConfig(config)
|
||||
|
||||
_, err := client.CreateCompletion(
|
||||
context.Background(),
|
||||
openai.CompletionRequest{
|
||||
MaxTokens: 5,
|
||||
Model: openai.O1,
|
||||
},
|
||||
)
|
||||
if !errors.Is(err, openai.ErrCompletionUnsupportedModel) {
|
||||
t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel for O1 model, but returned: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCompletionWithGPT4DotModels Tests that newer GPT4 models are not supported for completion endpoint.
|
||||
func TestCompletionWithGPT4DotModels(t *testing.T) {
|
||||
config := openai.DefaultConfig("whatever")
|
||||
config.BaseURL = "http://localhost/v1"
|
||||
client := openai.NewClientWithConfig(config)
|
||||
|
||||
models := []string{
|
||||
openai.GPT4Dot1,
|
||||
openai.GPT4Dot120250414,
|
||||
openai.GPT4Dot1Mini,
|
||||
openai.GPT4Dot1Mini20250414,
|
||||
openai.GPT4Dot1Nano,
|
||||
openai.GPT4Dot1Nano20250414,
|
||||
openai.GPT4Dot5Preview,
|
||||
openai.GPT4Dot5Preview20250227,
|
||||
}
|
||||
|
||||
for _, model := range models {
|
||||
t.Run(model, func(t *testing.T) {
|
||||
_, err := client.CreateCompletion(
|
||||
context.Background(),
|
||||
openai.CompletionRequest{
|
||||
MaxTokens: 5,
|
||||
Model: model,
|
||||
},
|
||||
)
|
||||
if !errors.Is(err, openai.ErrCompletionUnsupportedModel) {
|
||||
t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel for %s model, but returned: %v", model, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestCompletionWithGPT4oModels Tests that GPT4o models are not supported for completion endpoint.
|
||||
func TestCompletionWithGPT4oModels(t *testing.T) {
|
||||
config := openai.DefaultConfig("whatever")
|
||||
config.BaseURL = "http://localhost/v1"
|
||||
client := openai.NewClientWithConfig(config)
|
||||
|
||||
models := []string{
|
||||
openai.GPT4o,
|
||||
openai.GPT4o20240513,
|
||||
openai.GPT4o20240806,
|
||||
openai.GPT4o20241120,
|
||||
openai.GPT4oLatest,
|
||||
openai.GPT4oMini,
|
||||
openai.GPT4oMini20240718,
|
||||
}
|
||||
|
||||
for _, model := range models {
|
||||
t.Run(model, func(t *testing.T) {
|
||||
_, err := client.CreateCompletion(
|
||||
context.Background(),
|
||||
openai.CompletionRequest{
|
||||
MaxTokens: 5,
|
||||
Model: model,
|
||||
},
|
||||
)
|
||||
if !errors.Is(err, openai.ErrCompletionUnsupportedModel) {
|
||||
t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel for %s model, but returned: %v", model, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
22
config.go
22
config.go
@@ -11,6 +11,8 @@ const (
|
||||
|
||||
azureAPIPrefix = "openai"
|
||||
azureDeploymentsPrefix = "deployments"
|
||||
|
||||
AnthropicAPIVersion = "2023-06-01"
|
||||
)
|
||||
|
||||
type APIType string
|
||||
@@ -20,6 +22,7 @@ const (
|
||||
APITypeAzure APIType = "AZURE"
|
||||
APITypeAzureAD APIType = "AZURE_AD"
|
||||
APITypeCloudflareAzure APIType = "CLOUDFLARE_AZURE"
|
||||
APITypeAnthropic APIType = "ANTHROPIC"
|
||||
)
|
||||
|
||||
const AzureAPIKeyHeader = "api-key"
|
||||
@@ -37,7 +40,7 @@ type ClientConfig struct {
|
||||
BaseURL string
|
||||
OrgID string
|
||||
APIType APIType
|
||||
APIVersion string // required when APIType is APITypeAzure or APITypeAzureAD
|
||||
APIVersion string // required when APIType is APITypeAzure or APITypeAzureAD or APITypeAnthropic
|
||||
AssistantVersion string
|
||||
AzureModelMapperFunc func(model string) string // replace model to azure deployment name func
|
||||
HTTPClient HTTPDoer
|
||||
@@ -76,6 +79,23 @@ func DefaultAzureConfig(apiKey, baseURL string) ClientConfig {
|
||||
}
|
||||
}
|
||||
|
||||
func DefaultAnthropicConfig(apiKey, baseURL string) ClientConfig {
|
||||
if baseURL == "" {
|
||||
baseURL = "https://api.anthropic.com/v1"
|
||||
}
|
||||
return ClientConfig{
|
||||
authToken: apiKey,
|
||||
BaseURL: baseURL,
|
||||
OrgID: "",
|
||||
APIType: APITypeAnthropic,
|
||||
APIVersion: AnthropicAPIVersion,
|
||||
|
||||
HTTPClient: &http.Client{},
|
||||
|
||||
EmptyMessagesLimit: defaultEmptyMessagesLimit,
|
||||
}
|
||||
}
|
||||
|
||||
func (ClientConfig) String() string {
|
||||
return "<OpenAI API ClientConfig>"
|
||||
}
|
||||
|
||||
@@ -60,3 +60,64 @@ func TestGetAzureDeploymentByModel(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultAnthropicConfig(t *testing.T) {
|
||||
apiKey := "test-key"
|
||||
baseURL := "https://api.anthropic.com/v1"
|
||||
|
||||
config := openai.DefaultAnthropicConfig(apiKey, baseURL)
|
||||
|
||||
if config.APIType != openai.APITypeAnthropic {
|
||||
t.Errorf("Expected APIType to be %v, got %v", openai.APITypeAnthropic, config.APIType)
|
||||
}
|
||||
|
||||
if config.APIVersion != openai.AnthropicAPIVersion {
|
||||
t.Errorf("Expected APIVersion to be 2023-06-01, got %v", config.APIVersion)
|
||||
}
|
||||
|
||||
if config.BaseURL != baseURL {
|
||||
t.Errorf("Expected BaseURL to be %v, got %v", baseURL, config.BaseURL)
|
||||
}
|
||||
|
||||
if config.EmptyMessagesLimit != 300 {
|
||||
t.Errorf("Expected EmptyMessagesLimit to be 300, got %v", config.EmptyMessagesLimit)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultAnthropicConfigWithEmptyValues(t *testing.T) {
|
||||
config := openai.DefaultAnthropicConfig("", "")
|
||||
|
||||
if config.APIType != openai.APITypeAnthropic {
|
||||
t.Errorf("Expected APIType to be %v, got %v", openai.APITypeAnthropic, config.APIType)
|
||||
}
|
||||
|
||||
if config.APIVersion != openai.AnthropicAPIVersion {
|
||||
t.Errorf("Expected APIVersion to be %s, got %v", openai.AnthropicAPIVersion, config.APIVersion)
|
||||
}
|
||||
|
||||
expectedBaseURL := "https://api.anthropic.com/v1"
|
||||
if config.BaseURL != expectedBaseURL {
|
||||
t.Errorf("Expected BaseURL to be %v, got %v", expectedBaseURL, config.BaseURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientConfigString(t *testing.T) {
|
||||
// String() should always return the constant value
|
||||
conf := openai.DefaultConfig("dummy-token")
|
||||
expected := "<OpenAI API ClientConfig>"
|
||||
got := conf.String()
|
||||
if got != expected {
|
||||
t.Errorf("ClientConfig.String() = %q; want %q", got, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAzureDeploymentByModel_NoMapper(t *testing.T) {
|
||||
// On a zero-value or DefaultConfig, AzureModelMapperFunc is nil,
|
||||
// so GetAzureDeploymentByModel should just return the input model.
|
||||
conf := openai.DefaultConfig("dummy-token")
|
||||
model := "some-model"
|
||||
got := conf.GetAzureDeploymentByModel(model)
|
||||
if got != model {
|
||||
t.Errorf("GetAzureDeploymentByModel(%q) = %q; want %q", model, got, model)
|
||||
}
|
||||
}
|
||||
|
||||
74
image.go
74
image.go
@@ -3,8 +3,8 @@ package openai
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
@@ -13,31 +13,62 @@ const (
|
||||
CreateImageSize256x256 = "256x256"
|
||||
CreateImageSize512x512 = "512x512"
|
||||
CreateImageSize1024x1024 = "1024x1024"
|
||||
|
||||
// dall-e-3 supported only.
|
||||
CreateImageSize1792x1024 = "1792x1024"
|
||||
CreateImageSize1024x1792 = "1024x1792"
|
||||
|
||||
// gpt-image-1 supported only.
|
||||
CreateImageSize1536x1024 = "1536x1024" // Landscape
|
||||
CreateImageSize1024x1536 = "1024x1536" // Portrait
|
||||
)
|
||||
|
||||
const (
|
||||
CreateImageResponseFormatURL = "url"
|
||||
// dall-e-2 and dall-e-3 only.
|
||||
CreateImageResponseFormatB64JSON = "b64_json"
|
||||
CreateImageResponseFormatURL = "url"
|
||||
)
|
||||
|
||||
const (
|
||||
CreateImageModelDallE2 = "dall-e-2"
|
||||
CreateImageModelDallE3 = "dall-e-3"
|
||||
CreateImageModelGptImage1 = "gpt-image-1"
|
||||
)
|
||||
|
||||
const (
|
||||
CreateImageQualityHD = "hd"
|
||||
CreateImageQualityStandard = "standard"
|
||||
|
||||
// gpt-image-1 only.
|
||||
CreateImageQualityHigh = "high"
|
||||
CreateImageQualityMedium = "medium"
|
||||
CreateImageQualityLow = "low"
|
||||
)
|
||||
|
||||
const (
|
||||
// dall-e-3 only.
|
||||
CreateImageStyleVivid = "vivid"
|
||||
CreateImageStyleNatural = "natural"
|
||||
)
|
||||
|
||||
const (
|
||||
// gpt-image-1 only.
|
||||
CreateImageBackgroundTransparent = "transparent"
|
||||
CreateImageBackgroundOpaque = "opaque"
|
||||
)
|
||||
|
||||
const (
|
||||
// gpt-image-1 only.
|
||||
CreateImageModerationLow = "low"
|
||||
)
|
||||
|
||||
const (
|
||||
// gpt-image-1 only.
|
||||
CreateImageOutputFormatPNG = "png"
|
||||
CreateImageOutputFormatJPEG = "jpeg"
|
||||
CreateImageOutputFormatWEBP = "webp"
|
||||
)
|
||||
|
||||
// ImageRequest represents the request structure for the image API.
|
||||
type ImageRequest struct {
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
@@ -48,16 +79,35 @@ type ImageRequest struct {
|
||||
Style string `json:"style,omitempty"`
|
||||
ResponseFormat string `json:"response_format,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
Background string `json:"background,omitempty"`
|
||||
Moderation string `json:"moderation,omitempty"`
|
||||
OutputCompression int `json:"output_compression,omitempty"`
|
||||
OutputFormat string `json:"output_format,omitempty"`
|
||||
}
|
||||
|
||||
// ImageResponse represents a response structure for image API.
|
||||
type ImageResponse struct {
|
||||
Created int64 `json:"created,omitempty"`
|
||||
Data []ImageResponseDataInner `json:"data,omitempty"`
|
||||
Usage ImageResponseUsage `json:"usage,omitempty"`
|
||||
|
||||
httpHeader
|
||||
}
|
||||
|
||||
// ImageResponseInputTokensDetails represents the token breakdown for input tokens.
|
||||
type ImageResponseInputTokensDetails struct {
|
||||
TextTokens int `json:"text_tokens,omitempty"`
|
||||
ImageTokens int `json:"image_tokens,omitempty"`
|
||||
}
|
||||
|
||||
// ImageResponseUsage represents the token usage information for image API.
|
||||
type ImageResponseUsage struct {
|
||||
TotalTokens int `json:"total_tokens,omitempty"`
|
||||
InputTokens int `json:"input_tokens,omitempty"`
|
||||
OutputTokens int `json:"output_tokens,omitempty"`
|
||||
InputTokensDetails ImageResponseInputTokensDetails `json:"input_tokens_details,omitempty"`
|
||||
}
|
||||
|
||||
// ImageResponseDataInner represents a response data structure for image API.
|
||||
type ImageResponseDataInner struct {
|
||||
URL string `json:"url,omitempty"`
|
||||
@@ -84,13 +134,15 @@ func (c *Client) CreateImage(ctx context.Context, request ImageRequest) (respons
|
||||
|
||||
// ImageEditRequest represents the request structure for the image API.
|
||||
type ImageEditRequest struct {
|
||||
Image *os.File `json:"image,omitempty"`
|
||||
Mask *os.File `json:"mask,omitempty"`
|
||||
Image io.Reader `json:"image,omitempty"`
|
||||
Mask io.Reader `json:"mask,omitempty"`
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
N int `json:"n,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
ResponseFormat string `json:"response_format,omitempty"`
|
||||
Quality string `json:"quality,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
}
|
||||
|
||||
// CreateEditImage - API call to create an image. This is the main endpoint of the DALL-E API.
|
||||
@@ -98,15 +150,16 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest)
|
||||
body := &bytes.Buffer{}
|
||||
builder := c.createFormBuilder(body)
|
||||
|
||||
// image
|
||||
err = builder.CreateFormFile("image", request.Image)
|
||||
// image, filename is not required
|
||||
err = builder.CreateFormFileReader("image", request.Image, "")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// mask, it is optional
|
||||
if request.Mask != nil {
|
||||
err = builder.CreateFormFile("mask", request.Mask)
|
||||
// mask, filename is not required
|
||||
err = builder.CreateFormFileReader("mask", request.Mask, "")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -154,11 +207,12 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest)
|
||||
|
||||
// ImageVariRequest represents the request structure for the image API.
|
||||
type ImageVariRequest struct {
|
||||
Image *os.File `json:"image,omitempty"`
|
||||
Image io.Reader `json:"image,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
N int `json:"n,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
ResponseFormat string `json:"response_format,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
}
|
||||
|
||||
// CreateVariImage - API call to create an image variation. This is the main endpoint of the DALL-E API.
|
||||
@@ -167,8 +221,8 @@ func (c *Client) CreateVariImage(ctx context.Context, request ImageVariRequest)
|
||||
body := &bytes.Buffer{}
|
||||
builder := c.createFormBuilder(body)
|
||||
|
||||
// image
|
||||
err = builder.CreateFormFile("image", request.Image)
|
||||
// image, filename is not required
|
||||
err = builder.CreateFormFileReader("image", request.Image, "")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -54,13 +54,13 @@ func TestImageFormBuilderFailures(t *testing.T) {
|
||||
}
|
||||
|
||||
mockFailedErr := fmt.Errorf("mock form builder fail")
|
||||
mockBuilder.mockCreateFormFile = func(string, *os.File) error {
|
||||
mockBuilder.mockCreateFormFileReader = func(string, io.Reader, string) error {
|
||||
return mockFailedErr
|
||||
}
|
||||
_, err := client.CreateEditImage(ctx, req)
|
||||
checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails")
|
||||
|
||||
mockBuilder.mockCreateFormFile = func(name string, _ *os.File) error {
|
||||
mockBuilder.mockCreateFormFileReader = func(name string, _ io.Reader, _ string) error {
|
||||
if name == "mask" {
|
||||
return mockFailedErr
|
||||
}
|
||||
@@ -119,13 +119,13 @@ func TestVariImageFormBuilderFailures(t *testing.T) {
|
||||
req := ImageVariRequest{}
|
||||
|
||||
mockFailedErr := fmt.Errorf("mock form builder fail")
|
||||
mockBuilder.mockCreateFormFile = func(string, *os.File) error {
|
||||
mockBuilder.mockCreateFormFileReader = func(string, io.Reader, string) error {
|
||||
return mockFailedErr
|
||||
}
|
||||
_, err := client.CreateVariImage(ctx, req)
|
||||
checks.ErrorIs(t, err, mockFailedErr, "CreateVariImage should return error if form builder fails")
|
||||
|
||||
mockBuilder.mockCreateFormFile = func(string, *os.File) error {
|
||||
mockBuilder.mockCreateFormFileReader = func(string, io.Reader, string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -4,8 +4,10 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/textproto"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type FormBuilder interface {
|
||||
@@ -30,8 +32,37 @@ func (fb *DefaultFormBuilder) CreateFormFile(fieldname string, file *os.File) er
|
||||
return fb.createFormFile(fieldname, file, file.Name())
|
||||
}
|
||||
|
||||
var quoteEscaper = strings.NewReplacer("\\", "\\\\", `"`, "\\\"")
|
||||
|
||||
func escapeQuotes(s string) string {
|
||||
return quoteEscaper.Replace(s)
|
||||
}
|
||||
|
||||
// CreateFormFileReader creates a form field with a file reader.
|
||||
// The filename in parameters can be an empty string.
|
||||
// The filename in Content-Disposition is required, But it can be an empty string.
|
||||
func (fb *DefaultFormBuilder) CreateFormFileReader(fieldname string, r io.Reader, filename string) error {
|
||||
return fb.createFormFile(fieldname, r, path.Base(filename))
|
||||
h := make(textproto.MIMEHeader)
|
||||
h.Set(
|
||||
"Content-Disposition",
|
||||
fmt.Sprintf(
|
||||
`form-data; name="%s"; filename="%s"`,
|
||||
escapeQuotes(fieldname),
|
||||
escapeQuotes(filepath.Base(filename)),
|
||||
),
|
||||
)
|
||||
|
||||
fieldWriter, err := fb.writer.CreatePart(h)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = io.Copy(fieldWriter, r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (fb *DefaultFormBuilder) createFormFile(fieldname string, r io.Reader, filename string) error {
|
||||
|
||||
@@ -43,3 +43,32 @@ func TestFormBuilderWithClosedFile(t *testing.T) {
|
||||
checks.HasError(t, err, "formbuilder should return error if file is closed")
|
||||
checks.ErrorIs(t, err, os.ErrClosed, "formbuilder should return error if file is closed")
|
||||
}
|
||||
|
||||
type failingReader struct {
|
||||
}
|
||||
|
||||
var errMockFailingReaderError = errors.New("mock reader failed")
|
||||
|
||||
func (*failingReader) Read([]byte) (int, error) {
|
||||
return 0, errMockFailingReaderError
|
||||
}
|
||||
|
||||
func TestFormBuilderWithReader(t *testing.T) {
|
||||
file, err := os.CreateTemp(t.TempDir(), "")
|
||||
if err != nil {
|
||||
t.Fatalf("Error creating tmp file: %v", err)
|
||||
}
|
||||
defer file.Close()
|
||||
builder := NewFormBuilder(&failingWriter{})
|
||||
err = builder.CreateFormFileReader("file", file, file.Name())
|
||||
checks.ErrorIs(t, err, errMockFailingWriterError, "formbuilder should return error if writer fails")
|
||||
|
||||
builder = NewFormBuilder(&bytes.Buffer{})
|
||||
reader := &failingReader{}
|
||||
err = builder.CreateFormFileReader("file", reader, "")
|
||||
checks.ErrorIs(t, err, errMockFailingReaderError, "formbuilder should return error if copy reader fails")
|
||||
|
||||
successReader := &bytes.Buffer{}
|
||||
err = builder.CreateFormFileReader("file", successReader, "")
|
||||
checks.NoError(t, err, "formbuilder should not return error")
|
||||
}
|
||||
|
||||
@@ -46,6 +46,8 @@ type Definition struct {
|
||||
// additionalProperties: false
|
||||
// additionalProperties: jsonschema.Definition{Type: jsonschema.String}
|
||||
AdditionalProperties any `json:"additionalProperties,omitempty"`
|
||||
// Whether the schema is nullable or not.
|
||||
Nullable bool `json:"nullable,omitempty"`
|
||||
}
|
||||
|
||||
func (d *Definition) MarshalJSON() ([]byte, error) {
|
||||
@@ -124,9 +126,12 @@ func reflectSchemaObject(t reflect.Type) (*Definition, error) {
|
||||
}
|
||||
jsonTag := field.Tag.Get("json")
|
||||
var required = true
|
||||
if jsonTag == "" {
|
||||
switch {
|
||||
case jsonTag == "-":
|
||||
continue
|
||||
case jsonTag == "":
|
||||
jsonTag = field.Name
|
||||
} else if strings.HasSuffix(jsonTag, ",omitempty") {
|
||||
case strings.HasSuffix(jsonTag, ",omitempty"):
|
||||
jsonTag = strings.TrimSuffix(jsonTag, ",omitempty")
|
||||
required = false
|
||||
}
|
||||
@@ -139,6 +144,16 @@ func reflectSchemaObject(t reflect.Type) (*Definition, error) {
|
||||
if description != "" {
|
||||
item.Description = description
|
||||
}
|
||||
enum := field.Tag.Get("enum")
|
||||
if enum != "" {
|
||||
item.Enum = strings.Split(enum, ",")
|
||||
}
|
||||
|
||||
if n := field.Tag.Get("nullable"); n != "" {
|
||||
nullable, _ := strconv.ParseBool(n)
|
||||
item.Nullable = nullable
|
||||
}
|
||||
|
||||
properties[jsonTag] = *item
|
||||
|
||||
if s := field.Tag.Get("required"); s != "" {
|
||||
|
||||
@@ -17,7 +17,7 @@ func TestDefinition_MarshalJSON(t *testing.T) {
|
||||
{
|
||||
name: "Test with empty Definition",
|
||||
def: jsonschema.Definition{},
|
||||
want: `{"properties":{}}`,
|
||||
want: `{}`,
|
||||
},
|
||||
{
|
||||
name: "Test with Definition properties set",
|
||||
@@ -35,11 +35,10 @@ func TestDefinition_MarshalJSON(t *testing.T) {
|
||||
"description":"A string type",
|
||||
"properties":{
|
||||
"name":{
|
||||
"type":"string",
|
||||
"properties":{}
|
||||
"type":"string"
|
||||
}
|
||||
}
|
||||
}`,
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "Test with nested Definition properties",
|
||||
@@ -66,17 +65,15 @@ func TestDefinition_MarshalJSON(t *testing.T) {
|
||||
"type":"object",
|
||||
"properties":{
|
||||
"name":{
|
||||
"type":"string",
|
||||
"properties":{}
|
||||
"type":"string"
|
||||
},
|
||||
"age":{
|
||||
"type":"integer",
|
||||
"properties":{}
|
||||
"type":"integer"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}`,
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "Test with complex nested Definition",
|
||||
@@ -114,30 +111,26 @@ func TestDefinition_MarshalJSON(t *testing.T) {
|
||||
"type":"object",
|
||||
"properties":{
|
||||
"name":{
|
||||
"type":"string",
|
||||
"properties":{}
|
||||
"type":"string"
|
||||
},
|
||||
"age":{
|
||||
"type":"integer",
|
||||
"properties":{}
|
||||
"type":"integer"
|
||||
},
|
||||
"address":{
|
||||
"type":"object",
|
||||
"properties":{
|
||||
"city":{
|
||||
"type":"string",
|
||||
"properties":{}
|
||||
"type":"string"
|
||||
},
|
||||
"country":{
|
||||
"type":"string",
|
||||
"properties":{}
|
||||
"type":"string"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}`,
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "Test with Array type Definition",
|
||||
@@ -155,18 +148,14 @@ func TestDefinition_MarshalJSON(t *testing.T) {
|
||||
want: `{
|
||||
"type":"array",
|
||||
"items":{
|
||||
"type":"string",
|
||||
"properties":{
|
||||
|
||||
}
|
||||
"type":"string"
|
||||
},
|
||||
"properties":{
|
||||
"name":{
|
||||
"type":"string",
|
||||
"properties":{}
|
||||
"type":"string"
|
||||
}
|
||||
}
|
||||
}`,
|
||||
}`,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -193,6 +182,232 @@ func TestDefinition_MarshalJSON(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestStructToSchema(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
in any
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "Test with empty struct",
|
||||
in: struct{}{},
|
||||
want: `{
|
||||
"type":"object",
|
||||
"additionalProperties":false
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "Test with struct containing many fields",
|
||||
in: struct {
|
||||
Name string `json:"name"`
|
||||
Age int `json:"age"`
|
||||
Active bool `json:"active"`
|
||||
Height float64 `json:"height"`
|
||||
Cities []struct {
|
||||
Name string `json:"name"`
|
||||
State string `json:"state"`
|
||||
} `json:"cities"`
|
||||
}{
|
||||
Name: "John Doe",
|
||||
Age: 30,
|
||||
Cities: []struct {
|
||||
Name string `json:"name"`
|
||||
State string `json:"state"`
|
||||
}{
|
||||
{Name: "New York", State: "NY"},
|
||||
{Name: "Los Angeles", State: "CA"},
|
||||
},
|
||||
},
|
||||
want: `{
|
||||
"type":"object",
|
||||
"properties":{
|
||||
"name":{
|
||||
"type":"string"
|
||||
},
|
||||
"age":{
|
||||
"type":"integer"
|
||||
},
|
||||
"active":{
|
||||
"type":"boolean"
|
||||
},
|
||||
"height":{
|
||||
"type":"number"
|
||||
},
|
||||
"cities":{
|
||||
"type":"array",
|
||||
"items":{
|
||||
"additionalProperties":false,
|
||||
"type":"object",
|
||||
"properties":{
|
||||
"name":{
|
||||
"type":"string"
|
||||
},
|
||||
"state":{
|
||||
"type":"string"
|
||||
}
|
||||
},
|
||||
"required":["name","state"]
|
||||
}
|
||||
}
|
||||
},
|
||||
"required":["name","age","active","height","cities"],
|
||||
"additionalProperties":false
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "Test with description tag",
|
||||
in: struct {
|
||||
Name string `json:"name" description:"The name of the person"`
|
||||
}{
|
||||
Name: "John Doe",
|
||||
},
|
||||
want: `{
|
||||
"type":"object",
|
||||
"properties":{
|
||||
"name":{
|
||||
"type":"string",
|
||||
"description":"The name of the person"
|
||||
}
|
||||
},
|
||||
"required":["name"],
|
||||
"additionalProperties":false
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "Test with required tag",
|
||||
in: struct {
|
||||
Name string `json:"name" required:"false"`
|
||||
}{
|
||||
Name: "John Doe",
|
||||
},
|
||||
want: `{
|
||||
"type":"object",
|
||||
"properties":{
|
||||
"name":{
|
||||
"type":"string"
|
||||
}
|
||||
},
|
||||
"additionalProperties":false
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "Test with enum tag",
|
||||
in: struct {
|
||||
Color string `json:"color" enum:"red,green,blue"`
|
||||
}{
|
||||
Color: "red",
|
||||
},
|
||||
want: `{
|
||||
"type":"object",
|
||||
"properties":{
|
||||
"color":{
|
||||
"type":"string",
|
||||
"enum":["red","green","blue"]
|
||||
}
|
||||
},
|
||||
"required":["color"],
|
||||
"additionalProperties":false
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "Test with nullable tag",
|
||||
in: struct {
|
||||
Name *string `json:"name" nullable:"true"`
|
||||
}{
|
||||
Name: nil,
|
||||
},
|
||||
want: `{
|
||||
|
||||
"type":"object",
|
||||
"properties":{
|
||||
"name":{
|
||||
"type":"string",
|
||||
"nullable":true
|
||||
}
|
||||
},
|
||||
"required":["name"],
|
||||
"additionalProperties":false
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "Test with exclude mark",
|
||||
in: struct {
|
||||
Name string `json:"-"`
|
||||
}{
|
||||
Name: "Name",
|
||||
},
|
||||
want: `{
|
||||
"type":"object",
|
||||
"additionalProperties":false
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "Test with no json tag",
|
||||
in: struct {
|
||||
Name string
|
||||
}{
|
||||
Name: "",
|
||||
},
|
||||
want: `{
|
||||
"type":"object",
|
||||
"properties":{
|
||||
"Name":{
|
||||
"type":"string"
|
||||
}
|
||||
},
|
||||
"required":["Name"],
|
||||
"additionalProperties":false
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "Test with omitempty tag",
|
||||
in: struct {
|
||||
Name string `json:"name,omitempty"`
|
||||
}{
|
||||
Name: "",
|
||||
},
|
||||
want: `{
|
||||
"type":"object",
|
||||
"properties":{
|
||||
"name":{
|
||||
"type":"string"
|
||||
}
|
||||
},
|
||||
"additionalProperties":false
|
||||
}`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
wantBytes := []byte(tt.want)
|
||||
|
||||
schema, err := jsonschema.GenerateSchemaForType(tt.in)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to generate schema: error = %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
var want map[string]interface{}
|
||||
err = json.Unmarshal(wantBytes, &want)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to Unmarshal JSON: error = %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
got := structToMap(t, schema)
|
||||
gotPtr := structToMap(t, &schema)
|
||||
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("MarshalJSON() got = %v, want %v", got, want)
|
||||
}
|
||||
if !reflect.DeepEqual(gotPtr, want) {
|
||||
t.Errorf("MarshalJSON() gotPtr = %v, want %v", gotPtr, want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func structToMap(t *testing.T, v any) map[string]any {
|
||||
t.Helper()
|
||||
gotBytes, err := json.Marshal(v)
|
||||
|
||||
@@ -47,6 +47,24 @@ func TestGetModel(t *testing.T) {
|
||||
checks.NoError(t, err, "GetModel error")
|
||||
}
|
||||
|
||||
// TestGetModelO3 Tests the retrieve O3 model endpoint of the API using the mocked server.
|
||||
func TestGetModelO3(t *testing.T) {
|
||||
client, server, teardown := setupOpenAITestServer()
|
||||
defer teardown()
|
||||
server.RegisterHandler("/v1/models/o3", handleGetModelEndpoint)
|
||||
_, err := client.GetModel(context.Background(), "o3")
|
||||
checks.NoError(t, err, "GetModel error for O3")
|
||||
}
|
||||
|
||||
// TestGetModelO4Mini Tests the retrieve O4Mini model endpoint of the API using the mocked server.
|
||||
func TestGetModelO4Mini(t *testing.T) {
|
||||
client, server, teardown := setupOpenAITestServer()
|
||||
defer teardown()
|
||||
server.RegisterHandler("/v1/models/o4-mini", handleGetModelEndpoint)
|
||||
_, err := client.GetModel(context.Background(), "o4-mini")
|
||||
checks.NoError(t, err, "GetModel error for O4Mini")
|
||||
}
|
||||
|
||||
func TestAzureGetModel(t *testing.T) {
|
||||
client, server, teardown := setupAzureTestServer()
|
||||
defer teardown()
|
||||
|
||||
@@ -29,7 +29,7 @@ func setupAzureTestServer() (client *openai.Client, server *test.ServerTest, tea
|
||||
|
||||
// numTokens Returns the number of GPT-3 encoded tokens in the given text.
|
||||
// This function approximates based on the rule of thumb stated by OpenAI:
|
||||
// https://beta.openai.com/tokenizer/
|
||||
// https://beta.openai.com/tokenizer.
|
||||
//
|
||||
// TODO: implement an actual tokenizer for GPT-3 and Codex (once available).
|
||||
func numTokens(s string) int {
|
||||
|
||||
@@ -28,15 +28,6 @@ var (
|
||||
ErrReasoningModelLimitationsOther = errors.New("this model has beta-limitations, temperature, top_p and n are fixed at 1, while presence_penalty and frequency_penalty are fixed at 0") //nolint:lll
|
||||
)
|
||||
|
||||
var unsupportedToolsForO1Models = map[ToolType]struct{}{
|
||||
ToolTypeFunction: {},
|
||||
}
|
||||
|
||||
var availableMessageRoleForO1Models = map[string]struct{}{
|
||||
ChatMessageRoleUser: {},
|
||||
ChatMessageRoleAssistant: {},
|
||||
}
|
||||
|
||||
// ReasoningValidator handles validation for o-series model requests.
|
||||
type ReasoningValidator struct{}
|
||||
|
||||
@@ -49,8 +40,9 @@ func NewReasoningValidator() *ReasoningValidator {
|
||||
func (v *ReasoningValidator) Validate(request ChatCompletionRequest) error {
|
||||
o1Series := strings.HasPrefix(request.Model, "o1")
|
||||
o3Series := strings.HasPrefix(request.Model, "o3")
|
||||
o4Series := strings.HasPrefix(request.Model, "o4")
|
||||
|
||||
if !o1Series && !o3Series {
|
||||
if !o1Series && !o3Series && !o4Series {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -58,12 +50,6 @@ func (v *ReasoningValidator) Validate(request ChatCompletionRequest) error {
|
||||
return err
|
||||
}
|
||||
|
||||
if o1Series {
|
||||
if err := v.validateO1Specific(request); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -93,19 +79,3 @@ func (v *ReasoningValidator) validateReasoningModelParams(request ChatCompletion
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateO1Specific checks O1-specific limitations.
|
||||
func (v *ReasoningValidator) validateO1Specific(request ChatCompletionRequest) error {
|
||||
for _, m := range request.Messages {
|
||||
if _, found := availableMessageRoleForO1Models[m.Role]; !found {
|
||||
return ErrO1BetaLimitationsMessageTypes
|
||||
}
|
||||
}
|
||||
|
||||
for _, t := range request.Tools {
|
||||
if _, found := unsupportedToolsForO1Models[t.Type]; found {
|
||||
return ErrO1BetaLimitationsTools
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -11,17 +11,22 @@ const (
|
||||
TTSModel1 SpeechModel = "tts-1"
|
||||
TTSModel1HD SpeechModel = "tts-1-hd"
|
||||
TTSModelCanary SpeechModel = "canary-tts"
|
||||
TTSModelGPT4oMini SpeechModel = "gpt-4o-mini-tts"
|
||||
)
|
||||
|
||||
type SpeechVoice string
|
||||
|
||||
const (
|
||||
VoiceAlloy SpeechVoice = "alloy"
|
||||
VoiceAsh SpeechVoice = "ash"
|
||||
VoiceBallad SpeechVoice = "ballad"
|
||||
VoiceCoral SpeechVoice = "coral"
|
||||
VoiceEcho SpeechVoice = "echo"
|
||||
VoiceFable SpeechVoice = "fable"
|
||||
VoiceOnyx SpeechVoice = "onyx"
|
||||
VoiceNova SpeechVoice = "nova"
|
||||
VoiceShimmer SpeechVoice = "shimmer"
|
||||
VoiceVerse SpeechVoice = "verse"
|
||||
)
|
||||
|
||||
type SpeechResponseFormat string
|
||||
@@ -39,6 +44,7 @@ type CreateSpeechRequest struct {
|
||||
Model SpeechModel `json:"model"`
|
||||
Input string `json:"input"`
|
||||
Voice SpeechVoice `json:"voice"`
|
||||
Instructions string `json:"instructions,omitempty"` // Optional, Doesnt work with tts-1 or tts-1-hd.
|
||||
ResponseFormat SpeechResponseFormat `json:"response_format,omitempty"` // Optional, default to mp3
|
||||
Speed float64 `json:"speed,omitempty"` // Optional, default to 1.0
|
||||
}
|
||||
|
||||
@@ -6,13 +6,14 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"regexp"
|
||||
|
||||
utils "git.vaala.cloud/VaalaCat/go-openai/internal"
|
||||
)
|
||||
|
||||
var (
|
||||
headerData = []byte("data: ")
|
||||
errorPrefix = []byte(`data: {"error":`)
|
||||
headerData = regexp.MustCompile(`^data:\s*`)
|
||||
errorPrefix = regexp.MustCompile(`^data:\s*{"error":`)
|
||||
)
|
||||
|
||||
type streamable interface {
|
||||
@@ -70,12 +71,12 @@ func (stream *streamReader[T]) processLines() ([]byte, error) {
|
||||
}
|
||||
|
||||
noSpaceLine := bytes.TrimSpace(rawLine)
|
||||
if bytes.HasPrefix(noSpaceLine, errorPrefix) {
|
||||
if errorPrefix.Match(noSpaceLine) {
|
||||
hasErrorPrefix = true
|
||||
}
|
||||
if !bytes.HasPrefix(noSpaceLine, headerData) || hasErrorPrefix {
|
||||
if !headerData.Match(noSpaceLine) || hasErrorPrefix {
|
||||
if hasErrorPrefix {
|
||||
noSpaceLine = bytes.TrimPrefix(noSpaceLine, headerData)
|
||||
noSpaceLine = headerData.ReplaceAll(noSpaceLine, nil)
|
||||
}
|
||||
writeErr := stream.errAccumulator.Write(noSpaceLine)
|
||||
if writeErr != nil {
|
||||
@@ -89,7 +90,7 @@ func (stream *streamReader[T]) processLines() ([]byte, error) {
|
||||
continue
|
||||
}
|
||||
|
||||
noPrefixLine := bytes.TrimPrefix(noSpaceLine, headerData)
|
||||
noPrefixLine := headerData.ReplaceAll(noSpaceLine, nil)
|
||||
if string(noPrefixLine) == "[DONE]" {
|
||||
stream.isFinished = true
|
||||
return nil, io.EOF
|
||||
|
||||
Reference in New Issue
Block a user