Better configuration (#79)
* Configurable Transport (#75) * new functions to allow HTTPClient configuration * updated go.mod for testing from remote * updated go.mod for remote testing * revert go.mod replace directives * Fixed NewOrgClientWithTransport comment * Make client fully configurable * make empty messages limit configurable #70 #71 * make auth token private in config * add docs * lint --------- Co-authored-by: Michael Fox <m.will.fox@gmail.com>
This commit is contained in:
45
api.go
45
api.go
@@ -6,43 +6,34 @@ import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
const apiURLv1 = "https://api.openai.com/v1"
|
||||
|
||||
func newTransport() *http.Client {
|
||||
return &http.Client{}
|
||||
}
|
||||
|
||||
// Client is OpenAI GPT-3 API client.
|
||||
type Client struct {
|
||||
BaseURL string
|
||||
HTTPClient *http.Client
|
||||
authToken string
|
||||
idOrg string
|
||||
config ClientConfig
|
||||
}
|
||||
|
||||
// NewClient creates new OpenAI API client.
|
||||
func NewClient(authToken string) *Client {
|
||||
return &Client{
|
||||
BaseURL: apiURLv1,
|
||||
HTTPClient: newTransport(),
|
||||
authToken: authToken,
|
||||
idOrg: "",
|
||||
}
|
||||
config := DefaultConfig(authToken)
|
||||
return &Client{config}
|
||||
}
|
||||
|
||||
// NewClientWithConfig creates new OpenAI API client for specified config.
|
||||
func NewClientWithConfig(config ClientConfig) *Client {
|
||||
return &Client{config}
|
||||
}
|
||||
|
||||
// NewOrgClient creates new OpenAI API client for specified Organization ID.
|
||||
//
|
||||
// Deprecated: Please use NewClientWithConfig.
|
||||
func NewOrgClient(authToken, org string) *Client {
|
||||
return &Client{
|
||||
BaseURL: apiURLv1,
|
||||
HTTPClient: newTransport(),
|
||||
authToken: authToken,
|
||||
idOrg: org,
|
||||
}
|
||||
config := DefaultConfig(authToken)
|
||||
config.OrgID = org
|
||||
return &Client{config}
|
||||
}
|
||||
|
||||
func (c *Client) sendRequest(req *http.Request, v interface{}) error {
|
||||
req.Header.Set("Accept", "application/json; charset=utf-8")
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.authToken))
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken))
|
||||
|
||||
// Check whether Content-Type is already set, Upload Files API requires
|
||||
// Content-Type == multipart/form-data
|
||||
@@ -51,11 +42,11 @@ func (c *Client) sendRequest(req *http.Request, v interface{}) error {
|
||||
req.Header.Set("Content-Type", "application/json; charset=utf-8")
|
||||
}
|
||||
|
||||
if len(c.idOrg) > 0 {
|
||||
req.Header.Set("OpenAI-Organization", c.idOrg)
|
||||
if len(c.config.OrgID) > 0 {
|
||||
req.Header.Set("OpenAI-Organization", c.config.OrgID)
|
||||
}
|
||||
|
||||
res, err := c.HTTPClient.Do(req)
|
||||
res, err := c.config.HTTPClient.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -86,5 +77,5 @@ func (c *Client) sendRequest(req *http.Request, v interface{}) error {
|
||||
}
|
||||
|
||||
func (c *Client) fullURL(suffix string) string {
|
||||
return fmt.Sprintf("%s%s", c.BaseURL, suffix)
|
||||
return fmt.Sprintf("%s%s", c.config.BaseURL, suffix)
|
||||
}
|
||||
|
||||
@@ -110,8 +110,10 @@ func TestAPIError(t *testing.T) {
|
||||
|
||||
func TestRequestError(t *testing.T) {
|
||||
var err error
|
||||
c := NewClient("dummy")
|
||||
c.BaseURL = "https://httpbin.org/status/418?"
|
||||
|
||||
config := DefaultConfig("dummy")
|
||||
config.BaseURL = "https://httpbin.org/status/418?"
|
||||
c := NewClientWithConfig(config)
|
||||
ctx := context.Background()
|
||||
_, err = c.ListEngines(ctx)
|
||||
if err == nil {
|
||||
|
||||
@@ -25,9 +25,10 @@ func TestCompletions(t *testing.T) {
|
||||
ts.Start()
|
||||
defer ts.Close()
|
||||
|
||||
client := NewClient(test.GetTestToken())
|
||||
config := DefaultConfig(test.GetTestToken())
|
||||
config.BaseURL = ts.URL + "/v1"
|
||||
client := NewClientWithConfig(config)
|
||||
ctx := context.Background()
|
||||
client.BaseURL = ts.URL + "/v1"
|
||||
|
||||
req := CompletionRequest{
|
||||
MaxTokens: 5,
|
||||
|
||||
33
config.go
Normal file
33
config.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package gogpt
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
const (
|
||||
apiURLv1 = "https://api.openai.com/v1"
|
||||
defaultEmptyMessagesLimit uint = 300
|
||||
)
|
||||
|
||||
// ClientConfig is a configuration of a client.
|
||||
type ClientConfig struct {
|
||||
authToken string
|
||||
|
||||
HTTPClient *http.Client
|
||||
|
||||
BaseURL string
|
||||
OrgID string
|
||||
|
||||
EmptyMessagesLimit uint
|
||||
}
|
||||
|
||||
func DefaultConfig(authToken string) ClientConfig {
|
||||
return ClientConfig{
|
||||
HTTPClient: &http.Client{},
|
||||
BaseURL: apiURLv1,
|
||||
OrgID: "",
|
||||
authToken: authToken,
|
||||
|
||||
EmptyMessagesLimit: defaultEmptyMessagesLimit,
|
||||
}
|
||||
}
|
||||
@@ -23,9 +23,10 @@ func TestEdits(t *testing.T) {
|
||||
ts.Start()
|
||||
defer ts.Close()
|
||||
|
||||
client := NewClient(test.GetTestToken())
|
||||
config := DefaultConfig(test.GetTestToken())
|
||||
config.BaseURL = ts.URL + "/v1"
|
||||
client := NewClientWithConfig(config)
|
||||
ctx := context.Background()
|
||||
client.BaseURL = ts.URL + "/v1"
|
||||
|
||||
// create an edit request
|
||||
model := "ada"
|
||||
|
||||
@@ -22,9 +22,10 @@ func TestFileUpload(t *testing.T) {
|
||||
ts.Start()
|
||||
defer ts.Close()
|
||||
|
||||
client := NewClient(test.GetTestToken())
|
||||
config := DefaultConfig(test.GetTestToken())
|
||||
config.BaseURL = ts.URL + "/v1"
|
||||
client := NewClientWithConfig(config)
|
||||
ctx := context.Background()
|
||||
client.BaseURL = ts.URL + "/v1"
|
||||
|
||||
req := FileRequest{
|
||||
FileName: "test.go",
|
||||
|
||||
@@ -23,9 +23,10 @@ func TestImages(t *testing.T) {
|
||||
ts.Start()
|
||||
defer ts.Close()
|
||||
|
||||
client := NewClient(test.GetTestToken())
|
||||
config := DefaultConfig(test.GetTestToken())
|
||||
config.BaseURL = ts.URL + "/v1"
|
||||
client := NewClientWithConfig(config)
|
||||
ctx := context.Background()
|
||||
client.BaseURL = ts.URL + "/v1"
|
||||
|
||||
req := ImageRequest{}
|
||||
req.Prompt = "Lorem ipsum"
|
||||
@@ -94,9 +95,10 @@ func TestImageEdit(t *testing.T) {
|
||||
ts.Start()
|
||||
defer ts.Close()
|
||||
|
||||
client := NewClient(test.GetTestToken())
|
||||
config := DefaultConfig(test.GetTestToken())
|
||||
config.BaseURL = ts.URL + "/v1"
|
||||
client := NewClientWithConfig(config)
|
||||
ctx := context.Background()
|
||||
client.BaseURL = ts.URL + "/v1"
|
||||
|
||||
origin, err := os.Create("image.png")
|
||||
if err != nil {
|
||||
|
||||
@@ -25,9 +25,10 @@ func TestModerations(t *testing.T) {
|
||||
ts.Start()
|
||||
defer ts.Close()
|
||||
|
||||
client := NewClient(test.GetTestToken())
|
||||
config := DefaultConfig(test.GetTestToken())
|
||||
config.BaseURL = ts.URL + "/v1"
|
||||
client := NewClientWithConfig(config)
|
||||
ctx := context.Background()
|
||||
client.BaseURL = ts.URL + "/v1"
|
||||
|
||||
// create an edit request
|
||||
model := "text-moderation-stable"
|
||||
|
||||
13
stream.go
13
stream.go
@@ -11,17 +11,18 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
emptyMessagesLimit = 300
|
||||
ErrTooManyEmptyStreamMessages = errors.New("stream has sent too many empty messages")
|
||||
)
|
||||
|
||||
type CompletionStream struct {
|
||||
emptyMessagesLimit uint
|
||||
|
||||
reader *bufio.Reader
|
||||
response *http.Response
|
||||
}
|
||||
|
||||
func (stream *CompletionStream) Recv() (response CompletionResponse, err error) {
|
||||
emptyMessagesCount := 0
|
||||
var emptyMessagesCount uint
|
||||
|
||||
waitForData:
|
||||
line, err := stream.reader.ReadBytes('\n')
|
||||
@@ -33,7 +34,7 @@ waitForData:
|
||||
line = bytes.TrimSpace(line)
|
||||
if !bytes.HasPrefix(line, headerData) {
|
||||
emptyMessagesCount++
|
||||
if emptyMessagesCount > emptyMessagesLimit {
|
||||
if emptyMessagesCount > stream.emptyMessagesLimit {
|
||||
err = ErrTooManyEmptyStreamMessages
|
||||
return
|
||||
}
|
||||
@@ -74,18 +75,20 @@ func (c *Client) CreateCompletionStream(
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
req.Header.Set("Cache-Control", "no-cache")
|
||||
req.Header.Set("Connection", "keep-alive")
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.authToken))
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
req = req.WithContext(ctx)
|
||||
resp, err := c.HTTPClient.Do(req) //nolint:bodyclose // body is closed in stream.Close()
|
||||
resp, err := c.config.HTTPClient.Do(req) //nolint:bodyclose // body is closed in stream.Close()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
stream = &CompletionStream{
|
||||
emptyMessagesLimit: c.config.EmptyMessagesLimit,
|
||||
|
||||
reader: bufio.NewReader(resp.Body),
|
||||
response: resp,
|
||||
}
|
||||
|
||||
@@ -37,9 +37,15 @@ func TestCreateCompletionStream(t *testing.T) {
|
||||
defer server.Close()
|
||||
|
||||
// Client portion of the test
|
||||
client := NewClient(test.GetTestToken())
|
||||
config := DefaultConfig(test.GetTestToken())
|
||||
config.BaseURL = server.URL + "/v1"
|
||||
config.HTTPClient.Transport = &tokenRoundTripper{
|
||||
test.GetTestToken(),
|
||||
http.DefaultTransport,
|
||||
}
|
||||
|
||||
client := NewClientWithConfig(config)
|
||||
ctx := context.Background()
|
||||
client.BaseURL = server.URL + "/v1"
|
||||
|
||||
request := CompletionRequest{
|
||||
Prompt: "Ex falso quodlibet",
|
||||
@@ -48,11 +54,6 @@ func TestCreateCompletionStream(t *testing.T) {
|
||||
Stream: true,
|
||||
}
|
||||
|
||||
client.HTTPClient.Transport = &tokenRoundTripper{
|
||||
test.GetTestToken(),
|
||||
http.DefaultTransport,
|
||||
}
|
||||
|
||||
stream, err := client.CreateCompletionStream(ctx, request)
|
||||
if err != nil {
|
||||
t.Errorf("CreateCompletionStream returned error: %v", err)
|
||||
|
||||
Reference in New Issue
Block a user