diff --git a/client.go b/client.go index ed8595e..cef3753 100644 --- a/client.go +++ b/client.go @@ -182,13 +182,21 @@ func sendRequestStream[T streamable](client *Client, req *http.Request) (*stream func (c *Client) setCommonHeaders(req *http.Request) { // https://learn.microsoft.com/en-us/azure/cognitive-services/openai/reference#authentication - // Azure API Key authentication - if c.config.APIType == APITypeAzure || c.config.APIType == APITypeCloudflareAzure { + switch c.config.APIType { + case APITypeAzure, APITypeCloudflareAzure: + // Azure API Key authentication req.Header.Set(AzureAPIKeyHeader, c.config.authToken) - } else if c.config.authToken != "" { - // OpenAI or Azure AD authentication - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken)) + case APITypeAnthropic: + // https://docs.anthropic.com/en/api/versioning + req.Header.Set("anthropic-version", c.config.APIVersion) + case APITypeOpenAI, APITypeAzureAD: + fallthrough + default: + if c.config.authToken != "" { + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken)) + } } + if c.config.OrgID != "" { req.Header.Set("OpenAI-Organization", c.config.OrgID) } diff --git a/client_test.go b/client_test.go index 2ed82f1..3219714 100644 --- a/client_test.go +++ b/client_test.go @@ -39,6 +39,21 @@ func TestClient(t *testing.T) { } } +func TestSetCommonHeadersAnthropic(t *testing.T) { + config := DefaultAnthropicConfig("mock-token", "") + client := NewClientWithConfig(config) + req, err := http.NewRequest("GET", "http://example.com", nil) + if err != nil { + t.Fatalf("Failed to create request: %v", err) + } + + client.setCommonHeaders(req) + + if got := req.Header.Get("anthropic-version"); got != AnthropicAPIVersion { + t.Errorf("Expected anthropic-version header to be %q, got %q", AnthropicAPIVersion, got) + } +} + func TestDecodeResponse(t *testing.T) { stringInput := "" diff --git a/config.go b/config.go index 8a91835..4788ba6 100644 --- a/config.go +++ b/config.go @@ -11,6 +11,8 @@ const ( azureAPIPrefix = "openai" azureDeploymentsPrefix = "deployments" + + AnthropicAPIVersion = "2023-06-01" ) type APIType string @@ -20,6 +22,7 @@ const ( APITypeAzure APIType = "AZURE" APITypeAzureAD APIType = "AZURE_AD" APITypeCloudflareAzure APIType = "CLOUDFLARE_AZURE" + APITypeAnthropic APIType = "ANTHROPIC" ) const AzureAPIKeyHeader = "api-key" @@ -37,7 +40,7 @@ type ClientConfig struct { BaseURL string OrgID string APIType APIType - APIVersion string // required when APIType is APITypeAzure or APITypeAzureAD + APIVersion string // required when APIType is APITypeAzure or APITypeAzureAD or APITypeAnthropic AssistantVersion string AzureModelMapperFunc func(model string) string // replace model to azure deployment name func HTTPClient HTTPDoer @@ -76,6 +79,23 @@ func DefaultAzureConfig(apiKey, baseURL string) ClientConfig { } } +func DefaultAnthropicConfig(apiKey, baseURL string) ClientConfig { + if baseURL == "" { + baseURL = "https://api.anthropic.com/v1" + } + return ClientConfig{ + authToken: apiKey, + BaseURL: baseURL, + OrgID: "", + APIType: APITypeAnthropic, + APIVersion: AnthropicAPIVersion, + + HTTPClient: &http.Client{}, + + EmptyMessagesLimit: defaultEmptyMessagesLimit, + } +} + func (ClientConfig) String() string { return "" } diff --git a/config_test.go b/config_test.go index 3e528c3..145c260 100644 --- a/config_test.go +++ b/config_test.go @@ -60,3 +60,43 @@ func TestGetAzureDeploymentByModel(t *testing.T) { }) } } + +func TestDefaultAnthropicConfig(t *testing.T) { + apiKey := "test-key" + baseURL := "https://api.anthropic.com/v1" + + config := openai.DefaultAnthropicConfig(apiKey, baseURL) + + if config.APIType != openai.APITypeAnthropic { + t.Errorf("Expected APIType to be %v, got %v", openai.APITypeAnthropic, config.APIType) + } + + if config.APIVersion != openai.AnthropicAPIVersion { + t.Errorf("Expected APIVersion to be 2023-06-01, got %v", config.APIVersion) + } + + if config.BaseURL != baseURL { + t.Errorf("Expected BaseURL to be %v, got %v", baseURL, config.BaseURL) + } + + if config.EmptyMessagesLimit != 300 { + t.Errorf("Expected EmptyMessagesLimit to be 300, got %v", config.EmptyMessagesLimit) + } +} + +func TestDefaultAnthropicConfigWithEmptyValues(t *testing.T) { + config := openai.DefaultAnthropicConfig("", "") + + if config.APIType != openai.APITypeAnthropic { + t.Errorf("Expected APIType to be %v, got %v", openai.APITypeAnthropic, config.APIType) + } + + if config.APIVersion != openai.AnthropicAPIVersion { + t.Errorf("Expected APIVersion to be %s, got %v", openai.AnthropicAPIVersion, config.APIVersion) + } + + expectedBaseURL := "https://api.anthropic.com/v1" + if config.BaseURL != expectedBaseURL { + t.Errorf("Expected BaseURL to be %v, got %v", expectedBaseURL, config.BaseURL) + } +}