move marshaller and unmarshaler into internal pkg (#304) (#325)

This commit is contained in:
渡邉祐一 / Yuichi Watanabe
2023-05-28 10:51:07 +09:00
committed by GitHub
parent 980504b47e
commit 62eb4beed2
12 changed files with 57 additions and 46 deletions

View File

@@ -4,6 +4,8 @@ import (
"bufio"
"context"
"net/http"
utils "github.com/sashabaranov/go-openai/internal"
)
type ChatCompletionStreamChoiceDelta struct {
@@ -65,7 +67,7 @@ func (c *Client) CreateChatCompletionStream(
reader: bufio.NewReader(resp.Body),
response: resp,
errAccumulator: newErrorAccumulator(),
unmarshaler: &jsonUnmarshaler{},
unmarshaler: &utils.JSONUnmarshaler{},
},
}
return

View File

@@ -4,6 +4,8 @@ import (
"bytes"
"fmt"
"io"
utils "github.com/sashabaranov/go-openai/internal"
)
type errorAccumulator interface {
@@ -19,13 +21,13 @@ type errorBuffer interface {
type defaultErrorAccumulator struct {
buffer errorBuffer
unmarshaler unmarshaler
unmarshaler utils.Unmarshaler
}
func newErrorAccumulator() errorAccumulator {
return &defaultErrorAccumulator{
buffer: &bytes.Buffer{},
unmarshaler: &jsonUnmarshaler{},
unmarshaler: &utils.JSONUnmarshaler{},
}
}
@@ -42,7 +44,7 @@ func (e *defaultErrorAccumulator) unmarshalError() (errResp *ErrorResponse) {
return
}
err := e.unmarshaler.unmarshal(e.buffer.Bytes(), &errResp)
err := e.unmarshaler.Unmarshal(e.buffer.Bytes(), &errResp)
if err != nil {
errResp = nil
}

View File

@@ -7,6 +7,7 @@ import (
"net/http"
"testing"
utils "github.com/sashabaranov/go-openai/internal"
"github.com/sashabaranov/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks"
)
@@ -33,7 +34,7 @@ func (b *failingErrorBuffer) Bytes() []byte {
return []byte{}
}
func (*failingUnMarshaller) unmarshal(_ []byte, _ any) error {
func (*failingUnMarshaller) Unmarshal(_ []byte, _ any) error {
return errTestUnmarshalerFailed
}
@@ -62,7 +63,7 @@ func TestErrorAccumulatorReturnsUnmarshalerErrors(t *testing.T) {
func TestErrorByteWriteErrors(t *testing.T) {
accumulator := &defaultErrorAccumulator{
buffer: &failingErrorBuffer{},
unmarshaler: &jsonUnmarshaler{},
unmarshaler: &utils.JSONUnmarshaler{},
}
err := accumulator.write([]byte("{"))
if !errors.Is(err, errTestErrorAccumulatorWriteFailed) {
@@ -91,7 +92,7 @@ func TestErrorAccumulatorWriteErrors(t *testing.T) {
stream.errAccumulator = &defaultErrorAccumulator{
buffer: &failingErrorBuffer{},
unmarshaler: &jsonUnmarshaler{},
unmarshaler: &utils.JSONUnmarshaler{},
}
_, err = stream.Recv()

View File

@@ -1,7 +1,7 @@
package openai //nolint:testpackage // testing private field
import (
. "github.com/sashabaranov/go-openai/internal"
utils "github.com/sashabaranov/go-openai/internal"
"github.com/sashabaranov/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks"
@@ -86,7 +86,7 @@ func TestFileUploadWithFailingFormBuilder(t *testing.T) {
config.BaseURL = ""
client := NewClientWithConfig(config)
mockBuilder := &mockFormBuilder{}
client.createFormBuilder = func(io.Writer) FormBuilder {
client.createFormBuilder = func(io.Writer) utils.FormBuilder {
return mockBuilder
}

15
internal/marshaller.go Normal file
View File

@@ -0,0 +1,15 @@
package openai
import (
"encoding/json"
)
type Marshaller interface {
Marshal(value any) ([]byte, error)
}
type JSONMarshaller struct{}
func (jm *JSONMarshaller) Marshal(value any) ([]byte, error) {
return json.Marshal(value)
}

15
internal/unmarshaler.go Normal file
View File

@@ -0,0 +1,15 @@
package openai
import (
"encoding/json"
)
type Unmarshaler interface {
Unmarshal(data []byte, v any) error
}
type JSONUnmarshaler struct{}
func (jm *JSONUnmarshaler) Unmarshal(data []byte, v any) error {
return json.Unmarshal(data, v)
}

View File

@@ -1,15 +0,0 @@
package openai
import (
"encoding/json"
)
type marshaller interface {
marshal(value any) ([]byte, error)
}
type jsonMarshaller struct{}
func (jm *jsonMarshaller) marshal(value any) ([]byte, error) {
return json.Marshal(value)
}

View File

@@ -4,6 +4,8 @@ import (
"bytes"
"context"
"net/http"
utils "github.com/sashabaranov/go-openai/internal"
)
type requestBuilder interface {
@@ -11,12 +13,12 @@ type requestBuilder interface {
}
type httpRequestBuilder struct {
marshaller marshaller
marshaller utils.Marshaller
}
func newRequestBuilder() *httpRequestBuilder {
return &httpRequestBuilder{
marshaller: &jsonMarshaller{},
marshaller: &utils.JSONMarshaller{},
}
}
@@ -26,7 +28,7 @@ func (b *httpRequestBuilder) build(ctx context.Context, method, url string, requ
}
var reqBytes []byte
reqBytes, err := b.marshaller.marshal(request)
reqBytes, err := b.marshaller.Marshal(request)
if err != nil {
return nil, err
}

View File

@@ -19,7 +19,7 @@ type (
failingMarshaller struct{}
)
func (*failingMarshaller) marshal(_ any) ([]byte, error) {
func (*failingMarshaller) Marshal(_ any) ([]byte, error) {
return []byte{}, errTestMarshallerFailed
}

View File

@@ -5,6 +5,8 @@ import (
"context"
"errors"
"net/http"
utils "github.com/sashabaranov/go-openai/internal"
)
var (
@@ -54,7 +56,7 @@ func (c *Client) CreateCompletionStream(
reader: bufio.NewReader(resp.Body),
response: resp,
errAccumulator: newErrorAccumulator(),
unmarshaler: &jsonUnmarshaler{},
unmarshaler: &utils.JSONUnmarshaler{},
},
}
return

View File

@@ -6,6 +6,8 @@ import (
"fmt"
"io"
"net/http"
utils "github.com/sashabaranov/go-openai/internal"
)
type streamable interface {
@@ -19,7 +21,7 @@ type streamReader[T streamable] struct {
reader *bufio.Reader
response *http.Response
errAccumulator errorAccumulator
unmarshaler unmarshaler
unmarshaler utils.Unmarshaler
}
func (stream *streamReader[T]) Recv() (response T, err error) {
@@ -63,7 +65,7 @@ waitForData:
return
}
err = stream.unmarshaler.unmarshal(line, &response)
err = stream.unmarshaler.Unmarshal(line, &response)
return
}

View File

@@ -1,15 +0,0 @@
package openai
import (
"encoding/json"
)
type unmarshaler interface {
unmarshal(data []byte, v any) error
}
type jsonUnmarshaler struct{}
func (jm *jsonUnmarshaler) unmarshal(data []byte, v any) error {
return json.Unmarshal(data, v)
}