Compare commits
27 Commits
1337a4b683
...
master
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2436e7afb8 | ||
|
|
67f3b169df | ||
|
|
3640274cd1 | ||
|
|
ff9d83a485 | ||
|
|
8c65b35c57 | ||
|
|
4d2e7ab29d | ||
|
|
6aaa732296 | ||
|
|
0116f2994d | ||
|
|
8ba38f6ba1 | ||
|
|
6181facea7 | ||
|
|
77ccac8d34 | ||
|
|
5ea214a188 | ||
|
|
d65f0cb54e | ||
|
|
93a611cf4f | ||
|
|
6836cf6a6f | ||
|
|
da5f9bc9bc | ||
|
|
bb5bc27567 | ||
|
|
4cccc6c934 | ||
|
|
306fbbbe6f | ||
|
|
658beda2ba | ||
|
|
d68a683815 | ||
|
|
e99eb54c9d | ||
|
|
74d6449f22 | ||
|
|
261721bfdb | ||
|
|
be2e2387d4 | ||
|
|
85f578b865 | ||
|
|
c0a9a75fe0 |
2
.github/ISSUE_TEMPLATE/bug_report.md
vendored
2
.github/ISSUE_TEMPLATE/bug_report.md
vendored
@@ -8,7 +8,7 @@ assignees: ''
|
|||||||
---
|
---
|
||||||
|
|
||||||
Your issue may already be reported!
|
Your issue may already be reported!
|
||||||
Please search on the [issue tracker](https://github.com/sashabaranov/go-openai/issues) before creating one.
|
Please search on the [issue tracker](https://git.vaala.cloud/VaalaCat/go-openai/issues) before creating one.
|
||||||
|
|
||||||
**Describe the bug**
|
**Describe the bug**
|
||||||
A clear and concise description of what the bug is. If it's an API-related bug, please provide relevant endpoint(s).
|
A clear and concise description of what the bug is. If it's an API-related bug, please provide relevant endpoint(s).
|
||||||
|
|||||||
2
.github/ISSUE_TEMPLATE/feature_request.md
vendored
2
.github/ISSUE_TEMPLATE/feature_request.md
vendored
@@ -8,7 +8,7 @@ assignees: ''
|
|||||||
---
|
---
|
||||||
|
|
||||||
Your issue may already be reported!
|
Your issue may already be reported!
|
||||||
Please search on the [issue tracker](https://github.com/sashabaranov/go-openai/issues) before creating one.
|
Please search on the [issue tracker](https://git.vaala.cloud/VaalaCat/go-openai/issues) before creating one.
|
||||||
|
|
||||||
**Is your feature request related to a problem? Please describe.**
|
**Is your feature request related to a problem? Please describe.**
|
||||||
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
|
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
|
||||||
|
|||||||
2
.github/PULL_REQUEST_TEMPLATE.md
vendored
2
.github/PULL_REQUEST_TEMPLATE.md
vendored
@@ -1,5 +1,5 @@
|
|||||||
A similar PR may already be submitted!
|
A similar PR may already be submitted!
|
||||||
Please search among the [Pull request](https://github.com/sashabaranov/go-openai/pulls) before creating one.
|
Please search among the [Pull request](https://git.vaala.cloud/VaalaCat/go-openai/pulls) before creating one.
|
||||||
|
|
||||||
If your changes introduce breaking changes, please prefix the title of your pull request with "[BREAKING_CHANGES]". This allows for clear identification of such changes in the 'What's Changed' section on the release page, making it developer-friendly.
|
If your changes introduce breaking changes, please prefix the title of your pull request with "[BREAKING_CHANGES]". This allows for clear identification of such changes in the 'What's Changed' section on the release page, making it developer-friendly.
|
||||||
|
|
||||||
|
|||||||
12
.github/workflows/pr.yml
vendored
12
.github/workflows/pr.yml
vendored
@@ -13,15 +13,17 @@ jobs:
|
|||||||
- name: Setup Go
|
- name: Setup Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: '1.21'
|
go-version: '1.24'
|
||||||
- name: Run vet
|
- name: Run vet
|
||||||
run: |
|
run: |
|
||||||
go vet .
|
go vet .
|
||||||
- name: Run golangci-lint
|
- name: Run golangci-lint
|
||||||
uses: golangci/golangci-lint-action@v4
|
uses: golangci/golangci-lint-action@v7
|
||||||
with:
|
with:
|
||||||
version: latest
|
version: v2.1.5
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: go test -race -covermode=atomic -coverprofile=coverage.out -v .
|
run: go test -race -covermode=atomic -coverprofile=coverage.out -v ./...
|
||||||
- name: Upload coverage reports to Codecov
|
- name: Upload coverage reports to Codecov
|
||||||
uses: codecov/codecov-action@v4
|
uses: codecov/codecov-action@v5
|
||||||
|
with:
|
||||||
|
token: ${{ secrets.CODECOV_TOKEN }}
|
||||||
|
|||||||
349
.golangci.yml
349
.golangci.yml
@@ -1,66 +1,94 @@
|
|||||||
## Golden config for golangci-lint v1.47.3
|
version: "2"
|
||||||
#
|
linters:
|
||||||
# This is the best config for golangci-lint based on my experience and opinion.
|
default: none
|
||||||
# It is very strict, but not extremely strict.
|
enable:
|
||||||
# Feel free to adopt and change it for your needs.
|
- asciicheck
|
||||||
|
- bidichk
|
||||||
run:
|
- bodyclose
|
||||||
# Timeout for analysis, e.g. 30s, 5m.
|
- contextcheck
|
||||||
# Default: 1m
|
- cyclop
|
||||||
timeout: 3m
|
- dupl
|
||||||
|
- durationcheck
|
||||||
|
- errcheck
|
||||||
# This file contains only configs which differ from defaults.
|
- errname
|
||||||
# All possible options can be found here https://github.com/golangci/golangci-lint/blob/master/.golangci.reference.yml
|
- errorlint
|
||||||
linters-settings:
|
- exhaustive
|
||||||
|
- forbidigo
|
||||||
|
- funlen
|
||||||
|
- gochecknoinits
|
||||||
|
- gocognit
|
||||||
|
- goconst
|
||||||
|
- gocritic
|
||||||
|
- gocyclo
|
||||||
|
- godot
|
||||||
|
- gomoddirectives
|
||||||
|
- gomodguard
|
||||||
|
- goprintffuncname
|
||||||
|
- gosec
|
||||||
|
- govet
|
||||||
|
- ineffassign
|
||||||
|
- lll
|
||||||
|
- makezero
|
||||||
|
- mnd
|
||||||
|
- nestif
|
||||||
|
- nilerr
|
||||||
|
- nilnil
|
||||||
|
- nolintlint
|
||||||
|
- nosprintfhostport
|
||||||
|
- predeclared
|
||||||
|
- promlinter
|
||||||
|
- revive
|
||||||
|
- rowserrcheck
|
||||||
|
- sqlclosecheck
|
||||||
|
- staticcheck
|
||||||
|
- testpackage
|
||||||
|
- tparallel
|
||||||
|
- unconvert
|
||||||
|
- unparam
|
||||||
|
- unused
|
||||||
|
- usetesting
|
||||||
|
- wastedassign
|
||||||
|
- whitespace
|
||||||
|
settings:
|
||||||
cyclop:
|
cyclop:
|
||||||
# The maximal code complexity to report.
|
|
||||||
# Default: 10
|
|
||||||
max-complexity: 30
|
max-complexity: 30
|
||||||
# The maximal average package complexity.
|
package-average: 10
|
||||||
# If it's higher than 0.0 (float) the check is enabled
|
|
||||||
# Default: 0.0
|
|
||||||
package-average: 10.0
|
|
||||||
|
|
||||||
errcheck:
|
errcheck:
|
||||||
# Report about not checking of errors in type assertions: `a := b.(MyStruct)`.
|
|
||||||
# Such cases aren't reported by default.
|
|
||||||
# Default: false
|
|
||||||
check-type-assertions: true
|
check-type-assertions: true
|
||||||
|
|
||||||
funlen:
|
funlen:
|
||||||
# Checks the number of lines in a function.
|
|
||||||
# If lower than 0, disable the check.
|
|
||||||
# Default: 60
|
|
||||||
lines: 100
|
lines: 100
|
||||||
# Checks the number of statements in a function.
|
|
||||||
# If lower than 0, disable the check.
|
|
||||||
# Default: 40
|
|
||||||
statements: 50
|
statements: 50
|
||||||
|
|
||||||
gocognit:
|
gocognit:
|
||||||
# Minimal code complexity to report
|
|
||||||
# Default: 30 (but we recommend 10-20)
|
|
||||||
min-complexity: 20
|
min-complexity: 20
|
||||||
|
|
||||||
gocritic:
|
gocritic:
|
||||||
# Settings passed to gocritic.
|
|
||||||
# The settings key is the name of a supported gocritic checker.
|
|
||||||
# The list of supported checkers can be find in https://go-critic.github.io/overview.
|
|
||||||
settings:
|
settings:
|
||||||
captLocal:
|
captLocal:
|
||||||
# Whether to restrict checker to params only.
|
|
||||||
# Default: true
|
|
||||||
paramsOnly: false
|
paramsOnly: false
|
||||||
underef:
|
underef:
|
||||||
# Whether to skip (*x).method() calls where x is a pointer receiver.
|
|
||||||
# Default: true
|
|
||||||
skipRecvDeref: false
|
skipRecvDeref: false
|
||||||
|
gomodguard:
|
||||||
|
blocked:
|
||||||
|
modules:
|
||||||
|
- github.com/golang/protobuf:
|
||||||
|
recommendations:
|
||||||
|
- google.golang.org/protobuf
|
||||||
|
reason: see https://developers.google.com/protocol-buffers/docs/reference/go/faq#modules
|
||||||
|
- github.com/satori/go.uuid:
|
||||||
|
recommendations:
|
||||||
|
- github.com/google/uuid
|
||||||
|
reason: satori's package is not maintained
|
||||||
|
- github.com/gofrs/uuid:
|
||||||
|
recommendations:
|
||||||
|
- github.com/google/uuid
|
||||||
|
reason: 'see recommendation from dev-infra team: https://confluence.gtforge.com/x/gQI6Aw'
|
||||||
|
govet:
|
||||||
|
disable:
|
||||||
|
- fieldalignment
|
||||||
|
enable-all: true
|
||||||
|
settings:
|
||||||
|
shadow:
|
||||||
|
strict: true
|
||||||
mnd:
|
mnd:
|
||||||
# List of function patterns to exclude from analysis.
|
|
||||||
# Values always ignored: `time.Date`
|
|
||||||
# Default: []
|
|
||||||
ignored-functions:
|
ignored-functions:
|
||||||
- os.Chmod
|
- os.Chmod
|
||||||
- os.Mkdir
|
- os.Mkdir
|
||||||
@@ -76,194 +104,44 @@ linters-settings:
|
|||||||
- strconv.ParseFloat
|
- strconv.ParseFloat
|
||||||
- strconv.ParseInt
|
- strconv.ParseInt
|
||||||
- strconv.ParseUint
|
- strconv.ParseUint
|
||||||
|
|
||||||
gomodguard:
|
|
||||||
blocked:
|
|
||||||
# List of blocked modules.
|
|
||||||
# Default: []
|
|
||||||
modules:
|
|
||||||
- github.com/golang/protobuf:
|
|
||||||
recommendations:
|
|
||||||
- google.golang.org/protobuf
|
|
||||||
reason: "see https://developers.google.com/protocol-buffers/docs/reference/go/faq#modules"
|
|
||||||
- github.com/satori/go.uuid:
|
|
||||||
recommendations:
|
|
||||||
- github.com/google/uuid
|
|
||||||
reason: "satori's package is not maintained"
|
|
||||||
- github.com/gofrs/uuid:
|
|
||||||
recommendations:
|
|
||||||
- github.com/google/uuid
|
|
||||||
reason: "see recommendation from dev-infra team: https://confluence.gtforge.com/x/gQI6Aw"
|
|
||||||
|
|
||||||
govet:
|
|
||||||
# Enable all analyzers.
|
|
||||||
# Default: false
|
|
||||||
enable-all: true
|
|
||||||
# Disable analyzers by name.
|
|
||||||
# Run `go tool vet help` to see all analyzers.
|
|
||||||
# Default: []
|
|
||||||
disable:
|
|
||||||
- fieldalignment # too strict
|
|
||||||
# Settings per analyzer.
|
|
||||||
settings:
|
|
||||||
shadow:
|
|
||||||
# Whether to be strict about shadowing; can be noisy.
|
|
||||||
# Default: false
|
|
||||||
strict: true
|
|
||||||
|
|
||||||
nakedret:
|
nakedret:
|
||||||
# Make an issue if func has more lines of code than this setting, and it has naked returns.
|
|
||||||
# Default: 30
|
|
||||||
max-func-lines: 0
|
max-func-lines: 0
|
||||||
|
|
||||||
nolintlint:
|
nolintlint:
|
||||||
# Exclude following linters from requiring an explanation.
|
|
||||||
# Default: []
|
|
||||||
allow-no-explanation: [ funlen, gocognit, lll ]
|
|
||||||
# Enable to require an explanation of nonzero length after each nolint directive.
|
|
||||||
# Default: false
|
|
||||||
require-explanation: true
|
require-explanation: true
|
||||||
# Enable to require nolint directives to mention the specific linter being suppressed.
|
|
||||||
# Default: false
|
|
||||||
require-specific: true
|
require-specific: true
|
||||||
|
allow-no-explanation:
|
||||||
|
- funlen
|
||||||
|
- gocognit
|
||||||
|
- lll
|
||||||
rowserrcheck:
|
rowserrcheck:
|
||||||
# database/sql is always checked
|
|
||||||
# Default: []
|
|
||||||
packages:
|
packages:
|
||||||
- github.com/jmoiron/sqlx
|
- github.com/jmoiron/sqlx
|
||||||
|
exclusions:
|
||||||
tenv:
|
generated: lax
|
||||||
# The option `all` will run against whole test files (`_test.go`) regardless of method/function signatures.
|
presets:
|
||||||
# Otherwise, only methods that take `*testing.T`, `*testing.B`, and `testing.TB` as arguments are checked.
|
- comments
|
||||||
# Default: false
|
- common-false-positives
|
||||||
all: true
|
- legacy
|
||||||
|
- std-error-handling
|
||||||
varcheck:
|
rules:
|
||||||
# Check usage of exported fields and variables.
|
- linters:
|
||||||
# Default: false
|
- forbidigo
|
||||||
exported-fields: false # default false # TODO: enable after fixing false positives
|
- mnd
|
||||||
|
- revive
|
||||||
|
path : ^examples/.*\.go$
|
||||||
linters:
|
- linters:
|
||||||
disable-all: true
|
- lll
|
||||||
enable:
|
source: ^//\s*go:generate\s
|
||||||
## enabled by default
|
- linters:
|
||||||
- errcheck # Errcheck is a program for checking for unchecked errors in go programs. These unchecked errors can be critical bugs in some cases
|
- godot
|
||||||
- gosimple # Linter for Go source code that specializes in simplifying a code
|
source: (noinspection|TODO)
|
||||||
- govet # Vet examines Go source code and reports suspicious constructs, such as Printf calls whose arguments do not align with the format string
|
- linters:
|
||||||
- ineffassign # Detects when assignments to existing variables are not used
|
- gocritic
|
||||||
- staticcheck # Staticcheck is a go vet on steroids, applying a ton of static analysis checks
|
source: //noinspection
|
||||||
- typecheck # Like the front-end of a Go compiler, parses and type-checks Go code
|
- linters:
|
||||||
- unused # Checks Go code for unused constants, variables, functions and types
|
- errorlint
|
||||||
## disabled by default
|
source: ^\s+if _, ok := err\.\([^.]+\.InternalError\); ok {
|
||||||
# - asasalint # Check for pass []any as any in variadic func(...any)
|
- linters:
|
||||||
- asciicheck # Simple linter to check that your code does not contain non-ASCII identifiers
|
|
||||||
- bidichk # Checks for dangerous unicode character sequences
|
|
||||||
- bodyclose # checks whether HTTP response body is closed successfully
|
|
||||||
- contextcheck # check the function whether use a non-inherited context
|
|
||||||
- cyclop # checks function and package cyclomatic complexity
|
|
||||||
- dupl # Tool for code clone detection
|
|
||||||
- durationcheck # check for two durations multiplied together
|
|
||||||
- errname # Checks that sentinel errors are prefixed with the Err and error types are suffixed with the Error.
|
|
||||||
- errorlint # errorlint is a linter for that can be used to find code that will cause problems with the error wrapping scheme introduced in Go 1.13.
|
|
||||||
# Removed execinquery (deprecated). execinquery is a linter about query string checker in Query function which reads your Go src files and warning it finds
|
|
||||||
- exhaustive # check exhaustiveness of enum switch statements
|
|
||||||
- exportloopref # checks for pointers to enclosing loop variables
|
|
||||||
- forbidigo # Forbids identifiers
|
|
||||||
- funlen # Tool for detection of long functions
|
|
||||||
# - gochecknoglobals # check that no global variables exist
|
|
||||||
- gochecknoinits # Checks that no init functions are present in Go code
|
|
||||||
- gocognit # Computes and checks the cognitive complexity of functions
|
|
||||||
- goconst # Finds repeated strings that could be replaced by a constant
|
|
||||||
- gocritic # Provides diagnostics that check for bugs, performance and style issues.
|
|
||||||
- gocyclo # Computes and checks the cyclomatic complexity of functions
|
|
||||||
- godot # Check if comments end in a period
|
|
||||||
- goimports # In addition to fixing imports, goimports also formats your code in the same style as gofmt.
|
|
||||||
- gomoddirectives # Manage the use of 'replace', 'retract', and 'excludes' directives in go.mod.
|
|
||||||
- gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations.
|
|
||||||
- goprintffuncname # Checks that printf-like functions are named with f at the end
|
|
||||||
- gosec # Inspects source code for security problems
|
|
||||||
- lll # Reports long lines
|
|
||||||
- makezero # Finds slice declarations with non-zero initial length
|
|
||||||
# - nakedret # Finds naked returns in functions greater than a specified function length
|
|
||||||
- mnd # An analyzer to detect magic numbers.
|
|
||||||
- nestif # Reports deeply nested if statements
|
|
||||||
- nilerr # Finds the code that returns nil even if it checks that the error is not nil.
|
|
||||||
- nilnil # Checks that there is no simultaneous return of nil error and an invalid value.
|
|
||||||
# - noctx # noctx finds sending http request without context.Context
|
|
||||||
- nolintlint # Reports ill-formed or insufficient nolint directives
|
|
||||||
# - nonamedreturns # Reports all named returns
|
|
||||||
- nosprintfhostport # Checks for misuse of Sprintf to construct a host with port in a URL.
|
|
||||||
- predeclared # find code that shadows one of Go's predeclared identifiers
|
|
||||||
- promlinter # Check Prometheus metrics naming via promlint
|
|
||||||
- revive # Fast, configurable, extensible, flexible, and beautiful linter for Go. Drop-in replacement of golint.
|
|
||||||
- rowserrcheck # checks whether Err of rows is checked successfully
|
|
||||||
- sqlclosecheck # Checks that sql.Rows and sql.Stmt are closed.
|
|
||||||
- stylecheck # Stylecheck is a replacement for golint
|
|
||||||
- tenv # tenv is analyzer that detects using os.Setenv instead of t.Setenv since Go1.17
|
|
||||||
- testpackage # linter that makes you use a separate _test package
|
|
||||||
- tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes
|
|
||||||
- unconvert # Remove unnecessary type conversions
|
|
||||||
- unparam # Reports unused function parameters
|
|
||||||
- usetesting # Reports uses of functions with replacement inside the testing package
|
|
||||||
- wastedassign # wastedassign finds wasted assignment statements.
|
|
||||||
- whitespace # Tool for detection of leading and trailing whitespace
|
|
||||||
## you may want to enable
|
|
||||||
#- decorder # check declaration order and count of types, constants, variables and functions
|
|
||||||
#- exhaustruct # Checks if all structure fields are initialized
|
|
||||||
#- goheader # Checks is file header matches to pattern
|
|
||||||
#- ireturn # Accept Interfaces, Return Concrete Types
|
|
||||||
#- prealloc # [premature optimization, but can be used in some cases] Finds slice declarations that could potentially be preallocated
|
|
||||||
#- varnamelen # [great idea, but too many false positives] checks that the length of a variable's name matches its scope
|
|
||||||
#- wrapcheck # Checks that errors returned from external packages are wrapped
|
|
||||||
## disabled
|
|
||||||
#- containedctx # containedctx is a linter that detects struct contained context.Context field
|
|
||||||
#- depguard # [replaced by gomodguard] Go linter that checks if package imports are in a list of acceptable packages
|
|
||||||
#- dogsled # Checks assignments with too many blank identifiers (e.g. x, _, _, _, := f())
|
|
||||||
#- errchkjson # [don't see profit + I'm against of omitting errors like in the first example https://github.com/breml/errchkjson] Checks types passed to the json encoding functions. Reports unsupported types and optionally reports occasions, where the check for the returned error can be omitted.
|
|
||||||
#- forcetypeassert # [replaced by errcheck] finds forced type assertions
|
|
||||||
#- gci # Gci controls golang package import order and makes it always deterministic.
|
|
||||||
#- godox # Tool for detection of FIXME, TODO and other comment keywords
|
|
||||||
#- goerr113 # [too strict] Golang linter to check the errors handling expressions
|
|
||||||
#- gofmt # [replaced by goimports] Gofmt checks whether code was gofmt-ed. By default this tool runs with -s option to check for code simplification
|
|
||||||
#- gofumpt # [replaced by goimports, gofumports is not available yet] Gofumpt checks whether code was gofumpt-ed.
|
|
||||||
#- grouper # An analyzer to analyze expression groups.
|
|
||||||
#- ifshort # Checks that your code uses short syntax for if-statements whenever possible
|
|
||||||
#- importas # Enforces consistent import aliases
|
|
||||||
#- maintidx # maintidx measures the maintainability index of each function.
|
|
||||||
#- misspell # [useless] Finds commonly misspelled English words in comments
|
|
||||||
#- nlreturn # [too strict and mostly code is not more readable] nlreturn checks for a new line before return and branch statements to increase code clarity
|
|
||||||
#- nosnakecase # Detects snake case of variable naming and function name. # TODO: maybe enable after https://github.com/sivchari/nosnakecase/issues/14
|
|
||||||
#- paralleltest # [too many false positives] paralleltest detects missing usage of t.Parallel() method in your Go test
|
|
||||||
#- tagliatelle # Checks the struct tags.
|
|
||||||
#- thelper # thelper detects golang test helpers without t.Helper() call and checks the consistency of test helpers
|
|
||||||
#- wsl # [too strict and mostly code is not more readable] Whitespace Linter - Forces you to use empty lines!
|
|
||||||
## deprecated
|
|
||||||
#- exhaustivestruct # [deprecated, replaced by exhaustruct] Checks if all struct's fields are initialized
|
|
||||||
#- golint # [deprecated, replaced by revive] Golint differs from gofmt. Gofmt reformats Go source code, whereas golint prints out style mistakes
|
|
||||||
#- interfacer # [deprecated] Linter that suggests narrower interface types
|
|
||||||
#- maligned # [deprecated, replaced by govet fieldalignment] Tool to detect Go structs that would take less memory if their fields were sorted
|
|
||||||
#- scopelint # [deprecated, replaced by exportloopref] Scopelint checks for unpinned variables in go programs
|
|
||||||
|
|
||||||
|
|
||||||
issues:
|
|
||||||
# Maximum count of issues with the same text.
|
|
||||||
# Set to 0 to disable.
|
|
||||||
# Default: 3
|
|
||||||
max-same-issues: 50
|
|
||||||
|
|
||||||
exclude-rules:
|
|
||||||
- source: "^//\\s*go:generate\\s"
|
|
||||||
linters: [ lll ]
|
|
||||||
- source: "(noinspection|TODO)"
|
|
||||||
linters: [ godot ]
|
|
||||||
- source: "//noinspection"
|
|
||||||
linters: [ gocritic ]
|
|
||||||
- source: "^\\s+if _, ok := err\\.\\([^.]+\\.InternalError\\); ok {"
|
|
||||||
linters: [ errorlint ]
|
|
||||||
- path: "_test\\.go"
|
|
||||||
linters:
|
|
||||||
- bodyclose
|
- bodyclose
|
||||||
- dupl
|
- dupl
|
||||||
- funlen
|
- funlen
|
||||||
@@ -271,3 +149,20 @@ issues:
|
|||||||
- gosec
|
- gosec
|
||||||
- noctx
|
- noctx
|
||||||
- wrapcheck
|
- wrapcheck
|
||||||
|
- staticcheck
|
||||||
|
path: _test\.go
|
||||||
|
paths:
|
||||||
|
- third_party$
|
||||||
|
- builtin$
|
||||||
|
- examples$
|
||||||
|
issues:
|
||||||
|
max-same-issues: 50
|
||||||
|
formatters:
|
||||||
|
enable:
|
||||||
|
- goimports
|
||||||
|
exclusions:
|
||||||
|
generated: lax
|
||||||
|
paths:
|
||||||
|
- third_party$
|
||||||
|
- builtin$
|
||||||
|
- examples$
|
||||||
|
|||||||
@@ -1,22 +1,22 @@
|
|||||||
# Contributing Guidelines
|
# Contributing Guidelines
|
||||||
|
|
||||||
## Overview
|
## Overview
|
||||||
Thank you for your interest in contributing to the "Go OpenAI" project! By following this guideline, we hope to ensure that your contributions are made smoothly and efficiently. The Go OpenAI project is licensed under the [Apache 2.0 License](https://github.com/sashabaranov/go-openai/blob/master/LICENSE), and we welcome contributions through GitHub pull requests.
|
Thank you for your interest in contributing to the "Go OpenAI" project! By following this guideline, we hope to ensure that your contributions are made smoothly and efficiently. The Go OpenAI project is licensed under the [Apache 2.0 License](https://git.vaala.cloud/VaalaCat/go-openai/blob/master/LICENSE), and we welcome contributions through GitHub pull requests.
|
||||||
|
|
||||||
## Reporting Bugs
|
## Reporting Bugs
|
||||||
If you discover a bug, first check the [GitHub Issues page](https://github.com/sashabaranov/go-openai/issues) to see if the issue has already been reported. If you're reporting a new issue, please use the "Bug report" template and provide detailed information about the problem, including steps to reproduce it.
|
If you discover a bug, first check the [GitHub Issues page](https://git.vaala.cloud/VaalaCat/go-openai/issues) to see if the issue has already been reported. If you're reporting a new issue, please use the "Bug report" template and provide detailed information about the problem, including steps to reproduce it.
|
||||||
|
|
||||||
## Suggesting Features
|
## Suggesting Features
|
||||||
If you want to suggest a new feature or improvement, first check the [GitHub Issues page](https://github.com/sashabaranov/go-openai/issues) to ensure a similar suggestion hasn't already been made. Use the "Feature request" template to provide a detailed description of your suggestion.
|
If you want to suggest a new feature or improvement, first check the [GitHub Issues page](https://git.vaala.cloud/VaalaCat/go-openai/issues) to ensure a similar suggestion hasn't already been made. Use the "Feature request" template to provide a detailed description of your suggestion.
|
||||||
|
|
||||||
## Reporting Vulnerabilities
|
## Reporting Vulnerabilities
|
||||||
If you identify a security concern, please use the "Report a security vulnerability" template on the [GitHub Issues page](https://github.com/sashabaranov/go-openai/issues) to share the details. This report will only be viewable to repository maintainers. You will be credited if the advisory is published.
|
If you identify a security concern, please use the "Report a security vulnerability" template on the [GitHub Issues page](https://git.vaala.cloud/VaalaCat/go-openai/issues) to share the details. This report will only be viewable to repository maintainers. You will be credited if the advisory is published.
|
||||||
|
|
||||||
## Questions for Users
|
## Questions for Users
|
||||||
If you have questions, please utilize [StackOverflow](https://stackoverflow.com/) or the [GitHub Discussions page](https://github.com/sashabaranov/go-openai/discussions).
|
If you have questions, please utilize [StackOverflow](https://stackoverflow.com/) or the [GitHub Discussions page](https://git.vaala.cloud/VaalaCat/go-openai/discussions).
|
||||||
|
|
||||||
## Contributing Code
|
## Contributing Code
|
||||||
There might already be a similar pull requests submitted! Please search for [pull requests](https://github.com/sashabaranov/go-openai/pulls) before creating one.
|
There might already be a similar pull requests submitted! Please search for [pull requests](https://git.vaala.cloud/VaalaCat/go-openai/pulls) before creating one.
|
||||||
|
|
||||||
### Requirements for Merging a Pull Request
|
### Requirements for Merging a Pull Request
|
||||||
|
|
||||||
|
|||||||
106
README.md
106
README.md
@@ -1,19 +1,19 @@
|
|||||||
# Go OpenAI
|
# Go OpenAI
|
||||||
[](https://pkg.go.dev/github.com/sashabaranov/go-openai)
|
[](https://pkg.go.dev/git.vaala.cloud/VaalaCat/go-openai)
|
||||||
[](https://goreportcard.com/report/github.com/sashabaranov/go-openai)
|
[](https://goreportcard.com/report/git.vaala.cloud/VaalaCat/go-openai)
|
||||||
[](https://codecov.io/gh/sashabaranov/go-openai)
|
[](https://codecov.io/gh/sashabaranov/go-openai)
|
||||||
|
|
||||||
This library provides unofficial Go clients for [OpenAI API](https://platform.openai.com/). We support:
|
This library provides unofficial Go clients for [OpenAI API](https://platform.openai.com/). We support:
|
||||||
|
|
||||||
* ChatGPT 4o, o1
|
* ChatGPT 4o, o1
|
||||||
* GPT-3, GPT-4
|
* GPT-3, GPT-4
|
||||||
* DALL·E 2, DALL·E 3
|
* DALL·E 2, DALL·E 3, GPT Image 1
|
||||||
* Whisper
|
* Whisper
|
||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
|
||||||
```
|
```
|
||||||
go get github.com/sashabaranov/go-openai
|
go get git.vaala.cloud/VaalaCat/go-openai
|
||||||
```
|
```
|
||||||
Currently, go-openai requires Go version 1.18 or greater.
|
Currently, go-openai requires Go version 1.18 or greater.
|
||||||
|
|
||||||
@@ -28,7 +28,7 @@ package main
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
openai "github.com/sashabaranov/go-openai"
|
openai "git.vaala.cloud/VaalaCat/go-openai"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
@@ -80,7 +80,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
openai "github.com/sashabaranov/go-openai"
|
openai "git.vaala.cloud/VaalaCat/go-openai"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
@@ -133,7 +133,7 @@ package main
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
openai "github.com/sashabaranov/go-openai"
|
openai "git.vaala.cloud/VaalaCat/go-openai"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
@@ -166,7 +166,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
openai "github.com/sashabaranov/go-openai"
|
openai "git.vaala.cloud/VaalaCat/go-openai"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
@@ -215,7 +215,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
openai "github.com/sashabaranov/go-openai"
|
openai "git.vaala.cloud/VaalaCat/go-openai"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
@@ -247,7 +247,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
openai "github.com/sashabaranov/go-openai"
|
openai "git.vaala.cloud/VaalaCat/go-openai"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
@@ -288,7 +288,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"fmt"
|
"fmt"
|
||||||
openai "github.com/sashabaranov/go-openai"
|
openai "git.vaala.cloud/VaalaCat/go-openai"
|
||||||
"image/png"
|
"image/png"
|
||||||
"os"
|
"os"
|
||||||
)
|
)
|
||||||
@@ -357,6 +357,66 @@ func main() {
|
|||||||
```
|
```
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>GPT Image 1 image generation</summary>
|
||||||
|
|
||||||
|
```go
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
openai "github.com/sashabaranov/go-openai"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
c := openai.NewClient("your token")
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
req := openai.ImageRequest{
|
||||||
|
Prompt: "Parrot on a skateboard performing a trick. Large bold text \"SKATE MASTER\" banner at the bottom of the image. Cartoon style, natural light, high detail, 1:1 aspect ratio.",
|
||||||
|
Background: openai.CreateImageBackgroundOpaque,
|
||||||
|
Model: openai.CreateImageModelGptImage1,
|
||||||
|
Size: openai.CreateImageSize1024x1024,
|
||||||
|
N: 1,
|
||||||
|
Quality: openai.CreateImageQualityLow,
|
||||||
|
OutputCompression: 100,
|
||||||
|
OutputFormat: openai.CreateImageOutputFormatJPEG,
|
||||||
|
// Moderation: openai.CreateImageModerationLow,
|
||||||
|
// User: "",
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := c.CreateImage(ctx, req)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Image creation Image generation with GPT Image 1error: %v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("Image Base64:", resp.Data[0].B64JSON)
|
||||||
|
|
||||||
|
// Decode the base64 data
|
||||||
|
imgBytes, err := base64.StdEncoding.DecodeString(resp.Data[0].B64JSON)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Base64 decode error: %v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write image to file
|
||||||
|
outputPath := "generated_image.jpg"
|
||||||
|
err = os.WriteFile(outputPath, imgBytes, 0644)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Failed to write image file: %v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("The image was saved as %s\n", outputPath)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
</details>
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
<summary>Configuring proxy</summary>
|
<summary>Configuring proxy</summary>
|
||||||
|
|
||||||
@@ -376,7 +436,7 @@ config.HTTPClient = &http.Client{
|
|||||||
c := openai.NewClientWithConfig(config)
|
c := openai.NewClientWithConfig(config)
|
||||||
```
|
```
|
||||||
|
|
||||||
See also: https://pkg.go.dev/github.com/sashabaranov/go-openai#ClientConfig
|
See also: https://pkg.go.dev/git.vaala.cloud/VaalaCat/go-openai#ClientConfig
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
@@ -392,7 +452,7 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/sashabaranov/go-openai"
|
"git.vaala.cloud/VaalaCat/go-openai"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
@@ -446,7 +506,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
openai "github.com/sashabaranov/go-openai"
|
openai "git.vaala.cloud/VaalaCat/go-openai"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
@@ -492,7 +552,7 @@ package main
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"log"
|
"log"
|
||||||
openai "github.com/sashabaranov/go-openai"
|
openai "git.vaala.cloud/VaalaCat/go-openai"
|
||||||
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -549,7 +609,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
openai "github.com/sashabaranov/go-openai"
|
openai "git.vaala.cloud/VaalaCat/go-openai"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
@@ -680,7 +740,7 @@ package main
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/sashabaranov/go-openai"
|
"git.vaala.cloud/VaalaCat/go-openai"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
@@ -755,8 +815,8 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
|
|
||||||
"github.com/sashabaranov/go-openai"
|
"git.vaala.cloud/VaalaCat/go-openai"
|
||||||
"github.com/sashabaranov/go-openai/jsonschema"
|
"git.vaala.cloud/VaalaCat/go-openai/jsonschema"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
@@ -828,7 +888,7 @@ Due to the factors mentioned above, different answers may be returned even for t
|
|||||||
By adopting these strategies, you can expect more consistent results.
|
By adopting these strategies, you can expect more consistent results.
|
||||||
|
|
||||||
**Related Issues:**
|
**Related Issues:**
|
||||||
[omitempty option of request struct will generate incorrect request when parameter is 0.](https://github.com/sashabaranov/go-openai/issues/9)
|
[omitempty option of request struct will generate incorrect request when parameter is 0.](https://git.vaala.cloud/VaalaCat/go-openai/issues/9)
|
||||||
|
|
||||||
### Does Go OpenAI provide a method to count tokens?
|
### Does Go OpenAI provide a method to count tokens?
|
||||||
|
|
||||||
@@ -839,15 +899,15 @@ For counting tokens, you might find the following links helpful:
|
|||||||
- [How to count tokens with tiktoken](https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb)
|
- [How to count tokens with tiktoken](https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb)
|
||||||
|
|
||||||
**Related Issues:**
|
**Related Issues:**
|
||||||
[Is it possible to join the implementation of GPT3 Tokenizer](https://github.com/sashabaranov/go-openai/issues/62)
|
[Is it possible to join the implementation of GPT3 Tokenizer](https://git.vaala.cloud/VaalaCat/go-openai/issues/62)
|
||||||
|
|
||||||
## Contributing
|
## Contributing
|
||||||
|
|
||||||
By following [Contributing Guidelines](https://github.com/sashabaranov/go-openai/blob/master/CONTRIBUTING.md), we hope to ensure that your contributions are made smoothly and efficiently.
|
By following [Contributing Guidelines](https://git.vaala.cloud/VaalaCat/go-openai/blob/master/CONTRIBUTING.md), we hope to ensure that your contributions are made smoothly and efficiently.
|
||||||
|
|
||||||
## Thank you
|
## Thank you
|
||||||
|
|
||||||
We want to take a moment to express our deepest gratitude to the [contributors](https://github.com/sashabaranov/go-openai/graphs/contributors) and sponsors of this project:
|
We want to take a moment to express our deepest gratitude to the [contributors](https://git.vaala.cloud/VaalaCat/go-openai/graphs/contributors) and sponsors of this project:
|
||||||
- [Carson Kahn](https://carsonkahn.com) of [Spindle AI](https://spindleai.com)
|
- [Carson Kahn](https://carsonkahn.com) of [Spindle AI](https://spindleai.com)
|
||||||
|
|
||||||
To all of you: thank you. You've helped us achieve more than we ever imagined possible. Can't wait to see where we go next, together!
|
To all of you: thank you. You've helped us achieve more than we ever imagined possible. Can't wait to see where we go next, together!
|
||||||
|
|||||||
@@ -10,9 +10,9 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/sashabaranov/go-openai"
|
"git.vaala.cloud/VaalaCat/go-openai"
|
||||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
"git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
|
||||||
"github.com/sashabaranov/go-openai/jsonschema"
|
"git.vaala.cloud/VaalaCat/go-openai/jsonschema"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestAPI(t *testing.T) {
|
func TestAPI(t *testing.T) {
|
||||||
|
|||||||
@@ -3,8 +3,8 @@ package openai_test
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
|
||||||
openai "github.com/sashabaranov/go-openai"
|
openai "git.vaala.cloud/VaalaCat/go-openai"
|
||||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
"git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
|
||||||
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|||||||
2
audio.go
2
audio.go
@@ -8,7 +8,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
utils "github.com/sashabaranov/go-openai/internal"
|
utils "git.vaala.cloud/VaalaCat/go-openai/internal"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Whisper Defines the models provided by OpenAI to use when processing audio with OpenAI.
|
// Whisper Defines the models provided by OpenAI to use when processing audio with OpenAI.
|
||||||
|
|||||||
@@ -12,9 +12,9 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/sashabaranov/go-openai"
|
"git.vaala.cloud/VaalaCat/go-openai"
|
||||||
"github.com/sashabaranov/go-openai/internal/test"
|
"git.vaala.cloud/VaalaCat/go-openai/internal/test"
|
||||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
"git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TestAudio Tests the transcription and translation endpoints of the API using the mocked server.
|
// TestAudio Tests the transcription and translation endpoints of the API using the mocked server.
|
||||||
|
|||||||
135
audio_test.go
135
audio_test.go
@@ -2,14 +2,17 @@ package openai //nolint:testpackage // testing private field
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/sashabaranov/go-openai/internal/test"
|
"git.vaala.cloud/VaalaCat/go-openai/internal/test"
|
||||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
"git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestAudioWithFailingFormBuilder(t *testing.T) {
|
func TestAudioWithFailingFormBuilder(t *testing.T) {
|
||||||
@@ -107,3 +110,131 @@ func TestCreateFileField(t *testing.T) {
|
|||||||
checks.HasError(t, err, "createFileField using file should return error when open file fails")
|
checks.HasError(t, err, "createFileField using file should return error when open file fails")
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// failingFormBuilder always returns an error when creating form files.
|
||||||
|
type failingFormBuilder struct{ err error }
|
||||||
|
|
||||||
|
func (f *failingFormBuilder) CreateFormFile(_ string, _ *os.File) error {
|
||||||
|
return f.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *failingFormBuilder) CreateFormFileReader(_ string, _ io.Reader, _ string) error {
|
||||||
|
return f.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *failingFormBuilder) WriteField(_, _ string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *failingFormBuilder) Close() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *failingFormBuilder) FormDataContentType() string {
|
||||||
|
return "multipart/form-data"
|
||||||
|
}
|
||||||
|
|
||||||
|
// failingAudioRequestBuilder simulates an error during HTTP request construction.
|
||||||
|
type failingAudioRequestBuilder struct{ err error }
|
||||||
|
|
||||||
|
func (f *failingAudioRequestBuilder) Build(
|
||||||
|
_ context.Context,
|
||||||
|
_, _ string,
|
||||||
|
_ any,
|
||||||
|
_ http.Header,
|
||||||
|
) (*http.Request, error) {
|
||||||
|
return nil, f.err
|
||||||
|
}
|
||||||
|
|
||||||
|
// errorHTTPClient always returns an error when making HTTP calls.
|
||||||
|
type errorHTTPClient struct{ err error }
|
||||||
|
|
||||||
|
func (e *errorHTTPClient) Do(_ *http.Request) (*http.Response, error) {
|
||||||
|
return nil, e.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCallAudioAPIMultipartFormError(t *testing.T) {
|
||||||
|
client := NewClient("test-token")
|
||||||
|
errForm := errors.New("mock create form file failure")
|
||||||
|
// Override form builder to force an error during multipart form creation.
|
||||||
|
client.createFormBuilder = func(_ io.Writer) utils.FormBuilder {
|
||||||
|
return &failingFormBuilder{err: errForm}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Provide a reader so createFileField uses the reader path (no file open).
|
||||||
|
req := AudioRequest{FilePath: "fake.mp3", Reader: bytes.NewBuffer([]byte("dummy")), Model: Whisper1}
|
||||||
|
_, err := client.callAudioAPI(context.Background(), req, "transcriptions")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error but got none")
|
||||||
|
}
|
||||||
|
if !errors.Is(err, errForm) {
|
||||||
|
t.Errorf("expected error %v, got %v", errForm, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCallAudioAPINewRequestError(t *testing.T) {
|
||||||
|
client := NewClient("test-token")
|
||||||
|
// Create a real temp file so multipart form succeeds.
|
||||||
|
tmp := t.TempDir()
|
||||||
|
path := filepath.Join(tmp, "file.mp3")
|
||||||
|
if err := os.WriteFile(path, []byte("content"), 0644); err != nil {
|
||||||
|
t.Fatalf("failed to write temp file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
errBuild := errors.New("mock build failure")
|
||||||
|
client.requestBuilder = &failingAudioRequestBuilder{err: errBuild}
|
||||||
|
|
||||||
|
req := AudioRequest{FilePath: path, Model: Whisper1}
|
||||||
|
_, err := client.callAudioAPI(context.Background(), req, "translations")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error but got none")
|
||||||
|
}
|
||||||
|
if !errors.Is(err, errBuild) {
|
||||||
|
t.Errorf("expected error %v, got %v", errBuild, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCallAudioAPISendRequestErrorJSON(t *testing.T) {
|
||||||
|
client := NewClient("test-token")
|
||||||
|
// Create a real temp file so multipart form succeeds.
|
||||||
|
tmp := t.TempDir()
|
||||||
|
path := filepath.Join(tmp, "file.mp3")
|
||||||
|
if err := os.WriteFile(path, []byte("content"), 0644); err != nil {
|
||||||
|
t.Fatalf("failed to write temp file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
errHTTP := errors.New("mock HTTPClient failure")
|
||||||
|
// Override HTTP client to simulate a network error.
|
||||||
|
client.config.HTTPClient = &errorHTTPClient{err: errHTTP}
|
||||||
|
|
||||||
|
req := AudioRequest{FilePath: path, Model: Whisper1}
|
||||||
|
_, err := client.callAudioAPI(context.Background(), req, "transcriptions")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error but got none")
|
||||||
|
}
|
||||||
|
if !errors.Is(err, errHTTP) {
|
||||||
|
t.Errorf("expected error %v, got %v", errHTTP, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCallAudioAPISendRequestErrorText(t *testing.T) {
|
||||||
|
client := NewClient("test-token")
|
||||||
|
tmp := t.TempDir()
|
||||||
|
path := filepath.Join(tmp, "file.mp3")
|
||||||
|
if err := os.WriteFile(path, []byte("content"), 0644); err != nil {
|
||||||
|
t.Fatalf("failed to write temp file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
errHTTP := errors.New("mock HTTPClient failure")
|
||||||
|
client.config.HTTPClient = &errorHTTPClient{err: errHTTP}
|
||||||
|
|
||||||
|
// Use a non-JSON response format to exercise the text path.
|
||||||
|
req := AudioRequest{FilePath: path, Model: Whisper1, Format: AudioResponseFormatText}
|
||||||
|
_, err := client.callAudioAPI(context.Background(), req, "translations")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error but got none")
|
||||||
|
}
|
||||||
|
if !errors.Is(err, errHTTP) {
|
||||||
|
t.Errorf("expected error %v, got %v", errHTTP, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -7,8 +7,8 @@ import (
|
|||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/sashabaranov/go-openai"
|
"git.vaala.cloud/VaalaCat/go-openai"
|
||||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
"git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestUploadBatchFile(t *testing.T) {
|
func TestUploadBatchFile(t *testing.T) {
|
||||||
|
|||||||
27
chat.go
27
chat.go
@@ -14,6 +14,7 @@ const (
|
|||||||
ChatMessageRoleAssistant = "assistant"
|
ChatMessageRoleAssistant = "assistant"
|
||||||
ChatMessageRoleFunction = "function"
|
ChatMessageRoleFunction = "function"
|
||||||
ChatMessageRoleTool = "tool"
|
ChatMessageRoleTool = "tool"
|
||||||
|
ChatMessageRoleDeveloper = "developer"
|
||||||
)
|
)
|
||||||
|
|
||||||
const chatCompletionsSuffix = "/chat/completions"
|
const chatCompletionsSuffix = "/chat/completions"
|
||||||
@@ -103,6 +104,12 @@ type ChatCompletionMessage struct {
|
|||||||
// - https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
// - https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
||||||
Name string `json:"name,omitempty"`
|
Name string `json:"name,omitempty"`
|
||||||
|
|
||||||
|
// This property is used for the "reasoning" feature supported by deepseek-reasoner
|
||||||
|
// which is not in the official documentation.
|
||||||
|
// the doc from deepseek:
|
||||||
|
// - https://api-docs.deepseek.com/api/create-chat-completion#responses
|
||||||
|
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||||
|
|
||||||
FunctionCall *FunctionCall `json:"function_call,omitempty"`
|
FunctionCall *FunctionCall `json:"function_call,omitempty"`
|
||||||
|
|
||||||
// For Role=assistant prompts this may be set to the tool calls generated by the model, such as function calls.
|
// For Role=assistant prompts this may be set to the tool calls generated by the model, such as function calls.
|
||||||
@@ -123,6 +130,7 @@ func (m ChatCompletionMessage) MarshalJSON() ([]byte, error) {
|
|||||||
Refusal string `json:"refusal,omitempty"`
|
Refusal string `json:"refusal,omitempty"`
|
||||||
MultiContent []ChatMessagePart `json:"content,omitempty"`
|
MultiContent []ChatMessagePart `json:"content,omitempty"`
|
||||||
Name string `json:"name,omitempty"`
|
Name string `json:"name,omitempty"`
|
||||||
|
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||||
FunctionCall *FunctionCall `json:"function_call,omitempty"`
|
FunctionCall *FunctionCall `json:"function_call,omitempty"`
|
||||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||||
@@ -136,6 +144,7 @@ func (m ChatCompletionMessage) MarshalJSON() ([]byte, error) {
|
|||||||
Refusal string `json:"refusal,omitempty"`
|
Refusal string `json:"refusal,omitempty"`
|
||||||
MultiContent []ChatMessagePart `json:"-"`
|
MultiContent []ChatMessagePart `json:"-"`
|
||||||
Name string `json:"name,omitempty"`
|
Name string `json:"name,omitempty"`
|
||||||
|
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||||
FunctionCall *FunctionCall `json:"function_call,omitempty"`
|
FunctionCall *FunctionCall `json:"function_call,omitempty"`
|
||||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||||
@@ -146,10 +155,11 @@ func (m ChatCompletionMessage) MarshalJSON() ([]byte, error) {
|
|||||||
func (m *ChatCompletionMessage) UnmarshalJSON(bs []byte) error {
|
func (m *ChatCompletionMessage) UnmarshalJSON(bs []byte) error {
|
||||||
msg := struct {
|
msg := struct {
|
||||||
Role string `json:"role"`
|
Role string `json:"role"`
|
||||||
Content string `json:"content,omitempty"`
|
Content string `json:"content"`
|
||||||
Refusal string `json:"refusal,omitempty"`
|
Refusal string `json:"refusal,omitempty"`
|
||||||
MultiContent []ChatMessagePart
|
MultiContent []ChatMessagePart
|
||||||
Name string `json:"name,omitempty"`
|
Name string `json:"name,omitempty"`
|
||||||
|
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||||
FunctionCall *FunctionCall `json:"function_call,omitempty"`
|
FunctionCall *FunctionCall `json:"function_call,omitempty"`
|
||||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||||
@@ -165,6 +175,7 @@ func (m *ChatCompletionMessage) UnmarshalJSON(bs []byte) error {
|
|||||||
Refusal string `json:"refusal,omitempty"`
|
Refusal string `json:"refusal,omitempty"`
|
||||||
MultiContent []ChatMessagePart `json:"content"`
|
MultiContent []ChatMessagePart `json:"content"`
|
||||||
Name string `json:"name,omitempty"`
|
Name string `json:"name,omitempty"`
|
||||||
|
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||||
FunctionCall *FunctionCall `json:"function_call,omitempty"`
|
FunctionCall *FunctionCall `json:"function_call,omitempty"`
|
||||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||||
@@ -262,6 +273,15 @@ type ChatCompletionRequest struct {
|
|||||||
ReasoningEffort string `json:"reasoning_effort,omitempty"`
|
ReasoningEffort string `json:"reasoning_effort,omitempty"`
|
||||||
// Metadata to store with the completion.
|
// Metadata to store with the completion.
|
||||||
Metadata map[string]string `json:"metadata,omitempty"`
|
Metadata map[string]string `json:"metadata,omitempty"`
|
||||||
|
IncludeReasoning *bool `json:"include_reasoning,omitempty"`
|
||||||
|
ReasoningFormat *string `json:"reasoning_format,omitempty"`
|
||||||
|
// Configuration for a predicted output.
|
||||||
|
Prediction *Prediction `json:"prediction,omitempty"`
|
||||||
|
// ChatTemplateKwargs provides a way to add non-standard parameters to the request body.
|
||||||
|
// Additional kwargs to pass to the template renderer. Will be accessible by the chat template.
|
||||||
|
// Such as think mode for qwen3. "chat_template_kwargs": {"enable_thinking": false}
|
||||||
|
// https://qwen.readthedocs.io/en/latest/deployment/vllm.html#thinking-non-thinking-modes
|
||||||
|
ChatTemplateKwargs map[string]any `json:"chat_template_kwargs,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type StreamOptions struct {
|
type StreamOptions struct {
|
||||||
@@ -329,6 +349,11 @@ type LogProbs struct {
|
|||||||
Content []LogProb `json:"content"`
|
Content []LogProb `json:"content"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Prediction struct {
|
||||||
|
Content string `json:"content"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
}
|
||||||
|
|
||||||
type FinishReason string
|
type FinishReason string
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
|||||||
@@ -11,6 +11,12 @@ type ChatCompletionStreamChoiceDelta struct {
|
|||||||
FunctionCall *FunctionCall `json:"function_call,omitempty"`
|
FunctionCall *FunctionCall `json:"function_call,omitempty"`
|
||||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||||
Refusal string `json:"refusal,omitempty"`
|
Refusal string `json:"refusal,omitempty"`
|
||||||
|
|
||||||
|
// This property is used for the "reasoning" feature supported by deepseek-reasoner
|
||||||
|
// which is not in the official documentation.
|
||||||
|
// the doc from deepseek:
|
||||||
|
// - https://api-docs.deepseek.com/api/create-chat-completion#responses
|
||||||
|
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ChatCompletionStreamChoiceLogprobs struct {
|
type ChatCompletionStreamChoiceLogprobs struct {
|
||||||
|
|||||||
@@ -10,8 +10,8 @@ import (
|
|||||||
"strconv"
|
"strconv"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/sashabaranov/go-openai"
|
"git.vaala.cloud/VaalaCat/go-openai"
|
||||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
"git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestChatCompletionsStreamWrongModel(t *testing.T) {
|
func TestChatCompletionsStreamWrongModel(t *testing.T) {
|
||||||
@@ -959,6 +959,56 @@ func TestCreateChatCompletionStreamReasoningValidatorFails(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCreateChatCompletionStreamO3ReasoningValidatorFails(t *testing.T) {
|
||||||
|
client, _, _ := setupOpenAITestServer()
|
||||||
|
|
||||||
|
stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{
|
||||||
|
MaxTokens: 100, // This will trigger the validator to fail
|
||||||
|
Model: openai.O3,
|
||||||
|
Messages: []openai.ChatCompletionMessage{
|
||||||
|
{
|
||||||
|
Role: openai.ChatMessageRoleUser,
|
||||||
|
Content: "Hello!",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Stream: true,
|
||||||
|
})
|
||||||
|
|
||||||
|
if stream != nil {
|
||||||
|
t.Error("Expected nil stream when validation fails")
|
||||||
|
stream.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
if !errors.Is(err, openai.ErrReasoningModelMaxTokensDeprecated) {
|
||||||
|
t.Errorf("Expected ErrReasoningModelMaxTokensDeprecated for O3, got: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateChatCompletionStreamO4MiniReasoningValidatorFails(t *testing.T) {
|
||||||
|
client, _, _ := setupOpenAITestServer()
|
||||||
|
|
||||||
|
stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{
|
||||||
|
MaxTokens: 100, // This will trigger the validator to fail
|
||||||
|
Model: openai.O4Mini,
|
||||||
|
Messages: []openai.ChatCompletionMessage{
|
||||||
|
{
|
||||||
|
Role: openai.ChatMessageRoleUser,
|
||||||
|
Content: "Hello!",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Stream: true,
|
||||||
|
})
|
||||||
|
|
||||||
|
if stream != nil {
|
||||||
|
t.Error("Expected nil stream when validation fails")
|
||||||
|
stream.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
if !errors.Is(err, openai.ErrReasoningModelMaxTokensDeprecated) {
|
||||||
|
t.Errorf("Expected ErrReasoningModelMaxTokensDeprecated for O4Mini, got: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func compareChatStreamResponseChoices(c1, c2 openai.ChatCompletionStreamChoice) bool {
|
func compareChatStreamResponseChoices(c1, c2 openai.ChatCompletionStreamChoice) bool {
|
||||||
if c1.Index != c2.Index {
|
if c1.Index != c2.Index {
|
||||||
return false
|
return false
|
||||||
|
|||||||
119
chat_test.go
119
chat_test.go
@@ -12,9 +12,9 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/sashabaranov/go-openai"
|
"git.vaala.cloud/VaalaCat/go-openai"
|
||||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
"git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
|
||||||
"github.com/sashabaranov/go-openai/jsonschema"
|
"git.vaala.cloud/VaalaCat/go-openai/jsonschema"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -106,40 +106,6 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) {
|
|||||||
},
|
},
|
||||||
expectedError: openai.ErrReasoningModelLimitationsLogprobs,
|
expectedError: openai.ErrReasoningModelLimitationsLogprobs,
|
||||||
},
|
},
|
||||||
{
|
|
||||||
name: "message_type_unsupported",
|
|
||||||
in: openai.ChatCompletionRequest{
|
|
||||||
MaxCompletionTokens: 1000,
|
|
||||||
Model: openai.O1Mini,
|
|
||||||
Messages: []openai.ChatCompletionMessage{
|
|
||||||
{
|
|
||||||
Role: openai.ChatMessageRoleSystem,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
expectedError: openai.ErrO1BetaLimitationsMessageTypes,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "tool_unsupported",
|
|
||||||
in: openai.ChatCompletionRequest{
|
|
||||||
MaxCompletionTokens: 1000,
|
|
||||||
Model: openai.O1Mini,
|
|
||||||
Messages: []openai.ChatCompletionMessage{
|
|
||||||
{
|
|
||||||
Role: openai.ChatMessageRoleUser,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Role: openai.ChatMessageRoleAssistant,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
Tools: []openai.Tool{
|
|
||||||
{
|
|
||||||
Type: openai.ToolTypeFunction,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
expectedError: openai.ErrO1BetaLimitationsTools,
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
name: "set_temperature_unsupported",
|
name: "set_temperature_unsupported",
|
||||||
in: openai.ChatCompletionRequest{
|
in: openai.ChatCompletionRequest{
|
||||||
@@ -445,6 +411,23 @@ func TestO3ModelChatCompletions(t *testing.T) {
|
|||||||
checks.NoError(t, err, "CreateChatCompletion error")
|
checks.NoError(t, err, "CreateChatCompletion error")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDeepseekR1ModelChatCompletions(t *testing.T) {
|
||||||
|
client, server, teardown := setupOpenAITestServer()
|
||||||
|
defer teardown()
|
||||||
|
server.RegisterHandler("/v1/chat/completions", handleDeepseekR1ChatCompletionEndpoint)
|
||||||
|
_, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{
|
||||||
|
Model: "deepseek-reasoner",
|
||||||
|
MaxCompletionTokens: 100,
|
||||||
|
Messages: []openai.ChatCompletionMessage{
|
||||||
|
{
|
||||||
|
Role: openai.ChatMessageRoleUser,
|
||||||
|
Content: "Hello!",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
checks.NoError(t, err, "CreateChatCompletion error")
|
||||||
|
}
|
||||||
|
|
||||||
// TestCompletions Tests the completions endpoint of the API using the mocked server.
|
// TestCompletions Tests the completions endpoint of the API using the mocked server.
|
||||||
func TestChatCompletionsWithHeaders(t *testing.T) {
|
func TestChatCompletionsWithHeaders(t *testing.T) {
|
||||||
client, server, teardown := setupOpenAITestServer()
|
client, server, teardown := setupOpenAITestServer()
|
||||||
@@ -856,6 +839,68 @@ func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
|
|||||||
fmt.Fprintln(w, string(resBytes))
|
fmt.Fprintln(w, string(resBytes))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func handleDeepseekR1ChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
|
||||||
|
var err error
|
||||||
|
var resBytes []byte
|
||||||
|
|
||||||
|
// completions only accepts POST requests
|
||||||
|
if r.Method != "POST" {
|
||||||
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||||
|
}
|
||||||
|
var completionReq openai.ChatCompletionRequest
|
||||||
|
if completionReq, err = getChatCompletionBody(r); err != nil {
|
||||||
|
http.Error(w, "could not read request", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
res := openai.ChatCompletionResponse{
|
||||||
|
ID: strconv.Itoa(int(time.Now().Unix())),
|
||||||
|
Object: "test-object",
|
||||||
|
Created: time.Now().Unix(),
|
||||||
|
// would be nice to validate Model during testing, but
|
||||||
|
// this may not be possible with how much upkeep
|
||||||
|
// would be required / wouldn't make much sense
|
||||||
|
Model: completionReq.Model,
|
||||||
|
}
|
||||||
|
// create completions
|
||||||
|
n := completionReq.N
|
||||||
|
if n == 0 {
|
||||||
|
n = 1
|
||||||
|
}
|
||||||
|
if completionReq.MaxCompletionTokens == 0 {
|
||||||
|
completionReq.MaxCompletionTokens = 1000
|
||||||
|
}
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
reasoningContent := "User says hello! And I need to reply"
|
||||||
|
completionStr := strings.Repeat("a", completionReq.MaxCompletionTokens-numTokens(reasoningContent))
|
||||||
|
res.Choices = append(res.Choices, openai.ChatCompletionChoice{
|
||||||
|
Message: openai.ChatCompletionMessage{
|
||||||
|
Role: openai.ChatMessageRoleAssistant,
|
||||||
|
ReasoningContent: reasoningContent,
|
||||||
|
Content: completionStr,
|
||||||
|
},
|
||||||
|
Index: i,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
inputTokens := numTokens(completionReq.Messages[0].Content) * n
|
||||||
|
completionTokens := completionReq.MaxTokens * n
|
||||||
|
res.Usage = openai.Usage{
|
||||||
|
PromptTokens: inputTokens,
|
||||||
|
CompletionTokens: completionTokens,
|
||||||
|
TotalTokens: inputTokens + completionTokens,
|
||||||
|
}
|
||||||
|
resBytes, _ = json.Marshal(res)
|
||||||
|
w.Header().Set(xCustomHeader, xCustomHeaderValue)
|
||||||
|
for k, v := range rateLimitHeaders {
|
||||||
|
switch val := v.(type) {
|
||||||
|
case int:
|
||||||
|
w.Header().Set(k, strconv.Itoa(val))
|
||||||
|
default:
|
||||||
|
w.Header().Set(k, fmt.Sprintf("%s", v))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fmt.Fprintln(w, string(resBytes))
|
||||||
|
}
|
||||||
|
|
||||||
// getChatCompletionBody Returns the body of the request to create a completion.
|
// getChatCompletionBody Returns the body of the request to create a completion.
|
||||||
func getChatCompletionBody(r *http.Request) (openai.ChatCompletionRequest, error) {
|
func getChatCompletionBody(r *http.Request) (openai.ChatCompletionRequest, error) {
|
||||||
completion := openai.ChatCompletionRequest{}
|
completion := openai.ChatCompletionRequest{}
|
||||||
|
|||||||
16
client.go
16
client.go
@@ -10,7 +10,7 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
utils "github.com/sashabaranov/go-openai/internal"
|
utils "git.vaala.cloud/VaalaCat/go-openai/internal"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Client is OpenAI GPT-3 API client.
|
// Client is OpenAI GPT-3 API client.
|
||||||
@@ -182,13 +182,21 @@ func sendRequestStream[T streamable](client *Client, req *http.Request) (*stream
|
|||||||
|
|
||||||
func (c *Client) setCommonHeaders(req *http.Request) {
|
func (c *Client) setCommonHeaders(req *http.Request) {
|
||||||
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/reference#authentication
|
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/reference#authentication
|
||||||
|
switch c.config.APIType {
|
||||||
|
case APITypeAzure, APITypeCloudflareAzure:
|
||||||
// Azure API Key authentication
|
// Azure API Key authentication
|
||||||
if c.config.APIType == APITypeAzure || c.config.APIType == APITypeCloudflareAzure {
|
|
||||||
req.Header.Set(AzureAPIKeyHeader, c.config.authToken)
|
req.Header.Set(AzureAPIKeyHeader, c.config.authToken)
|
||||||
} else if c.config.authToken != "" {
|
case APITypeAnthropic:
|
||||||
// OpenAI or Azure AD authentication
|
// https://docs.anthropic.com/en/api/versioning
|
||||||
|
req.Header.Set("anthropic-version", c.config.APIVersion)
|
||||||
|
case APITypeOpenAI, APITypeAzureAD:
|
||||||
|
fallthrough
|
||||||
|
default:
|
||||||
|
if c.config.authToken != "" {
|
||||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken))
|
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken))
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if c.config.OrgID != "" {
|
if c.config.OrgID != "" {
|
||||||
req.Header.Set("OpenAI-Organization", c.config.OrgID)
|
req.Header.Set("OpenAI-Organization", c.config.OrgID)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,8 +10,8 @@ import (
|
|||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/sashabaranov/go-openai/internal/test"
|
"git.vaala.cloud/VaalaCat/go-openai/internal/test"
|
||||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
"git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
|
||||||
)
|
)
|
||||||
|
|
||||||
var errTestRequestBuilderFailed = errors.New("test request builder failed")
|
var errTestRequestBuilderFailed = errors.New("test request builder failed")
|
||||||
@@ -39,6 +39,21 @@ func TestClient(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSetCommonHeadersAnthropic(t *testing.T) {
|
||||||
|
config := DefaultAnthropicConfig("mock-token", "")
|
||||||
|
client := NewClientWithConfig(config)
|
||||||
|
req, err := http.NewRequest("GET", "http://example.com", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create request: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
client.setCommonHeaders(req)
|
||||||
|
|
||||||
|
if got := req.Header.Get("anthropic-version"); got != AnthropicAPIVersion {
|
||||||
|
t.Errorf("Expected anthropic-version header to be %q, got %q", AnthropicAPIVersion, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestDecodeResponse(t *testing.T) {
|
func TestDecodeResponse(t *testing.T) {
|
||||||
stringInput := ""
|
stringInput := ""
|
||||||
|
|
||||||
|
|||||||
@@ -15,6 +15,8 @@ type Usage struct {
|
|||||||
type CompletionTokensDetails struct {
|
type CompletionTokensDetails struct {
|
||||||
AudioTokens int `json:"audio_tokens"`
|
AudioTokens int `json:"audio_tokens"`
|
||||||
ReasoningTokens int `json:"reasoning_tokens"`
|
ReasoningTokens int `json:"reasoning_tokens"`
|
||||||
|
AcceptedPredictionTokens int `json:"accepted_prediction_tokens"`
|
||||||
|
RejectedPredictionTokens int `json:"rejected_prediction_tokens"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// PromptTokensDetails Breakdown of tokens used in the prompt.
|
// PromptTokensDetails Breakdown of tokens used in the prompt.
|
||||||
|
|||||||
@@ -16,8 +16,12 @@ const (
|
|||||||
O1Preview20240912 = "o1-preview-2024-09-12"
|
O1Preview20240912 = "o1-preview-2024-09-12"
|
||||||
O1 = "o1"
|
O1 = "o1"
|
||||||
O120241217 = "o1-2024-12-17"
|
O120241217 = "o1-2024-12-17"
|
||||||
|
O3 = "o3"
|
||||||
|
O320250416 = "o3-2025-04-16"
|
||||||
O3Mini = "o3-mini"
|
O3Mini = "o3-mini"
|
||||||
O3Mini20250131 = "o3-mini-2025-01-31"
|
O3Mini20250131 = "o3-mini-2025-01-31"
|
||||||
|
O4Mini = "o4-mini"
|
||||||
|
O4Mini20250416 = "o4-mini-2025-04-16"
|
||||||
GPT432K0613 = "gpt-4-32k-0613"
|
GPT432K0613 = "gpt-4-32k-0613"
|
||||||
GPT432K0314 = "gpt-4-32k-0314"
|
GPT432K0314 = "gpt-4-32k-0314"
|
||||||
GPT432K = "gpt-4-32k"
|
GPT432K = "gpt-4-32k"
|
||||||
@@ -37,6 +41,14 @@ const (
|
|||||||
GPT4TurboPreview = "gpt-4-turbo-preview"
|
GPT4TurboPreview = "gpt-4-turbo-preview"
|
||||||
GPT4VisionPreview = "gpt-4-vision-preview"
|
GPT4VisionPreview = "gpt-4-vision-preview"
|
||||||
GPT4 = "gpt-4"
|
GPT4 = "gpt-4"
|
||||||
|
GPT4Dot1 = "gpt-4.1"
|
||||||
|
GPT4Dot120250414 = "gpt-4.1-2025-04-14"
|
||||||
|
GPT4Dot1Mini = "gpt-4.1-mini"
|
||||||
|
GPT4Dot1Mini20250414 = "gpt-4.1-mini-2025-04-14"
|
||||||
|
GPT4Dot1Nano = "gpt-4.1-nano"
|
||||||
|
GPT4Dot1Nano20250414 = "gpt-4.1-nano-2025-04-14"
|
||||||
|
GPT4Dot5Preview = "gpt-4.5-preview"
|
||||||
|
GPT4Dot5Preview20250227 = "gpt-4.5-preview-2025-02-27"
|
||||||
GPT3Dot5Turbo0125 = "gpt-3.5-turbo-0125"
|
GPT3Dot5Turbo0125 = "gpt-3.5-turbo-0125"
|
||||||
GPT3Dot5Turbo1106 = "gpt-3.5-turbo-1106"
|
GPT3Dot5Turbo1106 = "gpt-3.5-turbo-1106"
|
||||||
GPT3Dot5Turbo0613 = "gpt-3.5-turbo-0613"
|
GPT3Dot5Turbo0613 = "gpt-3.5-turbo-0613"
|
||||||
@@ -91,6 +103,10 @@ var disabledModelsForEndpoints = map[string]map[string]bool{
|
|||||||
O1Preview20240912: true,
|
O1Preview20240912: true,
|
||||||
O3Mini: true,
|
O3Mini: true,
|
||||||
O3Mini20250131: true,
|
O3Mini20250131: true,
|
||||||
|
O4Mini: true,
|
||||||
|
O4Mini20250416: true,
|
||||||
|
O3: true,
|
||||||
|
O320250416: true,
|
||||||
GPT3Dot5Turbo: true,
|
GPT3Dot5Turbo: true,
|
||||||
GPT3Dot5Turbo0301: true,
|
GPT3Dot5Turbo0301: true,
|
||||||
GPT3Dot5Turbo0613: true,
|
GPT3Dot5Turbo0613: true,
|
||||||
@@ -99,6 +115,8 @@ var disabledModelsForEndpoints = map[string]map[string]bool{
|
|||||||
GPT3Dot5Turbo16K: true,
|
GPT3Dot5Turbo16K: true,
|
||||||
GPT3Dot5Turbo16K0613: true,
|
GPT3Dot5Turbo16K0613: true,
|
||||||
GPT4: true,
|
GPT4: true,
|
||||||
|
GPT4Dot5Preview: true,
|
||||||
|
GPT4Dot5Preview20250227: true,
|
||||||
GPT4o: true,
|
GPT4o: true,
|
||||||
GPT4o20240513: true,
|
GPT4o20240513: true,
|
||||||
GPT4o20240806: true,
|
GPT4o20240806: true,
|
||||||
@@ -117,6 +135,13 @@ var disabledModelsForEndpoints = map[string]map[string]bool{
|
|||||||
GPT432K: true,
|
GPT432K: true,
|
||||||
GPT432K0314: true,
|
GPT432K0314: true,
|
||||||
GPT432K0613: true,
|
GPT432K0613: true,
|
||||||
|
O1: true,
|
||||||
|
GPT4Dot1: true,
|
||||||
|
GPT4Dot120250414: true,
|
||||||
|
GPT4Dot1Mini: true,
|
||||||
|
GPT4Dot1Mini20250414: true,
|
||||||
|
GPT4Dot1Nano: true,
|
||||||
|
GPT4Dot1Nano20250414: true,
|
||||||
},
|
},
|
||||||
chatCompletionsSuffix: {
|
chatCompletionsSuffix: {
|
||||||
CodexCodeDavinci002: true,
|
CodexCodeDavinci002: true,
|
||||||
@@ -190,6 +215,8 @@ type CompletionRequest struct {
|
|||||||
Temperature float32 `json:"temperature,omitempty"`
|
Temperature float32 `json:"temperature,omitempty"`
|
||||||
TopP float32 `json:"top_p,omitempty"`
|
TopP float32 `json:"top_p,omitempty"`
|
||||||
User string `json:"user,omitempty"`
|
User string `json:"user,omitempty"`
|
||||||
|
// Options for streaming response. Only set this when you set stream: true.
|
||||||
|
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// CompletionChoice represents one of possible completions.
|
// CompletionChoice represents one of possible completions.
|
||||||
|
|||||||
@@ -12,8 +12,8 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/sashabaranov/go-openai"
|
"git.vaala.cloud/VaalaCat/go-openai"
|
||||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
"git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestCompletionsWrongModel(t *testing.T) {
|
func TestCompletionsWrongModel(t *testing.T) {
|
||||||
@@ -33,6 +33,42 @@ func TestCompletionsWrongModel(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestCompletionsWrongModelO3 Tests the completions endpoint with O3 model which is not supported.
|
||||||
|
func TestCompletionsWrongModelO3(t *testing.T) {
|
||||||
|
config := openai.DefaultConfig("whatever")
|
||||||
|
config.BaseURL = "http://localhost/v1"
|
||||||
|
client := openai.NewClientWithConfig(config)
|
||||||
|
|
||||||
|
_, err := client.CreateCompletion(
|
||||||
|
context.Background(),
|
||||||
|
openai.CompletionRequest{
|
||||||
|
MaxTokens: 5,
|
||||||
|
Model: openai.O3,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if !errors.Is(err, openai.ErrCompletionUnsupportedModel) {
|
||||||
|
t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel for O3, but returned: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestCompletionsWrongModelO4Mini Tests the completions endpoint with O4Mini model which is not supported.
|
||||||
|
func TestCompletionsWrongModelO4Mini(t *testing.T) {
|
||||||
|
config := openai.DefaultConfig("whatever")
|
||||||
|
config.BaseURL = "http://localhost/v1"
|
||||||
|
client := openai.NewClientWithConfig(config)
|
||||||
|
|
||||||
|
_, err := client.CreateCompletion(
|
||||||
|
context.Background(),
|
||||||
|
openai.CompletionRequest{
|
||||||
|
MaxTokens: 5,
|
||||||
|
Model: openai.O4Mini,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if !errors.Is(err, openai.ErrCompletionUnsupportedModel) {
|
||||||
|
t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel for O4Mini, but returned: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestCompletionWithStream(t *testing.T) {
|
func TestCompletionWithStream(t *testing.T) {
|
||||||
config := openai.DefaultConfig("whatever")
|
config := openai.DefaultConfig("whatever")
|
||||||
client := openai.NewClientWithConfig(config)
|
client := openai.NewClientWithConfig(config)
|
||||||
@@ -181,3 +217,86 @@ func getCompletionBody(r *http.Request) (openai.CompletionRequest, error) {
|
|||||||
}
|
}
|
||||||
return completion, nil
|
return completion, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestCompletionWithO1Model Tests that O1 model is not supported for completion endpoint.
|
||||||
|
func TestCompletionWithO1Model(t *testing.T) {
|
||||||
|
config := openai.DefaultConfig("whatever")
|
||||||
|
config.BaseURL = "http://localhost/v1"
|
||||||
|
client := openai.NewClientWithConfig(config)
|
||||||
|
|
||||||
|
_, err := client.CreateCompletion(
|
||||||
|
context.Background(),
|
||||||
|
openai.CompletionRequest{
|
||||||
|
MaxTokens: 5,
|
||||||
|
Model: openai.O1,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if !errors.Is(err, openai.ErrCompletionUnsupportedModel) {
|
||||||
|
t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel for O1 model, but returned: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestCompletionWithGPT4DotModels Tests that newer GPT4 models are not supported for completion endpoint.
|
||||||
|
func TestCompletionWithGPT4DotModels(t *testing.T) {
|
||||||
|
config := openai.DefaultConfig("whatever")
|
||||||
|
config.BaseURL = "http://localhost/v1"
|
||||||
|
client := openai.NewClientWithConfig(config)
|
||||||
|
|
||||||
|
models := []string{
|
||||||
|
openai.GPT4Dot1,
|
||||||
|
openai.GPT4Dot120250414,
|
||||||
|
openai.GPT4Dot1Mini,
|
||||||
|
openai.GPT4Dot1Mini20250414,
|
||||||
|
openai.GPT4Dot1Nano,
|
||||||
|
openai.GPT4Dot1Nano20250414,
|
||||||
|
openai.GPT4Dot5Preview,
|
||||||
|
openai.GPT4Dot5Preview20250227,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, model := range models {
|
||||||
|
t.Run(model, func(t *testing.T) {
|
||||||
|
_, err := client.CreateCompletion(
|
||||||
|
context.Background(),
|
||||||
|
openai.CompletionRequest{
|
||||||
|
MaxTokens: 5,
|
||||||
|
Model: model,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if !errors.Is(err, openai.ErrCompletionUnsupportedModel) {
|
||||||
|
t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel for %s model, but returned: %v", model, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestCompletionWithGPT4oModels Tests that GPT4o models are not supported for completion endpoint.
|
||||||
|
func TestCompletionWithGPT4oModels(t *testing.T) {
|
||||||
|
config := openai.DefaultConfig("whatever")
|
||||||
|
config.BaseURL = "http://localhost/v1"
|
||||||
|
client := openai.NewClientWithConfig(config)
|
||||||
|
|
||||||
|
models := []string{
|
||||||
|
openai.GPT4o,
|
||||||
|
openai.GPT4o20240513,
|
||||||
|
openai.GPT4o20240806,
|
||||||
|
openai.GPT4o20241120,
|
||||||
|
openai.GPT4oLatest,
|
||||||
|
openai.GPT4oMini,
|
||||||
|
openai.GPT4oMini20240718,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, model := range models {
|
||||||
|
t.Run(model, func(t *testing.T) {
|
||||||
|
_, err := client.CreateCompletion(
|
||||||
|
context.Background(),
|
||||||
|
openai.CompletionRequest{
|
||||||
|
MaxTokens: 5,
|
||||||
|
Model: model,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if !errors.Is(err, openai.ErrCompletionUnsupportedModel) {
|
||||||
|
t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel for %s model, but returned: %v", model, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
22
config.go
22
config.go
@@ -11,6 +11,8 @@ const (
|
|||||||
|
|
||||||
azureAPIPrefix = "openai"
|
azureAPIPrefix = "openai"
|
||||||
azureDeploymentsPrefix = "deployments"
|
azureDeploymentsPrefix = "deployments"
|
||||||
|
|
||||||
|
AnthropicAPIVersion = "2023-06-01"
|
||||||
)
|
)
|
||||||
|
|
||||||
type APIType string
|
type APIType string
|
||||||
@@ -20,6 +22,7 @@ const (
|
|||||||
APITypeAzure APIType = "AZURE"
|
APITypeAzure APIType = "AZURE"
|
||||||
APITypeAzureAD APIType = "AZURE_AD"
|
APITypeAzureAD APIType = "AZURE_AD"
|
||||||
APITypeCloudflareAzure APIType = "CLOUDFLARE_AZURE"
|
APITypeCloudflareAzure APIType = "CLOUDFLARE_AZURE"
|
||||||
|
APITypeAnthropic APIType = "ANTHROPIC"
|
||||||
)
|
)
|
||||||
|
|
||||||
const AzureAPIKeyHeader = "api-key"
|
const AzureAPIKeyHeader = "api-key"
|
||||||
@@ -37,7 +40,7 @@ type ClientConfig struct {
|
|||||||
BaseURL string
|
BaseURL string
|
||||||
OrgID string
|
OrgID string
|
||||||
APIType APIType
|
APIType APIType
|
||||||
APIVersion string // required when APIType is APITypeAzure or APITypeAzureAD
|
APIVersion string // required when APIType is APITypeAzure or APITypeAzureAD or APITypeAnthropic
|
||||||
AssistantVersion string
|
AssistantVersion string
|
||||||
AzureModelMapperFunc func(model string) string // replace model to azure deployment name func
|
AzureModelMapperFunc func(model string) string // replace model to azure deployment name func
|
||||||
HTTPClient HTTPDoer
|
HTTPClient HTTPDoer
|
||||||
@@ -76,6 +79,23 @@ func DefaultAzureConfig(apiKey, baseURL string) ClientConfig {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func DefaultAnthropicConfig(apiKey, baseURL string) ClientConfig {
|
||||||
|
if baseURL == "" {
|
||||||
|
baseURL = "https://api.anthropic.com/v1"
|
||||||
|
}
|
||||||
|
return ClientConfig{
|
||||||
|
authToken: apiKey,
|
||||||
|
BaseURL: baseURL,
|
||||||
|
OrgID: "",
|
||||||
|
APIType: APITypeAnthropic,
|
||||||
|
APIVersion: AnthropicAPIVersion,
|
||||||
|
|
||||||
|
HTTPClient: &http.Client{},
|
||||||
|
|
||||||
|
EmptyMessagesLimit: defaultEmptyMessagesLimit,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (ClientConfig) String() string {
|
func (ClientConfig) String() string {
|
||||||
return "<OpenAI API ClientConfig>"
|
return "<OpenAI API ClientConfig>"
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ package openai_test
|
|||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/sashabaranov/go-openai"
|
"git.vaala.cloud/VaalaCat/go-openai"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestGetAzureDeploymentByModel(t *testing.T) {
|
func TestGetAzureDeploymentByModel(t *testing.T) {
|
||||||
@@ -60,3 +60,64 @@ func TestGetAzureDeploymentByModel(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDefaultAnthropicConfig(t *testing.T) {
|
||||||
|
apiKey := "test-key"
|
||||||
|
baseURL := "https://api.anthropic.com/v1"
|
||||||
|
|
||||||
|
config := openai.DefaultAnthropicConfig(apiKey, baseURL)
|
||||||
|
|
||||||
|
if config.APIType != openai.APITypeAnthropic {
|
||||||
|
t.Errorf("Expected APIType to be %v, got %v", openai.APITypeAnthropic, config.APIType)
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.APIVersion != openai.AnthropicAPIVersion {
|
||||||
|
t.Errorf("Expected APIVersion to be 2023-06-01, got %v", config.APIVersion)
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.BaseURL != baseURL {
|
||||||
|
t.Errorf("Expected BaseURL to be %v, got %v", baseURL, config.BaseURL)
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.EmptyMessagesLimit != 300 {
|
||||||
|
t.Errorf("Expected EmptyMessagesLimit to be 300, got %v", config.EmptyMessagesLimit)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDefaultAnthropicConfigWithEmptyValues(t *testing.T) {
|
||||||
|
config := openai.DefaultAnthropicConfig("", "")
|
||||||
|
|
||||||
|
if config.APIType != openai.APITypeAnthropic {
|
||||||
|
t.Errorf("Expected APIType to be %v, got %v", openai.APITypeAnthropic, config.APIType)
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.APIVersion != openai.AnthropicAPIVersion {
|
||||||
|
t.Errorf("Expected APIVersion to be %s, got %v", openai.AnthropicAPIVersion, config.APIVersion)
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedBaseURL := "https://api.anthropic.com/v1"
|
||||||
|
if config.BaseURL != expectedBaseURL {
|
||||||
|
t.Errorf("Expected BaseURL to be %v, got %v", expectedBaseURL, config.BaseURL)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientConfigString(t *testing.T) {
|
||||||
|
// String() should always return the constant value
|
||||||
|
conf := openai.DefaultConfig("dummy-token")
|
||||||
|
expected := "<OpenAI API ClientConfig>"
|
||||||
|
got := conf.String()
|
||||||
|
if got != expected {
|
||||||
|
t.Errorf("ClientConfig.String() = %q; want %q", got, expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetAzureDeploymentByModel_NoMapper(t *testing.T) {
|
||||||
|
// On a zero-value or DefaultConfig, AzureModelMapperFunc is nil,
|
||||||
|
// so GetAzureDeploymentByModel should just return the input model.
|
||||||
|
conf := openai.DefaultConfig("dummy-token")
|
||||||
|
model := "some-model"
|
||||||
|
got := conf.GetAzureDeploymentByModel(model)
|
||||||
|
if got != model {
|
||||||
|
t.Errorf("GetAzureDeploymentByModel(%q) = %q; want %q", model, got, model)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -9,8 +9,8 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/sashabaranov/go-openai"
|
"git.vaala.cloud/VaalaCat/go-openai"
|
||||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
"git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TestEdits Tests the edits endpoint of the API using the mocked server.
|
// TestEdits Tests the edits endpoint of the API using the mocked server.
|
||||||
|
|||||||
@@ -11,8 +11,8 @@ import (
|
|||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/sashabaranov/go-openai"
|
"git.vaala.cloud/VaalaCat/go-openai"
|
||||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
"git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestEmbedding(t *testing.T) {
|
func TestEmbedding(t *testing.T) {
|
||||||
|
|||||||
@@ -7,8 +7,8 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/sashabaranov/go-openai"
|
"git.vaala.cloud/VaalaCat/go-openai"
|
||||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
"git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TestGetEngine Tests the retrieve engine endpoint of the API using the mocked server.
|
// TestGetEngine Tests the retrieve engine endpoint of the API using the mocked server.
|
||||||
|
|||||||
4
error.go
4
error.go
@@ -54,7 +54,7 @@ func (e *APIError) UnmarshalJSON(data []byte) (err error) {
|
|||||||
err = json.Unmarshal(rawMap["message"], &e.Message)
|
err = json.Unmarshal(rawMap["message"], &e.Message)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// If the parameter field of a function call is invalid as a JSON schema
|
// If the parameter field of a function call is invalid as a JSON schema
|
||||||
// refs: https://github.com/sashabaranov/go-openai/issues/381
|
// refs: https://git.vaala.cloud/VaalaCat/go-openai/issues/381
|
||||||
var messages []string
|
var messages []string
|
||||||
err = json.Unmarshal(rawMap["message"], &messages)
|
err = json.Unmarshal(rawMap["message"], &messages)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -64,7 +64,7 @@ func (e *APIError) UnmarshalJSON(data []byte) (err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// optional fields for azure openai
|
// optional fields for azure openai
|
||||||
// refs: https://github.com/sashabaranov/go-openai/issues/343
|
// refs: https://git.vaala.cloud/VaalaCat/go-openai/issues/343
|
||||||
if _, ok := rawMap["type"]; ok {
|
if _, ok := rawMap["type"]; ok {
|
||||||
err = json.Unmarshal(rawMap["type"], &e.Type)
|
err = json.Unmarshal(rawMap["type"], &e.Type)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import (
|
|||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/sashabaranov/go-openai"
|
"git.vaala.cloud/VaalaCat/go-openai"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestAPIErrorUnmarshalJSON(t *testing.T) {
|
func TestAPIErrorUnmarshalJSON(t *testing.T) {
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
"github.com/sashabaranov/go-openai"
|
"git.vaala.cloud/VaalaCat/go-openai"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Example() {
|
func Example() {
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
"github.com/sashabaranov/go-openai"
|
"git.vaala.cloud/VaalaCat/go-openai"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
|
|||||||
@@ -5,8 +5,8 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
"github.com/sashabaranov/go-openai"
|
"git.vaala.cloud/VaalaCat/go-openai"
|
||||||
"github.com/sashabaranov/go-openai/jsonschema"
|
"git.vaala.cloud/VaalaCat/go-openai/jsonschema"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
"github.com/sashabaranov/go-openai"
|
"git.vaala.cloud/VaalaCat/go-openai"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
"github.com/sashabaranov/go-openai"
|
"git.vaala.cloud/VaalaCat/go-openai"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
"github.com/sashabaranov/go-openai"
|
"git.vaala.cloud/VaalaCat/go-openai"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
|
|||||||
@@ -12,8 +12,8 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/sashabaranov/go-openai"
|
"git.vaala.cloud/VaalaCat/go-openai"
|
||||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
"git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestFileBytesUpload(t *testing.T) {
|
func TestFileBytesUpload(t *testing.T) {
|
||||||
|
|||||||
@@ -7,8 +7,8 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
utils "github.com/sashabaranov/go-openai/internal"
|
utils "git.vaala.cloud/VaalaCat/go-openai/internal"
|
||||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
"git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestFileBytesUploadWithFailingFormBuilder(t *testing.T) {
|
func TestFileBytesUploadWithFailingFormBuilder(t *testing.T) {
|
||||||
|
|||||||
@@ -7,8 +7,8 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/sashabaranov/go-openai"
|
"git.vaala.cloud/VaalaCat/go-openai"
|
||||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
"git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
|
||||||
)
|
)
|
||||||
|
|
||||||
const testFineTuneID = "fine-tune-id"
|
const testFineTuneID = "fine-tune-id"
|
||||||
|
|||||||
@@ -7,8 +7,8 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/sashabaranov/go-openai"
|
"git.vaala.cloud/VaalaCat/go-openai"
|
||||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
"git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
|
||||||
)
|
)
|
||||||
|
|
||||||
const testFineTuninigJobID = "fine-tuning-job-id"
|
const testFineTuninigJobID = "fine-tuning-job-id"
|
||||||
|
|||||||
2
go.mod
2
go.mod
@@ -1,3 +1,3 @@
|
|||||||
module github.com/sashabaranov/go-openai
|
module git.vaala.cloud/VaalaCat/go-openai
|
||||||
|
|
||||||
go 1.18
|
go 1.18
|
||||||
|
|||||||
74
image.go
74
image.go
@@ -3,8 +3,8 @@ package openai
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -13,31 +13,62 @@ const (
|
|||||||
CreateImageSize256x256 = "256x256"
|
CreateImageSize256x256 = "256x256"
|
||||||
CreateImageSize512x512 = "512x512"
|
CreateImageSize512x512 = "512x512"
|
||||||
CreateImageSize1024x1024 = "1024x1024"
|
CreateImageSize1024x1024 = "1024x1024"
|
||||||
|
|
||||||
// dall-e-3 supported only.
|
// dall-e-3 supported only.
|
||||||
CreateImageSize1792x1024 = "1792x1024"
|
CreateImageSize1792x1024 = "1792x1024"
|
||||||
CreateImageSize1024x1792 = "1024x1792"
|
CreateImageSize1024x1792 = "1024x1792"
|
||||||
|
|
||||||
|
// gpt-image-1 supported only.
|
||||||
|
CreateImageSize1536x1024 = "1536x1024" // Landscape
|
||||||
|
CreateImageSize1024x1536 = "1024x1536" // Portrait
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
CreateImageResponseFormatURL = "url"
|
// dall-e-2 and dall-e-3 only.
|
||||||
CreateImageResponseFormatB64JSON = "b64_json"
|
CreateImageResponseFormatB64JSON = "b64_json"
|
||||||
|
CreateImageResponseFormatURL = "url"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
CreateImageModelDallE2 = "dall-e-2"
|
CreateImageModelDallE2 = "dall-e-2"
|
||||||
CreateImageModelDallE3 = "dall-e-3"
|
CreateImageModelDallE3 = "dall-e-3"
|
||||||
|
CreateImageModelGptImage1 = "gpt-image-1"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
CreateImageQualityHD = "hd"
|
CreateImageQualityHD = "hd"
|
||||||
CreateImageQualityStandard = "standard"
|
CreateImageQualityStandard = "standard"
|
||||||
|
|
||||||
|
// gpt-image-1 only.
|
||||||
|
CreateImageQualityHigh = "high"
|
||||||
|
CreateImageQualityMedium = "medium"
|
||||||
|
CreateImageQualityLow = "low"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
// dall-e-3 only.
|
||||||
CreateImageStyleVivid = "vivid"
|
CreateImageStyleVivid = "vivid"
|
||||||
CreateImageStyleNatural = "natural"
|
CreateImageStyleNatural = "natural"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// gpt-image-1 only.
|
||||||
|
CreateImageBackgroundTransparent = "transparent"
|
||||||
|
CreateImageBackgroundOpaque = "opaque"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// gpt-image-1 only.
|
||||||
|
CreateImageModerationLow = "low"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// gpt-image-1 only.
|
||||||
|
CreateImageOutputFormatPNG = "png"
|
||||||
|
CreateImageOutputFormatJPEG = "jpeg"
|
||||||
|
CreateImageOutputFormatWEBP = "webp"
|
||||||
|
)
|
||||||
|
|
||||||
// ImageRequest represents the request structure for the image API.
|
// ImageRequest represents the request structure for the image API.
|
||||||
type ImageRequest struct {
|
type ImageRequest struct {
|
||||||
Prompt string `json:"prompt,omitempty"`
|
Prompt string `json:"prompt,omitempty"`
|
||||||
@@ -48,16 +79,35 @@ type ImageRequest struct {
|
|||||||
Style string `json:"style,omitempty"`
|
Style string `json:"style,omitempty"`
|
||||||
ResponseFormat string `json:"response_format,omitempty"`
|
ResponseFormat string `json:"response_format,omitempty"`
|
||||||
User string `json:"user,omitempty"`
|
User string `json:"user,omitempty"`
|
||||||
|
Background string `json:"background,omitempty"`
|
||||||
|
Moderation string `json:"moderation,omitempty"`
|
||||||
|
OutputCompression int `json:"output_compression,omitempty"`
|
||||||
|
OutputFormat string `json:"output_format,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ImageResponse represents a response structure for image API.
|
// ImageResponse represents a response structure for image API.
|
||||||
type ImageResponse struct {
|
type ImageResponse struct {
|
||||||
Created int64 `json:"created,omitempty"`
|
Created int64 `json:"created,omitempty"`
|
||||||
Data []ImageResponseDataInner `json:"data,omitempty"`
|
Data []ImageResponseDataInner `json:"data,omitempty"`
|
||||||
|
Usage ImageResponseUsage `json:"usage,omitempty"`
|
||||||
|
|
||||||
httpHeader
|
httpHeader
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ImageResponseInputTokensDetails represents the token breakdown for input tokens.
|
||||||
|
type ImageResponseInputTokensDetails struct {
|
||||||
|
TextTokens int `json:"text_tokens,omitempty"`
|
||||||
|
ImageTokens int `json:"image_tokens,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ImageResponseUsage represents the token usage information for image API.
|
||||||
|
type ImageResponseUsage struct {
|
||||||
|
TotalTokens int `json:"total_tokens,omitempty"`
|
||||||
|
InputTokens int `json:"input_tokens,omitempty"`
|
||||||
|
OutputTokens int `json:"output_tokens,omitempty"`
|
||||||
|
InputTokensDetails ImageResponseInputTokensDetails `json:"input_tokens_details,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
// ImageResponseDataInner represents a response data structure for image API.
|
// ImageResponseDataInner represents a response data structure for image API.
|
||||||
type ImageResponseDataInner struct {
|
type ImageResponseDataInner struct {
|
||||||
URL string `json:"url,omitempty"`
|
URL string `json:"url,omitempty"`
|
||||||
@@ -84,13 +134,15 @@ func (c *Client) CreateImage(ctx context.Context, request ImageRequest) (respons
|
|||||||
|
|
||||||
// ImageEditRequest represents the request structure for the image API.
|
// ImageEditRequest represents the request structure for the image API.
|
||||||
type ImageEditRequest struct {
|
type ImageEditRequest struct {
|
||||||
Image *os.File `json:"image,omitempty"`
|
Image io.Reader `json:"image,omitempty"`
|
||||||
Mask *os.File `json:"mask,omitempty"`
|
Mask io.Reader `json:"mask,omitempty"`
|
||||||
Prompt string `json:"prompt,omitempty"`
|
Prompt string `json:"prompt,omitempty"`
|
||||||
Model string `json:"model,omitempty"`
|
Model string `json:"model,omitempty"`
|
||||||
N int `json:"n,omitempty"`
|
N int `json:"n,omitempty"`
|
||||||
Size string `json:"size,omitempty"`
|
Size string `json:"size,omitempty"`
|
||||||
ResponseFormat string `json:"response_format,omitempty"`
|
ResponseFormat string `json:"response_format,omitempty"`
|
||||||
|
Quality string `json:"quality,omitempty"`
|
||||||
|
User string `json:"user,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateEditImage - API call to create an image. This is the main endpoint of the DALL-E API.
|
// CreateEditImage - API call to create an image. This is the main endpoint of the DALL-E API.
|
||||||
@@ -98,15 +150,16 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest)
|
|||||||
body := &bytes.Buffer{}
|
body := &bytes.Buffer{}
|
||||||
builder := c.createFormBuilder(body)
|
builder := c.createFormBuilder(body)
|
||||||
|
|
||||||
// image
|
// image, filename is not required
|
||||||
err = builder.CreateFormFile("image", request.Image)
|
err = builder.CreateFormFileReader("image", request.Image, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// mask, it is optional
|
// mask, it is optional
|
||||||
if request.Mask != nil {
|
if request.Mask != nil {
|
||||||
err = builder.CreateFormFile("mask", request.Mask)
|
// mask, filename is not required
|
||||||
|
err = builder.CreateFormFileReader("mask", request.Mask, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -154,11 +207,12 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest)
|
|||||||
|
|
||||||
// ImageVariRequest represents the request structure for the image API.
|
// ImageVariRequest represents the request structure for the image API.
|
||||||
type ImageVariRequest struct {
|
type ImageVariRequest struct {
|
||||||
Image *os.File `json:"image,omitempty"`
|
Image io.Reader `json:"image,omitempty"`
|
||||||
Model string `json:"model,omitempty"`
|
Model string `json:"model,omitempty"`
|
||||||
N int `json:"n,omitempty"`
|
N int `json:"n,omitempty"`
|
||||||
Size string `json:"size,omitempty"`
|
Size string `json:"size,omitempty"`
|
||||||
ResponseFormat string `json:"response_format,omitempty"`
|
ResponseFormat string `json:"response_format,omitempty"`
|
||||||
|
User string `json:"user,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateVariImage - API call to create an image variation. This is the main endpoint of the DALL-E API.
|
// CreateVariImage - API call to create an image variation. This is the main endpoint of the DALL-E API.
|
||||||
@@ -167,8 +221,8 @@ func (c *Client) CreateVariImage(ctx context.Context, request ImageVariRequest)
|
|||||||
body := &bytes.Buffer{}
|
body := &bytes.Buffer{}
|
||||||
builder := c.createFormBuilder(body)
|
builder := c.createFormBuilder(body)
|
||||||
|
|
||||||
// image
|
// image, filename is not required
|
||||||
err = builder.CreateFormFile("image", request.Image)
|
err = builder.CreateFormFileReader("image", request.Image, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,8 +11,8 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/sashabaranov/go-openai"
|
"git.vaala.cloud/VaalaCat/go-openai"
|
||||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
"git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestImages(t *testing.T) {
|
func TestImages(t *testing.T) {
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
package openai //nolint:testpackage // testing private field
|
package openai //nolint:testpackage // testing private field
|
||||||
|
|
||||||
import (
|
import (
|
||||||
utils "github.com/sashabaranov/go-openai/internal"
|
utils "git.vaala.cloud/VaalaCat/go-openai/internal"
|
||||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
"git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
|
||||||
|
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -54,13 +54,13 @@ func TestImageFormBuilderFailures(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
mockFailedErr := fmt.Errorf("mock form builder fail")
|
mockFailedErr := fmt.Errorf("mock form builder fail")
|
||||||
mockBuilder.mockCreateFormFile = func(string, *os.File) error {
|
mockBuilder.mockCreateFormFileReader = func(string, io.Reader, string) error {
|
||||||
return mockFailedErr
|
return mockFailedErr
|
||||||
}
|
}
|
||||||
_, err := client.CreateEditImage(ctx, req)
|
_, err := client.CreateEditImage(ctx, req)
|
||||||
checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails")
|
checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails")
|
||||||
|
|
||||||
mockBuilder.mockCreateFormFile = func(name string, _ *os.File) error {
|
mockBuilder.mockCreateFormFileReader = func(name string, _ io.Reader, _ string) error {
|
||||||
if name == "mask" {
|
if name == "mask" {
|
||||||
return mockFailedErr
|
return mockFailedErr
|
||||||
}
|
}
|
||||||
@@ -119,13 +119,13 @@ func TestVariImageFormBuilderFailures(t *testing.T) {
|
|||||||
req := ImageVariRequest{}
|
req := ImageVariRequest{}
|
||||||
|
|
||||||
mockFailedErr := fmt.Errorf("mock form builder fail")
|
mockFailedErr := fmt.Errorf("mock form builder fail")
|
||||||
mockBuilder.mockCreateFormFile = func(string, *os.File) error {
|
mockBuilder.mockCreateFormFileReader = func(string, io.Reader, string) error {
|
||||||
return mockFailedErr
|
return mockFailedErr
|
||||||
}
|
}
|
||||||
_, err := client.CreateVariImage(ctx, req)
|
_, err := client.CreateVariImage(ctx, req)
|
||||||
checks.ErrorIs(t, err, mockFailedErr, "CreateVariImage should return error if form builder fails")
|
checks.ErrorIs(t, err, mockFailedErr, "CreateVariImage should return error if form builder fails")
|
||||||
|
|
||||||
mockBuilder.mockCreateFormFile = func(string, *os.File) error {
|
mockBuilder.mockCreateFormFileReader = func(string, io.Reader, string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -5,8 +5,8 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
utils "github.com/sashabaranov/go-openai/internal"
|
utils "git.vaala.cloud/VaalaCat/go-openai/internal"
|
||||||
"github.com/sashabaranov/go-openai/internal/test"
|
"git.vaala.cloud/VaalaCat/go-openai/internal/test"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestErrorAccumulatorBytes(t *testing.T) {
|
func TestErrorAccumulatorBytes(t *testing.T) {
|
||||||
|
|||||||
@@ -4,8 +4,10 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"mime/multipart"
|
"mime/multipart"
|
||||||
|
"net/textproto"
|
||||||
"os"
|
"os"
|
||||||
"path"
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
type FormBuilder interface {
|
type FormBuilder interface {
|
||||||
@@ -30,8 +32,37 @@ func (fb *DefaultFormBuilder) CreateFormFile(fieldname string, file *os.File) er
|
|||||||
return fb.createFormFile(fieldname, file, file.Name())
|
return fb.createFormFile(fieldname, file, file.Name())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var quoteEscaper = strings.NewReplacer("\\", "\\\\", `"`, "\\\"")
|
||||||
|
|
||||||
|
func escapeQuotes(s string) string {
|
||||||
|
return quoteEscaper.Replace(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateFormFileReader creates a form field with a file reader.
|
||||||
|
// The filename in parameters can be an empty string.
|
||||||
|
// The filename in Content-Disposition is required, But it can be an empty string.
|
||||||
func (fb *DefaultFormBuilder) CreateFormFileReader(fieldname string, r io.Reader, filename string) error {
|
func (fb *DefaultFormBuilder) CreateFormFileReader(fieldname string, r io.Reader, filename string) error {
|
||||||
return fb.createFormFile(fieldname, r, path.Base(filename))
|
h := make(textproto.MIMEHeader)
|
||||||
|
h.Set(
|
||||||
|
"Content-Disposition",
|
||||||
|
fmt.Sprintf(
|
||||||
|
`form-data; name="%s"; filename="%s"`,
|
||||||
|
escapeQuotes(fieldname),
|
||||||
|
escapeQuotes(filepath.Base(filename)),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
fieldWriter, err := fb.writer.CreatePart(h)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = io.Copy(fieldWriter, r)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (fb *DefaultFormBuilder) createFormFile(fieldname string, r io.Reader, filename string) error {
|
func (fb *DefaultFormBuilder) createFormFile(fieldname string, r io.Reader, filename string) error {
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
package openai //nolint:testpackage // testing private field
|
package openai //nolint:testpackage // testing private field
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
"git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
|
||||||
|
|
||||||
"bytes"
|
"bytes"
|
||||||
"errors"
|
"errors"
|
||||||
@@ -43,3 +43,32 @@ func TestFormBuilderWithClosedFile(t *testing.T) {
|
|||||||
checks.HasError(t, err, "formbuilder should return error if file is closed")
|
checks.HasError(t, err, "formbuilder should return error if file is closed")
|
||||||
checks.ErrorIs(t, err, os.ErrClosed, "formbuilder should return error if file is closed")
|
checks.ErrorIs(t, err, os.ErrClosed, "formbuilder should return error if file is closed")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type failingReader struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
var errMockFailingReaderError = errors.New("mock reader failed")
|
||||||
|
|
||||||
|
func (*failingReader) Read([]byte) (int, error) {
|
||||||
|
return 0, errMockFailingReaderError
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFormBuilderWithReader(t *testing.T) {
|
||||||
|
file, err := os.CreateTemp(t.TempDir(), "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Error creating tmp file: %v", err)
|
||||||
|
}
|
||||||
|
defer file.Close()
|
||||||
|
builder := NewFormBuilder(&failingWriter{})
|
||||||
|
err = builder.CreateFormFileReader("file", file, file.Name())
|
||||||
|
checks.ErrorIs(t, err, errMockFailingWriterError, "formbuilder should return error if writer fails")
|
||||||
|
|
||||||
|
builder = NewFormBuilder(&bytes.Buffer{})
|
||||||
|
reader := &failingReader{}
|
||||||
|
err = builder.CreateFormFileReader("file", reader, "")
|
||||||
|
checks.ErrorIs(t, err, errMockFailingReaderError, "formbuilder should return error if copy reader fails")
|
||||||
|
|
||||||
|
successReader := &bytes.Buffer{}
|
||||||
|
err = builder.CreateFormFileReader("file", successReader, "")
|
||||||
|
checks.NoError(t, err, "formbuilder should not return error")
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
package test
|
package test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
"git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
|
||||||
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
|
|||||||
@@ -46,6 +46,8 @@ type Definition struct {
|
|||||||
// additionalProperties: false
|
// additionalProperties: false
|
||||||
// additionalProperties: jsonschema.Definition{Type: jsonschema.String}
|
// additionalProperties: jsonschema.Definition{Type: jsonschema.String}
|
||||||
AdditionalProperties any `json:"additionalProperties,omitempty"`
|
AdditionalProperties any `json:"additionalProperties,omitempty"`
|
||||||
|
// Whether the schema is nullable or not.
|
||||||
|
Nullable bool `json:"nullable,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Definition) MarshalJSON() ([]byte, error) {
|
func (d *Definition) MarshalJSON() ([]byte, error) {
|
||||||
@@ -124,9 +126,12 @@ func reflectSchemaObject(t reflect.Type) (*Definition, error) {
|
|||||||
}
|
}
|
||||||
jsonTag := field.Tag.Get("json")
|
jsonTag := field.Tag.Get("json")
|
||||||
var required = true
|
var required = true
|
||||||
if jsonTag == "" {
|
switch {
|
||||||
|
case jsonTag == "-":
|
||||||
|
continue
|
||||||
|
case jsonTag == "":
|
||||||
jsonTag = field.Name
|
jsonTag = field.Name
|
||||||
} else if strings.HasSuffix(jsonTag, ",omitempty") {
|
case strings.HasSuffix(jsonTag, ",omitempty"):
|
||||||
jsonTag = strings.TrimSuffix(jsonTag, ",omitempty")
|
jsonTag = strings.TrimSuffix(jsonTag, ",omitempty")
|
||||||
required = false
|
required = false
|
||||||
}
|
}
|
||||||
@@ -139,6 +144,16 @@ func reflectSchemaObject(t reflect.Type) (*Definition, error) {
|
|||||||
if description != "" {
|
if description != "" {
|
||||||
item.Description = description
|
item.Description = description
|
||||||
}
|
}
|
||||||
|
enum := field.Tag.Get("enum")
|
||||||
|
if enum != "" {
|
||||||
|
item.Enum = strings.Split(enum, ",")
|
||||||
|
}
|
||||||
|
|
||||||
|
if n := field.Tag.Get("nullable"); n != "" {
|
||||||
|
nullable, _ := strconv.ParseBool(n)
|
||||||
|
item.Nullable = nullable
|
||||||
|
}
|
||||||
|
|
||||||
properties[jsonTag] = *item
|
properties[jsonTag] = *item
|
||||||
|
|
||||||
if s := field.Tag.Get("required"); s != "" {
|
if s := field.Tag.Get("required"); s != "" {
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import (
|
|||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/sashabaranov/go-openai/jsonschema"
|
"git.vaala.cloud/VaalaCat/go-openai/jsonschema"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestDefinition_MarshalJSON(t *testing.T) {
|
func TestDefinition_MarshalJSON(t *testing.T) {
|
||||||
@@ -17,7 +17,7 @@ func TestDefinition_MarshalJSON(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "Test with empty Definition",
|
name: "Test with empty Definition",
|
||||||
def: jsonschema.Definition{},
|
def: jsonschema.Definition{},
|
||||||
want: `{"properties":{}}`,
|
want: `{}`,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Test with Definition properties set",
|
name: "Test with Definition properties set",
|
||||||
@@ -35,11 +35,10 @@ func TestDefinition_MarshalJSON(t *testing.T) {
|
|||||||
"description":"A string type",
|
"description":"A string type",
|
||||||
"properties":{
|
"properties":{
|
||||||
"name":{
|
"name":{
|
||||||
"type":"string",
|
"type":"string"
|
||||||
"properties":{}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}`,
|
}`,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Test with nested Definition properties",
|
name: "Test with nested Definition properties",
|
||||||
@@ -66,17 +65,15 @@ func TestDefinition_MarshalJSON(t *testing.T) {
|
|||||||
"type":"object",
|
"type":"object",
|
||||||
"properties":{
|
"properties":{
|
||||||
"name":{
|
"name":{
|
||||||
"type":"string",
|
"type":"string"
|
||||||
"properties":{}
|
|
||||||
},
|
},
|
||||||
"age":{
|
"age":{
|
||||||
"type":"integer",
|
"type":"integer"
|
||||||
"properties":{}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}`,
|
}`,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Test with complex nested Definition",
|
name: "Test with complex nested Definition",
|
||||||
@@ -114,30 +111,26 @@ func TestDefinition_MarshalJSON(t *testing.T) {
|
|||||||
"type":"object",
|
"type":"object",
|
||||||
"properties":{
|
"properties":{
|
||||||
"name":{
|
"name":{
|
||||||
"type":"string",
|
"type":"string"
|
||||||
"properties":{}
|
|
||||||
},
|
},
|
||||||
"age":{
|
"age":{
|
||||||
"type":"integer",
|
"type":"integer"
|
||||||
"properties":{}
|
|
||||||
},
|
},
|
||||||
"address":{
|
"address":{
|
||||||
"type":"object",
|
"type":"object",
|
||||||
"properties":{
|
"properties":{
|
||||||
"city":{
|
"city":{
|
||||||
"type":"string",
|
"type":"string"
|
||||||
"properties":{}
|
|
||||||
},
|
},
|
||||||
"country":{
|
"country":{
|
||||||
"type":"string",
|
"type":"string"
|
||||||
"properties":{}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}`,
|
}`,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Test with Array type Definition",
|
name: "Test with Array type Definition",
|
||||||
@@ -155,18 +148,14 @@ func TestDefinition_MarshalJSON(t *testing.T) {
|
|||||||
want: `{
|
want: `{
|
||||||
"type":"array",
|
"type":"array",
|
||||||
"items":{
|
"items":{
|
||||||
"type":"string",
|
"type":"string"
|
||||||
"properties":{
|
|
||||||
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"properties":{
|
"properties":{
|
||||||
"name":{
|
"name":{
|
||||||
"type":"string",
|
"type":"string"
|
||||||
"properties":{}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}`,
|
}`,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -193,6 +182,232 @@ func TestDefinition_MarshalJSON(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestStructToSchema(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
in any
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Test with empty struct",
|
||||||
|
in: struct{}{},
|
||||||
|
want: `{
|
||||||
|
"type":"object",
|
||||||
|
"additionalProperties":false
|
||||||
|
}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Test with struct containing many fields",
|
||||||
|
in: struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Age int `json:"age"`
|
||||||
|
Active bool `json:"active"`
|
||||||
|
Height float64 `json:"height"`
|
||||||
|
Cities []struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
State string `json:"state"`
|
||||||
|
} `json:"cities"`
|
||||||
|
}{
|
||||||
|
Name: "John Doe",
|
||||||
|
Age: 30,
|
||||||
|
Cities: []struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
State string `json:"state"`
|
||||||
|
}{
|
||||||
|
{Name: "New York", State: "NY"},
|
||||||
|
{Name: "Los Angeles", State: "CA"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: `{
|
||||||
|
"type":"object",
|
||||||
|
"properties":{
|
||||||
|
"name":{
|
||||||
|
"type":"string"
|
||||||
|
},
|
||||||
|
"age":{
|
||||||
|
"type":"integer"
|
||||||
|
},
|
||||||
|
"active":{
|
||||||
|
"type":"boolean"
|
||||||
|
},
|
||||||
|
"height":{
|
||||||
|
"type":"number"
|
||||||
|
},
|
||||||
|
"cities":{
|
||||||
|
"type":"array",
|
||||||
|
"items":{
|
||||||
|
"additionalProperties":false,
|
||||||
|
"type":"object",
|
||||||
|
"properties":{
|
||||||
|
"name":{
|
||||||
|
"type":"string"
|
||||||
|
},
|
||||||
|
"state":{
|
||||||
|
"type":"string"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required":["name","state"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required":["name","age","active","height","cities"],
|
||||||
|
"additionalProperties":false
|
||||||
|
}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Test with description tag",
|
||||||
|
in: struct {
|
||||||
|
Name string `json:"name" description:"The name of the person"`
|
||||||
|
}{
|
||||||
|
Name: "John Doe",
|
||||||
|
},
|
||||||
|
want: `{
|
||||||
|
"type":"object",
|
||||||
|
"properties":{
|
||||||
|
"name":{
|
||||||
|
"type":"string",
|
||||||
|
"description":"The name of the person"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required":["name"],
|
||||||
|
"additionalProperties":false
|
||||||
|
}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Test with required tag",
|
||||||
|
in: struct {
|
||||||
|
Name string `json:"name" required:"false"`
|
||||||
|
}{
|
||||||
|
Name: "John Doe",
|
||||||
|
},
|
||||||
|
want: `{
|
||||||
|
"type":"object",
|
||||||
|
"properties":{
|
||||||
|
"name":{
|
||||||
|
"type":"string"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties":false
|
||||||
|
}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Test with enum tag",
|
||||||
|
in: struct {
|
||||||
|
Color string `json:"color" enum:"red,green,blue"`
|
||||||
|
}{
|
||||||
|
Color: "red",
|
||||||
|
},
|
||||||
|
want: `{
|
||||||
|
"type":"object",
|
||||||
|
"properties":{
|
||||||
|
"color":{
|
||||||
|
"type":"string",
|
||||||
|
"enum":["red","green","blue"]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required":["color"],
|
||||||
|
"additionalProperties":false
|
||||||
|
}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Test with nullable tag",
|
||||||
|
in: struct {
|
||||||
|
Name *string `json:"name" nullable:"true"`
|
||||||
|
}{
|
||||||
|
Name: nil,
|
||||||
|
},
|
||||||
|
want: `{
|
||||||
|
|
||||||
|
"type":"object",
|
||||||
|
"properties":{
|
||||||
|
"name":{
|
||||||
|
"type":"string",
|
||||||
|
"nullable":true
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required":["name"],
|
||||||
|
"additionalProperties":false
|
||||||
|
}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Test with exclude mark",
|
||||||
|
in: struct {
|
||||||
|
Name string `json:"-"`
|
||||||
|
}{
|
||||||
|
Name: "Name",
|
||||||
|
},
|
||||||
|
want: `{
|
||||||
|
"type":"object",
|
||||||
|
"additionalProperties":false
|
||||||
|
}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Test with no json tag",
|
||||||
|
in: struct {
|
||||||
|
Name string
|
||||||
|
}{
|
||||||
|
Name: "",
|
||||||
|
},
|
||||||
|
want: `{
|
||||||
|
"type":"object",
|
||||||
|
"properties":{
|
||||||
|
"Name":{
|
||||||
|
"type":"string"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required":["Name"],
|
||||||
|
"additionalProperties":false
|
||||||
|
}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Test with omitempty tag",
|
||||||
|
in: struct {
|
||||||
|
Name string `json:"name,omitempty"`
|
||||||
|
}{
|
||||||
|
Name: "",
|
||||||
|
},
|
||||||
|
want: `{
|
||||||
|
"type":"object",
|
||||||
|
"properties":{
|
||||||
|
"name":{
|
||||||
|
"type":"string"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties":false
|
||||||
|
}`,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
wantBytes := []byte(tt.want)
|
||||||
|
|
||||||
|
schema, err := jsonschema.GenerateSchemaForType(tt.in)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Failed to generate schema: error = %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var want map[string]interface{}
|
||||||
|
err = json.Unmarshal(wantBytes, &want)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Failed to Unmarshal JSON: error = %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
got := structToMap(t, schema)
|
||||||
|
gotPtr := structToMap(t, &schema)
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(got, want) {
|
||||||
|
t.Errorf("MarshalJSON() got = %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(gotPtr, want) {
|
||||||
|
t.Errorf("MarshalJSON() gotPtr = %v, want %v", gotPtr, want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func structToMap(t *testing.T, v any) map[string]any {
|
func structToMap(t *testing.T, v any) map[string]any {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
gotBytes, err := json.Marshal(v)
|
gotBytes, err := json.Marshal(v)
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ package jsonschema_test
|
|||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/sashabaranov/go-openai/jsonschema"
|
"git.vaala.cloud/VaalaCat/go-openai/jsonschema"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Test_Validate(t *testing.T) {
|
func Test_Validate(t *testing.T) {
|
||||||
|
|||||||
@@ -7,9 +7,9 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/sashabaranov/go-openai"
|
"git.vaala.cloud/VaalaCat/go-openai"
|
||||||
"github.com/sashabaranov/go-openai/internal/test"
|
"git.vaala.cloud/VaalaCat/go-openai/internal/test"
|
||||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
"git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
|
||||||
)
|
)
|
||||||
|
|
||||||
var emptyStr = ""
|
var emptyStr = ""
|
||||||
|
|||||||
@@ -9,8 +9,8 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/sashabaranov/go-openai"
|
"git.vaala.cloud/VaalaCat/go-openai"
|
||||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
"git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
|
||||||
)
|
)
|
||||||
|
|
||||||
const testFineTuneModelID = "fine-tune-model-id"
|
const testFineTuneModelID = "fine-tune-model-id"
|
||||||
@@ -47,6 +47,24 @@ func TestGetModel(t *testing.T) {
|
|||||||
checks.NoError(t, err, "GetModel error")
|
checks.NoError(t, err, "GetModel error")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestGetModelO3 Tests the retrieve O3 model endpoint of the API using the mocked server.
|
||||||
|
func TestGetModelO3(t *testing.T) {
|
||||||
|
client, server, teardown := setupOpenAITestServer()
|
||||||
|
defer teardown()
|
||||||
|
server.RegisterHandler("/v1/models/o3", handleGetModelEndpoint)
|
||||||
|
_, err := client.GetModel(context.Background(), "o3")
|
||||||
|
checks.NoError(t, err, "GetModel error for O3")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestGetModelO4Mini Tests the retrieve O4Mini model endpoint of the API using the mocked server.
|
||||||
|
func TestGetModelO4Mini(t *testing.T) {
|
||||||
|
client, server, teardown := setupOpenAITestServer()
|
||||||
|
defer teardown()
|
||||||
|
server.RegisterHandler("/v1/models/o4-mini", handleGetModelEndpoint)
|
||||||
|
_, err := client.GetModel(context.Background(), "o4-mini")
|
||||||
|
checks.NoError(t, err, "GetModel error for O4Mini")
|
||||||
|
}
|
||||||
|
|
||||||
func TestAzureGetModel(t *testing.T) {
|
func TestAzureGetModel(t *testing.T) {
|
||||||
client, server, teardown := setupAzureTestServer()
|
client, server, teardown := setupAzureTestServer()
|
||||||
defer teardown()
|
defer teardown()
|
||||||
|
|||||||
@@ -11,8 +11,8 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/sashabaranov/go-openai"
|
"git.vaala.cloud/VaalaCat/go-openai"
|
||||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
"git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TestModeration Tests the moderations endpoint of the API using the mocked server.
|
// TestModeration Tests the moderations endpoint of the API using the mocked server.
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
package openai_test
|
package openai_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/sashabaranov/go-openai"
|
"git.vaala.cloud/VaalaCat/go-openai"
|
||||||
"github.com/sashabaranov/go-openai/internal/test"
|
"git.vaala.cloud/VaalaCat/go-openai/internal/test"
|
||||||
)
|
)
|
||||||
|
|
||||||
func setupOpenAITestServer() (client *openai.Client, server *test.ServerTest, teardown func()) {
|
func setupOpenAITestServer() (client *openai.Client, server *test.ServerTest, teardown func()) {
|
||||||
@@ -29,7 +29,7 @@ func setupAzureTestServer() (client *openai.Client, server *test.ServerTest, tea
|
|||||||
|
|
||||||
// numTokens Returns the number of GPT-3 encoded tokens in the given text.
|
// numTokens Returns the number of GPT-3 encoded tokens in the given text.
|
||||||
// This function approximates based on the rule of thumb stated by OpenAI:
|
// This function approximates based on the rule of thumb stated by OpenAI:
|
||||||
// https://beta.openai.com/tokenizer/
|
// https://beta.openai.com/tokenizer.
|
||||||
//
|
//
|
||||||
// TODO: implement an actual tokenizer for GPT-3 and Codex (once available).
|
// TODO: implement an actual tokenizer for GPT-3 and Codex (once available).
|
||||||
func numTokens(s string) int {
|
func numTokens(s string) int {
|
||||||
|
|||||||
@@ -28,15 +28,6 @@ var (
|
|||||||
ErrReasoningModelLimitationsOther = errors.New("this model has beta-limitations, temperature, top_p and n are fixed at 1, while presence_penalty and frequency_penalty are fixed at 0") //nolint:lll
|
ErrReasoningModelLimitationsOther = errors.New("this model has beta-limitations, temperature, top_p and n are fixed at 1, while presence_penalty and frequency_penalty are fixed at 0") //nolint:lll
|
||||||
)
|
)
|
||||||
|
|
||||||
var unsupportedToolsForO1Models = map[ToolType]struct{}{
|
|
||||||
ToolTypeFunction: {},
|
|
||||||
}
|
|
||||||
|
|
||||||
var availableMessageRoleForO1Models = map[string]struct{}{
|
|
||||||
ChatMessageRoleUser: {},
|
|
||||||
ChatMessageRoleAssistant: {},
|
|
||||||
}
|
|
||||||
|
|
||||||
// ReasoningValidator handles validation for o-series model requests.
|
// ReasoningValidator handles validation for o-series model requests.
|
||||||
type ReasoningValidator struct{}
|
type ReasoningValidator struct{}
|
||||||
|
|
||||||
@@ -49,8 +40,9 @@ func NewReasoningValidator() *ReasoningValidator {
|
|||||||
func (v *ReasoningValidator) Validate(request ChatCompletionRequest) error {
|
func (v *ReasoningValidator) Validate(request ChatCompletionRequest) error {
|
||||||
o1Series := strings.HasPrefix(request.Model, "o1")
|
o1Series := strings.HasPrefix(request.Model, "o1")
|
||||||
o3Series := strings.HasPrefix(request.Model, "o3")
|
o3Series := strings.HasPrefix(request.Model, "o3")
|
||||||
|
o4Series := strings.HasPrefix(request.Model, "o4")
|
||||||
|
|
||||||
if !o1Series && !o3Series {
|
if !o1Series && !o3Series && !o4Series {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -58,12 +50,6 @@ func (v *ReasoningValidator) Validate(request ChatCompletionRequest) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if o1Series {
|
|
||||||
if err := v.validateO1Specific(request); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -93,19 +79,3 @@ func (v *ReasoningValidator) validateReasoningModelParams(request ChatCompletion
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// validateO1Specific checks O1-specific limitations.
|
|
||||||
func (v *ReasoningValidator) validateO1Specific(request ChatCompletionRequest) error {
|
|
||||||
for _, m := range request.Messages {
|
|
||||||
if _, found := availableMessageRoleForO1Models[m.Role]; !found {
|
|
||||||
return ErrO1BetaLimitationsMessageTypes
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, t := range request.Tools {
|
|
||||||
if _, found := unsupportedToolsForO1Models[t.Type]; found {
|
|
||||||
return ErrO1BetaLimitationsTools
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -3,8 +3,8 @@ package openai_test
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
|
||||||
openai "github.com/sashabaranov/go-openai"
|
openai "git.vaala.cloud/VaalaCat/go-openai"
|
||||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
"git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
|
||||||
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|||||||
@@ -11,17 +11,22 @@ const (
|
|||||||
TTSModel1 SpeechModel = "tts-1"
|
TTSModel1 SpeechModel = "tts-1"
|
||||||
TTSModel1HD SpeechModel = "tts-1-hd"
|
TTSModel1HD SpeechModel = "tts-1-hd"
|
||||||
TTSModelCanary SpeechModel = "canary-tts"
|
TTSModelCanary SpeechModel = "canary-tts"
|
||||||
|
TTSModelGPT4oMini SpeechModel = "gpt-4o-mini-tts"
|
||||||
)
|
)
|
||||||
|
|
||||||
type SpeechVoice string
|
type SpeechVoice string
|
||||||
|
|
||||||
const (
|
const (
|
||||||
VoiceAlloy SpeechVoice = "alloy"
|
VoiceAlloy SpeechVoice = "alloy"
|
||||||
|
VoiceAsh SpeechVoice = "ash"
|
||||||
|
VoiceBallad SpeechVoice = "ballad"
|
||||||
|
VoiceCoral SpeechVoice = "coral"
|
||||||
VoiceEcho SpeechVoice = "echo"
|
VoiceEcho SpeechVoice = "echo"
|
||||||
VoiceFable SpeechVoice = "fable"
|
VoiceFable SpeechVoice = "fable"
|
||||||
VoiceOnyx SpeechVoice = "onyx"
|
VoiceOnyx SpeechVoice = "onyx"
|
||||||
VoiceNova SpeechVoice = "nova"
|
VoiceNova SpeechVoice = "nova"
|
||||||
VoiceShimmer SpeechVoice = "shimmer"
|
VoiceShimmer SpeechVoice = "shimmer"
|
||||||
|
VoiceVerse SpeechVoice = "verse"
|
||||||
)
|
)
|
||||||
|
|
||||||
type SpeechResponseFormat string
|
type SpeechResponseFormat string
|
||||||
@@ -39,6 +44,7 @@ type CreateSpeechRequest struct {
|
|||||||
Model SpeechModel `json:"model"`
|
Model SpeechModel `json:"model"`
|
||||||
Input string `json:"input"`
|
Input string `json:"input"`
|
||||||
Voice SpeechVoice `json:"voice"`
|
Voice SpeechVoice `json:"voice"`
|
||||||
|
Instructions string `json:"instructions,omitempty"` // Optional, Doesnt work with tts-1 or tts-1-hd.
|
||||||
ResponseFormat SpeechResponseFormat `json:"response_format,omitempty"` // Optional, default to mp3
|
ResponseFormat SpeechResponseFormat `json:"response_format,omitempty"` // Optional, default to mp3
|
||||||
Speed float64 `json:"speed,omitempty"` // Optional, default to 1.0
|
Speed float64 `json:"speed,omitempty"` // Optional, default to 1.0
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,9 +11,9 @@ import (
|
|||||||
"path/filepath"
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/sashabaranov/go-openai"
|
"git.vaala.cloud/VaalaCat/go-openai"
|
||||||
"github.com/sashabaranov/go-openai/internal/test"
|
"git.vaala.cloud/VaalaCat/go-openai/internal/test"
|
||||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
"git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestSpeechIntegration(t *testing.T) {
|
func TestSpeechIntegration(t *testing.T) {
|
||||||
|
|||||||
@@ -6,13 +6,14 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"regexp"
|
||||||
|
|
||||||
utils "github.com/sashabaranov/go-openai/internal"
|
utils "git.vaala.cloud/VaalaCat/go-openai/internal"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
headerData = []byte("data: ")
|
headerData = regexp.MustCompile(`^data:\s*`)
|
||||||
errorPrefix = []byte(`data: {"error":`)
|
errorPrefix = regexp.MustCompile(`^data:\s*{"error":`)
|
||||||
)
|
)
|
||||||
|
|
||||||
type streamable interface {
|
type streamable interface {
|
||||||
@@ -70,12 +71,12 @@ func (stream *streamReader[T]) processLines() ([]byte, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
noSpaceLine := bytes.TrimSpace(rawLine)
|
noSpaceLine := bytes.TrimSpace(rawLine)
|
||||||
if bytes.HasPrefix(noSpaceLine, errorPrefix) {
|
if errorPrefix.Match(noSpaceLine) {
|
||||||
hasErrorPrefix = true
|
hasErrorPrefix = true
|
||||||
}
|
}
|
||||||
if !bytes.HasPrefix(noSpaceLine, headerData) || hasErrorPrefix {
|
if !headerData.Match(noSpaceLine) || hasErrorPrefix {
|
||||||
if hasErrorPrefix {
|
if hasErrorPrefix {
|
||||||
noSpaceLine = bytes.TrimPrefix(noSpaceLine, headerData)
|
noSpaceLine = headerData.ReplaceAll(noSpaceLine, nil)
|
||||||
}
|
}
|
||||||
writeErr := stream.errAccumulator.Write(noSpaceLine)
|
writeErr := stream.errAccumulator.Write(noSpaceLine)
|
||||||
if writeErr != nil {
|
if writeErr != nil {
|
||||||
@@ -89,7 +90,7 @@ func (stream *streamReader[T]) processLines() ([]byte, error) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
noPrefixLine := bytes.TrimPrefix(noSpaceLine, headerData)
|
noPrefixLine := headerData.ReplaceAll(noSpaceLine, nil)
|
||||||
if string(noPrefixLine) == "[DONE]" {
|
if string(noPrefixLine) == "[DONE]" {
|
||||||
stream.isFinished = true
|
stream.isFinished = true
|
||||||
return nil, io.EOF
|
return nil, io.EOF
|
||||||
|
|||||||
@@ -6,9 +6,9 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
utils "github.com/sashabaranov/go-openai/internal"
|
utils "git.vaala.cloud/VaalaCat/go-openai/internal"
|
||||||
"github.com/sashabaranov/go-openai/internal/test"
|
"git.vaala.cloud/VaalaCat/go-openai/internal/test"
|
||||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
"git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
|
||||||
)
|
)
|
||||||
|
|
||||||
var errTestUnmarshalerFailed = errors.New("test unmarshaler failed")
|
var errTestUnmarshalerFailed = errors.New("test unmarshaler failed")
|
||||||
|
|||||||
@@ -10,8 +10,8 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/sashabaranov/go-openai"
|
"git.vaala.cloud/VaalaCat/go-openai"
|
||||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
"git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestCompletionsStreamWrongModel(t *testing.T) {
|
func TestCompletionsStreamWrongModel(t *testing.T) {
|
||||||
|
|||||||
@@ -7,8 +7,8 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
openai "github.com/sashabaranov/go-openai"
|
openai "git.vaala.cloud/VaalaCat/go-openai"
|
||||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
"git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TestThread Tests the thread endpoint of the API using the mocked server.
|
// TestThread Tests the thread endpoint of the API using the mocked server.
|
||||||
|
|||||||
@@ -3,8 +3,8 @@ package openai_test
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
|
||||||
openai "github.com/sashabaranov/go-openai"
|
openai "git.vaala.cloud/VaalaCat/go-openai"
|
||||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
"git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
|
||||||
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|||||||
Reference in New Issue
Block a user