Better configuration (#79)

* Configurable Transport (#75)

* new functions to allow HTTPClient configuration

* updated go.mod for testing from remote

* updated go.mod for remote testing

* revert go.mod replace directives

* Fixed NewOrgClientWithTransport comment

* Make client fully configurable

* make empty messages limit configurable #70 #71

* make auth token private in config

* add docs

* lint

---------

Co-authored-by: Michael Fox <m.will.fox@gmail.com>
This commit is contained in:
sashabaranov
2023-02-21 00:16:44 +04:00
committed by GitHub
parent 133d2c9184
commit 1eb5d625f8
10 changed files with 89 additions and 53 deletions

43
api.go
View File

@@ -6,43 +6,34 @@ import (
"net/http"
)
const apiURLv1 = "https://api.openai.com/v1"
func newTransport() *http.Client {
return &http.Client{}
}
// Client is OpenAI GPT-3 API client.
type Client struct {
BaseURL string
HTTPClient *http.Client
authToken string
idOrg string
config ClientConfig
}
// NewClient creates new OpenAI API client.
func NewClient(authToken string) *Client {
return &Client{
BaseURL: apiURLv1,
HTTPClient: newTransport(),
authToken: authToken,
idOrg: "",
config := DefaultConfig(authToken)
return &Client{config}
}
// NewClientWithConfig creates new OpenAI API client for specified config.
func NewClientWithConfig(config ClientConfig) *Client {
return &Client{config}
}
// NewOrgClient creates new OpenAI API client for specified Organization ID.
//
// Deprecated: Please use NewClientWithConfig.
func NewOrgClient(authToken, org string) *Client {
return &Client{
BaseURL: apiURLv1,
HTTPClient: newTransport(),
authToken: authToken,
idOrg: org,
}
config := DefaultConfig(authToken)
config.OrgID = org
return &Client{config}
}
func (c *Client) sendRequest(req *http.Request, v interface{}) error {
req.Header.Set("Accept", "application/json; charset=utf-8")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.authToken))
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken))
// Check whether Content-Type is already set, Upload Files API requires
// Content-Type == multipart/form-data
@@ -51,11 +42,11 @@ func (c *Client) sendRequest(req *http.Request, v interface{}) error {
req.Header.Set("Content-Type", "application/json; charset=utf-8")
}
if len(c.idOrg) > 0 {
req.Header.Set("OpenAI-Organization", c.idOrg)
if len(c.config.OrgID) > 0 {
req.Header.Set("OpenAI-Organization", c.config.OrgID)
}
res, err := c.HTTPClient.Do(req)
res, err := c.config.HTTPClient.Do(req)
if err != nil {
return err
}
@@ -86,5 +77,5 @@ func (c *Client) sendRequest(req *http.Request, v interface{}) error {
}
func (c *Client) fullURL(suffix string) string {
return fmt.Sprintf("%s%s", c.BaseURL, suffix)
return fmt.Sprintf("%s%s", c.config.BaseURL, suffix)
}

View File

@@ -110,8 +110,10 @@ func TestAPIError(t *testing.T) {
func TestRequestError(t *testing.T) {
var err error
c := NewClient("dummy")
c.BaseURL = "https://httpbin.org/status/418?"
config := DefaultConfig("dummy")
config.BaseURL = "https://httpbin.org/status/418?"
c := NewClientWithConfig(config)
ctx := context.Background()
_, err = c.ListEngines(ctx)
if err == nil {

View File

@@ -25,9 +25,10 @@ func TestCompletions(t *testing.T) {
ts.Start()
defer ts.Close()
client := NewClient(test.GetTestToken())
config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
ctx := context.Background()
client.BaseURL = ts.URL + "/v1"
req := CompletionRequest{
MaxTokens: 5,

33
config.go Normal file
View File

@@ -0,0 +1,33 @@
package gogpt
import (
"net/http"
)
const (
apiURLv1 = "https://api.openai.com/v1"
defaultEmptyMessagesLimit uint = 300
)
// ClientConfig is a configuration of a client.
type ClientConfig struct {
authToken string
HTTPClient *http.Client
BaseURL string
OrgID string
EmptyMessagesLimit uint
}
func DefaultConfig(authToken string) ClientConfig {
return ClientConfig{
HTTPClient: &http.Client{},
BaseURL: apiURLv1,
OrgID: "",
authToken: authToken,
EmptyMessagesLimit: defaultEmptyMessagesLimit,
}
}

View File

@@ -23,9 +23,10 @@ func TestEdits(t *testing.T) {
ts.Start()
defer ts.Close()
client := NewClient(test.GetTestToken())
config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
ctx := context.Background()
client.BaseURL = ts.URL + "/v1"
// create an edit request
model := "ada"

View File

@@ -22,9 +22,10 @@ func TestFileUpload(t *testing.T) {
ts.Start()
defer ts.Close()
client := NewClient(test.GetTestToken())
config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
ctx := context.Background()
client.BaseURL = ts.URL + "/v1"
req := FileRequest{
FileName: "test.go",

View File

@@ -23,9 +23,10 @@ func TestImages(t *testing.T) {
ts.Start()
defer ts.Close()
client := NewClient(test.GetTestToken())
config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
ctx := context.Background()
client.BaseURL = ts.URL + "/v1"
req := ImageRequest{}
req.Prompt = "Lorem ipsum"
@@ -94,9 +95,10 @@ func TestImageEdit(t *testing.T) {
ts.Start()
defer ts.Close()
client := NewClient(test.GetTestToken())
config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
ctx := context.Background()
client.BaseURL = ts.URL + "/v1"
origin, err := os.Create("image.png")
if err != nil {

View File

@@ -25,9 +25,10 @@ func TestModerations(t *testing.T) {
ts.Start()
defer ts.Close()
client := NewClient(test.GetTestToken())
config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
ctx := context.Background()
client.BaseURL = ts.URL + "/v1"
// create an edit request
model := "text-moderation-stable"

View File

@@ -11,17 +11,18 @@ import (
)
var (
emptyMessagesLimit = 300
ErrTooManyEmptyStreamMessages = errors.New("stream has sent too many empty messages")
)
type CompletionStream struct {
emptyMessagesLimit uint
reader *bufio.Reader
response *http.Response
}
func (stream *CompletionStream) Recv() (response CompletionResponse, err error) {
emptyMessagesCount := 0
var emptyMessagesCount uint
waitForData:
line, err := stream.reader.ReadBytes('\n')
@@ -33,7 +34,7 @@ waitForData:
line = bytes.TrimSpace(line)
if !bytes.HasPrefix(line, headerData) {
emptyMessagesCount++
if emptyMessagesCount > emptyMessagesLimit {
if emptyMessagesCount > stream.emptyMessagesLimit {
err = ErrTooManyEmptyStreamMessages
return
}
@@ -74,18 +75,20 @@ func (c *Client) CreateCompletionStream(
req.Header.Set("Accept", "text/event-stream")
req.Header.Set("Cache-Control", "no-cache")
req.Header.Set("Connection", "keep-alive")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.authToken))
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken))
if err != nil {
return
}
req = req.WithContext(ctx)
resp, err := c.HTTPClient.Do(req) //nolint:bodyclose // body is closed in stream.Close()
resp, err := c.config.HTTPClient.Do(req) //nolint:bodyclose // body is closed in stream.Close()
if err != nil {
return
}
stream = &CompletionStream{
emptyMessagesLimit: c.config.EmptyMessagesLimit,
reader: bufio.NewReader(resp.Body),
response: resp,
}

View File

@@ -37,9 +37,15 @@ func TestCreateCompletionStream(t *testing.T) {
defer server.Close()
// Client portion of the test
client := NewClient(test.GetTestToken())
config := DefaultConfig(test.GetTestToken())
config.BaseURL = server.URL + "/v1"
config.HTTPClient.Transport = &tokenRoundTripper{
test.GetTestToken(),
http.DefaultTransport,
}
client := NewClientWithConfig(config)
ctx := context.Background()
client.BaseURL = server.URL + "/v1"
request := CompletionRequest{
Prompt: "Ex falso quodlibet",
@@ -48,11 +54,6 @@ func TestCreateCompletionStream(t *testing.T) {
Stream: true,
}
client.HTTPClient.Transport = &tokenRoundTripper{
test.GetTestToken(),
http.DefaultTransport,
}
stream, err := client.CreateCompletionStream(ctx, request)
if err != nil {
t.Errorf("CreateCompletionStream returned error: %v", err)