refactor: refactoring http request creation and sending (#395)
* refactoring http request creation and sending * fix lint error * increase the test coverage of client.go * refactor: Change the style of HTTPRequestBuilder.Build func to one-argument-per-line.
This commit is contained in:
committed by
GitHub
parent
157de0680f
commit
f1b66967a4
@@ -94,7 +94,7 @@ func TestRequestAuthHeader(t *testing.T) {
|
|||||||
az.OrgID = c.OrgID
|
az.OrgID = c.OrgID
|
||||||
|
|
||||||
cli := NewClientWithConfig(az)
|
cli := NewClientWithConfig(az)
|
||||||
req, err := cli.newStreamRequest(context.Background(), "POST", "/chat/completions", nil, "")
|
req, err := cli.newRequest(context.Background(), "POST", "/chat/completions")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Failed to create request: %v", err)
|
t.Errorf("Failed to create request: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
4
audio.go
4
audio.go
@@ -95,11 +95,11 @@ func (c *Client) callAudioAPI(
|
|||||||
}
|
}
|
||||||
|
|
||||||
urlSuffix := fmt.Sprintf("/audio/%s", endpointSuffix)
|
urlSuffix := fmt.Sprintf("/audio/%s", endpointSuffix)
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), &formBody)
|
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model),
|
||||||
|
withBody(&formBody), withContentType(builder.FormDataContentType()))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return AudioResponse{}, err
|
return AudioResponse{}, err
|
||||||
}
|
}
|
||||||
req.Header.Add("Content-Type", builder.FormDataContentType())
|
|
||||||
|
|
||||||
if request.HasJSONResponse() {
|
if request.HasJSONResponse() {
|
||||||
err = c.sendRequest(req, &response)
|
err = c.sendRequest(req, &response)
|
||||||
|
|||||||
2
chat.go
2
chat.go
@@ -152,7 +152,7 @@ func (c *Client) CreateChatCompletion(
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), request)
|
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), withBody(request))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,10 +1,8 @@
|
|||||||
package openai
|
package openai
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
|
||||||
"context"
|
"context"
|
||||||
|
"net/http"
|
||||||
utils "github.com/sashabaranov/go-openai/internal"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type ChatCompletionStreamChoiceDelta struct {
|
type ChatCompletionStreamChoiceDelta struct {
|
||||||
@@ -48,27 +46,17 @@ func (c *Client) CreateChatCompletionStream(
|
|||||||
}
|
}
|
||||||
|
|
||||||
request.Stream = true
|
request.Stream = true
|
||||||
req, err := c.newStreamRequest(ctx, "POST", urlSuffix, request, request.Model)
|
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), withBody(request))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := sendRequestStream[ChatCompletionStreamResponse](c, req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := c.config.HTTPClient.Do(req) //nolint:bodyclose // body is closed in stream.Close()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if isFailureStatusCode(resp) {
|
|
||||||
return nil, c.handleErrorResp(resp)
|
|
||||||
}
|
|
||||||
|
|
||||||
stream = &ChatCompletionStream{
|
stream = &ChatCompletionStream{
|
||||||
streamReader: &streamReader[ChatCompletionStreamResponse]{
|
streamReader: resp,
|
||||||
emptyMessagesLimit: c.config.EmptyMessagesLimit,
|
|
||||||
reader: bufio.NewReader(resp.Body),
|
|
||||||
response: resp,
|
|
||||||
errAccumulator: utils.NewErrorAccumulator(),
|
|
||||||
unmarshaler: &utils.JSONUnmarshaler{},
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
94
client.go
94
client.go
@@ -1,6 +1,7 @@
|
|||||||
package openai
|
package openai
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bufio"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -45,6 +46,42 @@ func NewOrgClient(authToken, org string) *Client {
|
|||||||
return NewClientWithConfig(config)
|
return NewClientWithConfig(config)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type requestOptions struct {
|
||||||
|
body any
|
||||||
|
header http.Header
|
||||||
|
}
|
||||||
|
|
||||||
|
type requestOption func(*requestOptions)
|
||||||
|
|
||||||
|
func withBody(body any) requestOption {
|
||||||
|
return func(args *requestOptions) {
|
||||||
|
args.body = body
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func withContentType(contentType string) requestOption {
|
||||||
|
return func(args *requestOptions) {
|
||||||
|
args.header.Set("Content-Type", contentType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) newRequest(ctx context.Context, method, url string, setters ...requestOption) (*http.Request, error) {
|
||||||
|
// Default Options
|
||||||
|
args := &requestOptions{
|
||||||
|
body: nil,
|
||||||
|
header: make(http.Header),
|
||||||
|
}
|
||||||
|
for _, setter := range setters {
|
||||||
|
setter(args)
|
||||||
|
}
|
||||||
|
req, err := c.requestBuilder.Build(ctx, method, url, args.body, args.header)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
c.setCommonHeaders(req)
|
||||||
|
return req, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Client) sendRequest(req *http.Request, v any) error {
|
func (c *Client) sendRequest(req *http.Request, v any) error {
|
||||||
req.Header.Set("Accept", "application/json; charset=utf-8")
|
req.Header.Set("Accept", "application/json; charset=utf-8")
|
||||||
|
|
||||||
@@ -55,8 +92,6 @@ func (c *Client) sendRequest(req *http.Request, v any) error {
|
|||||||
req.Header.Set("Content-Type", "application/json; charset=utf-8")
|
req.Header.Set("Content-Type", "application/json; charset=utf-8")
|
||||||
}
|
}
|
||||||
|
|
||||||
c.setCommonHeaders(req)
|
|
||||||
|
|
||||||
res, err := c.config.HTTPClient.Do(req)
|
res, err := c.config.HTTPClient.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -71,6 +106,41 @@ func (c *Client) sendRequest(req *http.Request, v any) error {
|
|||||||
return decodeResponse(res.Body, v)
|
return decodeResponse(res.Body, v)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Client) sendRequestRaw(req *http.Request) (body io.ReadCloser, err error) {
|
||||||
|
resp, err := c.config.HTTPClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if isFailureStatusCode(resp) {
|
||||||
|
err = c.handleErrorResp(resp)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return resp.Body, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func sendRequestStream[T streamable](client *Client, req *http.Request) (*streamReader[T], error) {
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("Accept", "text/event-stream")
|
||||||
|
req.Header.Set("Cache-Control", "no-cache")
|
||||||
|
req.Header.Set("Connection", "keep-alive")
|
||||||
|
|
||||||
|
resp, err := client.config.HTTPClient.Do(req) //nolint:bodyclose // body is closed in stream.Close()
|
||||||
|
if err != nil {
|
||||||
|
return new(streamReader[T]), err
|
||||||
|
}
|
||||||
|
if isFailureStatusCode(resp) {
|
||||||
|
return new(streamReader[T]), client.handleErrorResp(resp)
|
||||||
|
}
|
||||||
|
return &streamReader[T]{
|
||||||
|
emptyMessagesLimit: client.config.EmptyMessagesLimit,
|
||||||
|
reader: bufio.NewReader(resp.Body),
|
||||||
|
response: resp,
|
||||||
|
errAccumulator: utils.NewErrorAccumulator(),
|
||||||
|
unmarshaler: &utils.JSONUnmarshaler{},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Client) setCommonHeaders(req *http.Request) {
|
func (c *Client) setCommonHeaders(req *http.Request) {
|
||||||
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/reference#authentication
|
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/reference#authentication
|
||||||
// Azure API Key authentication
|
// Azure API Key authentication
|
||||||
@@ -138,26 +208,6 @@ func (c *Client) fullURL(suffix string, args ...any) string {
|
|||||||
return fmt.Sprintf("%s%s", c.config.BaseURL, suffix)
|
return fmt.Sprintf("%s%s", c.config.BaseURL, suffix)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) newStreamRequest(
|
|
||||||
ctx context.Context,
|
|
||||||
method string,
|
|
||||||
urlSuffix string,
|
|
||||||
body any,
|
|
||||||
model string) (*http.Request, error) {
|
|
||||||
req, err := c.requestBuilder.Build(ctx, method, c.fullURL(urlSuffix, model), body)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
req.Header.Set("Accept", "text/event-stream")
|
|
||||||
req.Header.Set("Cache-Control", "no-cache")
|
|
||||||
req.Header.Set("Connection", "keep-alive")
|
|
||||||
|
|
||||||
c.setCommonHeaders(req)
|
|
||||||
return req, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Client) handleErrorResp(resp *http.Response) error {
|
func (c *Client) handleErrorResp(resp *http.Response) error {
|
||||||
var errRes ErrorResponse
|
var errRes ErrorResponse
|
||||||
err := json.NewDecoder(resp.Body).Decode(&errRes)
|
err := json.NewDecoder(resp.Body).Decode(&errRes)
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ var errTestRequestBuilderFailed = errors.New("test request builder failed")
|
|||||||
|
|
||||||
type failingRequestBuilder struct{}
|
type failingRequestBuilder struct{}
|
||||||
|
|
||||||
func (*failingRequestBuilder) Build(_ context.Context, _, _ string, _ any) (*http.Request, error) {
|
func (*failingRequestBuilder) Build(_ context.Context, _, _ string, _ any, _ http.Header) (*http.Request, error) {
|
||||||
return nil, errTestRequestBuilderFailed
|
return nil, errTestRequestBuilderFailed
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -41,9 +41,10 @@ func TestDecodeResponse(t *testing.T) {
|
|||||||
stringInput := ""
|
stringInput := ""
|
||||||
|
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
value interface{}
|
value interface{}
|
||||||
body io.Reader
|
body io.Reader
|
||||||
|
hasError bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "nil input",
|
name: "nil input",
|
||||||
@@ -60,18 +61,32 @@ func TestDecodeResponse(t *testing.T) {
|
|||||||
value: &map[string]interface{}{},
|
value: &map[string]interface{}{},
|
||||||
body: bytes.NewReader([]byte(`{"test": "test"}`)),
|
body: bytes.NewReader([]byte(`{"test": "test"}`)),
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "reader return error",
|
||||||
|
value: &stringInput,
|
||||||
|
body: &errorReader{err: errors.New("dummy")},
|
||||||
|
hasError: true,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
err := decodeResponse(tc.body, tc.value)
|
err := decodeResponse(tc.body, tc.value)
|
||||||
if err != nil {
|
if (err != nil) != tc.hasError {
|
||||||
t.Errorf("Unexpected error: %v", err)
|
t.Errorf("Unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type errorReader struct {
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *errorReader) Read(_ []byte) (n int, err error) {
|
||||||
|
return 0, e.err
|
||||||
|
}
|
||||||
|
|
||||||
func TestHandleErrorResp(t *testing.T) {
|
func TestHandleErrorResp(t *testing.T) {
|
||||||
// var errRes *ErrorResponse
|
// var errRes *ErrorResponse
|
||||||
var errRes ErrorResponse
|
var errRes ErrorResponse
|
||||||
|
|||||||
@@ -165,7 +165,7 @@ func (c *Client) CreateCompletion(
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), request)
|
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), withBody(request))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
2
edits.go
2
edits.go
@@ -32,7 +32,7 @@ type EditsResponse struct {
|
|||||||
|
|
||||||
// Perform an API call to the Edits endpoint.
|
// Perform an API call to the Edits endpoint.
|
||||||
func (c *Client) Edits(ctx context.Context, request EditsRequest) (response EditsResponse, err error) {
|
func (c *Client) Edits(ctx context.Context, request EditsRequest) (response EditsResponse, err error) {
|
||||||
req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL("/edits", fmt.Sprint(request.Model)), request)
|
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/edits", fmt.Sprint(request.Model)), withBody(request))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -132,7 +132,7 @@ type EmbeddingRequest struct {
|
|||||||
// CreateEmbeddings returns an EmbeddingResponse which will contain an Embedding for every item in |request.Input|.
|
// CreateEmbeddings returns an EmbeddingResponse which will contain an Embedding for every item in |request.Input|.
|
||||||
// https://beta.openai.com/docs/api-reference/embeddings/create
|
// https://beta.openai.com/docs/api-reference/embeddings/create
|
||||||
func (c *Client) CreateEmbeddings(ctx context.Context, request EmbeddingRequest) (resp EmbeddingResponse, err error) {
|
func (c *Client) CreateEmbeddings(ctx context.Context, request EmbeddingRequest) (resp EmbeddingResponse, err error) {
|
||||||
req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL("/embeddings", request.Model.String()), request)
|
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/embeddings", request.Model.String()), withBody(request))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ type EnginesList struct {
|
|||||||
// ListEngines Lists the currently available engines, and provides basic
|
// ListEngines Lists the currently available engines, and provides basic
|
||||||
// information about each option such as the owner and availability.
|
// information about each option such as the owner and availability.
|
||||||
func (c *Client) ListEngines(ctx context.Context) (engines EnginesList, err error) {
|
func (c *Client) ListEngines(ctx context.Context) (engines EnginesList, err error) {
|
||||||
req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL("/engines"), nil)
|
req, err := c.newRequest(ctx, http.MethodGet, c.fullURL("/engines"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -38,7 +38,7 @@ func (c *Client) GetEngine(
|
|||||||
engineID string,
|
engineID string,
|
||||||
) (engine Engine, err error) {
|
) (engine Engine, err error) {
|
||||||
urlSuffix := fmt.Sprintf("/engines/%s", engineID)
|
urlSuffix := fmt.Sprintf("/engines/%s", engineID)
|
||||||
req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil)
|
req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
28
files.go
28
files.go
@@ -57,21 +57,19 @@ func (c *Client) CreateFile(ctx context.Context, request FileRequest) (file File
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL("/files"), &b)
|
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/files"),
|
||||||
|
withBody(&b), withContentType(builder.FormDataContentType()))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
req.Header.Set("Content-Type", builder.FormDataContentType())
|
|
||||||
|
|
||||||
err = c.sendRequest(req, &file)
|
err = c.sendRequest(req, &file)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteFile deletes an existing file.
|
// DeleteFile deletes an existing file.
|
||||||
func (c *Client) DeleteFile(ctx context.Context, fileID string) (err error) {
|
func (c *Client) DeleteFile(ctx context.Context, fileID string) (err error) {
|
||||||
req, err := c.requestBuilder.Build(ctx, http.MethodDelete, c.fullURL("/files/"+fileID), nil)
|
req, err := c.newRequest(ctx, http.MethodDelete, c.fullURL("/files/"+fileID))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -83,7 +81,7 @@ func (c *Client) DeleteFile(ctx context.Context, fileID string) (err error) {
|
|||||||
// ListFiles Lists the currently available files,
|
// ListFiles Lists the currently available files,
|
||||||
// and provides basic information about each file such as the file name and purpose.
|
// and provides basic information about each file such as the file name and purpose.
|
||||||
func (c *Client) ListFiles(ctx context.Context) (files FilesList, err error) {
|
func (c *Client) ListFiles(ctx context.Context) (files FilesList, err error) {
|
||||||
req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL("/files"), nil)
|
req, err := c.newRequest(ctx, http.MethodGet, c.fullURL("/files"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -96,7 +94,7 @@ func (c *Client) ListFiles(ctx context.Context) (files FilesList, err error) {
|
|||||||
// such as the file name and purpose.
|
// such as the file name and purpose.
|
||||||
func (c *Client) GetFile(ctx context.Context, fileID string) (file File, err error) {
|
func (c *Client) GetFile(ctx context.Context, fileID string) (file File, err error) {
|
||||||
urlSuffix := fmt.Sprintf("/files/%s", fileID)
|
urlSuffix := fmt.Sprintf("/files/%s", fileID)
|
||||||
req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil)
|
req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -107,23 +105,11 @@ func (c *Client) GetFile(ctx context.Context, fileID string) (file File, err err
|
|||||||
|
|
||||||
func (c *Client) GetFileContent(ctx context.Context, fileID string) (content io.ReadCloser, err error) {
|
func (c *Client) GetFileContent(ctx context.Context, fileID string) (content io.ReadCloser, err error) {
|
||||||
urlSuffix := fmt.Sprintf("/files/%s/content", fileID)
|
urlSuffix := fmt.Sprintf("/files/%s/content", fileID)
|
||||||
req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil)
|
req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.setCommonHeaders(req)
|
content, err = c.sendRequestRaw(req)
|
||||||
|
|
||||||
res, err := c.config.HTTPClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if isFailureStatusCode(res) {
|
|
||||||
err = c.handleErrorResp(res)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
content = res.Body
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -68,7 +68,7 @@ type FineTuneDeleteResponse struct {
|
|||||||
|
|
||||||
func (c *Client) CreateFineTune(ctx context.Context, request FineTuneRequest) (response FineTune, err error) {
|
func (c *Client) CreateFineTune(ctx context.Context, request FineTuneRequest) (response FineTune, err error) {
|
||||||
urlSuffix := "/fine-tunes"
|
urlSuffix := "/fine-tunes"
|
||||||
req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL(urlSuffix), request)
|
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -79,7 +79,7 @@ func (c *Client) CreateFineTune(ctx context.Context, request FineTuneRequest) (r
|
|||||||
|
|
||||||
// CancelFineTune cancel a fine-tune job.
|
// CancelFineTune cancel a fine-tune job.
|
||||||
func (c *Client) CancelFineTune(ctx context.Context, fineTuneID string) (response FineTune, err error) {
|
func (c *Client) CancelFineTune(ctx context.Context, fineTuneID string) (response FineTune, err error) {
|
||||||
req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL("/fine-tunes/"+fineTuneID+"/cancel"), nil)
|
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/fine-tunes/"+fineTuneID+"/cancel"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -89,7 +89,7 @@ func (c *Client) CancelFineTune(ctx context.Context, fineTuneID string) (respons
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) ListFineTunes(ctx context.Context) (response FineTuneList, err error) {
|
func (c *Client) ListFineTunes(ctx context.Context) (response FineTuneList, err error) {
|
||||||
req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL("/fine-tunes"), nil)
|
req, err := c.newRequest(ctx, http.MethodGet, c.fullURL("/fine-tunes"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -100,7 +100,7 @@ func (c *Client) ListFineTunes(ctx context.Context) (response FineTuneList, err
|
|||||||
|
|
||||||
func (c *Client) GetFineTune(ctx context.Context, fineTuneID string) (response FineTune, err error) {
|
func (c *Client) GetFineTune(ctx context.Context, fineTuneID string) (response FineTune, err error) {
|
||||||
urlSuffix := fmt.Sprintf("/fine-tunes/%s", fineTuneID)
|
urlSuffix := fmt.Sprintf("/fine-tunes/%s", fineTuneID)
|
||||||
req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil)
|
req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -110,7 +110,7 @@ func (c *Client) GetFineTune(ctx context.Context, fineTuneID string) (response F
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) DeleteFineTune(ctx context.Context, fineTuneID string) (response FineTuneDeleteResponse, err error) {
|
func (c *Client) DeleteFineTune(ctx context.Context, fineTuneID string) (response FineTuneDeleteResponse, err error) {
|
||||||
req, err := c.requestBuilder.Build(ctx, http.MethodDelete, c.fullURL("/fine-tunes/"+fineTuneID), nil)
|
req, err := c.newRequest(ctx, http.MethodDelete, c.fullURL("/fine-tunes/"+fineTuneID))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -120,7 +120,7 @@ func (c *Client) DeleteFineTune(ctx context.Context, fineTuneID string) (respons
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) ListFineTuneEvents(ctx context.Context, fineTuneID string) (response FineTuneEventList, err error) {
|
func (c *Client) ListFineTuneEvents(ctx context.Context, fineTuneID string) (response FineTuneEventList, err error) {
|
||||||
req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL("/fine-tunes/"+fineTuneID+"/events"), nil)
|
req, err := c.newRequest(ctx, http.MethodGet, c.fullURL("/fine-tunes/"+fineTuneID+"/events"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
13
image.go
13
image.go
@@ -44,7 +44,7 @@ type ImageResponseDataInner struct {
|
|||||||
// CreateImage - API call to create an image. This is the main endpoint of the DALL-E API.
|
// CreateImage - API call to create an image. This is the main endpoint of the DALL-E API.
|
||||||
func (c *Client) CreateImage(ctx context.Context, request ImageRequest) (response ImageResponse, err error) {
|
func (c *Client) CreateImage(ctx context.Context, request ImageRequest) (response ImageResponse, err error) {
|
||||||
urlSuffix := "/images/generations"
|
urlSuffix := "/images/generations"
|
||||||
req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL(urlSuffix), request)
|
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -107,13 +107,12 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest)
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
urlSuffix := "/images/edits"
|
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/images/edits"),
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix), body)
|
withBody(body), withContentType(builder.FormDataContentType()))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
req.Header.Set("Content-Type", builder.FormDataContentType())
|
|
||||||
err = c.sendRequest(req, &response)
|
err = c.sendRequest(req, &response)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -158,14 +157,12 @@ func (c *Client) CreateVariImage(ctx context.Context, request ImageVariRequest)
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
//https://platform.openai.com/docs/api-reference/images/create-variation
|
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/images/variations"),
|
||||||
urlSuffix := "/images/variations"
|
withBody(body), withContentType(builder.FormDataContentType()))
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix), body)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
req.Header.Set("Content-Type", builder.FormDataContentType())
|
|
||||||
err = c.sendRequest(req, &response)
|
err = c.sendRequest(req, &response)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,11 +3,12 @@ package openai
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
)
|
)
|
||||||
|
|
||||||
type RequestBuilder interface {
|
type RequestBuilder interface {
|
||||||
Build(ctx context.Context, method, url string, request any) (*http.Request, error)
|
Build(ctx context.Context, method, url string, body any, header http.Header) (*http.Request, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type HTTPRequestBuilder struct {
|
type HTTPRequestBuilder struct {
|
||||||
@@ -20,21 +21,32 @@ func NewRequestBuilder() *HTTPRequestBuilder {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *HTTPRequestBuilder) Build(ctx context.Context, method, url string, request any) (*http.Request, error) {
|
func (b *HTTPRequestBuilder) Build(
|
||||||
if request == nil {
|
ctx context.Context,
|
||||||
return http.NewRequestWithContext(ctx, method, url, nil)
|
method string,
|
||||||
|
url string,
|
||||||
|
body any,
|
||||||
|
header http.Header,
|
||||||
|
) (req *http.Request, err error) {
|
||||||
|
var bodyReader io.Reader
|
||||||
|
if body != nil {
|
||||||
|
if v, ok := body.(io.Reader); ok {
|
||||||
|
bodyReader = v
|
||||||
|
} else {
|
||||||
|
var reqBytes []byte
|
||||||
|
reqBytes, err = b.marshaller.Marshal(body)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
bodyReader = bytes.NewBuffer(reqBytes)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
req, err = http.NewRequestWithContext(ctx, method, url, bodyReader)
|
||||||
var reqBytes []byte
|
|
||||||
reqBytes, err := b.marshaller.Marshal(request)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return
|
||||||
}
|
}
|
||||||
|
if header != nil {
|
||||||
return http.NewRequestWithContext(
|
req.Header = header
|
||||||
ctx,
|
}
|
||||||
method,
|
return
|
||||||
url,
|
|
||||||
bytes.NewBuffer(reqBytes),
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ func TestRequestBuilderReturnsMarshallerErrors(t *testing.T) {
|
|||||||
marshaller: &failingMarshaller{},
|
marshaller: &failingMarshaller{},
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := builder.Build(context.Background(), "", "", struct{}{})
|
_, err := builder.Build(context.Background(), "", "", struct{}{}, nil)
|
||||||
if !errors.Is(err, errTestMarshallerFailed) {
|
if !errors.Is(err, errTestMarshallerFailed) {
|
||||||
t.Fatalf("Did not return error when marshaller failed: %v", err)
|
t.Fatalf("Did not return error when marshaller failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -38,7 +38,7 @@ func TestRequestBuilderReturnsRequest(t *testing.T) {
|
|||||||
reqBytes, _ = b.marshaller.Marshal(request)
|
reqBytes, _ = b.marshaller.Marshal(request)
|
||||||
want, _ = http.NewRequestWithContext(ctx, method, url, bytes.NewBuffer(reqBytes))
|
want, _ = http.NewRequestWithContext(ctx, method, url, bytes.NewBuffer(reqBytes))
|
||||||
)
|
)
|
||||||
got, _ := b.Build(ctx, method, url, request)
|
got, _ := b.Build(ctx, method, url, request, nil)
|
||||||
if !reflect.DeepEqual(got.Body, want.Body) ||
|
if !reflect.DeepEqual(got.Body, want.Body) ||
|
||||||
!reflect.DeepEqual(got.URL, want.URL) ||
|
!reflect.DeepEqual(got.URL, want.URL) ||
|
||||||
!reflect.DeepEqual(got.Method, want.Method) {
|
!reflect.DeepEqual(got.Method, want.Method) {
|
||||||
@@ -54,7 +54,7 @@ func TestRequestBuilderReturnsRequestWhenRequestOfArgsIsNil(t *testing.T) {
|
|||||||
want, _ = http.NewRequestWithContext(ctx, method, url, nil)
|
want, _ = http.NewRequestWithContext(ctx, method, url, nil)
|
||||||
)
|
)
|
||||||
b := NewRequestBuilder()
|
b := NewRequestBuilder()
|
||||||
got, _ := b.Build(ctx, method, url, nil)
|
got, _ := b.Build(ctx, method, url, nil, nil)
|
||||||
if !reflect.DeepEqual(got, want) {
|
if !reflect.DeepEqual(got, want) {
|
||||||
t.Errorf("Build() got = %v, want %v", got, want)
|
t.Errorf("Build() got = %v, want %v", got, want)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -41,7 +41,7 @@ type ModelsList struct {
|
|||||||
// ListModels Lists the currently available models,
|
// ListModels Lists the currently available models,
|
||||||
// and provides basic information about each model such as the model id and parent.
|
// and provides basic information about each model such as the model id and parent.
|
||||||
func (c *Client) ListModels(ctx context.Context) (models ModelsList, err error) {
|
func (c *Client) ListModels(ctx context.Context) (models ModelsList, err error) {
|
||||||
req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL("/models"), nil)
|
req, err := c.newRequest(ctx, http.MethodGet, c.fullURL("/models"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -54,7 +54,7 @@ func (c *Client) ListModels(ctx context.Context) (models ModelsList, err error)
|
|||||||
// the model such as the owner and permissioning.
|
// the model such as the owner and permissioning.
|
||||||
func (c *Client) GetModel(ctx context.Context, modelID string) (model Model, err error) {
|
func (c *Client) GetModel(ctx context.Context, modelID string) (model Model, err error) {
|
||||||
urlSuffix := fmt.Sprintf("/models/%s", modelID)
|
urlSuffix := fmt.Sprintf("/models/%s", modelID)
|
||||||
req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil)
|
req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,9 @@
|
|||||||
package openai_test
|
package openai_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"os"
|
||||||
|
"time"
|
||||||
|
|
||||||
. "github.com/sashabaranov/go-openai"
|
. "github.com/sashabaranov/go-openai"
|
||||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||||
|
|
||||||
@@ -56,3 +59,22 @@ func handleGetModelEndpoint(w http.ResponseWriter, _ *http.Request) {
|
|||||||
resBytes, _ := json.Marshal(Model{})
|
resBytes, _ := json.Marshal(Model{})
|
||||||
fmt.Fprintln(w, string(resBytes))
|
fmt.Fprintln(w, string(resBytes))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGetModelReturnTimeoutError(t *testing.T) {
|
||||||
|
client, server, teardown := setupOpenAITestServer()
|
||||||
|
defer teardown()
|
||||||
|
server.RegisterHandler("/v1/models/text-davinci-003", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
time.Sleep(10 * time.Nanosecond)
|
||||||
|
})
|
||||||
|
ctx := context.Background()
|
||||||
|
ctx, cancel := context.WithTimeout(ctx, time.Nanosecond)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
_, err := client.GetModel(ctx, "text-davinci-003")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("Did not return error")
|
||||||
|
}
|
||||||
|
if !os.IsTimeout(err) {
|
||||||
|
t.Fatal("Did not return timeout error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -63,7 +63,7 @@ type ModerationResponse struct {
|
|||||||
// Moderations — perform a moderation api call over a string.
|
// Moderations — perform a moderation api call over a string.
|
||||||
// Input can be an array or slice but a string will reduce the complexity.
|
// Input can be an array or slice but a string will reduce the complexity.
|
||||||
func (c *Client) Moderations(ctx context.Context, request ModerationRequest) (response ModerationResponse, err error) {
|
func (c *Client) Moderations(ctx context.Context, request ModerationRequest) (response ModerationResponse, err error) {
|
||||||
req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL("/moderations", request.Model), request)
|
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/moderations", request.Model), withBody(&request))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
27
stream.go
27
stream.go
@@ -1,11 +1,8 @@
|
|||||||
package openai
|
package openai
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
|
||||||
utils "github.com/sashabaranov/go-openai/internal"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -36,27 +33,17 @@ func (c *Client) CreateCompletionStream(
|
|||||||
}
|
}
|
||||||
|
|
||||||
request.Stream = true
|
request.Stream = true
|
||||||
req, err := c.newStreamRequest(ctx, "POST", urlSuffix, request, request.Model)
|
req, err := c.newRequest(ctx, "POST", c.fullURL(urlSuffix, request.Model), withBody(request))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := sendRequestStream[CompletionResponse](c, req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := c.config.HTTPClient.Do(req) //nolint:bodyclose // body is closed in stream.Close()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if isFailureStatusCode(resp) {
|
|
||||||
return nil, c.handleErrorResp(resp)
|
|
||||||
}
|
|
||||||
|
|
||||||
stream = &CompletionStream{
|
stream = &CompletionStream{
|
||||||
streamReader: &streamReader[CompletionResponse]{
|
streamReader: resp,
|
||||||
emptyMessagesLimit: c.config.EmptyMessagesLimit,
|
|
||||||
reader: bufio.NewReader(resp.Body),
|
|
||||||
response: resp,
|
|
||||||
errAccumulator: utils.NewErrorAccumulator(),
|
|
||||||
unmarshaler: &utils.JSONUnmarshaler{},
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,7 +6,9 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
. "github.com/sashabaranov/go-openai"
|
. "github.com/sashabaranov/go-openai"
|
||||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||||
@@ -300,6 +302,30 @@ func TestCreateCompletionStreamBrokenJSONError(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCreateCompletionStreamReturnTimeoutError(t *testing.T) {
|
||||||
|
client, server, teardown := setupOpenAITestServer()
|
||||||
|
defer teardown()
|
||||||
|
server.RegisterHandler("/v1/completions", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
time.Sleep(10 * time.Nanosecond)
|
||||||
|
})
|
||||||
|
ctx := context.Background()
|
||||||
|
ctx, cancel := context.WithTimeout(ctx, time.Nanosecond)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
_, err := client.CreateCompletionStream(ctx, CompletionRequest{
|
||||||
|
Prompt: "Ex falso quodlibet",
|
||||||
|
Model: "text-davinci-002",
|
||||||
|
MaxTokens: 10,
|
||||||
|
Stream: true,
|
||||||
|
})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("Did not return error")
|
||||||
|
}
|
||||||
|
if !os.IsTimeout(err) {
|
||||||
|
t.Fatal("Did not return timeout error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Helper funcs.
|
// Helper funcs.
|
||||||
func compareResponses(r1, r2 CompletionResponse) bool {
|
func compareResponses(r1, r2 CompletionResponse) bool {
|
||||||
if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model {
|
if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model {
|
||||||
|
|||||||
Reference in New Issue
Block a user