This commit is contained in:
committed by
GitHub
parent
f0770cfe1d
commit
e49d771fff
109
api_test.go
109
api_test.go
@@ -137,6 +137,108 @@ func TestAPIError(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIErrorUnmarshalJSONMessageField(t *testing.T) {
|
||||
type testCase struct {
|
||||
name string
|
||||
response string
|
||||
hasError bool
|
||||
checkFn func(t *testing.T, apiErr APIError)
|
||||
}
|
||||
testCases := []testCase{
|
||||
{
|
||||
name: "parse succeeds when the message is string",
|
||||
response: `{"message":"foo","type":"invalid_request_error","param":null,"code":null}`,
|
||||
hasError: false,
|
||||
checkFn: func(t *testing.T, apiErr APIError) {
|
||||
expected := "foo"
|
||||
if apiErr.Message != expected {
|
||||
t.Fatalf("Unexpected API message: %v; expected: %s", apiErr, expected)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "parse succeeds when the message is array with single item",
|
||||
response: `{"message":["foo"],"type":"invalid_request_error","param":null,"code":null}`,
|
||||
hasError: false,
|
||||
checkFn: func(t *testing.T, apiErr APIError) {
|
||||
expected := "foo"
|
||||
if apiErr.Message != expected {
|
||||
t.Fatalf("Unexpected API message: %v; expected: %s", apiErr, expected)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "parse succeeds when the message is array with multiple items",
|
||||
response: `{"message":["foo", "bar", "baz"],"type":"invalid_request_error","param":null,"code":null}`,
|
||||
hasError: false,
|
||||
checkFn: func(t *testing.T, apiErr APIError) {
|
||||
expected := "foo, bar, baz"
|
||||
if apiErr.Message != expected {
|
||||
t.Fatalf("Unexpected API message: %v; expected: %s", apiErr, expected)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "parse succeeds when the message is empty array",
|
||||
response: `{"message":[],"type":"invalid_request_error","param":null,"code":null}`,
|
||||
hasError: false,
|
||||
checkFn: func(t *testing.T, apiErr APIError) {
|
||||
if apiErr.Message != "" {
|
||||
t.Fatalf("Unexpected API message: %v; expected: empty", apiErr)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "parse succeeds when the message is null",
|
||||
response: `{"message":null,"type":"invalid_request_error","param":null,"code":null}`,
|
||||
hasError: false,
|
||||
checkFn: func(t *testing.T, apiErr APIError) {
|
||||
if apiErr.Message != "" {
|
||||
t.Fatalf("Unexpected API message: %v; expected: empty", apiErr)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "parse failed when the message is object",
|
||||
response: `{"message":{},"type":"invalid_request_error","param":null,"code":null}`,
|
||||
hasError: true,
|
||||
},
|
||||
{
|
||||
name: "parse failed when the message is int",
|
||||
response: `{"message":1,"type":"invalid_request_error","param":null,"code":null}`,
|
||||
hasError: true,
|
||||
},
|
||||
{
|
||||
name: "parse failed when the message is float",
|
||||
response: `{"message":0.1,"type":"invalid_request_error","param":null,"code":null}`,
|
||||
hasError: true,
|
||||
},
|
||||
{
|
||||
name: "parse failed when the message is bool",
|
||||
response: `{"message":true,"type":"invalid_request_error","param":null,"code":null}`,
|
||||
hasError: true,
|
||||
},
|
||||
{
|
||||
name: "parse failed when the message is not exists",
|
||||
response: `{"type":"invalid_request_error","param":null,"code":null}`,
|
||||
hasError: true,
|
||||
},
|
||||
}
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
var apiErr APIError
|
||||
err := json.Unmarshal([]byte(tc.response), &apiErr)
|
||||
if (err != nil) != tc.hasError {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
if tc.checkFn != nil {
|
||||
tc.checkFn(t, apiErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIErrorUnmarshalJSONInteger(t *testing.T) {
|
||||
var apiErr APIError
|
||||
response := `{"code":418,"message":"I'm a teapot","param":"prompt","type":"teapot_error"}`
|
||||
@@ -217,13 +319,6 @@ func TestAPIErrorUnmarshalJSONInvalidType(t *testing.T) {
|
||||
checks.HasError(t, err, "Type should be a string")
|
||||
}
|
||||
|
||||
func TestAPIErrorUnmarshalJSONInvalidMessage(t *testing.T) {
|
||||
var apiErr APIError
|
||||
response := `{"code":418,"message":false,"param":"prompt","type":"teapot_error"}`
|
||||
err := json.Unmarshal([]byte(response), &apiErr)
|
||||
checks.HasError(t, err, "Message should be a string")
|
||||
}
|
||||
|
||||
func TestRequestError(t *testing.T) {
|
||||
client, server, teardown := setupOpenAITestServer()
|
||||
defer teardown()
|
||||
|
||||
8
error.go
8
error.go
@@ -3,6 +3,7 @@ package openai
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// APIError provides error information returned by the OpenAI API.
|
||||
@@ -40,9 +41,16 @@ func (e *APIError) UnmarshalJSON(data []byte) (err error) {
|
||||
}
|
||||
|
||||
err = json.Unmarshal(rawMap["message"], &e.Message)
|
||||
if err != nil {
|
||||
// If the parameter field of a function call is invalid as a JSON schema
|
||||
// refs: https://github.com/sashabaranov/go-openai/issues/381
|
||||
var messages []string
|
||||
err = json.Unmarshal(rawMap["message"], &messages)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
e.Message = strings.Join(messages, ", ")
|
||||
}
|
||||
|
||||
// optional fields for azure openai
|
||||
// refs: https://github.com/sashabaranov/go-openai/issues/343
|
||||
|
||||
Reference in New Issue
Block a user