From 39ca4e94882215d59857cd8791c12082edec2c97 Mon Sep 17 00:00:00 2001 From: Afeyer Date: Fri, 3 Mar 2023 13:52:02 +0800 Subject: [PATCH] Implement chat completion streaming (#101) * Implement chat completion streaming * Optimize the implementation of chat completion stream * Fix linter error --- api.go | 30 ++++++++++++++ chat.go | 2 +- chat_stream.go | 110 +++++++++++++++++++++++++++++++++++++++++++++++++ stream.go | 14 +------ 4 files changed, 142 insertions(+), 14 deletions(-) create mode 100644 chat_stream.go diff --git a/api.go b/api.go index 715c9dd..ba37826 100644 --- a/api.go +++ b/api.go @@ -1,6 +1,8 @@ package gogpt import ( + "bytes" + "context" "encoding/json" "fmt" "net/http" @@ -79,3 +81,31 @@ func (c *Client) sendRequest(req *http.Request, v interface{}) error { func (c *Client) fullURL(suffix string) string { return fmt.Sprintf("%s%s", c.config.BaseURL, suffix) } + +func (c *Client) newStreamRequest( + ctx context.Context, + method string, + urlSuffix string, + body interface{}) (*http.Request, error) { + var reqBody []byte + if body != nil { + var err error + reqBody, err = json.Marshal(body) + if err != nil { + return nil, err + } + } + + req, err := http.NewRequestWithContext(ctx, method, c.fullURL(urlSuffix), bytes.NewBuffer(reqBody)) + if err != nil { + return nil, err + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "text/event-stream") + req.Header.Set("Cache-Control", "no-cache") + req.Header.Set("Connection", "keep-alive") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken)) + + return req, nil +} diff --git a/chat.go b/chat.go index 81e5a39..c6ef5fc 100644 --- a/chat.go +++ b/chat.go @@ -49,7 +49,7 @@ type ChatCompletionResponse struct { Usage Usage `json:"usage"` } -// CreateChatCompletion — API call to Creates a completion for the chat message. +// CreateChatCompletion — API call to Create a completion for the chat message. func (c *Client) CreateChatCompletion( ctx context.Context, request ChatCompletionRequest, diff --git a/chat_stream.go b/chat_stream.go new file mode 100644 index 0000000..c4813cb --- /dev/null +++ b/chat_stream.go @@ -0,0 +1,110 @@ +package gogpt + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "io" + "net/http" +) + +type ChatCompletionStreamChoiceDelta struct { + Content string `json:"content"` +} + +type ChatCompletionStreamChoice struct { + Index int `json:"index"` + Delta ChatCompletionStreamChoiceDelta `json:"delta"` + FinishReason string `json:"finish_reason"` +} + +type ChatCompletionStreamResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []ChatCompletionStreamChoice `json:"choices"` +} + +// ChatCompletionStream +// Note: Perhaps it is more elegant to abstract Stream using generics. +type ChatCompletionStream struct { + emptyMessagesLimit uint + isFinished bool + + reader *bufio.Reader + response *http.Response +} + +func (stream *ChatCompletionStream) Recv() (response ChatCompletionStreamResponse, err error) { + if stream.isFinished { + err = io.EOF + return + } + + var emptyMessagesCount uint + +waitForData: + line, err := stream.reader.ReadBytes('\n') + if err != nil { + return + } + + var headerData = []byte("data: ") + line = bytes.TrimSpace(line) + if !bytes.HasPrefix(line, headerData) { + emptyMessagesCount++ + if emptyMessagesCount > stream.emptyMessagesLimit { + err = ErrTooManyEmptyStreamMessages + return + } + + goto waitForData + } + + line = bytes.TrimPrefix(line, headerData) + if string(line) == "[DONE]" { + stream.isFinished = true + err = io.EOF + return + } + + err = json.Unmarshal(line, &response) + return +} + +func (stream *ChatCompletionStream) Close() { + stream.response.Body.Close() +} + +func (stream *ChatCompletionStream) GetResponse() *http.Response { + return stream.response +} + +// CreateChatCompletionStream — API call to create a chat completion w/ streaming +// support. It sets whether to stream back partial progress. If set, tokens will be +// sent as data-only server-sent events as they become available, with the +// stream terminated by a data: [DONE] message. +func (c *Client) CreateChatCompletionStream( + ctx context.Context, + request ChatCompletionRequest, +) (stream *ChatCompletionStream, err error) { + request.Stream = true + req, err := c.newStreamRequest(ctx, "POST", "/chat/completions", request) + if err != nil { + return + } + + resp, err := c.config.HTTPClient.Do(req) //nolint:bodyclose // body is closed in stream.Close() + if err != nil { + return + } + + stream = &ChatCompletionStream{ + emptyMessagesLimit: c.config.EmptyMessagesLimit, + reader: bufio.NewReader(resp.Body), + response: resp, + } + return +} diff --git a/stream.go b/stream.go index d1bdf48..4745b47 100644 --- a/stream.go +++ b/stream.go @@ -6,7 +6,6 @@ import ( "context" "encoding/json" "errors" - "fmt" "io" "net/http" ) @@ -73,18 +72,7 @@ func (c *Client) CreateCompletionStream( request CompletionRequest, ) (stream *CompletionStream, err error) { request.Stream = true - reqBytes, err := json.Marshal(request) - if err != nil { - return - } - - urlSuffix := "/completions" - req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix), bytes.NewBuffer(reqBytes)) - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Accept", "text/event-stream") - req.Header.Set("Cache-Control", "no-cache") - req.Header.Set("Connection", "keep-alive") - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken)) + req, err := c.newStreamRequest(ctx, "POST", "/completions", request) if err != nil { return }