feat: add RecvRaw (#896)

This commit is contained in:
Qiying Wang
2024-11-30 18:29:05 +08:00
committed by GitHub
parent 21fa42c18d
commit c203ca001f
2 changed files with 35 additions and 17 deletions

View File

@@ -32,17 +32,28 @@ type streamReader[T streamable] struct {
} }
func (stream *streamReader[T]) Recv() (response T, err error) { func (stream *streamReader[T]) Recv() (response T, err error) {
if stream.isFinished { rawLine, err := stream.RecvRaw()
err = io.EOF if err != nil {
return return
} }
response, err = stream.processLines() err = stream.unmarshaler.Unmarshal(rawLine, &response)
if err != nil {
return return
}
return response, nil
}
func (stream *streamReader[T]) RecvRaw() ([]byte, error) {
if stream.isFinished {
return nil, io.EOF
}
return stream.processLines()
} }
//nolint:gocognit //nolint:gocognit
func (stream *streamReader[T]) processLines() (T, error) { func (stream *streamReader[T]) processLines() ([]byte, error) {
var ( var (
emptyMessagesCount uint emptyMessagesCount uint
hasErrorPrefix bool hasErrorPrefix bool
@@ -53,9 +64,9 @@ func (stream *streamReader[T]) processLines() (T, error) {
if readErr != nil || hasErrorPrefix { if readErr != nil || hasErrorPrefix {
respErr := stream.unmarshalError() respErr := stream.unmarshalError()
if respErr != nil { if respErr != nil {
return *new(T), fmt.Errorf("error, %w", respErr.Error) return nil, fmt.Errorf("error, %w", respErr.Error)
} }
return *new(T), readErr return nil, readErr
} }
noSpaceLine := bytes.TrimSpace(rawLine) noSpaceLine := bytes.TrimSpace(rawLine)
@@ -68,11 +79,11 @@ func (stream *streamReader[T]) processLines() (T, error) {
} }
writeErr := stream.errAccumulator.Write(noSpaceLine) writeErr := stream.errAccumulator.Write(noSpaceLine)
if writeErr != nil { if writeErr != nil {
return *new(T), writeErr return nil, writeErr
} }
emptyMessagesCount++ emptyMessagesCount++
if emptyMessagesCount > stream.emptyMessagesLimit { if emptyMessagesCount > stream.emptyMessagesLimit {
return *new(T), ErrTooManyEmptyStreamMessages return nil, ErrTooManyEmptyStreamMessages
} }
continue continue
@@ -81,16 +92,10 @@ func (stream *streamReader[T]) processLines() (T, error) {
noPrefixLine := bytes.TrimPrefix(noSpaceLine, headerData) noPrefixLine := bytes.TrimPrefix(noSpaceLine, headerData)
if string(noPrefixLine) == "[DONE]" { if string(noPrefixLine) == "[DONE]" {
stream.isFinished = true stream.isFinished = true
return *new(T), io.EOF return nil, io.EOF
} }
var response T return noPrefixLine, nil
unmarshalErr := stream.unmarshaler.Unmarshal(noPrefixLine, &response)
if unmarshalErr != nil {
return *new(T), unmarshalErr
}
return response, nil
} }
} }

View File

@@ -63,3 +63,16 @@ func TestStreamReaderReturnsErrTestErrorAccumulatorWriteFailed(t *testing.T) {
_, err := stream.Recv() _, err := stream.Recv()
checks.ErrorIs(t, err, test.ErrTestErrorAccumulatorWriteFailed, "Did not return error when write failed", err.Error()) checks.ErrorIs(t, err, test.ErrTestErrorAccumulatorWriteFailed, "Did not return error when write failed", err.Error())
} }
func TestStreamReaderRecvRaw(t *testing.T) {
stream := &streamReader[ChatCompletionStreamResponse]{
reader: bufio.NewReader(bytes.NewReader([]byte("data: {\"key\": \"value\"}\n"))),
}
rawLine, err := stream.RecvRaw()
if err != nil {
t.Fatalf("Did not return raw line: %v", err)
}
if !bytes.Equal(rawLine, []byte("{\"key\": \"value\"}")) {
t.Fatalf("Did not return raw line: %v", string(rawLine))
}
}