Add api client code
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -13,3 +13,6 @@
|
|||||||
|
|
||||||
# Dependency directories (remove the comment below to include it)
|
# Dependency directories (remove the comment below to include it)
|
||||||
# vendor/
|
# vendor/
|
||||||
|
|
||||||
|
# Auth token for tests
|
||||||
|
.openai-token
|
||||||
55
api.go
Normal file
55
api.go
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
package gogpt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const apiURLv1 = "https://api.openai.com/v1"
|
||||||
|
|
||||||
|
// Client is OpenAI GPT-3 API client
|
||||||
|
type Client struct {
|
||||||
|
BaseURL string
|
||||||
|
authToken string
|
||||||
|
HTTPClient *http.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewClient creates new OpenAI API client
|
||||||
|
func NewClient(authToken string) *Client {
|
||||||
|
return &Client{
|
||||||
|
BaseURL: apiURLv1,
|
||||||
|
authToken: authToken,
|
||||||
|
HTTPClient: &http.Client{
|
||||||
|
Timeout: time.Minute,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) sendRequest(req *http.Request, v interface{}) error {
|
||||||
|
req.Header.Set("Content-Type", "application/json; charset=utf-8")
|
||||||
|
req.Header.Set("Accept", "application/json; charset=utf-8")
|
||||||
|
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.authToken))
|
||||||
|
|
||||||
|
res, err := c.HTTPClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
defer res.Body.Close()
|
||||||
|
|
||||||
|
if res.StatusCode < http.StatusOK || res.StatusCode >= http.StatusBadRequest {
|
||||||
|
return fmt.Errorf("error, status code: %d", res.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = json.NewDecoder(res.Body).Decode(&v); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) fullURL(suffix string) string {
|
||||||
|
return fmt.Sprintf("%s%s", c.BaseURL, suffix)
|
||||||
|
}
|
||||||
42
api_test.go
Normal file
42
api_test.go
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
package gogpt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"io/ioutil"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAPI(t *testing.T) {
|
||||||
|
tokenBytes, err := ioutil.ReadFile(".openai-token")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Could not load auth token from .openai-token file")
|
||||||
|
}
|
||||||
|
|
||||||
|
c := NewClient(string(tokenBytes))
|
||||||
|
ctx := context.Background()
|
||||||
|
_, err = c.ListEngines(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ListEngines error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = c.GetEngine(ctx, "davinci")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetEngine error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req := CompletionRequest{MaxTokens: 5}
|
||||||
|
req.Prompt = "Lorem ipsum"
|
||||||
|
_, err = c.CreateCompletion(ctx, "ada", req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreateCompletion error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
searchReq := SearchRequest{
|
||||||
|
Documents: []string{"White House", "hospital", "school"},
|
||||||
|
Query: "the president",
|
||||||
|
}
|
||||||
|
_, err = c.Search(ctx, "ada", searchReq)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Search error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
60
completion.go
Normal file
60
completion.go
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
package gogpt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
type CompletionRequest struct {
|
||||||
|
Prompt string `json:"prompt,omitempty"`
|
||||||
|
MaxTokens int `json:"max_tokens,omitempty"`
|
||||||
|
|
||||||
|
Temperature float32 `json:"temperature,omitempty"`
|
||||||
|
TopP float32 `json:"top_p,omitempty"`
|
||||||
|
|
||||||
|
N int `json:"n,omitempty"`
|
||||||
|
|
||||||
|
LogProbs int `json:"logobs,omitempty"`
|
||||||
|
|
||||||
|
Echo bool `json:"echo,omitempty"`
|
||||||
|
Stop string `json:"stop,omitempty"`
|
||||||
|
|
||||||
|
PresencePenalty float32 `json:"presence_penalty,omitempty"`
|
||||||
|
FrequencyPenalty float32 `json:"frequency_penalty,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Choice struct {
|
||||||
|
Text string `json:"text"`
|
||||||
|
Index int `json:"index"`
|
||||||
|
FinishReason string `json:"finish_reason"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type CompletionResponse struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Object string `json:"object"`
|
||||||
|
Created uint64 `json:"created"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
Сhoices []Choice `json:"choices"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateCompletion — API call to create a completion. This is the main endpoint of the API. Returns new text as well as, if requested, the probabilities over each alternative token at each position.
|
||||||
|
func (c *Client) CreateCompletion(ctx context.Context, engineID string, request CompletionRequest) (response CompletionResponse, err error) {
|
||||||
|
var reqBytes []byte
|
||||||
|
reqBytes, err = json.Marshal(request)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
urlSuffix := fmt.Sprintf("/engines/%s/completions", engineID)
|
||||||
|
req, err := http.NewRequest("POST", c.fullURL(urlSuffix), bytes.NewBuffer(reqBytes))
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
req = req.WithContext(ctx)
|
||||||
|
err = c.sendRequest(req, &response)
|
||||||
|
return
|
||||||
|
}
|
||||||
43
engines.go
Normal file
43
engines.go
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
package gogpt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Engine struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Object string `json:"object"`
|
||||||
|
Owner string `json:"owner"`
|
||||||
|
Ready bool `json:"ready"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type EnginesList struct {
|
||||||
|
Engines []Engine `json:"data"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListEngines Lists the currently available engines, and provides basic information about each option such as the owner and availability.
|
||||||
|
func (c *Client) ListEngines(ctx context.Context) (engines EnginesList, err error) {
|
||||||
|
req, err := http.NewRequest("GET", c.fullURL("/engines"), nil)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
req = req.WithContext(ctx)
|
||||||
|
err = c.sendRequest(req, &engines)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetEngine Retrieves an engine instance, providing basic information about the engine such as the owner and availability.
|
||||||
|
func (c *Client) GetEngine(ctx context.Context, engineID string) (engine Engine, err error) {
|
||||||
|
urlSuffix := fmt.Sprintf("/engines/%s", engineID)
|
||||||
|
req, err := http.NewRequest("GET", c.fullURL(urlSuffix), nil)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
req = req.WithContext(ctx)
|
||||||
|
err = c.sendRequest(req, &engine)
|
||||||
|
return
|
||||||
|
}
|
||||||
42
search.go
Normal file
42
search.go
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
package gogpt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
type SearchRequest struct {
|
||||||
|
Documents []string `json:"documents"`
|
||||||
|
Query string `json:"query"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type SearchResult struct {
|
||||||
|
Document int `json:"document"`
|
||||||
|
Score float32 `json:"score"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type SearchResponse struct {
|
||||||
|
SearchResults []SearchResult `json:"data"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Search — perform a semantic search api call over a list of documents.
|
||||||
|
func (c *Client) Search(ctx context.Context, engineID string, request SearchRequest) (response SearchResponse, err error) {
|
||||||
|
var reqBytes []byte
|
||||||
|
reqBytes, err = json.Marshal(request)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
urlSuffix := fmt.Sprintf("/engines/%s/search", engineID)
|
||||||
|
req, err := http.NewRequest("POST", c.fullURL(urlSuffix), bytes.NewBuffer(reqBytes))
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
req = req.WithContext(ctx)
|
||||||
|
err = c.sendRequest(req, &response)
|
||||||
|
return
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user