Support get http header and x-ratelimit-* headers (#507)

* feat: add headers to http response

* feat: support rate limit headers

* fix: go lint

* fix: test coverage

* refactor streamReader

* refactor streamReader

* refactor: NewRateLimitHeaders to newRateLimitHeaders

* refactor: RateLimitHeaders Resets filed

* refactor: move RateLimitHeaders struct
This commit is contained in:
Liu Shuang
2023-10-10 10:29:41 -05:00
committed by GitHub
parent 8e165dc9aa
commit b77d01edca
5 changed files with 191 additions and 5 deletions

View File

@@ -21,6 +21,17 @@ const (
xCustomHeaderValue = "test"
)
var (
rateLimitHeaders = map[string]any{
"x-ratelimit-limit-requests": 60,
"x-ratelimit-limit-tokens": 150000,
"x-ratelimit-remaining-requests": 59,
"x-ratelimit-remaining-tokens": 149984,
"x-ratelimit-reset-requests": "1s",
"x-ratelimit-reset-tokens": "6m0s",
}
)
func TestChatCompletionsWrongModel(t *testing.T) {
config := DefaultConfig("whatever")
config.BaseURL = "http://localhost/v1"
@@ -97,6 +108,40 @@ func TestChatCompletionsWithHeaders(t *testing.T) {
}
}
// TestChatCompletionsWithRateLimitHeaders Tests the completions endpoint of the API using the mocked server.
func TestChatCompletionsWithRateLimitHeaders(t *testing.T) {
client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint)
resp, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{
MaxTokens: 5,
Model: GPT3Dot5Turbo,
Messages: []ChatCompletionMessage{
{
Role: ChatMessageRoleUser,
Content: "Hello!",
},
},
})
checks.NoError(t, err, "CreateChatCompletion error")
headers := resp.GetRateLimitHeaders()
resetRequests := headers.ResetRequests.String()
if resetRequests != rateLimitHeaders["x-ratelimit-reset-requests"] {
t.Errorf("expected resetRequests %s to be %s", resetRequests, rateLimitHeaders["x-ratelimit-reset-requests"])
}
resetRequestsTime := headers.ResetRequests.Time()
if resetRequestsTime.Before(time.Now()) {
t.Errorf("unexpected reset requetsts: %v", resetRequestsTime)
}
bs1, _ := json.Marshal(headers)
bs2, _ := json.Marshal(rateLimitHeaders)
if string(bs1) != string(bs2) {
t.Errorf("expected rate limit header %s to be %s", bs2, bs1)
}
}
// TestChatCompletionsFunctions tests including a function call.
func TestChatCompletionsFunctions(t *testing.T) {
client, server, teardown := setupOpenAITestServer()
@@ -311,6 +356,14 @@ func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
}
resBytes, _ = json.Marshal(res)
w.Header().Set(xCustomHeader, xCustomHeaderValue)
for k, v := range rateLimitHeaders {
switch val := v.(type) {
case int:
w.Header().Set(k, strconv.Itoa(val))
default:
w.Header().Set(k, fmt.Sprintf("%s", v))
}
}
fmt.Fprintln(w, string(resBytes))
}