Add OpenAI Mock Server (#31)

* add constants for completions, refactor usage, add test server

Signed-off-by: Oleg <97077423+RobotSail@users.noreply.github.com>

* append v1 endpoint to test

Signed-off-by: Oleg <97077423+RobotSail@users.noreply.github.com>

* add makefile for easy targets

Signed-off-by: Oleg <97077423+RobotSail@users.noreply.github.com>

* lint files & add linter

Signed-off-by: Oleg <97077423+RobotSail@users.noreply.github.com>

* disable real API tests in short mode

Signed-off-by: Oleg <97077423+RobotSail@users.noreply.github.com>

Signed-off-by: Oleg <97077423+RobotSail@users.noreply.github.com>
This commit is contained in:
Oleg
2022-08-11 05:29:23 -04:00
committed by GitHub
parent 8b463ceb2b
commit d63df93c65
12 changed files with 619 additions and 61 deletions

275
.golangci.yml Normal file
View File

@@ -0,0 +1,275 @@
## Golden config for golangci-lint v1.47.3
#
# This is the best config for golangci-lint based on my experience and opinion.
# It is very strict, but not extremely strict.
# Feel free to adopt and change it for your needs.
run:
# Timeout for analysis, e.g. 30s, 5m.
# Default: 1m
timeout: 3m
# This file contains only configs which differ from defaults.
# All possible options can be found here https://github.com/golangci/golangci-lint/blob/master/.golangci.reference.yml
linters-settings:
cyclop:
# The maximal code complexity to report.
# Default: 10
max-complexity: 30
# The maximal average package complexity.
# If it's higher than 0.0 (float) the check is enabled
# Default: 0.0
package-average: 10.0
errcheck:
# Report about not checking of errors in type assertions: `a := b.(MyStruct)`.
# Such cases aren't reported by default.
# Default: false
check-type-assertions: true
funlen:
# Checks the number of lines in a function.
# If lower than 0, disable the check.
# Default: 60
lines: 100
# Checks the number of statements in a function.
# If lower than 0, disable the check.
# Default: 40
statements: 50
gocognit:
# Minimal code complexity to report
# Default: 30 (but we recommend 10-20)
min-complexity: 20
gocritic:
# Settings passed to gocritic.
# The settings key is the name of a supported gocritic checker.
# The list of supported checkers can be find in https://go-critic.github.io/overview.
settings:
captLocal:
# Whether to restrict checker to params only.
# Default: true
paramsOnly: false
underef:
# Whether to skip (*x).method() calls where x is a pointer receiver.
# Default: true
skipRecvDeref: false
gomnd:
# List of function patterns to exclude from analysis.
# Values always ignored: `time.Date`
# Default: []
ignored-functions:
- os.Chmod
- os.Mkdir
- os.MkdirAll
- os.OpenFile
- os.WriteFile
- prometheus.ExponentialBuckets
- prometheus.ExponentialBucketsRange
- prometheus.LinearBuckets
- strconv.FormatFloat
- strconv.FormatInt
- strconv.FormatUint
- strconv.ParseFloat
- strconv.ParseInt
- strconv.ParseUint
gomodguard:
blocked:
# List of blocked modules.
# Default: []
modules:
- github.com/golang/protobuf:
recommendations:
- google.golang.org/protobuf
reason: "see https://developers.google.com/protocol-buffers/docs/reference/go/faq#modules"
- github.com/satori/go.uuid:
recommendations:
- github.com/google/uuid
reason: "satori's package is not maintained"
- github.com/gofrs/uuid:
recommendations:
- github.com/google/uuid
reason: "see recommendation from dev-infra team: https://confluence.gtforge.com/x/gQI6Aw"
govet:
# Enable all analyzers.
# Default: false
enable-all: true
# Disable analyzers by name.
# Run `go tool vet help` to see all analyzers.
# Default: []
disable:
- fieldalignment # too strict
# Settings per analyzer.
settings:
shadow:
# Whether to be strict about shadowing; can be noisy.
# Default: false
strict: true
nakedret:
# Make an issue if func has more lines of code than this setting, and it has naked returns.
# Default: 30
max-func-lines: 0
nolintlint:
# Exclude following linters from requiring an explanation.
# Default: []
allow-no-explanation: [ funlen, gocognit, lll ]
# Enable to require an explanation of nonzero length after each nolint directive.
# Default: false
require-explanation: true
# Enable to require nolint directives to mention the specific linter being suppressed.
# Default: false
require-specific: true
rowserrcheck:
# database/sql is always checked
# Default: []
packages:
- github.com/jmoiron/sqlx
tenv:
# The option `all` will run against whole test files (`_test.go`) regardless of method/function signatures.
# Otherwise, only methods that take `*testing.T`, `*testing.B`, and `testing.TB` as arguments are checked.
# Default: false
all: true
varcheck:
# Check usage of exported fields and variables.
# Default: false
exported-fields: false # default false # TODO: enable after fixing false positives
linters:
disable-all: true
enable:
## enabled by default
- deadcode # Finds unused code
- errcheck # Errcheck is a program for checking for unchecked errors in go programs. These unchecked errors can be critical bugs in some cases
- gosimple # Linter for Go source code that specializes in simplifying a code
- govet # Vet examines Go source code and reports suspicious constructs, such as Printf calls whose arguments do not align with the format string
- ineffassign # Detects when assignments to existing variables are not used
- staticcheck # Staticcheck is a go vet on steroids, applying a ton of static analysis checks
- structcheck # Finds unused struct fields
- typecheck # Like the front-end of a Go compiler, parses and type-checks Go code
- unused # Checks Go code for unused constants, variables, functions and types
- varcheck # Finds unused global variables and constants
## disabled by default
# - asasalint # Check for pass []any as any in variadic func(...any)
- asciicheck # Simple linter to check that your code does not contain non-ASCII identifiers
- bidichk # Checks for dangerous unicode character sequences
- bodyclose # checks whether HTTP response body is closed successfully
- contextcheck # check the function whether use a non-inherited context
- cyclop # checks function and package cyclomatic complexity
- dupl # Tool for code clone detection
- durationcheck # check for two durations multiplied together
- errname # Checks that sentinel errors are prefixed with the Err and error types are suffixed with the Error.
- errorlint # errorlint is a linter for that can be used to find code that will cause problems with the error wrapping scheme introduced in Go 1.13.
- execinquery # execinquery is a linter about query string checker in Query function which reads your Go src files and warning it finds
- exhaustive # check exhaustiveness of enum switch statements
- exportloopref # checks for pointers to enclosing loop variables
- forbidigo # Forbids identifiers
- funlen # Tool for detection of long functions
# - gochecknoglobals # check that no global variables exist
- gochecknoinits # Checks that no init functions are present in Go code
- gocognit # Computes and checks the cognitive complexity of functions
- goconst # Finds repeated strings that could be replaced by a constant
- gocritic # Provides diagnostics that check for bugs, performance and style issues.
- gocyclo # Computes and checks the cyclomatic complexity of functions
- godot # Check if comments end in a period
- goimports # In addition to fixing imports, goimports also formats your code in the same style as gofmt.
- gomnd # An analyzer to detect magic numbers.
- gomoddirectives # Manage the use of 'replace', 'retract', and 'excludes' directives in go.mod.
- gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations.
- goprintffuncname # Checks that printf-like functions are named with f at the end
- gosec # Inspects source code for security problems
- lll # Reports long lines
- makezero # Finds slice declarations with non-zero initial length
# - nakedret # Finds naked returns in functions greater than a specified function length
- nestif # Reports deeply nested if statements
- nilerr # Finds the code that returns nil even if it checks that the error is not nil.
- nilnil # Checks that there is no simultaneous return of nil error and an invalid value.
# - noctx # noctx finds sending http request without context.Context
- nolintlint # Reports ill-formed or insufficient nolint directives
# - nonamedreturns # Reports all named returns
- nosprintfhostport # Checks for misuse of Sprintf to construct a host with port in a URL.
- predeclared # find code that shadows one of Go's predeclared identifiers
- promlinter # Check Prometheus metrics naming via promlint
- revive # Fast, configurable, extensible, flexible, and beautiful linter for Go. Drop-in replacement of golint.
- rowserrcheck # checks whether Err of rows is checked successfully
- sqlclosecheck # Checks that sql.Rows and sql.Stmt are closed.
- stylecheck # Stylecheck is a replacement for golint
- tenv # tenv is analyzer that detects using os.Setenv instead of t.Setenv since Go1.17
- testpackage # linter that makes you use a separate _test package
- tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes
- unconvert # Remove unnecessary type conversions
- unparam # Reports unused function parameters
- wastedassign # wastedassign finds wasted assignment statements.
- whitespace # Tool for detection of leading and trailing whitespace
## you may want to enable
#- decorder # check declaration order and count of types, constants, variables and functions
#- exhaustruct # Checks if all structure fields are initialized
#- goheader # Checks is file header matches to pattern
#- ireturn # Accept Interfaces, Return Concrete Types
#- prealloc # [premature optimization, but can be used in some cases] Finds slice declarations that could potentially be preallocated
#- varnamelen # [great idea, but too many false positives] checks that the length of a variable's name matches its scope
#- wrapcheck # Checks that errors returned from external packages are wrapped
## disabled
#- containedctx # containedctx is a linter that detects struct contained context.Context field
#- depguard # [replaced by gomodguard] Go linter that checks if package imports are in a list of acceptable packages
#- dogsled # Checks assignments with too many blank identifiers (e.g. x, _, _, _, := f())
#- errchkjson # [don't see profit + I'm against of omitting errors like in the first example https://github.com/breml/errchkjson] Checks types passed to the json encoding functions. Reports unsupported types and optionally reports occasions, where the check for the returned error can be omitted.
#- forcetypeassert # [replaced by errcheck] finds forced type assertions
#- gci # Gci controls golang package import order and makes it always deterministic.
#- godox # Tool for detection of FIXME, TODO and other comment keywords
#- goerr113 # [too strict] Golang linter to check the errors handling expressions
#- gofmt # [replaced by goimports] Gofmt checks whether code was gofmt-ed. By default this tool runs with -s option to check for code simplification
#- gofumpt # [replaced by goimports, gofumports is not available yet] Gofumpt checks whether code was gofumpt-ed.
#- grouper # An analyzer to analyze expression groups.
#- ifshort # Checks that your code uses short syntax for if-statements whenever possible
#- importas # Enforces consistent import aliases
#- maintidx # maintidx measures the maintainability index of each function.
#- misspell # [useless] Finds commonly misspelled English words in comments
#- nlreturn # [too strict and mostly code is not more readable] nlreturn checks for a new line before return and branch statements to increase code clarity
#- nosnakecase # Detects snake case of variable naming and function name. # TODO: maybe enable after https://github.com/sivchari/nosnakecase/issues/14
#- paralleltest # [too many false positives] paralleltest detects missing usage of t.Parallel() method in your Go test
#- tagliatelle # Checks the struct tags.
#- thelper # thelper detects golang test helpers without t.Helper() call and checks the consistency of test helpers
#- wsl # [too strict and mostly code is not more readable] Whitespace Linter - Forces you to use empty lines!
## deprecated
#- exhaustivestruct # [deprecated, replaced by exhaustruct] Checks if all struct's fields are initialized
#- golint # [deprecated, replaced by revive] Golint differs from gofmt. Gofmt reformats Go source code, whereas golint prints out style mistakes
#- interfacer # [deprecated] Linter that suggests narrower interface types
#- maligned # [deprecated, replaced by govet fieldalignment] Tool to detect Go structs that would take less memory if their fields were sorted
#- scopelint # [deprecated, replaced by exportloopref] Scopelint checks for unpinned variables in go programs
issues:
# Maximum count of issues with the same text.
# Set to 0 to disable.
# Default: 3
max-same-issues: 50
exclude-rules:
- source: "^//\\s*go:generate\\s"
linters: [ lll ]
- source: "(noinspection|TODO)"
linters: [ godot ]
- source: "//noinspection"
linters: [ gocritic ]
- source: "^\\s+if _, ok := err\\.\\([^.]+\\.InternalError\\); ok {"
linters: [ errorlint ]
- path: "_test\\.go"
linters:
- bodyclose
- dupl
- funlen
- goconst
- gosec
- noctx
- wrapcheck

35
Makefile Normal file
View File

@@ -0,0 +1,35 @@
##@ General
# The help target prints out all targets with their descriptions organized
# beneath their categories. The categories are represented by '##@' and the
# target descriptions by '##'. The awk commands is responsible for reading the
# entire set of makefiles included in this invocation, looking for lines of the
# file as xyz: ## something, and then pretty-format the target and help. Then,
# if there's a line with ##@ something, that gets pretty-printed as a category.
# More info on the usage of ANSI control characters for terminal formatting:
# https://en.wikipedia.org/wiki/ANSI_escape_code#SGR_parameters
# More info on the awk command:
# http://linuxcommand.org/lc3_adv_awk.php
.PHONY: help
help: ## Display this help.
@awk 'BEGIN {FS = ":.*##"; printf "\nUsage:\n make \033[36m<target>\033[0m\n"} /^[a-zA-Z_0-9-]+:.*?##/ { printf " \033[36m%-15s\033[0m %s\n", $$1, $$2 } /^##@/ { printf "\n\033[1m%s\033[0m\n", substr($$0, 5) } ' $(MAKEFILE_LIST)
##@ Development
.PHONY: test
TEST_ARGS ?= -v
TEST_TARGETS ?= ./...
test: ## Test the Go modules within this package.
@ echo ▶️ go test $(TEST_ARGS) $(TEST_TARGETS)
go test $(TEST_ARGS) $(TEST_TARGETS)
@ echo ✅ success!
.PHONY: lint
LINT_TARGETS ?= ./...
lint: ## Lint Go code with the installed golangci-lint
@ echo "▶️ golangci-lint run"
golangci-lint run $(LINT_TARGETS)
@ echo "✅ golangci-lint run"

6
api.go
View File

@@ -15,7 +15,7 @@ func newTransport() *http.Client {
} }
} }
// Client is OpenAI GPT-3 API client // Client is OpenAI GPT-3 API client.
type Client struct { type Client struct {
BaseURL string BaseURL string
HTTPClient *http.Client HTTPClient *http.Client
@@ -23,7 +23,7 @@ type Client struct {
idOrg string idOrg string
} }
// NewClient creates new OpenAI API client // NewClient creates new OpenAI API client.
func NewClient(authToken string) *Client { func NewClient(authToken string) *Client {
return &Client{ return &Client{
BaseURL: apiURLv1, BaseURL: apiURLv1,
@@ -33,7 +33,7 @@ func NewClient(authToken string) *Client {
} }
} }
// NewOrgClient creates new OpenAI API client for specified Organization ID // NewOrgClient creates new OpenAI API client for specified Organization ID.
func NewOrgClient(authToken, org string) *Client { func NewOrgClient(authToken, org string) *Client {
return &Client{ return &Client{
BaseURL: apiURLv1, BaseURL: apiURLv1,

View File

@@ -1,14 +1,30 @@
package gogpt package gogpt_test
import ( import (
"bytes" "bytes"
"context" "context"
"encoding/json" "encoding/json"
"fmt"
"io/ioutil" "io/ioutil"
"log"
"net/http"
"net/http/httptest"
"strconv"
"strings"
"testing" "testing"
"time"
. "github.com/sashabaranov/go-gpt3"
)
const (
testAPIToken = "this-is-my-secure-token-do-not-steal!!"
) )
func TestAPI(t *testing.T) { func TestAPI(t *testing.T) {
if testing.Short() {
t.Skip("skipping test in short mode")
}
tokenBytes, err := ioutil.ReadFile(".openai-token") tokenBytes, err := ioutil.ReadFile(".openai-token")
if err != nil { if err != nil {
t.Fatalf("Could not load auth token from .openai-token file") t.Fatalf("Could not load auth token from .openai-token file")
@@ -38,16 +54,6 @@ func TestAPI(t *testing.T) {
} }
} // else skip } // else skip
req := CompletionRequest{
MaxTokens: 5,
Model: "ada",
}
req.Prompt = "Lorem ipsum"
_, err = c.CreateCompletion(ctx, req)
if err != nil {
t.Fatalf("CreateCompletion error: %v", err)
}
searchReq := SearchRequest{ searchReq := SearchRequest{
Documents: []string{"White House", "hospital", "school"}, Documents: []string{"White House", "hospital", "school"},
Query: "the president", Query: "the president",
@@ -70,6 +76,60 @@ func TestAPI(t *testing.T) {
} }
} }
// TestCompletions Tests the completions endpoint of the API using the mocked server.
func TestCompletions(t *testing.T) {
// create the test server
var err error
ts := OpenAITestServer()
ts.Start()
defer ts.Close()
client := NewClient(testAPIToken)
ctx := context.Background()
client.BaseURL = ts.URL + "/v1"
req := CompletionRequest{
MaxTokens: 5,
Model: "ada",
}
req.Prompt = "Lorem ipsum"
_, err = client.CreateCompletion(ctx, req)
if err != nil {
t.Fatalf("CreateCompletion error: %v", err)
}
}
// TestEdits Tests the edits endpoint of the API using the mocked server.
func TestEdits(t *testing.T) {
// create the test server
var err error
ts := OpenAITestServer()
ts.Start()
defer ts.Close()
client := NewClient(testAPIToken)
ctx := context.Background()
client.BaseURL = ts.URL + "/v1"
// create an edit request
model := "ada"
editReq := EditsRequest{
Model: &model,
Input: "Lorem ipsum dolor sit amet, consectetur adipiscing elit, " +
"sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim" +
" ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip" +
" ex ea commodo consequat. Duis aute irure dolor in reprehe",
Instruction: "test instruction",
N: 3,
}
response, err := client.Edits(ctx, editReq)
if err != nil {
t.Fatalf("Edits error: %v", err)
}
if len(response.Choices) != editReq.N {
t.Fatalf("edits does not properly return the correct number of choices")
}
}
func TestEmbedding(t *testing.T) { func TestEmbedding(t *testing.T) {
embeddedModels := []EmbeddingModel{ embeddedModels := []EmbeddingModel{
AdaSimilarity, AdaSimilarity,
@@ -108,3 +168,156 @@ func TestEmbedding(t *testing.T) {
} }
} }
} }
// getEditBody Returns the body of the request to create an edit.
func getEditBody(r *http.Request) (EditsRequest, error) {
edit := EditsRequest{}
// read the request body
reqBody, err := ioutil.ReadAll(r.Body)
if err != nil {
return EditsRequest{}, err
}
err = json.Unmarshal(reqBody, &edit)
if err != nil {
return EditsRequest{}, err
}
return edit, nil
}
// handleEditEndpoint Handles the edit endpoint by the test server.
func handleEditEndpoint(w http.ResponseWriter, r *http.Request) {
var err error
var resBytes []byte
// edits only accepts POST requests
if r.Method != "POST" {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
var editReq EditsRequest
editReq, err = getEditBody(r)
if err != nil {
http.Error(w, "could not read request", http.StatusInternalServerError)
return
}
// create a response
res := EditsResponse{
Object: "test-object",
Created: uint64(time.Now().Unix()),
}
// edit and calculate token usage
editString := "edited by mocked OpenAI server :)"
inputTokens := numTokens(editReq.Input+editReq.Instruction) * editReq.N
completionTokens := int(float32(len(editString))/4) * editReq.N
for i := 0; i < editReq.N; i++ {
// instruction will be hidden and only seen by OpenAI
res.Choices = append(res.Choices, EditsChoice{
Text: editReq.Input + editString,
Index: i,
})
}
res.Usage = Usage{
PromptTokens: inputTokens,
CompletionTokens: completionTokens,
TotalTokens: inputTokens + completionTokens,
}
resBytes, _ = json.Marshal(res)
fmt.Fprint(w, string(resBytes))
}
// handleCompletionEndpoint Handles the completion endpoint by the test server.
func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
var err error
var resBytes []byte
// completions only accepts POST requests
if r.Method != "POST" {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
var completionReq CompletionRequest
if completionReq, err = getCompletionBody(r); err != nil {
http.Error(w, "could not read request", http.StatusInternalServerError)
return
}
res := CompletionResponse{
ID: strconv.Itoa(int(time.Now().Unix())),
Object: "test-object",
Created: uint64(time.Now().Unix()),
// would be nice to validate Model during testing, but
// this may not be possible with how much upkeep
// would be required / wouldn't make much sense
Model: completionReq.Model,
}
// create completions
for i := 0; i < completionReq.N; i++ {
// generate a random string of length completionReq.Length
completionStr := strings.Repeat("a", completionReq.MaxTokens)
if completionReq.Echo {
completionStr = completionReq.Prompt + completionStr
}
res.Choices = append(res.Choices, CompletionChoice{
Text: completionStr,
Index: i,
})
}
inputTokens := numTokens(completionReq.Prompt) * completionReq.N
completionTokens := completionReq.MaxTokens * completionReq.N
res.Usage = Usage{
PromptTokens: inputTokens,
CompletionTokens: completionTokens,
TotalTokens: inputTokens + completionTokens,
}
resBytes, _ = json.Marshal(res)
fmt.Fprintln(w, string(resBytes))
}
// getCompletionBody Returns the body of the request to create a completion.
func getCompletionBody(r *http.Request) (CompletionRequest, error) {
completion := CompletionRequest{}
// read the request body
reqBody, err := ioutil.ReadAll(r.Body)
if err != nil {
return CompletionRequest{}, err
}
err = json.Unmarshal(reqBody, &completion)
if err != nil {
return CompletionRequest{}, err
}
return completion, nil
}
// numTokens Returns the number of GPT-3 encoded tokens in the given text.
// This function approximates based on the rule of thumb stated by OpenAI:
// https://beta.openai.com/tokenizer
//
// TODO: implement an actual tokenizer for GPT-3 and Codex (once available)
func numTokens(s string) int {
return int(float32(len(s)) / 4)
}
// OpenAITestServer Creates a mocked OpenAI server which can pretend to handle requests during testing.
func OpenAITestServer() *httptest.Server {
return httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log.Printf("received request at path %q\n", r.URL.Path)
// check auth
if r.Header.Get("Authorization") != "Bearer "+testAPIToken {
w.WriteHeader(http.StatusUnauthorized)
return
}
// OPTIMIZE: create separate handler functions for these
switch r.URL.Path {
case "/v1/edits":
handleEditEndpoint(w, r)
return
case "/v1/completions":
handleCompletionEndpoint(w, r)
return
// TODO: implement the other endpoints
default:
// the endpoint doesn't exist
http.Error(w, "the resource path doesn't exist", http.StatusNotFound)
return
}
}))
}

9
common.go Normal file
View File

@@ -0,0 +1,9 @@
// common.go defines common types used throughout the OpenAI API.
package gogpt
// Usage Represents the total token usage per request to OpenAI.
type Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}

View File

@@ -7,7 +7,34 @@ import (
"net/http" "net/http"
) )
// CompletionRequest represents a request structure for completion API // GPT3 Defines the models provided by OpenAI to use when generating
// completions from OpenAI.
// GPT3 Models are designed for text-based tasks. For code-specific
// tasks, please refer to the Codex series of models.
const (
GPT3TextDavinci002 = "text-davinci-002"
GPT3TextCurie001 = "text-curie-001"
GPT3TextBabbage001 = "text-babbage-001"
GPT3TextAda001 = "text-ada-001"
GPT3TextDavinci001 = "text-davinci-001"
GPT3DavinciInstructBeta = "davinci-instruct-beta"
GPT3Davinci = "davinci"
GPT3CurieInstructBeta = "curie-instruct-beta"
GPT3Curie = "curie"
GPT3Ada = "ada"
GPT3Babbage = "babbage"
)
// Codex Defines the models provided by OpenAI.
// These models are designed for code-specific tasks, and use
// a different tokenizer which optimizes for whitespace.
const (
CodexCodeDavinci002 = "code-davinci-002"
CodexCodeCushman001 = "code-cushman-001"
CodexCodeDavinci001 = "code-davinci-001"
)
// CompletionRequest represents a request structure for completion API.
type CompletionRequest struct { type CompletionRequest struct {
Model string `json:"model"` Model string `json:"model"`
Prompt string `json:"prompt,omitempty"` Prompt string `json:"prompt,omitempty"`
@@ -26,7 +53,7 @@ type CompletionRequest struct {
User string `json:"user,omitempty"` User string `json:"user,omitempty"`
} }
// CompletionChoice represents one of possible completions // CompletionChoice represents one of possible completions.
type CompletionChoice struct { type CompletionChoice struct {
Text string `json:"text"` Text string `json:"text"`
Index int `json:"index"` Index int `json:"index"`
@@ -34,7 +61,7 @@ type CompletionChoice struct {
LogProbs LogprobResult `json:"logprobs"` LogProbs LogprobResult `json:"logprobs"`
} }
// LogprobResult represents logprob result of Choice // LogprobResult represents logprob result of Choice.
type LogprobResult struct { type LogprobResult struct {
Tokens []string `json:"tokens"` Tokens []string `json:"tokens"`
TokenLogprobs []float32 `json:"token_logprobs"` TokenLogprobs []float32 `json:"token_logprobs"`
@@ -42,21 +69,14 @@ type LogprobResult struct {
TextOffset []int `json:"text_offset"` TextOffset []int `json:"text_offset"`
} }
// CompletionUsage represents Usage of CompletionResponse // CompletionResponse represents a response structure for completion API.
type CompletionUsage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
// CompletionResponse represents a response structure for completion API
type CompletionResponse struct { type CompletionResponse struct {
ID string `json:"id"` ID string `json:"id"`
Object string `json:"object"` Object string `json:"object"`
Created uint64 `json:"created"` Created uint64 `json:"created"`
Model string `json:"model"` Model string `json:"model"`
Choices []CompletionChoice `json:"choices"` Choices []CompletionChoice `json:"choices"`
Usage CompletionUsage `json:"usage"` Usage Usage `json:"usage"`
} }
// CreateCompletion — API call to create a completion. This is the main endpoint of the API. Returns new text as well // CreateCompletion — API call to create a completion. This is the main endpoint of the API. Returns new text as well
@@ -64,7 +84,10 @@ type CompletionResponse struct {
// //
// If using a fine-tuned model, simply provide the model's ID in the CompletionRequest object, // If using a fine-tuned model, simply provide the model's ID in the CompletionRequest object,
// and the server will use the model's parameters to generate the completion. // and the server will use the model's parameters to generate the completion.
func (c *Client) CreateCompletion(ctx context.Context, request CompletionRequest) (response CompletionResponse, err error) { func (c *Client) CreateCompletion(
ctx context.Context,
request CompletionRequest,
) (response CompletionResponse, err error) {
var reqBytes []byte var reqBytes []byte
reqBytes, err = json.Marshal(request) reqBytes, err = json.Marshal(request)
if err != nil { if err != nil {

View File

@@ -7,7 +7,7 @@ import (
"net/http" "net/http"
) )
// EditsRequest represents a request structure for Edits API // EditsRequest represents a request structure for Edits API.
type EditsRequest struct { type EditsRequest struct {
Model *string `json:"model,omitempty"` Model *string `json:"model,omitempty"`
Input string `json:"input,omitempty"` Input string `json:"input,omitempty"`
@@ -17,24 +17,17 @@ type EditsRequest struct {
TopP float32 `json:"top_p,omitempty"` TopP float32 `json:"top_p,omitempty"`
} }
// EditsUsage represents Usage of EditsResponse // EditsChoice represents one of possible edits.
type EditsUsage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
// EditsChoice represents one of possible edits
type EditsChoice struct { type EditsChoice struct {
Text string `json:"text"` Text string `json:"text"`
Index int `json:"index"` Index int `json:"index"`
} }
// EditsResponse represents a response structure for Edits API // EditsResponse represents a response structure for Edits API.
type EditsResponse struct { type EditsResponse struct {
Object string `json:"object"` Object string `json:"object"`
Created uint64 `json:"created"` Created uint64 `json:"created"`
Usage EditsUsage `json:"usage"` Usage Usage `json:"usage"`
Choices []EditsChoice `json:"choices"` Choices []EditsChoice `json:"choices"`
} }

View File

@@ -92,11 +92,12 @@ var stringToEnum = map[string]EmbeddingModel{
"code-search-babbage-text-001": BabbageCodeSearchText, "code-search-babbage-text-001": BabbageCodeSearchText,
} }
// Embedding is a special format of data representation that can be easily utilized by machine learning models and algorithms. // Embedding is a special format of data representation that can be easily utilized by machine
// The embedding is an information dense representation of the semantic meaning of a piece of text. Each embedding is a vector of // learning models and algorithms. The embedding is an information dense representation of the
// floating point numbers, such that the distance between two embeddings in the vector space is correlated with semantic similarity // semantic meaning of a piece of text. Each embedding is a vector of floating point numbers,
// between two inputs in the original format. For example, if two texts are similar, then their vector representations should // such that the distance between two embeddings in the vector space is correlated with semantic similarity
// also be similar. // between two inputs in the original format. For example, if two texts are similar,
// then their vector representations should also be similar.
type Embedding struct { type Embedding struct {
Object string `json:"object"` Object string `json:"object"`
Embedding []float64 `json:"embedding"` Embedding []float64 `json:"embedding"`

View File

@@ -6,7 +6,7 @@ import (
"net/http" "net/http"
) )
// Engine struct represents engine from OpenAPI API // Engine struct represents engine from OpenAPI API.
type Engine struct { type Engine struct {
ID string `json:"id"` ID string `json:"id"`
Object string `json:"object"` Object string `json:"object"`
@@ -14,12 +14,13 @@ type Engine struct {
Ready bool `json:"ready"` Ready bool `json:"ready"`
} }
// EnginesList is a list of engines // EnginesList is a list of engines.
type EnginesList struct { type EnginesList struct {
Engines []Engine `json:"data"` Engines []Engine `json:"data"`
} }
// ListEngines Lists the currently available engines, and provides basic information about each option such as the owner and availability. // ListEngines Lists the currently available engines, and provides basic
// information about each option such as the owner and availability.
func (c *Client) ListEngines(ctx context.Context) (engines EnginesList, err error) { func (c *Client) ListEngines(ctx context.Context) (engines EnginesList, err error) {
req, err := http.NewRequest("GET", c.fullURL("/engines"), nil) req, err := http.NewRequest("GET", c.fullURL("/engines"), nil)
if err != nil { if err != nil {
@@ -31,8 +32,12 @@ func (c *Client) ListEngines(ctx context.Context) (engines EnginesList, err erro
return return
} }
// GetEngine Retrieves an engine instance, providing basic information about the engine such as the owner and availability. // GetEngine Retrieves an engine instance, providing basic information about
func (c *Client) GetEngine(ctx context.Context, engineID string) (engine Engine, err error) { // the engine such as the owner and availability.
func (c *Client) GetEngine(
ctx context.Context,
engineID string,
) (engine Engine, err error) {
urlSuffix := fmt.Sprintf("/engines/%s", engineID) urlSuffix := fmt.Sprintf("/engines/%s", engineID)
req, err := http.NewRequest("GET", c.fullURL(urlSuffix), nil) req, err := http.NewRequest("GET", c.fullURL(urlSuffix), nil)
if err != nil { if err != nil {

View File

@@ -18,7 +18,7 @@ type FileRequest struct {
Purpose string `json:"purpose"` Purpose string `json:"purpose"`
} }
// File struct represents an OpenAPI file // File struct represents an OpenAPI file.
type File struct { type File struct {
Bytes int `json:"bytes"` Bytes int `json:"bytes"`
CreatedAt int `json:"created_at"` CreatedAt int `json:"created_at"`
@@ -29,13 +29,13 @@ type File struct {
Purpose string `json:"purpose"` Purpose string `json:"purpose"`
} }
// FilesList is a list of files that belong to the user or organization // FilesList is a list of files that belong to the user or organization.
type FilesList struct { type FilesList struct {
Files []File `json:"data"` Files []File `json:"data"`
} }
// isUrl is a helper function that determines whether the given FilePath // isUrl is a helper function that determines whether the given FilePath
// is a remote URL or a local file path // is a remote URL or a local file path.
func isURL(path string) bool { func isURL(path string) bool {
_, err := url.ParseRequestURI(path) _, err := url.ParseRequestURI(path)
if err != nil { if err != nil {
@@ -51,7 +51,7 @@ func isURL(path string) bool {
} }
// CreateFile uploads a jsonl file to GPT3 // CreateFile uploads a jsonl file to GPT3
// FilePath can be either a local file path or a URL // FilePath can be either a local file path or a URL.
func (c *Client) CreateFile(ctx context.Context, request FileRequest) (file File, err error) { func (c *Client) CreateFile(ctx context.Context, request FileRequest) (file File, err error) {
var b bytes.Buffer var b bytes.Buffer
w := multipart.NewWriter(&b) w := multipart.NewWriter(&b)
@@ -116,7 +116,7 @@ func (c *Client) CreateFile(ctx context.Context, request FileRequest) (file File
return return
} }
// DeleteFile deletes an existing file // DeleteFile deletes an existing file.
func (c *Client) DeleteFile(ctx context.Context, fileID string) (err error) { func (c *Client) DeleteFile(ctx context.Context, fileID string) (err error) {
req, err := http.NewRequest("DELETE", c.fullURL("/files/"+fileID), nil) req, err := http.NewRequest("DELETE", c.fullURL("/files/"+fileID), nil)
if err != nil { if err != nil {

View File

@@ -7,20 +7,20 @@ import (
"net/http" "net/http"
) )
// ModerationRequest represents a request structure for moderation API // ModerationRequest represents a request structure for moderation API.
type ModerationRequest struct { type ModerationRequest struct {
Input string `json:"input,omitempty"` Input string `json:"input,omitempty"`
Model *string `json:"model,omitempty"` Model *string `json:"model,omitempty"`
} }
// Result represents one of possible moderation results // Result represents one of possible moderation results.
type Result struct { type Result struct {
Categories ResultCategories `json:"categories"` Categories ResultCategories `json:"categories"`
CategoryScores ResultCategoryScores `json:"category_scores"` CategoryScores ResultCategoryScores `json:"category_scores"`
Flagged bool `json:"flagged"` Flagged bool `json:"flagged"`
} }
// ResultCategories represents Categories of Result // ResultCategories represents Categories of Result.
type ResultCategories struct { type ResultCategories struct {
Hate bool `json:"hate"` Hate bool `json:"hate"`
HateThreatening bool `json:"hate/threatening"` HateThreatening bool `json:"hate/threatening"`
@@ -31,7 +31,7 @@ type ResultCategories struct {
ViolenceGraphic bool `json:"violence/graphic"` ViolenceGraphic bool `json:"violence/graphic"`
} }
// ResultCategoryScores represents CategoryScores of Result // ResultCategoryScores represents CategoryScores of Result.
type ResultCategoryScores struct { type ResultCategoryScores struct {
Hate float32 `json:"hate"` Hate float32 `json:"hate"`
HateThreatening float32 `json:"hate/threatening"` HateThreatening float32 `json:"hate/threatening"`
@@ -42,7 +42,7 @@ type ResultCategoryScores struct {
ViolenceGraphic float32 `json:"violence/graphic"` ViolenceGraphic float32 `json:"violence/graphic"`
} }
// ModerationResponse represents a response structure for moderation API // ModerationResponse represents a response structure for moderation API.
type ModerationResponse struct { type ModerationResponse struct {
ID string `json:"id"` ID string `json:"id"`
Model string `json:"model"` Model string `json:"model"`

View File

@@ -24,7 +24,7 @@ type SearchRequest struct {
User string `json:"user,omitempty"` User string `json:"user,omitempty"`
} }
// SearchResult represents single result from search API // SearchResult represents single result from search API.
type SearchResult struct { type SearchResult struct {
Document int `json:"document"` Document int `json:"document"`
Object string `json:"object"` Object string `json:"object"`
@@ -32,14 +32,18 @@ type SearchResult struct {
Metadata string `json:"metadata"` // 2* Metadata string `json:"metadata"` // 2*
} }
// SearchResponse represents a response structure for search API // SearchResponse represents a response structure for search API.
type SearchResponse struct { type SearchResponse struct {
SearchResults []SearchResult `json:"data"` SearchResults []SearchResult `json:"data"`
Object string `json:"object"` Object string `json:"object"`
} }
// Search — perform a semantic search api call over a list of documents. // Search — perform a semantic search api call over a list of documents.
func (c *Client) Search(ctx context.Context, engineID string, request SearchRequest) (response SearchResponse, err error) { func (c *Client) Search(
ctx context.Context,
engineID string,
request SearchRequest,
) (response SearchResponse, err error) {
var reqBytes []byte var reqBytes []byte
reqBytes, err = json.Marshal(request) reqBytes, err = json.Marshal(request)
if err != nil { if err != nil {