package openai_test import ( . "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai/internal/test" "context" "encoding/json" "fmt" "io" "net/http" "strconv" "strings" "testing" "time" ) // TestCompletions Tests the completions endpoint of the API using the mocked server. func TestCompletions(t *testing.T) { server := test.NewTestServer() server.RegisterHandler("/v1/completions", handleCompletionEndpoint) // create the test server var err error ts := server.OpenAITestServer() ts.Start() defer ts.Close() config := DefaultConfig(test.GetTestToken()) config.BaseURL = ts.URL + "/v1" client := NewClientWithConfig(config) ctx := context.Background() req := CompletionRequest{ MaxTokens: 5, Model: "ada", } req.Prompt = "Lorem ipsum" _, err = client.CreateCompletion(ctx, req) if err != nil { t.Fatalf("CreateCompletion error: %v", err) } } // handleCompletionEndpoint Handles the completion endpoint by the test server. func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) { var err error var resBytes []byte // completions only accepts POST requests if r.Method != "POST" { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) } var completionReq CompletionRequest if completionReq, err = getCompletionBody(r); err != nil { http.Error(w, "could not read request", http.StatusInternalServerError) return } res := CompletionResponse{ ID: strconv.Itoa(int(time.Now().Unix())), Object: "test-object", Created: time.Now().Unix(), // would be nice to validate Model during testing, but // this may not be possible with how much upkeep // would be required / wouldn't make much sense Model: completionReq.Model, } // create completions for i := 0; i < completionReq.N; i++ { // generate a random string of length completionReq.Length completionStr := strings.Repeat("a", completionReq.MaxTokens) if completionReq.Echo { completionStr = completionReq.Prompt + completionStr } res.Choices = append(res.Choices, CompletionChoice{ Text: completionStr, Index: i, }) } inputTokens := numTokens(completionReq.Prompt) * completionReq.N completionTokens := completionReq.MaxTokens * completionReq.N res.Usage = Usage{ PromptTokens: inputTokens, CompletionTokens: completionTokens, TotalTokens: inputTokens + completionTokens, } resBytes, _ = json.Marshal(res) fmt.Fprintln(w, string(resBytes)) } // getCompletionBody Returns the body of the request to create a completion. func getCompletionBody(r *http.Request) (CompletionRequest, error) { completion := CompletionRequest{} // read the request body reqBody, err := io.ReadAll(r.Body) if err != nil { return CompletionRequest{}, err } err = json.Unmarshal(reqBody, &completion) if err != nil { return CompletionRequest{}, err } return completion, nil }