simplify unmarshal (#191)

* simplify unmarshal

* simplify unmarshalError

* rename errorAccumulate -> defaultErrorAccumulator

* update converage
This commit is contained in:
sashabaranov
2023-03-22 09:56:05 +04:00
committed by GitHub
parent a5a945ad14
commit eb68a72bcc
3 changed files with 29 additions and 22 deletions

View File

@@ -8,7 +8,7 @@ import (
type errorAccumulator interface { type errorAccumulator interface {
write(p []byte) error write(p []byte) error
unmarshalError() (*ErrorResponse, error) unmarshalError() *ErrorResponse
} }
type errorBuffer interface { type errorBuffer interface {
@@ -17,19 +17,19 @@ type errorBuffer interface {
Bytes() []byte Bytes() []byte
} }
type errorAccumulate struct { type defaultErrorAccumulator struct {
buffer errorBuffer buffer errorBuffer
unmarshaler unmarshaler unmarshaler unmarshaler
} }
func newErrorAccumulator() errorAccumulator { func newErrorAccumulator() errorAccumulator {
return &errorAccumulate{ return &defaultErrorAccumulator{
buffer: &bytes.Buffer{}, buffer: &bytes.Buffer{},
unmarshaler: &jsonUnmarshaler{}, unmarshaler: &jsonUnmarshaler{},
} }
} }
func (e *errorAccumulate) write(p []byte) error { func (e *defaultErrorAccumulator) write(p []byte) error {
_, err := e.buffer.Write(p) _, err := e.buffer.Write(p)
if err != nil { if err != nil {
return fmt.Errorf("error accumulator write error, %w", err) return fmt.Errorf("error accumulator write error, %w", err)
@@ -37,15 +37,15 @@ func (e *errorAccumulate) write(p []byte) error {
return nil return nil
} }
func (e *errorAccumulate) unmarshalError() (*ErrorResponse, error) { func (e *defaultErrorAccumulator) unmarshalError() (errResp *ErrorResponse) {
var err error if e.buffer.Len() == 0 {
if e.buffer.Len() > 0 { return
var errRes ErrorResponse }
err = e.unmarshaler.unmarshal(e.buffer.Bytes(), &errRes)
err := e.unmarshaler.unmarshal(e.buffer.Bytes(), &errResp)
if err != nil { if err != nil {
return nil, err errResp = nil
} }
return &errRes, nil
} return
return nil, err
} }

View File

@@ -36,23 +36,29 @@ func (*failingUnMarshaller) unmarshal(_ []byte, _ any) error {
} }
func TestErrorAccumulatorReturnsUnmarshalerErrors(t *testing.T) { func TestErrorAccumulatorReturnsUnmarshalerErrors(t *testing.T) {
accumulator := &errorAccumulate{ accumulator := &defaultErrorAccumulator{
buffer: &bytes.Buffer{}, buffer: &bytes.Buffer{},
unmarshaler: &failingUnMarshaller{}, unmarshaler: &failingUnMarshaller{},
} }
respErr := accumulator.unmarshalError()
if respErr != nil {
t.Fatalf("Did not return nil with empty buffer: %v", respErr)
}
err := accumulator.write([]byte("{")) err := accumulator.write([]byte("{"))
if err != nil { if err != nil {
t.Fatalf("%+v", err) t.Fatalf("%+v", err)
} }
_, err = accumulator.unmarshalError()
if !errors.Is(err, errTestUnmarshalerFailed) { respErr = accumulator.unmarshalError()
t.Fatalf("Did not return error when unmarshaler failed: %v", err) if respErr != nil {
t.Fatalf("Did not return nil when unmarshaler failed: %v", respErr)
} }
} }
func TestErrorByteWriteErrors(t *testing.T) { func TestErrorByteWriteErrors(t *testing.T) {
accumulator := &errorAccumulate{ accumulator := &defaultErrorAccumulator{
buffer: &failingErrorBuffer{}, buffer: &failingErrorBuffer{},
unmarshaler: &jsonUnmarshaler{}, unmarshaler: &jsonUnmarshaler{},
} }
@@ -78,7 +84,7 @@ func TestErrorAccumulatorWriteErrors(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
stream.errAccumulator = &errorAccumulate{ stream.errAccumulator = &defaultErrorAccumulator{
buffer: &failingErrorBuffer{}, buffer: &failingErrorBuffer{},
unmarshaler: &jsonUnmarshaler{}, unmarshaler: &jsonUnmarshaler{},
} }

View File

@@ -33,8 +33,9 @@ func (stream *streamReader[T]) Recv() (response T, err error) {
waitForData: waitForData:
line, err := stream.reader.ReadBytes('\n') line, err := stream.reader.ReadBytes('\n')
if err != nil { if err != nil {
if errRes, _ := stream.errAccumulator.unmarshalError(); errRes != nil { respErr := stream.errAccumulator.unmarshalError()
err = fmt.Errorf("error, %w", errRes.Error) if respErr != nil {
err = fmt.Errorf("error, %w", respErr.Error)
} }
return return
} }