Files
go-openai/image_test.go
JoyShi 21eef5bc8d Move form_builder into internal pkg. (#311)
* Move form_uilder into internal pkg.

* Fix import of audio.go

* Reorganize.

* Fix import.

* Fix

---------

Co-authored-by: JoyShi <joy.shi@sap.com>
2023-05-16 13:38:09 -07:00

406 lines
10 KiB
Go

package openai //nolint:testpackage // testing private field
import (
utils "github.com/sashabaranov/go-openai/internal"
"github.com/sashabaranov/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"testing"
"time"
)
func TestImages(t *testing.T) {
server := test.NewTestServer()
server.RegisterHandler("/v1/images/generations", handleImageEndpoint)
// create the test server
var err error
ts := server.OpenAITestServer()
ts.Start()
defer ts.Close()
config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
ctx := context.Background()
req := ImageRequest{}
req.Prompt = "Lorem ipsum"
_, err = client.CreateImage(ctx, req)
checks.NoError(t, err, "CreateImage error")
}
// handleImageEndpoint Handles the images endpoint by the test server.
func handleImageEndpoint(w http.ResponseWriter, r *http.Request) {
var err error
var resBytes []byte
// imagess only accepts POST requests
if r.Method != "POST" {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
var imageReq ImageRequest
if imageReq, err = getImageBody(r); err != nil {
http.Error(w, "could not read request", http.StatusInternalServerError)
return
}
res := ImageResponse{
Created: time.Now().Unix(),
}
for i := 0; i < imageReq.N; i++ {
imageData := ImageResponseDataInner{}
switch imageReq.ResponseFormat {
case CreateImageResponseFormatURL, "":
imageData.URL = "https://example.com/image.png"
case CreateImageResponseFormatB64JSON:
// This decodes to "{}" in base64.
imageData.B64JSON = "e30K"
default:
http.Error(w, "invalid response format", http.StatusBadRequest)
return
}
res.Data = append(res.Data, imageData)
}
resBytes, _ = json.Marshal(res)
fmt.Fprintln(w, string(resBytes))
}
// getImageBody Returns the body of the request to create a image.
func getImageBody(r *http.Request) (ImageRequest, error) {
image := ImageRequest{}
// read the request body
reqBody, err := io.ReadAll(r.Body)
if err != nil {
return ImageRequest{}, err
}
err = json.Unmarshal(reqBody, &image)
if err != nil {
return ImageRequest{}, err
}
return image, nil
}
func TestImageEdit(t *testing.T) {
server := test.NewTestServer()
server.RegisterHandler("/v1/images/edits", handleEditImageEndpoint)
// create the test server
var err error
ts := server.OpenAITestServer()
ts.Start()
defer ts.Close()
config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
ctx := context.Background()
origin, err := os.Create("image.png")
if err != nil {
t.Error("open origin file error")
return
}
mask, err := os.Create("mask.png")
if err != nil {
t.Error("open mask file error")
return
}
defer func() {
mask.Close()
origin.Close()
os.Remove("mask.png")
os.Remove("image.png")
}()
req := ImageEditRequest{
Image: origin,
Mask: mask,
Prompt: "There is a turtle in the pool",
N: 3,
Size: CreateImageSize1024x1024,
ResponseFormat: CreateImageResponseFormatURL,
}
_, err = client.CreateEditImage(ctx, req)
checks.NoError(t, err, "CreateImage error")
}
func TestImageEditWithoutMask(t *testing.T) {
server := test.NewTestServer()
server.RegisterHandler("/v1/images/edits", handleEditImageEndpoint)
// create the test server
var err error
ts := server.OpenAITestServer()
ts.Start()
defer ts.Close()
config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
ctx := context.Background()
origin, err := os.Create("image.png")
if err != nil {
t.Error("open origin file error")
return
}
defer func() {
origin.Close()
os.Remove("image.png")
}()
req := ImageEditRequest{
Image: origin,
Prompt: "There is a turtle in the pool",
N: 3,
Size: CreateImageSize1024x1024,
ResponseFormat: CreateImageResponseFormatURL,
}
_, err = client.CreateEditImage(ctx, req)
checks.NoError(t, err, "CreateImage error")
}
// handleEditImageEndpoint Handles the images endpoint by the test server.
func handleEditImageEndpoint(w http.ResponseWriter, r *http.Request) {
var resBytes []byte
// imagess only accepts POST requests
if r.Method != "POST" {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
responses := ImageResponse{
Created: time.Now().Unix(),
Data: []ImageResponseDataInner{
{
URL: "test-url1",
B64JSON: "",
},
{
URL: "test-url2",
B64JSON: "",
},
{
URL: "test-url3",
B64JSON: "",
},
},
}
resBytes, _ = json.Marshal(responses)
fmt.Fprintln(w, string(resBytes))
}
func TestImageVariation(t *testing.T) {
server := test.NewTestServer()
server.RegisterHandler("/v1/images/variations", handleVariateImageEndpoint)
// create the test server
var err error
ts := server.OpenAITestServer()
ts.Start()
defer ts.Close()
config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
ctx := context.Background()
origin, err := os.Create("image.png")
if err != nil {
t.Error("open origin file error")
return
}
defer func() {
origin.Close()
os.Remove("image.png")
}()
req := ImageVariRequest{
Image: origin,
N: 3,
Size: CreateImageSize1024x1024,
ResponseFormat: CreateImageResponseFormatURL,
}
_, err = client.CreateVariImage(ctx, req)
checks.NoError(t, err, "CreateImage error")
}
// handleVariateImageEndpoint Handles the images endpoint by the test server.
func handleVariateImageEndpoint(w http.ResponseWriter, r *http.Request) {
var resBytes []byte
// imagess only accepts POST requests
if r.Method != "POST" {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
responses := ImageResponse{
Created: time.Now().Unix(),
Data: []ImageResponseDataInner{
{
URL: "test-url1",
B64JSON: "",
},
{
URL: "test-url2",
B64JSON: "",
},
{
URL: "test-url3",
B64JSON: "",
},
},
}
resBytes, _ = json.Marshal(responses)
fmt.Fprintln(w, string(resBytes))
}
type mockFormBuilder struct {
mockCreateFormFile func(string, *os.File) error
mockWriteField func(string, string) error
mockClose func() error
}
func (fb *mockFormBuilder) CreateFormFile(fieldname string, file *os.File) error {
return fb.mockCreateFormFile(fieldname, file)
}
func (fb *mockFormBuilder) WriteField(fieldname, value string) error {
return fb.mockWriteField(fieldname, value)
}
func (fb *mockFormBuilder) Close() error {
return fb.mockClose()
}
func (fb *mockFormBuilder) FormDataContentType() string {
return ""
}
func TestImageFormBuilderFailures(t *testing.T) {
config := DefaultConfig("")
config.BaseURL = ""
client := NewClientWithConfig(config)
mockBuilder := &mockFormBuilder{}
client.createFormBuilder = func(io.Writer) utils.FormBuilder {
return mockBuilder
}
ctx := context.Background()
req := ImageEditRequest{
Mask: &os.File{},
}
mockFailedErr := fmt.Errorf("mock form builder fail")
mockBuilder.mockCreateFormFile = func(string, *os.File) error {
return mockFailedErr
}
_, err := client.CreateEditImage(ctx, req)
checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails")
mockBuilder.mockCreateFormFile = func(name string, file *os.File) error {
if name == "mask" {
return mockFailedErr
}
return nil
}
_, err = client.CreateEditImage(ctx, req)
checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails")
mockBuilder.mockCreateFormFile = func(name string, file *os.File) error {
return nil
}
var failForField string
mockBuilder.mockWriteField = func(fieldname, value string) error {
if fieldname == failForField {
return mockFailedErr
}
return nil
}
failForField = "prompt"
_, err = client.CreateEditImage(ctx, req)
checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails")
failForField = "n"
_, err = client.CreateEditImage(ctx, req)
checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails")
failForField = "size"
_, err = client.CreateEditImage(ctx, req)
checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails")
failForField = "response_format"
_, err = client.CreateEditImage(ctx, req)
checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails")
failForField = ""
mockBuilder.mockClose = func() error {
return mockFailedErr
}
_, err = client.CreateEditImage(ctx, req)
checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails")
}
func TestVariImageFormBuilderFailures(t *testing.T) {
config := DefaultConfig("")
config.BaseURL = ""
client := NewClientWithConfig(config)
mockBuilder := &mockFormBuilder{}
client.createFormBuilder = func(io.Writer) utils.FormBuilder {
return mockBuilder
}
ctx := context.Background()
req := ImageVariRequest{}
mockFailedErr := fmt.Errorf("mock form builder fail")
mockBuilder.mockCreateFormFile = func(string, *os.File) error {
return mockFailedErr
}
_, err := client.CreateVariImage(ctx, req)
checks.ErrorIs(t, err, mockFailedErr, "CreateVariImage should return error if form builder fails")
mockBuilder.mockCreateFormFile = func(name string, file *os.File) error {
return nil
}
var failForField string
mockBuilder.mockWriteField = func(fieldname, value string) error {
if fieldname == failForField {
return mockFailedErr
}
return nil
}
failForField = "n"
_, err = client.CreateVariImage(ctx, req)
checks.ErrorIs(t, err, mockFailedErr, "CreateVariImage should return error if form builder fails")
failForField = "size"
_, err = client.CreateVariImage(ctx, req)
checks.ErrorIs(t, err, mockFailedErr, "CreateVariImage should return error if form builder fails")
failForField = "response_format"
_, err = client.CreateVariImage(ctx, req)
checks.ErrorIs(t, err, mockFailedErr, "CreateVariImage should return error if form builder fails")
failForField = ""
mockBuilder.mockClose = func() error {
return mockFailedErr
}
_, err = client.CreateVariImage(ctx, req)
checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails")
}