Check if the model param is valid for moderations endpoint (#437)
* chore: check for models before sending moderation requets to openai endpoint * chore: table driven tests to include more model cases for moderations endpoint
This commit is contained in:
@@ -2,6 +2,7 @@ package openai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
@@ -15,9 +16,19 @@ import (
|
||||
const (
|
||||
ModerationTextStable = "text-moderation-stable"
|
||||
ModerationTextLatest = "text-moderation-latest"
|
||||
ModerationText001 = "text-moderation-001"
|
||||
// Deprecated: use ModerationTextStable and ModerationTextLatest instead.
|
||||
ModerationText001 = "text-moderation-001"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrModerationInvalidModel = errors.New("this model is not supported with moderation, please use text-moderation-stable or text-moderation-latest instead") //nolint:lll
|
||||
)
|
||||
|
||||
var validModerationModel = map[string]struct{}{
|
||||
ModerationTextStable: {},
|
||||
ModerationTextLatest: {},
|
||||
}
|
||||
|
||||
// ModerationRequest represents a request structure for moderation API.
|
||||
type ModerationRequest struct {
|
||||
Input string `json:"input,omitempty"`
|
||||
@@ -63,6 +74,10 @@ type ModerationResponse struct {
|
||||
// Moderations — perform a moderation api call over a string.
|
||||
// Input can be an array or slice but a string will reduce the complexity.
|
||||
func (c *Client) Moderations(ctx context.Context, request ModerationRequest) (response ModerationResponse, err error) {
|
||||
if _, ok := validModerationModel[request.Model]; len(request.Model) > 0 && !ok {
|
||||
err = ErrModerationInvalidModel
|
||||
return
|
||||
}
|
||||
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/moderations", request.Model), withBody(&request))
|
||||
if err != nil {
|
||||
return
|
||||
|
||||
@@ -27,6 +27,41 @@ func TestModerations(t *testing.T) {
|
||||
checks.NoError(t, err, "Moderation error")
|
||||
}
|
||||
|
||||
// TestModerationsWithIncorrectModel Tests passing valid and invalid models to moderations endpoint.
|
||||
func TestModerationsWithDifferentModelOptions(t *testing.T) {
|
||||
var modelOptions []struct {
|
||||
model string
|
||||
expect error
|
||||
}
|
||||
modelOptions = append(modelOptions,
|
||||
getModerationModelTestOption(GPT3Dot5Turbo, ErrModerationInvalidModel),
|
||||
getModerationModelTestOption(ModerationTextStable, nil),
|
||||
getModerationModelTestOption(ModerationTextLatest, nil),
|
||||
getModerationModelTestOption("", nil),
|
||||
)
|
||||
client, server, teardown := setupOpenAITestServer()
|
||||
defer teardown()
|
||||
server.RegisterHandler("/v1/moderations", handleModerationEndpoint)
|
||||
for _, modelTest := range modelOptions {
|
||||
_, err := client.Moderations(context.Background(), ModerationRequest{
|
||||
Model: modelTest.model,
|
||||
Input: "I want to kill them.",
|
||||
})
|
||||
checks.ErrorIs(t, err, modelTest.expect,
|
||||
fmt.Sprintf("Moderations(..) expects err: %v, actual err:%v", modelTest.expect, err))
|
||||
}
|
||||
}
|
||||
|
||||
func getModerationModelTestOption(model string, expect error) struct {
|
||||
model string
|
||||
expect error
|
||||
} {
|
||||
return struct {
|
||||
model string
|
||||
expect error
|
||||
}{model: model, expect: expect}
|
||||
}
|
||||
|
||||
// handleModerationEndpoint Handles the moderation endpoint by the test server.
|
||||
func handleModerationEndpoint(w http.ResponseWriter, r *http.Request) {
|
||||
var err error
|
||||
|
||||
Reference in New Issue
Block a user