feat: implement new fine tuning job API (#479)

* feat: implement new fine tuning job API

* fix: export ListFineTuningJobEventsParameter

* fix: lint errors

* fix: test errors

* fix: code test coverage

* fix: code test coverage

* fix: use any

* chore: use url.Values
This commit is contained in:
Simone Vellei
2023-08-29 14:04:27 +02:00
committed by GitHub
parent a14bc103f4
commit a2ca01bb6d
3 changed files with 255 additions and 0 deletions

View File

@@ -223,6 +223,18 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) {
{"ListFineTuneEvents", func() (any, error) { {"ListFineTuneEvents", func() (any, error) {
return client.ListFineTuneEvents(ctx, "") return client.ListFineTuneEvents(ctx, "")
}}, }},
{"CreateFineTuningJob", func() (any, error) {
return client.CreateFineTuningJob(ctx, FineTuningJobRequest{})
}},
{"CancelFineTuningJob", func() (any, error) {
return client.CancelFineTuningJob(ctx, "")
}},
{"RetrieveFineTuningJob", func() (any, error) {
return client.RetrieveFineTuningJob(ctx, "")
}},
{"ListFineTuningJobEvents", func() (any, error) {
return client.ListFineTuningJobEvents(ctx, "")
}},
{"Moderations", func() (any, error) { {"Moderations", func() (any, error) {
return client.Moderations(ctx, ModerationRequest{}) return client.Moderations(ctx, ModerationRequest{})
}}, }},

153
fine_tuning_job.go Normal file
View File

@@ -0,0 +1,153 @@
package openai
import (
"context"
"fmt"
"net/http"
"net/url"
)
type FineTuningJob struct {
ID string `json:"id"`
Object string `json:"object"`
CreatedAt int64 `json:"created_at"`
FinishedAt int64 `json:"finished_at"`
Model string `json:"model"`
FineTunedModel string `json:"fine_tuned_model,omitempty"`
OrganizationID string `json:"organization_id"`
Status string `json:"status"`
Hyperparameters Hyperparameters `json:"hyperparameters"`
TrainingFile string `json:"training_file"`
ValidationFile string `json:"validation_file,omitempty"`
ResultFiles []string `json:"result_files"`
TrainedTokens int `json:"trained_tokens"`
}
type Hyperparameters struct {
Epochs int `json:"n_epochs"`
}
type FineTuningJobRequest struct {
TrainingFile string `json:"training_file"`
ValidationFile string `json:"validation_file,omitempty"`
Model string `json:"model,omitempty"`
Hyperparameters *Hyperparameters `json:"hyperparameters,omitempty"`
Suffix string `json:"suffix,omitempty"`
}
type FineTuningJobEventList struct {
Object string `json:"object"`
Data []FineTuneEvent `json:"data"`
HasMore bool `json:"has_more"`
}
type FineTuningJobEvent struct {
Object string `json:"object"`
ID string `json:"id"`
CreatedAt int `json:"created_at"`
Level string `json:"level"`
Message string `json:"message"`
Data any `json:"data"`
Type string `json:"type"`
}
// CreateFineTuningJob create a fine tuning job.
func (c *Client) CreateFineTuningJob(
ctx context.Context,
request FineTuningJobRequest,
) (response FineTuningJob, err error) {
urlSuffix := "/fine_tuning/jobs"
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request))
if err != nil {
return
}
err = c.sendRequest(req, &response)
return
}
// CancelFineTuningJob cancel a fine tuning job.
func (c *Client) CancelFineTuningJob(ctx context.Context, fineTuningJobID string) (response FineTuningJob, err error) {
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/fine_tuning/jobs/"+fineTuningJobID+"/cancel"))
if err != nil {
return
}
err = c.sendRequest(req, &response)
return
}
// RetrieveFineTuningJob retrieve a fine tuning job.
func (c *Client) RetrieveFineTuningJob(
ctx context.Context,
fineTuningJobID string,
) (response FineTuningJob, err error) {
urlSuffix := fmt.Sprintf("/fine_tuning/jobs/%s", fineTuningJobID)
req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix))
if err != nil {
return
}
err = c.sendRequest(req, &response)
return
}
type listFineTuningJobEventsParameters struct {
after *string
limit *int
}
type ListFineTuningJobEventsParameter func(*listFineTuningJobEventsParameters)
func ListFineTuningJobEventsWithAfter(after string) ListFineTuningJobEventsParameter {
return func(args *listFineTuningJobEventsParameters) {
args.after = &after
}
}
func ListFineTuningJobEventsWithLimit(limit int) ListFineTuningJobEventsParameter {
return func(args *listFineTuningJobEventsParameters) {
args.limit = &limit
}
}
// ListFineTuningJobs list fine tuning jobs events.
func (c *Client) ListFineTuningJobEvents(
ctx context.Context,
fineTuningJobID string,
setters ...ListFineTuningJobEventsParameter,
) (response FineTuningJobEventList, err error) {
parameters := &listFineTuningJobEventsParameters{
after: nil,
limit: nil,
}
for _, setter := range setters {
setter(parameters)
}
urlValues := url.Values{}
if parameters.after != nil {
urlValues.Add("after", *parameters.after)
}
if parameters.limit != nil {
urlValues.Add("limit", fmt.Sprintf("%d", *parameters.limit))
}
encodedValues := ""
if len(urlValues) > 0 {
encodedValues = "?" + urlValues.Encode()
}
req, err := c.newRequest(
ctx,
http.MethodGet,
c.fullURL("/fine_tuning/jobs/"+fineTuningJobID+"/events"+encodedValues),
)
if err != nil {
return
}
err = c.sendRequest(req, &response)
return
}

90
fine_tuning_job_test.go Normal file
View File

@@ -0,0 +1,90 @@
package openai_test
import (
"context"
. "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks"
"encoding/json"
"fmt"
"net/http"
"testing"
)
const testFineTuninigJobID = "fine-tuning-job-id"
// TestFineTuningJob Tests the fine tuning job endpoint of the API using the mocked server.
func TestFineTuningJob(t *testing.T) {
client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler(
"/v1/fine_tuning/jobs",
func(w http.ResponseWriter, r *http.Request) {
var resBytes []byte
resBytes, _ = json.Marshal(FineTuningJob{})
fmt.Fprintln(w, string(resBytes))
},
)
server.RegisterHandler(
"/fine_tuning/jobs/"+testFineTuninigJobID+"/cancel",
func(w http.ResponseWriter, r *http.Request) {
resBytes, _ := json.Marshal(FineTuningJob{})
fmt.Fprintln(w, string(resBytes))
},
)
server.RegisterHandler(
"/v1/fine_tuning/jobs/"+testFineTuninigJobID,
func(w http.ResponseWriter, r *http.Request) {
var resBytes []byte
resBytes, _ = json.Marshal(FineTuningJob{})
fmt.Fprintln(w, string(resBytes))
},
)
server.RegisterHandler(
"/v1/fine_tuning/jobs/"+testFineTuninigJobID+"/events",
func(w http.ResponseWriter, r *http.Request) {
resBytes, _ := json.Marshal(FineTuningJobEventList{})
fmt.Fprintln(w, string(resBytes))
},
)
ctx := context.Background()
_, err := client.CreateFineTuningJob(ctx, FineTuningJobRequest{})
checks.NoError(t, err, "CreateFineTuningJob error")
_, err = client.CancelFineTuningJob(ctx, testFineTuninigJobID)
checks.NoError(t, err, "CancelFineTuningJob error")
_, err = client.RetrieveFineTuningJob(ctx, testFineTuninigJobID)
checks.NoError(t, err, "RetrieveFineTuningJob error")
_, err = client.ListFineTuningJobEvents(ctx, testFineTuninigJobID)
checks.NoError(t, err, "ListFineTuningJobEvents error")
_, err = client.ListFineTuningJobEvents(
ctx,
testFineTuninigJobID,
ListFineTuningJobEventsWithAfter("last-event-id"),
)
checks.NoError(t, err, "ListFineTuningJobEvents error")
_, err = client.ListFineTuningJobEvents(
ctx,
testFineTuninigJobID,
ListFineTuningJobEventsWithLimit(10),
)
checks.NoError(t, err, "ListFineTuningJobEvents error")
_, err = client.ListFineTuningJobEvents(
ctx,
testFineTuninigJobID,
ListFineTuningJobEventsWithAfter("last-event-id"),
ListFineTuningJobEventsWithLimit(10),
)
checks.NoError(t, err, "ListFineTuningJobEvents error")
}