diff --git a/client_test.go b/client_test.go index 862cbe8..5e63539 100644 --- a/client_test.go +++ b/client_test.go @@ -170,104 +170,82 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) { ctx := context.Background() - _, err = client.CreateCompletion(ctx, CompletionRequest{Prompt: "testing"}) - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) + type TestCase struct { + Name string + TestFunc func() (any, error) } - _, err = client.CreateChatCompletion(ctx, ChatCompletionRequest{Model: GPT3Dot5Turbo}) - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) + testCases := []TestCase{ + {"CreateCompletion", func() (any, error) { + return client.CreateCompletion(ctx, CompletionRequest{Prompt: "testing"}) + }}, + {"CreateCompletionStream", func() (any, error) { + return client.CreateCompletionStream(ctx, CompletionRequest{Prompt: ""}) + }}, + {"CreateChatCompletion", func() (any, error) { + return client.CreateChatCompletion(ctx, ChatCompletionRequest{Model: GPT3Dot5Turbo}) + }}, + {"CreateChatCompletionStream", func() (any, error) { + return client.CreateChatCompletionStream(ctx, ChatCompletionRequest{Model: GPT3Dot5Turbo}) + }}, + {"CreateFineTune", func() (any, error) { + return client.CreateFineTune(ctx, FineTuneRequest{}) + }}, + {"ListFineTunes", func() (any, error) { + return client.ListFineTunes(ctx) + }}, + {"CancelFineTune", func() (any, error) { + return client.CancelFineTune(ctx, "") + }}, + {"GetFineTune", func() (any, error) { + return client.GetFineTune(ctx, "") + }}, + {"DeleteFineTune", func() (any, error) { + return client.DeleteFineTune(ctx, "") + }}, + {"ListFineTuneEvents", func() (any, error) { + return client.ListFineTuneEvents(ctx, "") + }}, + {"Moderations", func() (any, error) { + return client.Moderations(ctx, ModerationRequest{}) + }}, + {"Edits", func() (any, error) { + return client.Edits(ctx, EditsRequest{}) + }}, + {"CreateEmbeddings", func() (any, error) { + return client.CreateEmbeddings(ctx, EmbeddingRequest{}) + }}, + {"CreateImage", func() (any, error) { + return client.CreateImage(ctx, ImageRequest{}) + }}, + {"DeleteFile", func() (any, error) { + return nil, client.DeleteFile(ctx, "") + }}, + {"GetFile", func() (any, error) { + return client.GetFile(ctx, "") + }}, + {"ListFiles", func() (any, error) { + return client.ListFiles(ctx) + }}, + {"ListEngines", func() (any, error) { + return client.ListEngines(ctx) + }}, + {"GetEngine", func() (any, error) { + return client.GetEngine(ctx, "") + }}, + {"ListModels", func() (any, error) { + return client.ListModels(ctx) + }}, + {"GetModel", func() (any, error) { + return client.GetModel(ctx, "text-davinci-003") + }}, } - _, err = client.CreateChatCompletionStream(ctx, ChatCompletionRequest{Model: GPT3Dot5Turbo}) - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.CreateFineTune(ctx, FineTuneRequest{}) - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.ListFineTunes(ctx) - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.CancelFineTune(ctx, "") - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.GetFineTune(ctx, "") - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.DeleteFineTune(ctx, "") - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.ListFineTuneEvents(ctx, "") - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.Moderations(ctx, ModerationRequest{}) - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.Edits(ctx, EditsRequest{}) - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.CreateEmbeddings(ctx, EmbeddingRequest{}) - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.CreateImage(ctx, ImageRequest{}) - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - err = client.DeleteFile(ctx, "") - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.GetFile(ctx, "") - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.ListFiles(ctx) - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.ListEngines(ctx) - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.GetEngine(ctx, "") - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.ListModels(ctx) - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) - } - - _, err = client.CreateCompletionStream(ctx, CompletionRequest{Prompt: ""}) - if !errors.Is(err, errTestRequestBuilderFailed) { - t.Fatalf("Did not return error when request builder failed: %v", err) + for _, testCase := range testCases { + _, err = testCase.TestFunc() + if !errors.Is(err, errTestRequestBuilderFailed) { + t.Fatalf("%s did not return error when request builder failed: %v", testCase.Name, err) + } } } diff --git a/models.go b/models.go index 485433b..b3d4583 100644 --- a/models.go +++ b/models.go @@ -2,6 +2,7 @@ package openai import ( "context" + "fmt" "net/http" ) @@ -48,3 +49,16 @@ func (c *Client) ListModels(ctx context.Context) (models ModelsList, err error) err = c.sendRequest(req, &models) return } + +// GetModel Retrieves a model instance, providing basic information about +// the model such as the owner and permissioning. +func (c *Client) GetModel(ctx context.Context, modelID string) (model Model, err error) { + urlSuffix := fmt.Sprintf("/models/%s", modelID) + req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil) + if err != nil { + return + } + + err = c.sendRequest(req, &model) + return +} diff --git a/models_test.go b/models_test.go index b017800..834c849 100644 --- a/models_test.go +++ b/models_test.go @@ -54,3 +54,44 @@ func handleModelsEndpoint(w http.ResponseWriter, _ *http.Request) { resBytes, _ := json.Marshal(ModelsList{}) fmt.Fprintln(w, string(resBytes)) } + +// TestGetModel Tests the retrieve model endpoint of the API using the mocked server. +func TestGetModel(t *testing.T) { + server := test.NewTestServer() + server.RegisterHandler("/v1/models/text-davinci-003", handleGetModelEndpoint) + // create the test server + ts := server.OpenAITestServer() + ts.Start() + defer ts.Close() + + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = ts.URL + "/v1" + client := NewClientWithConfig(config) + ctx := context.Background() + + _, err := client.GetModel(ctx, "text-davinci-003") + checks.NoError(t, err, "GetModel error") +} + +func TestAzureGetModel(t *testing.T) { + server := test.NewTestServer() + server.RegisterHandler("/openai/models/text-davinci-003", handleModelsEndpoint) + // create the test server + ts := server.OpenAITestServer() + ts.Start() + defer ts.Close() + + config := DefaultAzureConfig(test.GetTestToken(), "https://dummylab.openai.azure.com/") + config.BaseURL = ts.URL + client := NewClientWithConfig(config) + ctx := context.Background() + + _, err := client.GetModel(ctx, "text-davinci-003") + checks.NoError(t, err, "GetModel error") +} + +// handleModelsEndpoint Handles the models endpoint by the test server. +func handleGetModelEndpoint(w http.ResponseWriter, _ *http.Request) { + resBytes, _ := json.Marshal(Model{}) + fmt.Fprintln(w, string(resBytes)) +}