maintain underlying error structs to allow for type conversion (#293)
* maintain underlying error structs to allow for type conversion and defensive error checking * allow Error.Is for Azure responses * update readme, add tests to ensure type conversion * fix whitespacing * read me * add import to readme example
This commit is contained in:
34
README.md
34
README.md
@@ -10,13 +10,13 @@ This library provides Go clients for [OpenAI API](https://platform.openai.com/).
|
|||||||
* DALL·E 2
|
* DALL·E 2
|
||||||
* Whisper
|
* Whisper
|
||||||
|
|
||||||
Installation:
|
### Installation:
|
||||||
```
|
```
|
||||||
go get github.com/sashabaranov/go-openai
|
go get github.com/sashabaranov/go-openai
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
ChatGPT example usage:
|
### ChatGPT example usage:
|
||||||
|
|
||||||
```go
|
```go
|
||||||
package main
|
package main
|
||||||
@@ -52,9 +52,7 @@ func main() {
|
|||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Other examples:
|
||||||
|
|
||||||
Other examples:
|
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
<summary>ChatGPT streaming completion</summary>
|
<summary>ChatGPT streaming completion</summary>
|
||||||
@@ -462,3 +460,29 @@ func main() {
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>Error handling</summary>
|
||||||
|
|
||||||
|
Open-AI maintains clear documentation on how to [handle API errors](https://platform.openai.com/docs/guides/error-codes/api-errors)
|
||||||
|
|
||||||
|
example:
|
||||||
|
```
|
||||||
|
e := &openai.APIError{}
|
||||||
|
if errors.As(err, &e) {
|
||||||
|
switch e.HTTPStatusCode {
|
||||||
|
case 401:
|
||||||
|
// invalid auth or key (do not retry)
|
||||||
|
case 429:
|
||||||
|
// rate limiting or engine overload (wait and retry)
|
||||||
|
case 500:
|
||||||
|
// openai server error (retry)
|
||||||
|
default:
|
||||||
|
// unhandled
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
```
|
||||||
|
</details>
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -149,15 +149,16 @@ func (c *Client) handleErrorResp(resp *http.Response) error {
|
|||||||
var errRes ErrorResponse
|
var errRes ErrorResponse
|
||||||
err := json.NewDecoder(resp.Body).Decode(&errRes)
|
err := json.NewDecoder(resp.Body).Decode(&errRes)
|
||||||
if err != nil || errRes.Error == nil {
|
if err != nil || errRes.Error == nil {
|
||||||
reqErr := RequestError{
|
reqErr := &RequestError{
|
||||||
HTTPStatusCode: resp.StatusCode,
|
HTTPStatusCode: resp.StatusCode,
|
||||||
Err: err,
|
Err: err,
|
||||||
}
|
}
|
||||||
if errRes.Error != nil {
|
if errRes.Error != nil {
|
||||||
reqErr.Err = errRes.Error
|
reqErr.Err = errRes.Error
|
||||||
}
|
}
|
||||||
return fmt.Errorf("error, %w", &reqErr)
|
return reqErr
|
||||||
}
|
}
|
||||||
|
|
||||||
errRes.Error.HTTPStatusCode = resp.StatusCode
|
errRes.Error.HTTPStatusCode = resp.StatusCode
|
||||||
return fmt.Errorf("error, status code: %d, message: %w", resp.StatusCode, errRes.Error)
|
return errRes.Error
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package openai //nolint:testpackage // testing private field
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -106,7 +107,7 @@ func TestHandleErrorResp(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}`,
|
}`,
|
||||||
)),
|
)),
|
||||||
expected: "error, status code 401, message: Access denied due to Virtual Network/Firewall rules.",
|
expected: "error, status code: 401, message: Access denied due to Virtual Network/Firewall rules.",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "503 Model Overloaded",
|
name: "503 Model Overloaded",
|
||||||
@@ -135,6 +136,12 @@ func TestHandleErrorResp(t *testing.T) {
|
|||||||
t.Errorf("Unexpected error: %v , expected: %s", err, tc.expected)
|
t.Errorf("Unexpected error: %v , expected: %s", err, tc.expected)
|
||||||
t.Fail()
|
t.Fail()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
e := &APIError{}
|
||||||
|
if !errors.As(err, &e) {
|
||||||
|
t.Errorf("(%s) Expected error to be of type APIError", tc.name)
|
||||||
|
t.Fail()
|
||||||
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
6
error.go
6
error.go
@@ -25,6 +25,10 @@ type ErrorResponse struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (e *APIError) Error() string {
|
func (e *APIError) Error() string {
|
||||||
|
if e.HTTPStatusCode > 0 {
|
||||||
|
return fmt.Sprintf("error, status code: %d, message: %s", e.HTTPStatusCode, e.Message)
|
||||||
|
}
|
||||||
|
|
||||||
return e.Message
|
return e.Message
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -70,7 +74,7 @@ func (e *APIError) UnmarshalJSON(data []byte) (err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (e *RequestError) Error() string {
|
func (e *RequestError) Error() string {
|
||||||
return fmt.Sprintf("status code %d, message: %s", e.HTTPStatusCode, e.Err)
|
return fmt.Sprintf("error, status code: %d, message: %s", e.HTTPStatusCode, e.Err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *RequestError) Unwrap() error {
|
func (e *RequestError) Unwrap() error {
|
||||||
|
|||||||
Reference in New Issue
Block a user