fix: stream return EOF when openai return error (#184)

* fix: stream return EOF when openai return error

* perf: add error accumulator

* fix: golangci-lint

* fix: unmarshal error possibly null

* fix: error accumulator

* test: error accumulator use interface and add test code

* test: error accumulator add test code

* refactor: use stream reader to re-use stream code

* refactor: stream reader use generics
This commit is contained in:
Liu Shuang
2023-03-22 13:32:47 +08:00
committed by GitHub
parent aa149c1bf8
commit a5a945ad14
8 changed files with 372 additions and 107 deletions

View File

@@ -2,12 +2,8 @@ package openai
import (
"bufio"
"bytes"
"context"
"encoding/json"
"errors"
"io"
"net/http"
)
var (
@@ -15,52 +11,7 @@ var (
)
type CompletionStream struct {
emptyMessagesLimit uint
isFinished bool
reader *bufio.Reader
response *http.Response
}
func (stream *CompletionStream) Recv() (response CompletionResponse, err error) {
if stream.isFinished {
err = io.EOF
return
}
var emptyMessagesCount uint
waitForData:
line, err := stream.reader.ReadBytes('\n')
if err != nil {
return
}
var headerData = []byte("data: ")
line = bytes.TrimSpace(line)
if !bytes.HasPrefix(line, headerData) {
emptyMessagesCount++
if emptyMessagesCount > stream.emptyMessagesLimit {
err = ErrTooManyEmptyStreamMessages
return
}
goto waitForData
}
line = bytes.TrimPrefix(line, headerData)
if string(line) == "[DONE]" {
stream.isFinished = true
err = io.EOF
return
}
err = json.Unmarshal(line, &response)
return
}
func (stream *CompletionStream) Close() {
stream.response.Body.Close()
*streamReader[CompletionResponse]
}
// CreateCompletionStream — API call to create a completion w/ streaming
@@ -83,10 +34,13 @@ func (c *Client) CreateCompletionStream(
}
stream = &CompletionStream{
emptyMessagesLimit: c.config.EmptyMessagesLimit,
reader: bufio.NewReader(resp.Body),
response: resp,
streamReader: &streamReader[CompletionResponse]{
emptyMessagesLimit: c.config.EmptyMessagesLimit,
reader: bufio.NewReader(resp.Body),
response: resp,
errAccumulator: newErrorAccumulator(),
unmarshaler: &jsonUnmarshaler{},
},
}
return
}