committed by
GitHub
parent
980504b47e
commit
62eb4beed2
@@ -4,6 +4,8 @@ import (
|
|||||||
"bufio"
|
"bufio"
|
||||||
"context"
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
|
utils "github.com/sashabaranov/go-openai/internal"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ChatCompletionStreamChoiceDelta struct {
|
type ChatCompletionStreamChoiceDelta struct {
|
||||||
@@ -65,7 +67,7 @@ func (c *Client) CreateChatCompletionStream(
|
|||||||
reader: bufio.NewReader(resp.Body),
|
reader: bufio.NewReader(resp.Body),
|
||||||
response: resp,
|
response: resp,
|
||||||
errAccumulator: newErrorAccumulator(),
|
errAccumulator: newErrorAccumulator(),
|
||||||
unmarshaler: &jsonUnmarshaler{},
|
unmarshaler: &utils.JSONUnmarshaler{},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -4,6 +4,8 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
|
||||||
|
utils "github.com/sashabaranov/go-openai/internal"
|
||||||
)
|
)
|
||||||
|
|
||||||
type errorAccumulator interface {
|
type errorAccumulator interface {
|
||||||
@@ -19,13 +21,13 @@ type errorBuffer interface {
|
|||||||
|
|
||||||
type defaultErrorAccumulator struct {
|
type defaultErrorAccumulator struct {
|
||||||
buffer errorBuffer
|
buffer errorBuffer
|
||||||
unmarshaler unmarshaler
|
unmarshaler utils.Unmarshaler
|
||||||
}
|
}
|
||||||
|
|
||||||
func newErrorAccumulator() errorAccumulator {
|
func newErrorAccumulator() errorAccumulator {
|
||||||
return &defaultErrorAccumulator{
|
return &defaultErrorAccumulator{
|
||||||
buffer: &bytes.Buffer{},
|
buffer: &bytes.Buffer{},
|
||||||
unmarshaler: &jsonUnmarshaler{},
|
unmarshaler: &utils.JSONUnmarshaler{},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -42,7 +44,7 @@ func (e *defaultErrorAccumulator) unmarshalError() (errResp *ErrorResponse) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err := e.unmarshaler.unmarshal(e.buffer.Bytes(), &errResp)
|
err := e.unmarshaler.Unmarshal(e.buffer.Bytes(), &errResp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errResp = nil
|
errResp = nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
utils "github.com/sashabaranov/go-openai/internal"
|
||||||
"github.com/sashabaranov/go-openai/internal/test"
|
"github.com/sashabaranov/go-openai/internal/test"
|
||||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||||
)
|
)
|
||||||
@@ -33,7 +34,7 @@ func (b *failingErrorBuffer) Bytes() []byte {
|
|||||||
return []byte{}
|
return []byte{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*failingUnMarshaller) unmarshal(_ []byte, _ any) error {
|
func (*failingUnMarshaller) Unmarshal(_ []byte, _ any) error {
|
||||||
return errTestUnmarshalerFailed
|
return errTestUnmarshalerFailed
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -62,7 +63,7 @@ func TestErrorAccumulatorReturnsUnmarshalerErrors(t *testing.T) {
|
|||||||
func TestErrorByteWriteErrors(t *testing.T) {
|
func TestErrorByteWriteErrors(t *testing.T) {
|
||||||
accumulator := &defaultErrorAccumulator{
|
accumulator := &defaultErrorAccumulator{
|
||||||
buffer: &failingErrorBuffer{},
|
buffer: &failingErrorBuffer{},
|
||||||
unmarshaler: &jsonUnmarshaler{},
|
unmarshaler: &utils.JSONUnmarshaler{},
|
||||||
}
|
}
|
||||||
err := accumulator.write([]byte("{"))
|
err := accumulator.write([]byte("{"))
|
||||||
if !errors.Is(err, errTestErrorAccumulatorWriteFailed) {
|
if !errors.Is(err, errTestErrorAccumulatorWriteFailed) {
|
||||||
@@ -91,7 +92,7 @@ func TestErrorAccumulatorWriteErrors(t *testing.T) {
|
|||||||
|
|
||||||
stream.errAccumulator = &defaultErrorAccumulator{
|
stream.errAccumulator = &defaultErrorAccumulator{
|
||||||
buffer: &failingErrorBuffer{},
|
buffer: &failingErrorBuffer{},
|
||||||
unmarshaler: &jsonUnmarshaler{},
|
unmarshaler: &utils.JSONUnmarshaler{},
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = stream.Recv()
|
_, err = stream.Recv()
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
package openai //nolint:testpackage // testing private field
|
package openai //nolint:testpackage // testing private field
|
||||||
|
|
||||||
import (
|
import (
|
||||||
. "github.com/sashabaranov/go-openai/internal"
|
utils "github.com/sashabaranov/go-openai/internal"
|
||||||
"github.com/sashabaranov/go-openai/internal/test"
|
"github.com/sashabaranov/go-openai/internal/test"
|
||||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||||
|
|
||||||
@@ -86,7 +86,7 @@ func TestFileUploadWithFailingFormBuilder(t *testing.T) {
|
|||||||
config.BaseURL = ""
|
config.BaseURL = ""
|
||||||
client := NewClientWithConfig(config)
|
client := NewClientWithConfig(config)
|
||||||
mockBuilder := &mockFormBuilder{}
|
mockBuilder := &mockFormBuilder{}
|
||||||
client.createFormBuilder = func(io.Writer) FormBuilder {
|
client.createFormBuilder = func(io.Writer) utils.FormBuilder {
|
||||||
return mockBuilder
|
return mockBuilder
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
15
internal/marshaller.go
Normal file
15
internal/marshaller.go
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Marshaller interface {
|
||||||
|
Marshal(value any) ([]byte, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type JSONMarshaller struct{}
|
||||||
|
|
||||||
|
func (jm *JSONMarshaller) Marshal(value any) ([]byte, error) {
|
||||||
|
return json.Marshal(value)
|
||||||
|
}
|
||||||
15
internal/unmarshaler.go
Normal file
15
internal/unmarshaler.go
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Unmarshaler interface {
|
||||||
|
Unmarshal(data []byte, v any) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type JSONUnmarshaler struct{}
|
||||||
|
|
||||||
|
func (jm *JSONUnmarshaler) Unmarshal(data []byte, v any) error {
|
||||||
|
return json.Unmarshal(data, v)
|
||||||
|
}
|
||||||
@@ -1,15 +0,0 @@
|
|||||||
package openai
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
)
|
|
||||||
|
|
||||||
type marshaller interface {
|
|
||||||
marshal(value any) ([]byte, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
type jsonMarshaller struct{}
|
|
||||||
|
|
||||||
func (jm *jsonMarshaller) marshal(value any) ([]byte, error) {
|
|
||||||
return json.Marshal(value)
|
|
||||||
}
|
|
||||||
@@ -4,6 +4,8 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
|
utils "github.com/sashabaranov/go-openai/internal"
|
||||||
)
|
)
|
||||||
|
|
||||||
type requestBuilder interface {
|
type requestBuilder interface {
|
||||||
@@ -11,12 +13,12 @@ type requestBuilder interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type httpRequestBuilder struct {
|
type httpRequestBuilder struct {
|
||||||
marshaller marshaller
|
marshaller utils.Marshaller
|
||||||
}
|
}
|
||||||
|
|
||||||
func newRequestBuilder() *httpRequestBuilder {
|
func newRequestBuilder() *httpRequestBuilder {
|
||||||
return &httpRequestBuilder{
|
return &httpRequestBuilder{
|
||||||
marshaller: &jsonMarshaller{},
|
marshaller: &utils.JSONMarshaller{},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -26,7 +28,7 @@ func (b *httpRequestBuilder) build(ctx context.Context, method, url string, requ
|
|||||||
}
|
}
|
||||||
|
|
||||||
var reqBytes []byte
|
var reqBytes []byte
|
||||||
reqBytes, err := b.marshaller.marshal(request)
|
reqBytes, err := b.marshaller.Marshal(request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ type (
|
|||||||
failingMarshaller struct{}
|
failingMarshaller struct{}
|
||||||
)
|
)
|
||||||
|
|
||||||
func (*failingMarshaller) marshal(_ any) ([]byte, error) {
|
func (*failingMarshaller) Marshal(_ any) ([]byte, error) {
|
||||||
return []byte{}, errTestMarshallerFailed
|
return []byte{}, errTestMarshallerFailed
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -5,6 +5,8 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
|
utils "github.com/sashabaranov/go-openai/internal"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -54,7 +56,7 @@ func (c *Client) CreateCompletionStream(
|
|||||||
reader: bufio.NewReader(resp.Body),
|
reader: bufio.NewReader(resp.Body),
|
||||||
response: resp,
|
response: resp,
|
||||||
errAccumulator: newErrorAccumulator(),
|
errAccumulator: newErrorAccumulator(),
|
||||||
unmarshaler: &jsonUnmarshaler{},
|
unmarshaler: &utils.JSONUnmarshaler{},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -6,6 +6,8 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
|
utils "github.com/sashabaranov/go-openai/internal"
|
||||||
)
|
)
|
||||||
|
|
||||||
type streamable interface {
|
type streamable interface {
|
||||||
@@ -19,7 +21,7 @@ type streamReader[T streamable] struct {
|
|||||||
reader *bufio.Reader
|
reader *bufio.Reader
|
||||||
response *http.Response
|
response *http.Response
|
||||||
errAccumulator errorAccumulator
|
errAccumulator errorAccumulator
|
||||||
unmarshaler unmarshaler
|
unmarshaler utils.Unmarshaler
|
||||||
}
|
}
|
||||||
|
|
||||||
func (stream *streamReader[T]) Recv() (response T, err error) {
|
func (stream *streamReader[T]) Recv() (response T, err error) {
|
||||||
@@ -63,7 +65,7 @@ waitForData:
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = stream.unmarshaler.unmarshal(line, &response)
|
err = stream.unmarshaler.Unmarshal(line, &response)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,15 +0,0 @@
|
|||||||
package openai
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
)
|
|
||||||
|
|
||||||
type unmarshaler interface {
|
|
||||||
unmarshal(data []byte, v any) error
|
|
||||||
}
|
|
||||||
|
|
||||||
type jsonUnmarshaler struct{}
|
|
||||||
|
|
||||||
func (jm *jsonUnmarshaler) unmarshal(data []byte, v any) error {
|
|
||||||
return json.Unmarshal(data, v)
|
|
||||||
}
|
|
||||||
Reference in New Issue
Block a user