feat: add Anthropic API support with custom version header (#934)
* feat: add Anthropic API support with custom version header * refactor: use switch statement for API type header handling * refactor: add OpenAI & AzureAD types to be exhaustive * Update client.go need explicit fallthrough in empty case statements * constant for APIVersion; addtl tests
This commit is contained in:
14
client.go
14
client.go
@@ -182,13 +182,21 @@ func sendRequestStream[T streamable](client *Client, req *http.Request) (*stream
|
|||||||
|
|
||||||
func (c *Client) setCommonHeaders(req *http.Request) {
|
func (c *Client) setCommonHeaders(req *http.Request) {
|
||||||
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/reference#authentication
|
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/reference#authentication
|
||||||
|
switch c.config.APIType {
|
||||||
|
case APITypeAzure, APITypeCloudflareAzure:
|
||||||
// Azure API Key authentication
|
// Azure API Key authentication
|
||||||
if c.config.APIType == APITypeAzure || c.config.APIType == APITypeCloudflareAzure {
|
|
||||||
req.Header.Set(AzureAPIKeyHeader, c.config.authToken)
|
req.Header.Set(AzureAPIKeyHeader, c.config.authToken)
|
||||||
} else if c.config.authToken != "" {
|
case APITypeAnthropic:
|
||||||
// OpenAI or Azure AD authentication
|
// https://docs.anthropic.com/en/api/versioning
|
||||||
|
req.Header.Set("anthropic-version", c.config.APIVersion)
|
||||||
|
case APITypeOpenAI, APITypeAzureAD:
|
||||||
|
fallthrough
|
||||||
|
default:
|
||||||
|
if c.config.authToken != "" {
|
||||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken))
|
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken))
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if c.config.OrgID != "" {
|
if c.config.OrgID != "" {
|
||||||
req.Header.Set("OpenAI-Organization", c.config.OrgID)
|
req.Header.Set("OpenAI-Organization", c.config.OrgID)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -39,6 +39,21 @@ func TestClient(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSetCommonHeadersAnthropic(t *testing.T) {
|
||||||
|
config := DefaultAnthropicConfig("mock-token", "")
|
||||||
|
client := NewClientWithConfig(config)
|
||||||
|
req, err := http.NewRequest("GET", "http://example.com", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create request: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
client.setCommonHeaders(req)
|
||||||
|
|
||||||
|
if got := req.Header.Get("anthropic-version"); got != AnthropicAPIVersion {
|
||||||
|
t.Errorf("Expected anthropic-version header to be %q, got %q", AnthropicAPIVersion, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestDecodeResponse(t *testing.T) {
|
func TestDecodeResponse(t *testing.T) {
|
||||||
stringInput := ""
|
stringInput := ""
|
||||||
|
|
||||||
|
|||||||
22
config.go
22
config.go
@@ -11,6 +11,8 @@ const (
|
|||||||
|
|
||||||
azureAPIPrefix = "openai"
|
azureAPIPrefix = "openai"
|
||||||
azureDeploymentsPrefix = "deployments"
|
azureDeploymentsPrefix = "deployments"
|
||||||
|
|
||||||
|
AnthropicAPIVersion = "2023-06-01"
|
||||||
)
|
)
|
||||||
|
|
||||||
type APIType string
|
type APIType string
|
||||||
@@ -20,6 +22,7 @@ const (
|
|||||||
APITypeAzure APIType = "AZURE"
|
APITypeAzure APIType = "AZURE"
|
||||||
APITypeAzureAD APIType = "AZURE_AD"
|
APITypeAzureAD APIType = "AZURE_AD"
|
||||||
APITypeCloudflareAzure APIType = "CLOUDFLARE_AZURE"
|
APITypeCloudflareAzure APIType = "CLOUDFLARE_AZURE"
|
||||||
|
APITypeAnthropic APIType = "ANTHROPIC"
|
||||||
)
|
)
|
||||||
|
|
||||||
const AzureAPIKeyHeader = "api-key"
|
const AzureAPIKeyHeader = "api-key"
|
||||||
@@ -37,7 +40,7 @@ type ClientConfig struct {
|
|||||||
BaseURL string
|
BaseURL string
|
||||||
OrgID string
|
OrgID string
|
||||||
APIType APIType
|
APIType APIType
|
||||||
APIVersion string // required when APIType is APITypeAzure or APITypeAzureAD
|
APIVersion string // required when APIType is APITypeAzure or APITypeAzureAD or APITypeAnthropic
|
||||||
AssistantVersion string
|
AssistantVersion string
|
||||||
AzureModelMapperFunc func(model string) string // replace model to azure deployment name func
|
AzureModelMapperFunc func(model string) string // replace model to azure deployment name func
|
||||||
HTTPClient HTTPDoer
|
HTTPClient HTTPDoer
|
||||||
@@ -76,6 +79,23 @@ func DefaultAzureConfig(apiKey, baseURL string) ClientConfig {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func DefaultAnthropicConfig(apiKey, baseURL string) ClientConfig {
|
||||||
|
if baseURL == "" {
|
||||||
|
baseURL = "https://api.anthropic.com/v1"
|
||||||
|
}
|
||||||
|
return ClientConfig{
|
||||||
|
authToken: apiKey,
|
||||||
|
BaseURL: baseURL,
|
||||||
|
OrgID: "",
|
||||||
|
APIType: APITypeAnthropic,
|
||||||
|
APIVersion: AnthropicAPIVersion,
|
||||||
|
|
||||||
|
HTTPClient: &http.Client{},
|
||||||
|
|
||||||
|
EmptyMessagesLimit: defaultEmptyMessagesLimit,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (ClientConfig) String() string {
|
func (ClientConfig) String() string {
|
||||||
return "<OpenAI API ClientConfig>"
|
return "<OpenAI API ClientConfig>"
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -60,3 +60,43 @@ func TestGetAzureDeploymentByModel(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDefaultAnthropicConfig(t *testing.T) {
|
||||||
|
apiKey := "test-key"
|
||||||
|
baseURL := "https://api.anthropic.com/v1"
|
||||||
|
|
||||||
|
config := openai.DefaultAnthropicConfig(apiKey, baseURL)
|
||||||
|
|
||||||
|
if config.APIType != openai.APITypeAnthropic {
|
||||||
|
t.Errorf("Expected APIType to be %v, got %v", openai.APITypeAnthropic, config.APIType)
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.APIVersion != openai.AnthropicAPIVersion {
|
||||||
|
t.Errorf("Expected APIVersion to be 2023-06-01, got %v", config.APIVersion)
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.BaseURL != baseURL {
|
||||||
|
t.Errorf("Expected BaseURL to be %v, got %v", baseURL, config.BaseURL)
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.EmptyMessagesLimit != 300 {
|
||||||
|
t.Errorf("Expected EmptyMessagesLimit to be 300, got %v", config.EmptyMessagesLimit)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDefaultAnthropicConfigWithEmptyValues(t *testing.T) {
|
||||||
|
config := openai.DefaultAnthropicConfig("", "")
|
||||||
|
|
||||||
|
if config.APIType != openai.APITypeAnthropic {
|
||||||
|
t.Errorf("Expected APIType to be %v, got %v", openai.APITypeAnthropic, config.APIType)
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.APIVersion != openai.AnthropicAPIVersion {
|
||||||
|
t.Errorf("Expected APIVersion to be %s, got %v", openai.AnthropicAPIVersion, config.APIVersion)
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedBaseURL := "https://api.anthropic.com/v1"
|
||||||
|
if config.BaseURL != expectedBaseURL {
|
||||||
|
t.Errorf("Expected BaseURL to be %v, got %v", expectedBaseURL, config.BaseURL)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user