fix: stream return EOF when openai return error (#184)

* fix: stream return EOF when openai return error

* perf: add error accumulator

* fix: golangci-lint

* fix: unmarshal error possibly null

* fix: error accumulator

* test: error accumulator use interface and add test code

* test: error accumulator add test code

* refactor: use stream reader to re-use stream code

* refactor: stream reader use generics
This commit is contained in:
Liu Shuang
2023-03-22 13:32:47 +08:00
committed by GitHub
parent aa149c1bf8
commit a5a945ad14
8 changed files with 372 additions and 107 deletions

View File

@@ -100,6 +100,68 @@ func TestCreateCompletionStream(t *testing.T) {
}
}
func TestCreateCompletionStreamError(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
// Send test responses
dataBytes := []byte{}
dataStr := []string{
`{`,
`"error": {`,
`"message": "Incorrect API key provided: sk-***************************************",`,
`"type": "invalid_request_error",`,
`"param": null,`,
`"code": "invalid_api_key"`,
`}`,
`}`,
}
for _, str := range dataStr {
dataBytes = append(dataBytes, []byte(str+"\n")...)
}
_, err := w.Write(dataBytes)
if err != nil {
t.Errorf("Write error: %s", err)
}
}))
defer server.Close()
// Client portion of the test
config := DefaultConfig(test.GetTestToken())
config.BaseURL = server.URL + "/v1"
config.HTTPClient.Transport = &tokenRoundTripper{
test.GetTestToken(),
http.DefaultTransport,
}
client := NewClientWithConfig(config)
ctx := context.Background()
request := CompletionRequest{
MaxTokens: 5,
Model: GPT3Dot5Turbo,
Prompt: "Hello!",
Stream: true,
}
stream, err := client.CreateCompletionStream(ctx, request)
if err != nil {
t.Errorf("CreateCompletionStream returned error: %v", err)
}
defer stream.Close()
_, streamErr := stream.Recv()
if streamErr == nil {
t.Errorf("stream.Recv() did not return error")
}
var apiErr *APIError
if !errors.As(streamErr, &apiErr) {
t.Errorf("stream.Recv() did not return APIError")
}
t.Logf("%+v\n", apiErr)
}
// A "tokenRoundTripper" is a struct that implements the RoundTripper
// interface, specifically to handle the authentication token by adding a token
// to the request header. We need this because the API requires that each