diff --git a/chat.go b/chat.go index 514aaee..8d29b32 100644 --- a/chat.go +++ b/chat.go @@ -114,6 +114,13 @@ const ( FinishReasonNull FinishReason = "null" ) +func (r FinishReason) MarshalJSON() ([]byte, error) { + if r == FinishReasonNull || r == "" { + return []byte("null"), nil + } + return []byte(`"` + string(r) + `"`), nil // best effort to not break future API changes +} + type ChatCompletionChoice struct { Index int `json:"index"` Message ChatCompletionMessage `json:"message"` diff --git a/chat_test.go b/chat_test.go index 5723d6c..38d66fa 100644 --- a/chat_test.go +++ b/chat_test.go @@ -298,3 +298,34 @@ func getChatCompletionBody(r *http.Request) (ChatCompletionRequest, error) { } return completion, nil } + +func TestFinishReason(t *testing.T) { + c := &ChatCompletionChoice{ + FinishReason: FinishReasonNull, + } + resBytes, _ := json.Marshal(c) + if !strings.Contains(string(resBytes), `"finish_reason":null`) { + t.Error("null should not be quoted") + } + + c.FinishReason = "" + + resBytes, _ = json.Marshal(c) + if !strings.Contains(string(resBytes), `"finish_reason":null`) { + t.Error("null should not be quoted") + } + + otherReasons := []FinishReason{ + FinishReasonStop, + FinishReasonLength, + FinishReasonFunctionCall, + FinishReasonContentFilter, + } + for _, r := range otherReasons { + c.FinishReason = r + resBytes, _ = json.Marshal(c) + if !strings.Contains(string(resBytes), fmt.Sprintf(`"finish_reason":"%s"`, r)) { + t.Errorf("%s should be quoted", r) + } + } +}