From b8c13e4c017ab031ede870da538c0b0e3cf4996b Mon Sep 17 00:00:00 2001 From: Yuki Bobier Koshimizu Date: Fri, 9 Jun 2023 00:31:25 +0900 Subject: [PATCH] Refactor streamReader: Replace goto Statement with Loop in Recv Method (#339) * test: Add tests for improved coverage before refactoring This commit adds tests to improve coverage before refactoring to ensure that the changes do not break anything. * refactor: replace goto statement with loop This commit introduces a refactor to improve the clarity of the control flow within the method. The goto statement can sometimes make the code hard to understand and maintain, hence this refactor aims to resolve that. * refactor: extract for-loop from Recv to another method This commit improves code readability and maintainability by making the Recv method simpler. --- stream_reader.go | 71 ++++++++++++--------- stream_test.go | 160 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 200 insertions(+), 31 deletions(-) diff --git a/stream_reader.go b/stream_reader.go index a9940b0..3416198 100644 --- a/stream_reader.go +++ b/stream_reader.go @@ -30,43 +30,52 @@ func (stream *streamReader[T]) Recv() (response T, err error) { return } + response, err = stream.processLines() + return +} + +func (stream *streamReader[T]) processLines() (T, error) { var emptyMessagesCount uint -waitForData: - line, err := stream.reader.ReadBytes('\n') - if err != nil { - respErr := stream.unmarshalError() - if respErr != nil { - err = fmt.Errorf("error, %w", respErr.Error) - } - return - } - - var headerData = []byte("data: ") - line = bytes.TrimSpace(line) - if !bytes.HasPrefix(line, headerData) { - if writeErr := stream.errAccumulator.Write(line); writeErr != nil { - err = writeErr - return - } - emptyMessagesCount++ - if emptyMessagesCount > stream.emptyMessagesLimit { - err = ErrTooManyEmptyStreamMessages - return + for { + rawLine, readErr := stream.reader.ReadBytes('\n') + if readErr != nil { + respErr := stream.unmarshalError() + if respErr != nil { + return *new(T), fmt.Errorf("error, %w", respErr.Error) + } + return *new(T), readErr } - goto waitForData - } + var headerData = []byte("data: ") + noSpaceLine := bytes.TrimSpace(rawLine) + if !bytes.HasPrefix(noSpaceLine, headerData) { + writeErr := stream.errAccumulator.Write(noSpaceLine) + if writeErr != nil { + return *new(T), writeErr + } + emptyMessagesCount++ + if emptyMessagesCount > stream.emptyMessagesLimit { + return *new(T), ErrTooManyEmptyStreamMessages + } - line = bytes.TrimPrefix(line, headerData) - if string(line) == "[DONE]" { - stream.isFinished = true - err = io.EOF - return - } + continue + } - err = stream.unmarshaler.Unmarshal(line, &response) - return + noPrefixLine := bytes.TrimPrefix(noSpaceLine, headerData) + if string(noPrefixLine) == "[DONE]" { + stream.isFinished = true + return *new(T), io.EOF + } + + var response T + unmarshalErr := stream.unmarshaler.Unmarshal(noPrefixLine, &response) + if unmarshalErr != nil { + return *new(T), unmarshalErr + } + + return response, nil + } } func (stream *streamReader[T]) unmarshalError() (errResp *ErrorResponse) { diff --git a/stream_test.go b/stream_test.go index 589fc9e..0faa212 100644 --- a/stream_test.go +++ b/stream_test.go @@ -2,6 +2,7 @@ package openai_test import ( "context" + "encoding/json" "errors" "io" "net/http" @@ -217,6 +218,165 @@ func TestCreateCompletionStreamRateLimitError(t *testing.T) { t.Logf("%+v\n", apiErr) } +func TestCreateCompletionStreamTooManyEmptyStreamMessagesError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + + // Send test responses + dataBytes := []byte{} + dataBytes = append(dataBytes, []byte("event: message\n")...) + //nolint:lll + data := `{"id":"1","object":"completion","created":1598069254,"model":"text-davinci-002","choices":[{"text":"response1","finish_reason":"max_tokens"}]}` + dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) + + // Totally 301 empty messages (300 is the limit) + for i := 0; i < 299; i++ { + dataBytes = append(dataBytes, '\n') + } + + dataBytes = append(dataBytes, []byte("event: message\n")...) + //nolint:lll + data = `{"id":"2","object":"completion","created":1598069255,"model":"text-davinci-002","choices":[{"text":"response2","finish_reason":"max_tokens"}]}` + dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) + + dataBytes = append(dataBytes, []byte("event: done\n")...) + dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...) + + _, err := w.Write(dataBytes) + checks.NoError(t, err, "Write error") + })) + defer server.Close() + + // Client portion of the test + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = server.URL + "/v1" + config.HTTPClient.Transport = &test.TokenRoundTripper{ + Token: test.GetTestToken(), + Fallback: http.DefaultTransport, + } + + client := NewClientWithConfig(config) + ctx := context.Background() + + request := CompletionRequest{ + Prompt: "Ex falso quodlibet", + Model: "text-davinci-002", + MaxTokens: 10, + Stream: true, + } + + stream, err := client.CreateCompletionStream(ctx, request) + checks.NoError(t, err, "CreateCompletionStream returned error") + defer stream.Close() + + _, _ = stream.Recv() + _, streamErr := stream.Recv() + if !errors.Is(streamErr, ErrTooManyEmptyStreamMessages) { + t.Errorf("TestCreateCompletionStreamTooManyEmptyStreamMessagesError did not return ErrTooManyEmptyStreamMessages") + } +} + +func TestCreateCompletionStreamUnexpectedTerminatedError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + + // Send test responses + dataBytes := []byte{} + dataBytes = append(dataBytes, []byte("event: message\n")...) + //nolint:lll + data := `{"id":"1","object":"completion","created":1598069254,"model":"text-davinci-002","choices":[{"text":"response1","finish_reason":"max_tokens"}]}` + dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) + + // Stream is terminated without sending "done" message + + _, err := w.Write(dataBytes) + checks.NoError(t, err, "Write error") + })) + defer server.Close() + + // Client portion of the test + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = server.URL + "/v1" + config.HTTPClient.Transport = &test.TokenRoundTripper{ + Token: test.GetTestToken(), + Fallback: http.DefaultTransport, + } + + client := NewClientWithConfig(config) + ctx := context.Background() + + request := CompletionRequest{ + Prompt: "Ex falso quodlibet", + Model: "text-davinci-002", + MaxTokens: 10, + Stream: true, + } + + stream, err := client.CreateCompletionStream(ctx, request) + checks.NoError(t, err, "CreateCompletionStream returned error") + defer stream.Close() + + _, _ = stream.Recv() + _, streamErr := stream.Recv() + if !errors.Is(streamErr, io.EOF) { + t.Errorf("TestCreateCompletionStreamUnexpectedTerminatedError did not return io.EOF") + } +} + +func TestCreateCompletionStreamBrokenJSONError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + + // Send test responses + dataBytes := []byte{} + dataBytes = append(dataBytes, []byte("event: message\n")...) + //nolint:lll + data := `{"id":"1","object":"completion","created":1598069254,"model":"text-davinci-002","choices":[{"text":"response1","finish_reason":"max_tokens"}]}` + dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) + + // Send broken json + dataBytes = append(dataBytes, []byte("event: message\n")...) + data = `{"id":"2","object":"completion","created":1598069255,"model":` + dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) + + dataBytes = append(dataBytes, []byte("event: done\n")...) + dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...) + + _, err := w.Write(dataBytes) + checks.NoError(t, err, "Write error") + })) + defer server.Close() + + // Client portion of the test + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = server.URL + "/v1" + config.HTTPClient.Transport = &test.TokenRoundTripper{ + Token: test.GetTestToken(), + Fallback: http.DefaultTransport, + } + + client := NewClientWithConfig(config) + ctx := context.Background() + + request := CompletionRequest{ + Prompt: "Ex falso quodlibet", + Model: "text-davinci-002", + MaxTokens: 10, + Stream: true, + } + + stream, err := client.CreateCompletionStream(ctx, request) + checks.NoError(t, err, "CreateCompletionStream returned error") + defer stream.Close() + + _, _ = stream.Recv() + _, streamErr := stream.Recv() + var syntaxError *json.SyntaxError + if !errors.As(streamErr, &syntaxError) { + t.Errorf("TestCreateCompletionStreamBrokenJSONError did not return json.SyntaxError") + } +} + // Helper funcs. func compareResponses(r1, r2 CompletionResponse) bool { if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model {