change azure engine config to modelMapper (#306)

* change azure engine config to azure modelMapper config

* Update go.mod

* Revert "Update go.mod"

This reverts commit 78d14c58f2a9ce668da43f6adbe20b60afcfe0d7.

* lint fix

* add test

* lint fix

* lint fix

* lint fix

* opt

* opt

* opt

* opt
This commit is contained in:
GargantuaX
2023-05-11 05:30:24 +08:00
committed by GitHub
parent 5f4ff3ebfa
commit be253c2d63
14 changed files with 119 additions and 32 deletions

View File

@@ -94,7 +94,7 @@ func TestRequestAuthHeader(t *testing.T) {
az.OrgID = c.OrgID az.OrgID = c.OrgID
cli := NewClientWithConfig(az) cli := NewClientWithConfig(az)
req, err := cli.newStreamRequest(context.Background(), "POST", "/chat/completions", nil) req, err := cli.newStreamRequest(context.Background(), "POST", "/chat/completions", nil, "")
if err != nil { if err != nil {
t.Errorf("Failed to create request: %v", err) t.Errorf("Failed to create request: %v", err)
} }
@@ -111,12 +111,14 @@ func TestAzureFullURL(t *testing.T) {
cases := []struct { cases := []struct {
Name string Name string
BaseURL string BaseURL string
Engine string AzureModelMapper map[string]string
Model string
Expect string Expect string
}{ }{
{ {
"AzureBaseURLWithSlashAutoStrip", "AzureBaseURLWithSlashAutoStrip",
"https://httpbin.org/", "https://httpbin.org/",
nil,
"chatgpt-demo", "chatgpt-demo",
"https://httpbin.org/" + "https://httpbin.org/" +
"openai/deployments/chatgpt-demo" + "openai/deployments/chatgpt-demo" +
@@ -125,6 +127,7 @@ func TestAzureFullURL(t *testing.T) {
{ {
"AzureBaseURLWithoutSlashOK", "AzureBaseURLWithoutSlashOK",
"https://httpbin.org", "https://httpbin.org",
nil,
"chatgpt-demo", "chatgpt-demo",
"https://httpbin.org/" + "https://httpbin.org/" +
"openai/deployments/chatgpt-demo" + "openai/deployments/chatgpt-demo" +
@@ -134,10 +137,10 @@ func TestAzureFullURL(t *testing.T) {
for _, c := range cases { for _, c := range cases {
t.Run(c.Name, func(t *testing.T) { t.Run(c.Name, func(t *testing.T) {
az := DefaultAzureConfig("dummy", c.BaseURL, c.Engine) az := DefaultAzureConfig("dummy", c.BaseURL)
cli := NewClientWithConfig(az) cli := NewClientWithConfig(az)
// /openai/deployments/{engine}/chat/completions?api-version={api_version} // /openai/deployments/{engine}/chat/completions?api-version={api_version}
actual := cli.fullURL("/chat/completions") actual := cli.fullURL("/chat/completions", c.Model)
if actual != c.Expect { if actual != c.Expect {
t.Errorf("Expected %s, got %s", c.Expect, actual) t.Errorf("Expected %s, got %s", c.Expect, actual)
} }

View File

@@ -68,7 +68,7 @@ func (c *Client) callAudioAPI(
} }
urlSuffix := fmt.Sprintf("/audio/%s", endpointSuffix) urlSuffix := fmt.Sprintf("/audio/%s", endpointSuffix)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix), &formBody) req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), &formBody)
if err != nil { if err != nil {
return AudioResponse{}, err return AudioResponse{}, err
} }

View File

@@ -77,7 +77,7 @@ func (c *Client) CreateChatCompletion(
return return
} }
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL(urlSuffix), request) req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), request)
if err != nil { if err != nil {
return return
} }

View File

@@ -46,7 +46,7 @@ func (c *Client) CreateChatCompletionStream(
} }
request.Stream = true request.Stream = true
req, err := c.newStreamRequest(ctx, "POST", urlSuffix, request) req, err := c.newStreamRequest(ctx, "POST", urlSuffix, request, request.Model)
if err != nil { if err != nil {
return return
} }

View File

@@ -98,8 +98,10 @@ func decodeString(body io.Reader, output *string) error {
return nil return nil
} }
func (c *Client) fullURL(suffix string) string { // fullURL returns full URL for request.
// /openai/deployments/{engine}/chat/completions?api-version={api_version} // args[0] is model name, if API type is Azure, model name is required to get deployment name.
func (c *Client) fullURL(suffix string, args ...any) string {
// /openai/deployments/{model}/chat/completions?api-version={api_version}
if c.config.APIType == APITypeAzure || c.config.APIType == APITypeAzureAD { if c.config.APIType == APITypeAzure || c.config.APIType == APITypeAzureAD {
baseURL := c.config.BaseURL baseURL := c.config.BaseURL
baseURL = strings.TrimRight(baseURL, "/") baseURL = strings.TrimRight(baseURL, "/")
@@ -108,8 +110,17 @@ func (c *Client) fullURL(suffix string) string {
if strings.Contains(suffix, "/models") { if strings.Contains(suffix, "/models") {
return fmt.Sprintf("%s/%s%s?api-version=%s", baseURL, azureAPIPrefix, suffix, c.config.APIVersion) return fmt.Sprintf("%s/%s%s?api-version=%s", baseURL, azureAPIPrefix, suffix, c.config.APIVersion)
} }
azureDeploymentName := "UNKNOWN"
if len(args) > 0 {
model, ok := args[0].(string)
if ok {
azureDeploymentName = c.config.GetAzureDeploymentByModel(model)
}
}
return fmt.Sprintf("%s/%s/%s/%s%s?api-version=%s", return fmt.Sprintf("%s/%s/%s/%s%s?api-version=%s",
baseURL, azureAPIPrefix, azureDeploymentsPrefix, c.config.Engine, suffix, c.config.APIVersion) baseURL, azureAPIPrefix, azureDeploymentsPrefix,
azureDeploymentName, suffix, c.config.APIVersion,
)
} }
// c.config.APIType == APITypeOpenAI || c.config.APIType == "" // c.config.APIType == APITypeOpenAI || c.config.APIType == ""
@@ -120,8 +131,9 @@ func (c *Client) newStreamRequest(
ctx context.Context, ctx context.Context,
method string, method string,
urlSuffix string, urlSuffix string,
body any) (*http.Request, error) { body any,
req, err := c.requestBuilder.build(ctx, method, c.fullURL(urlSuffix), body) model string) (*http.Request, error) {
req, err := c.requestBuilder.build(ctx, method, c.fullURL(urlSuffix, model), body)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -155,7 +155,7 @@ func (c *Client) CreateCompletion(
return return
} }
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL(urlSuffix), request) req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), request)
if err != nil { if err != nil {
return return
} }

View File

@@ -2,6 +2,7 @@ package openai
import ( import (
"net/http" "net/http"
"regexp"
) )
const ( const (
@@ -30,8 +31,7 @@ type ClientConfig struct {
OrgID string OrgID string
APIType APIType APIType APIType
APIVersion string // required when APIType is APITypeAzure or APITypeAzureAD APIVersion string // required when APIType is APITypeAzure or APITypeAzureAD
Engine string // required when APIType is APITypeAzure or APITypeAzureAD AzureModelMapperFunc func(model string) string // replace model to azure deployment name func
HTTPClient *http.Client HTTPClient *http.Client
EmptyMessagesLimit uint EmptyMessagesLimit uint
@@ -50,14 +50,16 @@ func DefaultConfig(authToken string) ClientConfig {
} }
} }
func DefaultAzureConfig(apiKey, baseURL, engine string) ClientConfig { func DefaultAzureConfig(apiKey, baseURL string) ClientConfig {
return ClientConfig{ return ClientConfig{
authToken: apiKey, authToken: apiKey,
BaseURL: baseURL, BaseURL: baseURL,
OrgID: "", OrgID: "",
APIType: APITypeAzure, APIType: APITypeAzure,
APIVersion: "2023-03-15-preview", APIVersion: "2023-03-15-preview",
Engine: engine, AzureModelMapperFunc: func(model string) string {
return regexp.MustCompile(`[.:]`).ReplaceAllString(model, "")
},
HTTPClient: &http.Client{}, HTTPClient: &http.Client{},
@@ -68,3 +70,11 @@ func DefaultAzureConfig(apiKey, baseURL, engine string) ClientConfig {
func (ClientConfig) String() string { func (ClientConfig) String() string {
return "<OpenAI API ClientConfig>" return "<OpenAI API ClientConfig>"
} }
func (c ClientConfig) GetAzureDeploymentByModel(model string) string {
if c.AzureModelMapperFunc != nil {
return c.AzureModelMapperFunc(model)
}
return model
}

62
config_test.go Normal file
View File

@@ -0,0 +1,62 @@
package openai_test
import (
"testing"
. "github.com/sashabaranov/go-openai"
)
func TestGetAzureDeploymentByModel(t *testing.T) {
cases := []struct {
Model string
AzureModelMapperFunc func(model string) string
Expect string
}{
{
Model: "gpt-3.5-turbo",
Expect: "gpt-35-turbo",
},
{
Model: "gpt-3.5-turbo-0301",
Expect: "gpt-35-turbo-0301",
},
{
Model: "text-embedding-ada-002",
Expect: "text-embedding-ada-002",
},
{
Model: "",
Expect: "",
},
{
Model: "models",
Expect: "models",
},
{
Model: "gpt-3.5-turbo",
Expect: "my-gpt35",
AzureModelMapperFunc: func(model string) string {
modelmapper := map[string]string{
"gpt-3.5-turbo": "my-gpt35",
}
if val, ok := modelmapper[model]; ok {
return val
}
return model
},
},
}
for _, c := range cases {
t.Run(c.Model, func(t *testing.T) {
conf := DefaultAzureConfig("", "https://test.openai.azure.com/")
if c.AzureModelMapperFunc != nil {
conf.AzureModelMapperFunc = c.AzureModelMapperFunc
}
actual := conf.GetAzureDeploymentByModel(c.Model)
if actual != c.Expect {
t.Errorf("Expected %s, got %s", c.Expect, actual)
}
})
}
}

View File

@@ -2,6 +2,7 @@ package openai
import ( import (
"context" "context"
"fmt"
"net/http" "net/http"
) )
@@ -31,7 +32,7 @@ type EditsResponse struct {
// Perform an API call to the Edits endpoint. // Perform an API call to the Edits endpoint.
func (c *Client) Edits(ctx context.Context, request EditsRequest) (response EditsResponse, err error) { func (c *Client) Edits(ctx context.Context, request EditsRequest) (response EditsResponse, err error) {
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/edits"), request) req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/edits", fmt.Sprint(request.Model)), request)
if err != nil { if err != nil {
return return
} }

View File

@@ -132,7 +132,7 @@ type EmbeddingRequest struct {
// CreateEmbeddings returns an EmbeddingResponse which will contain an Embedding for every item in |request.Input|. // CreateEmbeddings returns an EmbeddingResponse which will contain an Embedding for every item in |request.Input|.
// https://beta.openai.com/docs/api-reference/embeddings/create // https://beta.openai.com/docs/api-reference/embeddings/create
func (c *Client) CreateEmbeddings(ctx context.Context, request EmbeddingRequest) (resp EmbeddingResponse, err error) { func (c *Client) CreateEmbeddings(ctx context.Context, request EmbeddingRequest) (resp EmbeddingResponse, err error) {
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/embeddings"), request) req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/embeddings", request.Model.String()), request)
if err != nil { if err != nil {
return return
} }

View File

@@ -305,8 +305,7 @@ func Example_chatbot() {
func ExampleDefaultAzureConfig() { func ExampleDefaultAzureConfig() {
azureKey := os.Getenv("AZURE_OPENAI_API_KEY") // Your azure API key azureKey := os.Getenv("AZURE_OPENAI_API_KEY") // Your azure API key
azureEndpoint := os.Getenv("AZURE_OPENAI_ENDPOINT") // Your azure OpenAI endpoint azureEndpoint := os.Getenv("AZURE_OPENAI_ENDPOINT") // Your azure OpenAI endpoint
azureModel := os.Getenv("AZURE_OPENAI_MODEL") // Your model deployment name config := openai.DefaultAzureConfig(azureKey, azureEndpoint)
config := openai.DefaultAzureConfig(azureKey, azureEndpoint, azureModel)
client := openai.NewClientWithConfig(config) client := openai.NewClientWithConfig(config)
resp, err := client.CreateChatCompletion( resp, err := client.CreateChatCompletion(
context.Background(), context.Background(),

View File

@@ -40,7 +40,7 @@ func TestAzureListModels(t *testing.T) {
ts.Start() ts.Start()
defer ts.Close() defer ts.Close()
config := DefaultAzureConfig(test.GetTestToken(), "https://dummylab.openai.azure.com/", "dummyengine") config := DefaultAzureConfig(test.GetTestToken(), "https://dummylab.openai.azure.com/")
config.BaseURL = ts.URL config.BaseURL = ts.URL
client := NewClientWithConfig(config) client := NewClientWithConfig(config)
ctx := context.Background() ctx := context.Background()

View File

@@ -63,7 +63,7 @@ type ModerationResponse struct {
// Moderations — perform a moderation api call over a string. // Moderations — perform a moderation api call over a string.
// Input can be an array or slice but a string will reduce the complexity. // Input can be an array or slice but a string will reduce the complexity.
func (c *Client) Moderations(ctx context.Context, request ModerationRequest) (response ModerationResponse, err error) { func (c *Client) Moderations(ctx context.Context, request ModerationRequest) (response ModerationResponse, err error) {
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/moderations"), request) req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/moderations", request.Model), request)
if err != nil { if err != nil {
return return
} }

View File

@@ -35,7 +35,7 @@ func (c *Client) CreateCompletionStream(
} }
request.Stream = true request.Stream = true
req, err := c.newStreamRequest(ctx, "POST", urlSuffix, request) req, err := c.newStreamRequest(ctx, "POST", urlSuffix, request, request.Model)
if err != nil { if err != nil {
return return
} }