Skip to content

Commit e113848

Browse files
authored
Merge pull request #837 from aws/fix-unexpected-panic
`rest-json`: updates rest-json error code retriever util
2 parents 5550687 + 3abce74 commit e113848

File tree

2 files changed

+104
-41
lines changed

2 files changed

+104
-41
lines changed

aws/protocol/restjson/decoder_util.go

Lines changed: 21 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -2,68 +2,48 @@ package restjson
22

33
import (
44
"encoding/json"
5-
"fmt"
65
"io"
76
"strings"
87

98
"github.com/awslabs/smithy-go"
10-
smithyjson "github.com/awslabs/smithy-go/json"
119
)
1210

1311
// GetErrorInfo util looks for code, __type, and message members in the
1412
// json body. These members are optionally available, and the function
1513
// returns the value of member if it is available. This function is useful to
1614
// identify the error code, msg in a REST JSON error response.
1715
func GetErrorInfo(decoder *json.Decoder) (errorType string, message string, err error) {
18-
startToken, err := decoder.Token()
19-
if err == io.EOF {
20-
return "", "", nil
16+
var errInfo struct {
17+
Code string
18+
Type string `json:"__type"`
19+
Message string
2120
}
22-
if err != nil {
23-
return "", "", err
24-
}
25-
26-
if t, ok := startToken.(json.Delim); !ok || t.String() != "{" {
27-
return "", "", fmt.Errorf("expected start token to be {")
28-
}
29-
30-
for decoder.More() {
31-
var target *string
32-
t, err := decoder.Token()
33-
if err != nil {
34-
return "", "", err
35-
}
3621

37-
switch st := t.(string); {
38-
case strings.EqualFold(st, "code"):
39-
fallthrough
40-
case strings.EqualFold(st, "__type"):
41-
target = &errorType
42-
case strings.EqualFold(st, "message"):
43-
target = &message
44-
default:
45-
smithyjson.DiscardUnknownField(decoder)
46-
continue
47-
}
48-
49-
v, err := decoder.Token()
50-
if err != nil {
51-
return errorType, message, err
22+
err = decoder.Decode(&errInfo)
23+
if err != nil {
24+
if err == io.EOF {
25+
return errorType, message, nil
5226
}
53-
*target = v.(string)
27+
return errorType, message, err
5428
}
5529

56-
endToken, err := decoder.Token()
57-
if err != nil {
58-
return "", "", err
30+
// assign error type
31+
if len(errInfo.Code) != 0 {
32+
errorType = errInfo.Code
33+
} else if len(errInfo.Type) != 0 {
34+
errorType = errInfo.Type
5935
}
6036

61-
if t, ok := endToken.(json.Delim); !ok || t.String() != "}" {
62-
return "", "", fmt.Errorf("expected end token to be }")
37+
// assign error message
38+
if len(errInfo.Message) != 0 {
39+
message = errInfo.Message
6340
}
6441

6542
// sanitize error
66-
errorType = SanitizeErrorCode(errorType)
43+
if len(errorType) != 0 {
44+
errorType = SanitizeErrorCode(errorType)
45+
}
46+
6747
return errorType, message, nil
6848
}
6949

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
package restjson
2+
3+
import (
4+
"bytes"
5+
"encoding/json"
6+
"io"
7+
"strings"
8+
"testing"
9+
)
10+
11+
func TestGetErrorInfo(t *testing.T) {
12+
cases := map[string]struct {
13+
errorResponse []byte
14+
expectedErrorType string
15+
expectedErrorMsg string
16+
expectedDeserializationError string
17+
}{
18+
"error with code": {
19+
errorResponse: []byte(`{"code": "errorCode", "message": "message for errorCode"}`),
20+
expectedErrorType: "errorCode",
21+
expectedErrorMsg: "message for errorCode",
22+
},
23+
"error with type": {
24+
errorResponse: []byte(`{"__type": "errorCode", "message": "message for errorCode"}`),
25+
expectedErrorType: "errorCode",
26+
expectedErrorMsg: "message for errorCode",
27+
},
28+
29+
"error with only message": {
30+
errorResponse: []byte(`{"message": "message for errorCode"}`),
31+
expectedErrorMsg: "message for errorCode",
32+
},
33+
34+
"error with only code": {
35+
errorResponse: []byte(`{"code": "errorCode"}`),
36+
expectedErrorType: "errorCode",
37+
},
38+
39+
"empty": {
40+
errorResponse: []byte(``),
41+
},
42+
43+
"unknownField": {
44+
errorResponse: []byte(`{"xyz":"abc", "code": "errorCode"}`),
45+
expectedErrorType: "errorCode",
46+
},
47+
48+
"unexpectedEOF": {
49+
errorResponse: []byte(`{"xyz":"abc"`),
50+
expectedDeserializationError: io.ErrUnexpectedEOF.Error(),
51+
},
52+
53+
"caseless compare": {
54+
errorResponse: []byte(`{"Code": "errorCode", "Message": "errorMessage", "xyz": "abc"}`),
55+
expectedErrorType: "errorCode",
56+
expectedErrorMsg: "errorMessage",
57+
},
58+
}
59+
60+
for name, c := range cases {
61+
t.Run(name, func(t *testing.T) {
62+
decoder := json.NewDecoder(bytes.NewReader(c.errorResponse))
63+
actualType, actualMsg, err := GetErrorInfo(decoder)
64+
if err != nil {
65+
if len(c.expectedDeserializationError) == 0 {
66+
t.Fatalf("expected no error, got %v", err.Error())
67+
}
68+
69+
if e, a := c.expectedDeserializationError, err.Error(); !strings.Contains(a, e) {
70+
t.Fatalf("expected error to be %v, got %v", e, a)
71+
}
72+
}
73+
74+
if e, a := c.expectedErrorType, actualType; !strings.EqualFold(e, a) {
75+
t.Fatalf("expected error type to be %v, got %v", e, a)
76+
}
77+
78+
if e, a := c.expectedErrorMsg, actualMsg; !strings.EqualFold(e, a) {
79+
t.Fatalf("expected error message to be %v, got %v", e, a)
80+
}
81+
})
82+
}
83+
}

0 commit comments

Comments
 (0)