Compare commits
37 Commits
74ed75f291
...
master
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2436e7afb8 | ||
|
|
67f3b169df | ||
|
|
3640274cd1 | ||
|
|
ff9d83a485 | ||
|
|
8c65b35c57 | ||
|
|
4d2e7ab29d | ||
|
|
6aaa732296 | ||
|
|
0116f2994d | ||
|
|
8ba38f6ba1 | ||
|
|
6181facea7 | ||
|
|
77ccac8d34 | ||
|
|
5ea214a188 | ||
|
|
d65f0cb54e | ||
|
|
93a611cf4f | ||
|
|
6836cf6a6f | ||
|
|
da5f9bc9bc | ||
|
|
bb5bc27567 | ||
|
|
4cccc6c934 | ||
|
|
306fbbbe6f | ||
|
|
658beda2ba | ||
|
|
d68a683815 | ||
|
|
e99eb54c9d | ||
|
|
74d6449f22 | ||
|
|
261721bfdb | ||
|
|
be2e2387d4 | ||
|
|
85f578b865 | ||
|
|
c0a9a75fe0 | ||
|
|
a62919e8c6 | ||
|
|
2054db016c | ||
|
|
45aa99607b | ||
|
|
9823a8bbbd | ||
|
|
7a2915a37d | ||
|
|
2a0ff5ac63 | ||
|
|
56a9acf86f | ||
|
|
af5355f5b1 | ||
|
|
c203ca001f | ||
|
|
21fa42c18d |
2
.github/ISSUE_TEMPLATE/bug_report.md
vendored
2
.github/ISSUE_TEMPLATE/bug_report.md
vendored
@@ -8,7 +8,7 @@ assignees: ''
|
||||
---
|
||||
|
||||
Your issue may already be reported!
|
||||
Please search on the [issue tracker](https://github.com/sashabaranov/go-openai/issues) before creating one.
|
||||
Please search on the [issue tracker](https://git.vaala.cloud/VaalaCat/go-openai/issues) before creating one.
|
||||
|
||||
**Describe the bug**
|
||||
A clear and concise description of what the bug is. If it's an API-related bug, please provide relevant endpoint(s).
|
||||
|
||||
2
.github/ISSUE_TEMPLATE/feature_request.md
vendored
2
.github/ISSUE_TEMPLATE/feature_request.md
vendored
@@ -8,7 +8,7 @@ assignees: ''
|
||||
---
|
||||
|
||||
Your issue may already be reported!
|
||||
Please search on the [issue tracker](https://github.com/sashabaranov/go-openai/issues) before creating one.
|
||||
Please search on the [issue tracker](https://git.vaala.cloud/VaalaCat/go-openai/issues) before creating one.
|
||||
|
||||
**Is your feature request related to a problem? Please describe.**
|
||||
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
|
||||
|
||||
2
.github/PULL_REQUEST_TEMPLATE.md
vendored
2
.github/PULL_REQUEST_TEMPLATE.md
vendored
@@ -1,5 +1,5 @@
|
||||
A similar PR may already be submitted!
|
||||
Please search among the [Pull request](https://github.com/sashabaranov/go-openai/pulls) before creating one.
|
||||
Please search among the [Pull request](https://git.vaala.cloud/VaalaCat/go-openai/pulls) before creating one.
|
||||
|
||||
If your changes introduce breaking changes, please prefix the title of your pull request with "[BREAKING_CHANGES]". This allows for clear identification of such changes in the 'What's Changed' section on the release page, making it developer-friendly.
|
||||
|
||||
|
||||
12
.github/workflows/pr.yml
vendored
12
.github/workflows/pr.yml
vendored
@@ -13,15 +13,17 @@ jobs:
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.21'
|
||||
go-version: '1.24'
|
||||
- name: Run vet
|
||||
run: |
|
||||
go vet .
|
||||
- name: Run golangci-lint
|
||||
uses: golangci/golangci-lint-action@v4
|
||||
uses: golangci/golangci-lint-action@v7
|
||||
with:
|
||||
version: latest
|
||||
version: v2.1.5
|
||||
- name: Run tests
|
||||
run: go test -race -covermode=atomic -coverprofile=coverage.out -v .
|
||||
run: go test -race -covermode=atomic -coverprofile=coverage.out -v ./...
|
||||
- name: Upload coverage reports to Codecov
|
||||
uses: codecov/codecov-action@v4
|
||||
uses: codecov/codecov-action@v5
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -17,3 +17,6 @@
|
||||
# Auth token for tests
|
||||
.openai-token
|
||||
.idea
|
||||
|
||||
# Generated by tests
|
||||
test.mp3
|
||||
434
.golangci.yml
434
.golangci.yml
@@ -1,272 +1,168 @@
|
||||
## Golden config for golangci-lint v1.47.3
|
||||
#
|
||||
# This is the best config for golangci-lint based on my experience and opinion.
|
||||
# It is very strict, but not extremely strict.
|
||||
# Feel free to adopt and change it for your needs.
|
||||
|
||||
run:
|
||||
# Timeout for analysis, e.g. 30s, 5m.
|
||||
# Default: 1m
|
||||
timeout: 3m
|
||||
|
||||
|
||||
# This file contains only configs which differ from defaults.
|
||||
# All possible options can be found here https://github.com/golangci/golangci-lint/blob/master/.golangci.reference.yml
|
||||
linters-settings:
|
||||
cyclop:
|
||||
# The maximal code complexity to report.
|
||||
# Default: 10
|
||||
max-complexity: 30
|
||||
# The maximal average package complexity.
|
||||
# If it's higher than 0.0 (float) the check is enabled
|
||||
# Default: 0.0
|
||||
package-average: 10.0
|
||||
|
||||
errcheck:
|
||||
# Report about not checking of errors in type assertions: `a := b.(MyStruct)`.
|
||||
# Such cases aren't reported by default.
|
||||
# Default: false
|
||||
check-type-assertions: true
|
||||
|
||||
funlen:
|
||||
# Checks the number of lines in a function.
|
||||
# If lower than 0, disable the check.
|
||||
# Default: 60
|
||||
lines: 100
|
||||
# Checks the number of statements in a function.
|
||||
# If lower than 0, disable the check.
|
||||
# Default: 40
|
||||
statements: 50
|
||||
|
||||
gocognit:
|
||||
# Minimal code complexity to report
|
||||
# Default: 30 (but we recommend 10-20)
|
||||
min-complexity: 20
|
||||
|
||||
gocritic:
|
||||
# Settings passed to gocritic.
|
||||
# The settings key is the name of a supported gocritic checker.
|
||||
# The list of supported checkers can be find in https://go-critic.github.io/overview.
|
||||
settings:
|
||||
captLocal:
|
||||
# Whether to restrict checker to params only.
|
||||
# Default: true
|
||||
paramsOnly: false
|
||||
underef:
|
||||
# Whether to skip (*x).method() calls where x is a pointer receiver.
|
||||
# Default: true
|
||||
skipRecvDeref: false
|
||||
|
||||
mnd:
|
||||
# List of function patterns to exclude from analysis.
|
||||
# Values always ignored: `time.Date`
|
||||
# Default: []
|
||||
ignored-functions:
|
||||
- os.Chmod
|
||||
- os.Mkdir
|
||||
- os.MkdirAll
|
||||
- os.OpenFile
|
||||
- os.WriteFile
|
||||
- prometheus.ExponentialBuckets
|
||||
- prometheus.ExponentialBucketsRange
|
||||
- prometheus.LinearBuckets
|
||||
- strconv.FormatFloat
|
||||
- strconv.FormatInt
|
||||
- strconv.FormatUint
|
||||
- strconv.ParseFloat
|
||||
- strconv.ParseInt
|
||||
- strconv.ParseUint
|
||||
|
||||
gomodguard:
|
||||
blocked:
|
||||
# List of blocked modules.
|
||||
# Default: []
|
||||
modules:
|
||||
- github.com/golang/protobuf:
|
||||
recommendations:
|
||||
- google.golang.org/protobuf
|
||||
reason: "see https://developers.google.com/protocol-buffers/docs/reference/go/faq#modules"
|
||||
- github.com/satori/go.uuid:
|
||||
recommendations:
|
||||
- github.com/google/uuid
|
||||
reason: "satori's package is not maintained"
|
||||
- github.com/gofrs/uuid:
|
||||
recommendations:
|
||||
- github.com/google/uuid
|
||||
reason: "see recommendation from dev-infra team: https://confluence.gtforge.com/x/gQI6Aw"
|
||||
|
||||
govet:
|
||||
# Enable all analyzers.
|
||||
# Default: false
|
||||
enable-all: true
|
||||
# Disable analyzers by name.
|
||||
# Run `go tool vet help` to see all analyzers.
|
||||
# Default: []
|
||||
disable:
|
||||
- fieldalignment # too strict
|
||||
# Settings per analyzer.
|
||||
settings:
|
||||
shadow:
|
||||
# Whether to be strict about shadowing; can be noisy.
|
||||
# Default: false
|
||||
strict: true
|
||||
|
||||
nakedret:
|
||||
# Make an issue if func has more lines of code than this setting, and it has naked returns.
|
||||
# Default: 30
|
||||
max-func-lines: 0
|
||||
|
||||
nolintlint:
|
||||
# Exclude following linters from requiring an explanation.
|
||||
# Default: []
|
||||
allow-no-explanation: [ funlen, gocognit, lll ]
|
||||
# Enable to require an explanation of nonzero length after each nolint directive.
|
||||
# Default: false
|
||||
require-explanation: true
|
||||
# Enable to require nolint directives to mention the specific linter being suppressed.
|
||||
# Default: false
|
||||
require-specific: true
|
||||
|
||||
rowserrcheck:
|
||||
# database/sql is always checked
|
||||
# Default: []
|
||||
packages:
|
||||
- github.com/jmoiron/sqlx
|
||||
|
||||
tenv:
|
||||
# The option `all` will run against whole test files (`_test.go`) regardless of method/function signatures.
|
||||
# Otherwise, only methods that take `*testing.T`, `*testing.B`, and `testing.TB` as arguments are checked.
|
||||
# Default: false
|
||||
all: true
|
||||
|
||||
varcheck:
|
||||
# Check usage of exported fields and variables.
|
||||
# Default: false
|
||||
exported-fields: false # default false # TODO: enable after fixing false positives
|
||||
|
||||
|
||||
version: "2"
|
||||
linters:
|
||||
disable-all: true
|
||||
default: none
|
||||
enable:
|
||||
## enabled by default
|
||||
- errcheck # Errcheck is a program for checking for unchecked errors in go programs. These unchecked errors can be critical bugs in some cases
|
||||
- gosimple # Linter for Go source code that specializes in simplifying a code
|
||||
- govet # Vet examines Go source code and reports suspicious constructs, such as Printf calls whose arguments do not align with the format string
|
||||
- ineffassign # Detects when assignments to existing variables are not used
|
||||
- staticcheck # Staticcheck is a go vet on steroids, applying a ton of static analysis checks
|
||||
- typecheck # Like the front-end of a Go compiler, parses and type-checks Go code
|
||||
- unused # Checks Go code for unused constants, variables, functions and types
|
||||
## disabled by default
|
||||
# - asasalint # Check for pass []any as any in variadic func(...any)
|
||||
- asciicheck # Simple linter to check that your code does not contain non-ASCII identifiers
|
||||
- bidichk # Checks for dangerous unicode character sequences
|
||||
- bodyclose # checks whether HTTP response body is closed successfully
|
||||
- contextcheck # check the function whether use a non-inherited context
|
||||
- cyclop # checks function and package cyclomatic complexity
|
||||
- dupl # Tool for code clone detection
|
||||
- durationcheck # check for two durations multiplied together
|
||||
- errname # Checks that sentinel errors are prefixed with the Err and error types are suffixed with the Error.
|
||||
- errorlint # errorlint is a linter for that can be used to find code that will cause problems with the error wrapping scheme introduced in Go 1.13.
|
||||
# Removed execinquery (deprecated). execinquery is a linter about query string checker in Query function which reads your Go src files and warning it finds
|
||||
- exhaustive # check exhaustiveness of enum switch statements
|
||||
- exportloopref # checks for pointers to enclosing loop variables
|
||||
- forbidigo # Forbids identifiers
|
||||
- funlen # Tool for detection of long functions
|
||||
# - gochecknoglobals # check that no global variables exist
|
||||
- gochecknoinits # Checks that no init functions are present in Go code
|
||||
- gocognit # Computes and checks the cognitive complexity of functions
|
||||
- goconst # Finds repeated strings that could be replaced by a constant
|
||||
- gocritic # Provides diagnostics that check for bugs, performance and style issues.
|
||||
- gocyclo # Computes and checks the cyclomatic complexity of functions
|
||||
- godot # Check if comments end in a period
|
||||
- goimports # In addition to fixing imports, goimports also formats your code in the same style as gofmt.
|
||||
- gomoddirectives # Manage the use of 'replace', 'retract', and 'excludes' directives in go.mod.
|
||||
- gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations.
|
||||
- goprintffuncname # Checks that printf-like functions are named with f at the end
|
||||
- gosec # Inspects source code for security problems
|
||||
- lll # Reports long lines
|
||||
- makezero # Finds slice declarations with non-zero initial length
|
||||
# - nakedret # Finds naked returns in functions greater than a specified function length
|
||||
- mnd # An analyzer to detect magic numbers.
|
||||
- nestif # Reports deeply nested if statements
|
||||
- nilerr # Finds the code that returns nil even if it checks that the error is not nil.
|
||||
- nilnil # Checks that there is no simultaneous return of nil error and an invalid value.
|
||||
# - noctx # noctx finds sending http request without context.Context
|
||||
- nolintlint # Reports ill-formed or insufficient nolint directives
|
||||
# - nonamedreturns # Reports all named returns
|
||||
- nosprintfhostport # Checks for misuse of Sprintf to construct a host with port in a URL.
|
||||
- predeclared # find code that shadows one of Go's predeclared identifiers
|
||||
- promlinter # Check Prometheus metrics naming via promlint
|
||||
- revive # Fast, configurable, extensible, flexible, and beautiful linter for Go. Drop-in replacement of golint.
|
||||
- rowserrcheck # checks whether Err of rows is checked successfully
|
||||
- sqlclosecheck # Checks that sql.Rows and sql.Stmt are closed.
|
||||
- stylecheck # Stylecheck is a replacement for golint
|
||||
- tenv # tenv is analyzer that detects using os.Setenv instead of t.Setenv since Go1.17
|
||||
- testpackage # linter that makes you use a separate _test package
|
||||
- tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes
|
||||
- unconvert # Remove unnecessary type conversions
|
||||
- unparam # Reports unused function parameters
|
||||
- wastedassign # wastedassign finds wasted assignment statements.
|
||||
- whitespace # Tool for detection of leading and trailing whitespace
|
||||
## you may want to enable
|
||||
#- decorder # check declaration order and count of types, constants, variables and functions
|
||||
#- exhaustruct # Checks if all structure fields are initialized
|
||||
#- goheader # Checks is file header matches to pattern
|
||||
#- ireturn # Accept Interfaces, Return Concrete Types
|
||||
#- prealloc # [premature optimization, but can be used in some cases] Finds slice declarations that could potentially be preallocated
|
||||
#- varnamelen # [great idea, but too many false positives] checks that the length of a variable's name matches its scope
|
||||
#- wrapcheck # Checks that errors returned from external packages are wrapped
|
||||
## disabled
|
||||
#- containedctx # containedctx is a linter that detects struct contained context.Context field
|
||||
#- depguard # [replaced by gomodguard] Go linter that checks if package imports are in a list of acceptable packages
|
||||
#- dogsled # Checks assignments with too many blank identifiers (e.g. x, _, _, _, := f())
|
||||
#- errchkjson # [don't see profit + I'm against of omitting errors like in the first example https://github.com/breml/errchkjson] Checks types passed to the json encoding functions. Reports unsupported types and optionally reports occasions, where the check for the returned error can be omitted.
|
||||
#- forcetypeassert # [replaced by errcheck] finds forced type assertions
|
||||
#- gci # Gci controls golang package import order and makes it always deterministic.
|
||||
#- godox # Tool for detection of FIXME, TODO and other comment keywords
|
||||
#- goerr113 # [too strict] Golang linter to check the errors handling expressions
|
||||
#- gofmt # [replaced by goimports] Gofmt checks whether code was gofmt-ed. By default this tool runs with -s option to check for code simplification
|
||||
#- gofumpt # [replaced by goimports, gofumports is not available yet] Gofumpt checks whether code was gofumpt-ed.
|
||||
#- grouper # An analyzer to analyze expression groups.
|
||||
#- ifshort # Checks that your code uses short syntax for if-statements whenever possible
|
||||
#- importas # Enforces consistent import aliases
|
||||
#- maintidx # maintidx measures the maintainability index of each function.
|
||||
#- misspell # [useless] Finds commonly misspelled English words in comments
|
||||
#- nlreturn # [too strict and mostly code is not more readable] nlreturn checks for a new line before return and branch statements to increase code clarity
|
||||
#- nosnakecase # Detects snake case of variable naming and function name. # TODO: maybe enable after https://github.com/sivchari/nosnakecase/issues/14
|
||||
#- paralleltest # [too many false positives] paralleltest detects missing usage of t.Parallel() method in your Go test
|
||||
#- tagliatelle # Checks the struct tags.
|
||||
#- thelper # thelper detects golang test helpers without t.Helper() call and checks the consistency of test helpers
|
||||
#- wsl # [too strict and mostly code is not more readable] Whitespace Linter - Forces you to use empty lines!
|
||||
## deprecated
|
||||
#- exhaustivestruct # [deprecated, replaced by exhaustruct] Checks if all struct's fields are initialized
|
||||
#- golint # [deprecated, replaced by revive] Golint differs from gofmt. Gofmt reformats Go source code, whereas golint prints out style mistakes
|
||||
#- interfacer # [deprecated] Linter that suggests narrower interface types
|
||||
#- maligned # [deprecated, replaced by govet fieldalignment] Tool to detect Go structs that would take less memory if their fields were sorted
|
||||
#- scopelint # [deprecated, replaced by exportloopref] Scopelint checks for unpinned variables in go programs
|
||||
|
||||
|
||||
issues:
|
||||
# Maximum count of issues with the same text.
|
||||
# Set to 0 to disable.
|
||||
# Default: 3
|
||||
max-same-issues: 50
|
||||
|
||||
exclude-rules:
|
||||
- source: "^//\\s*go:generate\\s"
|
||||
linters: [ lll ]
|
||||
- source: "(noinspection|TODO)"
|
||||
linters: [ godot ]
|
||||
- source: "//noinspection"
|
||||
linters: [ gocritic ]
|
||||
- source: "^\\s+if _, ok := err\\.\\([^.]+\\.InternalError\\); ok {"
|
||||
linters: [ errorlint ]
|
||||
- path: "_test\\.go"
|
||||
linters:
|
||||
- bodyclose
|
||||
- dupl
|
||||
- asciicheck
|
||||
- bidichk
|
||||
- bodyclose
|
||||
- contextcheck
|
||||
- cyclop
|
||||
- dupl
|
||||
- durationcheck
|
||||
- errcheck
|
||||
- errname
|
||||
- errorlint
|
||||
- exhaustive
|
||||
- forbidigo
|
||||
- funlen
|
||||
- gochecknoinits
|
||||
- gocognit
|
||||
- goconst
|
||||
- gocritic
|
||||
- gocyclo
|
||||
- godot
|
||||
- gomoddirectives
|
||||
- gomodguard
|
||||
- goprintffuncname
|
||||
- gosec
|
||||
- govet
|
||||
- ineffassign
|
||||
- lll
|
||||
- makezero
|
||||
- mnd
|
||||
- nestif
|
||||
- nilerr
|
||||
- nilnil
|
||||
- nolintlint
|
||||
- nosprintfhostport
|
||||
- predeclared
|
||||
- promlinter
|
||||
- revive
|
||||
- rowserrcheck
|
||||
- sqlclosecheck
|
||||
- staticcheck
|
||||
- testpackage
|
||||
- tparallel
|
||||
- unconvert
|
||||
- unparam
|
||||
- unused
|
||||
- usetesting
|
||||
- wastedassign
|
||||
- whitespace
|
||||
settings:
|
||||
cyclop:
|
||||
max-complexity: 30
|
||||
package-average: 10
|
||||
errcheck:
|
||||
check-type-assertions: true
|
||||
funlen:
|
||||
lines: 100
|
||||
statements: 50
|
||||
gocognit:
|
||||
min-complexity: 20
|
||||
gocritic:
|
||||
settings:
|
||||
captLocal:
|
||||
paramsOnly: false
|
||||
underef:
|
||||
skipRecvDeref: false
|
||||
gomodguard:
|
||||
blocked:
|
||||
modules:
|
||||
- github.com/golang/protobuf:
|
||||
recommendations:
|
||||
- google.golang.org/protobuf
|
||||
reason: see https://developers.google.com/protocol-buffers/docs/reference/go/faq#modules
|
||||
- github.com/satori/go.uuid:
|
||||
recommendations:
|
||||
- github.com/google/uuid
|
||||
reason: satori's package is not maintained
|
||||
- github.com/gofrs/uuid:
|
||||
recommendations:
|
||||
- github.com/google/uuid
|
||||
reason: 'see recommendation from dev-infra team: https://confluence.gtforge.com/x/gQI6Aw'
|
||||
govet:
|
||||
disable:
|
||||
- fieldalignment
|
||||
enable-all: true
|
||||
settings:
|
||||
shadow:
|
||||
strict: true
|
||||
mnd:
|
||||
ignored-functions:
|
||||
- os.Chmod
|
||||
- os.Mkdir
|
||||
- os.MkdirAll
|
||||
- os.OpenFile
|
||||
- os.WriteFile
|
||||
- prometheus.ExponentialBuckets
|
||||
- prometheus.ExponentialBucketsRange
|
||||
- prometheus.LinearBuckets
|
||||
- strconv.FormatFloat
|
||||
- strconv.FormatInt
|
||||
- strconv.FormatUint
|
||||
- strconv.ParseFloat
|
||||
- strconv.ParseInt
|
||||
- strconv.ParseUint
|
||||
nakedret:
|
||||
max-func-lines: 0
|
||||
nolintlint:
|
||||
require-explanation: true
|
||||
require-specific: true
|
||||
allow-no-explanation:
|
||||
- funlen
|
||||
- goconst
|
||||
- gosec
|
||||
- noctx
|
||||
- wrapcheck
|
||||
- gocognit
|
||||
- lll
|
||||
rowserrcheck:
|
||||
packages:
|
||||
- github.com/jmoiron/sqlx
|
||||
exclusions:
|
||||
generated: lax
|
||||
presets:
|
||||
- comments
|
||||
- common-false-positives
|
||||
- legacy
|
||||
- std-error-handling
|
||||
rules:
|
||||
- linters:
|
||||
- forbidigo
|
||||
- mnd
|
||||
- revive
|
||||
path : ^examples/.*\.go$
|
||||
- linters:
|
||||
- lll
|
||||
source: ^//\s*go:generate\s
|
||||
- linters:
|
||||
- godot
|
||||
source: (noinspection|TODO)
|
||||
- linters:
|
||||
- gocritic
|
||||
source: //noinspection
|
||||
- linters:
|
||||
- errorlint
|
||||
source: ^\s+if _, ok := err\.\([^.]+\.InternalError\); ok {
|
||||
- linters:
|
||||
- bodyclose
|
||||
- dupl
|
||||
- funlen
|
||||
- goconst
|
||||
- gosec
|
||||
- noctx
|
||||
- wrapcheck
|
||||
- staticcheck
|
||||
path: _test\.go
|
||||
paths:
|
||||
- third_party$
|
||||
- builtin$
|
||||
- examples$
|
||||
issues:
|
||||
max-same-issues: 50
|
||||
formatters:
|
||||
enable:
|
||||
- goimports
|
||||
exclusions:
|
||||
generated: lax
|
||||
paths:
|
||||
- third_party$
|
||||
- builtin$
|
||||
- examples$
|
||||
|
||||
@@ -1,22 +1,22 @@
|
||||
# Contributing Guidelines
|
||||
|
||||
## Overview
|
||||
Thank you for your interest in contributing to the "Go OpenAI" project! By following this guideline, we hope to ensure that your contributions are made smoothly and efficiently. The Go OpenAI project is licensed under the [Apache 2.0 License](https://github.com/sashabaranov/go-openai/blob/master/LICENSE), and we welcome contributions through GitHub pull requests.
|
||||
Thank you for your interest in contributing to the "Go OpenAI" project! By following this guideline, we hope to ensure that your contributions are made smoothly and efficiently. The Go OpenAI project is licensed under the [Apache 2.0 License](https://git.vaala.cloud/VaalaCat/go-openai/blob/master/LICENSE), and we welcome contributions through GitHub pull requests.
|
||||
|
||||
## Reporting Bugs
|
||||
If you discover a bug, first check the [GitHub Issues page](https://github.com/sashabaranov/go-openai/issues) to see if the issue has already been reported. If you're reporting a new issue, please use the "Bug report" template and provide detailed information about the problem, including steps to reproduce it.
|
||||
If you discover a bug, first check the [GitHub Issues page](https://git.vaala.cloud/VaalaCat/go-openai/issues) to see if the issue has already been reported. If you're reporting a new issue, please use the "Bug report" template and provide detailed information about the problem, including steps to reproduce it.
|
||||
|
||||
## Suggesting Features
|
||||
If you want to suggest a new feature or improvement, first check the [GitHub Issues page](https://github.com/sashabaranov/go-openai/issues) to ensure a similar suggestion hasn't already been made. Use the "Feature request" template to provide a detailed description of your suggestion.
|
||||
If you want to suggest a new feature or improvement, first check the [GitHub Issues page](https://git.vaala.cloud/VaalaCat/go-openai/issues) to ensure a similar suggestion hasn't already been made. Use the "Feature request" template to provide a detailed description of your suggestion.
|
||||
|
||||
## Reporting Vulnerabilities
|
||||
If you identify a security concern, please use the "Report a security vulnerability" template on the [GitHub Issues page](https://github.com/sashabaranov/go-openai/issues) to share the details. This report will only be viewable to repository maintainers. You will be credited if the advisory is published.
|
||||
If you identify a security concern, please use the "Report a security vulnerability" template on the [GitHub Issues page](https://git.vaala.cloud/VaalaCat/go-openai/issues) to share the details. This report will only be viewable to repository maintainers. You will be credited if the advisory is published.
|
||||
|
||||
## Questions for Users
|
||||
If you have questions, please utilize [StackOverflow](https://stackoverflow.com/) or the [GitHub Discussions page](https://github.com/sashabaranov/go-openai/discussions).
|
||||
If you have questions, please utilize [StackOverflow](https://stackoverflow.com/) or the [GitHub Discussions page](https://git.vaala.cloud/VaalaCat/go-openai/discussions).
|
||||
|
||||
## Contributing Code
|
||||
There might already be a similar pull requests submitted! Please search for [pull requests](https://github.com/sashabaranov/go-openai/pulls) before creating one.
|
||||
There might already be a similar pull requests submitted! Please search for [pull requests](https://git.vaala.cloud/VaalaCat/go-openai/pulls) before creating one.
|
||||
|
||||
### Requirements for Merging a Pull Request
|
||||
|
||||
|
||||
106
README.md
106
README.md
@@ -1,19 +1,19 @@
|
||||
# Go OpenAI
|
||||
[](https://pkg.go.dev/github.com/sashabaranov/go-openai)
|
||||
[](https://goreportcard.com/report/github.com/sashabaranov/go-openai)
|
||||
[](https://pkg.go.dev/git.vaala.cloud/VaalaCat/go-openai)
|
||||
[](https://goreportcard.com/report/git.vaala.cloud/VaalaCat/go-openai)
|
||||
[](https://codecov.io/gh/sashabaranov/go-openai)
|
||||
|
||||
This library provides unofficial Go clients for [OpenAI API](https://platform.openai.com/). We support:
|
||||
|
||||
* ChatGPT 4o, o1
|
||||
* GPT-3, GPT-4
|
||||
* DALL·E 2, DALL·E 3
|
||||
* DALL·E 2, DALL·E 3, GPT Image 1
|
||||
* Whisper
|
||||
|
||||
## Installation
|
||||
|
||||
```
|
||||
go get github.com/sashabaranov/go-openai
|
||||
go get git.vaala.cloud/VaalaCat/go-openai
|
||||
```
|
||||
Currently, go-openai requires Go version 1.18 or greater.
|
||||
|
||||
@@ -28,7 +28,7 @@ package main
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
openai "github.com/sashabaranov/go-openai"
|
||||
openai "git.vaala.cloud/VaalaCat/go-openai"
|
||||
)
|
||||
|
||||
func main() {
|
||||
@@ -80,7 +80,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
openai "github.com/sashabaranov/go-openai"
|
||||
openai "git.vaala.cloud/VaalaCat/go-openai"
|
||||
)
|
||||
|
||||
func main() {
|
||||
@@ -133,7 +133,7 @@ package main
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
openai "github.com/sashabaranov/go-openai"
|
||||
openai "git.vaala.cloud/VaalaCat/go-openai"
|
||||
)
|
||||
|
||||
func main() {
|
||||
@@ -166,7 +166,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
openai "github.com/sashabaranov/go-openai"
|
||||
openai "git.vaala.cloud/VaalaCat/go-openai"
|
||||
)
|
||||
|
||||
func main() {
|
||||
@@ -215,7 +215,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
openai "github.com/sashabaranov/go-openai"
|
||||
openai "git.vaala.cloud/VaalaCat/go-openai"
|
||||
)
|
||||
|
||||
func main() {
|
||||
@@ -247,7 +247,7 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
openai "github.com/sashabaranov/go-openai"
|
||||
openai "git.vaala.cloud/VaalaCat/go-openai"
|
||||
)
|
||||
|
||||
func main() {
|
||||
@@ -288,7 +288,7 @@ import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
openai "github.com/sashabaranov/go-openai"
|
||||
openai "git.vaala.cloud/VaalaCat/go-openai"
|
||||
"image/png"
|
||||
"os"
|
||||
)
|
||||
@@ -357,6 +357,66 @@ func main() {
|
||||
```
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>GPT Image 1 image generation</summary>
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
openai "github.com/sashabaranov/go-openai"
|
||||
)
|
||||
|
||||
func main() {
|
||||
c := openai.NewClient("your token")
|
||||
ctx := context.Background()
|
||||
|
||||
req := openai.ImageRequest{
|
||||
Prompt: "Parrot on a skateboard performing a trick. Large bold text \"SKATE MASTER\" banner at the bottom of the image. Cartoon style, natural light, high detail, 1:1 aspect ratio.",
|
||||
Background: openai.CreateImageBackgroundOpaque,
|
||||
Model: openai.CreateImageModelGptImage1,
|
||||
Size: openai.CreateImageSize1024x1024,
|
||||
N: 1,
|
||||
Quality: openai.CreateImageQualityLow,
|
||||
OutputCompression: 100,
|
||||
OutputFormat: openai.CreateImageOutputFormatJPEG,
|
||||
// Moderation: openai.CreateImageModerationLow,
|
||||
// User: "",
|
||||
}
|
||||
|
||||
resp, err := c.CreateImage(ctx, req)
|
||||
if err != nil {
|
||||
fmt.Printf("Image creation Image generation with GPT Image 1error: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Println("Image Base64:", resp.Data[0].B64JSON)
|
||||
|
||||
// Decode the base64 data
|
||||
imgBytes, err := base64.StdEncoding.DecodeString(resp.Data[0].B64JSON)
|
||||
if err != nil {
|
||||
fmt.Printf("Base64 decode error: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Write image to file
|
||||
outputPath := "generated_image.jpg"
|
||||
err = os.WriteFile(outputPath, imgBytes, 0644)
|
||||
if err != nil {
|
||||
fmt.Printf("Failed to write image file: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Printf("The image was saved as %s\n", outputPath)
|
||||
}
|
||||
```
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>Configuring proxy</summary>
|
||||
|
||||
@@ -376,7 +436,7 @@ config.HTTPClient = &http.Client{
|
||||
c := openai.NewClientWithConfig(config)
|
||||
```
|
||||
|
||||
See also: https://pkg.go.dev/github.com/sashabaranov/go-openai#ClientConfig
|
||||
See also: https://pkg.go.dev/git.vaala.cloud/VaalaCat/go-openai#ClientConfig
|
||||
</details>
|
||||
|
||||
<details>
|
||||
@@ -392,7 +452,7 @@ import (
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/sashabaranov/go-openai"
|
||||
"git.vaala.cloud/VaalaCat/go-openai"
|
||||
)
|
||||
|
||||
func main() {
|
||||
@@ -446,7 +506,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
openai "github.com/sashabaranov/go-openai"
|
||||
openai "git.vaala.cloud/VaalaCat/go-openai"
|
||||
)
|
||||
|
||||
func main() {
|
||||
@@ -492,7 +552,7 @@ package main
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
openai "github.com/sashabaranov/go-openai"
|
||||
openai "git.vaala.cloud/VaalaCat/go-openai"
|
||||
|
||||
)
|
||||
|
||||
@@ -549,7 +609,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
openai "github.com/sashabaranov/go-openai"
|
||||
openai "git.vaala.cloud/VaalaCat/go-openai"
|
||||
)
|
||||
|
||||
func main() {
|
||||
@@ -680,7 +740,7 @@ package main
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/sashabaranov/go-openai"
|
||||
"git.vaala.cloud/VaalaCat/go-openai"
|
||||
)
|
||||
|
||||
func main() {
|
||||
@@ -755,8 +815,8 @@ import (
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"github.com/sashabaranov/go-openai"
|
||||
"github.com/sashabaranov/go-openai/jsonschema"
|
||||
"git.vaala.cloud/VaalaCat/go-openai"
|
||||
"git.vaala.cloud/VaalaCat/go-openai/jsonschema"
|
||||
)
|
||||
|
||||
func main() {
|
||||
@@ -828,7 +888,7 @@ Due to the factors mentioned above, different answers may be returned even for t
|
||||
By adopting these strategies, you can expect more consistent results.
|
||||
|
||||
**Related Issues:**
|
||||
[omitempty option of request struct will generate incorrect request when parameter is 0.](https://github.com/sashabaranov/go-openai/issues/9)
|
||||
[omitempty option of request struct will generate incorrect request when parameter is 0.](https://git.vaala.cloud/VaalaCat/go-openai/issues/9)
|
||||
|
||||
### Does Go OpenAI provide a method to count tokens?
|
||||
|
||||
@@ -839,15 +899,15 @@ For counting tokens, you might find the following links helpful:
|
||||
- [How to count tokens with tiktoken](https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb)
|
||||
|
||||
**Related Issues:**
|
||||
[Is it possible to join the implementation of GPT3 Tokenizer](https://github.com/sashabaranov/go-openai/issues/62)
|
||||
[Is it possible to join the implementation of GPT3 Tokenizer](https://git.vaala.cloud/VaalaCat/go-openai/issues/62)
|
||||
|
||||
## Contributing
|
||||
|
||||
By following [Contributing Guidelines](https://github.com/sashabaranov/go-openai/blob/master/CONTRIBUTING.md), we hope to ensure that your contributions are made smoothly and efficiently.
|
||||
By following [Contributing Guidelines](https://git.vaala.cloud/VaalaCat/go-openai/blob/master/CONTRIBUTING.md), we hope to ensure that your contributions are made smoothly and efficiently.
|
||||
|
||||
## Thank you
|
||||
|
||||
We want to take a moment to express our deepest gratitude to the [contributors](https://github.com/sashabaranov/go-openai/graphs/contributors) and sponsors of this project:
|
||||
We want to take a moment to express our deepest gratitude to the [contributors](https://git.vaala.cloud/VaalaCat/go-openai/graphs/contributors) and sponsors of this project:
|
||||
- [Carson Kahn](https://carsonkahn.com) of [Spindle AI](https://spindleai.com)
|
||||
|
||||
To all of you: thank you. You've helped us achieve more than we ever imagined possible. Can't wait to see where we go next, together!
|
||||
|
||||
@@ -10,9 +10,9 @@ import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/sashabaranov/go-openai"
|
||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||
"github.com/sashabaranov/go-openai/jsonschema"
|
||||
"git.vaala.cloud/VaalaCat/go-openai"
|
||||
"git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
|
||||
"git.vaala.cloud/VaalaCat/go-openai/jsonschema"
|
||||
)
|
||||
|
||||
func TestAPI(t *testing.T) {
|
||||
|
||||
@@ -3,8 +3,8 @@ package openai_test
|
||||
import (
|
||||
"context"
|
||||
|
||||
openai "github.com/sashabaranov/go-openai"
|
||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||
openai "git.vaala.cloud/VaalaCat/go-openai"
|
||||
"git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
|
||||
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
2
audio.go
2
audio.go
@@ -8,7 +8,7 @@ import (
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
utils "github.com/sashabaranov/go-openai/internal"
|
||||
utils "git.vaala.cloud/VaalaCat/go-openai/internal"
|
||||
)
|
||||
|
||||
// Whisper Defines the models provided by OpenAI to use when processing audio with OpenAI.
|
||||
|
||||
@@ -12,9 +12,9 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/sashabaranov/go-openai"
|
||||
"github.com/sashabaranov/go-openai/internal/test"
|
||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||
"git.vaala.cloud/VaalaCat/go-openai"
|
||||
"git.vaala.cloud/VaalaCat/go-openai/internal/test"
|
||||
"git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
|
||||
)
|
||||
|
||||
// TestAudio Tests the transcription and translation endpoints of the API using the mocked server.
|
||||
@@ -40,12 +40,9 @@ func TestAudio(t *testing.T) {
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
dir, cleanup := test.CreateTestDirectory(t)
|
||||
defer cleanup()
|
||||
|
||||
for _, tc := range testcases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
path := filepath.Join(dir, "fake.mp3")
|
||||
path := filepath.Join(t.TempDir(), "fake.mp3")
|
||||
test.CreateTestFile(t, path)
|
||||
|
||||
req := openai.AudioRequest{
|
||||
@@ -90,12 +87,9 @@ func TestAudioWithOptionalArgs(t *testing.T) {
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
dir, cleanup := test.CreateTestDirectory(t)
|
||||
defer cleanup()
|
||||
|
||||
for _, tc := range testcases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
path := filepath.Join(dir, "fake.mp3")
|
||||
path := filepath.Join(t.TempDir(), "fake.mp3")
|
||||
test.CreateTestFile(t, path)
|
||||
|
||||
req := openai.AudioRequest{
|
||||
|
||||
143
audio_test.go
143
audio_test.go
@@ -2,20 +2,21 @@ package openai //nolint:testpackage // testing private field
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/sashabaranov/go-openai/internal/test"
|
||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||
"git.vaala.cloud/VaalaCat/go-openai/internal/test"
|
||||
"git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
|
||||
)
|
||||
|
||||
func TestAudioWithFailingFormBuilder(t *testing.T) {
|
||||
dir, cleanup := test.CreateTestDirectory(t)
|
||||
defer cleanup()
|
||||
path := filepath.Join(dir, "fake.mp3")
|
||||
path := filepath.Join(t.TempDir(), "fake.mp3")
|
||||
test.CreateTestFile(t, path)
|
||||
|
||||
req := AudioRequest{
|
||||
@@ -63,9 +64,7 @@ func TestAudioWithFailingFormBuilder(t *testing.T) {
|
||||
|
||||
func TestCreateFileField(t *testing.T) {
|
||||
t.Run("createFileField failing file", func(t *testing.T) {
|
||||
dir, cleanup := test.CreateTestDirectory(t)
|
||||
defer cleanup()
|
||||
path := filepath.Join(dir, "fake.mp3")
|
||||
path := filepath.Join(t.TempDir(), "fake.mp3")
|
||||
test.CreateTestFile(t, path)
|
||||
|
||||
req := AudioRequest{
|
||||
@@ -111,3 +110,131 @@ func TestCreateFileField(t *testing.T) {
|
||||
checks.HasError(t, err, "createFileField using file should return error when open file fails")
|
||||
})
|
||||
}
|
||||
|
||||
// failingFormBuilder always returns an error when creating form files.
|
||||
type failingFormBuilder struct{ err error }
|
||||
|
||||
func (f *failingFormBuilder) CreateFormFile(_ string, _ *os.File) error {
|
||||
return f.err
|
||||
}
|
||||
|
||||
func (f *failingFormBuilder) CreateFormFileReader(_ string, _ io.Reader, _ string) error {
|
||||
return f.err
|
||||
}
|
||||
|
||||
func (f *failingFormBuilder) WriteField(_, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *failingFormBuilder) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *failingFormBuilder) FormDataContentType() string {
|
||||
return "multipart/form-data"
|
||||
}
|
||||
|
||||
// failingAudioRequestBuilder simulates an error during HTTP request construction.
|
||||
type failingAudioRequestBuilder struct{ err error }
|
||||
|
||||
func (f *failingAudioRequestBuilder) Build(
|
||||
_ context.Context,
|
||||
_, _ string,
|
||||
_ any,
|
||||
_ http.Header,
|
||||
) (*http.Request, error) {
|
||||
return nil, f.err
|
||||
}
|
||||
|
||||
// errorHTTPClient always returns an error when making HTTP calls.
|
||||
type errorHTTPClient struct{ err error }
|
||||
|
||||
func (e *errorHTTPClient) Do(_ *http.Request) (*http.Response, error) {
|
||||
return nil, e.err
|
||||
}
|
||||
|
||||
func TestCallAudioAPIMultipartFormError(t *testing.T) {
|
||||
client := NewClient("test-token")
|
||||
errForm := errors.New("mock create form file failure")
|
||||
// Override form builder to force an error during multipart form creation.
|
||||
client.createFormBuilder = func(_ io.Writer) utils.FormBuilder {
|
||||
return &failingFormBuilder{err: errForm}
|
||||
}
|
||||
|
||||
// Provide a reader so createFileField uses the reader path (no file open).
|
||||
req := AudioRequest{FilePath: "fake.mp3", Reader: bytes.NewBuffer([]byte("dummy")), Model: Whisper1}
|
||||
_, err := client.callAudioAPI(context.Background(), req, "transcriptions")
|
||||
if err == nil {
|
||||
t.Fatal("expected error but got none")
|
||||
}
|
||||
if !errors.Is(err, errForm) {
|
||||
t.Errorf("expected error %v, got %v", errForm, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCallAudioAPINewRequestError(t *testing.T) {
|
||||
client := NewClient("test-token")
|
||||
// Create a real temp file so multipart form succeeds.
|
||||
tmp := t.TempDir()
|
||||
path := filepath.Join(tmp, "file.mp3")
|
||||
if err := os.WriteFile(path, []byte("content"), 0644); err != nil {
|
||||
t.Fatalf("failed to write temp file: %v", err)
|
||||
}
|
||||
|
||||
errBuild := errors.New("mock build failure")
|
||||
client.requestBuilder = &failingAudioRequestBuilder{err: errBuild}
|
||||
|
||||
req := AudioRequest{FilePath: path, Model: Whisper1}
|
||||
_, err := client.callAudioAPI(context.Background(), req, "translations")
|
||||
if err == nil {
|
||||
t.Fatal("expected error but got none")
|
||||
}
|
||||
if !errors.Is(err, errBuild) {
|
||||
t.Errorf("expected error %v, got %v", errBuild, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCallAudioAPISendRequestErrorJSON(t *testing.T) {
|
||||
client := NewClient("test-token")
|
||||
// Create a real temp file so multipart form succeeds.
|
||||
tmp := t.TempDir()
|
||||
path := filepath.Join(tmp, "file.mp3")
|
||||
if err := os.WriteFile(path, []byte("content"), 0644); err != nil {
|
||||
t.Fatalf("failed to write temp file: %v", err)
|
||||
}
|
||||
|
||||
errHTTP := errors.New("mock HTTPClient failure")
|
||||
// Override HTTP client to simulate a network error.
|
||||
client.config.HTTPClient = &errorHTTPClient{err: errHTTP}
|
||||
|
||||
req := AudioRequest{FilePath: path, Model: Whisper1}
|
||||
_, err := client.callAudioAPI(context.Background(), req, "transcriptions")
|
||||
if err == nil {
|
||||
t.Fatal("expected error but got none")
|
||||
}
|
||||
if !errors.Is(err, errHTTP) {
|
||||
t.Errorf("expected error %v, got %v", errHTTP, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCallAudioAPISendRequestErrorText(t *testing.T) {
|
||||
client := NewClient("test-token")
|
||||
tmp := t.TempDir()
|
||||
path := filepath.Join(tmp, "file.mp3")
|
||||
if err := os.WriteFile(path, []byte("content"), 0644); err != nil {
|
||||
t.Fatalf("failed to write temp file: %v", err)
|
||||
}
|
||||
|
||||
errHTTP := errors.New("mock HTTPClient failure")
|
||||
client.config.HTTPClient = &errorHTTPClient{err: errHTTP}
|
||||
|
||||
// Use a non-JSON response format to exercise the text path.
|
||||
req := AudioRequest{FilePath: path, Model: Whisper1, Format: AudioResponseFormatText}
|
||||
_, err := client.callAudioAPI(context.Background(), req, "translations")
|
||||
if err == nil {
|
||||
t.Fatal("expected error but got none")
|
||||
}
|
||||
if !errors.Is(err, errHTTP) {
|
||||
t.Errorf("expected error %v, got %v", errHTTP, err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,8 +7,8 @@ import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/sashabaranov/go-openai"
|
||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||
"git.vaala.cloud/VaalaCat/go-openai"
|
||||
"git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
|
||||
)
|
||||
|
||||
func TestUploadBatchFile(t *testing.T) {
|
||||
|
||||
100
chat.go
100
chat.go
@@ -14,6 +14,7 @@ const (
|
||||
ChatMessageRoleAssistant = "assistant"
|
||||
ChatMessageRoleFunction = "function"
|
||||
ChatMessageRoleTool = "tool"
|
||||
ChatMessageRoleDeveloper = "developer"
|
||||
)
|
||||
|
||||
const chatCompletionsSuffix = "/chat/completions"
|
||||
@@ -93,7 +94,7 @@ type ChatMessagePart struct {
|
||||
|
||||
type ChatCompletionMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
Content string `json:"content,omitempty"`
|
||||
Refusal string `json:"refusal,omitempty"`
|
||||
MultiContent []ChatMessagePart
|
||||
|
||||
@@ -103,6 +104,12 @@ type ChatCompletionMessage struct {
|
||||
// - https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
||||
Name string `json:"name,omitempty"`
|
||||
|
||||
// This property is used for the "reasoning" feature supported by deepseek-reasoner
|
||||
// which is not in the official documentation.
|
||||
// the doc from deepseek:
|
||||
// - https://api-docs.deepseek.com/api/create-chat-completion#responses
|
||||
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||
|
||||
FunctionCall *FunctionCall `json:"function_call,omitempty"`
|
||||
|
||||
// For Role=assistant prompts this may be set to the tool calls generated by the model, such as function calls.
|
||||
@@ -118,41 +125,44 @@ func (m ChatCompletionMessage) MarshalJSON() ([]byte, error) {
|
||||
}
|
||||
if len(m.MultiContent) > 0 {
|
||||
msg := struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"-"`
|
||||
Refusal string `json:"refusal,omitempty"`
|
||||
MultiContent []ChatMessagePart `json:"content,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
FunctionCall *FunctionCall `json:"function_call,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
Role string `json:"role"`
|
||||
Content string `json:"-"`
|
||||
Refusal string `json:"refusal,omitempty"`
|
||||
MultiContent []ChatMessagePart `json:"content,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||
FunctionCall *FunctionCall `json:"function_call,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
}(m)
|
||||
return json.Marshal(msg)
|
||||
}
|
||||
|
||||
msg := struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
Refusal string `json:"refusal,omitempty"`
|
||||
MultiContent []ChatMessagePart `json:"-"`
|
||||
Name string `json:"name,omitempty"`
|
||||
FunctionCall *FunctionCall `json:"function_call,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content,omitempty"`
|
||||
Refusal string `json:"refusal,omitempty"`
|
||||
MultiContent []ChatMessagePart `json:"-"`
|
||||
Name string `json:"name,omitempty"`
|
||||
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||
FunctionCall *FunctionCall `json:"function_call,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
}(m)
|
||||
return json.Marshal(msg)
|
||||
}
|
||||
|
||||
func (m *ChatCompletionMessage) UnmarshalJSON(bs []byte) error {
|
||||
msg := struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
Refusal string `json:"refusal,omitempty"`
|
||||
MultiContent []ChatMessagePart
|
||||
Name string `json:"name,omitempty"`
|
||||
FunctionCall *FunctionCall `json:"function_call,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
Refusal string `json:"refusal,omitempty"`
|
||||
MultiContent []ChatMessagePart
|
||||
Name string `json:"name,omitempty"`
|
||||
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||
FunctionCall *FunctionCall `json:"function_call,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
}{}
|
||||
|
||||
if err := json.Unmarshal(bs, &msg); err == nil {
|
||||
@@ -160,14 +170,15 @@ func (m *ChatCompletionMessage) UnmarshalJSON(bs []byte) error {
|
||||
return nil
|
||||
}
|
||||
multiMsg := struct {
|
||||
Role string `json:"role"`
|
||||
Content string
|
||||
Refusal string `json:"refusal,omitempty"`
|
||||
MultiContent []ChatMessagePart `json:"content"`
|
||||
Name string `json:"name,omitempty"`
|
||||
FunctionCall *FunctionCall `json:"function_call,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
Role string `json:"role"`
|
||||
Content string
|
||||
Refusal string `json:"refusal,omitempty"`
|
||||
MultiContent []ChatMessagePart `json:"content"`
|
||||
Name string `json:"name,omitempty"`
|
||||
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||
FunctionCall *FunctionCall `json:"function_call,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
}{}
|
||||
if err := json.Unmarshal(bs, &multiMsg); err != nil {
|
||||
return err
|
||||
@@ -179,7 +190,7 @@ func (m *ChatCompletionMessage) UnmarshalJSON(bs []byte) error {
|
||||
type ToolCall struct {
|
||||
// Index is not nil only in chat completion chunk object
|
||||
Index *int `json:"index,omitempty"`
|
||||
ID string `json:"id"`
|
||||
ID string `json:"id,omitempty"`
|
||||
Type ToolType `json:"type"`
|
||||
Function FunctionCall `json:"function"`
|
||||
}
|
||||
@@ -258,8 +269,19 @@ type ChatCompletionRequest struct {
|
||||
// Store can be set to true to store the output of this completion request for use in distillations and evals.
|
||||
// https://platform.openai.com/docs/api-reference/chat/create#chat-create-store
|
||||
Store bool `json:"store,omitempty"`
|
||||
// Controls effort on reasoning for reasoning models. It can be set to "low", "medium", or "high".
|
||||
ReasoningEffort string `json:"reasoning_effort,omitempty"`
|
||||
// Metadata to store with the completion.
|
||||
Metadata map[string]string `json:"metadata,omitempty"`
|
||||
Metadata map[string]string `json:"metadata,omitempty"`
|
||||
IncludeReasoning *bool `json:"include_reasoning,omitempty"`
|
||||
ReasoningFormat *string `json:"reasoning_format,omitempty"`
|
||||
// Configuration for a predicted output.
|
||||
Prediction *Prediction `json:"prediction,omitempty"`
|
||||
// ChatTemplateKwargs provides a way to add non-standard parameters to the request body.
|
||||
// Additional kwargs to pass to the template renderer. Will be accessible by the chat template.
|
||||
// Such as think mode for qwen3. "chat_template_kwargs": {"enable_thinking": false}
|
||||
// https://qwen.readthedocs.io/en/latest/deployment/vllm.html#thinking-non-thinking-modes
|
||||
ChatTemplateKwargs map[string]any `json:"chat_template_kwargs,omitempty"`
|
||||
}
|
||||
|
||||
type StreamOptions struct {
|
||||
@@ -327,6 +349,11 @@ type LogProbs struct {
|
||||
Content []LogProb `json:"content"`
|
||||
}
|
||||
|
||||
type Prediction struct {
|
||||
Content string `json:"content"`
|
||||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
type FinishReason string
|
||||
|
||||
const (
|
||||
@@ -390,7 +417,8 @@ func (c *Client) CreateChatCompletion(
|
||||
return
|
||||
}
|
||||
|
||||
if err = validateRequestForO1Models(request); err != nil {
|
||||
reasoningValidator := NewReasoningValidator()
|
||||
if err = reasoningValidator.Validate(request); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -11,6 +11,12 @@ type ChatCompletionStreamChoiceDelta struct {
|
||||
FunctionCall *FunctionCall `json:"function_call,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
Refusal string `json:"refusal,omitempty"`
|
||||
|
||||
// This property is used for the "reasoning" feature supported by deepseek-reasoner
|
||||
// which is not in the official documentation.
|
||||
// the doc from deepseek:
|
||||
// - https://api-docs.deepseek.com/api/create-chat-completion#responses
|
||||
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||
}
|
||||
|
||||
type ChatCompletionStreamChoiceLogprobs struct {
|
||||
@@ -80,7 +86,8 @@ func (c *Client) CreateChatCompletionStream(
|
||||
}
|
||||
|
||||
request.Stream = true
|
||||
if err = validateRequestForO1Models(request); err != nil {
|
||||
reasoningValidator := NewReasoningValidator()
|
||||
if err = reasoningValidator.Validate(request); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -10,8 +10,8 @@ import (
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/sashabaranov/go-openai"
|
||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||
"git.vaala.cloud/VaalaCat/go-openai"
|
||||
"git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
|
||||
)
|
||||
|
||||
func TestChatCompletionsStreamWrongModel(t *testing.T) {
|
||||
@@ -792,6 +792,223 @@ func compareChatResponses(r1, r2 openai.ChatCompletionStreamResponse) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func TestCreateChatCompletionStreamWithReasoningModel(t *testing.T) {
|
||||
client, server, teardown := setupOpenAITestServer()
|
||||
defer teardown()
|
||||
server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
|
||||
dataBytes := []byte{}
|
||||
|
||||
//nolint:lll
|
||||
dataBytes = append(dataBytes, []byte(`data: {"id":"1","object":"chat.completion.chunk","created":1729585728,"model":"o3-mini-2025-01-31","system_fingerprint":"fp_mini","choices":[{"index":0,"delta":{"role":"assistant","content":""},"finish_reason":null}]}`)...)
|
||||
dataBytes = append(dataBytes, []byte("\n\n")...)
|
||||
|
||||
//nolint:lll
|
||||
dataBytes = append(dataBytes, []byte(`data: {"id":"2","object":"chat.completion.chunk","created":1729585728,"model":"o3-mini-2025-01-31","system_fingerprint":"fp_mini","choices":[{"index":0,"delta":{"content":"Hello"},"finish_reason":null}]}`)...)
|
||||
dataBytes = append(dataBytes, []byte("\n\n")...)
|
||||
|
||||
//nolint:lll
|
||||
dataBytes = append(dataBytes, []byte(`data: {"id":"3","object":"chat.completion.chunk","created":1729585728,"model":"o3-mini-2025-01-31","system_fingerprint":"fp_mini","choices":[{"index":0,"delta":{"content":" from"},"finish_reason":null}]}`)...)
|
||||
dataBytes = append(dataBytes, []byte("\n\n")...)
|
||||
|
||||
//nolint:lll
|
||||
dataBytes = append(dataBytes, []byte(`data: {"id":"4","object":"chat.completion.chunk","created":1729585728,"model":"o3-mini-2025-01-31","system_fingerprint":"fp_mini","choices":[{"index":0,"delta":{"content":" O3Mini"},"finish_reason":null}]}`)...)
|
||||
dataBytes = append(dataBytes, []byte("\n\n")...)
|
||||
|
||||
//nolint:lll
|
||||
dataBytes = append(dataBytes, []byte(`data: {"id":"5","object":"chat.completion.chunk","created":1729585728,"model":"o3-mini-2025-01-31","system_fingerprint":"fp_mini","choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}`)...)
|
||||
dataBytes = append(dataBytes, []byte("\n\n")...)
|
||||
|
||||
dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...)
|
||||
|
||||
_, err := w.Write(dataBytes)
|
||||
checks.NoError(t, err, "Write error")
|
||||
})
|
||||
|
||||
stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{
|
||||
MaxCompletionTokens: 2000,
|
||||
Model: openai.O3Mini20250131,
|
||||
Messages: []openai.ChatCompletionMessage{
|
||||
{
|
||||
Role: openai.ChatMessageRoleUser,
|
||||
Content: "Hello!",
|
||||
},
|
||||
},
|
||||
Stream: true,
|
||||
})
|
||||
checks.NoError(t, err, "CreateCompletionStream returned error")
|
||||
defer stream.Close()
|
||||
|
||||
expectedResponses := []openai.ChatCompletionStreamResponse{
|
||||
{
|
||||
ID: "1",
|
||||
Object: "chat.completion.chunk",
|
||||
Created: 1729585728,
|
||||
Model: openai.O3Mini20250131,
|
||||
SystemFingerprint: "fp_mini",
|
||||
Choices: []openai.ChatCompletionStreamChoice{
|
||||
{
|
||||
Index: 0,
|
||||
Delta: openai.ChatCompletionStreamChoiceDelta{
|
||||
Role: "assistant",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "2",
|
||||
Object: "chat.completion.chunk",
|
||||
Created: 1729585728,
|
||||
Model: openai.O3Mini20250131,
|
||||
SystemFingerprint: "fp_mini",
|
||||
Choices: []openai.ChatCompletionStreamChoice{
|
||||
{
|
||||
Index: 0,
|
||||
Delta: openai.ChatCompletionStreamChoiceDelta{
|
||||
Content: "Hello",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "3",
|
||||
Object: "chat.completion.chunk",
|
||||
Created: 1729585728,
|
||||
Model: openai.O3Mini20250131,
|
||||
SystemFingerprint: "fp_mini",
|
||||
Choices: []openai.ChatCompletionStreamChoice{
|
||||
{
|
||||
Index: 0,
|
||||
Delta: openai.ChatCompletionStreamChoiceDelta{
|
||||
Content: " from",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "4",
|
||||
Object: "chat.completion.chunk",
|
||||
Created: 1729585728,
|
||||
Model: openai.O3Mini20250131,
|
||||
SystemFingerprint: "fp_mini",
|
||||
Choices: []openai.ChatCompletionStreamChoice{
|
||||
{
|
||||
Index: 0,
|
||||
Delta: openai.ChatCompletionStreamChoiceDelta{
|
||||
Content: " O3Mini",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "5",
|
||||
Object: "chat.completion.chunk",
|
||||
Created: 1729585728,
|
||||
Model: openai.O3Mini20250131,
|
||||
SystemFingerprint: "fp_mini",
|
||||
Choices: []openai.ChatCompletionStreamChoice{
|
||||
{
|
||||
Index: 0,
|
||||
Delta: openai.ChatCompletionStreamChoiceDelta{},
|
||||
FinishReason: "stop",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for ix, expectedResponse := range expectedResponses {
|
||||
b, _ := json.Marshal(expectedResponse)
|
||||
t.Logf("%d: %s", ix, string(b))
|
||||
|
||||
receivedResponse, streamErr := stream.Recv()
|
||||
checks.NoError(t, streamErr, "stream.Recv() failed")
|
||||
if !compareChatResponses(expectedResponse, receivedResponse) {
|
||||
t.Errorf("Stream response %v is %v, expected %v", ix, receivedResponse, expectedResponse)
|
||||
}
|
||||
}
|
||||
|
||||
_, streamErr := stream.Recv()
|
||||
if !errors.Is(streamErr, io.EOF) {
|
||||
t.Errorf("stream.Recv() did not return EOF in the end: %v", streamErr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateChatCompletionStreamReasoningValidatorFails(t *testing.T) {
|
||||
client, _, _ := setupOpenAITestServer()
|
||||
|
||||
stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{
|
||||
MaxTokens: 100, // This will trigger the validator to fail
|
||||
Model: openai.O3Mini,
|
||||
Messages: []openai.ChatCompletionMessage{
|
||||
{
|
||||
Role: openai.ChatMessageRoleUser,
|
||||
Content: "Hello!",
|
||||
},
|
||||
},
|
||||
Stream: true,
|
||||
})
|
||||
|
||||
if stream != nil {
|
||||
t.Error("Expected nil stream when validation fails")
|
||||
stream.Close()
|
||||
}
|
||||
|
||||
if !errors.Is(err, openai.ErrReasoningModelMaxTokensDeprecated) {
|
||||
t.Errorf("Expected ErrReasoningModelMaxTokensDeprecated, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateChatCompletionStreamO3ReasoningValidatorFails(t *testing.T) {
|
||||
client, _, _ := setupOpenAITestServer()
|
||||
|
||||
stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{
|
||||
MaxTokens: 100, // This will trigger the validator to fail
|
||||
Model: openai.O3,
|
||||
Messages: []openai.ChatCompletionMessage{
|
||||
{
|
||||
Role: openai.ChatMessageRoleUser,
|
||||
Content: "Hello!",
|
||||
},
|
||||
},
|
||||
Stream: true,
|
||||
})
|
||||
|
||||
if stream != nil {
|
||||
t.Error("Expected nil stream when validation fails")
|
||||
stream.Close()
|
||||
}
|
||||
|
||||
if !errors.Is(err, openai.ErrReasoningModelMaxTokensDeprecated) {
|
||||
t.Errorf("Expected ErrReasoningModelMaxTokensDeprecated for O3, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateChatCompletionStreamO4MiniReasoningValidatorFails(t *testing.T) {
|
||||
client, _, _ := setupOpenAITestServer()
|
||||
|
||||
stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{
|
||||
MaxTokens: 100, // This will trigger the validator to fail
|
||||
Model: openai.O4Mini,
|
||||
Messages: []openai.ChatCompletionMessage{
|
||||
{
|
||||
Role: openai.ChatMessageRoleUser,
|
||||
Content: "Hello!",
|
||||
},
|
||||
},
|
||||
Stream: true,
|
||||
})
|
||||
|
||||
if stream != nil {
|
||||
t.Error("Expected nil stream when validation fails")
|
||||
stream.Close()
|
||||
}
|
||||
|
||||
if !errors.Is(err, openai.ErrReasoningModelMaxTokensDeprecated) {
|
||||
t.Errorf("Expected ErrReasoningModelMaxTokensDeprecated for O4Mini, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func compareChatStreamResponseChoices(c1, c2 openai.ChatCompletionStreamChoice) bool {
|
||||
if c1.Index != c2.Index {
|
||||
return false
|
||||
|
||||
274
chat_test.go
274
chat_test.go
@@ -12,9 +12,9 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sashabaranov/go-openai"
|
||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||
"github.com/sashabaranov/go-openai/jsonschema"
|
||||
"git.vaala.cloud/VaalaCat/go-openai"
|
||||
"git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
|
||||
"git.vaala.cloud/VaalaCat/go-openai/jsonschema"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -64,7 +64,7 @@ func TestO1ModelsChatCompletionsDeprecatedFields(t *testing.T) {
|
||||
MaxTokens: 5,
|
||||
Model: openai.O1Preview,
|
||||
},
|
||||
expectedError: openai.ErrO1MaxTokensDeprecated,
|
||||
expectedError: openai.ErrReasoningModelMaxTokensDeprecated,
|
||||
},
|
||||
{
|
||||
name: "o1-mini_MaxTokens_deprecated",
|
||||
@@ -72,7 +72,7 @@ func TestO1ModelsChatCompletionsDeprecatedFields(t *testing.T) {
|
||||
MaxTokens: 5,
|
||||
Model: openai.O1Mini,
|
||||
},
|
||||
expectedError: openai.ErrO1MaxTokensDeprecated,
|
||||
expectedError: openai.ErrReasoningModelMaxTokensDeprecated,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -104,41 +104,7 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) {
|
||||
LogProbs: true,
|
||||
Model: openai.O1Preview,
|
||||
},
|
||||
expectedError: openai.ErrO1BetaLimitationsLogprobs,
|
||||
},
|
||||
{
|
||||
name: "message_type_unsupported",
|
||||
in: openai.ChatCompletionRequest{
|
||||
MaxCompletionTokens: 1000,
|
||||
Model: openai.O1Mini,
|
||||
Messages: []openai.ChatCompletionMessage{
|
||||
{
|
||||
Role: openai.ChatMessageRoleSystem,
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedError: openai.ErrO1BetaLimitationsMessageTypes,
|
||||
},
|
||||
{
|
||||
name: "tool_unsupported",
|
||||
in: openai.ChatCompletionRequest{
|
||||
MaxCompletionTokens: 1000,
|
||||
Model: openai.O1Mini,
|
||||
Messages: []openai.ChatCompletionMessage{
|
||||
{
|
||||
Role: openai.ChatMessageRoleUser,
|
||||
},
|
||||
{
|
||||
Role: openai.ChatMessageRoleAssistant,
|
||||
},
|
||||
},
|
||||
Tools: []openai.Tool{
|
||||
{
|
||||
Type: openai.ToolTypeFunction,
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedError: openai.ErrO1BetaLimitationsTools,
|
||||
expectedError: openai.ErrReasoningModelLimitationsLogprobs,
|
||||
},
|
||||
{
|
||||
name: "set_temperature_unsupported",
|
||||
@@ -155,7 +121,7 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) {
|
||||
},
|
||||
Temperature: float32(2),
|
||||
},
|
||||
expectedError: openai.ErrO1BetaLimitationsOther,
|
||||
expectedError: openai.ErrReasoningModelLimitationsOther,
|
||||
},
|
||||
{
|
||||
name: "set_top_unsupported",
|
||||
@@ -173,7 +139,7 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) {
|
||||
Temperature: float32(1),
|
||||
TopP: float32(0.1),
|
||||
},
|
||||
expectedError: openai.ErrO1BetaLimitationsOther,
|
||||
expectedError: openai.ErrReasoningModelLimitationsOther,
|
||||
},
|
||||
{
|
||||
name: "set_n_unsupported",
|
||||
@@ -192,7 +158,7 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) {
|
||||
TopP: float32(1),
|
||||
N: 2,
|
||||
},
|
||||
expectedError: openai.ErrO1BetaLimitationsOther,
|
||||
expectedError: openai.ErrReasoningModelLimitationsOther,
|
||||
},
|
||||
{
|
||||
name: "set_presence_penalty_unsupported",
|
||||
@@ -209,7 +175,7 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) {
|
||||
},
|
||||
PresencePenalty: float32(1),
|
||||
},
|
||||
expectedError: openai.ErrO1BetaLimitationsOther,
|
||||
expectedError: openai.ErrReasoningModelLimitationsOther,
|
||||
},
|
||||
{
|
||||
name: "set_frequency_penalty_unsupported",
|
||||
@@ -226,7 +192,127 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) {
|
||||
},
|
||||
FrequencyPenalty: float32(0.1),
|
||||
},
|
||||
expectedError: openai.ErrO1BetaLimitationsOther,
|
||||
expectedError: openai.ErrReasoningModelLimitationsOther,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
config := openai.DefaultConfig("whatever")
|
||||
config.BaseURL = "http://localhost/v1"
|
||||
client := openai.NewClientWithConfig(config)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := client.CreateChatCompletion(ctx, tt.in)
|
||||
checks.HasError(t, err)
|
||||
msg := fmt.Sprintf("CreateChatCompletion should return wrong model error, returned: %s", err)
|
||||
checks.ErrorIs(t, err, tt.expectedError, msg)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestO3ModelsChatCompletionsBetaLimitations(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
in openai.ChatCompletionRequest
|
||||
expectedError error
|
||||
}{
|
||||
{
|
||||
name: "log_probs_unsupported",
|
||||
in: openai.ChatCompletionRequest{
|
||||
MaxCompletionTokens: 1000,
|
||||
LogProbs: true,
|
||||
Model: openai.O3Mini,
|
||||
},
|
||||
expectedError: openai.ErrReasoningModelLimitationsLogprobs,
|
||||
},
|
||||
{
|
||||
name: "set_temperature_unsupported",
|
||||
in: openai.ChatCompletionRequest{
|
||||
MaxCompletionTokens: 1000,
|
||||
Model: openai.O3Mini,
|
||||
Messages: []openai.ChatCompletionMessage{
|
||||
{
|
||||
Role: openai.ChatMessageRoleUser,
|
||||
},
|
||||
{
|
||||
Role: openai.ChatMessageRoleAssistant,
|
||||
},
|
||||
},
|
||||
Temperature: float32(2),
|
||||
},
|
||||
expectedError: openai.ErrReasoningModelLimitationsOther,
|
||||
},
|
||||
{
|
||||
name: "set_top_unsupported",
|
||||
in: openai.ChatCompletionRequest{
|
||||
MaxCompletionTokens: 1000,
|
||||
Model: openai.O3Mini,
|
||||
Messages: []openai.ChatCompletionMessage{
|
||||
{
|
||||
Role: openai.ChatMessageRoleUser,
|
||||
},
|
||||
{
|
||||
Role: openai.ChatMessageRoleAssistant,
|
||||
},
|
||||
},
|
||||
Temperature: float32(1),
|
||||
TopP: float32(0.1),
|
||||
},
|
||||
expectedError: openai.ErrReasoningModelLimitationsOther,
|
||||
},
|
||||
{
|
||||
name: "set_n_unsupported",
|
||||
in: openai.ChatCompletionRequest{
|
||||
MaxCompletionTokens: 1000,
|
||||
Model: openai.O3Mini,
|
||||
Messages: []openai.ChatCompletionMessage{
|
||||
{
|
||||
Role: openai.ChatMessageRoleUser,
|
||||
},
|
||||
{
|
||||
Role: openai.ChatMessageRoleAssistant,
|
||||
},
|
||||
},
|
||||
Temperature: float32(1),
|
||||
TopP: float32(1),
|
||||
N: 2,
|
||||
},
|
||||
expectedError: openai.ErrReasoningModelLimitationsOther,
|
||||
},
|
||||
{
|
||||
name: "set_presence_penalty_unsupported",
|
||||
in: openai.ChatCompletionRequest{
|
||||
MaxCompletionTokens: 1000,
|
||||
Model: openai.O3Mini,
|
||||
Messages: []openai.ChatCompletionMessage{
|
||||
{
|
||||
Role: openai.ChatMessageRoleUser,
|
||||
},
|
||||
{
|
||||
Role: openai.ChatMessageRoleAssistant,
|
||||
},
|
||||
},
|
||||
PresencePenalty: float32(1),
|
||||
},
|
||||
expectedError: openai.ErrReasoningModelLimitationsOther,
|
||||
},
|
||||
{
|
||||
name: "set_frequency_penalty_unsupported",
|
||||
in: openai.ChatCompletionRequest{
|
||||
MaxCompletionTokens: 1000,
|
||||
Model: openai.O3Mini,
|
||||
Messages: []openai.ChatCompletionMessage{
|
||||
{
|
||||
Role: openai.ChatMessageRoleUser,
|
||||
},
|
||||
{
|
||||
Role: openai.ChatMessageRoleAssistant,
|
||||
},
|
||||
},
|
||||
FrequencyPenalty: float32(0.1),
|
||||
},
|
||||
expectedError: openai.ErrReasoningModelLimitationsOther,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -308,6 +394,40 @@ func TestO1ModelChatCompletions(t *testing.T) {
|
||||
checks.NoError(t, err, "CreateChatCompletion error")
|
||||
}
|
||||
|
||||
func TestO3ModelChatCompletions(t *testing.T) {
|
||||
client, server, teardown := setupOpenAITestServer()
|
||||
defer teardown()
|
||||
server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint)
|
||||
_, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{
|
||||
Model: openai.O3Mini,
|
||||
MaxCompletionTokens: 1000,
|
||||
Messages: []openai.ChatCompletionMessage{
|
||||
{
|
||||
Role: openai.ChatMessageRoleUser,
|
||||
Content: "Hello!",
|
||||
},
|
||||
},
|
||||
})
|
||||
checks.NoError(t, err, "CreateChatCompletion error")
|
||||
}
|
||||
|
||||
func TestDeepseekR1ModelChatCompletions(t *testing.T) {
|
||||
client, server, teardown := setupOpenAITestServer()
|
||||
defer teardown()
|
||||
server.RegisterHandler("/v1/chat/completions", handleDeepseekR1ChatCompletionEndpoint)
|
||||
_, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{
|
||||
Model: "deepseek-reasoner",
|
||||
MaxCompletionTokens: 100,
|
||||
Messages: []openai.ChatCompletionMessage{
|
||||
{
|
||||
Role: openai.ChatMessageRoleUser,
|
||||
Content: "Hello!",
|
||||
},
|
||||
},
|
||||
})
|
||||
checks.NoError(t, err, "CreateChatCompletion error")
|
||||
}
|
||||
|
||||
// TestCompletions Tests the completions endpoint of the API using the mocked server.
|
||||
func TestChatCompletionsWithHeaders(t *testing.T) {
|
||||
client, server, teardown := setupOpenAITestServer()
|
||||
@@ -631,7 +751,7 @@ func TestMultipartChatMessageSerialization(t *testing.T) {
|
||||
t.Fatalf("Unexpected error")
|
||||
}
|
||||
res = strings.ReplaceAll(string(s), " ", "")
|
||||
if res != `{"role":"user","content":""}` {
|
||||
if res != `{"role":"user"}` {
|
||||
t.Fatalf("invalid message: %s", string(s))
|
||||
}
|
||||
}
|
||||
@@ -719,6 +839,68 @@ func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, string(resBytes))
|
||||
}
|
||||
|
||||
func handleDeepseekR1ChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
|
||||
var err error
|
||||
var resBytes []byte
|
||||
|
||||
// completions only accepts POST requests
|
||||
if r.Method != "POST" {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
}
|
||||
var completionReq openai.ChatCompletionRequest
|
||||
if completionReq, err = getChatCompletionBody(r); err != nil {
|
||||
http.Error(w, "could not read request", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
res := openai.ChatCompletionResponse{
|
||||
ID: strconv.Itoa(int(time.Now().Unix())),
|
||||
Object: "test-object",
|
||||
Created: time.Now().Unix(),
|
||||
// would be nice to validate Model during testing, but
|
||||
// this may not be possible with how much upkeep
|
||||
// would be required / wouldn't make much sense
|
||||
Model: completionReq.Model,
|
||||
}
|
||||
// create completions
|
||||
n := completionReq.N
|
||||
if n == 0 {
|
||||
n = 1
|
||||
}
|
||||
if completionReq.MaxCompletionTokens == 0 {
|
||||
completionReq.MaxCompletionTokens = 1000
|
||||
}
|
||||
for i := 0; i < n; i++ {
|
||||
reasoningContent := "User says hello! And I need to reply"
|
||||
completionStr := strings.Repeat("a", completionReq.MaxCompletionTokens-numTokens(reasoningContent))
|
||||
res.Choices = append(res.Choices, openai.ChatCompletionChoice{
|
||||
Message: openai.ChatCompletionMessage{
|
||||
Role: openai.ChatMessageRoleAssistant,
|
||||
ReasoningContent: reasoningContent,
|
||||
Content: completionStr,
|
||||
},
|
||||
Index: i,
|
||||
})
|
||||
}
|
||||
inputTokens := numTokens(completionReq.Messages[0].Content) * n
|
||||
completionTokens := completionReq.MaxTokens * n
|
||||
res.Usage = openai.Usage{
|
||||
PromptTokens: inputTokens,
|
||||
CompletionTokens: completionTokens,
|
||||
TotalTokens: inputTokens + completionTokens,
|
||||
}
|
||||
resBytes, _ = json.Marshal(res)
|
||||
w.Header().Set(xCustomHeader, xCustomHeaderValue)
|
||||
for k, v := range rateLimitHeaders {
|
||||
switch val := v.(type) {
|
||||
case int:
|
||||
w.Header().Set(k, strconv.Itoa(val))
|
||||
default:
|
||||
w.Header().Set(k, fmt.Sprintf("%s", v))
|
||||
}
|
||||
}
|
||||
fmt.Fprintln(w, string(resBytes))
|
||||
}
|
||||
|
||||
// getChatCompletionBody Returns the body of the request to create a completion.
|
||||
func getChatCompletionBody(r *http.Request) (openai.ChatCompletionRequest, error) {
|
||||
completion := openai.ChatCompletionRequest{}
|
||||
|
||||
20
client.go
20
client.go
@@ -10,7 +10,7 @@ import (
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
utils "github.com/sashabaranov/go-openai/internal"
|
||||
utils "git.vaala.cloud/VaalaCat/go-openai/internal"
|
||||
)
|
||||
|
||||
// Client is OpenAI GPT-3 API client.
|
||||
@@ -182,13 +182,21 @@ func sendRequestStream[T streamable](client *Client, req *http.Request) (*stream
|
||||
|
||||
func (c *Client) setCommonHeaders(req *http.Request) {
|
||||
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/reference#authentication
|
||||
// Azure API Key authentication
|
||||
if c.config.APIType == APITypeAzure || c.config.APIType == APITypeCloudflareAzure {
|
||||
switch c.config.APIType {
|
||||
case APITypeAzure, APITypeCloudflareAzure:
|
||||
// Azure API Key authentication
|
||||
req.Header.Set(AzureAPIKeyHeader, c.config.authToken)
|
||||
} else if c.config.authToken != "" {
|
||||
// OpenAI or Azure AD authentication
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken))
|
||||
case APITypeAnthropic:
|
||||
// https://docs.anthropic.com/en/api/versioning
|
||||
req.Header.Set("anthropic-version", c.config.APIVersion)
|
||||
case APITypeOpenAI, APITypeAzureAD:
|
||||
fallthrough
|
||||
default:
|
||||
if c.config.authToken != "" {
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken))
|
||||
}
|
||||
}
|
||||
|
||||
if c.config.OrgID != "" {
|
||||
req.Header.Set("OpenAI-Organization", c.config.OrgID)
|
||||
}
|
||||
|
||||
@@ -10,8 +10,8 @@ import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/sashabaranov/go-openai/internal/test"
|
||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||
"git.vaala.cloud/VaalaCat/go-openai/internal/test"
|
||||
"git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
|
||||
)
|
||||
|
||||
var errTestRequestBuilderFailed = errors.New("test request builder failed")
|
||||
@@ -39,6 +39,21 @@ func TestClient(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetCommonHeadersAnthropic(t *testing.T) {
|
||||
config := DefaultAnthropicConfig("mock-token", "")
|
||||
client := NewClientWithConfig(config)
|
||||
req, err := http.NewRequest("GET", "http://example.com", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create request: %v", err)
|
||||
}
|
||||
|
||||
client.setCommonHeaders(req)
|
||||
|
||||
if got := req.Header.Get("anthropic-version"); got != AnthropicAPIVersion {
|
||||
t.Errorf("Expected anthropic-version header to be %q, got %q", AnthropicAPIVersion, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeResponse(t *testing.T) {
|
||||
stringInput := ""
|
||||
|
||||
|
||||
@@ -13,8 +13,10 @@ type Usage struct {
|
||||
|
||||
// CompletionTokensDetails Breakdown of tokens used in a completion.
|
||||
type CompletionTokensDetails struct {
|
||||
AudioTokens int `json:"audio_tokens"`
|
||||
ReasoningTokens int `json:"reasoning_tokens"`
|
||||
AudioTokens int `json:"audio_tokens"`
|
||||
ReasoningTokens int `json:"reasoning_tokens"`
|
||||
AcceptedPredictionTokens int `json:"accepted_prediction_tokens"`
|
||||
RejectedPredictionTokens int `json:"rejected_prediction_tokens"`
|
||||
}
|
||||
|
||||
// PromptTokensDetails Breakdown of tokens used in the prompt.
|
||||
|
||||
235
completion.go
235
completion.go
@@ -2,59 +2,61 @@ package openai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrO1MaxTokensDeprecated = errors.New("this model is not supported MaxTokens, please use MaxCompletionTokens") //nolint:lll
|
||||
ErrCompletionUnsupportedModel = errors.New("this model is not supported with this method, please use CreateChatCompletion client method instead") //nolint:lll
|
||||
ErrCompletionStreamNotSupported = errors.New("streaming is not supported with this method, please use CreateCompletionStream") //nolint:lll
|
||||
ErrCompletionRequestPromptTypeNotSupported = errors.New("the type of CompletionRequest.Prompt only supports string and []string") //nolint:lll
|
||||
)
|
||||
|
||||
var (
|
||||
ErrO1BetaLimitationsMessageTypes = errors.New("this model has beta-limitations, user and assistant messages only, system messages are not supported") //nolint:lll
|
||||
ErrO1BetaLimitationsTools = errors.New("this model has beta-limitations, tools, function calling, and response format parameters are not supported") //nolint:lll
|
||||
ErrO1BetaLimitationsLogprobs = errors.New("this model has beta-limitations, logprobs not supported") //nolint:lll
|
||||
ErrO1BetaLimitationsOther = errors.New("this model has beta-limitations, temperature, top_p and n are fixed at 1, while presence_penalty and frequency_penalty are fixed at 0") //nolint:lll
|
||||
)
|
||||
|
||||
// GPT3 Defines the models provided by OpenAI to use when generating
|
||||
// completions from OpenAI.
|
||||
// GPT3 Models are designed for text-based tasks. For code-specific
|
||||
// tasks, please refer to the Codex series of models.
|
||||
const (
|
||||
O1Mini = "o1-mini"
|
||||
O1Mini20240912 = "o1-mini-2024-09-12"
|
||||
O1Preview = "o1-preview"
|
||||
O1Preview20240912 = "o1-preview-2024-09-12"
|
||||
GPT432K0613 = "gpt-4-32k-0613"
|
||||
GPT432K0314 = "gpt-4-32k-0314"
|
||||
GPT432K = "gpt-4-32k"
|
||||
GPT40613 = "gpt-4-0613"
|
||||
GPT40314 = "gpt-4-0314"
|
||||
GPT4o = "gpt-4o"
|
||||
GPT4o20240513 = "gpt-4o-2024-05-13"
|
||||
GPT4o20240806 = "gpt-4o-2024-08-06"
|
||||
GPT4oLatest = "chatgpt-4o-latest"
|
||||
GPT4oMini = "gpt-4o-mini"
|
||||
GPT4oMini20240718 = "gpt-4o-mini-2024-07-18"
|
||||
GPT4Turbo = "gpt-4-turbo"
|
||||
GPT4Turbo20240409 = "gpt-4-turbo-2024-04-09"
|
||||
GPT4Turbo0125 = "gpt-4-0125-preview"
|
||||
GPT4Turbo1106 = "gpt-4-1106-preview"
|
||||
GPT4TurboPreview = "gpt-4-turbo-preview"
|
||||
GPT4VisionPreview = "gpt-4-vision-preview"
|
||||
GPT4 = "gpt-4"
|
||||
GPT3Dot5Turbo0125 = "gpt-3.5-turbo-0125"
|
||||
GPT3Dot5Turbo1106 = "gpt-3.5-turbo-1106"
|
||||
GPT3Dot5Turbo0613 = "gpt-3.5-turbo-0613"
|
||||
GPT3Dot5Turbo0301 = "gpt-3.5-turbo-0301"
|
||||
GPT3Dot5Turbo16K = "gpt-3.5-turbo-16k"
|
||||
GPT3Dot5Turbo16K0613 = "gpt-3.5-turbo-16k-0613"
|
||||
GPT3Dot5Turbo = "gpt-3.5-turbo"
|
||||
GPT3Dot5TurboInstruct = "gpt-3.5-turbo-instruct"
|
||||
O1Mini = "o1-mini"
|
||||
O1Mini20240912 = "o1-mini-2024-09-12"
|
||||
O1Preview = "o1-preview"
|
||||
O1Preview20240912 = "o1-preview-2024-09-12"
|
||||
O1 = "o1"
|
||||
O120241217 = "o1-2024-12-17"
|
||||
O3 = "o3"
|
||||
O320250416 = "o3-2025-04-16"
|
||||
O3Mini = "o3-mini"
|
||||
O3Mini20250131 = "o3-mini-2025-01-31"
|
||||
O4Mini = "o4-mini"
|
||||
O4Mini20250416 = "o4-mini-2025-04-16"
|
||||
GPT432K0613 = "gpt-4-32k-0613"
|
||||
GPT432K0314 = "gpt-4-32k-0314"
|
||||
GPT432K = "gpt-4-32k"
|
||||
GPT40613 = "gpt-4-0613"
|
||||
GPT40314 = "gpt-4-0314"
|
||||
GPT4o = "gpt-4o"
|
||||
GPT4o20240513 = "gpt-4o-2024-05-13"
|
||||
GPT4o20240806 = "gpt-4o-2024-08-06"
|
||||
GPT4o20241120 = "gpt-4o-2024-11-20"
|
||||
GPT4oLatest = "chatgpt-4o-latest"
|
||||
GPT4oMini = "gpt-4o-mini"
|
||||
GPT4oMini20240718 = "gpt-4o-mini-2024-07-18"
|
||||
GPT4Turbo = "gpt-4-turbo"
|
||||
GPT4Turbo20240409 = "gpt-4-turbo-2024-04-09"
|
||||
GPT4Turbo0125 = "gpt-4-0125-preview"
|
||||
GPT4Turbo1106 = "gpt-4-1106-preview"
|
||||
GPT4TurboPreview = "gpt-4-turbo-preview"
|
||||
GPT4VisionPreview = "gpt-4-vision-preview"
|
||||
GPT4 = "gpt-4"
|
||||
GPT4Dot1 = "gpt-4.1"
|
||||
GPT4Dot120250414 = "gpt-4.1-2025-04-14"
|
||||
GPT4Dot1Mini = "gpt-4.1-mini"
|
||||
GPT4Dot1Mini20250414 = "gpt-4.1-mini-2025-04-14"
|
||||
GPT4Dot1Nano = "gpt-4.1-nano"
|
||||
GPT4Dot1Nano20250414 = "gpt-4.1-nano-2025-04-14"
|
||||
GPT4Dot5Preview = "gpt-4.5-preview"
|
||||
GPT4Dot5Preview20250227 = "gpt-4.5-preview-2025-02-27"
|
||||
GPT3Dot5Turbo0125 = "gpt-3.5-turbo-0125"
|
||||
GPT3Dot5Turbo1106 = "gpt-3.5-turbo-1106"
|
||||
GPT3Dot5Turbo0613 = "gpt-3.5-turbo-0613"
|
||||
GPT3Dot5Turbo0301 = "gpt-3.5-turbo-0301"
|
||||
GPT3Dot5Turbo16K = "gpt-3.5-turbo-16k"
|
||||
GPT3Dot5Turbo16K0613 = "gpt-3.5-turbo-16k-0613"
|
||||
GPT3Dot5Turbo = "gpt-3.5-turbo"
|
||||
GPT3Dot5TurboInstruct = "gpt-3.5-turbo-instruct"
|
||||
// Deprecated: Model is shutdown. Use gpt-3.5-turbo-instruct instead.
|
||||
GPT3TextDavinci003 = "text-davinci-003"
|
||||
// Deprecated: Model is shutdown. Use gpt-3.5-turbo-instruct instead.
|
||||
@@ -93,46 +95,53 @@ const (
|
||||
CodexCodeDavinci001 = "code-davinci-001"
|
||||
)
|
||||
|
||||
// O1SeriesModels List of new Series of OpenAI models.
|
||||
// Some old api attributes not supported.
|
||||
var O1SeriesModels = map[string]struct{}{
|
||||
O1Mini: {},
|
||||
O1Mini20240912: {},
|
||||
O1Preview: {},
|
||||
O1Preview20240912: {},
|
||||
}
|
||||
|
||||
var disabledModelsForEndpoints = map[string]map[string]bool{
|
||||
"/completions": {
|
||||
O1Mini: true,
|
||||
O1Mini20240912: true,
|
||||
O1Preview: true,
|
||||
O1Preview20240912: true,
|
||||
GPT3Dot5Turbo: true,
|
||||
GPT3Dot5Turbo0301: true,
|
||||
GPT3Dot5Turbo0613: true,
|
||||
GPT3Dot5Turbo1106: true,
|
||||
GPT3Dot5Turbo0125: true,
|
||||
GPT3Dot5Turbo16K: true,
|
||||
GPT3Dot5Turbo16K0613: true,
|
||||
GPT4: true,
|
||||
GPT4o: true,
|
||||
GPT4o20240513: true,
|
||||
GPT4o20240806: true,
|
||||
GPT4oLatest: true,
|
||||
GPT4oMini: true,
|
||||
GPT4oMini20240718: true,
|
||||
GPT4TurboPreview: true,
|
||||
GPT4VisionPreview: true,
|
||||
GPT4Turbo1106: true,
|
||||
GPT4Turbo0125: true,
|
||||
GPT4Turbo: true,
|
||||
GPT4Turbo20240409: true,
|
||||
GPT40314: true,
|
||||
GPT40613: true,
|
||||
GPT432K: true,
|
||||
GPT432K0314: true,
|
||||
GPT432K0613: true,
|
||||
O1Mini: true,
|
||||
O1Mini20240912: true,
|
||||
O1Preview: true,
|
||||
O1Preview20240912: true,
|
||||
O3Mini: true,
|
||||
O3Mini20250131: true,
|
||||
O4Mini: true,
|
||||
O4Mini20250416: true,
|
||||
O3: true,
|
||||
O320250416: true,
|
||||
GPT3Dot5Turbo: true,
|
||||
GPT3Dot5Turbo0301: true,
|
||||
GPT3Dot5Turbo0613: true,
|
||||
GPT3Dot5Turbo1106: true,
|
||||
GPT3Dot5Turbo0125: true,
|
||||
GPT3Dot5Turbo16K: true,
|
||||
GPT3Dot5Turbo16K0613: true,
|
||||
GPT4: true,
|
||||
GPT4Dot5Preview: true,
|
||||
GPT4Dot5Preview20250227: true,
|
||||
GPT4o: true,
|
||||
GPT4o20240513: true,
|
||||
GPT4o20240806: true,
|
||||
GPT4o20241120: true,
|
||||
GPT4oLatest: true,
|
||||
GPT4oMini: true,
|
||||
GPT4oMini20240718: true,
|
||||
GPT4TurboPreview: true,
|
||||
GPT4VisionPreview: true,
|
||||
GPT4Turbo1106: true,
|
||||
GPT4Turbo0125: true,
|
||||
GPT4Turbo: true,
|
||||
GPT4Turbo20240409: true,
|
||||
GPT40314: true,
|
||||
GPT40613: true,
|
||||
GPT432K: true,
|
||||
GPT432K0314: true,
|
||||
GPT432K0613: true,
|
||||
O1: true,
|
||||
GPT4Dot1: true,
|
||||
GPT4Dot120250414: true,
|
||||
GPT4Dot1Mini: true,
|
||||
GPT4Dot1Mini20250414: true,
|
||||
GPT4Dot1Nano: true,
|
||||
GPT4Dot1Nano20250414: true,
|
||||
},
|
||||
chatCompletionsSuffix: {
|
||||
CodexCodeDavinci002: true,
|
||||
@@ -179,64 +188,6 @@ func checkPromptType(prompt any) bool {
|
||||
return true // all items in the slice are string, so it is []string
|
||||
}
|
||||
|
||||
var unsupportedToolsForO1Models = map[ToolType]struct{}{
|
||||
ToolTypeFunction: {},
|
||||
}
|
||||
|
||||
var availableMessageRoleForO1Models = map[string]struct{}{
|
||||
ChatMessageRoleUser: {},
|
||||
ChatMessageRoleAssistant: {},
|
||||
}
|
||||
|
||||
// validateRequestForO1Models checks for deprecated fields of OpenAI models.
|
||||
func validateRequestForO1Models(request ChatCompletionRequest) error {
|
||||
if _, found := O1SeriesModels[request.Model]; !found {
|
||||
return nil
|
||||
}
|
||||
|
||||
if request.MaxTokens > 0 {
|
||||
return ErrO1MaxTokensDeprecated
|
||||
}
|
||||
|
||||
// Logprobs: not supported.
|
||||
if request.LogProbs {
|
||||
return ErrO1BetaLimitationsLogprobs
|
||||
}
|
||||
|
||||
// Message types: user and assistant messages only, system messages are not supported.
|
||||
for _, m := range request.Messages {
|
||||
if _, found := availableMessageRoleForO1Models[m.Role]; !found {
|
||||
return ErrO1BetaLimitationsMessageTypes
|
||||
}
|
||||
}
|
||||
|
||||
// Tools: tools, function calling, and response format parameters are not supported
|
||||
for _, t := range request.Tools {
|
||||
if _, found := unsupportedToolsForO1Models[t.Type]; found {
|
||||
return ErrO1BetaLimitationsTools
|
||||
}
|
||||
}
|
||||
|
||||
// Other: temperature, top_p and n are fixed at 1, while presence_penalty and frequency_penalty are fixed at 0.
|
||||
if request.Temperature > 0 && request.Temperature != 1 {
|
||||
return ErrO1BetaLimitationsOther
|
||||
}
|
||||
if request.TopP > 0 && request.TopP != 1 {
|
||||
return ErrO1BetaLimitationsOther
|
||||
}
|
||||
if request.N > 0 && request.N != 1 {
|
||||
return ErrO1BetaLimitationsOther
|
||||
}
|
||||
if request.PresencePenalty > 0 {
|
||||
return ErrO1BetaLimitationsOther
|
||||
}
|
||||
if request.FrequencyPenalty > 0 {
|
||||
return ErrO1BetaLimitationsOther
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CompletionRequest represents a request structure for completion API.
|
||||
type CompletionRequest struct {
|
||||
Model string `json:"model"`
|
||||
@@ -264,6 +215,8 @@ type CompletionRequest struct {
|
||||
Temperature float32 `json:"temperature,omitempty"`
|
||||
TopP float32 `json:"top_p,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
// Options for streaming response. Only set this when you set stream: true.
|
||||
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
|
||||
}
|
||||
|
||||
// CompletionChoice represents one of possible completions.
|
||||
|
||||
@@ -12,8 +12,8 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sashabaranov/go-openai"
|
||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||
"git.vaala.cloud/VaalaCat/go-openai"
|
||||
"git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
|
||||
)
|
||||
|
||||
func TestCompletionsWrongModel(t *testing.T) {
|
||||
@@ -33,6 +33,42 @@ func TestCompletionsWrongModel(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestCompletionsWrongModelO3 Tests the completions endpoint with O3 model which is not supported.
|
||||
func TestCompletionsWrongModelO3(t *testing.T) {
|
||||
config := openai.DefaultConfig("whatever")
|
||||
config.BaseURL = "http://localhost/v1"
|
||||
client := openai.NewClientWithConfig(config)
|
||||
|
||||
_, err := client.CreateCompletion(
|
||||
context.Background(),
|
||||
openai.CompletionRequest{
|
||||
MaxTokens: 5,
|
||||
Model: openai.O3,
|
||||
},
|
||||
)
|
||||
if !errors.Is(err, openai.ErrCompletionUnsupportedModel) {
|
||||
t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel for O3, but returned: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCompletionsWrongModelO4Mini Tests the completions endpoint with O4Mini model which is not supported.
|
||||
func TestCompletionsWrongModelO4Mini(t *testing.T) {
|
||||
config := openai.DefaultConfig("whatever")
|
||||
config.BaseURL = "http://localhost/v1"
|
||||
client := openai.NewClientWithConfig(config)
|
||||
|
||||
_, err := client.CreateCompletion(
|
||||
context.Background(),
|
||||
openai.CompletionRequest{
|
||||
MaxTokens: 5,
|
||||
Model: openai.O4Mini,
|
||||
},
|
||||
)
|
||||
if !errors.Is(err, openai.ErrCompletionUnsupportedModel) {
|
||||
t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel for O4Mini, but returned: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompletionWithStream(t *testing.T) {
|
||||
config := openai.DefaultConfig("whatever")
|
||||
client := openai.NewClientWithConfig(config)
|
||||
@@ -181,3 +217,86 @@ func getCompletionBody(r *http.Request) (openai.CompletionRequest, error) {
|
||||
}
|
||||
return completion, nil
|
||||
}
|
||||
|
||||
// TestCompletionWithO1Model Tests that O1 model is not supported for completion endpoint.
|
||||
func TestCompletionWithO1Model(t *testing.T) {
|
||||
config := openai.DefaultConfig("whatever")
|
||||
config.BaseURL = "http://localhost/v1"
|
||||
client := openai.NewClientWithConfig(config)
|
||||
|
||||
_, err := client.CreateCompletion(
|
||||
context.Background(),
|
||||
openai.CompletionRequest{
|
||||
MaxTokens: 5,
|
||||
Model: openai.O1,
|
||||
},
|
||||
)
|
||||
if !errors.Is(err, openai.ErrCompletionUnsupportedModel) {
|
||||
t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel for O1 model, but returned: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCompletionWithGPT4DotModels Tests that newer GPT4 models are not supported for completion endpoint.
|
||||
func TestCompletionWithGPT4DotModels(t *testing.T) {
|
||||
config := openai.DefaultConfig("whatever")
|
||||
config.BaseURL = "http://localhost/v1"
|
||||
client := openai.NewClientWithConfig(config)
|
||||
|
||||
models := []string{
|
||||
openai.GPT4Dot1,
|
||||
openai.GPT4Dot120250414,
|
||||
openai.GPT4Dot1Mini,
|
||||
openai.GPT4Dot1Mini20250414,
|
||||
openai.GPT4Dot1Nano,
|
||||
openai.GPT4Dot1Nano20250414,
|
||||
openai.GPT4Dot5Preview,
|
||||
openai.GPT4Dot5Preview20250227,
|
||||
}
|
||||
|
||||
for _, model := range models {
|
||||
t.Run(model, func(t *testing.T) {
|
||||
_, err := client.CreateCompletion(
|
||||
context.Background(),
|
||||
openai.CompletionRequest{
|
||||
MaxTokens: 5,
|
||||
Model: model,
|
||||
},
|
||||
)
|
||||
if !errors.Is(err, openai.ErrCompletionUnsupportedModel) {
|
||||
t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel for %s model, but returned: %v", model, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestCompletionWithGPT4oModels Tests that GPT4o models are not supported for completion endpoint.
|
||||
func TestCompletionWithGPT4oModels(t *testing.T) {
|
||||
config := openai.DefaultConfig("whatever")
|
||||
config.BaseURL = "http://localhost/v1"
|
||||
client := openai.NewClientWithConfig(config)
|
||||
|
||||
models := []string{
|
||||
openai.GPT4o,
|
||||
openai.GPT4o20240513,
|
||||
openai.GPT4o20240806,
|
||||
openai.GPT4o20241120,
|
||||
openai.GPT4oLatest,
|
||||
openai.GPT4oMini,
|
||||
openai.GPT4oMini20240718,
|
||||
}
|
||||
|
||||
for _, model := range models {
|
||||
t.Run(model, func(t *testing.T) {
|
||||
_, err := client.CreateCompletion(
|
||||
context.Background(),
|
||||
openai.CompletionRequest{
|
||||
MaxTokens: 5,
|
||||
Model: model,
|
||||
},
|
||||
)
|
||||
if !errors.Is(err, openai.ErrCompletionUnsupportedModel) {
|
||||
t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel for %s model, but returned: %v", model, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
22
config.go
22
config.go
@@ -11,6 +11,8 @@ const (
|
||||
|
||||
azureAPIPrefix = "openai"
|
||||
azureDeploymentsPrefix = "deployments"
|
||||
|
||||
AnthropicAPIVersion = "2023-06-01"
|
||||
)
|
||||
|
||||
type APIType string
|
||||
@@ -20,6 +22,7 @@ const (
|
||||
APITypeAzure APIType = "AZURE"
|
||||
APITypeAzureAD APIType = "AZURE_AD"
|
||||
APITypeCloudflareAzure APIType = "CLOUDFLARE_AZURE"
|
||||
APITypeAnthropic APIType = "ANTHROPIC"
|
||||
)
|
||||
|
||||
const AzureAPIKeyHeader = "api-key"
|
||||
@@ -37,7 +40,7 @@ type ClientConfig struct {
|
||||
BaseURL string
|
||||
OrgID string
|
||||
APIType APIType
|
||||
APIVersion string // required when APIType is APITypeAzure or APITypeAzureAD
|
||||
APIVersion string // required when APIType is APITypeAzure or APITypeAzureAD or APITypeAnthropic
|
||||
AssistantVersion string
|
||||
AzureModelMapperFunc func(model string) string // replace model to azure deployment name func
|
||||
HTTPClient HTTPDoer
|
||||
@@ -76,6 +79,23 @@ func DefaultAzureConfig(apiKey, baseURL string) ClientConfig {
|
||||
}
|
||||
}
|
||||
|
||||
func DefaultAnthropicConfig(apiKey, baseURL string) ClientConfig {
|
||||
if baseURL == "" {
|
||||
baseURL = "https://api.anthropic.com/v1"
|
||||
}
|
||||
return ClientConfig{
|
||||
authToken: apiKey,
|
||||
BaseURL: baseURL,
|
||||
OrgID: "",
|
||||
APIType: APITypeAnthropic,
|
||||
APIVersion: AnthropicAPIVersion,
|
||||
|
||||
HTTPClient: &http.Client{},
|
||||
|
||||
EmptyMessagesLimit: defaultEmptyMessagesLimit,
|
||||
}
|
||||
}
|
||||
|
||||
func (ClientConfig) String() string {
|
||||
return "<OpenAI API ClientConfig>"
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@ package openai_test
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/sashabaranov/go-openai"
|
||||
"git.vaala.cloud/VaalaCat/go-openai"
|
||||
)
|
||||
|
||||
func TestGetAzureDeploymentByModel(t *testing.T) {
|
||||
@@ -60,3 +60,64 @@ func TestGetAzureDeploymentByModel(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultAnthropicConfig(t *testing.T) {
|
||||
apiKey := "test-key"
|
||||
baseURL := "https://api.anthropic.com/v1"
|
||||
|
||||
config := openai.DefaultAnthropicConfig(apiKey, baseURL)
|
||||
|
||||
if config.APIType != openai.APITypeAnthropic {
|
||||
t.Errorf("Expected APIType to be %v, got %v", openai.APITypeAnthropic, config.APIType)
|
||||
}
|
||||
|
||||
if config.APIVersion != openai.AnthropicAPIVersion {
|
||||
t.Errorf("Expected APIVersion to be 2023-06-01, got %v", config.APIVersion)
|
||||
}
|
||||
|
||||
if config.BaseURL != baseURL {
|
||||
t.Errorf("Expected BaseURL to be %v, got %v", baseURL, config.BaseURL)
|
||||
}
|
||||
|
||||
if config.EmptyMessagesLimit != 300 {
|
||||
t.Errorf("Expected EmptyMessagesLimit to be 300, got %v", config.EmptyMessagesLimit)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultAnthropicConfigWithEmptyValues(t *testing.T) {
|
||||
config := openai.DefaultAnthropicConfig("", "")
|
||||
|
||||
if config.APIType != openai.APITypeAnthropic {
|
||||
t.Errorf("Expected APIType to be %v, got %v", openai.APITypeAnthropic, config.APIType)
|
||||
}
|
||||
|
||||
if config.APIVersion != openai.AnthropicAPIVersion {
|
||||
t.Errorf("Expected APIVersion to be %s, got %v", openai.AnthropicAPIVersion, config.APIVersion)
|
||||
}
|
||||
|
||||
expectedBaseURL := "https://api.anthropic.com/v1"
|
||||
if config.BaseURL != expectedBaseURL {
|
||||
t.Errorf("Expected BaseURL to be %v, got %v", expectedBaseURL, config.BaseURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientConfigString(t *testing.T) {
|
||||
// String() should always return the constant value
|
||||
conf := openai.DefaultConfig("dummy-token")
|
||||
expected := "<OpenAI API ClientConfig>"
|
||||
got := conf.String()
|
||||
if got != expected {
|
||||
t.Errorf("ClientConfig.String() = %q; want %q", got, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAzureDeploymentByModel_NoMapper(t *testing.T) {
|
||||
// On a zero-value or DefaultConfig, AzureModelMapperFunc is nil,
|
||||
// so GetAzureDeploymentByModel should just return the input model.
|
||||
conf := openai.DefaultConfig("dummy-token")
|
||||
model := "some-model"
|
||||
got := conf.GetAzureDeploymentByModel(model)
|
||||
if got != model {
|
||||
t.Errorf("GetAzureDeploymentByModel(%q) = %q; want %q", model, got, model)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,8 +9,8 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sashabaranov/go-openai"
|
||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||
"git.vaala.cloud/VaalaCat/go-openai"
|
||||
"git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
|
||||
)
|
||||
|
||||
// TestEdits Tests the edits endpoint of the API using the mocked server.
|
||||
|
||||
@@ -11,8 +11,8 @@ import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/sashabaranov/go-openai"
|
||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||
"git.vaala.cloud/VaalaCat/go-openai"
|
||||
"git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
|
||||
)
|
||||
|
||||
func TestEmbedding(t *testing.T) {
|
||||
|
||||
@@ -7,8 +7,8 @@ import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/sashabaranov/go-openai"
|
||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||
"git.vaala.cloud/VaalaCat/go-openai"
|
||||
"git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
|
||||
)
|
||||
|
||||
// TestGetEngine Tests the retrieve engine endpoint of the API using the mocked server.
|
||||
|
||||
4
error.go
4
error.go
@@ -54,7 +54,7 @@ func (e *APIError) UnmarshalJSON(data []byte) (err error) {
|
||||
err = json.Unmarshal(rawMap["message"], &e.Message)
|
||||
if err != nil {
|
||||
// If the parameter field of a function call is invalid as a JSON schema
|
||||
// refs: https://github.com/sashabaranov/go-openai/issues/381
|
||||
// refs: https://git.vaala.cloud/VaalaCat/go-openai/issues/381
|
||||
var messages []string
|
||||
err = json.Unmarshal(rawMap["message"], &messages)
|
||||
if err != nil {
|
||||
@@ -64,7 +64,7 @@ func (e *APIError) UnmarshalJSON(data []byte) (err error) {
|
||||
}
|
||||
|
||||
// optional fields for azure openai
|
||||
// refs: https://github.com/sashabaranov/go-openai/issues/343
|
||||
// refs: https://git.vaala.cloud/VaalaCat/go-openai/issues/343
|
||||
if _, ok := rawMap["type"]; ok {
|
||||
err = json.Unmarshal(rawMap["type"], &e.Type)
|
||||
if err != nil {
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/sashabaranov/go-openai"
|
||||
"git.vaala.cloud/VaalaCat/go-openai"
|
||||
)
|
||||
|
||||
func TestAPIErrorUnmarshalJSON(t *testing.T) {
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
"net/url"
|
||||
"os"
|
||||
|
||||
"github.com/sashabaranov/go-openai"
|
||||
"git.vaala.cloud/VaalaCat/go-openai"
|
||||
)
|
||||
|
||||
func Example() {
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/sashabaranov/go-openai"
|
||||
"git.vaala.cloud/VaalaCat/go-openai"
|
||||
)
|
||||
|
||||
func main() {
|
||||
|
||||
@@ -5,8 +5,8 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/sashabaranov/go-openai"
|
||||
"github.com/sashabaranov/go-openai/jsonschema"
|
||||
"git.vaala.cloud/VaalaCat/go-openai"
|
||||
"git.vaala.cloud/VaalaCat/go-openai/jsonschema"
|
||||
)
|
||||
|
||||
func main() {
|
||||
|
||||
@@ -5,7 +5,7 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/sashabaranov/go-openai"
|
||||
"git.vaala.cloud/VaalaCat/go-openai"
|
||||
)
|
||||
|
||||
func main() {
|
||||
|
||||
@@ -5,7 +5,7 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/sashabaranov/go-openai"
|
||||
"git.vaala.cloud/VaalaCat/go-openai"
|
||||
)
|
||||
|
||||
func main() {
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/sashabaranov/go-openai"
|
||||
"git.vaala.cloud/VaalaCat/go-openai"
|
||||
)
|
||||
|
||||
func main() {
|
||||
|
||||
@@ -12,8 +12,8 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sashabaranov/go-openai"
|
||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||
"git.vaala.cloud/VaalaCat/go-openai"
|
||||
"git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
|
||||
)
|
||||
|
||||
func TestFileBytesUpload(t *testing.T) {
|
||||
|
||||
@@ -7,8 +7,8 @@ import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
utils "github.com/sashabaranov/go-openai/internal"
|
||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||
utils "git.vaala.cloud/VaalaCat/go-openai/internal"
|
||||
"git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
|
||||
)
|
||||
|
||||
func TestFileBytesUploadWithFailingFormBuilder(t *testing.T) {
|
||||
|
||||
@@ -7,8 +7,8 @@ import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/sashabaranov/go-openai"
|
||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||
"git.vaala.cloud/VaalaCat/go-openai"
|
||||
"git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
|
||||
)
|
||||
|
||||
const testFineTuneID = "fine-tune-id"
|
||||
|
||||
@@ -7,8 +7,8 @@ import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/sashabaranov/go-openai"
|
||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||
"git.vaala.cloud/VaalaCat/go-openai"
|
||||
"git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
|
||||
)
|
||||
|
||||
const testFineTuninigJobID = "fine-tuning-job-id"
|
||||
|
||||
2
go.mod
2
go.mod
@@ -1,3 +1,3 @@
|
||||
module github.com/sashabaranov/go-openai
|
||||
module git.vaala.cloud/VaalaCat/go-openai
|
||||
|
||||
go 1.18
|
||||
|
||||
112
image.go
112
image.go
@@ -3,8 +3,8 @@ package openai
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
@@ -13,51 +13,101 @@ const (
|
||||
CreateImageSize256x256 = "256x256"
|
||||
CreateImageSize512x512 = "512x512"
|
||||
CreateImageSize1024x1024 = "1024x1024"
|
||||
|
||||
// dall-e-3 supported only.
|
||||
CreateImageSize1792x1024 = "1792x1024"
|
||||
CreateImageSize1024x1792 = "1024x1792"
|
||||
|
||||
// gpt-image-1 supported only.
|
||||
CreateImageSize1536x1024 = "1536x1024" // Landscape
|
||||
CreateImageSize1024x1536 = "1024x1536" // Portrait
|
||||
)
|
||||
|
||||
const (
|
||||
CreateImageResponseFormatURL = "url"
|
||||
// dall-e-2 and dall-e-3 only.
|
||||
CreateImageResponseFormatB64JSON = "b64_json"
|
||||
CreateImageResponseFormatURL = "url"
|
||||
)
|
||||
|
||||
const (
|
||||
CreateImageModelDallE2 = "dall-e-2"
|
||||
CreateImageModelDallE3 = "dall-e-3"
|
||||
CreateImageModelDallE2 = "dall-e-2"
|
||||
CreateImageModelDallE3 = "dall-e-3"
|
||||
CreateImageModelGptImage1 = "gpt-image-1"
|
||||
)
|
||||
|
||||
const (
|
||||
CreateImageQualityHD = "hd"
|
||||
CreateImageQualityStandard = "standard"
|
||||
|
||||
// gpt-image-1 only.
|
||||
CreateImageQualityHigh = "high"
|
||||
CreateImageQualityMedium = "medium"
|
||||
CreateImageQualityLow = "low"
|
||||
)
|
||||
|
||||
const (
|
||||
// dall-e-3 only.
|
||||
CreateImageStyleVivid = "vivid"
|
||||
CreateImageStyleNatural = "natural"
|
||||
)
|
||||
|
||||
const (
|
||||
// gpt-image-1 only.
|
||||
CreateImageBackgroundTransparent = "transparent"
|
||||
CreateImageBackgroundOpaque = "opaque"
|
||||
)
|
||||
|
||||
const (
|
||||
// gpt-image-1 only.
|
||||
CreateImageModerationLow = "low"
|
||||
)
|
||||
|
||||
const (
|
||||
// gpt-image-1 only.
|
||||
CreateImageOutputFormatPNG = "png"
|
||||
CreateImageOutputFormatJPEG = "jpeg"
|
||||
CreateImageOutputFormatWEBP = "webp"
|
||||
)
|
||||
|
||||
// ImageRequest represents the request structure for the image API.
|
||||
type ImageRequest struct {
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
N int `json:"n,omitempty"`
|
||||
Quality string `json:"quality,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
Style string `json:"style,omitempty"`
|
||||
ResponseFormat string `json:"response_format,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
N int `json:"n,omitempty"`
|
||||
Quality string `json:"quality,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
Style string `json:"style,omitempty"`
|
||||
ResponseFormat string `json:"response_format,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
Background string `json:"background,omitempty"`
|
||||
Moderation string `json:"moderation,omitempty"`
|
||||
OutputCompression int `json:"output_compression,omitempty"`
|
||||
OutputFormat string `json:"output_format,omitempty"`
|
||||
}
|
||||
|
||||
// ImageResponse represents a response structure for image API.
|
||||
type ImageResponse struct {
|
||||
Created int64 `json:"created,omitempty"`
|
||||
Data []ImageResponseDataInner `json:"data,omitempty"`
|
||||
Usage ImageResponseUsage `json:"usage,omitempty"`
|
||||
|
||||
httpHeader
|
||||
}
|
||||
|
||||
// ImageResponseInputTokensDetails represents the token breakdown for input tokens.
|
||||
type ImageResponseInputTokensDetails struct {
|
||||
TextTokens int `json:"text_tokens,omitempty"`
|
||||
ImageTokens int `json:"image_tokens,omitempty"`
|
||||
}
|
||||
|
||||
// ImageResponseUsage represents the token usage information for image API.
|
||||
type ImageResponseUsage struct {
|
||||
TotalTokens int `json:"total_tokens,omitempty"`
|
||||
InputTokens int `json:"input_tokens,omitempty"`
|
||||
OutputTokens int `json:"output_tokens,omitempty"`
|
||||
InputTokensDetails ImageResponseInputTokensDetails `json:"input_tokens_details,omitempty"`
|
||||
}
|
||||
|
||||
// ImageResponseDataInner represents a response data structure for image API.
|
||||
type ImageResponseDataInner struct {
|
||||
URL string `json:"url,omitempty"`
|
||||
@@ -84,13 +134,15 @@ func (c *Client) CreateImage(ctx context.Context, request ImageRequest) (respons
|
||||
|
||||
// ImageEditRequest represents the request structure for the image API.
|
||||
type ImageEditRequest struct {
|
||||
Image *os.File `json:"image,omitempty"`
|
||||
Mask *os.File `json:"mask,omitempty"`
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
N int `json:"n,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
ResponseFormat string `json:"response_format,omitempty"`
|
||||
Image io.Reader `json:"image,omitempty"`
|
||||
Mask io.Reader `json:"mask,omitempty"`
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
N int `json:"n,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
ResponseFormat string `json:"response_format,omitempty"`
|
||||
Quality string `json:"quality,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
}
|
||||
|
||||
// CreateEditImage - API call to create an image. This is the main endpoint of the DALL-E API.
|
||||
@@ -98,15 +150,16 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest)
|
||||
body := &bytes.Buffer{}
|
||||
builder := c.createFormBuilder(body)
|
||||
|
||||
// image
|
||||
err = builder.CreateFormFile("image", request.Image)
|
||||
// image, filename is not required
|
||||
err = builder.CreateFormFileReader("image", request.Image, "")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// mask, it is optional
|
||||
if request.Mask != nil {
|
||||
err = builder.CreateFormFile("mask", request.Mask)
|
||||
// mask, filename is not required
|
||||
err = builder.CreateFormFileReader("mask", request.Mask, "")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -154,11 +207,12 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest)
|
||||
|
||||
// ImageVariRequest represents the request structure for the image API.
|
||||
type ImageVariRequest struct {
|
||||
Image *os.File `json:"image,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
N int `json:"n,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
ResponseFormat string `json:"response_format,omitempty"`
|
||||
Image io.Reader `json:"image,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
N int `json:"n,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
ResponseFormat string `json:"response_format,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
}
|
||||
|
||||
// CreateVariImage - API call to create an image variation. This is the main endpoint of the DALL-E API.
|
||||
@@ -167,8 +221,8 @@ func (c *Client) CreateVariImage(ctx context.Context, request ImageVariRequest)
|
||||
body := &bytes.Buffer{}
|
||||
builder := c.createFormBuilder(body)
|
||||
|
||||
// image
|
||||
err = builder.CreateFormFile("image", request.Image)
|
||||
// image, filename is not required
|
||||
err = builder.CreateFormFileReader("image", request.Image, "")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -7,11 +7,12 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sashabaranov/go-openai"
|
||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||
"git.vaala.cloud/VaalaCat/go-openai"
|
||||
"git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
|
||||
)
|
||||
|
||||
func TestImages(t *testing.T) {
|
||||
@@ -86,24 +87,17 @@ func TestImageEdit(t *testing.T) {
|
||||
defer teardown()
|
||||
server.RegisterHandler("/v1/images/edits", handleEditImageEndpoint)
|
||||
|
||||
origin, err := os.Create("image.png")
|
||||
origin, err := os.Create(filepath.Join(t.TempDir(), "image.png"))
|
||||
if err != nil {
|
||||
t.Error("open origin file error")
|
||||
return
|
||||
t.Fatalf("open origin file error: %v", err)
|
||||
}
|
||||
defer origin.Close()
|
||||
|
||||
mask, err := os.Create("mask.png")
|
||||
mask, err := os.Create(filepath.Join(t.TempDir(), "mask.png"))
|
||||
if err != nil {
|
||||
t.Error("open mask file error")
|
||||
return
|
||||
t.Fatalf("open mask file error: %v", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
mask.Close()
|
||||
origin.Close()
|
||||
os.Remove("mask.png")
|
||||
os.Remove("image.png")
|
||||
}()
|
||||
defer mask.Close()
|
||||
|
||||
_, err = client.CreateEditImage(context.Background(), openai.ImageEditRequest{
|
||||
Image: origin,
|
||||
@@ -121,16 +115,11 @@ func TestImageEditWithoutMask(t *testing.T) {
|
||||
defer teardown()
|
||||
server.RegisterHandler("/v1/images/edits", handleEditImageEndpoint)
|
||||
|
||||
origin, err := os.Create("image.png")
|
||||
origin, err := os.Create(filepath.Join(t.TempDir(), "image.png"))
|
||||
if err != nil {
|
||||
t.Error("open origin file error")
|
||||
return
|
||||
t.Fatalf("open origin file error: %v", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
origin.Close()
|
||||
os.Remove("image.png")
|
||||
}()
|
||||
defer origin.Close()
|
||||
|
||||
_, err = client.CreateEditImage(context.Background(), openai.ImageEditRequest{
|
||||
Image: origin,
|
||||
@@ -178,16 +167,11 @@ func TestImageVariation(t *testing.T) {
|
||||
defer teardown()
|
||||
server.RegisterHandler("/v1/images/variations", handleVariateImageEndpoint)
|
||||
|
||||
origin, err := os.Create("image.png")
|
||||
origin, err := os.Create(filepath.Join(t.TempDir(), "image.png"))
|
||||
if err != nil {
|
||||
t.Error("open origin file error")
|
||||
return
|
||||
t.Fatalf("open origin file error: %v", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
origin.Close()
|
||||
os.Remove("image.png")
|
||||
}()
|
||||
defer origin.Close()
|
||||
|
||||
_, err = client.CreateVariImage(context.Background(), openai.ImageVariRequest{
|
||||
Image: origin,
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
package openai //nolint:testpackage // testing private field
|
||||
|
||||
import (
|
||||
utils "github.com/sashabaranov/go-openai/internal"
|
||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||
utils "git.vaala.cloud/VaalaCat/go-openai/internal"
|
||||
"git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
|
||||
|
||||
"context"
|
||||
"fmt"
|
||||
@@ -54,13 +54,13 @@ func TestImageFormBuilderFailures(t *testing.T) {
|
||||
}
|
||||
|
||||
mockFailedErr := fmt.Errorf("mock form builder fail")
|
||||
mockBuilder.mockCreateFormFile = func(string, *os.File) error {
|
||||
mockBuilder.mockCreateFormFileReader = func(string, io.Reader, string) error {
|
||||
return mockFailedErr
|
||||
}
|
||||
_, err := client.CreateEditImage(ctx, req)
|
||||
checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails")
|
||||
|
||||
mockBuilder.mockCreateFormFile = func(name string, _ *os.File) error {
|
||||
mockBuilder.mockCreateFormFileReader = func(name string, _ io.Reader, _ string) error {
|
||||
if name == "mask" {
|
||||
return mockFailedErr
|
||||
}
|
||||
@@ -119,13 +119,13 @@ func TestVariImageFormBuilderFailures(t *testing.T) {
|
||||
req := ImageVariRequest{}
|
||||
|
||||
mockFailedErr := fmt.Errorf("mock form builder fail")
|
||||
mockBuilder.mockCreateFormFile = func(string, *os.File) error {
|
||||
mockBuilder.mockCreateFormFileReader = func(string, io.Reader, string) error {
|
||||
return mockFailedErr
|
||||
}
|
||||
_, err := client.CreateVariImage(ctx, req)
|
||||
checks.ErrorIs(t, err, mockFailedErr, "CreateVariImage should return error if form builder fails")
|
||||
|
||||
mockBuilder.mockCreateFormFile = func(string, *os.File) error {
|
||||
mockBuilder.mockCreateFormFileReader = func(string, io.Reader, string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -5,8 +5,8 @@ import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
utils "github.com/sashabaranov/go-openai/internal"
|
||||
"github.com/sashabaranov/go-openai/internal/test"
|
||||
utils "git.vaala.cloud/VaalaCat/go-openai/internal"
|
||||
"git.vaala.cloud/VaalaCat/go-openai/internal/test"
|
||||
)
|
||||
|
||||
func TestErrorAccumulatorBytes(t *testing.T) {
|
||||
|
||||
@@ -4,8 +4,10 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/textproto"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type FormBuilder interface {
|
||||
@@ -30,8 +32,37 @@ func (fb *DefaultFormBuilder) CreateFormFile(fieldname string, file *os.File) er
|
||||
return fb.createFormFile(fieldname, file, file.Name())
|
||||
}
|
||||
|
||||
var quoteEscaper = strings.NewReplacer("\\", "\\\\", `"`, "\\\"")
|
||||
|
||||
func escapeQuotes(s string) string {
|
||||
return quoteEscaper.Replace(s)
|
||||
}
|
||||
|
||||
// CreateFormFileReader creates a form field with a file reader.
|
||||
// The filename in parameters can be an empty string.
|
||||
// The filename in Content-Disposition is required, But it can be an empty string.
|
||||
func (fb *DefaultFormBuilder) CreateFormFileReader(fieldname string, r io.Reader, filename string) error {
|
||||
return fb.createFormFile(fieldname, r, path.Base(filename))
|
||||
h := make(textproto.MIMEHeader)
|
||||
h.Set(
|
||||
"Content-Disposition",
|
||||
fmt.Sprintf(
|
||||
`form-data; name="%s"; filename="%s"`,
|
||||
escapeQuotes(fieldname),
|
||||
escapeQuotes(filepath.Base(filename)),
|
||||
),
|
||||
)
|
||||
|
||||
fieldWriter, err := fb.writer.CreatePart(h)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = io.Copy(fieldWriter, r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (fb *DefaultFormBuilder) createFormFile(fieldname string, r io.Reader, filename string) error {
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
package openai //nolint:testpackage // testing private field
|
||||
|
||||
import (
|
||||
"github.com/sashabaranov/go-openai/internal/test"
|
||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||
"git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
|
||||
|
||||
"bytes"
|
||||
"errors"
|
||||
@@ -20,15 +19,11 @@ func (*failingWriter) Write([]byte) (int, error) {
|
||||
}
|
||||
|
||||
func TestFormBuilderWithFailingWriter(t *testing.T) {
|
||||
dir, cleanup := test.CreateTestDirectory(t)
|
||||
defer cleanup()
|
||||
|
||||
file, err := os.CreateTemp(dir, "")
|
||||
file, err := os.CreateTemp(t.TempDir(), "")
|
||||
if err != nil {
|
||||
t.Errorf("Error creating tmp file: %v", err)
|
||||
t.Fatalf("Error creating tmp file: %v", err)
|
||||
}
|
||||
defer file.Close()
|
||||
defer os.Remove(file.Name())
|
||||
|
||||
builder := NewFormBuilder(&failingWriter{})
|
||||
err = builder.CreateFormFile("file", file)
|
||||
@@ -36,15 +31,11 @@ func TestFormBuilderWithFailingWriter(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestFormBuilderWithClosedFile(t *testing.T) {
|
||||
dir, cleanup := test.CreateTestDirectory(t)
|
||||
defer cleanup()
|
||||
|
||||
file, err := os.CreateTemp(dir, "")
|
||||
file, err := os.CreateTemp(t.TempDir(), "")
|
||||
if err != nil {
|
||||
t.Errorf("Error creating tmp file: %v", err)
|
||||
t.Fatalf("Error creating tmp file: %v", err)
|
||||
}
|
||||
file.Close()
|
||||
defer os.Remove(file.Name())
|
||||
|
||||
body := &bytes.Buffer{}
|
||||
builder := NewFormBuilder(body)
|
||||
@@ -52,3 +43,32 @@ func TestFormBuilderWithClosedFile(t *testing.T) {
|
||||
checks.HasError(t, err, "formbuilder should return error if file is closed")
|
||||
checks.ErrorIs(t, err, os.ErrClosed, "formbuilder should return error if file is closed")
|
||||
}
|
||||
|
||||
type failingReader struct {
|
||||
}
|
||||
|
||||
var errMockFailingReaderError = errors.New("mock reader failed")
|
||||
|
||||
func (*failingReader) Read([]byte) (int, error) {
|
||||
return 0, errMockFailingReaderError
|
||||
}
|
||||
|
||||
func TestFormBuilderWithReader(t *testing.T) {
|
||||
file, err := os.CreateTemp(t.TempDir(), "")
|
||||
if err != nil {
|
||||
t.Fatalf("Error creating tmp file: %v", err)
|
||||
}
|
||||
defer file.Close()
|
||||
builder := NewFormBuilder(&failingWriter{})
|
||||
err = builder.CreateFormFileReader("file", file, file.Name())
|
||||
checks.ErrorIs(t, err, errMockFailingWriterError, "formbuilder should return error if writer fails")
|
||||
|
||||
builder = NewFormBuilder(&bytes.Buffer{})
|
||||
reader := &failingReader{}
|
||||
err = builder.CreateFormFileReader("file", reader, "")
|
||||
checks.ErrorIs(t, err, errMockFailingReaderError, "formbuilder should return error if copy reader fails")
|
||||
|
||||
successReader := &bytes.Buffer{}
|
||||
err = builder.CreateFormFileReader("file", successReader, "")
|
||||
checks.NoError(t, err, "formbuilder should not return error")
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||
"git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
|
||||
|
||||
"net/http"
|
||||
"os"
|
||||
@@ -19,16 +19,6 @@ func CreateTestFile(t *testing.T, path string) {
|
||||
file.Close()
|
||||
}
|
||||
|
||||
// CreateTestDirectory creates a temporary folder which will be deleted when cleanup is called.
|
||||
func CreateTestDirectory(t *testing.T) (path string, cleanup func()) {
|
||||
t.Helper()
|
||||
|
||||
path, err := os.MkdirTemp(os.TempDir(), "")
|
||||
checks.NoError(t, err)
|
||||
|
||||
return path, func() { os.RemoveAll(path) }
|
||||
}
|
||||
|
||||
// TokenRoundTripper is a struct that implements the RoundTripper
|
||||
// interface, specifically to handle the authentication token by adding a token
|
||||
// to the request header. We need this because the API requires that each
|
||||
|
||||
@@ -46,6 +46,8 @@ type Definition struct {
|
||||
// additionalProperties: false
|
||||
// additionalProperties: jsonschema.Definition{Type: jsonschema.String}
|
||||
AdditionalProperties any `json:"additionalProperties,omitempty"`
|
||||
// Whether the schema is nullable or not.
|
||||
Nullable bool `json:"nullable,omitempty"`
|
||||
}
|
||||
|
||||
func (d *Definition) MarshalJSON() ([]byte, error) {
|
||||
@@ -124,9 +126,12 @@ func reflectSchemaObject(t reflect.Type) (*Definition, error) {
|
||||
}
|
||||
jsonTag := field.Tag.Get("json")
|
||||
var required = true
|
||||
if jsonTag == "" {
|
||||
switch {
|
||||
case jsonTag == "-":
|
||||
continue
|
||||
case jsonTag == "":
|
||||
jsonTag = field.Name
|
||||
} else if strings.HasSuffix(jsonTag, ",omitempty") {
|
||||
case strings.HasSuffix(jsonTag, ",omitempty"):
|
||||
jsonTag = strings.TrimSuffix(jsonTag, ",omitempty")
|
||||
required = false
|
||||
}
|
||||
@@ -139,6 +144,16 @@ func reflectSchemaObject(t reflect.Type) (*Definition, error) {
|
||||
if description != "" {
|
||||
item.Description = description
|
||||
}
|
||||
enum := field.Tag.Get("enum")
|
||||
if enum != "" {
|
||||
item.Enum = strings.Split(enum, ",")
|
||||
}
|
||||
|
||||
if n := field.Tag.Get("nullable"); n != "" {
|
||||
nullable, _ := strconv.ParseBool(n)
|
||||
item.Nullable = nullable
|
||||
}
|
||||
|
||||
properties[jsonTag] = *item
|
||||
|
||||
if s := field.Tag.Get("required"); s != "" {
|
||||
|
||||
@@ -5,7 +5,7 @@ import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/sashabaranov/go-openai/jsonschema"
|
||||
"git.vaala.cloud/VaalaCat/go-openai/jsonschema"
|
||||
)
|
||||
|
||||
func TestDefinition_MarshalJSON(t *testing.T) {
|
||||
@@ -17,7 +17,7 @@ func TestDefinition_MarshalJSON(t *testing.T) {
|
||||
{
|
||||
name: "Test with empty Definition",
|
||||
def: jsonschema.Definition{},
|
||||
want: `{"properties":{}}`,
|
||||
want: `{}`,
|
||||
},
|
||||
{
|
||||
name: "Test with Definition properties set",
|
||||
@@ -31,15 +31,14 @@ func TestDefinition_MarshalJSON(t *testing.T) {
|
||||
},
|
||||
},
|
||||
want: `{
|
||||
"type":"string",
|
||||
"description":"A string type",
|
||||
"properties":{
|
||||
"name":{
|
||||
"type":"string",
|
||||
"properties":{}
|
||||
}
|
||||
}
|
||||
}`,
|
||||
"type":"string",
|
||||
"description":"A string type",
|
||||
"properties":{
|
||||
"name":{
|
||||
"type":"string"
|
||||
}
|
||||
}
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "Test with nested Definition properties",
|
||||
@@ -60,23 +59,21 @@ func TestDefinition_MarshalJSON(t *testing.T) {
|
||||
},
|
||||
},
|
||||
want: `{
|
||||
"type":"object",
|
||||
"properties":{
|
||||
"user":{
|
||||
"type":"object",
|
||||
"properties":{
|
||||
"name":{
|
||||
"type":"string",
|
||||
"properties":{}
|
||||
},
|
||||
"age":{
|
||||
"type":"integer",
|
||||
"properties":{}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}`,
|
||||
"type":"object",
|
||||
"properties":{
|
||||
"user":{
|
||||
"type":"object",
|
||||
"properties":{
|
||||
"name":{
|
||||
"type":"string"
|
||||
},
|
||||
"age":{
|
||||
"type":"integer"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "Test with complex nested Definition",
|
||||
@@ -108,36 +105,32 @@ func TestDefinition_MarshalJSON(t *testing.T) {
|
||||
},
|
||||
},
|
||||
want: `{
|
||||
"type":"object",
|
||||
"properties":{
|
||||
"user":{
|
||||
"type":"object",
|
||||
"properties":{
|
||||
"name":{
|
||||
"type":"string",
|
||||
"properties":{}
|
||||
},
|
||||
"age":{
|
||||
"type":"integer",
|
||||
"properties":{}
|
||||
},
|
||||
"address":{
|
||||
"type":"object",
|
||||
"properties":{
|
||||
"city":{
|
||||
"type":"string",
|
||||
"properties":{}
|
||||
},
|
||||
"country":{
|
||||
"type":"string",
|
||||
"properties":{}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}`,
|
||||
"type":"object",
|
||||
"properties":{
|
||||
"user":{
|
||||
"type":"object",
|
||||
"properties":{
|
||||
"name":{
|
||||
"type":"string"
|
||||
},
|
||||
"age":{
|
||||
"type":"integer"
|
||||
},
|
||||
"address":{
|
||||
"type":"object",
|
||||
"properties":{
|
||||
"city":{
|
||||
"type":"string"
|
||||
},
|
||||
"country":{
|
||||
"type":"string"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "Test with Array type Definition",
|
||||
@@ -153,20 +146,16 @@ func TestDefinition_MarshalJSON(t *testing.T) {
|
||||
},
|
||||
},
|
||||
want: `{
|
||||
"type":"array",
|
||||
"items":{
|
||||
"type":"string",
|
||||
"properties":{
|
||||
|
||||
}
|
||||
},
|
||||
"properties":{
|
||||
"name":{
|
||||
"type":"string",
|
||||
"properties":{}
|
||||
}
|
||||
}
|
||||
}`,
|
||||
"type":"array",
|
||||
"items":{
|
||||
"type":"string"
|
||||
},
|
||||
"properties":{
|
||||
"name":{
|
||||
"type":"string"
|
||||
}
|
||||
}
|
||||
}`,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -193,6 +182,232 @@ func TestDefinition_MarshalJSON(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestStructToSchema(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
in any
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "Test with empty struct",
|
||||
in: struct{}{},
|
||||
want: `{
|
||||
"type":"object",
|
||||
"additionalProperties":false
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "Test with struct containing many fields",
|
||||
in: struct {
|
||||
Name string `json:"name"`
|
||||
Age int `json:"age"`
|
||||
Active bool `json:"active"`
|
||||
Height float64 `json:"height"`
|
||||
Cities []struct {
|
||||
Name string `json:"name"`
|
||||
State string `json:"state"`
|
||||
} `json:"cities"`
|
||||
}{
|
||||
Name: "John Doe",
|
||||
Age: 30,
|
||||
Cities: []struct {
|
||||
Name string `json:"name"`
|
||||
State string `json:"state"`
|
||||
}{
|
||||
{Name: "New York", State: "NY"},
|
||||
{Name: "Los Angeles", State: "CA"},
|
||||
},
|
||||
},
|
||||
want: `{
|
||||
"type":"object",
|
||||
"properties":{
|
||||
"name":{
|
||||
"type":"string"
|
||||
},
|
||||
"age":{
|
||||
"type":"integer"
|
||||
},
|
||||
"active":{
|
||||
"type":"boolean"
|
||||
},
|
||||
"height":{
|
||||
"type":"number"
|
||||
},
|
||||
"cities":{
|
||||
"type":"array",
|
||||
"items":{
|
||||
"additionalProperties":false,
|
||||
"type":"object",
|
||||
"properties":{
|
||||
"name":{
|
||||
"type":"string"
|
||||
},
|
||||
"state":{
|
||||
"type":"string"
|
||||
}
|
||||
},
|
||||
"required":["name","state"]
|
||||
}
|
||||
}
|
||||
},
|
||||
"required":["name","age","active","height","cities"],
|
||||
"additionalProperties":false
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "Test with description tag",
|
||||
in: struct {
|
||||
Name string `json:"name" description:"The name of the person"`
|
||||
}{
|
||||
Name: "John Doe",
|
||||
},
|
||||
want: `{
|
||||
"type":"object",
|
||||
"properties":{
|
||||
"name":{
|
||||
"type":"string",
|
||||
"description":"The name of the person"
|
||||
}
|
||||
},
|
||||
"required":["name"],
|
||||
"additionalProperties":false
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "Test with required tag",
|
||||
in: struct {
|
||||
Name string `json:"name" required:"false"`
|
||||
}{
|
||||
Name: "John Doe",
|
||||
},
|
||||
want: `{
|
||||
"type":"object",
|
||||
"properties":{
|
||||
"name":{
|
||||
"type":"string"
|
||||
}
|
||||
},
|
||||
"additionalProperties":false
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "Test with enum tag",
|
||||
in: struct {
|
||||
Color string `json:"color" enum:"red,green,blue"`
|
||||
}{
|
||||
Color: "red",
|
||||
},
|
||||
want: `{
|
||||
"type":"object",
|
||||
"properties":{
|
||||
"color":{
|
||||
"type":"string",
|
||||
"enum":["red","green","blue"]
|
||||
}
|
||||
},
|
||||
"required":["color"],
|
||||
"additionalProperties":false
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "Test with nullable tag",
|
||||
in: struct {
|
||||
Name *string `json:"name" nullable:"true"`
|
||||
}{
|
||||
Name: nil,
|
||||
},
|
||||
want: `{
|
||||
|
||||
"type":"object",
|
||||
"properties":{
|
||||
"name":{
|
||||
"type":"string",
|
||||
"nullable":true
|
||||
}
|
||||
},
|
||||
"required":["name"],
|
||||
"additionalProperties":false
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "Test with exclude mark",
|
||||
in: struct {
|
||||
Name string `json:"-"`
|
||||
}{
|
||||
Name: "Name",
|
||||
},
|
||||
want: `{
|
||||
"type":"object",
|
||||
"additionalProperties":false
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "Test with no json tag",
|
||||
in: struct {
|
||||
Name string
|
||||
}{
|
||||
Name: "",
|
||||
},
|
||||
want: `{
|
||||
"type":"object",
|
||||
"properties":{
|
||||
"Name":{
|
||||
"type":"string"
|
||||
}
|
||||
},
|
||||
"required":["Name"],
|
||||
"additionalProperties":false
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "Test with omitempty tag",
|
||||
in: struct {
|
||||
Name string `json:"name,omitempty"`
|
||||
}{
|
||||
Name: "",
|
||||
},
|
||||
want: `{
|
||||
"type":"object",
|
||||
"properties":{
|
||||
"name":{
|
||||
"type":"string"
|
||||
}
|
||||
},
|
||||
"additionalProperties":false
|
||||
}`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
wantBytes := []byte(tt.want)
|
||||
|
||||
schema, err := jsonschema.GenerateSchemaForType(tt.in)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to generate schema: error = %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
var want map[string]interface{}
|
||||
err = json.Unmarshal(wantBytes, &want)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to Unmarshal JSON: error = %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
got := structToMap(t, schema)
|
||||
gotPtr := structToMap(t, &schema)
|
||||
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("MarshalJSON() got = %v, want %v", got, want)
|
||||
}
|
||||
if !reflect.DeepEqual(gotPtr, want) {
|
||||
t.Errorf("MarshalJSON() gotPtr = %v, want %v", gotPtr, want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func structToMap(t *testing.T, v any) map[string]any {
|
||||
t.Helper()
|
||||
gotBytes, err := json.Marshal(v)
|
||||
|
||||
@@ -3,7 +3,7 @@ package jsonschema_test
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/sashabaranov/go-openai/jsonschema"
|
||||
"git.vaala.cloud/VaalaCat/go-openai/jsonschema"
|
||||
)
|
||||
|
||||
func Test_Validate(t *testing.T) {
|
||||
|
||||
@@ -41,6 +41,7 @@ type MessageContent struct {
|
||||
Type string `json:"type"`
|
||||
Text *MessageText `json:"text,omitempty"`
|
||||
ImageFile *ImageFile `json:"image_file,omitempty"`
|
||||
ImageURL *ImageURL `json:"image_url,omitempty"`
|
||||
}
|
||||
type MessageText struct {
|
||||
Value string `json:"value"`
|
||||
@@ -51,6 +52,11 @@ type ImageFile struct {
|
||||
FileID string `json:"file_id"`
|
||||
}
|
||||
|
||||
type ImageURL struct {
|
||||
URL string `json:"url"`
|
||||
Detail string `json:"detail"`
|
||||
}
|
||||
|
||||
type MessageRequest struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
|
||||
@@ -7,9 +7,9 @@ import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/sashabaranov/go-openai"
|
||||
"github.com/sashabaranov/go-openai/internal/test"
|
||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||
"git.vaala.cloud/VaalaCat/go-openai"
|
||||
"git.vaala.cloud/VaalaCat/go-openai/internal/test"
|
||||
"git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
|
||||
)
|
||||
|
||||
var emptyStr = ""
|
||||
|
||||
@@ -9,8 +9,8 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sashabaranov/go-openai"
|
||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||
"git.vaala.cloud/VaalaCat/go-openai"
|
||||
"git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
|
||||
)
|
||||
|
||||
const testFineTuneModelID = "fine-tune-model-id"
|
||||
@@ -47,6 +47,24 @@ func TestGetModel(t *testing.T) {
|
||||
checks.NoError(t, err, "GetModel error")
|
||||
}
|
||||
|
||||
// TestGetModelO3 Tests the retrieve O3 model endpoint of the API using the mocked server.
|
||||
func TestGetModelO3(t *testing.T) {
|
||||
client, server, teardown := setupOpenAITestServer()
|
||||
defer teardown()
|
||||
server.RegisterHandler("/v1/models/o3", handleGetModelEndpoint)
|
||||
_, err := client.GetModel(context.Background(), "o3")
|
||||
checks.NoError(t, err, "GetModel error for O3")
|
||||
}
|
||||
|
||||
// TestGetModelO4Mini Tests the retrieve O4Mini model endpoint of the API using the mocked server.
|
||||
func TestGetModelO4Mini(t *testing.T) {
|
||||
client, server, teardown := setupOpenAITestServer()
|
||||
defer teardown()
|
||||
server.RegisterHandler("/v1/models/o4-mini", handleGetModelEndpoint)
|
||||
_, err := client.GetModel(context.Background(), "o4-mini")
|
||||
checks.NoError(t, err, "GetModel error for O4Mini")
|
||||
}
|
||||
|
||||
func TestAzureGetModel(t *testing.T) {
|
||||
client, server, teardown := setupAzureTestServer()
|
||||
defer teardown()
|
||||
|
||||
@@ -11,8 +11,8 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sashabaranov/go-openai"
|
||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||
"git.vaala.cloud/VaalaCat/go-openai"
|
||||
"git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
|
||||
)
|
||||
|
||||
// TestModeration Tests the moderations endpoint of the API using the mocked server.
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
package openai_test
|
||||
|
||||
import (
|
||||
"github.com/sashabaranov/go-openai"
|
||||
"github.com/sashabaranov/go-openai/internal/test"
|
||||
"git.vaala.cloud/VaalaCat/go-openai"
|
||||
"git.vaala.cloud/VaalaCat/go-openai/internal/test"
|
||||
)
|
||||
|
||||
func setupOpenAITestServer() (client *openai.Client, server *test.ServerTest, teardown func()) {
|
||||
@@ -29,9 +29,9 @@ func setupAzureTestServer() (client *openai.Client, server *test.ServerTest, tea
|
||||
|
||||
// numTokens Returns the number of GPT-3 encoded tokens in the given text.
|
||||
// This function approximates based on the rule of thumb stated by OpenAI:
|
||||
// https://beta.openai.com/tokenizer
|
||||
// https://beta.openai.com/tokenizer.
|
||||
//
|
||||
// TODO: implement an actual tokenizer for GPT-3 and Codex (once available)
|
||||
// TODO: implement an actual tokenizer for GPT-3 and Codex (once available).
|
||||
func numTokens(s string) int {
|
||||
return int(float32(len(s)) / 4)
|
||||
}
|
||||
|
||||
81
reasoning_validator.go
Normal file
81
reasoning_validator.go
Normal file
@@ -0,0 +1,81 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
// Deprecated: use ErrReasoningModelMaxTokensDeprecated instead.
|
||||
ErrO1MaxTokensDeprecated = errors.New("this model is not supported MaxTokens, please use MaxCompletionTokens") //nolint:lll
|
||||
ErrCompletionUnsupportedModel = errors.New("this model is not supported with this method, please use CreateChatCompletion client method instead") //nolint:lll
|
||||
ErrCompletionStreamNotSupported = errors.New("streaming is not supported with this method, please use CreateCompletionStream") //nolint:lll
|
||||
ErrCompletionRequestPromptTypeNotSupported = errors.New("the type of CompletionRequest.Prompt only supports string and []string") //nolint:lll
|
||||
)
|
||||
|
||||
var (
|
||||
ErrO1BetaLimitationsMessageTypes = errors.New("this model has beta-limitations, user and assistant messages only, system messages are not supported") //nolint:lll
|
||||
ErrO1BetaLimitationsTools = errors.New("this model has beta-limitations, tools, function calling, and response format parameters are not supported") //nolint:lll
|
||||
// Deprecated: use ErrReasoningModelLimitations* instead.
|
||||
ErrO1BetaLimitationsLogprobs = errors.New("this model has beta-limitations, logprobs not supported") //nolint:lll
|
||||
ErrO1BetaLimitationsOther = errors.New("this model has beta-limitations, temperature, top_p and n are fixed at 1, while presence_penalty and frequency_penalty are fixed at 0") //nolint:lll
|
||||
)
|
||||
|
||||
var (
|
||||
//nolint:lll
|
||||
ErrReasoningModelMaxTokensDeprecated = errors.New("this model is not supported MaxTokens, please use MaxCompletionTokens")
|
||||
ErrReasoningModelLimitationsLogprobs = errors.New("this model has beta-limitations, logprobs not supported") //nolint:lll
|
||||
ErrReasoningModelLimitationsOther = errors.New("this model has beta-limitations, temperature, top_p and n are fixed at 1, while presence_penalty and frequency_penalty are fixed at 0") //nolint:lll
|
||||
)
|
||||
|
||||
// ReasoningValidator handles validation for o-series model requests.
|
||||
type ReasoningValidator struct{}
|
||||
|
||||
// NewReasoningValidator creates a new validator for o-series models.
|
||||
func NewReasoningValidator() *ReasoningValidator {
|
||||
return &ReasoningValidator{}
|
||||
}
|
||||
|
||||
// Validate performs all validation checks for o-series models.
|
||||
func (v *ReasoningValidator) Validate(request ChatCompletionRequest) error {
|
||||
o1Series := strings.HasPrefix(request.Model, "o1")
|
||||
o3Series := strings.HasPrefix(request.Model, "o3")
|
||||
o4Series := strings.HasPrefix(request.Model, "o4")
|
||||
|
||||
if !o1Series && !o3Series && !o4Series {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := v.validateReasoningModelParams(request); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateReasoningModelParams checks reasoning model parameters.
|
||||
func (v *ReasoningValidator) validateReasoningModelParams(request ChatCompletionRequest) error {
|
||||
if request.MaxTokens > 0 {
|
||||
return ErrReasoningModelMaxTokensDeprecated
|
||||
}
|
||||
if request.LogProbs {
|
||||
return ErrReasoningModelLimitationsLogprobs
|
||||
}
|
||||
if request.Temperature > 0 && request.Temperature != 1 {
|
||||
return ErrReasoningModelLimitationsOther
|
||||
}
|
||||
if request.TopP > 0 && request.TopP != 1 {
|
||||
return ErrReasoningModelLimitationsOther
|
||||
}
|
||||
if request.N > 0 && request.N != 1 {
|
||||
return ErrReasoningModelLimitationsOther
|
||||
}
|
||||
if request.PresencePenalty > 0 {
|
||||
return ErrReasoningModelLimitationsOther
|
||||
}
|
||||
if request.FrequencyPenalty > 0 {
|
||||
return ErrReasoningModelLimitationsOther
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
13
run.go
13
run.go
@@ -83,12 +83,13 @@ const (
|
||||
)
|
||||
|
||||
type RunRequest struct {
|
||||
AssistantID string `json:"assistant_id"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Instructions string `json:"instructions,omitempty"`
|
||||
AdditionalInstructions string `json:"additional_instructions,omitempty"`
|
||||
Tools []Tool `json:"tools,omitempty"`
|
||||
Metadata map[string]any `json:"metadata,omitempty"`
|
||||
AssistantID string `json:"assistant_id"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Instructions string `json:"instructions,omitempty"`
|
||||
AdditionalInstructions string `json:"additional_instructions,omitempty"`
|
||||
AdditionalMessages []ThreadMessage `json:"additional_messages,omitempty"`
|
||||
Tools []Tool `json:"tools,omitempty"`
|
||||
Metadata map[string]any `json:"metadata,omitempty"`
|
||||
|
||||
// Sampling temperature between 0 and 2. Higher values like 0.8 are more random.
|
||||
// lower values are more focused and deterministic.
|
||||
|
||||
@@ -3,8 +3,8 @@ package openai_test
|
||||
import (
|
||||
"context"
|
||||
|
||||
openai "github.com/sashabaranov/go-openai"
|
||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||
openai "git.vaala.cloud/VaalaCat/go-openai"
|
||||
"git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
|
||||
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
12
speech.go
12
speech.go
@@ -8,20 +8,25 @@ import (
|
||||
type SpeechModel string
|
||||
|
||||
const (
|
||||
TTSModel1 SpeechModel = "tts-1"
|
||||
TTSModel1HD SpeechModel = "tts-1-hd"
|
||||
TTSModelCanary SpeechModel = "canary-tts"
|
||||
TTSModel1 SpeechModel = "tts-1"
|
||||
TTSModel1HD SpeechModel = "tts-1-hd"
|
||||
TTSModelCanary SpeechModel = "canary-tts"
|
||||
TTSModelGPT4oMini SpeechModel = "gpt-4o-mini-tts"
|
||||
)
|
||||
|
||||
type SpeechVoice string
|
||||
|
||||
const (
|
||||
VoiceAlloy SpeechVoice = "alloy"
|
||||
VoiceAsh SpeechVoice = "ash"
|
||||
VoiceBallad SpeechVoice = "ballad"
|
||||
VoiceCoral SpeechVoice = "coral"
|
||||
VoiceEcho SpeechVoice = "echo"
|
||||
VoiceFable SpeechVoice = "fable"
|
||||
VoiceOnyx SpeechVoice = "onyx"
|
||||
VoiceNova SpeechVoice = "nova"
|
||||
VoiceShimmer SpeechVoice = "shimmer"
|
||||
VoiceVerse SpeechVoice = "verse"
|
||||
)
|
||||
|
||||
type SpeechResponseFormat string
|
||||
@@ -39,6 +44,7 @@ type CreateSpeechRequest struct {
|
||||
Model SpeechModel `json:"model"`
|
||||
Input string `json:"input"`
|
||||
Voice SpeechVoice `json:"voice"`
|
||||
Instructions string `json:"instructions,omitempty"` // Optional, Doesnt work with tts-1 or tts-1-hd.
|
||||
ResponseFormat SpeechResponseFormat `json:"response_format,omitempty"` // Optional, default to mp3
|
||||
Speed float64 `json:"speed,omitempty"` // Optional, default to 1.0
|
||||
}
|
||||
|
||||
@@ -11,9 +11,9 @@ import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/sashabaranov/go-openai"
|
||||
"github.com/sashabaranov/go-openai/internal/test"
|
||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||
"git.vaala.cloud/VaalaCat/go-openai"
|
||||
"git.vaala.cloud/VaalaCat/go-openai/internal/test"
|
||||
"git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
|
||||
)
|
||||
|
||||
func TestSpeechIntegration(t *testing.T) {
|
||||
@@ -21,10 +21,8 @@ func TestSpeechIntegration(t *testing.T) {
|
||||
defer teardown()
|
||||
|
||||
server.RegisterHandler("/v1/audio/speech", func(w http.ResponseWriter, r *http.Request) {
|
||||
dir, cleanup := test.CreateTestDirectory(t)
|
||||
path := filepath.Join(dir, "fake.mp3")
|
||||
path := filepath.Join(t.TempDir(), "fake.mp3")
|
||||
test.CreateTestFile(t, path)
|
||||
defer cleanup()
|
||||
|
||||
// audio endpoints only accept POST requests
|
||||
if r.Method != "POST" {
|
||||
|
||||
@@ -6,13 +6,14 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"regexp"
|
||||
|
||||
utils "github.com/sashabaranov/go-openai/internal"
|
||||
utils "git.vaala.cloud/VaalaCat/go-openai/internal"
|
||||
)
|
||||
|
||||
var (
|
||||
headerData = []byte("data: ")
|
||||
errorPrefix = []byte(`data: {"error":`)
|
||||
headerData = regexp.MustCompile(`^data:\s*`)
|
||||
errorPrefix = regexp.MustCompile(`^data:\s*{"error":`)
|
||||
)
|
||||
|
||||
type streamable interface {
|
||||
@@ -32,17 +33,28 @@ type streamReader[T streamable] struct {
|
||||
}
|
||||
|
||||
func (stream *streamReader[T]) Recv() (response T, err error) {
|
||||
if stream.isFinished {
|
||||
err = io.EOF
|
||||
rawLine, err := stream.RecvRaw()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
response, err = stream.processLines()
|
||||
return
|
||||
err = stream.unmarshaler.Unmarshal(rawLine, &response)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (stream *streamReader[T]) RecvRaw() ([]byte, error) {
|
||||
if stream.isFinished {
|
||||
return nil, io.EOF
|
||||
}
|
||||
|
||||
return stream.processLines()
|
||||
}
|
||||
|
||||
//nolint:gocognit
|
||||
func (stream *streamReader[T]) processLines() (T, error) {
|
||||
func (stream *streamReader[T]) processLines() ([]byte, error) {
|
||||
var (
|
||||
emptyMessagesCount uint
|
||||
hasErrorPrefix bool
|
||||
@@ -53,44 +65,38 @@ func (stream *streamReader[T]) processLines() (T, error) {
|
||||
if readErr != nil || hasErrorPrefix {
|
||||
respErr := stream.unmarshalError()
|
||||
if respErr != nil {
|
||||
return *new(T), fmt.Errorf("error, %w", respErr.Error)
|
||||
return nil, fmt.Errorf("error, %w", respErr.Error)
|
||||
}
|
||||
return *new(T), readErr
|
||||
return nil, readErr
|
||||
}
|
||||
|
||||
noSpaceLine := bytes.TrimSpace(rawLine)
|
||||
if bytes.HasPrefix(noSpaceLine, errorPrefix) {
|
||||
if errorPrefix.Match(noSpaceLine) {
|
||||
hasErrorPrefix = true
|
||||
}
|
||||
if !bytes.HasPrefix(noSpaceLine, headerData) || hasErrorPrefix {
|
||||
if !headerData.Match(noSpaceLine) || hasErrorPrefix {
|
||||
if hasErrorPrefix {
|
||||
noSpaceLine = bytes.TrimPrefix(noSpaceLine, headerData)
|
||||
noSpaceLine = headerData.ReplaceAll(noSpaceLine, nil)
|
||||
}
|
||||
writeErr := stream.errAccumulator.Write(noSpaceLine)
|
||||
if writeErr != nil {
|
||||
return *new(T), writeErr
|
||||
return nil, writeErr
|
||||
}
|
||||
emptyMessagesCount++
|
||||
if emptyMessagesCount > stream.emptyMessagesLimit {
|
||||
return *new(T), ErrTooManyEmptyStreamMessages
|
||||
return nil, ErrTooManyEmptyStreamMessages
|
||||
}
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
noPrefixLine := bytes.TrimPrefix(noSpaceLine, headerData)
|
||||
noPrefixLine := headerData.ReplaceAll(noSpaceLine, nil)
|
||||
if string(noPrefixLine) == "[DONE]" {
|
||||
stream.isFinished = true
|
||||
return *new(T), io.EOF
|
||||
return nil, io.EOF
|
||||
}
|
||||
|
||||
var response T
|
||||
unmarshalErr := stream.unmarshaler.Unmarshal(noPrefixLine, &response)
|
||||
if unmarshalErr != nil {
|
||||
return *new(T), unmarshalErr
|
||||
}
|
||||
|
||||
return response, nil
|
||||
return noPrefixLine, nil
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -6,9 +6,9 @@ import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
utils "github.com/sashabaranov/go-openai/internal"
|
||||
"github.com/sashabaranov/go-openai/internal/test"
|
||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||
utils "git.vaala.cloud/VaalaCat/go-openai/internal"
|
||||
"git.vaala.cloud/VaalaCat/go-openai/internal/test"
|
||||
"git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
|
||||
)
|
||||
|
||||
var errTestUnmarshalerFailed = errors.New("test unmarshaler failed")
|
||||
@@ -63,3 +63,16 @@ func TestStreamReaderReturnsErrTestErrorAccumulatorWriteFailed(t *testing.T) {
|
||||
_, err := stream.Recv()
|
||||
checks.ErrorIs(t, err, test.ErrTestErrorAccumulatorWriteFailed, "Did not return error when write failed", err.Error())
|
||||
}
|
||||
|
||||
func TestStreamReaderRecvRaw(t *testing.T) {
|
||||
stream := &streamReader[ChatCompletionStreamResponse]{
|
||||
reader: bufio.NewReader(bytes.NewReader([]byte("data: {\"key\": \"value\"}\n"))),
|
||||
}
|
||||
rawLine, err := stream.RecvRaw()
|
||||
if err != nil {
|
||||
t.Fatalf("Did not return raw line: %v", err)
|
||||
}
|
||||
if !bytes.Equal(rawLine, []byte("{\"key\": \"value\"}")) {
|
||||
t.Fatalf("Did not return raw line: %v", string(rawLine))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,8 +10,8 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sashabaranov/go-openai"
|
||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||
"git.vaala.cloud/VaalaCat/go-openai"
|
||||
"git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
|
||||
)
|
||||
|
||||
func TestCompletionsStreamWrongModel(t *testing.T) {
|
||||
|
||||
@@ -7,8 +7,8 @@ import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
openai "github.com/sashabaranov/go-openai"
|
||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||
openai "git.vaala.cloud/VaalaCat/go-openai"
|
||||
"git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
|
||||
)
|
||||
|
||||
// TestThread Tests the thread endpoint of the API using the mocked server.
|
||||
|
||||
@@ -3,8 +3,8 @@ package openai_test
|
||||
import (
|
||||
"context"
|
||||
|
||||
openai "github.com/sashabaranov/go-openai"
|
||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||
openai "git.vaala.cloud/VaalaCat/go-openai"
|
||||
"git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
|
||||
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
Reference in New Issue
Block a user