From c3b2451f7c7dc477d98e1baa10993ac55392c7dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=A1=E9=82=89=E7=A5=90=E4=B8=80=20/=20Yuichi=20Watana?= =?UTF-8?q?be?= Date: Tue, 11 Jul 2023 20:48:15 +0900 Subject: [PATCH] fix: invalid schema for function 'func_name': None is not of type 'object' (#429)(#432) (#434) * fix: invalid schema for function 'func_name': None is not of type 'object' (#429)(#432) * test: add integration test for function call (#429)(#432) * style: remove duplicate import (#429)(#432) --- api_integration_test.go | 32 ++++++++++++++++++++++++++++++++ jsonschema/json.go | 23 +++++++---------------- jsonschema/json_test.go | 38 ++++++++++++++++++++++++-------------- 3 files changed, 63 insertions(+), 30 deletions(-) diff --git a/api_integration_test.go b/api_integration_test.go index d4e7328..254fbeb 100644 --- a/api_integration_test.go +++ b/api_integration_test.go @@ -11,6 +11,7 @@ import ( . "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai/internal/test/checks" + "github.com/sashabaranov/go-openai/jsonschema" ) func TestAPI(t *testing.T) { @@ -100,6 +101,37 @@ func TestAPI(t *testing.T) { if counter == 0 { t.Error("Stream did not return any responses") } + + _, err = c.CreateChatCompletion( + context.Background(), + ChatCompletionRequest{ + Model: GPT3Dot5Turbo, + Messages: []ChatCompletionMessage{ + { + Role: ChatMessageRoleUser, + Content: "What is the weather like in Boston?", + }, + }, + Functions: []FunctionDefinition{{ + Name: "get_current_weather", + Parameters: jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "location": { + Type: jsonschema.String, + Description: "The city and state, e.g. San Francisco, CA", + }, + "unit": { + Type: jsonschema.String, + Enum: []string{"celsius", "fahrenheit"}, + }, + }, + Required: []string{"location"}, + }, + }}, + }, + ) + checks.NoError(t, err, "CreateChatCompletion (with functions) returned error") } func TestAPIError(t *testing.T) { diff --git a/jsonschema/json.go b/jsonschema/json.go index e4eef98..cb941eb 100644 --- a/jsonschema/json.go +++ b/jsonschema/json.go @@ -36,23 +36,14 @@ type Definition struct { Items *Definition `json:"items,omitempty"` } -func (d *Definition) MarshalJSON() ([]byte, error) { - d.initializeProperties() - return json.Marshal(*d) -} - -func (d *Definition) initializeProperties() { +func (d Definition) MarshalJSON() ([]byte, error) { if d.Properties == nil { d.Properties = make(map[string]Definition) - return - } - - for k, v := range d.Properties { - if v.Properties == nil { - v.Properties = make(map[string]Definition) - } else { - v.initializeProperties() - } - d.Properties[k] = v } + type Alias Definition + return json.Marshal(struct { + Alias + }{ + Alias: (Alias)(d), + }) } diff --git a/jsonschema/json_test.go b/jsonschema/json_test.go index 0dc31a5..c8d0c1d 100644 --- a/jsonschema/json_test.go +++ b/jsonschema/json_test.go @@ -172,30 +172,40 @@ func TestDefinition_MarshalJSON(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - gotBytes, err := json.Marshal(&tt.def) - if err != nil { - t.Errorf("Failed to Marshal JSON: error = %v", err) - return - } - - var got map[string]interface{} - err = json.Unmarshal(gotBytes, &got) - if err != nil { - t.Errorf("Failed to Unmarshal JSON: error = %v", err) - return - } - wantBytes := []byte(tt.want) var want map[string]interface{} - err = json.Unmarshal(wantBytes, &want) + err := json.Unmarshal(wantBytes, &want) if err != nil { t.Errorf("Failed to Unmarshal JSON: error = %v", err) return } + got := structToMap(t, tt.def) + gotPtr := structToMap(t, &tt.def) + if !reflect.DeepEqual(got, want) { t.Errorf("MarshalJSON() got = %v, want %v", got, want) } + if !reflect.DeepEqual(gotPtr, want) { + t.Errorf("MarshalJSON() gotPtr = %v, want %v", gotPtr, want) + } }) } } + +func structToMap(t *testing.T, v any) map[string]any { + t.Helper() + gotBytes, err := json.Marshal(v) + if err != nil { + t.Errorf("Failed to Marshal JSON: error = %v", err) + return nil + } + + var got map[string]interface{} + err = json.Unmarshal(gotBytes, &got) + if err != nil { + t.Errorf("Failed to Unmarshal JSON: error = %v", err) + return nil + } + return got +}