Migrate From Old Completions + Embedding Endpoint (#28)

* migrate away from deprecated OpenAI endpoints

Signed-off-by: Oleg <97077423+RobotSail@users.noreply.github.com>

* test embedding correctness

Signed-off-by: Oleg <97077423+RobotSail@users.noreply.github.com>
This commit is contained in:
Oleg
2022-08-02 04:16:54 -04:00
committed by GitHub
parent 51f94a6ab3
commit 53212c71df
5 changed files with 69 additions and 37 deletions

View File

@@ -27,10 +27,11 @@ func main() {
ctx := context.Background()
req := gogpt.CompletionRequest{
Model: "ada",
MaxTokens: 5,
Prompt: "Lorem ipsum",
}
resp, err := c.CreateCompletion(ctx, "ada", req)
resp, err := c.CreateCompletion(ctx, req)
if err != nil {
return
}

View File

@@ -1,7 +1,9 @@
package gogpt
import (
"bytes"
"context"
"encoding/json"
"io/ioutil"
"testing"
)
@@ -36,9 +38,12 @@ func TestAPI(t *testing.T) {
}
} // else skip
req := CompletionRequest{MaxTokens: 5}
req := CompletionRequest{
MaxTokens: 5,
Model: "ada",
}
req.Prompt = "Lorem ipsum"
_, err = c.CreateCompletion(ctx, "ada", req)
_, err = c.CreateCompletion(ctx, req)
if err != nil {
t.Fatalf("CreateCompletion error: %v", err)
}
@@ -57,9 +62,49 @@ func TestAPI(t *testing.T) {
"The food was delicious and the waiter",
"Other examples of embedding request",
},
Model: AdaSearchQuery,
}
_, err = c.CreateEmbeddings(ctx, embeddingReq, AdaSearchQuery)
_, err = c.CreateEmbeddings(ctx, embeddingReq)
if err != nil {
t.Fatalf("Embedding error: %v", err)
}
}
func TestEmbedding(t *testing.T) {
embeddedModels := []EmbeddingModel{
AdaSimilarity,
BabbageSimilarity,
CurieSimilarity,
DavinciSimilarity,
AdaSearchDocument,
AdaSearchQuery,
BabbageSearchDocument,
BabbageSearchQuery,
CurieSearchDocument,
CurieSearchQuery,
DavinciSearchDocument,
DavinciSearchQuery,
AdaCodeSearchCode,
AdaCodeSearchText,
BabbageCodeSearchCode,
BabbageCodeSearchText,
}
for _, model := range embeddedModels {
embeddingReq := EmbeddingRequest{
Input: []string{
"The food was delicious and the waiter",
"Other examples of embedding request",
},
Model: model,
}
// marshal embeddingReq to JSON and confirm that the model field equals
// the AdaSearchQuery type
marshaled, err := json.Marshal(embeddingReq)
if err != nil {
t.Fatalf("Could not marshal embedding request: %v", err)
}
if !bytes.Contains(marshaled, []byte(`"model":"`+model.String()+`"`)) {
t.Fatalf("Expected embedding request to contain model field")
}
}
}

View File

@@ -4,13 +4,12 @@ import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
)
// CompletionRequest represents a request structure for completion API
type CompletionRequest struct {
Model *string `json:"model,omitempty"`
Model string `json:"model"`
Prompt string `json:"prompt,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
Temperature float32 `json:"temperature,omitempty"`
@@ -60,29 +59,12 @@ type CompletionResponse struct {
Usage CompletionUsage `json:"usage"`
}
// CreateCompletion — API call to create a completion. This is the main endpoint of the API. Returns new text as well as, if requested, the probabilities over each alternative token at each position.
func (c *Client) CreateCompletion(ctx context.Context, engineID string, request CompletionRequest) (response CompletionResponse, err error) {
var reqBytes []byte
reqBytes, err = json.Marshal(request)
if err != nil {
return
}
urlSuffix := fmt.Sprintf("/engines/%s/completions", engineID)
req, err := http.NewRequest("POST", c.fullURL(urlSuffix), bytes.NewBuffer(reqBytes))
if err != nil {
return
}
req = req.WithContext(ctx)
err = c.sendRequest(req, &response)
return
}
// CreateCompletionWithFineTunedModel - API call to create a completion with a fine tuned model
// See https://beta.openai.com/docs/guides/fine-tuning/use-a-fine-tuned-model
// In this case, the model is specified in the CompletionRequest object.
func (c *Client) CreateCompletionWithFineTunedModel(ctx context.Context, request CompletionRequest) (response CompletionResponse, err error) {
// CreateCompletion — API call to create a completion. This is the main endpoint of the API. Returns new text as well
// as, if requested, the probabilities over each alternative token at each position.
//
// If using a fine-tuned model, simply provide the model's ID in the CompletionRequest object,
// and the server will use the model's parameters to generate the completion.
func (c *Client) CreateCompletion(ctx context.Context, request CompletionRequest) (response CompletionResponse, err error) {
var reqBytes []byte
reqBytes, err = json.Marshal(request)
if err != nil {

View File

@@ -4,7 +4,6 @@ import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
)
@@ -120,18 +119,23 @@ type EmbeddingRequest struct {
// E.g.
// "The food was delicious and the waiter..."
Input []string `json:"input"`
// ID of the model to use. You can use the List models API to see all of your available models,
// or see our Model overview for descriptions of them.
Model EmbeddingModel `json:"model"`
// A unique identifier representing your end-user, which will help OpenAI to monitor and detect abuse.
User string `json:"user"`
}
// CreateEmbeddings returns an EmbeddingResponse which will contain an Embedding for every item in |request.Input|.
// https://beta.openai.com/docs/api-reference/embeddings/create
func (c *Client) CreateEmbeddings(ctx context.Context, request EmbeddingRequest, model EmbeddingModel) (resp EmbeddingResponse, err error) {
func (c *Client) CreateEmbeddings(ctx context.Context, request EmbeddingRequest) (resp EmbeddingResponse, err error) {
var reqBytes []byte
reqBytes, err = json.Marshal(request)
if err != nil {
return
}
urlSuffix := fmt.Sprintf("/engines/%s/embeddings", model)
urlSuffix := "/embeddings"
req, err := http.NewRequest(http.MethodPost, c.fullURL(urlSuffix), bytes.NewBuffer(reqBytes))
if err != nil {
return

View File

@@ -17,11 +17,11 @@ import (
*/
type SearchRequest struct {
Query string `json:"query"`
Documents []string `json:"documents"` // 1*
FileID string `json:"file"` // 1*
MaxRerank int `json:"max_rerank"` // 2*
ReturnMetadata bool `json:"return_metadata"`
User string `json:"user"`
Documents []string `json:"documents"` // 1*
FileID string `json:"file,omitempty"` // 1*
MaxRerank int `json:"max_rerank,omitempty"` // 2*
ReturnMetadata bool `json:"return_metadata,omitempty"`
User string `json:"user,omitempty"`
}
// SearchResult represents single result from search API