Files
go-openai/stream_reader.go
Yuki Bobier Koshimizu b8c13e4c01 Refactor streamReader: Replace goto Statement with Loop in Recv Method (#339)
* test: Add tests for improved coverage before refactoring

This commit adds tests to improve coverage before refactoring
to ensure that the changes do not break anything.

* refactor: replace goto statement with loop

This commit introduces a refactor to improve the clarity of the control flow within the method.
The goto statement can sometimes make the code hard to understand and maintain, hence this refactor aims to resolve that.

* refactor: extract for-loop from Recv to another method

This commit improves code readability and maintainability
by making the Recv method simpler.
2023-06-08 19:31:25 +04:00

98 lines
1.9 KiB
Go

package openai
import (
"bufio"
"bytes"
"fmt"
"io"
"net/http"
utils "github.com/sashabaranov/go-openai/internal"
)
type streamable interface {
ChatCompletionStreamResponse | CompletionResponse
}
type streamReader[T streamable] struct {
emptyMessagesLimit uint
isFinished bool
reader *bufio.Reader
response *http.Response
errAccumulator utils.ErrorAccumulator
unmarshaler utils.Unmarshaler
}
func (stream *streamReader[T]) Recv() (response T, err error) {
if stream.isFinished {
err = io.EOF
return
}
response, err = stream.processLines()
return
}
func (stream *streamReader[T]) processLines() (T, error) {
var emptyMessagesCount uint
for {
rawLine, readErr := stream.reader.ReadBytes('\n')
if readErr != nil {
respErr := stream.unmarshalError()
if respErr != nil {
return *new(T), fmt.Errorf("error, %w", respErr.Error)
}
return *new(T), readErr
}
var headerData = []byte("data: ")
noSpaceLine := bytes.TrimSpace(rawLine)
if !bytes.HasPrefix(noSpaceLine, headerData) {
writeErr := stream.errAccumulator.Write(noSpaceLine)
if writeErr != nil {
return *new(T), writeErr
}
emptyMessagesCount++
if emptyMessagesCount > stream.emptyMessagesLimit {
return *new(T), ErrTooManyEmptyStreamMessages
}
continue
}
noPrefixLine := bytes.TrimPrefix(noSpaceLine, headerData)
if string(noPrefixLine) == "[DONE]" {
stream.isFinished = true
return *new(T), io.EOF
}
var response T
unmarshalErr := stream.unmarshaler.Unmarshal(noPrefixLine, &response)
if unmarshalErr != nil {
return *new(T), unmarshalErr
}
return response, nil
}
}
func (stream *streamReader[T]) unmarshalError() (errResp *ErrorResponse) {
errBytes := stream.errAccumulator.Bytes()
if len(errBytes) == 0 {
return
}
err := stream.unmarshaler.Unmarshal(errBytes, &errResp)
if err != nil {
errResp = nil
}
return
}
func (stream *streamReader[T]) Close() {
stream.response.Body.Close()
}