From 5191ea6f554ce8c26694646895642cfab3b21317 Mon Sep 17 00:00:00 2001 From: Rascal0814 <80660375+Rascal0814@users.noreply.github.com> Date: Sun, 12 Feb 2023 02:51:53 +0800 Subject: [PATCH] Modify the test module, add the file upload test, and add the image edit api (#63) * Modify the test module, add the file upload test, and add the image editing api * fix golangci-lint * fix golangci-lint * Static file deletion, file directory name modification * fix * test-server-related logic encapsulated in a single tidy struct --------- Co-authored-by: julian_huang --- .gitignore | 3 +- api_test.go | 419 +--------------------------------------- completion_test.go | 102 ++++++++++ edits_test.go | 103 ++++++++++ embeddings_test.go | 48 +++++ files.go | 9 +- files_test.go | 80 ++++++++ image.go | 63 ++++++ image_test.go | 162 ++++++++++++++++ internal/test/server.go | 46 +++++ moderation_test.go | 101 ++++++++++ stream_test.go | 9 +- 12 files changed, 719 insertions(+), 426 deletions(-) create mode 100644 completion_test.go create mode 100644 edits_test.go create mode 100644 embeddings_test.go create mode 100644 files_test.go create mode 100644 image_test.go create mode 100644 internal/test/server.go create mode 100644 moderation_test.go diff --git a/.gitignore b/.gitignore index 42385aa..99b40bf 100644 --- a/.gitignore +++ b/.gitignore @@ -15,4 +15,5 @@ # vendor/ # Auth token for tests -.openai-token \ No newline at end of file +.openai-token +.idea \ No newline at end of file diff --git a/api_test.go b/api_test.go index d0b4d52..7843bef 100644 --- a/api_test.go +++ b/api_test.go @@ -1,26 +1,13 @@ package gogpt_test import ( - "bytes" - "context" - "encoding/json" - "errors" - "fmt" - "io" - "log" - "net/http" - "net/http/httptest" - "os" - "strconv" - "strings" - "testing" - "time" - . "github.com/sashabaranov/go-gpt3" -) -const ( - testAPIToken = "this-is-my-secure-token-do-not-steal!" + "context" + "errors" + "io" + "os" + "testing" ) func TestAPI(t *testing.T) { @@ -94,370 +81,6 @@ func TestAPI(t *testing.T) { } } -// TestCompletions Tests the completions endpoint of the API using the mocked server. -func TestCompletions(t *testing.T) { - // create the test server - var err error - ts := OpenAITestServer() - ts.Start() - defer ts.Close() - - client := NewClient(testAPIToken) - ctx := context.Background() - client.BaseURL = ts.URL + "/v1" - - req := CompletionRequest{ - MaxTokens: 5, - Model: "ada", - } - req.Prompt = "Lorem ipsum" - _, err = client.CreateCompletion(ctx, req) - if err != nil { - t.Fatalf("CreateCompletion error: %v", err) - } -} - -// TestEdits Tests the edits endpoint of the API using the mocked server. -func TestEdits(t *testing.T) { - // create the test server - var err error - ts := OpenAITestServer() - ts.Start() - defer ts.Close() - - client := NewClient(testAPIToken) - ctx := context.Background() - client.BaseURL = ts.URL + "/v1" - - // create an edit request - model := "ada" - editReq := EditsRequest{ - Model: &model, - Input: "Lorem ipsum dolor sit amet, consectetur adipiscing elit, " + - "sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim" + - " ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip" + - " ex ea commodo consequat. Duis aute irure dolor in reprehe", - Instruction: "test instruction", - N: 3, - } - response, err := client.Edits(ctx, editReq) - if err != nil { - t.Fatalf("Edits error: %v", err) - } - if len(response.Choices) != editReq.N { - t.Fatalf("edits does not properly return the correct number of choices") - } -} - -// TestModeration Tests the moderations endpoint of the API using the mocked server. -func TestModerations(t *testing.T) { - // create the test server - var err error - ts := OpenAITestServer() - ts.Start() - defer ts.Close() - - client := NewClient(testAPIToken) - ctx := context.Background() - client.BaseURL = ts.URL + "/v1" - - // create an edit request - model := "text-moderation-stable" - moderationReq := ModerationRequest{ - Model: &model, - Input: "I want to kill them.", - } - _, err = client.Moderations(ctx, moderationReq) - if err != nil { - t.Fatalf("Moderation error: %v", err) - } -} - -func TestEmbedding(t *testing.T) { - embeddedModels := []EmbeddingModel{ - AdaSimilarity, - BabbageSimilarity, - CurieSimilarity, - DavinciSimilarity, - AdaSearchDocument, - AdaSearchQuery, - BabbageSearchDocument, - BabbageSearchQuery, - CurieSearchDocument, - CurieSearchQuery, - DavinciSearchDocument, - DavinciSearchQuery, - AdaCodeSearchCode, - AdaCodeSearchText, - BabbageCodeSearchCode, - BabbageCodeSearchText, - } - for _, model := range embeddedModels { - embeddingReq := EmbeddingRequest{ - Input: []string{ - "The food was delicious and the waiter", - "Other examples of embedding request", - }, - Model: model, - } - // marshal embeddingReq to JSON and confirm that the model field equals - // the AdaSearchQuery type - marshaled, err := json.Marshal(embeddingReq) - if err != nil { - t.Fatalf("Could not marshal embedding request: %v", err) - } - if !bytes.Contains(marshaled, []byte(`"model":"`+model.String()+`"`)) { - t.Fatalf("Expected embedding request to contain model field") - } - } -} - -func TestImages(t *testing.T) { - // create the test server - var err error - ts := OpenAITestServer() - ts.Start() - defer ts.Close() - - client := NewClient(testAPIToken) - ctx := context.Background() - client.BaseURL = ts.URL + "/v1" - - req := ImageRequest{} - req.Prompt = "Lorem ipsum" - _, err = client.CreateImage(ctx, req) - if err != nil { - t.Fatalf("CreateImage error: %v", err) - } -} - -// getEditBody Returns the body of the request to create an edit. -func getEditBody(r *http.Request) (EditsRequest, error) { - edit := EditsRequest{} - // read the request body - reqBody, err := io.ReadAll(r.Body) - if err != nil { - return EditsRequest{}, err - } - err = json.Unmarshal(reqBody, &edit) - if err != nil { - return EditsRequest{}, err - } - return edit, nil -} - -// handleEditEndpoint Handles the edit endpoint by the test server. -func handleEditEndpoint(w http.ResponseWriter, r *http.Request) { - var err error - var resBytes []byte - - // edits only accepts POST requests - if r.Method != "POST" { - http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) - } - var editReq EditsRequest - editReq, err = getEditBody(r) - if err != nil { - http.Error(w, "could not read request", http.StatusInternalServerError) - return - } - // create a response - res := EditsResponse{ - Object: "test-object", - Created: time.Now().Unix(), - } - // edit and calculate token usage - editString := "edited by mocked OpenAI server :)" - inputTokens := numTokens(editReq.Input+editReq.Instruction) * editReq.N - completionTokens := int(float32(len(editString))/4) * editReq.N - for i := 0; i < editReq.N; i++ { - // instruction will be hidden and only seen by OpenAI - res.Choices = append(res.Choices, EditsChoice{ - Text: editReq.Input + editString, - Index: i, - }) - } - res.Usage = Usage{ - PromptTokens: inputTokens, - CompletionTokens: completionTokens, - TotalTokens: inputTokens + completionTokens, - } - resBytes, _ = json.Marshal(res) - fmt.Fprint(w, string(resBytes)) -} - -// handleCompletionEndpoint Handles the completion endpoint by the test server. -func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) { - var err error - var resBytes []byte - - // completions only accepts POST requests - if r.Method != "POST" { - http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) - } - var completionReq CompletionRequest - if completionReq, err = getCompletionBody(r); err != nil { - http.Error(w, "could not read request", http.StatusInternalServerError) - return - } - - res := CompletionResponse{ - ID: strconv.Itoa(int(time.Now().Unix())), - Object: "test-object", - Created: time.Now().Unix(), - // would be nice to validate Model during testing, but - // this may not be possible with how much upkeep - // would be required / wouldn't make much sense - Model: completionReq.Model, - } - - // create completions - for i := 0; i < completionReq.N; i++ { - // generate a random string of length completionReq.Length - completionStr := strings.Repeat("a", completionReq.MaxTokens) - if completionReq.Echo { - completionStr = completionReq.Prompt + completionStr - } - res.Choices = append(res.Choices, CompletionChoice{ - Text: completionStr, - Index: i, - }) - } - inputTokens := numTokens(completionReq.Prompt) * completionReq.N - completionTokens := completionReq.MaxTokens * completionReq.N - res.Usage = Usage{ - PromptTokens: inputTokens, - CompletionTokens: completionTokens, - TotalTokens: inputTokens + completionTokens, - } - resBytes, _ = json.Marshal(res) - fmt.Fprintln(w, string(resBytes)) -} - -// getCompletionBody Returns the body of the request to create a completion. -func getCompletionBody(r *http.Request) (CompletionRequest, error) { - completion := CompletionRequest{} - // read the request body - reqBody, err := io.ReadAll(r.Body) - if err != nil { - return CompletionRequest{}, err - } - err = json.Unmarshal(reqBody, &completion) - if err != nil { - return CompletionRequest{}, err - } - return completion, nil -} - -// handleImageEndpoint Handles the images endpoint by the test server. -func handleImageEndpoint(w http.ResponseWriter, r *http.Request) { - var err error - var resBytes []byte - - // imagess only accepts POST requests - if r.Method != "POST" { - http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) - } - var imageReq ImageRequest - if imageReq, err = getImageBody(r); err != nil { - http.Error(w, "could not read request", http.StatusInternalServerError) - return - } - res := ImageResponse{ - Created: time.Now().Unix(), - } - for i := 0; i < imageReq.N; i++ { - imageData := ImageResponseDataInner{} - switch imageReq.ResponseFormat { - case CreateImageResponseFormatURL, "": - imageData.URL = "https://example.com/image.png" - case CreateImageResponseFormatB64JSON: - // This decodes to "{}" in base64. - imageData.B64JSON = "e30K" - default: - http.Error(w, "invalid response format", http.StatusBadRequest) - return - } - res.Data = append(res.Data, imageData) - } - resBytes, _ = json.Marshal(res) - fmt.Fprintln(w, string(resBytes)) -} - -// getImageBody Returns the body of the request to create a image. -func getImageBody(r *http.Request) (ImageRequest, error) { - image := ImageRequest{} - // read the request body - reqBody, err := io.ReadAll(r.Body) - if err != nil { - return ImageRequest{}, err - } - err = json.Unmarshal(reqBody, &image) - if err != nil { - return ImageRequest{}, err - } - return image, nil -} - -// handleModerationEndpoint Handles the moderation endpoint by the test server. -func handleModerationEndpoint(w http.ResponseWriter, r *http.Request) { - var err error - var resBytes []byte - - // completions only accepts POST requests - if r.Method != "POST" { - http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) - } - var moderationReq ModerationRequest - if moderationReq, err = getModerationBody(r); err != nil { - http.Error(w, "could not read request", http.StatusInternalServerError) - return - } - - resCat := ResultCategories{} - resCatScore := ResultCategoryScores{} - switch { - case strings.Contains(moderationReq.Input, "kill"): - resCat = ResultCategories{Violence: true} - resCatScore = ResultCategoryScores{Violence: 1} - case strings.Contains(moderationReq.Input, "hate"): - resCat = ResultCategories{Hate: true} - resCatScore = ResultCategoryScores{Hate: 1} - case strings.Contains(moderationReq.Input, "suicide"): - resCat = ResultCategories{SelfHarm: true} - resCatScore = ResultCategoryScores{SelfHarm: 1} - case strings.Contains(moderationReq.Input, "porn"): - resCat = ResultCategories{Sexual: true} - resCatScore = ResultCategoryScores{Sexual: 1} - } - - result := Result{Categories: resCat, CategoryScores: resCatScore, Flagged: true} - - res := ModerationResponse{ - ID: strconv.Itoa(int(time.Now().Unix())), - Model: *moderationReq.Model, - } - res.Results = append(res.Results, result) - - resBytes, _ = json.Marshal(res) - fmt.Fprintln(w, string(resBytes)) -} - -// getModerationBody Returns the body of the request to do a moderation. -func getModerationBody(r *http.Request) (ModerationRequest, error) { - moderation := ModerationRequest{} - // read the request body - reqBody, err := io.ReadAll(r.Body) - if err != nil { - return ModerationRequest{}, err - } - err = json.Unmarshal(reqBody, &moderation) - if err != nil { - return ModerationRequest{}, err - } - return moderation, nil -} - // numTokens Returns the number of GPT-3 encoded tokens in the given text. // This function approximates based on the rule of thumb stated by OpenAI: // https://beta.openai.com/tokenizer @@ -466,35 +89,3 @@ func getModerationBody(r *http.Request) (ModerationRequest, error) { func numTokens(s string) int { return int(float32(len(s)) / 4) } - -// OpenAITestServer Creates a mocked OpenAI server which can pretend to handle requests during testing. -func OpenAITestServer() *httptest.Server { - return httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - log.Printf("received request at path %q\n", r.URL.Path) - - // check auth - if r.Header.Get("Authorization") != "Bearer "+testAPIToken { - w.WriteHeader(http.StatusUnauthorized) - return - } - - // OPTIMIZE: create separate handler functions for these - switch r.URL.Path { - case "/v1/edits": - handleEditEndpoint(w, r) - return - case "/v1/completions": - handleCompletionEndpoint(w, r) - return - case "/v1/moderations": - handleModerationEndpoint(w, r) - case "/v1/images/generations": - handleImageEndpoint(w, r) - // TODO: implement the other endpoints - default: - // the endpoint doesn't exist - http.Error(w, "the resource path doesn't exist", http.StatusNotFound) - return - } - })) -} diff --git a/completion_test.go b/completion_test.go new file mode 100644 index 0000000..c96df1a --- /dev/null +++ b/completion_test.go @@ -0,0 +1,102 @@ +package gogpt_test + +import ( + . "github.com/sashabaranov/go-gpt3" + "github.com/sashabaranov/go-gpt3/internal/test" + + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strconv" + "strings" + "testing" + "time" +) + +// TestCompletions Tests the completions endpoint of the API using the mocked server. +func TestCompletions(t *testing.T) { + server := test.NewTestServer() + server.RegisterHandler("/v1/completions", handleCompletionEndpoint) + // create the test server + var err error + ts := server.OpenAITestServer() + ts.Start() + defer ts.Close() + + client := NewClient(test.GetTestToken()) + ctx := context.Background() + client.BaseURL = ts.URL + "/v1" + + req := CompletionRequest{ + MaxTokens: 5, + Model: "ada", + } + req.Prompt = "Lorem ipsum" + _, err = client.CreateCompletion(ctx, req) + if err != nil { + t.Fatalf("CreateCompletion error: %v", err) + } +} + +// handleCompletionEndpoint Handles the completion endpoint by the test server. +func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) { + var err error + var resBytes []byte + + // completions only accepts POST requests + if r.Method != "POST" { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } + var completionReq CompletionRequest + if completionReq, err = getCompletionBody(r); err != nil { + http.Error(w, "could not read request", http.StatusInternalServerError) + return + } + res := CompletionResponse{ + ID: strconv.Itoa(int(time.Now().Unix())), + Object: "test-object", + Created: time.Now().Unix(), + // would be nice to validate Model during testing, but + // this may not be possible with how much upkeep + // would be required / wouldn't make much sense + Model: completionReq.Model, + } + // create completions + for i := 0; i < completionReq.N; i++ { + // generate a random string of length completionReq.Length + completionStr := strings.Repeat("a", completionReq.MaxTokens) + if completionReq.Echo { + completionStr = completionReq.Prompt + completionStr + } + res.Choices = append(res.Choices, CompletionChoice{ + Text: completionStr, + Index: i, + }) + } + inputTokens := numTokens(completionReq.Prompt) * completionReq.N + completionTokens := completionReq.MaxTokens * completionReq.N + res.Usage = Usage{ + PromptTokens: inputTokens, + CompletionTokens: completionTokens, + TotalTokens: inputTokens + completionTokens, + } + resBytes, _ = json.Marshal(res) + fmt.Fprintln(w, string(resBytes)) +} + +// getCompletionBody Returns the body of the request to create a completion. +func getCompletionBody(r *http.Request) (CompletionRequest, error) { + completion := CompletionRequest{} + // read the request body + reqBody, err := io.ReadAll(r.Body) + if err != nil { + return CompletionRequest{}, err + } + err = json.Unmarshal(reqBody, &completion) + if err != nil { + return CompletionRequest{}, err + } + return completion, nil +} diff --git a/edits_test.go b/edits_test.go new file mode 100644 index 0000000..499d098 --- /dev/null +++ b/edits_test.go @@ -0,0 +1,103 @@ +package gogpt_test + +import ( + . "github.com/sashabaranov/go-gpt3" + "github.com/sashabaranov/go-gpt3/internal/test" + + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "testing" + "time" +) + +// TestEdits Tests the edits endpoint of the API using the mocked server. +func TestEdits(t *testing.T) { + server := test.NewTestServer() + server.RegisterHandler("/v1/edits", handleEditEndpoint) + // create the test server + var err error + ts := server.OpenAITestServer() + ts.Start() + defer ts.Close() + + client := NewClient(test.GetTestToken()) + ctx := context.Background() + client.BaseURL = ts.URL + "/v1" + + // create an edit request + model := "ada" + editReq := EditsRequest{ + Model: &model, + Input: "Lorem ipsum dolor sit amet, consectetur adipiscing elit, " + + "sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim" + + " ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip" + + " ex ea commodo consequat. Duis aute irure dolor in reprehe", + Instruction: "test instruction", + N: 3, + } + response, err := client.Edits(ctx, editReq) + if err != nil { + t.Fatalf("Edits error: %v", err) + } + if len(response.Choices) != editReq.N { + t.Fatalf("edits does not properly return the correct number of choices") + } +} + +// handleEditEndpoint Handles the edit endpoint by the test server. +func handleEditEndpoint(w http.ResponseWriter, r *http.Request) { + var err error + var resBytes []byte + + // edits only accepts POST requests + if r.Method != "POST" { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } + var editReq EditsRequest + editReq, err = getEditBody(r) + if err != nil { + http.Error(w, "could not read request", http.StatusInternalServerError) + return + } + // create a response + res := EditsResponse{ + Object: "test-object", + Created: time.Now().Unix(), + } + // edit and calculate token usage + editString := "edited by mocked OpenAI server :)" + inputTokens := numTokens(editReq.Input+editReq.Instruction) * editReq.N + completionTokens := int(float32(len(editString))/4) * editReq.N + for i := 0; i < editReq.N; i++ { + // instruction will be hidden and only seen by OpenAI + res.Choices = append(res.Choices, EditsChoice{ + Text: editReq.Input + editString, + Index: i, + }) + } + res.Usage = Usage{ + PromptTokens: inputTokens, + CompletionTokens: completionTokens, + TotalTokens: inputTokens + completionTokens, + } + resBytes, _ = json.Marshal(res) + fmt.Fprint(w, string(resBytes)) +} + +// getEditBody Returns the body of the request to create an edit. +func getEditBody(r *http.Request) (EditsRequest, error) { + edit := EditsRequest{} + // read the request body + reqBody, err := io.ReadAll(r.Body) + if err != nil { + return EditsRequest{}, err + } + err = json.Unmarshal(reqBody, &edit) + if err != nil { + return EditsRequest{}, err + } + return edit, nil +} diff --git a/embeddings_test.go b/embeddings_test.go new file mode 100644 index 0000000..daa74e2 --- /dev/null +++ b/embeddings_test.go @@ -0,0 +1,48 @@ +package gogpt_test + +import ( + . "github.com/sashabaranov/go-gpt3" + + "bytes" + "encoding/json" + "testing" +) + +func TestEmbedding(t *testing.T) { + embeddedModels := []EmbeddingModel{ + AdaSimilarity, + BabbageSimilarity, + CurieSimilarity, + DavinciSimilarity, + AdaSearchDocument, + AdaSearchQuery, + BabbageSearchDocument, + BabbageSearchQuery, + CurieSearchDocument, + CurieSearchQuery, + DavinciSearchDocument, + DavinciSearchQuery, + AdaCodeSearchCode, + AdaCodeSearchText, + BabbageCodeSearchCode, + BabbageCodeSearchText, + } + for _, model := range embeddedModels { + embeddingReq := EmbeddingRequest{ + Input: []string{ + "The food was delicious and the waiter", + "Other examples of embedding request", + }, + Model: model, + } + // marshal embeddingReq to JSON and confirm that the model field equals + // the AdaSearchQuery type + marshaled, err := json.Marshal(embeddingReq) + if err != nil { + t.Fatalf("Could not marshal embedding request: %v", err) + } + if !bytes.Contains(marshaled, []byte(`"model":"`+model.String()+`"`)) { + t.Fatalf("Expected embedding request to contain model field") + } + } +} diff --git a/files.go b/files.go index bc27e65..0b106fa 100644 --- a/files.go +++ b/files.go @@ -9,7 +9,6 @@ import ( "net/http" "net/url" "os" - "strings" ) type FileRequest struct { @@ -56,13 +55,9 @@ func (c *Client) CreateFile(ctx context.Context, request FileRequest) (file File var b bytes.Buffer w := multipart.NewWriter(&b) - var fw, pw io.Writer - pw, err = w.CreateFormField("purpose") - if err != nil { - return - } + var fw io.Writer - _, err = io.Copy(pw, strings.NewReader(request.Purpose)) + err = w.WriteField("purpose", request.Purpose) if err != nil { return } diff --git a/files_test.go b/files_test.go new file mode 100644 index 0000000..94c8904 --- /dev/null +++ b/files_test.go @@ -0,0 +1,80 @@ +package gogpt_test + +import ( + . "github.com/sashabaranov/go-gpt3" + "github.com/sashabaranov/go-gpt3/internal/test" + + "context" + "encoding/json" + "fmt" + "net/http" + "strconv" + "testing" + "time" +) + +func TestFileUpload(t *testing.T) { + server := test.NewTestServer() + server.RegisterHandler("/v1/files", handleCreateFile) + // create the test server + var err error + ts := server.OpenAITestServer() + ts.Start() + defer ts.Close() + + client := NewClient(test.GetTestToken()) + ctx := context.Background() + client.BaseURL = ts.URL + "/v1" + + req := FileRequest{ + FileName: "test.go", + FilePath: "api.go", + Purpose: "fine-tune", + } + _, err = client.CreateFile(ctx, req) + if err != nil { + t.Fatalf("CreateFile error: %v", err) + } +} + +// handleCreateFile Handles the images endpoint by the test server. +func handleCreateFile(w http.ResponseWriter, r *http.Request) { + var err error + var resBytes []byte + + // edits only accepts POST requests + if r.Method != "POST" { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } + err = r.ParseMultipartForm(1024 * 1024 * 1024) + if err != nil { + http.Error(w, "file is more than 1GB", http.StatusInternalServerError) + return + } + + values := r.Form + var purpose string + for key, value := range values { + if key == "purpose" { + purpose = value[0] + } + } + file, header, err := r.FormFile("file") + if err != nil { + return + } + defer file.Close() + + var fileReq = File{ + Bytes: int(header.Size), + ID: strconv.Itoa(int(time.Now().Unix())), + FileName: header.Filename, + Purpose: purpose, + CreatedAt: time.Now().Unix(), + Object: "test-objecct", + Owner: "test-owner", + } + + resBytes, _ = json.Marshal(fileReq) + fmt.Fprint(w, string(resBytes)) +} diff --git a/image.go b/image.go index e71e0a1..4368e99 100644 --- a/image.go +++ b/image.go @@ -4,7 +4,11 @@ import ( "bytes" "context" "encoding/json" + "io" + "mime/multipart" "net/http" + "os" + "strconv" ) // Image sizes defined by the OpenAI API. @@ -58,3 +62,62 @@ func (c *Client) CreateImage(ctx context.Context, request ImageRequest) (respons err = c.sendRequest(req, &response) return } + +// ImageEditRequest represents the request structure for the image API. +type ImageEditRequest struct { + Image *os.File `json:"image,omitempty"` + Mask *os.File `json:"mask,omitempty"` + Prompt string `json:"prompt,omitempty"` + N int `json:"n,omitempty"` + Size string `json:"size,omitempty"` +} + +// CreateEditImage - API call to create an image. This is the main endpoint of the DALL-E API. +func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest) (response ImageResponse, err error) { + body := &bytes.Buffer{} + writer := multipart.NewWriter(body) + + // image + image, err := writer.CreateFormFile("image", request.Image.Name()) + if err != nil { + return + } + _, err = io.Copy(image, request.Image) + if err != nil { + return + } + + // mask + mask, err := writer.CreateFormFile("mask", request.Mask.Name()) + if err != nil { + return + } + _, err = io.Copy(mask, request.Mask) + if err != nil { + return + } + + err = writer.WriteField("prompt", request.Prompt) + if err != nil { + return + } + err = writer.WriteField("n", strconv.Itoa(request.N)) + if err != nil { + return + } + err = writer.WriteField("size", request.Size) + if err != nil { + return + } + writer.Close() + urlSuffix := "/images/edits" + req, err := http.NewRequest(http.MethodPost, c.fullURL(urlSuffix), body) + if err != nil { + return + } + + req = req.WithContext(ctx) + req.Header.Set("Content-Type", writer.FormDataContentType()) + err = c.sendRequest(req, &response) + return +} diff --git a/image_test.go b/image_test.go new file mode 100644 index 0000000..6eaf182 --- /dev/null +++ b/image_test.go @@ -0,0 +1,162 @@ +package gogpt_test + +import ( + . "github.com/sashabaranov/go-gpt3" + "github.com/sashabaranov/go-gpt3/internal/test" + + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "testing" + "time" +) + +func TestImages(t *testing.T) { + server := test.NewTestServer() + server.RegisterHandler("/v1/images/generations", handleImageEndpoint) + // create the test server + var err error + ts := server.OpenAITestServer() + ts.Start() + defer ts.Close() + + client := NewClient(test.GetTestToken()) + ctx := context.Background() + client.BaseURL = ts.URL + "/v1" + + req := ImageRequest{} + req.Prompt = "Lorem ipsum" + _, err = client.CreateImage(ctx, req) + if err != nil { + t.Fatalf("CreateImage error: %v", err) + } +} + +// handleImageEndpoint Handles the images endpoint by the test server. +func handleImageEndpoint(w http.ResponseWriter, r *http.Request) { + var err error + var resBytes []byte + + // imagess only accepts POST requests + if r.Method != "POST" { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } + var imageReq ImageRequest + if imageReq, err = getImageBody(r); err != nil { + http.Error(w, "could not read request", http.StatusInternalServerError) + return + } + res := ImageResponse{ + Created: time.Now().Unix(), + } + for i := 0; i < imageReq.N; i++ { + imageData := ImageResponseDataInner{} + switch imageReq.ResponseFormat { + case CreateImageResponseFormatURL, "": + imageData.URL = "https://example.com/image.png" + case CreateImageResponseFormatB64JSON: + // This decodes to "{}" in base64. + imageData.B64JSON = "e30K" + default: + http.Error(w, "invalid response format", http.StatusBadRequest) + return + } + res.Data = append(res.Data, imageData) + } + resBytes, _ = json.Marshal(res) + fmt.Fprintln(w, string(resBytes)) +} + +// getImageBody Returns the body of the request to create a image. +func getImageBody(r *http.Request) (ImageRequest, error) { + image := ImageRequest{} + // read the request body + reqBody, err := io.ReadAll(r.Body) + if err != nil { + return ImageRequest{}, err + } + err = json.Unmarshal(reqBody, &image) + if err != nil { + return ImageRequest{}, err + } + return image, nil +} + +func TestImageEdit(t *testing.T) { + server := test.NewTestServer() + server.RegisterHandler("/v1/images/edits", handleEditImageEndpoint) + // create the test server + var err error + ts := server.OpenAITestServer() + ts.Start() + defer ts.Close() + + client := NewClient(test.GetTestToken()) + ctx := context.Background() + client.BaseURL = ts.URL + "/v1" + + origin, err := os.Create("image.png") + if err != nil { + t.Error("open origin file error") + return + } + + mask, err := os.Create("mask.png") + if err != nil { + t.Error("open mask file error") + return + } + + defer func() { + mask.Close() + origin.Close() + os.Remove("mask.png") + os.Remove("image.png") + }() + + req := ImageEditRequest{ + Image: origin, + Mask: mask, + Prompt: "There is a turtle in the pool", + N: 3, + Size: CreateImageSize1024x1024, + } + _, err = client.CreateEditImage(ctx, req) + if err != nil { + t.Fatalf("CreateImage error: %v", err) + } +} + +// handleEditImageEndpoint Handles the images endpoint by the test server. +func handleEditImageEndpoint(w http.ResponseWriter, r *http.Request) { + var resBytes []byte + + // imagess only accepts POST requests + if r.Method != "POST" { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } + + responses := ImageResponse{ + Created: time.Now().Unix(), + Data: []ImageResponseDataInner{ + { + URL: "test-url1", + B64JSON: "", + }, + { + URL: "test-url2", + B64JSON: "", + }, + { + URL: "test-url3", + B64JSON: "", + }, + }, + } + + resBytes, _ = json.Marshal(responses) + fmt.Fprintln(w, string(resBytes)) +} diff --git a/internal/test/server.go b/internal/test/server.go new file mode 100644 index 0000000..0c6f67d --- /dev/null +++ b/internal/test/server.go @@ -0,0 +1,46 @@ +package test + +import ( + "log" + "net/http" + "net/http/httptest" +) + +const testAPI = "this-is-my-secure-token-do-not-steal!!" + +func GetTestToken() string { + return testAPI +} + +type ServerTest struct { + handlers map[string]handler +} +type handler func(w http.ResponseWriter, r *http.Request) + +func NewTestServer() *ServerTest { + return &ServerTest{handlers: make(map[string]handler)} +} + +func (ts *ServerTest) RegisterHandler(path string, handler handler) { + ts.handlers[path] = handler +} + +// OpenAITestServer Creates a mocked OpenAI server which can pretend to handle requests during testing. +func (ts *ServerTest) OpenAITestServer() *httptest.Server { + return httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + log.Printf("received request at path %q\n", r.URL.Path) + + // check auth + if r.Header.Get("Authorization") != "Bearer "+GetTestToken() { + w.WriteHeader(http.StatusUnauthorized) + return + } + + handlerCall, ok := ts.handlers[r.URL.Path] + if !ok { + http.Error(w, "the resource path doesn't exist", http.StatusNotFound) + return + } + handlerCall(w, r) + })) +} diff --git a/moderation_test.go b/moderation_test.go new file mode 100644 index 0000000..3198cb6 --- /dev/null +++ b/moderation_test.go @@ -0,0 +1,101 @@ +package gogpt_test + +import ( + . "github.com/sashabaranov/go-gpt3" + "github.com/sashabaranov/go-gpt3/internal/test" + + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strconv" + "strings" + "testing" + "time" +) + +// TestModeration Tests the moderations endpoint of the API using the mocked server. +func TestModerations(t *testing.T) { + server := test.NewTestServer() + server.RegisterHandler("/v1/moderations", handleModerationEndpoint) + // create the test server + var err error + ts := server.OpenAITestServer() + ts.Start() + defer ts.Close() + + client := NewClient(test.GetTestToken()) + ctx := context.Background() + client.BaseURL = ts.URL + "/v1" + + // create an edit request + model := "text-moderation-stable" + moderationReq := ModerationRequest{ + Model: &model, + Input: "I want to kill them.", + } + _, err = client.Moderations(ctx, moderationReq) + if err != nil { + t.Fatalf("Moderation error: %v", err) + } +} + +// handleModerationEndpoint Handles the moderation endpoint by the test server. +func handleModerationEndpoint(w http.ResponseWriter, r *http.Request) { + var err error + var resBytes []byte + + // completions only accepts POST requests + if r.Method != "POST" { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } + var moderationReq ModerationRequest + if moderationReq, err = getModerationBody(r); err != nil { + http.Error(w, "could not read request", http.StatusInternalServerError) + return + } + + resCat := ResultCategories{} + resCatScore := ResultCategoryScores{} + switch { + case strings.Contains(moderationReq.Input, "kill"): + resCat = ResultCategories{Violence: true} + resCatScore = ResultCategoryScores{Violence: 1} + case strings.Contains(moderationReq.Input, "hate"): + resCat = ResultCategories{Hate: true} + resCatScore = ResultCategoryScores{Hate: 1} + case strings.Contains(moderationReq.Input, "suicide"): + resCat = ResultCategories{SelfHarm: true} + resCatScore = ResultCategoryScores{SelfHarm: 1} + case strings.Contains(moderationReq.Input, "porn"): + resCat = ResultCategories{Sexual: true} + resCatScore = ResultCategoryScores{Sexual: 1} + } + + result := Result{Categories: resCat, CategoryScores: resCatScore, Flagged: true} + + res := ModerationResponse{ + ID: strconv.Itoa(int(time.Now().Unix())), + Model: *moderationReq.Model, + } + res.Results = append(res.Results, result) + + resBytes, _ = json.Marshal(res) + fmt.Fprintln(w, string(resBytes)) +} + +// getModerationBody Returns the body of the request to do a moderation. +func getModerationBody(r *http.Request) (ModerationRequest, error) { + moderation := ModerationRequest{} + // read the request body + reqBody, err := io.ReadAll(r.Body) + if err != nil { + return ModerationRequest{}, err + } + err = json.Unmarshal(reqBody, &moderation) + if err != nil { + return ModerationRequest{}, err + } + return moderation, nil +} diff --git a/stream_test.go b/stream_test.go index bd7ddf7..c19e534 100644 --- a/stream_test.go +++ b/stream_test.go @@ -1,12 +1,13 @@ package gogpt_test import ( + . "github.com/sashabaranov/go-gpt3" + "github.com/sashabaranov/go-gpt3/internal/test" + "context" "net/http" "net/http/httptest" "testing" - - . "github.com/sashabaranov/go-gpt3" ) func TestCreateCompletionStream(t *testing.T) { @@ -36,7 +37,7 @@ func TestCreateCompletionStream(t *testing.T) { defer server.Close() // Client portion of the test - client := NewClient(testAPIToken) + client := NewClient(test.GetTestToken()) ctx := context.Background() client.BaseURL = server.URL + "/v1" @@ -48,7 +49,7 @@ func TestCreateCompletionStream(t *testing.T) { } client.HTTPClient.Transport = &tokenRoundTripper{ - testAPIToken, + test.GetTestToken(), http.DefaultTransport, }