diff --git a/httplib.go b/httplib.go index e820f41..8f85c0b 100644 --- a/httplib.go +++ b/httplib.go @@ -107,12 +107,11 @@ func unmarshalError(err error, responseBody []byte) error { } var raw RawTrace if err2 := json.Unmarshal(responseBody, &raw); err2 != nil { - return err + return errorOnInvalidJSON(err, responseBody) } if len(raw.Err) != 0 { - err2 := json.Unmarshal(raw.Err, err) - if err2 != nil { - return err + if err2 := json.Unmarshal(raw.Err, err); err2 != nil { + return errorOnInvalidJSON(err, responseBody) } return &TraceErr{ Traces: raw.Traces, @@ -122,6 +121,19 @@ func unmarshalError(err error, responseBody []byte) error { Fields: raw.Fields, } } - json.Unmarshal(responseBody, err) + if err2 := json.Unmarshal(responseBody, err); err2 != nil { + return errorOnInvalidJSON(err, responseBody) + } return err } + +// errorOnInvalidJSON is used to construct a TraceErr with the +// input error as Err and the responseBody as Messages. +// This function is used when the responseBody is not valid +// JSON or it contains an unexpected JSON. +func errorOnInvalidJSON(err error, responseBody []byte) error { + return &TraceErr{ + Err: err, + Messages: []string{string(responseBody)}, + } +} diff --git a/httplib_test.go b/httplib_test.go index 3f12e0a..56b2865 100644 --- a/httplib_test.go +++ b/httplib_test.go @@ -26,27 +26,35 @@ import ( func TestReplyJSON(t *testing.T) { t.Parallel() - var ( - errCode = 400 - errText = "test error" - expectedErrorResponse = "" + - "{\n" + - " \"error\": {\n" + - " \"message\": \"" + errText + "\"\n" + - " }\n" + - "}" - ) - for _, tc := range []struct { + var expectedErrorResponse = `{ + "error": { + "message": "test error" + } + }` + + tests := []struct { desc string err error }{ - {"plain error", errors.New("test error")}, - {"trace error", &TraceErr{Err: errors.New("test error")}}, - {"trace error with stacktrace", &TraceErr{Err: errors.New("test error"), Traces: Traces{{Path: "A", Func: "B", Line: 1}}}}, - } { + { + desc: "plain error", + err: errors.New("test error"), + }, + { + desc: "trace error", + err: &TraceErr{Err: errors.New("test error")}, + }, + { + desc: "trace error with stacktrace", + err: &TraceErr{Err: errors.New("test error"), Traces: Traces{{Path: "A", Func: "B", Line: 1}}}, + }, + } + + for _, tc := range tests { t.Run(tc.desc, func(t *testing.T) { recorder := httptest.NewRecorder() + const errCode = 400 replyJSON(recorder, errCode, tc.err) require.JSONEq(t, expectedErrorResponse, recorder.Body.String()) }) @@ -55,12 +63,63 @@ func TestReplyJSON(t *testing.T) { func TestUnmarshalError(t *testing.T) { t.Parallel() - testCase := func(t *testing.T, err error, response string, isExpectedErr func(error) bool, expectedMsg string) { - readErr := unmarshalError(err, []byte(response)) - require.True(t, isExpectedErr(readErr)) - require.EqualError(t, readErr, expectedMsg) + + tests := []struct { + desc string + inputErr error + inputResponse string + assertErr func(error) bool + expectedMsg string + }{ + { + desc: "unmarshal not found error", + inputErr: &NotFoundError{}, + inputResponse: `{"error": {"message": "ABC"}}`, + assertErr: IsNotFound, + expectedMsg: "ABC", + }, + { + desc: "unmarshal access denied error", + inputErr: &AccessDeniedError{}, + inputResponse: `{"error": {"message": "ABC"}}`, + assertErr: IsAccessDenied, + expectedMsg: "ABC", + }, + { + desc: "unmarshal error without error JSON key", + inputErr: &AccessDeniedError{}, + inputResponse: `{"message": "ABC"}`, + assertErr: IsAccessDenied, + expectedMsg: "ABC", + }, + { + desc: "unmarshal invalid error", + inputErr: &AccessDeniedError{}, + inputResponse: `{"error": "message ABC"}`, + assertErr: IsAccessDenied, + expectedMsg: "{\"error\": \"message ABC\"}\n\taccess denied", + }, + { + desc: "unmarshal invalid error without error JSON key", + inputErr: &AccessDeniedError{}, + inputResponse: `["error message ABC"]`, + assertErr: IsAccessDenied, + expectedMsg: "[\"error message ABC\"]\n\taccess denied", + }, + { + desc: "unmarshal error with non-JSON body", + inputErr: &AccessDeniedError{}, + inputResponse: "error message ABC", + assertErr: IsAccessDenied, + expectedMsg: "error message ABC\n\taccess denied", + }, } - testCase(t, &NotFoundError{}, `{"error": {"message": "ABC"}}`, IsNotFound, "ABC") - testCase(t, &AccessDeniedError{}, `{"error": {"message": "ABC"}}`, IsAccessDenied, "ABC") + for _, tc := range tests { + t.Run(tc.desc, func(t *testing.T) { + readErr := unmarshalError(tc.inputErr, []byte(tc.inputResponse)) + require.True(t, tc.assertErr(readErr)) + require.EqualError(t, readErr, tc.expectedMsg) + }) + } }