diff --git a/api.go b/api.go index 0a7bf15..715c9dd 100644 --- a/api.go +++ b/api.go @@ -6,43 +6,34 @@ import ( "net/http" ) -const apiURLv1 = "https://api.openai.com/v1" - -func newTransport() *http.Client { - return &http.Client{} -} - // Client is OpenAI GPT-3 API client. type Client struct { - BaseURL string - HTTPClient *http.Client - authToken string - idOrg string + config ClientConfig } // NewClient creates new OpenAI API client. func NewClient(authToken string) *Client { - return &Client{ - BaseURL: apiURLv1, - HTTPClient: newTransport(), - authToken: authToken, - idOrg: "", - } + config := DefaultConfig(authToken) + return &Client{config} +} + +// NewClientWithConfig creates new OpenAI API client for specified config. +func NewClientWithConfig(config ClientConfig) *Client { + return &Client{config} } // NewOrgClient creates new OpenAI API client for specified Organization ID. +// +// Deprecated: Please use NewClientWithConfig. func NewOrgClient(authToken, org string) *Client { - return &Client{ - BaseURL: apiURLv1, - HTTPClient: newTransport(), - authToken: authToken, - idOrg: org, - } + config := DefaultConfig(authToken) + config.OrgID = org + return &Client{config} } func (c *Client) sendRequest(req *http.Request, v interface{}) error { req.Header.Set("Accept", "application/json; charset=utf-8") - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.authToken)) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken)) // Check whether Content-Type is already set, Upload Files API requires // Content-Type == multipart/form-data @@ -51,11 +42,11 @@ func (c *Client) sendRequest(req *http.Request, v interface{}) error { req.Header.Set("Content-Type", "application/json; charset=utf-8") } - if len(c.idOrg) > 0 { - req.Header.Set("OpenAI-Organization", c.idOrg) + if len(c.config.OrgID) > 0 { + req.Header.Set("OpenAI-Organization", c.config.OrgID) } - res, err := c.HTTPClient.Do(req) + res, err := c.config.HTTPClient.Do(req) if err != nil { return err } @@ -86,5 +77,5 @@ func (c *Client) sendRequest(req *http.Request, v interface{}) error { } func (c *Client) fullURL(suffix string) string { - return fmt.Sprintf("%s%s", c.BaseURL, suffix) + return fmt.Sprintf("%s%s", c.config.BaseURL, suffix) } diff --git a/api_test.go b/api_test.go index 4c8732f..e7b6813 100644 --- a/api_test.go +++ b/api_test.go @@ -110,8 +110,10 @@ func TestAPIError(t *testing.T) { func TestRequestError(t *testing.T) { var err error - c := NewClient("dummy") - c.BaseURL = "https://httpbin.org/status/418?" + + config := DefaultConfig("dummy") + config.BaseURL = "https://httpbin.org/status/418?" + c := NewClientWithConfig(config) ctx := context.Background() _, err = c.ListEngines(ctx) if err == nil { diff --git a/completion_test.go b/completion_test.go index c96df1a..594a23c 100644 --- a/completion_test.go +++ b/completion_test.go @@ -25,9 +25,10 @@ func TestCompletions(t *testing.T) { ts.Start() defer ts.Close() - client := NewClient(test.GetTestToken()) + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = ts.URL + "/v1" + client := NewClientWithConfig(config) ctx := context.Background() - client.BaseURL = ts.URL + "/v1" req := CompletionRequest{ MaxTokens: 5, diff --git a/config.go b/config.go new file mode 100644 index 0000000..561962e --- /dev/null +++ b/config.go @@ -0,0 +1,33 @@ +package gogpt + +import ( + "net/http" +) + +const ( + apiURLv1 = "https://api.openai.com/v1" + defaultEmptyMessagesLimit uint = 300 +) + +// ClientConfig is a configuration of a client. +type ClientConfig struct { + authToken string + + HTTPClient *http.Client + + BaseURL string + OrgID string + + EmptyMessagesLimit uint +} + +func DefaultConfig(authToken string) ClientConfig { + return ClientConfig{ + HTTPClient: &http.Client{}, + BaseURL: apiURLv1, + OrgID: "", + authToken: authToken, + + EmptyMessagesLimit: defaultEmptyMessagesLimit, + } +} diff --git a/edits_test.go b/edits_test.go index 499d098..e61c0ae 100644 --- a/edits_test.go +++ b/edits_test.go @@ -23,9 +23,10 @@ func TestEdits(t *testing.T) { ts.Start() defer ts.Close() - client := NewClient(test.GetTestToken()) + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = ts.URL + "/v1" + client := NewClientWithConfig(config) ctx := context.Background() - client.BaseURL = ts.URL + "/v1" // create an edit request model := "ada" diff --git a/files_test.go b/files_test.go index 94c8904..5c563f4 100644 --- a/files_test.go +++ b/files_test.go @@ -22,9 +22,10 @@ func TestFileUpload(t *testing.T) { ts.Start() defer ts.Close() - client := NewClient(test.GetTestToken()) + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = ts.URL + "/v1" + client := NewClientWithConfig(config) ctx := context.Background() - client.BaseURL = ts.URL + "/v1" req := FileRequest{ FileName: "test.go", diff --git a/image_test.go b/image_test.go index 6eaf182..c273fcf 100644 --- a/image_test.go +++ b/image_test.go @@ -23,9 +23,10 @@ func TestImages(t *testing.T) { ts.Start() defer ts.Close() - client := NewClient(test.GetTestToken()) + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = ts.URL + "/v1" + client := NewClientWithConfig(config) ctx := context.Background() - client.BaseURL = ts.URL + "/v1" req := ImageRequest{} req.Prompt = "Lorem ipsum" @@ -94,9 +95,10 @@ func TestImageEdit(t *testing.T) { ts.Start() defer ts.Close() - client := NewClient(test.GetTestToken()) + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = ts.URL + "/v1" + client := NewClientWithConfig(config) ctx := context.Background() - client.BaseURL = ts.URL + "/v1" origin, err := os.Create("image.png") if err != nil { diff --git a/moderation_test.go b/moderation_test.go index 3198cb6..68f0e64 100644 --- a/moderation_test.go +++ b/moderation_test.go @@ -25,9 +25,10 @@ func TestModerations(t *testing.T) { ts.Start() defer ts.Close() - client := NewClient(test.GetTestToken()) + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = ts.URL + "/v1" + client := NewClientWithConfig(config) ctx := context.Background() - client.BaseURL = ts.URL + "/v1" // create an edit request model := "text-moderation-stable" diff --git a/stream.go b/stream.go index 2fd6715..51c7f84 100644 --- a/stream.go +++ b/stream.go @@ -11,17 +11,18 @@ import ( ) var ( - emptyMessagesLimit = 300 ErrTooManyEmptyStreamMessages = errors.New("stream has sent too many empty messages") ) type CompletionStream struct { + emptyMessagesLimit uint + reader *bufio.Reader response *http.Response } func (stream *CompletionStream) Recv() (response CompletionResponse, err error) { - emptyMessagesCount := 0 + var emptyMessagesCount uint waitForData: line, err := stream.reader.ReadBytes('\n') @@ -33,7 +34,7 @@ waitForData: line = bytes.TrimSpace(line) if !bytes.HasPrefix(line, headerData) { emptyMessagesCount++ - if emptyMessagesCount > emptyMessagesLimit { + if emptyMessagesCount > stream.emptyMessagesLimit { err = ErrTooManyEmptyStreamMessages return } @@ -74,18 +75,20 @@ func (c *Client) CreateCompletionStream( req.Header.Set("Accept", "text/event-stream") req.Header.Set("Cache-Control", "no-cache") req.Header.Set("Connection", "keep-alive") - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.authToken)) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken)) if err != nil { return } req = req.WithContext(ctx) - resp, err := c.HTTPClient.Do(req) //nolint:bodyclose // body is closed in stream.Close() + resp, err := c.config.HTTPClient.Do(req) //nolint:bodyclose // body is closed in stream.Close() if err != nil { return } stream = &CompletionStream{ + emptyMessagesLimit: c.config.EmptyMessagesLimit, + reader: bufio.NewReader(resp.Body), response: resp, } diff --git a/stream_test.go b/stream_test.go index c19e534..cdf574a 100644 --- a/stream_test.go +++ b/stream_test.go @@ -37,9 +37,15 @@ func TestCreateCompletionStream(t *testing.T) { defer server.Close() // Client portion of the test - client := NewClient(test.GetTestToken()) + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = server.URL + "/v1" + config.HTTPClient.Transport = &tokenRoundTripper{ + test.GetTestToken(), + http.DefaultTransport, + } + + client := NewClientWithConfig(config) ctx := context.Background() - client.BaseURL = server.URL + "/v1" request := CompletionRequest{ Prompt: "Ex falso quodlibet", @@ -48,11 +54,6 @@ func TestCreateCompletionStream(t *testing.T) { Stream: true, } - client.HTTPClient.Transport = &tokenRoundTripper{ - test.GetTestToken(), - http.DefaultTransport, - } - stream, err := client.CreateCompletionStream(ctx, request) if err != nil { t.Errorf("CreateCompletionStream returned error: %v", err)