package openai_test import ( . "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai/internal/test" "context" "encoding/json" "fmt" "io" "net/http" "os" "testing" "time" ) func TestImages(t *testing.T) { server := test.NewTestServer() server.RegisterHandler("/v1/images/generations", handleImageEndpoint) // create the test server var err error ts := server.OpenAITestServer() ts.Start() defer ts.Close() config := DefaultConfig(test.GetTestToken()) config.BaseURL = ts.URL + "/v1" client := NewClientWithConfig(config) ctx := context.Background() req := ImageRequest{} req.Prompt = "Lorem ipsum" _, err = client.CreateImage(ctx, req) if err != nil { t.Fatalf("CreateImage error: %v", err) } } // handleImageEndpoint Handles the images endpoint by the test server. func handleImageEndpoint(w http.ResponseWriter, r *http.Request) { var err error var resBytes []byte // imagess only accepts POST requests if r.Method != "POST" { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) } var imageReq ImageRequest if imageReq, err = getImageBody(r); err != nil { http.Error(w, "could not read request", http.StatusInternalServerError) return } res := ImageResponse{ Created: time.Now().Unix(), } for i := 0; i < imageReq.N; i++ { imageData := ImageResponseDataInner{} switch imageReq.ResponseFormat { case CreateImageResponseFormatURL, "": imageData.URL = "https://example.com/image.png" case CreateImageResponseFormatB64JSON: // This decodes to "{}" in base64. imageData.B64JSON = "e30K" default: http.Error(w, "invalid response format", http.StatusBadRequest) return } res.Data = append(res.Data, imageData) } resBytes, _ = json.Marshal(res) fmt.Fprintln(w, string(resBytes)) } // getImageBody Returns the body of the request to create a image. func getImageBody(r *http.Request) (ImageRequest, error) { image := ImageRequest{} // read the request body reqBody, err := io.ReadAll(r.Body) if err != nil { return ImageRequest{}, err } err = json.Unmarshal(reqBody, &image) if err != nil { return ImageRequest{}, err } return image, nil } func TestImageEdit(t *testing.T) { server := test.NewTestServer() server.RegisterHandler("/v1/images/edits", handleEditImageEndpoint) // create the test server var err error ts := server.OpenAITestServer() ts.Start() defer ts.Close() config := DefaultConfig(test.GetTestToken()) config.BaseURL = ts.URL + "/v1" client := NewClientWithConfig(config) ctx := context.Background() origin, err := os.Create("image.png") if err != nil { t.Error("open origin file error") return } mask, err := os.Create("mask.png") if err != nil { t.Error("open mask file error") return } defer func() { mask.Close() origin.Close() os.Remove("mask.png") os.Remove("image.png") }() req := ImageEditRequest{ Image: origin, Mask: mask, Prompt: "There is a turtle in the pool", N: 3, Size: CreateImageSize1024x1024, } _, err = client.CreateEditImage(ctx, req) if err != nil { t.Fatalf("CreateImage error: %v", err) } } func TestImageEditWithoutMask(t *testing.T) { server := test.NewTestServer() server.RegisterHandler("/v1/images/edits", handleEditImageEndpoint) // create the test server var err error ts := server.OpenAITestServer() ts.Start() defer ts.Close() config := DefaultConfig(test.GetTestToken()) config.BaseURL = ts.URL + "/v1" client := NewClientWithConfig(config) ctx := context.Background() origin, err := os.Create("image.png") if err != nil { t.Error("open origin file error") return } defer func() { origin.Close() os.Remove("image.png") }() req := ImageEditRequest{ Image: origin, Prompt: "There is a turtle in the pool", N: 3, Size: CreateImageSize1024x1024, } _, err = client.CreateEditImage(ctx, req) if err != nil { t.Fatalf("CreateImage error: %v", err) } } // handleEditImageEndpoint Handles the images endpoint by the test server. func handleEditImageEndpoint(w http.ResponseWriter, r *http.Request) { var resBytes []byte // imagess only accepts POST requests if r.Method != "POST" { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) } responses := ImageResponse{ Created: time.Now().Unix(), Data: []ImageResponseDataInner{ { URL: "test-url1", B64JSON: "", }, { URL: "test-url2", B64JSON: "", }, { URL: "test-url3", B64JSON: "", }, }, } resBytes, _ = json.Marshal(responses) fmt.Fprintln(w, string(resBytes)) } func TestImageVariation(t *testing.T) { server := test.NewTestServer() server.RegisterHandler("/v1/images/edits", handleVariateImageEndpoint) // create the test server var err error ts := server.OpenAITestServer() ts.Start() defer ts.Close() config := DefaultConfig(test.GetTestToken()) config.BaseURL = ts.URL + "/v1" client := NewClientWithConfig(config) ctx := context.Background() origin, err := os.Create("image.png") if err != nil { t.Error("open origin file error") return } defer func() { origin.Close() os.Remove("image.png") }() req := ImageVariationRequest{ Image: origin, N: 3, Size: CreateImageSize1024x1024, } _, err = client.CreateVariateImage(ctx, req) if err != nil { t.Fatalf("CreateImage error: %v", err) } } // handleVariateImageEndpoint Handles the images endpoint by the test server. func handleVariateImageEndpoint(w http.ResponseWriter, r *http.Request) { var resBytes []byte // imagess only accepts POST requests if r.Method != "POST" { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) } responses := ImageResponse{ Created: time.Now().Unix(), Data: []ImageResponseDataInner{ { URL: "test-url1", B64JSON: "", }, { URL: "test-url2", B64JSON: "", }, { URL: "test-url3", B64JSON: "", }, }, } resBytes, _ = json.Marshal(responses) fmt.Fprintln(w, string(resBytes)) }