change azure engine config to modelMapper (#306)

* change azure engine config to azure modelMapper config

* Update go.mod

* Revert "Update go.mod"

This reverts commit 78d14c58f2a9ce668da43f6adbe20b60afcfe0d7.

* lint fix

* add test

* lint fix

* lint fix

* lint fix

* opt

* opt

* opt

* opt
This commit is contained in:
GargantuaX
2023-05-11 05:30:24 +08:00
committed by GitHub
parent 5f4ff3ebfa
commit be253c2d63
14 changed files with 119 additions and 32 deletions

View File

@@ -98,8 +98,10 @@ func decodeString(body io.Reader, output *string) error {
return nil
}
func (c *Client) fullURL(suffix string) string {
// /openai/deployments/{engine}/chat/completions?api-version={api_version}
// fullURL returns full URL for request.
// args[0] is model name, if API type is Azure, model name is required to get deployment name.
func (c *Client) fullURL(suffix string, args ...any) string {
// /openai/deployments/{model}/chat/completions?api-version={api_version}
if c.config.APIType == APITypeAzure || c.config.APIType == APITypeAzureAD {
baseURL := c.config.BaseURL
baseURL = strings.TrimRight(baseURL, "/")
@@ -108,8 +110,17 @@ func (c *Client) fullURL(suffix string) string {
if strings.Contains(suffix, "/models") {
return fmt.Sprintf("%s/%s%s?api-version=%s", baseURL, azureAPIPrefix, suffix, c.config.APIVersion)
}
azureDeploymentName := "UNKNOWN"
if len(args) > 0 {
model, ok := args[0].(string)
if ok {
azureDeploymentName = c.config.GetAzureDeploymentByModel(model)
}
}
return fmt.Sprintf("%s/%s/%s/%s%s?api-version=%s",
baseURL, azureAPIPrefix, azureDeploymentsPrefix, c.config.Engine, suffix, c.config.APIVersion)
baseURL, azureAPIPrefix, azureDeploymentsPrefix,
azureDeploymentName, suffix, c.config.APIVersion,
)
}
// c.config.APIType == APITypeOpenAI || c.config.APIType == ""
@@ -120,8 +131,9 @@ func (c *Client) newStreamRequest(
ctx context.Context,
method string,
urlSuffix string,
body any) (*http.Request, error) {
req, err := c.requestBuilder.build(ctx, method, c.fullURL(urlSuffix), body)
body any,
model string) (*http.Request, error) {
req, err := c.requestBuilder.build(ctx, method, c.fullURL(urlSuffix, model), body)
if err != nil {
return nil, err
}