Files
go-openai/completion_test.go
Chris Hua f22da8a7ed feat: allow more input types to functions, fix tests (#377)
* feat: use json.rawMessage, test functions

* chore: lint

* fix: tests

the ChatCompletion mock server doesn't actually run otherwise. N=0
is the default request but the server will treat it as n=1

* fix: tests should default to n=1 completions

* chore: add back removed interfaces, custom marshal

* chore: lint

* chore: lint

* chore: add some tests

* chore: appease lint

* clean up JSON schema + tests

* chore: lint

* feat: remove backwards compatible functions

for illustrative purposes

* fix: revert params change

* chore: use interface{}

* chore: add test

* chore: add back FunctionDefine

* chore: /s/interface{}/any

* chore: add back jsonschemadefinition

* chore: testcov

* chore: lint

* chore: remove pointers

* chore: update comment

* chore: address CR

added test for compatibility as well

---------

Co-authored-by: James <jmacwhyte@MacBooger-II.local>
2023-06-21 16:58:27 +04:00

126 lines
3.5 KiB
Go

package openai_test
import (
. "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strconv"
"strings"
"testing"
"time"
)
func TestCompletionsWrongModel(t *testing.T) {
config := DefaultConfig("whatever")
config.BaseURL = "http://localhost/v1"
client := NewClientWithConfig(config)
_, err := client.CreateCompletion(
context.Background(),
CompletionRequest{
MaxTokens: 5,
Model: GPT3Dot5Turbo,
},
)
if !errors.Is(err, ErrCompletionUnsupportedModel) {
t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel, but returned: %v", err)
}
}
func TestCompletionWithStream(t *testing.T) {
config := DefaultConfig("whatever")
client := NewClientWithConfig(config)
ctx := context.Background()
req := CompletionRequest{Stream: true}
_, err := client.CreateCompletion(ctx, req)
if !errors.Is(err, ErrCompletionStreamNotSupported) {
t.Fatalf("CreateCompletion didn't return ErrCompletionStreamNotSupported")
}
}
// TestCompletions Tests the completions endpoint of the API using the mocked server.
func TestCompletions(t *testing.T) {
client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler("/v1/completions", handleCompletionEndpoint)
req := CompletionRequest{
MaxTokens: 5,
Model: "ada",
Prompt: "Lorem ipsum",
}
_, err := client.CreateCompletion(context.Background(), req)
checks.NoError(t, err, "CreateCompletion error")
}
// handleCompletionEndpoint Handles the completion endpoint by the test server.
func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
var err error
var resBytes []byte
// completions only accepts POST requests
if r.Method != "POST" {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
var completionReq CompletionRequest
if completionReq, err = getCompletionBody(r); err != nil {
http.Error(w, "could not read request", http.StatusInternalServerError)
return
}
res := CompletionResponse{
ID: strconv.Itoa(int(time.Now().Unix())),
Object: "test-object",
Created: time.Now().Unix(),
// would be nice to validate Model during testing, but
// this may not be possible with how much upkeep
// would be required / wouldn't make much sense
Model: completionReq.Model,
}
// create completions
n := completionReq.N
if n == 0 {
n = 1
}
for i := 0; i < n; i++ {
// generate a random string of length completionReq.Length
completionStr := strings.Repeat("a", completionReq.MaxTokens)
if completionReq.Echo {
completionStr = completionReq.Prompt.(string) + completionStr
}
res.Choices = append(res.Choices, CompletionChoice{
Text: completionStr,
Index: i,
})
}
inputTokens := numTokens(completionReq.Prompt.(string)) * n
completionTokens := completionReq.MaxTokens * n
res.Usage = Usage{
PromptTokens: inputTokens,
CompletionTokens: completionTokens,
TotalTokens: inputTokens + completionTokens,
}
resBytes, _ = json.Marshal(res)
fmt.Fprintln(w, string(resBytes))
}
// getCompletionBody Returns the body of the request to create a completion.
func getCompletionBody(r *http.Request) (CompletionRequest, error) {
completion := CompletionRequest{}
// read the request body
reqBody, err := io.ReadAll(r.Body)
if err != nil {
return CompletionRequest{}, err
}
err = json.Unmarshal(reqBody, &completion)
if err != nil {
return CompletionRequest{}, err
}
return completion, nil
}