Files
go-openai/chat_stream.go
Liu Shuang a5a945ad14 fix: stream return EOF when openai return error (#184)
* fix: stream return EOF when openai return error

* perf: add error accumulator

* fix: golangci-lint

* fix: unmarshal error possibly null

* fix: error accumulator

* test: error accumulator use interface and add test code

* test: error accumulator add test code

* refactor: use stream reader to re-use stream code

* refactor: stream reader use generics
2023-03-22 09:32:47 +04:00

62 lines
1.8 KiB
Go

package openai
import (
"bufio"
"context"
)
type ChatCompletionStreamChoiceDelta struct {
Content string `json:"content"`
}
type ChatCompletionStreamChoice struct {
Index int `json:"index"`
Delta ChatCompletionStreamChoiceDelta `json:"delta"`
FinishReason string `json:"finish_reason"`
}
type ChatCompletionStreamResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []ChatCompletionStreamChoice `json:"choices"`
}
// ChatCompletionStream
// Note: Perhaps it is more elegant to abstract Stream using generics.
type ChatCompletionStream struct {
*streamReader[ChatCompletionStreamResponse]
}
// CreateChatCompletionStream — API call to create a chat completion w/ streaming
// support. It sets whether to stream back partial progress. If set, tokens will be
// sent as data-only server-sent events as they become available, with the
// stream terminated by a data: [DONE] message.
func (c *Client) CreateChatCompletionStream(
ctx context.Context,
request ChatCompletionRequest,
) (stream *ChatCompletionStream, err error) {
request.Stream = true
req, err := c.newStreamRequest(ctx, "POST", "/chat/completions", request)
if err != nil {
return
}
resp, err := c.config.HTTPClient.Do(req) //nolint:bodyclose // body is closed in stream.Close()
if err != nil {
return
}
stream = &ChatCompletionStream{
streamReader: &streamReader[ChatCompletionStreamResponse]{
emptyMessagesLimit: c.config.EmptyMessagesLimit,
reader: bufio.NewReader(resp.Body),
response: resp,
errAccumulator: newErrorAccumulator(),
unmarshaler: &jsonUnmarshaler{},
},
}
return
}