feat: support cloudflare AI Gateway flavored azure openai (#715)
* feat: support cloudflare AI Gateway flavored azure openai Signed-off-by: STRRL <im@strrl.dev> * test: add test for cloudflare azure fullURL --------- Signed-off-by: STRRL <im@strrl.dev> Co-authored-by: STRRL <im@strrl.dev>
This commit is contained in:
@@ -148,3 +148,39 @@ func TestAzureFullURL(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCloudflareAzureFullURL(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
Name string
|
||||||
|
BaseURL string
|
||||||
|
Expect string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"CloudflareAzureBaseURLWithSlashAutoStrip",
|
||||||
|
"https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo/",
|
||||||
|
"https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo/" +
|
||||||
|
"chat/completions?api-version=2023-05-15",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"CloudflareAzureBaseURLWithoutSlashOK",
|
||||||
|
"https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo",
|
||||||
|
"https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo/" +
|
||||||
|
"chat/completions?api-version=2023-05-15",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, c := range cases {
|
||||||
|
t.Run(c.Name, func(t *testing.T) {
|
||||||
|
az := DefaultAzureConfig("dummy", c.BaseURL)
|
||||||
|
az.APIType = APITypeCloudflareAzure
|
||||||
|
|
||||||
|
cli := NewClientWithConfig(az)
|
||||||
|
|
||||||
|
actual := cli.fullURL("/chat/completions")
|
||||||
|
if actual != c.Expect {
|
||||||
|
t.Errorf("Expected %s, got %s", c.Expect, actual)
|
||||||
|
}
|
||||||
|
t.Logf("Full URL: %s", actual)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
10
client.go
10
client.go
@@ -182,7 +182,7 @@ func sendRequestStream[T streamable](client *Client, req *http.Request) (*stream
|
|||||||
func (c *Client) setCommonHeaders(req *http.Request) {
|
func (c *Client) setCommonHeaders(req *http.Request) {
|
||||||
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/reference#authentication
|
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/reference#authentication
|
||||||
// Azure API Key authentication
|
// Azure API Key authentication
|
||||||
if c.config.APIType == APITypeAzure {
|
if c.config.APIType == APITypeAzure || c.config.APIType == APITypeCloudflareAzure {
|
||||||
req.Header.Set(AzureAPIKeyHeader, c.config.authToken)
|
req.Header.Set(AzureAPIKeyHeader, c.config.authToken)
|
||||||
} else if c.config.authToken != "" {
|
} else if c.config.authToken != "" {
|
||||||
// OpenAI or Azure AD authentication
|
// OpenAI or Azure AD authentication
|
||||||
@@ -246,7 +246,13 @@ func (c *Client) fullURL(suffix string, args ...any) string {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// c.config.APIType == APITypeOpenAI || c.config.APIType == ""
|
// https://developers.cloudflare.com/ai-gateway/providers/azureopenai/
|
||||||
|
if c.config.APIType == APITypeCloudflareAzure {
|
||||||
|
baseURL := c.config.BaseURL
|
||||||
|
baseURL = strings.TrimRight(baseURL, "/")
|
||||||
|
return fmt.Sprintf("%s%s?api-version=%s", baseURL, suffix, c.config.APIVersion)
|
||||||
|
}
|
||||||
|
|
||||||
return fmt.Sprintf("%s%s", c.config.BaseURL, suffix)
|
return fmt.Sprintf("%s%s", c.config.BaseURL, suffix)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ const (
|
|||||||
APITypeOpenAI APIType = "OPEN_AI"
|
APITypeOpenAI APIType = "OPEN_AI"
|
||||||
APITypeAzure APIType = "AZURE"
|
APITypeAzure APIType = "AZURE"
|
||||||
APITypeAzureAD APIType = "AZURE_AD"
|
APITypeAzureAD APIType = "AZURE_AD"
|
||||||
|
APITypeCloudflareAzure APIType = "CLOUDFLARE_AZURE"
|
||||||
)
|
)
|
||||||
|
|
||||||
const AzureAPIKeyHeader = "api-key"
|
const AzureAPIKeyHeader = "api-key"
|
||||||
|
|||||||
Reference in New Issue
Block a user