committed by
GitHub
parent
980504b47e
commit
62eb4beed2
@@ -4,6 +4,8 @@ import (
|
||||
"bufio"
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
utils "github.com/sashabaranov/go-openai/internal"
|
||||
)
|
||||
|
||||
type ChatCompletionStreamChoiceDelta struct {
|
||||
@@ -65,7 +67,7 @@ func (c *Client) CreateChatCompletionStream(
|
||||
reader: bufio.NewReader(resp.Body),
|
||||
response: resp,
|
||||
errAccumulator: newErrorAccumulator(),
|
||||
unmarshaler: &jsonUnmarshaler{},
|
||||
unmarshaler: &utils.JSONUnmarshaler{},
|
||||
},
|
||||
}
|
||||
return
|
||||
|
||||
@@ -4,6 +4,8 @@ import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
utils "github.com/sashabaranov/go-openai/internal"
|
||||
)
|
||||
|
||||
type errorAccumulator interface {
|
||||
@@ -19,13 +21,13 @@ type errorBuffer interface {
|
||||
|
||||
type defaultErrorAccumulator struct {
|
||||
buffer errorBuffer
|
||||
unmarshaler unmarshaler
|
||||
unmarshaler utils.Unmarshaler
|
||||
}
|
||||
|
||||
func newErrorAccumulator() errorAccumulator {
|
||||
return &defaultErrorAccumulator{
|
||||
buffer: &bytes.Buffer{},
|
||||
unmarshaler: &jsonUnmarshaler{},
|
||||
unmarshaler: &utils.JSONUnmarshaler{},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -42,7 +44,7 @@ func (e *defaultErrorAccumulator) unmarshalError() (errResp *ErrorResponse) {
|
||||
return
|
||||
}
|
||||
|
||||
err := e.unmarshaler.unmarshal(e.buffer.Bytes(), &errResp)
|
||||
err := e.unmarshaler.Unmarshal(e.buffer.Bytes(), &errResp)
|
||||
if err != nil {
|
||||
errResp = nil
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
utils "github.com/sashabaranov/go-openai/internal"
|
||||
"github.com/sashabaranov/go-openai/internal/test"
|
||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||
)
|
||||
@@ -33,7 +34,7 @@ func (b *failingErrorBuffer) Bytes() []byte {
|
||||
return []byte{}
|
||||
}
|
||||
|
||||
func (*failingUnMarshaller) unmarshal(_ []byte, _ any) error {
|
||||
func (*failingUnMarshaller) Unmarshal(_ []byte, _ any) error {
|
||||
return errTestUnmarshalerFailed
|
||||
}
|
||||
|
||||
@@ -62,7 +63,7 @@ func TestErrorAccumulatorReturnsUnmarshalerErrors(t *testing.T) {
|
||||
func TestErrorByteWriteErrors(t *testing.T) {
|
||||
accumulator := &defaultErrorAccumulator{
|
||||
buffer: &failingErrorBuffer{},
|
||||
unmarshaler: &jsonUnmarshaler{},
|
||||
unmarshaler: &utils.JSONUnmarshaler{},
|
||||
}
|
||||
err := accumulator.write([]byte("{"))
|
||||
if !errors.Is(err, errTestErrorAccumulatorWriteFailed) {
|
||||
@@ -91,7 +92,7 @@ func TestErrorAccumulatorWriteErrors(t *testing.T) {
|
||||
|
||||
stream.errAccumulator = &defaultErrorAccumulator{
|
||||
buffer: &failingErrorBuffer{},
|
||||
unmarshaler: &jsonUnmarshaler{},
|
||||
unmarshaler: &utils.JSONUnmarshaler{},
|
||||
}
|
||||
|
||||
_, err = stream.Recv()
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package openai //nolint:testpackage // testing private field
|
||||
|
||||
import (
|
||||
. "github.com/sashabaranov/go-openai/internal"
|
||||
utils "github.com/sashabaranov/go-openai/internal"
|
||||
"github.com/sashabaranov/go-openai/internal/test"
|
||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||
|
||||
@@ -86,7 +86,7 @@ func TestFileUploadWithFailingFormBuilder(t *testing.T) {
|
||||
config.BaseURL = ""
|
||||
client := NewClientWithConfig(config)
|
||||
mockBuilder := &mockFormBuilder{}
|
||||
client.createFormBuilder = func(io.Writer) FormBuilder {
|
||||
client.createFormBuilder = func(io.Writer) utils.FormBuilder {
|
||||
return mockBuilder
|
||||
}
|
||||
|
||||
|
||||
15
internal/marshaller.go
Normal file
15
internal/marshaller.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
type Marshaller interface {
|
||||
Marshal(value any) ([]byte, error)
|
||||
}
|
||||
|
||||
type JSONMarshaller struct{}
|
||||
|
||||
func (jm *JSONMarshaller) Marshal(value any) ([]byte, error) {
|
||||
return json.Marshal(value)
|
||||
}
|
||||
15
internal/unmarshaler.go
Normal file
15
internal/unmarshaler.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
type Unmarshaler interface {
|
||||
Unmarshal(data []byte, v any) error
|
||||
}
|
||||
|
||||
type JSONUnmarshaler struct{}
|
||||
|
||||
func (jm *JSONUnmarshaler) Unmarshal(data []byte, v any) error {
|
||||
return json.Unmarshal(data, v)
|
||||
}
|
||||
@@ -1,15 +0,0 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
type marshaller interface {
|
||||
marshal(value any) ([]byte, error)
|
||||
}
|
||||
|
||||
type jsonMarshaller struct{}
|
||||
|
||||
func (jm *jsonMarshaller) marshal(value any) ([]byte, error) {
|
||||
return json.Marshal(value)
|
||||
}
|
||||
@@ -4,6 +4,8 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
utils "github.com/sashabaranov/go-openai/internal"
|
||||
)
|
||||
|
||||
type requestBuilder interface {
|
||||
@@ -11,12 +13,12 @@ type requestBuilder interface {
|
||||
}
|
||||
|
||||
type httpRequestBuilder struct {
|
||||
marshaller marshaller
|
||||
marshaller utils.Marshaller
|
||||
}
|
||||
|
||||
func newRequestBuilder() *httpRequestBuilder {
|
||||
return &httpRequestBuilder{
|
||||
marshaller: &jsonMarshaller{},
|
||||
marshaller: &utils.JSONMarshaller{},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -26,7 +28,7 @@ func (b *httpRequestBuilder) build(ctx context.Context, method, url string, requ
|
||||
}
|
||||
|
||||
var reqBytes []byte
|
||||
reqBytes, err := b.marshaller.marshal(request)
|
||||
reqBytes, err := b.marshaller.Marshal(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -19,7 +19,7 @@ type (
|
||||
failingMarshaller struct{}
|
||||
)
|
||||
|
||||
func (*failingMarshaller) marshal(_ any) ([]byte, error) {
|
||||
func (*failingMarshaller) Marshal(_ any) ([]byte, error) {
|
||||
return []byte{}, errTestMarshallerFailed
|
||||
}
|
||||
|
||||
|
||||
@@ -5,6 +5,8 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
utils "github.com/sashabaranov/go-openai/internal"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -54,7 +56,7 @@ func (c *Client) CreateCompletionStream(
|
||||
reader: bufio.NewReader(resp.Body),
|
||||
response: resp,
|
||||
errAccumulator: newErrorAccumulator(),
|
||||
unmarshaler: &jsonUnmarshaler{},
|
||||
unmarshaler: &utils.JSONUnmarshaler{},
|
||||
},
|
||||
}
|
||||
return
|
||||
|
||||
@@ -6,6 +6,8 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
utils "github.com/sashabaranov/go-openai/internal"
|
||||
)
|
||||
|
||||
type streamable interface {
|
||||
@@ -19,7 +21,7 @@ type streamReader[T streamable] struct {
|
||||
reader *bufio.Reader
|
||||
response *http.Response
|
||||
errAccumulator errorAccumulator
|
||||
unmarshaler unmarshaler
|
||||
unmarshaler utils.Unmarshaler
|
||||
}
|
||||
|
||||
func (stream *streamReader[T]) Recv() (response T, err error) {
|
||||
@@ -63,7 +65,7 @@ waitForData:
|
||||
return
|
||||
}
|
||||
|
||||
err = stream.unmarshaler.unmarshal(line, &response)
|
||||
err = stream.unmarshaler.Unmarshal(line, &response)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -1,15 +0,0 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
type unmarshaler interface {
|
||||
unmarshal(data []byte, v any) error
|
||||
}
|
||||
|
||||
type jsonUnmarshaler struct{}
|
||||
|
||||
func (jm *jsonUnmarshaler) unmarshal(data []byte, v any) error {
|
||||
return json.Unmarshal(data, v)
|
||||
}
|
||||
Reference in New Issue
Block a user