handle stream completion (#86)

* handle stream completion

* fix tests
This commit is contained in:
sashabaranov
2023-02-22 12:33:25 +04:00
committed by GitHub
parent 1eb5d625f8
commit ae05ed976f
2 changed files with 16 additions and 1 deletions

View File

@@ -7,6 +7,7 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"io"
"net/http" "net/http"
) )
@@ -16,12 +17,18 @@ var (
type CompletionStream struct { type CompletionStream struct {
emptyMessagesLimit uint emptyMessagesLimit uint
isFinished bool
reader *bufio.Reader reader *bufio.Reader
response *http.Response response *http.Response
} }
func (stream *CompletionStream) Recv() (response CompletionResponse, err error) { func (stream *CompletionStream) Recv() (response CompletionResponse, err error) {
if stream.isFinished {
err = io.EOF
return
}
var emptyMessagesCount uint var emptyMessagesCount uint
waitForData: waitForData:
@@ -44,6 +51,8 @@ waitForData:
line = bytes.TrimPrefix(line, headerData) line = bytes.TrimPrefix(line, headerData)
if string(line) == "[DONE]" { if string(line) == "[DONE]" {
stream.isFinished = true
err = io.EOF
return return
} }

View File

@@ -5,6 +5,8 @@ import (
"github.com/sashabaranov/go-gpt3/internal/test" "github.com/sashabaranov/go-gpt3/internal/test"
"context" "context"
"errors"
"io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
@@ -75,7 +77,6 @@ func TestCreateCompletionStream(t *testing.T) {
Model: "text-davinci-002", Model: "text-davinci-002",
Choices: []CompletionChoice{{Text: "response2", FinishReason: "max_tokens"}}, Choices: []CompletionChoice{{Text: "response2", FinishReason: "max_tokens"}},
}, },
{},
} }
for ix, expectedResponse := range expectedResponses { for ix, expectedResponse := range expectedResponses {
@@ -87,6 +88,11 @@ func TestCreateCompletionStream(t *testing.T) {
t.Errorf("Stream response %v is %v, expected %v", ix, receivedResponse, expectedResponse) t.Errorf("Stream response %v is %v, expected %v", ix, receivedResponse, expectedResponse)
} }
} }
_, streamErr := stream.Recv()
if !errors.Is(streamErr, io.EOF) {
t.Errorf("stream.Recv() did not return EOF in the end: %v", streamErr)
}
} }
// A "tokenRoundTripper" is a struct that implements the RoundTripper // A "tokenRoundTripper" is a struct that implements the RoundTripper