diff --git a/stream_reader.go b/stream_reader.go index ecfa268..6faefe0 100644 --- a/stream_reader.go +++ b/stream_reader.go @@ -6,13 +6,14 @@ import ( "fmt" "io" "net/http" + "regexp" utils "github.com/sashabaranov/go-openai/internal" ) var ( - headerData = []byte("data: ") - errorPrefix = []byte(`data: {"error":`) + headerData = regexp.MustCompile(`^data:\s*`) + errorPrefix = regexp.MustCompile(`^data:\s*{"error":`) ) type streamable interface { @@ -70,12 +71,12 @@ func (stream *streamReader[T]) processLines() ([]byte, error) { } noSpaceLine := bytes.TrimSpace(rawLine) - if bytes.HasPrefix(noSpaceLine, errorPrefix) { + if errorPrefix.Match(noSpaceLine) { hasErrorPrefix = true } - if !bytes.HasPrefix(noSpaceLine, headerData) || hasErrorPrefix { + if !headerData.Match(noSpaceLine) || hasErrorPrefix { if hasErrorPrefix { - noSpaceLine = bytes.TrimPrefix(noSpaceLine, headerData) + noSpaceLine = headerData.ReplaceAll(noSpaceLine, nil) } writeErr := stream.errAccumulator.Write(noSpaceLine) if writeErr != nil { @@ -89,7 +90,7 @@ func (stream *streamReader[T]) processLines() ([]byte, error) { continue } - noPrefixLine := bytes.TrimPrefix(noSpaceLine, headerData) + noPrefixLine := headerData.ReplaceAll(noSpaceLine, nil) if string(noPrefixLine) == "[DONE]" { stream.isFinished = true return nil, io.EOF