fix: chat stream resp error (#259)

This commit is contained in:
Liu Shuang
2023-04-19 20:05:00 +08:00
committed by GitHub
parent 3b10c032b6
commit d6ab1b3a4f
8 changed files with 146 additions and 33 deletions

View File

@@ -1,16 +1,15 @@
package openai_test package openai_test
import ( import (
"encoding/json"
. "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks"
"context" "context"
"encoding/json"
"errors" "errors"
"io" "io"
"os" "os"
"testing" "testing"
. "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks"
) )
func TestAPI(t *testing.T) { func TestAPI(t *testing.T) {
@@ -119,8 +118,8 @@ func TestAPIError(t *testing.T) {
t.Fatalf("Error is not an APIError: %+v", err) t.Fatalf("Error is not an APIError: %+v", err)
} }
if apiErr.StatusCode != 401 { if apiErr.HTTPStatusCode != 401 {
t.Fatalf("Unexpected API error status code: %d", apiErr.StatusCode) t.Fatalf("Unexpected API error status code: %d", apiErr.HTTPStatusCode)
} }
switch v := apiErr.Code.(type) { switch v := apiErr.Code.(type) {
@@ -239,8 +238,8 @@ func TestRequestError(t *testing.T) {
t.Fatalf("Error is not a RequestError: %+v", err) t.Fatalf("Error is not a RequestError: %+v", err)
} }
if reqErr.StatusCode != 418 { if reqErr.HTTPStatusCode != 418 {
t.Fatalf("Unexpected request error status code: %d", reqErr.StatusCode) t.Fatalf("Unexpected request error status code: %d", reqErr.HTTPStatusCode)
} }
if reqErr.Unwrap() == nil { if reqErr.Unwrap() == nil {

View File

@@ -3,6 +3,7 @@ package openai
import ( import (
"bufio" "bufio"
"context" "context"
"net/http"
) )
type ChatCompletionStreamChoiceDelta struct { type ChatCompletionStreamChoiceDelta struct {
@@ -53,6 +54,9 @@ func (c *Client) CreateChatCompletionStream(
if err != nil { if err != nil {
return return
} }
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusBadRequest {
return nil, c.handleErrorResp(resp)
}
stream = &ChatCompletionStream{ stream = &ChatCompletionStream{
streamReader: &streamReader[ChatCompletionStreamResponse]{ streamReader: &streamReader[ChatCompletionStreamResponse]{

View File

@@ -204,6 +204,57 @@ func TestCreateChatCompletionStreamError(t *testing.T) {
t.Logf("%+v\n", apiErr) t.Logf("%+v\n", apiErr)
} }
func TestCreateChatCompletionStreamRateLimitError(t *testing.T) {
server := test.NewTestServer()
server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(429)
// Send test responses
dataBytes := []byte(`{"error":{` +
`"message": "You are sending requests too quickly.",` +
`"type":"rate_limit_reached",` +
`"param":null,` +
`"code":"rate_limit_reached"}}`)
_, err := w.Write(dataBytes)
checks.NoError(t, err, "Write error")
})
ts := server.OpenAITestServer()
ts.Start()
defer ts.Close()
// Client portion of the test
config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
config.HTTPClient.Transport = &tokenRoundTripper{
test.GetTestToken(),
http.DefaultTransport,
}
client := NewClientWithConfig(config)
ctx := context.Background()
request := ChatCompletionRequest{
MaxTokens: 5,
Model: GPT3Dot5Turbo,
Messages: []ChatCompletionMessage{
{
Role: ChatMessageRoleUser,
Content: "Hello!",
},
},
Stream: true,
}
var apiErr *APIError
_, err := client.CreateChatCompletionStream(ctx, request)
if !errors.As(err, &apiErr) {
t.Errorf("TestCreateChatCompletionStreamRateLimitError did not return APIError")
}
t.Logf("%+v\n", apiErr)
}
// Helper funcs. // Helper funcs.
func compareChatResponses(r1, r2 ChatCompletionStreamResponse) bool { func compareChatResponses(r1, r2 ChatCompletionStreamResponse) bool {
if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model { if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model {

View File

@@ -72,17 +72,7 @@ func (c *Client) sendRequest(req *http.Request, v interface{}) error {
defer res.Body.Close() defer res.Body.Close()
if res.StatusCode < http.StatusOK || res.StatusCode >= http.StatusBadRequest { if res.StatusCode < http.StatusOK || res.StatusCode >= http.StatusBadRequest {
var errRes ErrorResponse return c.handleErrorResp(res)
err = json.NewDecoder(res.Body).Decode(&errRes)
if err != nil || errRes.Error == nil {
reqErr := RequestError{
StatusCode: res.StatusCode,
Err: err,
}
return fmt.Errorf("error, %w", &reqErr)
}
errRes.Error.StatusCode = res.StatusCode
return fmt.Errorf("error, status code: %d, message: %w", res.StatusCode, errRes.Error)
} }
if v != nil { if v != nil {
@@ -132,3 +122,17 @@ func (c *Client) newStreamRequest(
} }
return req, nil return req, nil
} }
func (c *Client) handleErrorResp(resp *http.Response) error {
var errRes ErrorResponse
err := json.NewDecoder(resp.Body).Decode(&errRes)
if err != nil || errRes.Error == nil {
reqErr := RequestError{
HTTPStatusCode: resp.StatusCode,
Err: err,
}
return fmt.Errorf("error, %w", &reqErr)
}
errRes.Error.HTTPStatusCode = resp.StatusCode
return fmt.Errorf("error, status code: %d, message: %w", resp.StatusCode, errRes.Error)
}

View File

@@ -11,12 +11,12 @@ type APIError struct {
Message string `json:"message"` Message string `json:"message"`
Param *string `json:"param,omitempty"` Param *string `json:"param,omitempty"`
Type string `json:"type"` Type string `json:"type"`
StatusCode int `json:"-"` HTTPStatusCode int `json:"-"`
} }
// RequestError provides informations about generic request errors. // RequestError provides informations about generic request errors.
type RequestError struct { type RequestError struct {
StatusCode int HTTPStatusCode int
Err error Err error
} }
@@ -73,7 +73,7 @@ func (e *RequestError) Error() string {
if e.Err != nil { if e.Err != nil {
return e.Err.Error() return e.Err.Error()
} }
return fmt.Sprintf("status code %d", e.StatusCode) return fmt.Sprintf("status code %d", e.HTTPStatusCode)
} }
func (e *RequestError) Unwrap() error { func (e *RequestError) Unwrap() error {

View File

@@ -4,6 +4,7 @@ import (
"bytes" "bytes"
"context" "context"
"errors" "errors"
"net/http"
"testing" "testing"
"github.com/sashabaranov/go-openai/internal/test" "github.com/sashabaranov/go-openai/internal/test"
@@ -71,7 +72,11 @@ func TestErrorByteWriteErrors(t *testing.T) {
func TestErrorAccumulatorWriteErrors(t *testing.T) { func TestErrorAccumulatorWriteErrors(t *testing.T) {
var err error var err error
ts := test.NewTestServer().OpenAITestServer() server := test.NewTestServer()
server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "error", 200)
})
ts := server.OpenAITestServer()
ts.Start() ts.Start()
defer ts.Close() defer ts.Close()

View File

@@ -4,6 +4,7 @@ import (
"bufio" "bufio"
"context" "context"
"errors" "errors"
"net/http"
) )
var ( var (
@@ -43,6 +44,9 @@ func (c *Client) CreateCompletionStream(
if err != nil { if err != nil {
return return
} }
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusBadRequest {
return nil, c.handleErrorResp(resp)
}
stream = &CompletionStream{ stream = &CompletionStream{
streamReader: &streamReader[CompletionResponse]{ streamReader: &streamReader[CompletionResponse]{

View File

@@ -1,16 +1,16 @@
package openai_test package openai_test
import ( import (
. "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks"
"context" "context"
"errors" "errors"
"io" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
. "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks"
) )
func TestCompletionsStreamWrongModel(t *testing.T) { func TestCompletionsStreamWrongModel(t *testing.T) {
@@ -171,6 +171,52 @@ func TestCreateCompletionStreamError(t *testing.T) {
t.Logf("%+v\n", apiErr) t.Logf("%+v\n", apiErr)
} }
func TestCreateCompletionStreamRateLimitError(t *testing.T) {
server := test.NewTestServer()
server.RegisterHandler("/v1/completions", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(429)
// Send test responses
dataBytes := []byte(`{"error":{` +
`"message": "You are sending requests too quickly.",` +
`"type":"rate_limit_reached",` +
`"param":null,` +
`"code":"rate_limit_reached"}}`)
_, err := w.Write(dataBytes)
checks.NoError(t, err, "Write error")
})
ts := server.OpenAITestServer()
ts.Start()
defer ts.Close()
// Client portion of the test
config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
config.HTTPClient.Transport = &tokenRoundTripper{
test.GetTestToken(),
http.DefaultTransport,
}
client := NewClientWithConfig(config)
ctx := context.Background()
request := CompletionRequest{
MaxTokens: 5,
Model: GPT3Ada,
Prompt: "Hello!",
Stream: true,
}
var apiErr *APIError
_, err := client.CreateCompletionStream(ctx, request)
if !errors.As(err, &apiErr) {
t.Errorf("TestCreateCompletionStreamRateLimitError did not return APIError")
}
t.Logf("%+v\n", apiErr)
}
// A "tokenRoundTripper" is a struct that implements the RoundTripper // A "tokenRoundTripper" is a struct that implements the RoundTripper
// interface, specifically to handle the authentication token by adding a token // interface, specifically to handle the authentication token by adding a token
// to the request header. We need this because the API requires that each // to the request header. We need this because the API requires that each