From 8677fb4bb4bc6c74c4eb4de439460a5ec496d432 Mon Sep 17 00:00:00 2001 From: ttys3 <41882455+ttys3@users.noreply.github.com> Date: Tue, 4 Apr 2023 16:05:20 +0800 Subject: [PATCH] feat: add azure openai support (#214) * feat: add azure openai support * chore: refine config * chore: make config options like the python one * chore: adjust config struct field order * test: fix tests * style: make the linter happy * fix: support Azure API Key authentication in sendRequest * chore: check error in CreateChatCompletionStream * chore: pass tests * chore: try pass tests again * chore: change ClientConfig back due to this lib does not like WithXxx config style * chore: revert fix to CreateChatCompletionStream() due to cause tests not pass * chore: at least add some comment about the required fields * chore: re order ClientConfig fields * chore: add DefaultAzure() * chore: set default api_version the same as py one "2023-03-15-preview" * style: fixup typo * test: add api_internal_test.go * style: make lint happy * chore: add constant AzureAPIKeyHeader * chore: use AzureAPIKeyHeader for api-key header, fix azure base url auto trim suffix / * test: add TestAzureFullURL, TestRequestAuthHeader and TestOpenAIFullURL * test: simplify TestRequestAuthHeader * test: refine TestOpenAIFullURL * chore: refine comments * feat: DefaultAzureConfig --- api.go | 27 ++++++++- api_internal_test.go | 133 +++++++++++++++++++++++++++++++++++++++++++ config.go | 47 ++++++++++++--- 3 files changed, 198 insertions(+), 9 deletions(-) create mode 100644 api_internal_test.go diff --git a/api.go b/api.go index 00d6d35..2c978bc 100644 --- a/api.go +++ b/api.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "net/http" + "strings" ) // Client is OpenAI GPT-3 API client. @@ -39,7 +40,13 @@ func NewOrgClient(authToken, org string) *Client { func (c *Client) sendRequest(req *http.Request, v interface{}) error { req.Header.Set("Accept", "application/json; charset=utf-8") - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken)) + // Azure API Key authentication + if c.config.APIType == APITypeAzure { + req.Header.Set(AzureAPIKeyHeader, c.config.authToken) + } else { + // OpenAI or Azure AD authentication + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken)) + } // Check whether Content-Type is already set, Upload Files API requires // Content-Type == multipart/form-data @@ -83,6 +90,15 @@ func (c *Client) sendRequest(req *http.Request, v interface{}) error { } func (c *Client) fullURL(suffix string) string { + // /openai/deployments/{engine}/chat/completions?api-version={api_version} + if c.config.APIType == APITypeAzure || c.config.APIType == APITypeAzureAD { + baseURL := c.config.BaseURL + baseURL = strings.TrimRight(baseURL, "/") + return fmt.Sprintf("%s/%s/%s/%s%s?api-version=%s", + baseURL, azureAPIPrefix, azureDeploymentsPrefix, c.config.Engine, suffix, c.config.APIVersion) + } + + // c.config.APIType == APITypeOpenAI || c.config.APIType == "" return fmt.Sprintf("%s%s", c.config.BaseURL, suffix) } @@ -100,7 +116,14 @@ func (c *Client) newStreamRequest( req.Header.Set("Accept", "text/event-stream") req.Header.Set("Cache-Control", "no-cache") req.Header.Set("Connection", "keep-alive") - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken)) + // https://learn.microsoft.com/en-us/azure/cognitive-services/openai/reference#authentication + // Azure API Key authentication + if c.config.APIType == APITypeAzure { + req.Header.Set(AzureAPIKeyHeader, c.config.authToken) + } else { + // OpenAI or Azure AD authentication + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken)) + } return req, nil } diff --git a/api_internal_test.go b/api_internal_test.go new file mode 100644 index 0000000..83dcafc --- /dev/null +++ b/api_internal_test.go @@ -0,0 +1,133 @@ +package openai + +import ( + "context" + "testing" +) + +func TestOpenAIFullURL(t *testing.T) { + cases := []struct { + Name string + Suffix string + Expect string + }{ + { + "ChatCompletionsURL", + "/chat/completions", + "https://api.openai.com/v1/chat/completions", + }, + { + "CompletionsURL", + "/completions", + "https://api.openai.com/v1/completions", + }, + } + + for _, c := range cases { + t.Run(c.Name, func(t *testing.T) { + az := DefaultConfig("dummy") + cli := NewClientWithConfig(az) + actual := cli.fullURL(c.Suffix) + if actual != c.Expect { + t.Errorf("Expected %s, got %s", c.Expect, actual) + } + t.Logf("Full URL: %s", actual) + }) + } +} + +func TestRequestAuthHeader(t *testing.T) { + cases := []struct { + Name string + APIType APIType + HeaderKey string + Token string + Expect string + }{ + { + "OpenAIDefault", + "", + "Authorization", + "dummy-token-openai", + "Bearer dummy-token-openai", + }, + { + "OpenAI", + APITypeOpenAI, + "Authorization", + "dummy-token-openai", + "Bearer dummy-token-openai", + }, + { + "AzureAD", + APITypeAzureAD, + "Authorization", + "dummy-token-azure", + "Bearer dummy-token-azure", + }, + { + "Azure", + APITypeAzure, + AzureAPIKeyHeader, + "dummy-api-key-here", + "dummy-api-key-here", + }, + } + + for _, c := range cases { + t.Run(c.Name, func(t *testing.T) { + az := DefaultConfig(c.Token) + az.APIType = c.APIType + + cli := NewClientWithConfig(az) + req, err := cli.newStreamRequest(context.Background(), "POST", "/chat/completions", nil) + if err != nil { + t.Errorf("Failed to create request: %v", err) + } + actual := req.Header.Get(c.HeaderKey) + if actual != c.Expect { + t.Errorf("Expected %s, got %s", c.Expect, actual) + } + t.Logf("%s: %s", c.HeaderKey, actual) + }) + } +} + +func TestAzureFullURL(t *testing.T) { + cases := []struct { + Name string + BaseURL string + Engine string + Expect string + }{ + { + "AzureBaseURLWithSlashAutoStrip", + "https://httpbin.org/", + "chatgpt-demo", + "https://httpbin.org/" + + "openai/deployments/chatgpt-demo" + + "/chat/completions?api-version=2023-03-15-preview", + }, + { + "AzureBaseURLWithoutSlashOK", + "https://httpbin.org", + "chatgpt-demo", + "https://httpbin.org/" + + "openai/deployments/chatgpt-demo" + + "/chat/completions?api-version=2023-03-15-preview", + }, + } + + for _, c := range cases { + t.Run(c.Name, func(t *testing.T) { + az := DefaultAzureConfig("dummy", c.BaseURL, c.Engine) + cli := NewClientWithConfig(az) + // /openai/deployments/{engine}/chat/completions?api-version={api_version} + actual := cli.fullURL("/chat/completions") + if actual != c.Expect { + t.Errorf("Expected %s, got %s", c.Expect, actual) + } + t.Logf("Full URL: %s", actual) + }) + } +} diff --git a/config.go b/config.go index e09c256..52e1efc 100644 --- a/config.go +++ b/config.go @@ -5,28 +5,61 @@ import ( ) const ( - apiURLv1 = "https://api.openai.com/v1" + openaiAPIURLv1 = "https://api.openai.com/v1" defaultEmptyMessagesLimit uint = 300 + + azureAPIPrefix = "openai" + azureDeploymentsPrefix = "deployments" ) +type APIType string + +const ( + APITypeOpenAI APIType = "OPEN_AI" + APITypeAzure APIType = "AZURE" + APITypeAzureAD APIType = "AZURE_AD" +) + +const AzureAPIKeyHeader = "api-key" + // ClientConfig is a configuration of a client. type ClientConfig struct { authToken string - HTTPClient *http.Client + BaseURL string + OrgID string + APIType APIType + APIVersion string // required when APIType is APITypeAzure or APITypeAzureAD + Engine string // required when APIType is APITypeAzure or APITypeAzureAD - BaseURL string - OrgID string + HTTPClient *http.Client EmptyMessagesLimit uint } func DefaultConfig(authToken string) ClientConfig { return ClientConfig{ + authToken: authToken, + BaseURL: openaiAPIURLv1, + APIType: APITypeOpenAI, + OrgID: "", + + HTTPClient: &http.Client{}, + + EmptyMessagesLimit: defaultEmptyMessagesLimit, + } +} + +func DefaultAzureConfig(apiKey, baseURL, engine string) ClientConfig { + return ClientConfig{ + authToken: apiKey, + BaseURL: baseURL, + OrgID: "", + APIType: APITypeAzure, + APIVersion: "2023-03-15-preview", + Engine: engine, + HTTPClient: &http.Client{}, - BaseURL: apiURLv1, - OrgID: "", - authToken: authToken, EmptyMessagesLimit: defaultEmptyMessagesLimit, }