* test: Add tests for improved coverage before refactoring This commit adds tests to improve coverage before refactoring to ensure that the changes do not break anything. * refactor: replace goto statement with loop This commit introduces a refactor to improve the clarity of the control flow within the method. The goto statement can sometimes make the code hard to understand and maintain, hence this refactor aims to resolve that. * refactor: extract for-loop from Recv to another method This commit improves code readability and maintainability by making the Recv method simpler.
98 lines
1.9 KiB
Go
98 lines
1.9 KiB
Go
package openai
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
|
|
utils "github.com/sashabaranov/go-openai/internal"
|
|
)
|
|
|
|
type streamable interface {
|
|
ChatCompletionStreamResponse | CompletionResponse
|
|
}
|
|
|
|
type streamReader[T streamable] struct {
|
|
emptyMessagesLimit uint
|
|
isFinished bool
|
|
|
|
reader *bufio.Reader
|
|
response *http.Response
|
|
errAccumulator utils.ErrorAccumulator
|
|
unmarshaler utils.Unmarshaler
|
|
}
|
|
|
|
func (stream *streamReader[T]) Recv() (response T, err error) {
|
|
if stream.isFinished {
|
|
err = io.EOF
|
|
return
|
|
}
|
|
|
|
response, err = stream.processLines()
|
|
return
|
|
}
|
|
|
|
func (stream *streamReader[T]) processLines() (T, error) {
|
|
var emptyMessagesCount uint
|
|
|
|
for {
|
|
rawLine, readErr := stream.reader.ReadBytes('\n')
|
|
if readErr != nil {
|
|
respErr := stream.unmarshalError()
|
|
if respErr != nil {
|
|
return *new(T), fmt.Errorf("error, %w", respErr.Error)
|
|
}
|
|
return *new(T), readErr
|
|
}
|
|
|
|
var headerData = []byte("data: ")
|
|
noSpaceLine := bytes.TrimSpace(rawLine)
|
|
if !bytes.HasPrefix(noSpaceLine, headerData) {
|
|
writeErr := stream.errAccumulator.Write(noSpaceLine)
|
|
if writeErr != nil {
|
|
return *new(T), writeErr
|
|
}
|
|
emptyMessagesCount++
|
|
if emptyMessagesCount > stream.emptyMessagesLimit {
|
|
return *new(T), ErrTooManyEmptyStreamMessages
|
|
}
|
|
|
|
continue
|
|
}
|
|
|
|
noPrefixLine := bytes.TrimPrefix(noSpaceLine, headerData)
|
|
if string(noPrefixLine) == "[DONE]" {
|
|
stream.isFinished = true
|
|
return *new(T), io.EOF
|
|
}
|
|
|
|
var response T
|
|
unmarshalErr := stream.unmarshaler.Unmarshal(noPrefixLine, &response)
|
|
if unmarshalErr != nil {
|
|
return *new(T), unmarshalErr
|
|
}
|
|
|
|
return response, nil
|
|
}
|
|
}
|
|
|
|
func (stream *streamReader[T]) unmarshalError() (errResp *ErrorResponse) {
|
|
errBytes := stream.errAccumulator.Bytes()
|
|
if len(errBytes) == 0 {
|
|
return
|
|
}
|
|
|
|
err := stream.unmarshaler.Unmarshal(errBytes, &errResp)
|
|
if err != nil {
|
|
errResp = nil
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
func (stream *streamReader[T]) Close() {
|
|
stream.response.Body.Close()
|
|
}
|