diff --git a/api.go b/api.go index 00d6d35..2c978bc 100644 --- a/api.go +++ b/api.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "net/http" + "strings" ) // Client is OpenAI GPT-3 API client. @@ -39,7 +40,13 @@ func NewOrgClient(authToken, org string) *Client { func (c *Client) sendRequest(req *http.Request, v interface{}) error { req.Header.Set("Accept", "application/json; charset=utf-8") - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken)) + // Azure API Key authentication + if c.config.APIType == APITypeAzure { + req.Header.Set(AzureAPIKeyHeader, c.config.authToken) + } else { + // OpenAI or Azure AD authentication + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken)) + } // Check whether Content-Type is already set, Upload Files API requires // Content-Type == multipart/form-data @@ -83,6 +90,15 @@ func (c *Client) sendRequest(req *http.Request, v interface{}) error { } func (c *Client) fullURL(suffix string) string { + // /openai/deployments/{engine}/chat/completions?api-version={api_version} + if c.config.APIType == APITypeAzure || c.config.APIType == APITypeAzureAD { + baseURL := c.config.BaseURL + baseURL = strings.TrimRight(baseURL, "/") + return fmt.Sprintf("%s/%s/%s/%s%s?api-version=%s", + baseURL, azureAPIPrefix, azureDeploymentsPrefix, c.config.Engine, suffix, c.config.APIVersion) + } + + // c.config.APIType == APITypeOpenAI || c.config.APIType == "" return fmt.Sprintf("%s%s", c.config.BaseURL, suffix) } @@ -100,7 +116,14 @@ func (c *Client) newStreamRequest( req.Header.Set("Accept", "text/event-stream") req.Header.Set("Cache-Control", "no-cache") req.Header.Set("Connection", "keep-alive") - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken)) + // https://learn.microsoft.com/en-us/azure/cognitive-services/openai/reference#authentication + // Azure API Key authentication + if c.config.APIType == APITypeAzure { + req.Header.Set(AzureAPIKeyHeader, c.config.authToken) + } else { + // OpenAI or Azure AD authentication + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken)) + } return req, nil } diff --git a/api_internal_test.go b/api_internal_test.go new file mode 100644 index 0000000..83dcafc --- /dev/null +++ b/api_internal_test.go @@ -0,0 +1,133 @@ +package openai + +import ( + "context" + "testing" +) + +func TestOpenAIFullURL(t *testing.T) { + cases := []struct { + Name string + Suffix string + Expect string + }{ + { + "ChatCompletionsURL", + "/chat/completions", + "https://api.openai.com/v1/chat/completions", + }, + { + "CompletionsURL", + "/completions", + "https://api.openai.com/v1/completions", + }, + } + + for _, c := range cases { + t.Run(c.Name, func(t *testing.T) { + az := DefaultConfig("dummy") + cli := NewClientWithConfig(az) + actual := cli.fullURL(c.Suffix) + if actual != c.Expect { + t.Errorf("Expected %s, got %s", c.Expect, actual) + } + t.Logf("Full URL: %s", actual) + }) + } +} + +func TestRequestAuthHeader(t *testing.T) { + cases := []struct { + Name string + APIType APIType + HeaderKey string + Token string + Expect string + }{ + { + "OpenAIDefault", + "", + "Authorization", + "dummy-token-openai", + "Bearer dummy-token-openai", + }, + { + "OpenAI", + APITypeOpenAI, + "Authorization", + "dummy-token-openai", + "Bearer dummy-token-openai", + }, + { + "AzureAD", + APITypeAzureAD, + "Authorization", + "dummy-token-azure", + "Bearer dummy-token-azure", + }, + { + "Azure", + APITypeAzure, + AzureAPIKeyHeader, + "dummy-api-key-here", + "dummy-api-key-here", + }, + } + + for _, c := range cases { + t.Run(c.Name, func(t *testing.T) { + az := DefaultConfig(c.Token) + az.APIType = c.APIType + + cli := NewClientWithConfig(az) + req, err := cli.newStreamRequest(context.Background(), "POST", "/chat/completions", nil) + if err != nil { + t.Errorf("Failed to create request: %v", err) + } + actual := req.Header.Get(c.HeaderKey) + if actual != c.Expect { + t.Errorf("Expected %s, got %s", c.Expect, actual) + } + t.Logf("%s: %s", c.HeaderKey, actual) + }) + } +} + +func TestAzureFullURL(t *testing.T) { + cases := []struct { + Name string + BaseURL string + Engine string + Expect string + }{ + { + "AzureBaseURLWithSlashAutoStrip", + "https://httpbin.org/", + "chatgpt-demo", + "https://httpbin.org/" + + "openai/deployments/chatgpt-demo" + + "/chat/completions?api-version=2023-03-15-preview", + }, + { + "AzureBaseURLWithoutSlashOK", + "https://httpbin.org", + "chatgpt-demo", + "https://httpbin.org/" + + "openai/deployments/chatgpt-demo" + + "/chat/completions?api-version=2023-03-15-preview", + }, + } + + for _, c := range cases { + t.Run(c.Name, func(t *testing.T) { + az := DefaultAzureConfig("dummy", c.BaseURL, c.Engine) + cli := NewClientWithConfig(az) + // /openai/deployments/{engine}/chat/completions?api-version={api_version} + actual := cli.fullURL("/chat/completions") + if actual != c.Expect { + t.Errorf("Expected %s, got %s", c.Expect, actual) + } + t.Logf("Full URL: %s", actual) + }) + } +} diff --git a/config.go b/config.go index e09c256..52e1efc 100644 --- a/config.go +++ b/config.go @@ -5,28 +5,61 @@ import ( ) const ( - apiURLv1 = "https://api.openai.com/v1" + openaiAPIURLv1 = "https://api.openai.com/v1" defaultEmptyMessagesLimit uint = 300 + + azureAPIPrefix = "openai" + azureDeploymentsPrefix = "deployments" ) +type APIType string + +const ( + APITypeOpenAI APIType = "OPEN_AI" + APITypeAzure APIType = "AZURE" + APITypeAzureAD APIType = "AZURE_AD" +) + +const AzureAPIKeyHeader = "api-key" + // ClientConfig is a configuration of a client. type ClientConfig struct { authToken string - HTTPClient *http.Client + BaseURL string + OrgID string + APIType APIType + APIVersion string // required when APIType is APITypeAzure or APITypeAzureAD + Engine string // required when APIType is APITypeAzure or APITypeAzureAD - BaseURL string - OrgID string + HTTPClient *http.Client EmptyMessagesLimit uint } func DefaultConfig(authToken string) ClientConfig { return ClientConfig{ + authToken: authToken, + BaseURL: openaiAPIURLv1, + APIType: APITypeOpenAI, + OrgID: "", + + HTTPClient: &http.Client{}, + + EmptyMessagesLimit: defaultEmptyMessagesLimit, + } +} + +func DefaultAzureConfig(apiKey, baseURL, engine string) ClientConfig { + return ClientConfig{ + authToken: apiKey, + BaseURL: baseURL, + OrgID: "", + APIType: APITypeAzure, + APIVersion: "2023-03-15-preview", + Engine: engine, + HTTPClient: &http.Client{}, - BaseURL: apiURLv1, - OrgID: "", - authToken: authToken, EmptyMessagesLimit: defaultEmptyMessagesLimit, }