feat: add Anthropic API support with custom version header (#934)

* feat: add Anthropic API support with custom version header

* refactor: use switch statement for API type header handling

* refactor: add OpenAI & AzureAD types to be exhaustive

* Update client.go

need explicit fallthrough in empty case statements

* constant for APIVersion; addtl tests
This commit is contained in:
Dan Ackerson
2025-02-25 12:03:38 +01:00
committed by GitHub
parent 85f578b865
commit be2e2387d4
4 changed files with 89 additions and 6 deletions

View File

@@ -182,13 +182,21 @@ func sendRequestStream[T streamable](client *Client, req *http.Request) (*stream
func (c *Client) setCommonHeaders(req *http.Request) { func (c *Client) setCommonHeaders(req *http.Request) {
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/reference#authentication // https://learn.microsoft.com/en-us/azure/cognitive-services/openai/reference#authentication
// Azure API Key authentication switch c.config.APIType {
if c.config.APIType == APITypeAzure || c.config.APIType == APITypeCloudflareAzure { case APITypeAzure, APITypeCloudflareAzure:
// Azure API Key authentication
req.Header.Set(AzureAPIKeyHeader, c.config.authToken) req.Header.Set(AzureAPIKeyHeader, c.config.authToken)
} else if c.config.authToken != "" { case APITypeAnthropic:
// OpenAI or Azure AD authentication // https://docs.anthropic.com/en/api/versioning
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken)) req.Header.Set("anthropic-version", c.config.APIVersion)
case APITypeOpenAI, APITypeAzureAD:
fallthrough
default:
if c.config.authToken != "" {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken))
}
} }
if c.config.OrgID != "" { if c.config.OrgID != "" {
req.Header.Set("OpenAI-Organization", c.config.OrgID) req.Header.Set("OpenAI-Organization", c.config.OrgID)
} }

View File

@@ -39,6 +39,21 @@ func TestClient(t *testing.T) {
} }
} }
func TestSetCommonHeadersAnthropic(t *testing.T) {
config := DefaultAnthropicConfig("mock-token", "")
client := NewClientWithConfig(config)
req, err := http.NewRequest("GET", "http://example.com", nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
client.setCommonHeaders(req)
if got := req.Header.Get("anthropic-version"); got != AnthropicAPIVersion {
t.Errorf("Expected anthropic-version header to be %q, got %q", AnthropicAPIVersion, got)
}
}
func TestDecodeResponse(t *testing.T) { func TestDecodeResponse(t *testing.T) {
stringInput := "" stringInput := ""

View File

@@ -11,6 +11,8 @@ const (
azureAPIPrefix = "openai" azureAPIPrefix = "openai"
azureDeploymentsPrefix = "deployments" azureDeploymentsPrefix = "deployments"
AnthropicAPIVersion = "2023-06-01"
) )
type APIType string type APIType string
@@ -20,6 +22,7 @@ const (
APITypeAzure APIType = "AZURE" APITypeAzure APIType = "AZURE"
APITypeAzureAD APIType = "AZURE_AD" APITypeAzureAD APIType = "AZURE_AD"
APITypeCloudflareAzure APIType = "CLOUDFLARE_AZURE" APITypeCloudflareAzure APIType = "CLOUDFLARE_AZURE"
APITypeAnthropic APIType = "ANTHROPIC"
) )
const AzureAPIKeyHeader = "api-key" const AzureAPIKeyHeader = "api-key"
@@ -37,7 +40,7 @@ type ClientConfig struct {
BaseURL string BaseURL string
OrgID string OrgID string
APIType APIType APIType APIType
APIVersion string // required when APIType is APITypeAzure or APITypeAzureAD APIVersion string // required when APIType is APITypeAzure or APITypeAzureAD or APITypeAnthropic
AssistantVersion string AssistantVersion string
AzureModelMapperFunc func(model string) string // replace model to azure deployment name func AzureModelMapperFunc func(model string) string // replace model to azure deployment name func
HTTPClient HTTPDoer HTTPClient HTTPDoer
@@ -76,6 +79,23 @@ func DefaultAzureConfig(apiKey, baseURL string) ClientConfig {
} }
} }
func DefaultAnthropicConfig(apiKey, baseURL string) ClientConfig {
if baseURL == "" {
baseURL = "https://api.anthropic.com/v1"
}
return ClientConfig{
authToken: apiKey,
BaseURL: baseURL,
OrgID: "",
APIType: APITypeAnthropic,
APIVersion: AnthropicAPIVersion,
HTTPClient: &http.Client{},
EmptyMessagesLimit: defaultEmptyMessagesLimit,
}
}
func (ClientConfig) String() string { func (ClientConfig) String() string {
return "<OpenAI API ClientConfig>" return "<OpenAI API ClientConfig>"
} }

View File

@@ -60,3 +60,43 @@ func TestGetAzureDeploymentByModel(t *testing.T) {
}) })
} }
} }
func TestDefaultAnthropicConfig(t *testing.T) {
apiKey := "test-key"
baseURL := "https://api.anthropic.com/v1"
config := openai.DefaultAnthropicConfig(apiKey, baseURL)
if config.APIType != openai.APITypeAnthropic {
t.Errorf("Expected APIType to be %v, got %v", openai.APITypeAnthropic, config.APIType)
}
if config.APIVersion != openai.AnthropicAPIVersion {
t.Errorf("Expected APIVersion to be 2023-06-01, got %v", config.APIVersion)
}
if config.BaseURL != baseURL {
t.Errorf("Expected BaseURL to be %v, got %v", baseURL, config.BaseURL)
}
if config.EmptyMessagesLimit != 300 {
t.Errorf("Expected EmptyMessagesLimit to be 300, got %v", config.EmptyMessagesLimit)
}
}
func TestDefaultAnthropicConfigWithEmptyValues(t *testing.T) {
config := openai.DefaultAnthropicConfig("", "")
if config.APIType != openai.APITypeAnthropic {
t.Errorf("Expected APIType to be %v, got %v", openai.APITypeAnthropic, config.APIType)
}
if config.APIVersion != openai.AnthropicAPIVersion {
t.Errorf("Expected APIVersion to be %s, got %v", openai.AnthropicAPIVersion, config.APIVersion)
}
expectedBaseURL := "https://api.anthropic.com/v1"
if config.BaseURL != expectedBaseURL {
t.Errorf("Expected BaseURL to be %v, got %v", expectedBaseURL, config.BaseURL)
}
}