move error_accumulator into internal pkg (#304) (#335)

* move error_accumulator into internal pkg (#304)

* move error_accumulator into internal pkg (#304)

* add a test for ErrTooManyEmptyStreamMessages in stream_reader (#304)
This commit is contained in:
渡邉祐一 / Yuichi Watanabe
2023-06-05 23:35:46 +09:00
committed by GitHub
parent fa694c61c2
commit 1394329e44
12 changed files with 249 additions and 201 deletions

View File

@@ -66,7 +66,7 @@ func (c *Client) CreateChatCompletionStream(
emptyMessagesLimit: c.config.EmptyMessagesLimit, emptyMessagesLimit: c.config.EmptyMessagesLimit,
reader: bufio.NewReader(resp.Body), reader: bufio.NewReader(resp.Body),
response: resp, response: resp,
errAccumulator: newErrorAccumulator(), errAccumulator: utils.NewErrorAccumulator(),
unmarshaler: &utils.JSONUnmarshaler{}, unmarshaler: &utils.JSONUnmarshaler{},
}, },
} }

View File

@@ -1,7 +1,7 @@
package openai_test package openai //nolint:testpackage // testing private field
import ( import (
. "github.com/sashabaranov/go-openai" utils "github.com/sashabaranov/go-openai/internal"
"github.com/sashabaranov/go-openai/internal/test" "github.com/sashabaranov/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks" "github.com/sashabaranov/go-openai/internal/test/checks"
@@ -63,9 +63,9 @@ func TestCreateChatCompletionStream(t *testing.T) {
// Client portion of the test // Client portion of the test
config := DefaultConfig(test.GetTestToken()) config := DefaultConfig(test.GetTestToken())
config.BaseURL = server.URL + "/v1" config.BaseURL = server.URL + "/v1"
config.HTTPClient.Transport = &tokenRoundTripper{ config.HTTPClient.Transport = &test.TokenRoundTripper{
test.GetTestToken(), Token: test.GetTestToken(),
http.DefaultTransport, Fallback: http.DefaultTransport,
} }
client := NewClientWithConfig(config) client := NewClientWithConfig(config)
@@ -170,9 +170,9 @@ func TestCreateChatCompletionStreamError(t *testing.T) {
// Client portion of the test // Client portion of the test
config := DefaultConfig(test.GetTestToken()) config := DefaultConfig(test.GetTestToken())
config.BaseURL = server.URL + "/v1" config.BaseURL = server.URL + "/v1"
config.HTTPClient.Transport = &tokenRoundTripper{ config.HTTPClient.Transport = &test.TokenRoundTripper{
test.GetTestToken(), Token: test.GetTestToken(),
http.DefaultTransport, Fallback: http.DefaultTransport,
} }
client := NewClientWithConfig(config) client := NewClientWithConfig(config)
@@ -227,9 +227,9 @@ func TestCreateChatCompletionStreamRateLimitError(t *testing.T) {
// Client portion of the test // Client portion of the test
config := DefaultConfig(test.GetTestToken()) config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1" config.BaseURL = ts.URL + "/v1"
config.HTTPClient.Transport = &tokenRoundTripper{ config.HTTPClient.Transport = &test.TokenRoundTripper{
test.GetTestToken(), Token: test.GetTestToken(),
http.DefaultTransport, Fallback: http.DefaultTransport,
} }
client := NewClientWithConfig(config) client := NewClientWithConfig(config)
@@ -255,6 +255,33 @@ func TestCreateChatCompletionStreamRateLimitError(t *testing.T) {
t.Logf("%+v\n", apiErr) t.Logf("%+v\n", apiErr)
} }
func TestCreateChatCompletionStreamErrorAccumulatorWriteErrors(t *testing.T) {
var err error
server := test.NewTestServer()
server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "error", 200)
})
ts := server.OpenAITestServer()
ts.Start()
defer ts.Close()
config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
ctx := context.Background()
stream, err := client.CreateChatCompletionStream(ctx, ChatCompletionRequest{})
checks.NoError(t, err)
stream.errAccumulator = &utils.DefaultErrorAccumulator{
Buffer: &test.FailingErrorBuffer{},
}
_, err = stream.Recv()
checks.ErrorIs(t, err, test.ErrTestErrorAccumulatorWriteFailed, "Did not return error when Write failed", err.Error())
}
// Helper funcs. // Helper funcs.
func compareChatResponses(r1, r2 ChatCompletionStreamResponse) bool { func compareChatResponses(r1, r2 ChatCompletionStreamResponse) bool {
if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model { if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model {

View File

@@ -1,53 +0,0 @@
package openai
import (
"bytes"
"fmt"
"io"
utils "github.com/sashabaranov/go-openai/internal"
)
type errorAccumulator interface {
write(p []byte) error
unmarshalError() *ErrorResponse
}
type errorBuffer interface {
io.Writer
Len() int
Bytes() []byte
}
type defaultErrorAccumulator struct {
buffer errorBuffer
unmarshaler utils.Unmarshaler
}
func newErrorAccumulator() errorAccumulator {
return &defaultErrorAccumulator{
buffer: &bytes.Buffer{},
unmarshaler: &utils.JSONUnmarshaler{},
}
}
func (e *defaultErrorAccumulator) write(p []byte) error {
_, err := e.buffer.Write(p)
if err != nil {
return fmt.Errorf("error accumulator write error, %w", err)
}
return nil
}
func (e *defaultErrorAccumulator) unmarshalError() (errResp *ErrorResponse) {
if e.buffer.Len() == 0 {
return
}
err := e.unmarshaler.Unmarshal(e.buffer.Bytes(), &errResp)
if err != nil {
errResp = nil
}
return
}

View File

@@ -1,100 +0,0 @@
package openai //nolint:testpackage // testing private field
import (
"bytes"
"context"
"errors"
"net/http"
"testing"
utils "github.com/sashabaranov/go-openai/internal"
"github.com/sashabaranov/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks"
)
var (
errTestUnmarshalerFailed = errors.New("test unmarshaler failed")
errTestErrorAccumulatorWriteFailed = errors.New("test error accumulator failed")
)
type (
failingUnMarshaller struct{}
failingErrorBuffer struct{}
)
func (b *failingErrorBuffer) Write(_ []byte) (n int, err error) {
return 0, errTestErrorAccumulatorWriteFailed
}
func (b *failingErrorBuffer) Len() int {
return 0
}
func (b *failingErrorBuffer) Bytes() []byte {
return []byte{}
}
func (*failingUnMarshaller) Unmarshal(_ []byte, _ any) error {
return errTestUnmarshalerFailed
}
func TestErrorAccumulatorReturnsUnmarshalerErrors(t *testing.T) {
accumulator := &defaultErrorAccumulator{
buffer: &bytes.Buffer{},
unmarshaler: &failingUnMarshaller{},
}
respErr := accumulator.unmarshalError()
if respErr != nil {
t.Fatalf("Did not return nil with empty buffer: %v", respErr)
}
err := accumulator.write([]byte("{"))
if err != nil {
t.Fatalf("%+v", err)
}
respErr = accumulator.unmarshalError()
if respErr != nil {
t.Fatalf("Did not return nil when unmarshaler failed: %v", respErr)
}
}
func TestErrorByteWriteErrors(t *testing.T) {
accumulator := &defaultErrorAccumulator{
buffer: &failingErrorBuffer{},
unmarshaler: &utils.JSONUnmarshaler{},
}
err := accumulator.write([]byte("{"))
if !errors.Is(err, errTestErrorAccumulatorWriteFailed) {
t.Fatalf("Did not return error when write failed: %v", err)
}
}
func TestErrorAccumulatorWriteErrors(t *testing.T) {
var err error
server := test.NewTestServer()
server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "error", 200)
})
ts := server.OpenAITestServer()
ts.Start()
defer ts.Close()
config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
ctx := context.Background()
stream, err := client.CreateChatCompletionStream(ctx, ChatCompletionRequest{})
checks.NoError(t, err)
stream.errAccumulator = &defaultErrorAccumulator{
buffer: &failingErrorBuffer{},
unmarshaler: &utils.JSONUnmarshaler{},
}
_, err = stream.Recv()
checks.ErrorIs(t, err, errTestErrorAccumulatorWriteFailed, "Did not return error when write failed", err.Error())
}

View File

@@ -0,0 +1,44 @@
package openai
import (
"bytes"
"fmt"
"io"
)
type ErrorAccumulator interface {
Write(p []byte) error
Bytes() []byte
}
type errorBuffer interface {
io.Writer
Len() int
Bytes() []byte
}
type DefaultErrorAccumulator struct {
Buffer errorBuffer
}
func NewErrorAccumulator() ErrorAccumulator {
return &DefaultErrorAccumulator{
Buffer: &bytes.Buffer{},
}
}
func (e *DefaultErrorAccumulator) Write(p []byte) error {
_, err := e.Buffer.Write(p)
if err != nil {
return fmt.Errorf("error accumulator write error, %w", err)
}
return nil
}
func (e *DefaultErrorAccumulator) Bytes() (errBytes []byte) {
if e.Buffer.Len() == 0 {
return
}
errBytes = e.Buffer.Bytes()
return
}

View File

@@ -0,0 +1,41 @@
package openai_test
import (
"bytes"
"errors"
"testing"
utils "github.com/sashabaranov/go-openai/internal"
"github.com/sashabaranov/go-openai/internal/test"
)
func TestErrorAccumulatorBytes(t *testing.T) {
accumulator := &utils.DefaultErrorAccumulator{
Buffer: &bytes.Buffer{},
}
errBytes := accumulator.Bytes()
if len(errBytes) != 0 {
t.Fatalf("Did not return nil with empty bytes: %s", string(errBytes))
}
err := accumulator.Write([]byte("{}"))
if err != nil {
t.Fatalf("%+v", err)
}
errBytes = accumulator.Bytes()
if len(errBytes) == 0 {
t.Fatalf("Did not return error bytes when has error: %s", string(errBytes))
}
}
func TestErrorByteWriteErrors(t *testing.T) {
accumulator := &utils.DefaultErrorAccumulator{
Buffer: &test.FailingErrorBuffer{},
}
err := accumulator.Write([]byte("{"))
if !errors.Is(err, test.ErrTestErrorAccumulatorWriteFailed) {
t.Fatalf("Did not return error when write failed: %v", err)
}
}

21
internal/test/failer.go Normal file
View File

@@ -0,0 +1,21 @@
package test
import "errors"
var (
ErrTestErrorAccumulatorWriteFailed = errors.New("test error accumulator failed")
)
type FailingErrorBuffer struct{}
func (b *FailingErrorBuffer) Write(_ []byte) (n int, err error) {
return 0, ErrTestErrorAccumulatorWriteFailed
}
func (b *FailingErrorBuffer) Len() int {
return 0
}
func (b *FailingErrorBuffer) Bytes() []byte {
return []byte{}
}

View File

@@ -3,6 +3,7 @@ package test
import ( import (
"github.com/sashabaranov/go-openai/internal/test/checks" "github.com/sashabaranov/go-openai/internal/test/checks"
"net/http"
"os" "os"
"testing" "testing"
) )
@@ -27,3 +28,26 @@ func CreateTestDirectory(t *testing.T) (path string, cleanup func()) {
return path, func() { os.RemoveAll(path) } return path, func() { os.RemoveAll(path) }
} }
// TokenRoundTripper is a struct that implements the RoundTripper
// interface, specifically to handle the authentication token by adding a token
// to the request header. We need this because the API requires that each
// request include a valid API token in the headers for authentication and
// authorization.
type TokenRoundTripper struct {
Token string
Fallback http.RoundTripper
}
// RoundTrip takes an *http.Request as input and returns an
// *http.Response and an error.
//
// It is expected to use the provided request to create a connection to an HTTP
// server and return the response, or an error if one occurred. The returned
// Response should have its Body closed. If the RoundTrip method returns an
// error, the Client's Get, Head, Post, and PostForm methods return the same
// error.
func (t *TokenRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
req.Header.Set("Authorization", "Bearer "+t.Token)
return t.Fallback.RoundTrip(req)
}

View File

@@ -55,7 +55,7 @@ func (c *Client) CreateCompletionStream(
emptyMessagesLimit: c.config.EmptyMessagesLimit, emptyMessagesLimit: c.config.EmptyMessagesLimit,
reader: bufio.NewReader(resp.Body), reader: bufio.NewReader(resp.Body),
response: resp, response: resp,
errAccumulator: newErrorAccumulator(), errAccumulator: utils.NewErrorAccumulator(),
unmarshaler: &utils.JSONUnmarshaler{}, unmarshaler: &utils.JSONUnmarshaler{},
}, },
} }

View File

@@ -20,7 +20,7 @@ type streamReader[T streamable] struct {
reader *bufio.Reader reader *bufio.Reader
response *http.Response response *http.Response
errAccumulator errorAccumulator errAccumulator utils.ErrorAccumulator
unmarshaler utils.Unmarshaler unmarshaler utils.Unmarshaler
} }
@@ -35,7 +35,7 @@ func (stream *streamReader[T]) Recv() (response T, err error) {
waitForData: waitForData:
line, err := stream.reader.ReadBytes('\n') line, err := stream.reader.ReadBytes('\n')
if err != nil { if err != nil {
respErr := stream.errAccumulator.unmarshalError() respErr := stream.unmarshalError()
if respErr != nil { if respErr != nil {
err = fmt.Errorf("error, %w", respErr.Error) err = fmt.Errorf("error, %w", respErr.Error)
} }
@@ -45,7 +45,7 @@ waitForData:
var headerData = []byte("data: ") var headerData = []byte("data: ")
line = bytes.TrimSpace(line) line = bytes.TrimSpace(line)
if !bytes.HasPrefix(line, headerData) { if !bytes.HasPrefix(line, headerData) {
if writeErr := stream.errAccumulator.write(line); writeErr != nil { if writeErr := stream.errAccumulator.Write(line); writeErr != nil {
err = writeErr err = writeErr
return return
} }
@@ -69,6 +69,20 @@ waitForData:
return return
} }
func (stream *streamReader[T]) unmarshalError() (errResp *ErrorResponse) {
errBytes := stream.errAccumulator.Bytes()
if len(errBytes) == 0 {
return
}
err := stream.unmarshaler.Unmarshal(errBytes, &errResp)
if err != nil {
errResp = nil
}
return
}
func (stream *streamReader[T]) Close() { func (stream *streamReader[T]) Close() {
stream.response.Body.Close() stream.response.Body.Close()
} }

53
stream_reader_test.go Normal file
View File

@@ -0,0 +1,53 @@
package openai //nolint:testpackage // testing private field
import (
"bufio"
"bytes"
"errors"
"testing"
utils "github.com/sashabaranov/go-openai/internal"
)
var errTestUnmarshalerFailed = errors.New("test unmarshaler failed")
type failingUnMarshaller struct{}
func (*failingUnMarshaller) Unmarshal(_ []byte, _ any) error {
return errTestUnmarshalerFailed
}
func TestStreamReaderReturnsUnmarshalerErrors(t *testing.T) {
stream := &streamReader[ChatCompletionStreamResponse]{
errAccumulator: utils.NewErrorAccumulator(),
unmarshaler: &failingUnMarshaller{},
}
respErr := stream.unmarshalError()
if respErr != nil {
t.Fatalf("Did not return nil with empty buffer: %v", respErr)
}
err := stream.errAccumulator.Write([]byte("{"))
if err != nil {
t.Fatalf("%+v", err)
}
respErr = stream.unmarshalError()
if respErr != nil {
t.Fatalf("Did not return nil when unmarshaler failed: %v", respErr)
}
}
func TestStreamReaderReturnsErrTooManyEmptyStreamMessages(t *testing.T) {
stream := &streamReader[ChatCompletionStreamResponse]{
emptyMessagesLimit: 3,
reader: bufio.NewReader(bytes.NewReader([]byte("\n\n\n\n"))),
errAccumulator: utils.NewErrorAccumulator(),
unmarshaler: &utils.JSONUnmarshaler{},
}
_, err := stream.Recv()
if !errors.Is(err, ErrTooManyEmptyStreamMessages) {
t.Fatalf("Did not return error when recv failed: %v", err)
}
}

View File

@@ -57,9 +57,9 @@ func TestCreateCompletionStream(t *testing.T) {
// Client portion of the test // Client portion of the test
config := DefaultConfig(test.GetTestToken()) config := DefaultConfig(test.GetTestToken())
config.BaseURL = server.URL + "/v1" config.BaseURL = server.URL + "/v1"
config.HTTPClient.Transport = &tokenRoundTripper{ config.HTTPClient.Transport = &test.TokenRoundTripper{
test.GetTestToken(), Token: test.GetTestToken(),
http.DefaultTransport, Fallback: http.DefaultTransport,
} }
client := NewClientWithConfig(config) client := NewClientWithConfig(config)
@@ -142,9 +142,9 @@ func TestCreateCompletionStreamError(t *testing.T) {
// Client portion of the test // Client portion of the test
config := DefaultConfig(test.GetTestToken()) config := DefaultConfig(test.GetTestToken())
config.BaseURL = server.URL + "/v1" config.BaseURL = server.URL + "/v1"
config.HTTPClient.Transport = &tokenRoundTripper{ config.HTTPClient.Transport = &test.TokenRoundTripper{
test.GetTestToken(), Token: test.GetTestToken(),
http.DefaultTransport, Fallback: http.DefaultTransport,
} }
client := NewClientWithConfig(config) client := NewClientWithConfig(config)
@@ -194,9 +194,9 @@ func TestCreateCompletionStreamRateLimitError(t *testing.T) {
// Client portion of the test // Client portion of the test
config := DefaultConfig(test.GetTestToken()) config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1" config.BaseURL = ts.URL + "/v1"
config.HTTPClient.Transport = &tokenRoundTripper{ config.HTTPClient.Transport = &test.TokenRoundTripper{
test.GetTestToken(), Token: test.GetTestToken(),
http.DefaultTransport, Fallback: http.DefaultTransport,
} }
client := NewClientWithConfig(config) client := NewClientWithConfig(config)
@@ -217,29 +217,6 @@ func TestCreateCompletionStreamRateLimitError(t *testing.T) {
t.Logf("%+v\n", apiErr) t.Logf("%+v\n", apiErr)
} }
// A "tokenRoundTripper" is a struct that implements the RoundTripper
// interface, specifically to handle the authentication token by adding a token
// to the request header. We need this because the API requires that each
// request include a valid API token in the headers for authentication and
// authorization.
type tokenRoundTripper struct {
token string
fallback http.RoundTripper
}
// RoundTrip takes an *http.Request as input and returns an
// *http.Response and an error.
//
// It is expected to use the provided request to create a connection to an HTTP
// server and return the response, or an error if one occurred. The returned
// Response should have its Body closed. If the RoundTrip method returns an
// error, the Client's Get, Head, Post, and PostForm methods return the same
// error.
func (t *tokenRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
req.Header.Set("Authorization", "Bearer "+t.token)
return t.fallback.RoundTrip(req)
}
// Helper funcs. // Helper funcs.
func compareResponses(r1, r2 CompletionResponse) bool { func compareResponses(r1, r2 CompletionResponse) bool {
if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model { if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model {