Compare commits

..

37 Commits

Author SHA1 Message Date
VaalaCat
2436e7afb8 feat: add reasoning format 2025-06-15 12:59:44 +00:00
VaalaCat
67f3b169df feat: add include_reasoning 2025-06-15 12:59:08 +00:00
VaalaCat
3640274cd1 feat: change repo name 2025-06-15 12:58:45 +00:00
JT A.
ff9d83a485 skip json field (#1009)
* skip json field

* backfill some coverage and tests
2025-05-29 11:31:35 +01:00
Axb12
8c65b35c57 update image api *os.File to io.Reader (#994)
* update image api *os.File to io.Reader

* update code style

* add reader test

* supplementary reader test

* update the reader in the form builder test

* add commnet

* update comment

* update code style
2025-05-20 14:45:40 +01:00
Alex Baranov
4d2e7ab29d fix lint (#998) 2025-05-13 12:59:06 +01:00
Justa
6aaa732296 add ChatTemplateKwargs to ChatCompletionRequest (#980)
Co-authored-by: Justa <justa.cai@akuvox.com>
2025-05-13 12:52:44 +01:00
Pedro Chaparro
0116f2994d feat: add support for image generation using gpt-image-1 (#971)
* feat: add gpt-image-1 support

* feat: add example to generate image using gpt-image-1 model

* style: missing period in comments

* feat: add missing fields to example

* docs: add GPT Image 1 to README

* revert: keep `examples/images/main.go` unchanged

* docs: remove unnecessary newline from example in README file
2025-05-13 12:51:08 +01:00
Alex Baranov
8ba38f6ba1 remove backup file (#996) 2025-05-13 12:44:16 +01:00
Alex Baranov
6181facea7 update codecov action, pass token (#987) 2025-05-04 15:45:40 +01:00
Alex Baranov
77ccac8d34 Upgrade golangci-lint to 2.1.5 (#986)
* Upgrade golangci-lint to 2.1.5

* update action
2025-05-03 22:39:47 +01:00
Alex Baranov
5ea214a188 Improve unit test coverage (#984)
* add tests for config

* add audio tests

* lint

* lint

* lint
2025-05-03 22:25:14 +01:00
Ben Katz
d65f0cb54e Fix: Corrected typo in O4Mini20250416 model name and endpoint map. (#981) 2025-05-03 21:44:48 +01:00
Daniel Peng
93a611cf4f Add Prediction field (#970)
* Add Prediction field to ChatCompletionRequest

* Include prediction tokens in response
2025-04-29 14:38:27 +01:00
Oleksandr Redko
6836cf6a6f Remove redundant typecheck linter (#955) 2025-04-29 14:36:38 +01:00
Sean McGinnis
da5f9bc9bc Add CompletionRequest.StreamOptions (#959)
The legacy completion API supports a `stream_options` object when
`stream` is set to true [0]. This adds a StreamOptions property to the
CompletionRequest struct to support this setting.

[0] https://platform.openai.com/docs/api-reference/completions/create#completions-create-stream_options

Signed-off-by: Sean McGinnis <sean.mcginnis@gmail.com>
2025-04-29 14:35:26 +01:00
rory malcolm
bb5bc27567 Add support for 4o-mini and 3o (#968)
- This adds supports, and tests, for the 3o and 4o-mini class of models
2025-04-29 14:34:33 +01:00
Zhongxian Pan
4cccc6c934 Adapt different stream data prefix, with or without space (#945) 2025-04-29 14:29:15 +01:00
goodenough
306fbbbe6f Add support for reasoning_content field in chat completion messages for DeepSeek R1 (#925)
* support deepseek field "reasoning_content"

* support deepseek field "reasoning_content"

* Comment ends in a period (godot)

* add comment on field reasoning_content

* fix go lint error

* chore: trigger CI

* make field "content" in MarshalJSON function omitempty

* remove reasoning_content in TestO1ModelChatCompletions func

* feat: Add test and handler for deepseek-reasoner chat model completions, including support for reasoning content in responses.

* feat: Add test and handler for deepseek-reasoner chat model completions, including support for reasoning content in responses.

* feat: Add test and handler for deepseek-reasoner chat model completions, including support for reasoning content in responses.
2025-04-29 14:24:45 +01:00
netr
658beda2ba feat: Add missing TTS models and voices (#958)
* feat: Add missing TTS models and voices

* feat: Add new instruction field to create speech request

- From docs: Control the voice of your generated audio with additional instructions. Does not work with tts-1 or tts-1-hd.

* fix: add canary-tts back to SpeechModel
2025-04-26 11:13:43 +01:00
Takahiro Ikeuchi
d68a683815 feat: add new GPT-4.1 model variants to completion.go (#966)
* feat: add new GPT-4.1 model variants to completion.go

* feat: add tests for unsupported models in completion endpoint

* fix: add missing periods to test function comments in completion_test.go
2025-04-23 22:50:47 +01:00
JT A.
e99eb54c9d add enum tag to jsonschema (#962)
* fix jsonschema tests

* ensure all run during PR Github Action

* add test for struct to schema

* add support for enum tag

* support nullable tag
2025-04-13 19:00:48 +01:00
Liu Shuang
74d6449f22 feat: add gpt-4.5-preview models (#947) 2025-03-04 08:26:59 +00:00
Alex Baranov
261721bfdb Fix linter (#943)
* fix lint

* remove linters
2025-02-25 16:56:35 +00:00
Dan Ackerson
be2e2387d4 feat: add Anthropic API support with custom version header (#934)
* feat: add Anthropic API support with custom version header

* refactor: use switch statement for API type header handling

* refactor: add OpenAI & AzureAD types to be exhaustive

* Update client.go

need explicit fallthrough in empty case statements

* constant for APIVersion; addtl tests
2025-02-25 11:03:38 +00:00
Liu Shuang
85f578b865 fix: remove validateO1Specific (#939)
* fix: remove validateO1Specific

* update golangci-lint-action version

* fix actions

* fix actions

* fix actions

* fix actions

* remove some o1 test
2025-02-17 11:29:18 +00:00
Liu Shuang
c0a9a75fe0 feat: add developer role (#936) 2025-02-12 15:05:44 +00:00
Mazyar Yousefiniyae shad
a62919e8c6 ref: add image url support to messages (#933)
Some checks failed
Integration tests / Run integration tests (push) Has been cancelled
Sanity check / Sanity check (push) Has been cancelled
* ref: add image url support to messages

* fix linter error

* fix linter error
2025-02-09 18:36:44 +00:00
rory malcolm
2054db016c Add support for O3-mini (#930)
* Add support for O3-mini

- Add support for the o3 mini set of models, including tests that match the constraints in OpenAI's API docs (https://platform.openai.com/docs/models#o3-mini).

* Deprecate and refactor

- Deprecate `ErrO1BetaLimitationsLogprobs` and `ErrO1BetaLimitationsOther`

- Implement `validationRequestForReasoningModels`, which works on both o1 & o3, and has per-model-type restrictions on functionality (eg, o3 class are allowed function calls and system messages, o1 isn't)

* Move reasoning validation to `reasoning_validator.go`

- Add a `NewReasoningValidator` which exposes a `Validate()` method for a given request

- Also adds a test for chat streams

* Final nits
2025-02-06 14:53:19 +00:00
saileshd1402
45aa99607b Make "Content" field in "ChatCompletionMessage" omitempty (#926) 2025-01-31 19:05:29 +00:00
Trevor Creech
9823a8bbbd Chat Completion API: add ReasoningEffort and new o1 models (#928)
* add reasoning_effort param

* add o1 model

* fix lint
2025-01-31 18:57:57 +00:00
Oleksandr Redko
7a2915a37d Simplify tests with T.TempDir (#929) 2025-01-31 18:55:41 +00:00
Sabuhi Gurbani
2a0ff5ac63 Added additional_messages (#914) 2024-12-27 10:01:16 +00:00
Alex Baranov
56a9acf86f Ignore test.mp3 (#913) 2024-12-08 13:16:48 +00:00
Tim Misiak
af5355f5b1 Fix ID field to be optional (#911)
The ID field is not always present for streaming responses. Without omitempty, the entire ToolCall struct will be missing.
2024-12-08 13:12:05 +00:00
Qiying Wang
c203ca001f feat: add RecvRaw (#896) 2024-11-30 10:29:05 +00:00
Liu Shuang
21fa42c18d feat: add gpt-4o-2024-11-20 model (#905) 2024-11-30 09:39:47 +00:00
66 changed files with 1962 additions and 830 deletions

View File

@@ -8,7 +8,7 @@ assignees: ''
---
Your issue may already be reported!
Please search on the [issue tracker](https://github.com/sashabaranov/go-openai/issues) before creating one.
Please search on the [issue tracker](https://git.vaala.cloud/VaalaCat/go-openai/issues) before creating one.
**Describe the bug**
A clear and concise description of what the bug is. If it's an API-related bug, please provide relevant endpoint(s).

View File

@@ -8,7 +8,7 @@ assignees: ''
---
Your issue may already be reported!
Please search on the [issue tracker](https://github.com/sashabaranov/go-openai/issues) before creating one.
Please search on the [issue tracker](https://git.vaala.cloud/VaalaCat/go-openai/issues) before creating one.
**Is your feature request related to a problem? Please describe.**
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]

View File

@@ -1,5 +1,5 @@
A similar PR may already be submitted!
Please search among the [Pull request](https://github.com/sashabaranov/go-openai/pulls) before creating one.
Please search among the [Pull request](https://git.vaala.cloud/VaalaCat/go-openai/pulls) before creating one.
If your changes introduce breaking changes, please prefix the title of your pull request with "[BREAKING_CHANGES]". This allows for clear identification of such changes in the 'What's Changed' section on the release page, making it developer-friendly.

View File

@@ -13,15 +13,17 @@ jobs:
- name: Setup Go
uses: actions/setup-go@v5
with:
go-version: '1.21'
go-version: '1.24'
- name: Run vet
run: |
go vet .
- name: Run golangci-lint
uses: golangci/golangci-lint-action@v4
uses: golangci/golangci-lint-action@v7
with:
version: latest
version: v2.1.5
- name: Run tests
run: go test -race -covermode=atomic -coverprofile=coverage.out -v .
run: go test -race -covermode=atomic -coverprofile=coverage.out -v ./...
- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@v4
uses: codecov/codecov-action@v5
with:
token: ${{ secrets.CODECOV_TOKEN }}

3
.gitignore vendored
View File

@@ -17,3 +17,6 @@
# Auth token for tests
.openai-token
.idea
# Generated by tests
test.mp3

View File

@@ -1,66 +1,94 @@
## Golden config for golangci-lint v1.47.3
#
# This is the best config for golangci-lint based on my experience and opinion.
# It is very strict, but not extremely strict.
# Feel free to adopt and change it for your needs.
run:
# Timeout for analysis, e.g. 30s, 5m.
# Default: 1m
timeout: 3m
# This file contains only configs which differ from defaults.
# All possible options can be found here https://github.com/golangci/golangci-lint/blob/master/.golangci.reference.yml
linters-settings:
version: "2"
linters:
default: none
enable:
- asciicheck
- bidichk
- bodyclose
- contextcheck
- cyclop
- dupl
- durationcheck
- errcheck
- errname
- errorlint
- exhaustive
- forbidigo
- funlen
- gochecknoinits
- gocognit
- goconst
- gocritic
- gocyclo
- godot
- gomoddirectives
- gomodguard
- goprintffuncname
- gosec
- govet
- ineffassign
- lll
- makezero
- mnd
- nestif
- nilerr
- nilnil
- nolintlint
- nosprintfhostport
- predeclared
- promlinter
- revive
- rowserrcheck
- sqlclosecheck
- staticcheck
- testpackage
- tparallel
- unconvert
- unparam
- unused
- usetesting
- wastedassign
- whitespace
settings:
cyclop:
# The maximal code complexity to report.
# Default: 10
max-complexity: 30
# The maximal average package complexity.
# If it's higher than 0.0 (float) the check is enabled
# Default: 0.0
package-average: 10.0
package-average: 10
errcheck:
# Report about not checking of errors in type assertions: `a := b.(MyStruct)`.
# Such cases aren't reported by default.
# Default: false
check-type-assertions: true
funlen:
# Checks the number of lines in a function.
# If lower than 0, disable the check.
# Default: 60
lines: 100
# Checks the number of statements in a function.
# If lower than 0, disable the check.
# Default: 40
statements: 50
gocognit:
# Minimal code complexity to report
# Default: 30 (but we recommend 10-20)
min-complexity: 20
gocritic:
# Settings passed to gocritic.
# The settings key is the name of a supported gocritic checker.
# The list of supported checkers can be find in https://go-critic.github.io/overview.
settings:
captLocal:
# Whether to restrict checker to params only.
# Default: true
paramsOnly: false
underef:
# Whether to skip (*x).method() calls where x is a pointer receiver.
# Default: true
skipRecvDeref: false
gomodguard:
blocked:
modules:
- github.com/golang/protobuf:
recommendations:
- google.golang.org/protobuf
reason: see https://developers.google.com/protocol-buffers/docs/reference/go/faq#modules
- github.com/satori/go.uuid:
recommendations:
- github.com/google/uuid
reason: satori's package is not maintained
- github.com/gofrs/uuid:
recommendations:
- github.com/google/uuid
reason: 'see recommendation from dev-infra team: https://confluence.gtforge.com/x/gQI6Aw'
govet:
disable:
- fieldalignment
enable-all: true
settings:
shadow:
strict: true
mnd:
# List of function patterns to exclude from analysis.
# Values always ignored: `time.Date`
# Default: []
ignored-functions:
- os.Chmod
- os.Mkdir
@@ -76,193 +104,44 @@ linters-settings:
- strconv.ParseFloat
- strconv.ParseInt
- strconv.ParseUint
gomodguard:
blocked:
# List of blocked modules.
# Default: []
modules:
- github.com/golang/protobuf:
recommendations:
- google.golang.org/protobuf
reason: "see https://developers.google.com/protocol-buffers/docs/reference/go/faq#modules"
- github.com/satori/go.uuid:
recommendations:
- github.com/google/uuid
reason: "satori's package is not maintained"
- github.com/gofrs/uuid:
recommendations:
- github.com/google/uuid
reason: "see recommendation from dev-infra team: https://confluence.gtforge.com/x/gQI6Aw"
govet:
# Enable all analyzers.
# Default: false
enable-all: true
# Disable analyzers by name.
# Run `go tool vet help` to see all analyzers.
# Default: []
disable:
- fieldalignment # too strict
# Settings per analyzer.
settings:
shadow:
# Whether to be strict about shadowing; can be noisy.
# Default: false
strict: true
nakedret:
# Make an issue if func has more lines of code than this setting, and it has naked returns.
# Default: 30
max-func-lines: 0
nolintlint:
# Exclude following linters from requiring an explanation.
# Default: []
allow-no-explanation: [ funlen, gocognit, lll ]
# Enable to require an explanation of nonzero length after each nolint directive.
# Default: false
require-explanation: true
# Enable to require nolint directives to mention the specific linter being suppressed.
# Default: false
require-specific: true
allow-no-explanation:
- funlen
- gocognit
- lll
rowserrcheck:
# database/sql is always checked
# Default: []
packages:
- github.com/jmoiron/sqlx
tenv:
# The option `all` will run against whole test files (`_test.go`) regardless of method/function signatures.
# Otherwise, only methods that take `*testing.T`, `*testing.B`, and `testing.TB` as arguments are checked.
# Default: false
all: true
varcheck:
# Check usage of exported fields and variables.
# Default: false
exported-fields: false # default false # TODO: enable after fixing false positives
linters:
disable-all: true
enable:
## enabled by default
- errcheck # Errcheck is a program for checking for unchecked errors in go programs. These unchecked errors can be critical bugs in some cases
- gosimple # Linter for Go source code that specializes in simplifying a code
- govet # Vet examines Go source code and reports suspicious constructs, such as Printf calls whose arguments do not align with the format string
- ineffassign # Detects when assignments to existing variables are not used
- staticcheck # Staticcheck is a go vet on steroids, applying a ton of static analysis checks
- typecheck # Like the front-end of a Go compiler, parses and type-checks Go code
- unused # Checks Go code for unused constants, variables, functions and types
## disabled by default
# - asasalint # Check for pass []any as any in variadic func(...any)
- asciicheck # Simple linter to check that your code does not contain non-ASCII identifiers
- bidichk # Checks for dangerous unicode character sequences
- bodyclose # checks whether HTTP response body is closed successfully
- contextcheck # check the function whether use a non-inherited context
- cyclop # checks function and package cyclomatic complexity
- dupl # Tool for code clone detection
- durationcheck # check for two durations multiplied together
- errname # Checks that sentinel errors are prefixed with the Err and error types are suffixed with the Error.
- errorlint # errorlint is a linter for that can be used to find code that will cause problems with the error wrapping scheme introduced in Go 1.13.
# Removed execinquery (deprecated). execinquery is a linter about query string checker in Query function which reads your Go src files and warning it finds
- exhaustive # check exhaustiveness of enum switch statements
- exportloopref # checks for pointers to enclosing loop variables
- forbidigo # Forbids identifiers
- funlen # Tool for detection of long functions
# - gochecknoglobals # check that no global variables exist
- gochecknoinits # Checks that no init functions are present in Go code
- gocognit # Computes and checks the cognitive complexity of functions
- goconst # Finds repeated strings that could be replaced by a constant
- gocritic # Provides diagnostics that check for bugs, performance and style issues.
- gocyclo # Computes and checks the cyclomatic complexity of functions
- godot # Check if comments end in a period
- goimports # In addition to fixing imports, goimports also formats your code in the same style as gofmt.
- gomoddirectives # Manage the use of 'replace', 'retract', and 'excludes' directives in go.mod.
- gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations.
- goprintffuncname # Checks that printf-like functions are named with f at the end
- gosec # Inspects source code for security problems
- lll # Reports long lines
- makezero # Finds slice declarations with non-zero initial length
# - nakedret # Finds naked returns in functions greater than a specified function length
- mnd # An analyzer to detect magic numbers.
- nestif # Reports deeply nested if statements
- nilerr # Finds the code that returns nil even if it checks that the error is not nil.
- nilnil # Checks that there is no simultaneous return of nil error and an invalid value.
# - noctx # noctx finds sending http request without context.Context
- nolintlint # Reports ill-formed or insufficient nolint directives
# - nonamedreturns # Reports all named returns
- nosprintfhostport # Checks for misuse of Sprintf to construct a host with port in a URL.
- predeclared # find code that shadows one of Go's predeclared identifiers
- promlinter # Check Prometheus metrics naming via promlint
- revive # Fast, configurable, extensible, flexible, and beautiful linter for Go. Drop-in replacement of golint.
- rowserrcheck # checks whether Err of rows is checked successfully
- sqlclosecheck # Checks that sql.Rows and sql.Stmt are closed.
- stylecheck # Stylecheck is a replacement for golint
- tenv # tenv is analyzer that detects using os.Setenv instead of t.Setenv since Go1.17
- testpackage # linter that makes you use a separate _test package
- tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes
- unconvert # Remove unnecessary type conversions
- unparam # Reports unused function parameters
- wastedassign # wastedassign finds wasted assignment statements.
- whitespace # Tool for detection of leading and trailing whitespace
## you may want to enable
#- decorder # check declaration order and count of types, constants, variables and functions
#- exhaustruct # Checks if all structure fields are initialized
#- goheader # Checks is file header matches to pattern
#- ireturn # Accept Interfaces, Return Concrete Types
#- prealloc # [premature optimization, but can be used in some cases] Finds slice declarations that could potentially be preallocated
#- varnamelen # [great idea, but too many false positives] checks that the length of a variable's name matches its scope
#- wrapcheck # Checks that errors returned from external packages are wrapped
## disabled
#- containedctx # containedctx is a linter that detects struct contained context.Context field
#- depguard # [replaced by gomodguard] Go linter that checks if package imports are in a list of acceptable packages
#- dogsled # Checks assignments with too many blank identifiers (e.g. x, _, _, _, := f())
#- errchkjson # [don't see profit + I'm against of omitting errors like in the first example https://github.com/breml/errchkjson] Checks types passed to the json encoding functions. Reports unsupported types and optionally reports occasions, where the check for the returned error can be omitted.
#- forcetypeassert # [replaced by errcheck] finds forced type assertions
#- gci # Gci controls golang package import order and makes it always deterministic.
#- godox # Tool for detection of FIXME, TODO and other comment keywords
#- goerr113 # [too strict] Golang linter to check the errors handling expressions
#- gofmt # [replaced by goimports] Gofmt checks whether code was gofmt-ed. By default this tool runs with -s option to check for code simplification
#- gofumpt # [replaced by goimports, gofumports is not available yet] Gofumpt checks whether code was gofumpt-ed.
#- grouper # An analyzer to analyze expression groups.
#- ifshort # Checks that your code uses short syntax for if-statements whenever possible
#- importas # Enforces consistent import aliases
#- maintidx # maintidx measures the maintainability index of each function.
#- misspell # [useless] Finds commonly misspelled English words in comments
#- nlreturn # [too strict and mostly code is not more readable] nlreturn checks for a new line before return and branch statements to increase code clarity
#- nosnakecase # Detects snake case of variable naming and function name. # TODO: maybe enable after https://github.com/sivchari/nosnakecase/issues/14
#- paralleltest # [too many false positives] paralleltest detects missing usage of t.Parallel() method in your Go test
#- tagliatelle # Checks the struct tags.
#- thelper # thelper detects golang test helpers without t.Helper() call and checks the consistency of test helpers
#- wsl # [too strict and mostly code is not more readable] Whitespace Linter - Forces you to use empty lines!
## deprecated
#- exhaustivestruct # [deprecated, replaced by exhaustruct] Checks if all struct's fields are initialized
#- golint # [deprecated, replaced by revive] Golint differs from gofmt. Gofmt reformats Go source code, whereas golint prints out style mistakes
#- interfacer # [deprecated] Linter that suggests narrower interface types
#- maligned # [deprecated, replaced by govet fieldalignment] Tool to detect Go structs that would take less memory if their fields were sorted
#- scopelint # [deprecated, replaced by exportloopref] Scopelint checks for unpinned variables in go programs
issues:
# Maximum count of issues with the same text.
# Set to 0 to disable.
# Default: 3
max-same-issues: 50
exclude-rules:
- source: "^//\\s*go:generate\\s"
linters: [ lll ]
- source: "(noinspection|TODO)"
linters: [ godot ]
- source: "//noinspection"
linters: [ gocritic ]
- source: "^\\s+if _, ok := err\\.\\([^.]+\\.InternalError\\); ok {"
linters: [ errorlint ]
- path: "_test\\.go"
linters:
exclusions:
generated: lax
presets:
- comments
- common-false-positives
- legacy
- std-error-handling
rules:
- linters:
- forbidigo
- mnd
- revive
path : ^examples/.*\.go$
- linters:
- lll
source: ^//\s*go:generate\s
- linters:
- godot
source: (noinspection|TODO)
- linters:
- gocritic
source: //noinspection
- linters:
- errorlint
source: ^\s+if _, ok := err\.\([^.]+\.InternalError\); ok {
- linters:
- bodyclose
- dupl
- funlen
@@ -270,3 +149,20 @@ issues:
- gosec
- noctx
- wrapcheck
- staticcheck
path: _test\.go
paths:
- third_party$
- builtin$
- examples$
issues:
max-same-issues: 50
formatters:
enable:
- goimports
exclusions:
generated: lax
paths:
- third_party$
- builtin$
- examples$

View File

@@ -1,22 +1,22 @@
# Contributing Guidelines
## Overview
Thank you for your interest in contributing to the "Go OpenAI" project! By following this guideline, we hope to ensure that your contributions are made smoothly and efficiently. The Go OpenAI project is licensed under the [Apache 2.0 License](https://github.com/sashabaranov/go-openai/blob/master/LICENSE), and we welcome contributions through GitHub pull requests.
Thank you for your interest in contributing to the "Go OpenAI" project! By following this guideline, we hope to ensure that your contributions are made smoothly and efficiently. The Go OpenAI project is licensed under the [Apache 2.0 License](https://git.vaala.cloud/VaalaCat/go-openai/blob/master/LICENSE), and we welcome contributions through GitHub pull requests.
## Reporting Bugs
If you discover a bug, first check the [GitHub Issues page](https://github.com/sashabaranov/go-openai/issues) to see if the issue has already been reported. If you're reporting a new issue, please use the "Bug report" template and provide detailed information about the problem, including steps to reproduce it.
If you discover a bug, first check the [GitHub Issues page](https://git.vaala.cloud/VaalaCat/go-openai/issues) to see if the issue has already been reported. If you're reporting a new issue, please use the "Bug report" template and provide detailed information about the problem, including steps to reproduce it.
## Suggesting Features
If you want to suggest a new feature or improvement, first check the [GitHub Issues page](https://github.com/sashabaranov/go-openai/issues) to ensure a similar suggestion hasn't already been made. Use the "Feature request" template to provide a detailed description of your suggestion.
If you want to suggest a new feature or improvement, first check the [GitHub Issues page](https://git.vaala.cloud/VaalaCat/go-openai/issues) to ensure a similar suggestion hasn't already been made. Use the "Feature request" template to provide a detailed description of your suggestion.
## Reporting Vulnerabilities
If you identify a security concern, please use the "Report a security vulnerability" template on the [GitHub Issues page](https://github.com/sashabaranov/go-openai/issues) to share the details. This report will only be viewable to repository maintainers. You will be credited if the advisory is published.
If you identify a security concern, please use the "Report a security vulnerability" template on the [GitHub Issues page](https://git.vaala.cloud/VaalaCat/go-openai/issues) to share the details. This report will only be viewable to repository maintainers. You will be credited if the advisory is published.
## Questions for Users
If you have questions, please utilize [StackOverflow](https://stackoverflow.com/) or the [GitHub Discussions page](https://github.com/sashabaranov/go-openai/discussions).
If you have questions, please utilize [StackOverflow](https://stackoverflow.com/) or the [GitHub Discussions page](https://git.vaala.cloud/VaalaCat/go-openai/discussions).
## Contributing Code
There might already be a similar pull requests submitted! Please search for [pull requests](https://github.com/sashabaranov/go-openai/pulls) before creating one.
There might already be a similar pull requests submitted! Please search for [pull requests](https://git.vaala.cloud/VaalaCat/go-openai/pulls) before creating one.
### Requirements for Merging a Pull Request

106
README.md
View File

@@ -1,19 +1,19 @@
# Go OpenAI
[![Go Reference](https://pkg.go.dev/badge/github.com/sashabaranov/go-openai.svg)](https://pkg.go.dev/github.com/sashabaranov/go-openai)
[![Go Report Card](https://goreportcard.com/badge/github.com/sashabaranov/go-openai)](https://goreportcard.com/report/github.com/sashabaranov/go-openai)
[![Go Reference](https://pkg.go.dev/badge/git.vaala.cloud/VaalaCat/go-openai.svg)](https://pkg.go.dev/git.vaala.cloud/VaalaCat/go-openai)
[![Go Report Card](https://goreportcard.com/badge/git.vaala.cloud/VaalaCat/go-openai)](https://goreportcard.com/report/git.vaala.cloud/VaalaCat/go-openai)
[![codecov](https://codecov.io/gh/sashabaranov/go-openai/branch/master/graph/badge.svg?token=bCbIfHLIsW)](https://codecov.io/gh/sashabaranov/go-openai)
This library provides unofficial Go clients for [OpenAI API](https://platform.openai.com/). We support:
* ChatGPT 4o, o1
* GPT-3, GPT-4
* DALL·E 2, DALL·E 3
* DALL·E 2, DALL·E 3, GPT Image 1
* Whisper
## Installation
```
go get github.com/sashabaranov/go-openai
go get git.vaala.cloud/VaalaCat/go-openai
```
Currently, go-openai requires Go version 1.18 or greater.
@@ -28,7 +28,7 @@ package main
import (
"context"
"fmt"
openai "github.com/sashabaranov/go-openai"
openai "git.vaala.cloud/VaalaCat/go-openai"
)
func main() {
@@ -80,7 +80,7 @@ import (
"errors"
"fmt"
"io"
openai "github.com/sashabaranov/go-openai"
openai "git.vaala.cloud/VaalaCat/go-openai"
)
func main() {
@@ -133,7 +133,7 @@ package main
import (
"context"
"fmt"
openai "github.com/sashabaranov/go-openai"
openai "git.vaala.cloud/VaalaCat/go-openai"
)
func main() {
@@ -166,7 +166,7 @@ import (
"context"
"fmt"
"io"
openai "github.com/sashabaranov/go-openai"
openai "git.vaala.cloud/VaalaCat/go-openai"
)
func main() {
@@ -215,7 +215,7 @@ import (
"context"
"fmt"
openai "github.com/sashabaranov/go-openai"
openai "git.vaala.cloud/VaalaCat/go-openai"
)
func main() {
@@ -247,7 +247,7 @@ import (
"fmt"
"os"
openai "github.com/sashabaranov/go-openai"
openai "git.vaala.cloud/VaalaCat/go-openai"
)
func main() {
@@ -288,7 +288,7 @@ import (
"context"
"encoding/base64"
"fmt"
openai "github.com/sashabaranov/go-openai"
openai "git.vaala.cloud/VaalaCat/go-openai"
"image/png"
"os"
)
@@ -357,6 +357,66 @@ func main() {
```
</details>
<details>
<summary>GPT Image 1 image generation</summary>
```go
package main
import (
"context"
"encoding/base64"
"fmt"
"os"
openai "github.com/sashabaranov/go-openai"
)
func main() {
c := openai.NewClient("your token")
ctx := context.Background()
req := openai.ImageRequest{
Prompt: "Parrot on a skateboard performing a trick. Large bold text \"SKATE MASTER\" banner at the bottom of the image. Cartoon style, natural light, high detail, 1:1 aspect ratio.",
Background: openai.CreateImageBackgroundOpaque,
Model: openai.CreateImageModelGptImage1,
Size: openai.CreateImageSize1024x1024,
N: 1,
Quality: openai.CreateImageQualityLow,
OutputCompression: 100,
OutputFormat: openai.CreateImageOutputFormatJPEG,
// Moderation: openai.CreateImageModerationLow,
// User: "",
}
resp, err := c.CreateImage(ctx, req)
if err != nil {
fmt.Printf("Image creation Image generation with GPT Image 1error: %v\n", err)
return
}
fmt.Println("Image Base64:", resp.Data[0].B64JSON)
// Decode the base64 data
imgBytes, err := base64.StdEncoding.DecodeString(resp.Data[0].B64JSON)
if err != nil {
fmt.Printf("Base64 decode error: %v\n", err)
return
}
// Write image to file
outputPath := "generated_image.jpg"
err = os.WriteFile(outputPath, imgBytes, 0644)
if err != nil {
fmt.Printf("Failed to write image file: %v\n", err)
return
}
fmt.Printf("The image was saved as %s\n", outputPath)
}
```
</details>
<details>
<summary>Configuring proxy</summary>
@@ -376,7 +436,7 @@ config.HTTPClient = &http.Client{
c := openai.NewClientWithConfig(config)
```
See also: https://pkg.go.dev/github.com/sashabaranov/go-openai#ClientConfig
See also: https://pkg.go.dev/git.vaala.cloud/VaalaCat/go-openai#ClientConfig
</details>
<details>
@@ -392,7 +452,7 @@ import (
"os"
"strings"
"github.com/sashabaranov/go-openai"
"git.vaala.cloud/VaalaCat/go-openai"
)
func main() {
@@ -446,7 +506,7 @@ import (
"context"
"fmt"
openai "github.com/sashabaranov/go-openai"
openai "git.vaala.cloud/VaalaCat/go-openai"
)
func main() {
@@ -492,7 +552,7 @@ package main
import (
"context"
"log"
openai "github.com/sashabaranov/go-openai"
openai "git.vaala.cloud/VaalaCat/go-openai"
)
@@ -549,7 +609,7 @@ import (
"context"
"fmt"
openai "github.com/sashabaranov/go-openai"
openai "git.vaala.cloud/VaalaCat/go-openai"
)
func main() {
@@ -680,7 +740,7 @@ package main
import (
"context"
"fmt"
"github.com/sashabaranov/go-openai"
"git.vaala.cloud/VaalaCat/go-openai"
)
func main() {
@@ -755,8 +815,8 @@ import (
"fmt"
"log"
"github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/jsonschema"
"git.vaala.cloud/VaalaCat/go-openai"
"git.vaala.cloud/VaalaCat/go-openai/jsonschema"
)
func main() {
@@ -828,7 +888,7 @@ Due to the factors mentioned above, different answers may be returned even for t
By adopting these strategies, you can expect more consistent results.
**Related Issues:**
[omitempty option of request struct will generate incorrect request when parameter is 0.](https://github.com/sashabaranov/go-openai/issues/9)
[omitempty option of request struct will generate incorrect request when parameter is 0.](https://git.vaala.cloud/VaalaCat/go-openai/issues/9)
### Does Go OpenAI provide a method to count tokens?
@@ -839,15 +899,15 @@ For counting tokens, you might find the following links helpful:
- [How to count tokens with tiktoken](https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb)
**Related Issues:**
[Is it possible to join the implementation of GPT3 Tokenizer](https://github.com/sashabaranov/go-openai/issues/62)
[Is it possible to join the implementation of GPT3 Tokenizer](https://git.vaala.cloud/VaalaCat/go-openai/issues/62)
## Contributing
By following [Contributing Guidelines](https://github.com/sashabaranov/go-openai/blob/master/CONTRIBUTING.md), we hope to ensure that your contributions are made smoothly and efficiently.
By following [Contributing Guidelines](https://git.vaala.cloud/VaalaCat/go-openai/blob/master/CONTRIBUTING.md), we hope to ensure that your contributions are made smoothly and efficiently.
## Thank you
We want to take a moment to express our deepest gratitude to the [contributors](https://github.com/sashabaranov/go-openai/graphs/contributors) and sponsors of this project:
We want to take a moment to express our deepest gratitude to the [contributors](https://git.vaala.cloud/VaalaCat/go-openai/graphs/contributors) and sponsors of this project:
- [Carson Kahn](https://carsonkahn.com) of [Spindle AI](https://spindleai.com)
To all of you: thank you. You've helped us achieve more than we ever imagined possible. Can't wait to see where we go next, together!

View File

@@ -10,9 +10,9 @@ import (
"os"
"testing"
"github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks"
"github.com/sashabaranov/go-openai/jsonschema"
"git.vaala.cloud/VaalaCat/go-openai"
"git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
"git.vaala.cloud/VaalaCat/go-openai/jsonschema"
)
func TestAPI(t *testing.T) {

View File

@@ -3,8 +3,8 @@ package openai_test
import (
"context"
openai "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks"
openai "git.vaala.cloud/VaalaCat/go-openai"
"git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
"encoding/json"
"fmt"

View File

@@ -8,7 +8,7 @@ import (
"net/http"
"os"
utils "github.com/sashabaranov/go-openai/internal"
utils "git.vaala.cloud/VaalaCat/go-openai/internal"
)
// Whisper Defines the models provided by OpenAI to use when processing audio with OpenAI.

View File

@@ -12,9 +12,9 @@ import (
"strings"
"testing"
"github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks"
"git.vaala.cloud/VaalaCat/go-openai"
"git.vaala.cloud/VaalaCat/go-openai/internal/test"
"git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
)
// TestAudio Tests the transcription and translation endpoints of the API using the mocked server.
@@ -40,12 +40,9 @@ func TestAudio(t *testing.T) {
ctx := context.Background()
dir, cleanup := test.CreateTestDirectory(t)
defer cleanup()
for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) {
path := filepath.Join(dir, "fake.mp3")
path := filepath.Join(t.TempDir(), "fake.mp3")
test.CreateTestFile(t, path)
req := openai.AudioRequest{
@@ -90,12 +87,9 @@ func TestAudioWithOptionalArgs(t *testing.T) {
ctx := context.Background()
dir, cleanup := test.CreateTestDirectory(t)
defer cleanup()
for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) {
path := filepath.Join(dir, "fake.mp3")
path := filepath.Join(t.TempDir(), "fake.mp3")
test.CreateTestFile(t, path)
req := openai.AudioRequest{

View File

@@ -2,20 +2,21 @@ package openai //nolint:testpackage // testing private field
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"testing"
"github.com/sashabaranov/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks"
"git.vaala.cloud/VaalaCat/go-openai/internal/test"
"git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
)
func TestAudioWithFailingFormBuilder(t *testing.T) {
dir, cleanup := test.CreateTestDirectory(t)
defer cleanup()
path := filepath.Join(dir, "fake.mp3")
path := filepath.Join(t.TempDir(), "fake.mp3")
test.CreateTestFile(t, path)
req := AudioRequest{
@@ -63,9 +64,7 @@ func TestAudioWithFailingFormBuilder(t *testing.T) {
func TestCreateFileField(t *testing.T) {
t.Run("createFileField failing file", func(t *testing.T) {
dir, cleanup := test.CreateTestDirectory(t)
defer cleanup()
path := filepath.Join(dir, "fake.mp3")
path := filepath.Join(t.TempDir(), "fake.mp3")
test.CreateTestFile(t, path)
req := AudioRequest{
@@ -111,3 +110,131 @@ func TestCreateFileField(t *testing.T) {
checks.HasError(t, err, "createFileField using file should return error when open file fails")
})
}
// failingFormBuilder always returns an error when creating form files.
type failingFormBuilder struct{ err error }
func (f *failingFormBuilder) CreateFormFile(_ string, _ *os.File) error {
return f.err
}
func (f *failingFormBuilder) CreateFormFileReader(_ string, _ io.Reader, _ string) error {
return f.err
}
func (f *failingFormBuilder) WriteField(_, _ string) error {
return nil
}
func (f *failingFormBuilder) Close() error {
return nil
}
func (f *failingFormBuilder) FormDataContentType() string {
return "multipart/form-data"
}
// failingAudioRequestBuilder simulates an error during HTTP request construction.
type failingAudioRequestBuilder struct{ err error }
func (f *failingAudioRequestBuilder) Build(
_ context.Context,
_, _ string,
_ any,
_ http.Header,
) (*http.Request, error) {
return nil, f.err
}
// errorHTTPClient always returns an error when making HTTP calls.
type errorHTTPClient struct{ err error }
func (e *errorHTTPClient) Do(_ *http.Request) (*http.Response, error) {
return nil, e.err
}
func TestCallAudioAPIMultipartFormError(t *testing.T) {
client := NewClient("test-token")
errForm := errors.New("mock create form file failure")
// Override form builder to force an error during multipart form creation.
client.createFormBuilder = func(_ io.Writer) utils.FormBuilder {
return &failingFormBuilder{err: errForm}
}
// Provide a reader so createFileField uses the reader path (no file open).
req := AudioRequest{FilePath: "fake.mp3", Reader: bytes.NewBuffer([]byte("dummy")), Model: Whisper1}
_, err := client.callAudioAPI(context.Background(), req, "transcriptions")
if err == nil {
t.Fatal("expected error but got none")
}
if !errors.Is(err, errForm) {
t.Errorf("expected error %v, got %v", errForm, err)
}
}
func TestCallAudioAPINewRequestError(t *testing.T) {
client := NewClient("test-token")
// Create a real temp file so multipart form succeeds.
tmp := t.TempDir()
path := filepath.Join(tmp, "file.mp3")
if err := os.WriteFile(path, []byte("content"), 0644); err != nil {
t.Fatalf("failed to write temp file: %v", err)
}
errBuild := errors.New("mock build failure")
client.requestBuilder = &failingAudioRequestBuilder{err: errBuild}
req := AudioRequest{FilePath: path, Model: Whisper1}
_, err := client.callAudioAPI(context.Background(), req, "translations")
if err == nil {
t.Fatal("expected error but got none")
}
if !errors.Is(err, errBuild) {
t.Errorf("expected error %v, got %v", errBuild, err)
}
}
func TestCallAudioAPISendRequestErrorJSON(t *testing.T) {
client := NewClient("test-token")
// Create a real temp file so multipart form succeeds.
tmp := t.TempDir()
path := filepath.Join(tmp, "file.mp3")
if err := os.WriteFile(path, []byte("content"), 0644); err != nil {
t.Fatalf("failed to write temp file: %v", err)
}
errHTTP := errors.New("mock HTTPClient failure")
// Override HTTP client to simulate a network error.
client.config.HTTPClient = &errorHTTPClient{err: errHTTP}
req := AudioRequest{FilePath: path, Model: Whisper1}
_, err := client.callAudioAPI(context.Background(), req, "transcriptions")
if err == nil {
t.Fatal("expected error but got none")
}
if !errors.Is(err, errHTTP) {
t.Errorf("expected error %v, got %v", errHTTP, err)
}
}
func TestCallAudioAPISendRequestErrorText(t *testing.T) {
client := NewClient("test-token")
tmp := t.TempDir()
path := filepath.Join(tmp, "file.mp3")
if err := os.WriteFile(path, []byte("content"), 0644); err != nil {
t.Fatalf("failed to write temp file: %v", err)
}
errHTTP := errors.New("mock HTTPClient failure")
client.config.HTTPClient = &errorHTTPClient{err: errHTTP}
// Use a non-JSON response format to exercise the text path.
req := AudioRequest{FilePath: path, Model: Whisper1, Format: AudioResponseFormatText}
_, err := client.callAudioAPI(context.Background(), req, "translations")
if err == nil {
t.Fatal("expected error but got none")
}
if !errors.Is(err, errHTTP) {
t.Errorf("expected error %v, got %v", errHTTP, err)
}
}

View File

@@ -7,8 +7,8 @@ import (
"reflect"
"testing"
"github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks"
"git.vaala.cloud/VaalaCat/go-openai"
"git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
)
func TestUploadBatchFile(t *testing.T) {

36
chat.go
View File

@@ -14,6 +14,7 @@ const (
ChatMessageRoleAssistant = "assistant"
ChatMessageRoleFunction = "function"
ChatMessageRoleTool = "tool"
ChatMessageRoleDeveloper = "developer"
)
const chatCompletionsSuffix = "/chat/completions"
@@ -93,7 +94,7 @@ type ChatMessagePart struct {
type ChatCompletionMessage struct {
Role string `json:"role"`
Content string `json:"content"`
Content string `json:"content,omitempty"`
Refusal string `json:"refusal,omitempty"`
MultiContent []ChatMessagePart
@@ -103,6 +104,12 @@ type ChatCompletionMessage struct {
// - https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
Name string `json:"name,omitempty"`
// This property is used for the "reasoning" feature supported by deepseek-reasoner
// which is not in the official documentation.
// the doc from deepseek:
// - https://api-docs.deepseek.com/api/create-chat-completion#responses
ReasoningContent string `json:"reasoning_content,omitempty"`
FunctionCall *FunctionCall `json:"function_call,omitempty"`
// For Role=assistant prompts this may be set to the tool calls generated by the model, such as function calls.
@@ -123,6 +130,7 @@ func (m ChatCompletionMessage) MarshalJSON() ([]byte, error) {
Refusal string `json:"refusal,omitempty"`
MultiContent []ChatMessagePart `json:"content,omitempty"`
Name string `json:"name,omitempty"`
ReasoningContent string `json:"reasoning_content,omitempty"`
FunctionCall *FunctionCall `json:"function_call,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
@@ -132,10 +140,11 @@ func (m ChatCompletionMessage) MarshalJSON() ([]byte, error) {
msg := struct {
Role string `json:"role"`
Content string `json:"content"`
Content string `json:"content,omitempty"`
Refusal string `json:"refusal,omitempty"`
MultiContent []ChatMessagePart `json:"-"`
Name string `json:"name,omitempty"`
ReasoningContent string `json:"reasoning_content,omitempty"`
FunctionCall *FunctionCall `json:"function_call,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
@@ -150,6 +159,7 @@ func (m *ChatCompletionMessage) UnmarshalJSON(bs []byte) error {
Refusal string `json:"refusal,omitempty"`
MultiContent []ChatMessagePart
Name string `json:"name,omitempty"`
ReasoningContent string `json:"reasoning_content,omitempty"`
FunctionCall *FunctionCall `json:"function_call,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
@@ -165,6 +175,7 @@ func (m *ChatCompletionMessage) UnmarshalJSON(bs []byte) error {
Refusal string `json:"refusal,omitempty"`
MultiContent []ChatMessagePart `json:"content"`
Name string `json:"name,omitempty"`
ReasoningContent string `json:"reasoning_content,omitempty"`
FunctionCall *FunctionCall `json:"function_call,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
@@ -179,7 +190,7 @@ func (m *ChatCompletionMessage) UnmarshalJSON(bs []byte) error {
type ToolCall struct {
// Index is not nil only in chat completion chunk object
Index *int `json:"index,omitempty"`
ID string `json:"id"`
ID string `json:"id,omitempty"`
Type ToolType `json:"type"`
Function FunctionCall `json:"function"`
}
@@ -258,8 +269,19 @@ type ChatCompletionRequest struct {
// Store can be set to true to store the output of this completion request for use in distillations and evals.
// https://platform.openai.com/docs/api-reference/chat/create#chat-create-store
Store bool `json:"store,omitempty"`
// Controls effort on reasoning for reasoning models. It can be set to "low", "medium", or "high".
ReasoningEffort string `json:"reasoning_effort,omitempty"`
// Metadata to store with the completion.
Metadata map[string]string `json:"metadata,omitempty"`
IncludeReasoning *bool `json:"include_reasoning,omitempty"`
ReasoningFormat *string `json:"reasoning_format,omitempty"`
// Configuration for a predicted output.
Prediction *Prediction `json:"prediction,omitempty"`
// ChatTemplateKwargs provides a way to add non-standard parameters to the request body.
// Additional kwargs to pass to the template renderer. Will be accessible by the chat template.
// Such as think mode for qwen3. "chat_template_kwargs": {"enable_thinking": false}
// https://qwen.readthedocs.io/en/latest/deployment/vllm.html#thinking-non-thinking-modes
ChatTemplateKwargs map[string]any `json:"chat_template_kwargs,omitempty"`
}
type StreamOptions struct {
@@ -327,6 +349,11 @@ type LogProbs struct {
Content []LogProb `json:"content"`
}
type Prediction struct {
Content string `json:"content"`
Type string `json:"type"`
}
type FinishReason string
const (
@@ -390,7 +417,8 @@ func (c *Client) CreateChatCompletion(
return
}
if err = validateRequestForO1Models(request); err != nil {
reasoningValidator := NewReasoningValidator()
if err = reasoningValidator.Validate(request); err != nil {
return
}

View File

@@ -11,6 +11,12 @@ type ChatCompletionStreamChoiceDelta struct {
FunctionCall *FunctionCall `json:"function_call,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
Refusal string `json:"refusal,omitempty"`
// This property is used for the "reasoning" feature supported by deepseek-reasoner
// which is not in the official documentation.
// the doc from deepseek:
// - https://api-docs.deepseek.com/api/create-chat-completion#responses
ReasoningContent string `json:"reasoning_content,omitempty"`
}
type ChatCompletionStreamChoiceLogprobs struct {
@@ -80,7 +86,8 @@ func (c *Client) CreateChatCompletionStream(
}
request.Stream = true
if err = validateRequestForO1Models(request); err != nil {
reasoningValidator := NewReasoningValidator()
if err = reasoningValidator.Validate(request); err != nil {
return
}

View File

@@ -10,8 +10,8 @@ import (
"strconv"
"testing"
"github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks"
"git.vaala.cloud/VaalaCat/go-openai"
"git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
)
func TestChatCompletionsStreamWrongModel(t *testing.T) {
@@ -792,6 +792,223 @@ func compareChatResponses(r1, r2 openai.ChatCompletionStreamResponse) bool {
return true
}
func TestCreateChatCompletionStreamWithReasoningModel(t *testing.T) {
client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
dataBytes := []byte{}
//nolint:lll
dataBytes = append(dataBytes, []byte(`data: {"id":"1","object":"chat.completion.chunk","created":1729585728,"model":"o3-mini-2025-01-31","system_fingerprint":"fp_mini","choices":[{"index":0,"delta":{"role":"assistant","content":""},"finish_reason":null}]}`)...)
dataBytes = append(dataBytes, []byte("\n\n")...)
//nolint:lll
dataBytes = append(dataBytes, []byte(`data: {"id":"2","object":"chat.completion.chunk","created":1729585728,"model":"o3-mini-2025-01-31","system_fingerprint":"fp_mini","choices":[{"index":0,"delta":{"content":"Hello"},"finish_reason":null}]}`)...)
dataBytes = append(dataBytes, []byte("\n\n")...)
//nolint:lll
dataBytes = append(dataBytes, []byte(`data: {"id":"3","object":"chat.completion.chunk","created":1729585728,"model":"o3-mini-2025-01-31","system_fingerprint":"fp_mini","choices":[{"index":0,"delta":{"content":" from"},"finish_reason":null}]}`)...)
dataBytes = append(dataBytes, []byte("\n\n")...)
//nolint:lll
dataBytes = append(dataBytes, []byte(`data: {"id":"4","object":"chat.completion.chunk","created":1729585728,"model":"o3-mini-2025-01-31","system_fingerprint":"fp_mini","choices":[{"index":0,"delta":{"content":" O3Mini"},"finish_reason":null}]}`)...)
dataBytes = append(dataBytes, []byte("\n\n")...)
//nolint:lll
dataBytes = append(dataBytes, []byte(`data: {"id":"5","object":"chat.completion.chunk","created":1729585728,"model":"o3-mini-2025-01-31","system_fingerprint":"fp_mini","choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}`)...)
dataBytes = append(dataBytes, []byte("\n\n")...)
dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...)
_, err := w.Write(dataBytes)
checks.NoError(t, err, "Write error")
})
stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{
MaxCompletionTokens: 2000,
Model: openai.O3Mini20250131,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
Content: "Hello!",
},
},
Stream: true,
})
checks.NoError(t, err, "CreateCompletionStream returned error")
defer stream.Close()
expectedResponses := []openai.ChatCompletionStreamResponse{
{
ID: "1",
Object: "chat.completion.chunk",
Created: 1729585728,
Model: openai.O3Mini20250131,
SystemFingerprint: "fp_mini",
Choices: []openai.ChatCompletionStreamChoice{
{
Index: 0,
Delta: openai.ChatCompletionStreamChoiceDelta{
Role: "assistant",
},
},
},
},
{
ID: "2",
Object: "chat.completion.chunk",
Created: 1729585728,
Model: openai.O3Mini20250131,
SystemFingerprint: "fp_mini",
Choices: []openai.ChatCompletionStreamChoice{
{
Index: 0,
Delta: openai.ChatCompletionStreamChoiceDelta{
Content: "Hello",
},
},
},
},
{
ID: "3",
Object: "chat.completion.chunk",
Created: 1729585728,
Model: openai.O3Mini20250131,
SystemFingerprint: "fp_mini",
Choices: []openai.ChatCompletionStreamChoice{
{
Index: 0,
Delta: openai.ChatCompletionStreamChoiceDelta{
Content: " from",
},
},
},
},
{
ID: "4",
Object: "chat.completion.chunk",
Created: 1729585728,
Model: openai.O3Mini20250131,
SystemFingerprint: "fp_mini",
Choices: []openai.ChatCompletionStreamChoice{
{
Index: 0,
Delta: openai.ChatCompletionStreamChoiceDelta{
Content: " O3Mini",
},
},
},
},
{
ID: "5",
Object: "chat.completion.chunk",
Created: 1729585728,
Model: openai.O3Mini20250131,
SystemFingerprint: "fp_mini",
Choices: []openai.ChatCompletionStreamChoice{
{
Index: 0,
Delta: openai.ChatCompletionStreamChoiceDelta{},
FinishReason: "stop",
},
},
},
}
for ix, expectedResponse := range expectedResponses {
b, _ := json.Marshal(expectedResponse)
t.Logf("%d: %s", ix, string(b))
receivedResponse, streamErr := stream.Recv()
checks.NoError(t, streamErr, "stream.Recv() failed")
if !compareChatResponses(expectedResponse, receivedResponse) {
t.Errorf("Stream response %v is %v, expected %v", ix, receivedResponse, expectedResponse)
}
}
_, streamErr := stream.Recv()
if !errors.Is(streamErr, io.EOF) {
t.Errorf("stream.Recv() did not return EOF in the end: %v", streamErr)
}
}
func TestCreateChatCompletionStreamReasoningValidatorFails(t *testing.T) {
client, _, _ := setupOpenAITestServer()
stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{
MaxTokens: 100, // This will trigger the validator to fail
Model: openai.O3Mini,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
Content: "Hello!",
},
},
Stream: true,
})
if stream != nil {
t.Error("Expected nil stream when validation fails")
stream.Close()
}
if !errors.Is(err, openai.ErrReasoningModelMaxTokensDeprecated) {
t.Errorf("Expected ErrReasoningModelMaxTokensDeprecated, got: %v", err)
}
}
func TestCreateChatCompletionStreamO3ReasoningValidatorFails(t *testing.T) {
client, _, _ := setupOpenAITestServer()
stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{
MaxTokens: 100, // This will trigger the validator to fail
Model: openai.O3,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
Content: "Hello!",
},
},
Stream: true,
})
if stream != nil {
t.Error("Expected nil stream when validation fails")
stream.Close()
}
if !errors.Is(err, openai.ErrReasoningModelMaxTokensDeprecated) {
t.Errorf("Expected ErrReasoningModelMaxTokensDeprecated for O3, got: %v", err)
}
}
func TestCreateChatCompletionStreamO4MiniReasoningValidatorFails(t *testing.T) {
client, _, _ := setupOpenAITestServer()
stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{
MaxTokens: 100, // This will trigger the validator to fail
Model: openai.O4Mini,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
Content: "Hello!",
},
},
Stream: true,
})
if stream != nil {
t.Error("Expected nil stream when validation fails")
stream.Close()
}
if !errors.Is(err, openai.ErrReasoningModelMaxTokensDeprecated) {
t.Errorf("Expected ErrReasoningModelMaxTokensDeprecated for O4Mini, got: %v", err)
}
}
func compareChatStreamResponseChoices(c1, c2 openai.ChatCompletionStreamChoice) bool {
if c1.Index != c2.Index {
return false

View File

@@ -12,9 +12,9 @@ import (
"testing"
"time"
"github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks"
"github.com/sashabaranov/go-openai/jsonschema"
"git.vaala.cloud/VaalaCat/go-openai"
"git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
"git.vaala.cloud/VaalaCat/go-openai/jsonschema"
)
const (
@@ -64,7 +64,7 @@ func TestO1ModelsChatCompletionsDeprecatedFields(t *testing.T) {
MaxTokens: 5,
Model: openai.O1Preview,
},
expectedError: openai.ErrO1MaxTokensDeprecated,
expectedError: openai.ErrReasoningModelMaxTokensDeprecated,
},
{
name: "o1-mini_MaxTokens_deprecated",
@@ -72,7 +72,7 @@ func TestO1ModelsChatCompletionsDeprecatedFields(t *testing.T) {
MaxTokens: 5,
Model: openai.O1Mini,
},
expectedError: openai.ErrO1MaxTokensDeprecated,
expectedError: openai.ErrReasoningModelMaxTokensDeprecated,
},
}
@@ -104,41 +104,7 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) {
LogProbs: true,
Model: openai.O1Preview,
},
expectedError: openai.ErrO1BetaLimitationsLogprobs,
},
{
name: "message_type_unsupported",
in: openai.ChatCompletionRequest{
MaxCompletionTokens: 1000,
Model: openai.O1Mini,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleSystem,
},
},
},
expectedError: openai.ErrO1BetaLimitationsMessageTypes,
},
{
name: "tool_unsupported",
in: openai.ChatCompletionRequest{
MaxCompletionTokens: 1000,
Model: openai.O1Mini,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
},
{
Role: openai.ChatMessageRoleAssistant,
},
},
Tools: []openai.Tool{
{
Type: openai.ToolTypeFunction,
},
},
},
expectedError: openai.ErrO1BetaLimitationsTools,
expectedError: openai.ErrReasoningModelLimitationsLogprobs,
},
{
name: "set_temperature_unsupported",
@@ -155,7 +121,7 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) {
},
Temperature: float32(2),
},
expectedError: openai.ErrO1BetaLimitationsOther,
expectedError: openai.ErrReasoningModelLimitationsOther,
},
{
name: "set_top_unsupported",
@@ -173,7 +139,7 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) {
Temperature: float32(1),
TopP: float32(0.1),
},
expectedError: openai.ErrO1BetaLimitationsOther,
expectedError: openai.ErrReasoningModelLimitationsOther,
},
{
name: "set_n_unsupported",
@@ -192,7 +158,7 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) {
TopP: float32(1),
N: 2,
},
expectedError: openai.ErrO1BetaLimitationsOther,
expectedError: openai.ErrReasoningModelLimitationsOther,
},
{
name: "set_presence_penalty_unsupported",
@@ -209,7 +175,7 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) {
},
PresencePenalty: float32(1),
},
expectedError: openai.ErrO1BetaLimitationsOther,
expectedError: openai.ErrReasoningModelLimitationsOther,
},
{
name: "set_frequency_penalty_unsupported",
@@ -226,7 +192,127 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) {
},
FrequencyPenalty: float32(0.1),
},
expectedError: openai.ErrO1BetaLimitationsOther,
expectedError: openai.ErrReasoningModelLimitationsOther,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
config := openai.DefaultConfig("whatever")
config.BaseURL = "http://localhost/v1"
client := openai.NewClientWithConfig(config)
ctx := context.Background()
_, err := client.CreateChatCompletion(ctx, tt.in)
checks.HasError(t, err)
msg := fmt.Sprintf("CreateChatCompletion should return wrong model error, returned: %s", err)
checks.ErrorIs(t, err, tt.expectedError, msg)
})
}
}
func TestO3ModelsChatCompletionsBetaLimitations(t *testing.T) {
tests := []struct {
name string
in openai.ChatCompletionRequest
expectedError error
}{
{
name: "log_probs_unsupported",
in: openai.ChatCompletionRequest{
MaxCompletionTokens: 1000,
LogProbs: true,
Model: openai.O3Mini,
},
expectedError: openai.ErrReasoningModelLimitationsLogprobs,
},
{
name: "set_temperature_unsupported",
in: openai.ChatCompletionRequest{
MaxCompletionTokens: 1000,
Model: openai.O3Mini,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
},
{
Role: openai.ChatMessageRoleAssistant,
},
},
Temperature: float32(2),
},
expectedError: openai.ErrReasoningModelLimitationsOther,
},
{
name: "set_top_unsupported",
in: openai.ChatCompletionRequest{
MaxCompletionTokens: 1000,
Model: openai.O3Mini,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
},
{
Role: openai.ChatMessageRoleAssistant,
},
},
Temperature: float32(1),
TopP: float32(0.1),
},
expectedError: openai.ErrReasoningModelLimitationsOther,
},
{
name: "set_n_unsupported",
in: openai.ChatCompletionRequest{
MaxCompletionTokens: 1000,
Model: openai.O3Mini,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
},
{
Role: openai.ChatMessageRoleAssistant,
},
},
Temperature: float32(1),
TopP: float32(1),
N: 2,
},
expectedError: openai.ErrReasoningModelLimitationsOther,
},
{
name: "set_presence_penalty_unsupported",
in: openai.ChatCompletionRequest{
MaxCompletionTokens: 1000,
Model: openai.O3Mini,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
},
{
Role: openai.ChatMessageRoleAssistant,
},
},
PresencePenalty: float32(1),
},
expectedError: openai.ErrReasoningModelLimitationsOther,
},
{
name: "set_frequency_penalty_unsupported",
in: openai.ChatCompletionRequest{
MaxCompletionTokens: 1000,
Model: openai.O3Mini,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
},
{
Role: openai.ChatMessageRoleAssistant,
},
},
FrequencyPenalty: float32(0.1),
},
expectedError: openai.ErrReasoningModelLimitationsOther,
},
}
@@ -308,6 +394,40 @@ func TestO1ModelChatCompletions(t *testing.T) {
checks.NoError(t, err, "CreateChatCompletion error")
}
func TestO3ModelChatCompletions(t *testing.T) {
client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint)
_, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{
Model: openai.O3Mini,
MaxCompletionTokens: 1000,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
Content: "Hello!",
},
},
})
checks.NoError(t, err, "CreateChatCompletion error")
}
func TestDeepseekR1ModelChatCompletions(t *testing.T) {
client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler("/v1/chat/completions", handleDeepseekR1ChatCompletionEndpoint)
_, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{
Model: "deepseek-reasoner",
MaxCompletionTokens: 100,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
Content: "Hello!",
},
},
})
checks.NoError(t, err, "CreateChatCompletion error")
}
// TestCompletions Tests the completions endpoint of the API using the mocked server.
func TestChatCompletionsWithHeaders(t *testing.T) {
client, server, teardown := setupOpenAITestServer()
@@ -631,7 +751,7 @@ func TestMultipartChatMessageSerialization(t *testing.T) {
t.Fatalf("Unexpected error")
}
res = strings.ReplaceAll(string(s), " ", "")
if res != `{"role":"user","content":""}` {
if res != `{"role":"user"}` {
t.Fatalf("invalid message: %s", string(s))
}
}
@@ -719,6 +839,68 @@ func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, string(resBytes))
}
func handleDeepseekR1ChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
var err error
var resBytes []byte
// completions only accepts POST requests
if r.Method != "POST" {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
var completionReq openai.ChatCompletionRequest
if completionReq, err = getChatCompletionBody(r); err != nil {
http.Error(w, "could not read request", http.StatusInternalServerError)
return
}
res := openai.ChatCompletionResponse{
ID: strconv.Itoa(int(time.Now().Unix())),
Object: "test-object",
Created: time.Now().Unix(),
// would be nice to validate Model during testing, but
// this may not be possible with how much upkeep
// would be required / wouldn't make much sense
Model: completionReq.Model,
}
// create completions
n := completionReq.N
if n == 0 {
n = 1
}
if completionReq.MaxCompletionTokens == 0 {
completionReq.MaxCompletionTokens = 1000
}
for i := 0; i < n; i++ {
reasoningContent := "User says hello! And I need to reply"
completionStr := strings.Repeat("a", completionReq.MaxCompletionTokens-numTokens(reasoningContent))
res.Choices = append(res.Choices, openai.ChatCompletionChoice{
Message: openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleAssistant,
ReasoningContent: reasoningContent,
Content: completionStr,
},
Index: i,
})
}
inputTokens := numTokens(completionReq.Messages[0].Content) * n
completionTokens := completionReq.MaxTokens * n
res.Usage = openai.Usage{
PromptTokens: inputTokens,
CompletionTokens: completionTokens,
TotalTokens: inputTokens + completionTokens,
}
resBytes, _ = json.Marshal(res)
w.Header().Set(xCustomHeader, xCustomHeaderValue)
for k, v := range rateLimitHeaders {
switch val := v.(type) {
case int:
w.Header().Set(k, strconv.Itoa(val))
default:
w.Header().Set(k, fmt.Sprintf("%s", v))
}
}
fmt.Fprintln(w, string(resBytes))
}
// getChatCompletionBody Returns the body of the request to create a completion.
func getChatCompletionBody(r *http.Request) (openai.ChatCompletionRequest, error) {
completion := openai.ChatCompletionRequest{}

View File

@@ -10,7 +10,7 @@ import (
"net/url"
"strings"
utils "github.com/sashabaranov/go-openai/internal"
utils "git.vaala.cloud/VaalaCat/go-openai/internal"
)
// Client is OpenAI GPT-3 API client.
@@ -182,13 +182,21 @@ func sendRequestStream[T streamable](client *Client, req *http.Request) (*stream
func (c *Client) setCommonHeaders(req *http.Request) {
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/reference#authentication
switch c.config.APIType {
case APITypeAzure, APITypeCloudflareAzure:
// Azure API Key authentication
if c.config.APIType == APITypeAzure || c.config.APIType == APITypeCloudflareAzure {
req.Header.Set(AzureAPIKeyHeader, c.config.authToken)
} else if c.config.authToken != "" {
// OpenAI or Azure AD authentication
case APITypeAnthropic:
// https://docs.anthropic.com/en/api/versioning
req.Header.Set("anthropic-version", c.config.APIVersion)
case APITypeOpenAI, APITypeAzureAD:
fallthrough
default:
if c.config.authToken != "" {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken))
}
}
if c.config.OrgID != "" {
req.Header.Set("OpenAI-Organization", c.config.OrgID)
}

View File

@@ -10,8 +10,8 @@ import (
"reflect"
"testing"
"github.com/sashabaranov/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks"
"git.vaala.cloud/VaalaCat/go-openai/internal/test"
"git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
)
var errTestRequestBuilderFailed = errors.New("test request builder failed")
@@ -39,6 +39,21 @@ func TestClient(t *testing.T) {
}
}
func TestSetCommonHeadersAnthropic(t *testing.T) {
config := DefaultAnthropicConfig("mock-token", "")
client := NewClientWithConfig(config)
req, err := http.NewRequest("GET", "http://example.com", nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
client.setCommonHeaders(req)
if got := req.Header.Get("anthropic-version"); got != AnthropicAPIVersion {
t.Errorf("Expected anthropic-version header to be %q, got %q", AnthropicAPIVersion, got)
}
}
func TestDecodeResponse(t *testing.T) {
stringInput := ""

View File

@@ -15,6 +15,8 @@ type Usage struct {
type CompletionTokensDetails struct {
AudioTokens int `json:"audio_tokens"`
ReasoningTokens int `json:"reasoning_tokens"`
AcceptedPredictionTokens int `json:"accepted_prediction_tokens"`
RejectedPredictionTokens int `json:"rejected_prediction_tokens"`
}
// PromptTokensDetails Breakdown of tokens used in the prompt.

View File

@@ -2,24 +2,9 @@ package openai
import (
"context"
"errors"
"net/http"
)
var (
ErrO1MaxTokensDeprecated = errors.New("this model is not supported MaxTokens, please use MaxCompletionTokens") //nolint:lll
ErrCompletionUnsupportedModel = errors.New("this model is not supported with this method, please use CreateChatCompletion client method instead") //nolint:lll
ErrCompletionStreamNotSupported = errors.New("streaming is not supported with this method, please use CreateCompletionStream") //nolint:lll
ErrCompletionRequestPromptTypeNotSupported = errors.New("the type of CompletionRequest.Prompt only supports string and []string") //nolint:lll
)
var (
ErrO1BetaLimitationsMessageTypes = errors.New("this model has beta-limitations, user and assistant messages only, system messages are not supported") //nolint:lll
ErrO1BetaLimitationsTools = errors.New("this model has beta-limitations, tools, function calling, and response format parameters are not supported") //nolint:lll
ErrO1BetaLimitationsLogprobs = errors.New("this model has beta-limitations, logprobs not supported") //nolint:lll
ErrO1BetaLimitationsOther = errors.New("this model has beta-limitations, temperature, top_p and n are fixed at 1, while presence_penalty and frequency_penalty are fixed at 0") //nolint:lll
)
// GPT3 Defines the models provided by OpenAI to use when generating
// completions from OpenAI.
// GPT3 Models are designed for text-based tasks. For code-specific
@@ -29,6 +14,14 @@ const (
O1Mini20240912 = "o1-mini-2024-09-12"
O1Preview = "o1-preview"
O1Preview20240912 = "o1-preview-2024-09-12"
O1 = "o1"
O120241217 = "o1-2024-12-17"
O3 = "o3"
O320250416 = "o3-2025-04-16"
O3Mini = "o3-mini"
O3Mini20250131 = "o3-mini-2025-01-31"
O4Mini = "o4-mini"
O4Mini20250416 = "o4-mini-2025-04-16"
GPT432K0613 = "gpt-4-32k-0613"
GPT432K0314 = "gpt-4-32k-0314"
GPT432K = "gpt-4-32k"
@@ -37,6 +30,7 @@ const (
GPT4o = "gpt-4o"
GPT4o20240513 = "gpt-4o-2024-05-13"
GPT4o20240806 = "gpt-4o-2024-08-06"
GPT4o20241120 = "gpt-4o-2024-11-20"
GPT4oLatest = "chatgpt-4o-latest"
GPT4oMini = "gpt-4o-mini"
GPT4oMini20240718 = "gpt-4o-mini-2024-07-18"
@@ -47,6 +41,14 @@ const (
GPT4TurboPreview = "gpt-4-turbo-preview"
GPT4VisionPreview = "gpt-4-vision-preview"
GPT4 = "gpt-4"
GPT4Dot1 = "gpt-4.1"
GPT4Dot120250414 = "gpt-4.1-2025-04-14"
GPT4Dot1Mini = "gpt-4.1-mini"
GPT4Dot1Mini20250414 = "gpt-4.1-mini-2025-04-14"
GPT4Dot1Nano = "gpt-4.1-nano"
GPT4Dot1Nano20250414 = "gpt-4.1-nano-2025-04-14"
GPT4Dot5Preview = "gpt-4.5-preview"
GPT4Dot5Preview20250227 = "gpt-4.5-preview-2025-02-27"
GPT3Dot5Turbo0125 = "gpt-3.5-turbo-0125"
GPT3Dot5Turbo1106 = "gpt-3.5-turbo-1106"
GPT3Dot5Turbo0613 = "gpt-3.5-turbo-0613"
@@ -93,21 +95,18 @@ const (
CodexCodeDavinci001 = "code-davinci-001"
)
// O1SeriesModels List of new Series of OpenAI models.
// Some old api attributes not supported.
var O1SeriesModels = map[string]struct{}{
O1Mini: {},
O1Mini20240912: {},
O1Preview: {},
O1Preview20240912: {},
}
var disabledModelsForEndpoints = map[string]map[string]bool{
"/completions": {
O1Mini: true,
O1Mini20240912: true,
O1Preview: true,
O1Preview20240912: true,
O3Mini: true,
O3Mini20250131: true,
O4Mini: true,
O4Mini20250416: true,
O3: true,
O320250416: true,
GPT3Dot5Turbo: true,
GPT3Dot5Turbo0301: true,
GPT3Dot5Turbo0613: true,
@@ -116,9 +115,12 @@ var disabledModelsForEndpoints = map[string]map[string]bool{
GPT3Dot5Turbo16K: true,
GPT3Dot5Turbo16K0613: true,
GPT4: true,
GPT4Dot5Preview: true,
GPT4Dot5Preview20250227: true,
GPT4o: true,
GPT4o20240513: true,
GPT4o20240806: true,
GPT4o20241120: true,
GPT4oLatest: true,
GPT4oMini: true,
GPT4oMini20240718: true,
@@ -133,6 +135,13 @@ var disabledModelsForEndpoints = map[string]map[string]bool{
GPT432K: true,
GPT432K0314: true,
GPT432K0613: true,
O1: true,
GPT4Dot1: true,
GPT4Dot120250414: true,
GPT4Dot1Mini: true,
GPT4Dot1Mini20250414: true,
GPT4Dot1Nano: true,
GPT4Dot1Nano20250414: true,
},
chatCompletionsSuffix: {
CodexCodeDavinci002: true,
@@ -179,64 +188,6 @@ func checkPromptType(prompt any) bool {
return true // all items in the slice are string, so it is []string
}
var unsupportedToolsForO1Models = map[ToolType]struct{}{
ToolTypeFunction: {},
}
var availableMessageRoleForO1Models = map[string]struct{}{
ChatMessageRoleUser: {},
ChatMessageRoleAssistant: {},
}
// validateRequestForO1Models checks for deprecated fields of OpenAI models.
func validateRequestForO1Models(request ChatCompletionRequest) error {
if _, found := O1SeriesModels[request.Model]; !found {
return nil
}
if request.MaxTokens > 0 {
return ErrO1MaxTokensDeprecated
}
// Logprobs: not supported.
if request.LogProbs {
return ErrO1BetaLimitationsLogprobs
}
// Message types: user and assistant messages only, system messages are not supported.
for _, m := range request.Messages {
if _, found := availableMessageRoleForO1Models[m.Role]; !found {
return ErrO1BetaLimitationsMessageTypes
}
}
// Tools: tools, function calling, and response format parameters are not supported
for _, t := range request.Tools {
if _, found := unsupportedToolsForO1Models[t.Type]; found {
return ErrO1BetaLimitationsTools
}
}
// Other: temperature, top_p and n are fixed at 1, while presence_penalty and frequency_penalty are fixed at 0.
if request.Temperature > 0 && request.Temperature != 1 {
return ErrO1BetaLimitationsOther
}
if request.TopP > 0 && request.TopP != 1 {
return ErrO1BetaLimitationsOther
}
if request.N > 0 && request.N != 1 {
return ErrO1BetaLimitationsOther
}
if request.PresencePenalty > 0 {
return ErrO1BetaLimitationsOther
}
if request.FrequencyPenalty > 0 {
return ErrO1BetaLimitationsOther
}
return nil
}
// CompletionRequest represents a request structure for completion API.
type CompletionRequest struct {
Model string `json:"model"`
@@ -264,6 +215,8 @@ type CompletionRequest struct {
Temperature float32 `json:"temperature,omitempty"`
TopP float32 `json:"top_p,omitempty"`
User string `json:"user,omitempty"`
// Options for streaming response. Only set this when you set stream: true.
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
}
// CompletionChoice represents one of possible completions.

View File

@@ -12,8 +12,8 @@ import (
"testing"
"time"
"github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks"
"git.vaala.cloud/VaalaCat/go-openai"
"git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
)
func TestCompletionsWrongModel(t *testing.T) {
@@ -33,6 +33,42 @@ func TestCompletionsWrongModel(t *testing.T) {
}
}
// TestCompletionsWrongModelO3 Tests the completions endpoint with O3 model which is not supported.
func TestCompletionsWrongModelO3(t *testing.T) {
config := openai.DefaultConfig("whatever")
config.BaseURL = "http://localhost/v1"
client := openai.NewClientWithConfig(config)
_, err := client.CreateCompletion(
context.Background(),
openai.CompletionRequest{
MaxTokens: 5,
Model: openai.O3,
},
)
if !errors.Is(err, openai.ErrCompletionUnsupportedModel) {
t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel for O3, but returned: %v", err)
}
}
// TestCompletionsWrongModelO4Mini Tests the completions endpoint with O4Mini model which is not supported.
func TestCompletionsWrongModelO4Mini(t *testing.T) {
config := openai.DefaultConfig("whatever")
config.BaseURL = "http://localhost/v1"
client := openai.NewClientWithConfig(config)
_, err := client.CreateCompletion(
context.Background(),
openai.CompletionRequest{
MaxTokens: 5,
Model: openai.O4Mini,
},
)
if !errors.Is(err, openai.ErrCompletionUnsupportedModel) {
t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel for O4Mini, but returned: %v", err)
}
}
func TestCompletionWithStream(t *testing.T) {
config := openai.DefaultConfig("whatever")
client := openai.NewClientWithConfig(config)
@@ -181,3 +217,86 @@ func getCompletionBody(r *http.Request) (openai.CompletionRequest, error) {
}
return completion, nil
}
// TestCompletionWithO1Model Tests that O1 model is not supported for completion endpoint.
func TestCompletionWithO1Model(t *testing.T) {
config := openai.DefaultConfig("whatever")
config.BaseURL = "http://localhost/v1"
client := openai.NewClientWithConfig(config)
_, err := client.CreateCompletion(
context.Background(),
openai.CompletionRequest{
MaxTokens: 5,
Model: openai.O1,
},
)
if !errors.Is(err, openai.ErrCompletionUnsupportedModel) {
t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel for O1 model, but returned: %v", err)
}
}
// TestCompletionWithGPT4DotModels Tests that newer GPT4 models are not supported for completion endpoint.
func TestCompletionWithGPT4DotModels(t *testing.T) {
config := openai.DefaultConfig("whatever")
config.BaseURL = "http://localhost/v1"
client := openai.NewClientWithConfig(config)
models := []string{
openai.GPT4Dot1,
openai.GPT4Dot120250414,
openai.GPT4Dot1Mini,
openai.GPT4Dot1Mini20250414,
openai.GPT4Dot1Nano,
openai.GPT4Dot1Nano20250414,
openai.GPT4Dot5Preview,
openai.GPT4Dot5Preview20250227,
}
for _, model := range models {
t.Run(model, func(t *testing.T) {
_, err := client.CreateCompletion(
context.Background(),
openai.CompletionRequest{
MaxTokens: 5,
Model: model,
},
)
if !errors.Is(err, openai.ErrCompletionUnsupportedModel) {
t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel for %s model, but returned: %v", model, err)
}
})
}
}
// TestCompletionWithGPT4oModels Tests that GPT4o models are not supported for completion endpoint.
func TestCompletionWithGPT4oModels(t *testing.T) {
config := openai.DefaultConfig("whatever")
config.BaseURL = "http://localhost/v1"
client := openai.NewClientWithConfig(config)
models := []string{
openai.GPT4o,
openai.GPT4o20240513,
openai.GPT4o20240806,
openai.GPT4o20241120,
openai.GPT4oLatest,
openai.GPT4oMini,
openai.GPT4oMini20240718,
}
for _, model := range models {
t.Run(model, func(t *testing.T) {
_, err := client.CreateCompletion(
context.Background(),
openai.CompletionRequest{
MaxTokens: 5,
Model: model,
},
)
if !errors.Is(err, openai.ErrCompletionUnsupportedModel) {
t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel for %s model, but returned: %v", model, err)
}
})
}
}

View File

@@ -11,6 +11,8 @@ const (
azureAPIPrefix = "openai"
azureDeploymentsPrefix = "deployments"
AnthropicAPIVersion = "2023-06-01"
)
type APIType string
@@ -20,6 +22,7 @@ const (
APITypeAzure APIType = "AZURE"
APITypeAzureAD APIType = "AZURE_AD"
APITypeCloudflareAzure APIType = "CLOUDFLARE_AZURE"
APITypeAnthropic APIType = "ANTHROPIC"
)
const AzureAPIKeyHeader = "api-key"
@@ -37,7 +40,7 @@ type ClientConfig struct {
BaseURL string
OrgID string
APIType APIType
APIVersion string // required when APIType is APITypeAzure or APITypeAzureAD
APIVersion string // required when APIType is APITypeAzure or APITypeAzureAD or APITypeAnthropic
AssistantVersion string
AzureModelMapperFunc func(model string) string // replace model to azure deployment name func
HTTPClient HTTPDoer
@@ -76,6 +79,23 @@ func DefaultAzureConfig(apiKey, baseURL string) ClientConfig {
}
}
func DefaultAnthropicConfig(apiKey, baseURL string) ClientConfig {
if baseURL == "" {
baseURL = "https://api.anthropic.com/v1"
}
return ClientConfig{
authToken: apiKey,
BaseURL: baseURL,
OrgID: "",
APIType: APITypeAnthropic,
APIVersion: AnthropicAPIVersion,
HTTPClient: &http.Client{},
EmptyMessagesLimit: defaultEmptyMessagesLimit,
}
}
func (ClientConfig) String() string {
return "<OpenAI API ClientConfig>"
}

View File

@@ -3,7 +3,7 @@ package openai_test
import (
"testing"
"github.com/sashabaranov/go-openai"
"git.vaala.cloud/VaalaCat/go-openai"
)
func TestGetAzureDeploymentByModel(t *testing.T) {
@@ -60,3 +60,64 @@ func TestGetAzureDeploymentByModel(t *testing.T) {
})
}
}
func TestDefaultAnthropicConfig(t *testing.T) {
apiKey := "test-key"
baseURL := "https://api.anthropic.com/v1"
config := openai.DefaultAnthropicConfig(apiKey, baseURL)
if config.APIType != openai.APITypeAnthropic {
t.Errorf("Expected APIType to be %v, got %v", openai.APITypeAnthropic, config.APIType)
}
if config.APIVersion != openai.AnthropicAPIVersion {
t.Errorf("Expected APIVersion to be 2023-06-01, got %v", config.APIVersion)
}
if config.BaseURL != baseURL {
t.Errorf("Expected BaseURL to be %v, got %v", baseURL, config.BaseURL)
}
if config.EmptyMessagesLimit != 300 {
t.Errorf("Expected EmptyMessagesLimit to be 300, got %v", config.EmptyMessagesLimit)
}
}
func TestDefaultAnthropicConfigWithEmptyValues(t *testing.T) {
config := openai.DefaultAnthropicConfig("", "")
if config.APIType != openai.APITypeAnthropic {
t.Errorf("Expected APIType to be %v, got %v", openai.APITypeAnthropic, config.APIType)
}
if config.APIVersion != openai.AnthropicAPIVersion {
t.Errorf("Expected APIVersion to be %s, got %v", openai.AnthropicAPIVersion, config.APIVersion)
}
expectedBaseURL := "https://api.anthropic.com/v1"
if config.BaseURL != expectedBaseURL {
t.Errorf("Expected BaseURL to be %v, got %v", expectedBaseURL, config.BaseURL)
}
}
func TestClientConfigString(t *testing.T) {
// String() should always return the constant value
conf := openai.DefaultConfig("dummy-token")
expected := "<OpenAI API ClientConfig>"
got := conf.String()
if got != expected {
t.Errorf("ClientConfig.String() = %q; want %q", got, expected)
}
}
func TestGetAzureDeploymentByModel_NoMapper(t *testing.T) {
// On a zero-value or DefaultConfig, AzureModelMapperFunc is nil,
// so GetAzureDeploymentByModel should just return the input model.
conf := openai.DefaultConfig("dummy-token")
model := "some-model"
got := conf.GetAzureDeploymentByModel(model)
if got != model {
t.Errorf("GetAzureDeploymentByModel(%q) = %q; want %q", model, got, model)
}
}

View File

@@ -9,8 +9,8 @@ import (
"testing"
"time"
"github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks"
"git.vaala.cloud/VaalaCat/go-openai"
"git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
)
// TestEdits Tests the edits endpoint of the API using the mocked server.

View File

@@ -11,8 +11,8 @@ import (
"reflect"
"testing"
"github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks"
"git.vaala.cloud/VaalaCat/go-openai"
"git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
)
func TestEmbedding(t *testing.T) {

View File

@@ -7,8 +7,8 @@ import (
"net/http"
"testing"
"github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks"
"git.vaala.cloud/VaalaCat/go-openai"
"git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
)
// TestGetEngine Tests the retrieve engine endpoint of the API using the mocked server.

View File

@@ -54,7 +54,7 @@ func (e *APIError) UnmarshalJSON(data []byte) (err error) {
err = json.Unmarshal(rawMap["message"], &e.Message)
if err != nil {
// If the parameter field of a function call is invalid as a JSON schema
// refs: https://github.com/sashabaranov/go-openai/issues/381
// refs: https://git.vaala.cloud/VaalaCat/go-openai/issues/381
var messages []string
err = json.Unmarshal(rawMap["message"], &messages)
if err != nil {
@@ -64,7 +64,7 @@ func (e *APIError) UnmarshalJSON(data []byte) (err error) {
}
// optional fields for azure openai
// refs: https://github.com/sashabaranov/go-openai/issues/343
// refs: https://git.vaala.cloud/VaalaCat/go-openai/issues/343
if _, ok := rawMap["type"]; ok {
err = json.Unmarshal(rawMap["type"], &e.Type)
if err != nil {

View File

@@ -6,7 +6,7 @@ import (
"reflect"
"testing"
"github.com/sashabaranov/go-openai"
"git.vaala.cloud/VaalaCat/go-openai"
)
func TestAPIErrorUnmarshalJSON(t *testing.T) {

View File

@@ -11,7 +11,7 @@ import (
"net/url"
"os"
"github.com/sashabaranov/go-openai"
"git.vaala.cloud/VaalaCat/go-openai"
)
func Example() {

View File

@@ -6,7 +6,7 @@ import (
"fmt"
"os"
"github.com/sashabaranov/go-openai"
"git.vaala.cloud/VaalaCat/go-openai"
)
func main() {

View File

@@ -5,8 +5,8 @@ import (
"fmt"
"os"
"github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/jsonschema"
"git.vaala.cloud/VaalaCat/go-openai"
"git.vaala.cloud/VaalaCat/go-openai/jsonschema"
)
func main() {

View File

@@ -5,7 +5,7 @@ import (
"fmt"
"os"
"github.com/sashabaranov/go-openai"
"git.vaala.cloud/VaalaCat/go-openai"
)
func main() {

View File

@@ -5,7 +5,7 @@ import (
"fmt"
"os"
"github.com/sashabaranov/go-openai"
"git.vaala.cloud/VaalaCat/go-openai"
)
func main() {

View File

@@ -6,7 +6,7 @@ import (
"fmt"
"os"
"github.com/sashabaranov/go-openai"
"git.vaala.cloud/VaalaCat/go-openai"
)
func main() {

View File

@@ -12,8 +12,8 @@ import (
"testing"
"time"
"github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks"
"git.vaala.cloud/VaalaCat/go-openai"
"git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
)
func TestFileBytesUpload(t *testing.T) {

View File

@@ -7,8 +7,8 @@ import (
"os"
"testing"
utils "github.com/sashabaranov/go-openai/internal"
"github.com/sashabaranov/go-openai/internal/test/checks"
utils "git.vaala.cloud/VaalaCat/go-openai/internal"
"git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
)
func TestFileBytesUploadWithFailingFormBuilder(t *testing.T) {

View File

@@ -7,8 +7,8 @@ import (
"net/http"
"testing"
"github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks"
"git.vaala.cloud/VaalaCat/go-openai"
"git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
)
const testFineTuneID = "fine-tune-id"

View File

@@ -7,8 +7,8 @@ import (
"net/http"
"testing"
"github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks"
"git.vaala.cloud/VaalaCat/go-openai"
"git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
)
const testFineTuninigJobID = "fine-tuning-job-id"

2
go.mod
View File

@@ -1,3 +1,3 @@
module github.com/sashabaranov/go-openai
module git.vaala.cloud/VaalaCat/go-openai
go 1.18

View File

@@ -3,8 +3,8 @@ package openai
import (
"bytes"
"context"
"io"
"net/http"
"os"
"strconv"
)
@@ -13,31 +13,62 @@ const (
CreateImageSize256x256 = "256x256"
CreateImageSize512x512 = "512x512"
CreateImageSize1024x1024 = "1024x1024"
// dall-e-3 supported only.
CreateImageSize1792x1024 = "1792x1024"
CreateImageSize1024x1792 = "1024x1792"
// gpt-image-1 supported only.
CreateImageSize1536x1024 = "1536x1024" // Landscape
CreateImageSize1024x1536 = "1024x1536" // Portrait
)
const (
CreateImageResponseFormatURL = "url"
// dall-e-2 and dall-e-3 only.
CreateImageResponseFormatB64JSON = "b64_json"
CreateImageResponseFormatURL = "url"
)
const (
CreateImageModelDallE2 = "dall-e-2"
CreateImageModelDallE3 = "dall-e-3"
CreateImageModelGptImage1 = "gpt-image-1"
)
const (
CreateImageQualityHD = "hd"
CreateImageQualityStandard = "standard"
// gpt-image-1 only.
CreateImageQualityHigh = "high"
CreateImageQualityMedium = "medium"
CreateImageQualityLow = "low"
)
const (
// dall-e-3 only.
CreateImageStyleVivid = "vivid"
CreateImageStyleNatural = "natural"
)
const (
// gpt-image-1 only.
CreateImageBackgroundTransparent = "transparent"
CreateImageBackgroundOpaque = "opaque"
)
const (
// gpt-image-1 only.
CreateImageModerationLow = "low"
)
const (
// gpt-image-1 only.
CreateImageOutputFormatPNG = "png"
CreateImageOutputFormatJPEG = "jpeg"
CreateImageOutputFormatWEBP = "webp"
)
// ImageRequest represents the request structure for the image API.
type ImageRequest struct {
Prompt string `json:"prompt,omitempty"`
@@ -48,16 +79,35 @@ type ImageRequest struct {
Style string `json:"style,omitempty"`
ResponseFormat string `json:"response_format,omitempty"`
User string `json:"user,omitempty"`
Background string `json:"background,omitempty"`
Moderation string `json:"moderation,omitempty"`
OutputCompression int `json:"output_compression,omitempty"`
OutputFormat string `json:"output_format,omitempty"`
}
// ImageResponse represents a response structure for image API.
type ImageResponse struct {
Created int64 `json:"created,omitempty"`
Data []ImageResponseDataInner `json:"data,omitempty"`
Usage ImageResponseUsage `json:"usage,omitempty"`
httpHeader
}
// ImageResponseInputTokensDetails represents the token breakdown for input tokens.
type ImageResponseInputTokensDetails struct {
TextTokens int `json:"text_tokens,omitempty"`
ImageTokens int `json:"image_tokens,omitempty"`
}
// ImageResponseUsage represents the token usage information for image API.
type ImageResponseUsage struct {
TotalTokens int `json:"total_tokens,omitempty"`
InputTokens int `json:"input_tokens,omitempty"`
OutputTokens int `json:"output_tokens,omitempty"`
InputTokensDetails ImageResponseInputTokensDetails `json:"input_tokens_details,omitempty"`
}
// ImageResponseDataInner represents a response data structure for image API.
type ImageResponseDataInner struct {
URL string `json:"url,omitempty"`
@@ -84,13 +134,15 @@ func (c *Client) CreateImage(ctx context.Context, request ImageRequest) (respons
// ImageEditRequest represents the request structure for the image API.
type ImageEditRequest struct {
Image *os.File `json:"image,omitempty"`
Mask *os.File `json:"mask,omitempty"`
Image io.Reader `json:"image,omitempty"`
Mask io.Reader `json:"mask,omitempty"`
Prompt string `json:"prompt,omitempty"`
Model string `json:"model,omitempty"`
N int `json:"n,omitempty"`
Size string `json:"size,omitempty"`
ResponseFormat string `json:"response_format,omitempty"`
Quality string `json:"quality,omitempty"`
User string `json:"user,omitempty"`
}
// CreateEditImage - API call to create an image. This is the main endpoint of the DALL-E API.
@@ -98,15 +150,16 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest)
body := &bytes.Buffer{}
builder := c.createFormBuilder(body)
// image
err = builder.CreateFormFile("image", request.Image)
// image, filename is not required
err = builder.CreateFormFileReader("image", request.Image, "")
if err != nil {
return
}
// mask, it is optional
if request.Mask != nil {
err = builder.CreateFormFile("mask", request.Mask)
// mask, filename is not required
err = builder.CreateFormFileReader("mask", request.Mask, "")
if err != nil {
return
}
@@ -154,11 +207,12 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest)
// ImageVariRequest represents the request structure for the image API.
type ImageVariRequest struct {
Image *os.File `json:"image,omitempty"`
Image io.Reader `json:"image,omitempty"`
Model string `json:"model,omitempty"`
N int `json:"n,omitempty"`
Size string `json:"size,omitempty"`
ResponseFormat string `json:"response_format,omitempty"`
User string `json:"user,omitempty"`
}
// CreateVariImage - API call to create an image variation. This is the main endpoint of the DALL-E API.
@@ -167,8 +221,8 @@ func (c *Client) CreateVariImage(ctx context.Context, request ImageVariRequest)
body := &bytes.Buffer{}
builder := c.createFormBuilder(body)
// image
err = builder.CreateFormFile("image", request.Image)
// image, filename is not required
err = builder.CreateFormFileReader("image", request.Image, "")
if err != nil {
return
}

View File

@@ -7,11 +7,12 @@ import (
"io"
"net/http"
"os"
"path/filepath"
"testing"
"time"
"github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks"
"git.vaala.cloud/VaalaCat/go-openai"
"git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
)
func TestImages(t *testing.T) {
@@ -86,24 +87,17 @@ func TestImageEdit(t *testing.T) {
defer teardown()
server.RegisterHandler("/v1/images/edits", handleEditImageEndpoint)
origin, err := os.Create("image.png")
origin, err := os.Create(filepath.Join(t.TempDir(), "image.png"))
if err != nil {
t.Error("open origin file error")
return
t.Fatalf("open origin file error: %v", err)
}
defer origin.Close()
mask, err := os.Create("mask.png")
mask, err := os.Create(filepath.Join(t.TempDir(), "mask.png"))
if err != nil {
t.Error("open mask file error")
return
t.Fatalf("open mask file error: %v", err)
}
defer func() {
mask.Close()
origin.Close()
os.Remove("mask.png")
os.Remove("image.png")
}()
defer mask.Close()
_, err = client.CreateEditImage(context.Background(), openai.ImageEditRequest{
Image: origin,
@@ -121,16 +115,11 @@ func TestImageEditWithoutMask(t *testing.T) {
defer teardown()
server.RegisterHandler("/v1/images/edits", handleEditImageEndpoint)
origin, err := os.Create("image.png")
origin, err := os.Create(filepath.Join(t.TempDir(), "image.png"))
if err != nil {
t.Error("open origin file error")
return
t.Fatalf("open origin file error: %v", err)
}
defer func() {
origin.Close()
os.Remove("image.png")
}()
defer origin.Close()
_, err = client.CreateEditImage(context.Background(), openai.ImageEditRequest{
Image: origin,
@@ -178,16 +167,11 @@ func TestImageVariation(t *testing.T) {
defer teardown()
server.RegisterHandler("/v1/images/variations", handleVariateImageEndpoint)
origin, err := os.Create("image.png")
origin, err := os.Create(filepath.Join(t.TempDir(), "image.png"))
if err != nil {
t.Error("open origin file error")
return
t.Fatalf("open origin file error: %v", err)
}
defer func() {
origin.Close()
os.Remove("image.png")
}()
defer origin.Close()
_, err = client.CreateVariImage(context.Background(), openai.ImageVariRequest{
Image: origin,

View File

@@ -1,8 +1,8 @@
package openai //nolint:testpackage // testing private field
import (
utils "github.com/sashabaranov/go-openai/internal"
"github.com/sashabaranov/go-openai/internal/test/checks"
utils "git.vaala.cloud/VaalaCat/go-openai/internal"
"git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
"context"
"fmt"
@@ -54,13 +54,13 @@ func TestImageFormBuilderFailures(t *testing.T) {
}
mockFailedErr := fmt.Errorf("mock form builder fail")
mockBuilder.mockCreateFormFile = func(string, *os.File) error {
mockBuilder.mockCreateFormFileReader = func(string, io.Reader, string) error {
return mockFailedErr
}
_, err := client.CreateEditImage(ctx, req)
checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails")
mockBuilder.mockCreateFormFile = func(name string, _ *os.File) error {
mockBuilder.mockCreateFormFileReader = func(name string, _ io.Reader, _ string) error {
if name == "mask" {
return mockFailedErr
}
@@ -119,13 +119,13 @@ func TestVariImageFormBuilderFailures(t *testing.T) {
req := ImageVariRequest{}
mockFailedErr := fmt.Errorf("mock form builder fail")
mockBuilder.mockCreateFormFile = func(string, *os.File) error {
mockBuilder.mockCreateFormFileReader = func(string, io.Reader, string) error {
return mockFailedErr
}
_, err := client.CreateVariImage(ctx, req)
checks.ErrorIs(t, err, mockFailedErr, "CreateVariImage should return error if form builder fails")
mockBuilder.mockCreateFormFile = func(string, *os.File) error {
mockBuilder.mockCreateFormFileReader = func(string, io.Reader, string) error {
return nil
}

View File

@@ -5,8 +5,8 @@ import (
"errors"
"testing"
utils "github.com/sashabaranov/go-openai/internal"
"github.com/sashabaranov/go-openai/internal/test"
utils "git.vaala.cloud/VaalaCat/go-openai/internal"
"git.vaala.cloud/VaalaCat/go-openai/internal/test"
)
func TestErrorAccumulatorBytes(t *testing.T) {

View File

@@ -4,8 +4,10 @@ import (
"fmt"
"io"
"mime/multipart"
"net/textproto"
"os"
"path"
"path/filepath"
"strings"
)
type FormBuilder interface {
@@ -30,8 +32,37 @@ func (fb *DefaultFormBuilder) CreateFormFile(fieldname string, file *os.File) er
return fb.createFormFile(fieldname, file, file.Name())
}
var quoteEscaper = strings.NewReplacer("\\", "\\\\", `"`, "\\\"")
func escapeQuotes(s string) string {
return quoteEscaper.Replace(s)
}
// CreateFormFileReader creates a form field with a file reader.
// The filename in parameters can be an empty string.
// The filename in Content-Disposition is required, But it can be an empty string.
func (fb *DefaultFormBuilder) CreateFormFileReader(fieldname string, r io.Reader, filename string) error {
return fb.createFormFile(fieldname, r, path.Base(filename))
h := make(textproto.MIMEHeader)
h.Set(
"Content-Disposition",
fmt.Sprintf(
`form-data; name="%s"; filename="%s"`,
escapeQuotes(fieldname),
escapeQuotes(filepath.Base(filename)),
),
)
fieldWriter, err := fb.writer.CreatePart(h)
if err != nil {
return err
}
_, err = io.Copy(fieldWriter, r)
if err != nil {
return err
}
return nil
}
func (fb *DefaultFormBuilder) createFormFile(fieldname string, r io.Reader, filename string) error {

View File

@@ -1,8 +1,7 @@
package openai //nolint:testpackage // testing private field
import (
"github.com/sashabaranov/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks"
"git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
"bytes"
"errors"
@@ -20,15 +19,11 @@ func (*failingWriter) Write([]byte) (int, error) {
}
func TestFormBuilderWithFailingWriter(t *testing.T) {
dir, cleanup := test.CreateTestDirectory(t)
defer cleanup()
file, err := os.CreateTemp(dir, "")
file, err := os.CreateTemp(t.TempDir(), "")
if err != nil {
t.Errorf("Error creating tmp file: %v", err)
t.Fatalf("Error creating tmp file: %v", err)
}
defer file.Close()
defer os.Remove(file.Name())
builder := NewFormBuilder(&failingWriter{})
err = builder.CreateFormFile("file", file)
@@ -36,15 +31,11 @@ func TestFormBuilderWithFailingWriter(t *testing.T) {
}
func TestFormBuilderWithClosedFile(t *testing.T) {
dir, cleanup := test.CreateTestDirectory(t)
defer cleanup()
file, err := os.CreateTemp(dir, "")
file, err := os.CreateTemp(t.TempDir(), "")
if err != nil {
t.Errorf("Error creating tmp file: %v", err)
t.Fatalf("Error creating tmp file: %v", err)
}
file.Close()
defer os.Remove(file.Name())
body := &bytes.Buffer{}
builder := NewFormBuilder(body)
@@ -52,3 +43,32 @@ func TestFormBuilderWithClosedFile(t *testing.T) {
checks.HasError(t, err, "formbuilder should return error if file is closed")
checks.ErrorIs(t, err, os.ErrClosed, "formbuilder should return error if file is closed")
}
type failingReader struct {
}
var errMockFailingReaderError = errors.New("mock reader failed")
func (*failingReader) Read([]byte) (int, error) {
return 0, errMockFailingReaderError
}
func TestFormBuilderWithReader(t *testing.T) {
file, err := os.CreateTemp(t.TempDir(), "")
if err != nil {
t.Fatalf("Error creating tmp file: %v", err)
}
defer file.Close()
builder := NewFormBuilder(&failingWriter{})
err = builder.CreateFormFileReader("file", file, file.Name())
checks.ErrorIs(t, err, errMockFailingWriterError, "formbuilder should return error if writer fails")
builder = NewFormBuilder(&bytes.Buffer{})
reader := &failingReader{}
err = builder.CreateFormFileReader("file", reader, "")
checks.ErrorIs(t, err, errMockFailingReaderError, "formbuilder should return error if copy reader fails")
successReader := &bytes.Buffer{}
err = builder.CreateFormFileReader("file", successReader, "")
checks.NoError(t, err, "formbuilder should not return error")
}

View File

@@ -1,7 +1,7 @@
package test
import (
"github.com/sashabaranov/go-openai/internal/test/checks"
"git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
"net/http"
"os"
@@ -19,16 +19,6 @@ func CreateTestFile(t *testing.T, path string) {
file.Close()
}
// CreateTestDirectory creates a temporary folder which will be deleted when cleanup is called.
func CreateTestDirectory(t *testing.T) (path string, cleanup func()) {
t.Helper()
path, err := os.MkdirTemp(os.TempDir(), "")
checks.NoError(t, err)
return path, func() { os.RemoveAll(path) }
}
// TokenRoundTripper is a struct that implements the RoundTripper
// interface, specifically to handle the authentication token by adding a token
// to the request header. We need this because the API requires that each

View File

@@ -46,6 +46,8 @@ type Definition struct {
// additionalProperties: false
// additionalProperties: jsonschema.Definition{Type: jsonschema.String}
AdditionalProperties any `json:"additionalProperties,omitempty"`
// Whether the schema is nullable or not.
Nullable bool `json:"nullable,omitempty"`
}
func (d *Definition) MarshalJSON() ([]byte, error) {
@@ -124,9 +126,12 @@ func reflectSchemaObject(t reflect.Type) (*Definition, error) {
}
jsonTag := field.Tag.Get("json")
var required = true
if jsonTag == "" {
switch {
case jsonTag == "-":
continue
case jsonTag == "":
jsonTag = field.Name
} else if strings.HasSuffix(jsonTag, ",omitempty") {
case strings.HasSuffix(jsonTag, ",omitempty"):
jsonTag = strings.TrimSuffix(jsonTag, ",omitempty")
required = false
}
@@ -139,6 +144,16 @@ func reflectSchemaObject(t reflect.Type) (*Definition, error) {
if description != "" {
item.Description = description
}
enum := field.Tag.Get("enum")
if enum != "" {
item.Enum = strings.Split(enum, ",")
}
if n := field.Tag.Get("nullable"); n != "" {
nullable, _ := strconv.ParseBool(n)
item.Nullable = nullable
}
properties[jsonTag] = *item
if s := field.Tag.Get("required"); s != "" {

View File

@@ -5,7 +5,7 @@ import (
"reflect"
"testing"
"github.com/sashabaranov/go-openai/jsonschema"
"git.vaala.cloud/VaalaCat/go-openai/jsonschema"
)
func TestDefinition_MarshalJSON(t *testing.T) {
@@ -17,7 +17,7 @@ func TestDefinition_MarshalJSON(t *testing.T) {
{
name: "Test with empty Definition",
def: jsonschema.Definition{},
want: `{"properties":{}}`,
want: `{}`,
},
{
name: "Test with Definition properties set",
@@ -35,11 +35,10 @@ func TestDefinition_MarshalJSON(t *testing.T) {
"description":"A string type",
"properties":{
"name":{
"type":"string",
"properties":{}
"type":"string"
}
}
}`,
}`,
},
{
name: "Test with nested Definition properties",
@@ -66,17 +65,15 @@ func TestDefinition_MarshalJSON(t *testing.T) {
"type":"object",
"properties":{
"name":{
"type":"string",
"properties":{}
"type":"string"
},
"age":{
"type":"integer",
"properties":{}
"type":"integer"
}
}
}
}
}`,
}`,
},
{
name: "Test with complex nested Definition",
@@ -114,30 +111,26 @@ func TestDefinition_MarshalJSON(t *testing.T) {
"type":"object",
"properties":{
"name":{
"type":"string",
"properties":{}
"type":"string"
},
"age":{
"type":"integer",
"properties":{}
"type":"integer"
},
"address":{
"type":"object",
"properties":{
"city":{
"type":"string",
"properties":{}
"type":"string"
},
"country":{
"type":"string",
"properties":{}
"type":"string"
}
}
}
}
}
}
}`,
}`,
},
{
name: "Test with Array type Definition",
@@ -155,18 +148,14 @@ func TestDefinition_MarshalJSON(t *testing.T) {
want: `{
"type":"array",
"items":{
"type":"string",
"properties":{
}
"type":"string"
},
"properties":{
"name":{
"type":"string",
"properties":{}
"type":"string"
}
}
}`,
}`,
},
}
@@ -193,6 +182,232 @@ func TestDefinition_MarshalJSON(t *testing.T) {
}
}
func TestStructToSchema(t *testing.T) {
tests := []struct {
name string
in any
want string
}{
{
name: "Test with empty struct",
in: struct{}{},
want: `{
"type":"object",
"additionalProperties":false
}`,
},
{
name: "Test with struct containing many fields",
in: struct {
Name string `json:"name"`
Age int `json:"age"`
Active bool `json:"active"`
Height float64 `json:"height"`
Cities []struct {
Name string `json:"name"`
State string `json:"state"`
} `json:"cities"`
}{
Name: "John Doe",
Age: 30,
Cities: []struct {
Name string `json:"name"`
State string `json:"state"`
}{
{Name: "New York", State: "NY"},
{Name: "Los Angeles", State: "CA"},
},
},
want: `{
"type":"object",
"properties":{
"name":{
"type":"string"
},
"age":{
"type":"integer"
},
"active":{
"type":"boolean"
},
"height":{
"type":"number"
},
"cities":{
"type":"array",
"items":{
"additionalProperties":false,
"type":"object",
"properties":{
"name":{
"type":"string"
},
"state":{
"type":"string"
}
},
"required":["name","state"]
}
}
},
"required":["name","age","active","height","cities"],
"additionalProperties":false
}`,
},
{
name: "Test with description tag",
in: struct {
Name string `json:"name" description:"The name of the person"`
}{
Name: "John Doe",
},
want: `{
"type":"object",
"properties":{
"name":{
"type":"string",
"description":"The name of the person"
}
},
"required":["name"],
"additionalProperties":false
}`,
},
{
name: "Test with required tag",
in: struct {
Name string `json:"name" required:"false"`
}{
Name: "John Doe",
},
want: `{
"type":"object",
"properties":{
"name":{
"type":"string"
}
},
"additionalProperties":false
}`,
},
{
name: "Test with enum tag",
in: struct {
Color string `json:"color" enum:"red,green,blue"`
}{
Color: "red",
},
want: `{
"type":"object",
"properties":{
"color":{
"type":"string",
"enum":["red","green","blue"]
}
},
"required":["color"],
"additionalProperties":false
}`,
},
{
name: "Test with nullable tag",
in: struct {
Name *string `json:"name" nullable:"true"`
}{
Name: nil,
},
want: `{
"type":"object",
"properties":{
"name":{
"type":"string",
"nullable":true
}
},
"required":["name"],
"additionalProperties":false
}`,
},
{
name: "Test with exclude mark",
in: struct {
Name string `json:"-"`
}{
Name: "Name",
},
want: `{
"type":"object",
"additionalProperties":false
}`,
},
{
name: "Test with no json tag",
in: struct {
Name string
}{
Name: "",
},
want: `{
"type":"object",
"properties":{
"Name":{
"type":"string"
}
},
"required":["Name"],
"additionalProperties":false
}`,
},
{
name: "Test with omitempty tag",
in: struct {
Name string `json:"name,omitempty"`
}{
Name: "",
},
want: `{
"type":"object",
"properties":{
"name":{
"type":"string"
}
},
"additionalProperties":false
}`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
wantBytes := []byte(tt.want)
schema, err := jsonschema.GenerateSchemaForType(tt.in)
if err != nil {
t.Errorf("Failed to generate schema: error = %v", err)
return
}
var want map[string]interface{}
err = json.Unmarshal(wantBytes, &want)
if err != nil {
t.Errorf("Failed to Unmarshal JSON: error = %v", err)
return
}
got := structToMap(t, schema)
gotPtr := structToMap(t, &schema)
if !reflect.DeepEqual(got, want) {
t.Errorf("MarshalJSON() got = %v, want %v", got, want)
}
if !reflect.DeepEqual(gotPtr, want) {
t.Errorf("MarshalJSON() gotPtr = %v, want %v", gotPtr, want)
}
})
}
}
func structToMap(t *testing.T, v any) map[string]any {
t.Helper()
gotBytes, err := json.Marshal(v)

View File

@@ -3,7 +3,7 @@ package jsonschema_test
import (
"testing"
"github.com/sashabaranov/go-openai/jsonschema"
"git.vaala.cloud/VaalaCat/go-openai/jsonschema"
)
func Test_Validate(t *testing.T) {

View File

@@ -41,6 +41,7 @@ type MessageContent struct {
Type string `json:"type"`
Text *MessageText `json:"text,omitempty"`
ImageFile *ImageFile `json:"image_file,omitempty"`
ImageURL *ImageURL `json:"image_url,omitempty"`
}
type MessageText struct {
Value string `json:"value"`
@@ -51,6 +52,11 @@ type ImageFile struct {
FileID string `json:"file_id"`
}
type ImageURL struct {
URL string `json:"url"`
Detail string `json:"detail"`
}
type MessageRequest struct {
Role string `json:"role"`
Content string `json:"content"`

View File

@@ -7,9 +7,9 @@ import (
"net/http"
"testing"
"github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks"
"git.vaala.cloud/VaalaCat/go-openai"
"git.vaala.cloud/VaalaCat/go-openai/internal/test"
"git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
)
var emptyStr = ""

View File

@@ -9,8 +9,8 @@ import (
"testing"
"time"
"github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks"
"git.vaala.cloud/VaalaCat/go-openai"
"git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
)
const testFineTuneModelID = "fine-tune-model-id"
@@ -47,6 +47,24 @@ func TestGetModel(t *testing.T) {
checks.NoError(t, err, "GetModel error")
}
// TestGetModelO3 Tests the retrieve O3 model endpoint of the API using the mocked server.
func TestGetModelO3(t *testing.T) {
client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler("/v1/models/o3", handleGetModelEndpoint)
_, err := client.GetModel(context.Background(), "o3")
checks.NoError(t, err, "GetModel error for O3")
}
// TestGetModelO4Mini Tests the retrieve O4Mini model endpoint of the API using the mocked server.
func TestGetModelO4Mini(t *testing.T) {
client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler("/v1/models/o4-mini", handleGetModelEndpoint)
_, err := client.GetModel(context.Background(), "o4-mini")
checks.NoError(t, err, "GetModel error for O4Mini")
}
func TestAzureGetModel(t *testing.T) {
client, server, teardown := setupAzureTestServer()
defer teardown()

View File

@@ -11,8 +11,8 @@ import (
"testing"
"time"
"github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks"
"git.vaala.cloud/VaalaCat/go-openai"
"git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
)
// TestModeration Tests the moderations endpoint of the API using the mocked server.

View File

@@ -1,8 +1,8 @@
package openai_test
import (
"github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test"
"git.vaala.cloud/VaalaCat/go-openai"
"git.vaala.cloud/VaalaCat/go-openai/internal/test"
)
func setupOpenAITestServer() (client *openai.Client, server *test.ServerTest, teardown func()) {
@@ -29,9 +29,9 @@ func setupAzureTestServer() (client *openai.Client, server *test.ServerTest, tea
// numTokens Returns the number of GPT-3 encoded tokens in the given text.
// This function approximates based on the rule of thumb stated by OpenAI:
// https://beta.openai.com/tokenizer
// https://beta.openai.com/tokenizer.
//
// TODO: implement an actual tokenizer for GPT-3 and Codex (once available)
// TODO: implement an actual tokenizer for GPT-3 and Codex (once available).
func numTokens(s string) int {
return int(float32(len(s)) / 4)
}

81
reasoning_validator.go Normal file
View File

@@ -0,0 +1,81 @@
package openai
import (
"errors"
"strings"
)
var (
// Deprecated: use ErrReasoningModelMaxTokensDeprecated instead.
ErrO1MaxTokensDeprecated = errors.New("this model is not supported MaxTokens, please use MaxCompletionTokens") //nolint:lll
ErrCompletionUnsupportedModel = errors.New("this model is not supported with this method, please use CreateChatCompletion client method instead") //nolint:lll
ErrCompletionStreamNotSupported = errors.New("streaming is not supported with this method, please use CreateCompletionStream") //nolint:lll
ErrCompletionRequestPromptTypeNotSupported = errors.New("the type of CompletionRequest.Prompt only supports string and []string") //nolint:lll
)
var (
ErrO1BetaLimitationsMessageTypes = errors.New("this model has beta-limitations, user and assistant messages only, system messages are not supported") //nolint:lll
ErrO1BetaLimitationsTools = errors.New("this model has beta-limitations, tools, function calling, and response format parameters are not supported") //nolint:lll
// Deprecated: use ErrReasoningModelLimitations* instead.
ErrO1BetaLimitationsLogprobs = errors.New("this model has beta-limitations, logprobs not supported") //nolint:lll
ErrO1BetaLimitationsOther = errors.New("this model has beta-limitations, temperature, top_p and n are fixed at 1, while presence_penalty and frequency_penalty are fixed at 0") //nolint:lll
)
var (
//nolint:lll
ErrReasoningModelMaxTokensDeprecated = errors.New("this model is not supported MaxTokens, please use MaxCompletionTokens")
ErrReasoningModelLimitationsLogprobs = errors.New("this model has beta-limitations, logprobs not supported") //nolint:lll
ErrReasoningModelLimitationsOther = errors.New("this model has beta-limitations, temperature, top_p and n are fixed at 1, while presence_penalty and frequency_penalty are fixed at 0") //nolint:lll
)
// ReasoningValidator handles validation for o-series model requests.
type ReasoningValidator struct{}
// NewReasoningValidator creates a new validator for o-series models.
func NewReasoningValidator() *ReasoningValidator {
return &ReasoningValidator{}
}
// Validate performs all validation checks for o-series models.
func (v *ReasoningValidator) Validate(request ChatCompletionRequest) error {
o1Series := strings.HasPrefix(request.Model, "o1")
o3Series := strings.HasPrefix(request.Model, "o3")
o4Series := strings.HasPrefix(request.Model, "o4")
if !o1Series && !o3Series && !o4Series {
return nil
}
if err := v.validateReasoningModelParams(request); err != nil {
return err
}
return nil
}
// validateReasoningModelParams checks reasoning model parameters.
func (v *ReasoningValidator) validateReasoningModelParams(request ChatCompletionRequest) error {
if request.MaxTokens > 0 {
return ErrReasoningModelMaxTokensDeprecated
}
if request.LogProbs {
return ErrReasoningModelLimitationsLogprobs
}
if request.Temperature > 0 && request.Temperature != 1 {
return ErrReasoningModelLimitationsOther
}
if request.TopP > 0 && request.TopP != 1 {
return ErrReasoningModelLimitationsOther
}
if request.N > 0 && request.N != 1 {
return ErrReasoningModelLimitationsOther
}
if request.PresencePenalty > 0 {
return ErrReasoningModelLimitationsOther
}
if request.FrequencyPenalty > 0 {
return ErrReasoningModelLimitationsOther
}
return nil
}

1
run.go
View File

@@ -87,6 +87,7 @@ type RunRequest struct {
Model string `json:"model,omitempty"`
Instructions string `json:"instructions,omitempty"`
AdditionalInstructions string `json:"additional_instructions,omitempty"`
AdditionalMessages []ThreadMessage `json:"additional_messages,omitempty"`
Tools []Tool `json:"tools,omitempty"`
Metadata map[string]any `json:"metadata,omitempty"`

View File

@@ -3,8 +3,8 @@ package openai_test
import (
"context"
openai "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks"
openai "git.vaala.cloud/VaalaCat/go-openai"
"git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
"encoding/json"
"fmt"

View File

@@ -11,17 +11,22 @@ const (
TTSModel1 SpeechModel = "tts-1"
TTSModel1HD SpeechModel = "tts-1-hd"
TTSModelCanary SpeechModel = "canary-tts"
TTSModelGPT4oMini SpeechModel = "gpt-4o-mini-tts"
)
type SpeechVoice string
const (
VoiceAlloy SpeechVoice = "alloy"
VoiceAsh SpeechVoice = "ash"
VoiceBallad SpeechVoice = "ballad"
VoiceCoral SpeechVoice = "coral"
VoiceEcho SpeechVoice = "echo"
VoiceFable SpeechVoice = "fable"
VoiceOnyx SpeechVoice = "onyx"
VoiceNova SpeechVoice = "nova"
VoiceShimmer SpeechVoice = "shimmer"
VoiceVerse SpeechVoice = "verse"
)
type SpeechResponseFormat string
@@ -39,6 +44,7 @@ type CreateSpeechRequest struct {
Model SpeechModel `json:"model"`
Input string `json:"input"`
Voice SpeechVoice `json:"voice"`
Instructions string `json:"instructions,omitempty"` // Optional, Doesnt work with tts-1 or tts-1-hd.
ResponseFormat SpeechResponseFormat `json:"response_format,omitempty"` // Optional, default to mp3
Speed float64 `json:"speed,omitempty"` // Optional, default to 1.0
}

View File

@@ -11,9 +11,9 @@ import (
"path/filepath"
"testing"
"github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks"
"git.vaala.cloud/VaalaCat/go-openai"
"git.vaala.cloud/VaalaCat/go-openai/internal/test"
"git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
)
func TestSpeechIntegration(t *testing.T) {
@@ -21,10 +21,8 @@ func TestSpeechIntegration(t *testing.T) {
defer teardown()
server.RegisterHandler("/v1/audio/speech", func(w http.ResponseWriter, r *http.Request) {
dir, cleanup := test.CreateTestDirectory(t)
path := filepath.Join(dir, "fake.mp3")
path := filepath.Join(t.TempDir(), "fake.mp3")
test.CreateTestFile(t, path)
defer cleanup()
// audio endpoints only accept POST requests
if r.Method != "POST" {

View File

@@ -6,13 +6,14 @@ import (
"fmt"
"io"
"net/http"
"regexp"
utils "github.com/sashabaranov/go-openai/internal"
utils "git.vaala.cloud/VaalaCat/go-openai/internal"
)
var (
headerData = []byte("data: ")
errorPrefix = []byte(`data: {"error":`)
headerData = regexp.MustCompile(`^data:\s*`)
errorPrefix = regexp.MustCompile(`^data:\s*{"error":`)
)
type streamable interface {
@@ -32,17 +33,28 @@ type streamReader[T streamable] struct {
}
func (stream *streamReader[T]) Recv() (response T, err error) {
if stream.isFinished {
err = io.EOF
rawLine, err := stream.RecvRaw()
if err != nil {
return
}
response, err = stream.processLines()
err = stream.unmarshaler.Unmarshal(rawLine, &response)
if err != nil {
return
}
return response, nil
}
func (stream *streamReader[T]) RecvRaw() ([]byte, error) {
if stream.isFinished {
return nil, io.EOF
}
return stream.processLines()
}
//nolint:gocognit
func (stream *streamReader[T]) processLines() (T, error) {
func (stream *streamReader[T]) processLines() ([]byte, error) {
var (
emptyMessagesCount uint
hasErrorPrefix bool
@@ -53,44 +65,38 @@ func (stream *streamReader[T]) processLines() (T, error) {
if readErr != nil || hasErrorPrefix {
respErr := stream.unmarshalError()
if respErr != nil {
return *new(T), fmt.Errorf("error, %w", respErr.Error)
return nil, fmt.Errorf("error, %w", respErr.Error)
}
return *new(T), readErr
return nil, readErr
}
noSpaceLine := bytes.TrimSpace(rawLine)
if bytes.HasPrefix(noSpaceLine, errorPrefix) {
if errorPrefix.Match(noSpaceLine) {
hasErrorPrefix = true
}
if !bytes.HasPrefix(noSpaceLine, headerData) || hasErrorPrefix {
if !headerData.Match(noSpaceLine) || hasErrorPrefix {
if hasErrorPrefix {
noSpaceLine = bytes.TrimPrefix(noSpaceLine, headerData)
noSpaceLine = headerData.ReplaceAll(noSpaceLine, nil)
}
writeErr := stream.errAccumulator.Write(noSpaceLine)
if writeErr != nil {
return *new(T), writeErr
return nil, writeErr
}
emptyMessagesCount++
if emptyMessagesCount > stream.emptyMessagesLimit {
return *new(T), ErrTooManyEmptyStreamMessages
return nil, ErrTooManyEmptyStreamMessages
}
continue
}
noPrefixLine := bytes.TrimPrefix(noSpaceLine, headerData)
noPrefixLine := headerData.ReplaceAll(noSpaceLine, nil)
if string(noPrefixLine) == "[DONE]" {
stream.isFinished = true
return *new(T), io.EOF
return nil, io.EOF
}
var response T
unmarshalErr := stream.unmarshaler.Unmarshal(noPrefixLine, &response)
if unmarshalErr != nil {
return *new(T), unmarshalErr
}
return response, nil
return noPrefixLine, nil
}
}

View File

@@ -6,9 +6,9 @@ import (
"errors"
"testing"
utils "github.com/sashabaranov/go-openai/internal"
"github.com/sashabaranov/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks"
utils "git.vaala.cloud/VaalaCat/go-openai/internal"
"git.vaala.cloud/VaalaCat/go-openai/internal/test"
"git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
)
var errTestUnmarshalerFailed = errors.New("test unmarshaler failed")
@@ -63,3 +63,16 @@ func TestStreamReaderReturnsErrTestErrorAccumulatorWriteFailed(t *testing.T) {
_, err := stream.Recv()
checks.ErrorIs(t, err, test.ErrTestErrorAccumulatorWriteFailed, "Did not return error when write failed", err.Error())
}
func TestStreamReaderRecvRaw(t *testing.T) {
stream := &streamReader[ChatCompletionStreamResponse]{
reader: bufio.NewReader(bytes.NewReader([]byte("data: {\"key\": \"value\"}\n"))),
}
rawLine, err := stream.RecvRaw()
if err != nil {
t.Fatalf("Did not return raw line: %v", err)
}
if !bytes.Equal(rawLine, []byte("{\"key\": \"value\"}")) {
t.Fatalf("Did not return raw line: %v", string(rawLine))
}
}

View File

@@ -10,8 +10,8 @@ import (
"testing"
"time"
"github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks"
"git.vaala.cloud/VaalaCat/go-openai"
"git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
)
func TestCompletionsStreamWrongModel(t *testing.T) {

View File

@@ -7,8 +7,8 @@ import (
"net/http"
"testing"
openai "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks"
openai "git.vaala.cloud/VaalaCat/go-openai"
"git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
)
// TestThread Tests the thread endpoint of the API using the mocked server.

View File

@@ -3,8 +3,8 @@ package openai_test
import (
"context"
openai "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks"
openai "git.vaala.cloud/VaalaCat/go-openai"
"git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
"encoding/json"
"fmt"