From 104c0c0b63e4915f4ff44faf07a153949e16ab9f Mon Sep 17 00:00:00 2001 From: Juan Date: Wed, 3 May 2023 19:02:35 +1000 Subject: [PATCH] Azure openai list models (#290) * feat(models): include flow for azure openai endpoint * feat(models): include flow for azure openai endpoint * feat(models): include flow for azure openai endpoint * chore(fullURL): update logic to run in fullURL function * chore(fullURL): update based on pr comments to use c.config.APIVersion --- client.go | 5 +++++ internal/test/server.go | 2 +- models_test.go | 18 ++++++++++++++++++ 3 files changed, 24 insertions(+), 1 deletion(-) diff --git a/client.go b/client.go index 500b3d5..b3d7595 100644 --- a/client.go +++ b/client.go @@ -103,6 +103,11 @@ func (c *Client) fullURL(suffix string) string { if c.config.APIType == APITypeAzure || c.config.APIType == APITypeAzureAD { baseURL := c.config.BaseURL baseURL = strings.TrimRight(baseURL, "/") + // if suffix is /models change to {endpoint}/openai/models?api-version=2022-12-01 + // https://learn.microsoft.com/en-us/rest/api/cognitiveservices/azureopenaistable/models/list?tabs=HTTP + if strings.Contains(suffix, "/models") { + return fmt.Sprintf("%s/%s%s?api-version=%s", baseURL, azureAPIPrefix, suffix, c.config.APIVersion) + } return fmt.Sprintf("%s/%s/%s/%s%s?api-version=%s", baseURL, azureAPIPrefix, azureDeploymentsPrefix, c.config.Engine, suffix, c.config.APIVersion) } diff --git a/internal/test/server.go b/internal/test/server.go index 0c6f67d..79d55c4 100644 --- a/internal/test/server.go +++ b/internal/test/server.go @@ -31,7 +31,7 @@ func (ts *ServerTest) OpenAITestServer() *httptest.Server { log.Printf("received request at path %q\n", r.URL.Path) // check auth - if r.Header.Get("Authorization") != "Bearer "+GetTestToken() { + if r.Header.Get("Authorization") != "Bearer "+GetTestToken() && r.Header.Get("api-key") != GetTestToken() { w.WriteHeader(http.StatusUnauthorized) return } diff --git a/models_test.go b/models_test.go index dad59be..70d6d75 100644 --- a/models_test.go +++ b/models_test.go @@ -31,6 +31,24 @@ func TestListModels(t *testing.T) { checks.NoError(t, err, "ListModels error") } +func TestAzureListModels(t *testing.T) { + server := test.NewTestServer() + server.RegisterHandler("/openai/models", handleModelsEndpoint) + // create the test server + var err error + ts := server.OpenAITestServer() + ts.Start() + defer ts.Close() + + config := DefaultAzureConfig(test.GetTestToken(), "https://dummylab.openai.azure.com/", "dummyengine") + config.BaseURL = ts.URL + client := NewClientWithConfig(config) + ctx := context.Background() + + _, err = client.ListModels(ctx) + checks.NoError(t, err, "ListModels error") +} + // handleModelsEndpoint Handles the models endpoint by the test server. func handleModelsEndpoint(w http.ResponseWriter, _ *http.Request) { resBytes, _ := json.Marshal(ModelsList{})