diff --git a/errors.go b/errors.go index 0ffdcc3..bc05eef 100644 --- a/errors.go +++ b/errors.go @@ -26,11 +26,14 @@ import ( "os" ) +// traceDepth is the depth to be used by error constructors. +const traceDepth = 2 + // NotFound returns new instance of not found error func NotFound(message string, args ...interface{}) Error { return newTrace(&NotFoundError{ Message: fmt.Sprintf(message, args...), - }, 2) + }, traceDepth) } // NotFoundError indicates that object has not been found @@ -88,7 +91,7 @@ func IsNotFound(e error) bool { func AlreadyExists(message string, args ...interface{}) Error { return newTrace(&AlreadyExistsError{ Message: fmt.Sprintf(message, args...), - }, 2) + }, traceDepth) } // AlreadyExistsError indicates that there's a duplicate object that already @@ -136,7 +139,7 @@ func IsAlreadyExists(e error) bool { func BadParameter(message string, args ...interface{}) Error { return newTrace(&BadParameterError{ Message: fmt.Sprintf(message, args...), - }, 2) + }, traceDepth) } // BadParameterError indicates that something is wrong with passed @@ -181,7 +184,7 @@ func IsBadParameter(e error) bool { func NotImplemented(message string, args ...interface{}) Error { return newTrace(&NotImplementedError{ Message: fmt.Sprintf(message, args...), - }, 2) + }, traceDepth) } // NotImplementedError defines an error condition to describe the result @@ -226,7 +229,7 @@ func IsNotImplemented(e error) bool { func CompareFailed(message string, args ...interface{}) Error { return newTrace(&CompareFailedError{ Message: fmt.Sprintf(message, args...), - }, 2) + }, traceDepth) } // CompareFailedError indicates a failed comparison (e.g. bad password or hash) @@ -274,7 +277,7 @@ func IsCompareFailed(e error) bool { func AccessDenied(message string, args ...interface{}) Error { return newTrace(&AccessDeniedError{ Message: fmt.Sprintf(message, args...), - }, 2) + }, traceDepth) } // AccessDeniedError indicates denied access @@ -325,35 +328,35 @@ func ConvertSystemError(err error) error { if os.IsExist(innerError) { return newTrace(&AlreadyExistsError{ Message: innerError.Error(), - }, 2) + }, traceDepth) } if os.IsNotExist(innerError) { return newTrace(&NotFoundError{ Message: innerError.Error(), - }, 2) + }, traceDepth) } if os.IsPermission(innerError) { return newTrace(&AccessDeniedError{ Message: innerError.Error(), - }, 2) + }, traceDepth) } switch realErr := innerError.(type) { case *net.OpError: return newTrace(&ConnectionProblemError{ Err: realErr, - }, 2) + }, traceDepth) case *os.PathError: message := fmt.Sprintf("failed to execute command %v error: %v", realErr.Path, realErr.Err) return newTrace(&AccessDeniedError{ Message: message, - }, 2) + }, traceDepth) case x509.SystemRootsError, x509.UnknownAuthorityError: - return newTrace(&TrustError{Err: innerError}, 2) + return newTrace(&TrustError{Err: innerError}, traceDepth) } if _, ok := innerError.(net.Error); ok { return newTrace(&ConnectionProblemError{ Err: innerError, - }, 2) + }, traceDepth) } return err } @@ -363,7 +366,7 @@ func ConnectionProblem(err error, message string, args ...interface{}) Error { return newTrace(&ConnectionProblemError{ Message: fmt.Sprintf(message, args...), Err: err, - }, 2) + }, traceDepth) } // ConnectionProblemError indicates a network related problem @@ -419,7 +422,7 @@ func IsConnectionProblem(e error) bool { func LimitExceeded(message string, args ...interface{}) Error { return newTrace(&LimitExceededError{ Message: fmt.Sprintf(message, args...), - }, 2) + }, traceDepth) } // LimitExceededError indicates rate limit or connection limit problem @@ -464,7 +467,7 @@ func Trust(err error, message string, args ...interface{}) Error { return newTrace(&TrustError{ Message: fmt.Sprintf(message, args...), Err: err, - }, 2) + }, traceDepth) } // TrustError indicates trust-related validation error (e.g. untrusted cert) @@ -522,7 +525,7 @@ func OAuth2(code, message string, query url.Values) Error { Code: code, Message: message, Query: query, - }, 2) + }, traceDepth) } // OAuth2Error defined an error used in OpenID Connect Flow (OIDC) @@ -589,7 +592,7 @@ func Retry(err error, message string, args ...interface{}) Error { return newTrace(&RetryError{ Message: fmt.Sprintf(message, args...), Err: err, - }, 2) + }, traceDepth) } // RetryError indicates a transient error type diff --git a/errors_test.go b/errors_test.go index 90fc342..c7d8c67 100644 --- a/errors_test.go +++ b/errors_test.go @@ -461,7 +461,7 @@ func TestGoErrorWrap_IsError_allTypes(t *testing.T) { }, { name: "RetryError", - instance: Retry(errors.New("underyling error"), "message"), + instance: Retry(errors.New("underlying error"), "message"), isError: IsRetryError, }, } diff --git a/trace_test.go b/trace_test.go index 4ddf073..8621e78 100644 --- a/trace_test.go +++ b/trace_test.go @@ -29,91 +29,83 @@ import ( "testing" log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/stretchr/testify/suite" ) -func TestTrace(t *testing.T) { - suite.Run(t, new(TraceSuite)) +func TestEmpty(t *testing.T) { + assert.Equal(t, "", DebugReport(nil)) + assert.Equal(t, "", UserMessage(nil)) + assert.Equal(t, "", UserMessageWithFields(nil)) + assert.Equal(t, map[string]interface{}{}, GetFields(nil)) } -type TraceSuite struct { - suite.Suite -} - -func (s *TraceSuite) TestEmpty() { - s.Equal("", DebugReport(nil)) - s.Equal("", UserMessage(nil)) - s.Equal("", UserMessageWithFields(nil)) - s.Equal(map[string]interface{}{}, GetFields(nil)) -} - -func (s *TraceSuite) TestWrap() { +func TestWrap(t *testing.T) { testErr := &testError{Param: "param"} err := Wrap(Wrap(testErr)) - s.Regexp(".*trace_test.go.*", line(DebugReport(err))) - s.NotRegexp(".*trace.go.*", line(DebugReport(err))) - s.NotRegexp(".*trace_test.go.*", line(UserMessage(err))) - s.Regexp(".*param.*", line(UserMessage(err))) + assert.Regexp(t, ".*trace_test.go.*", line(DebugReport(err))) + assert.NotRegexp(t, ".*trace.go.*", line(DebugReport(err))) + assert.NotRegexp(t, ".*trace_test.go.*", line(UserMessage(err))) + assert.Regexp(t, ".*param.*", line(UserMessage(err))) } -func (s *TraceSuite) TestOrigError() { +func TestOrigError(t *testing.T) { testErr := fmt.Errorf("some error") err := Wrap(Wrap(testErr)) - s.Equal(testErr, err.OrigError()) + assert.Equal(t, testErr, err.OrigError()) } -func (s *TraceSuite) TestIsEOF() { - s.True(IsEOF(io.EOF)) - s.True(IsEOF(Wrap(io.EOF))) +func TestIsEOF(t *testing.T) { + assert.True(t, IsEOF(io.EOF)) + assert.True(t, IsEOF(Wrap(io.EOF))) } -func (s *TraceSuite) TestWrapUserMessage() { +func TestWrapUserMessage(t *testing.T) { testErr := fmt.Errorf("description") err := Wrap(testErr, "user message") - s.Regexp(".*trace_test.go.*", line(DebugReport(err))) - s.NotRegexp(".*trace.go.*", line(DebugReport(err))) - s.Equal("user message\tdescription", line(UserMessage(err))) + assert.Regexp(t, ".*trace_test.go.*", line(DebugReport(err))) + assert.NotRegexp(t, ".*trace.go.*", line(DebugReport(err))) + assert.Equal(t, "user message\tdescription", line(UserMessage(err))) err = Wrap(err, "user message 2") - s.Equal("user message 2\tuser message\t\tdescription", line(UserMessage(err))) + assert.Equal(t, "user message 2\tuser message\t\tdescription", line(UserMessage(err))) } -func (s *TraceSuite) TestWrapWithMessage() { +func TestWrapWithMessage(t *testing.T) { testErr := fmt.Errorf("description") err := WrapWithMessage(testErr, "user message") - s.Equal("user message\tdescription", line(UserMessage(err))) - s.Regexp(".*trace_test.go.*", line(DebugReport(err))) - s.NotRegexp(".*trace.go.*", line(DebugReport(err))) + assert.Equal(t, "user message\tdescription", line(UserMessage(err))) + assert.Regexp(t, ".*trace_test.go.*", line(DebugReport(err))) + assert.NotRegexp(t, ".*trace.go.*", line(DebugReport(err))) } -func (s *TraceSuite) TestUserMessageWithFields() { +func TestUserMessageWithFields(t *testing.T) { testErr := fmt.Errorf("description") - s.Equal(testErr.Error(), UserMessageWithFields(testErr)) + assert.Equal(t, testErr.Error(), UserMessageWithFields(testErr)) err := Wrap(testErr, "user message") - s.Equal("user message\tdescription", line(UserMessageWithFields(err))) + assert.Equal(t, "user message\tdescription", line(UserMessageWithFields(err))) err = WithField(err, "test_key", "test_value") - s.Equal("test_key=\"test_value\" user message\tdescription", line(UserMessageWithFields(err))) + assert.Equal(t, "test_key=\"test_value\" user message\tdescription", line(UserMessageWithFields(err))) } -func (s *TraceSuite) TestGetFields() { +func TestGetFields(t *testing.T) { testErr := fmt.Errorf("description") - s.Equal(map[string]interface{}{}, GetFields(testErr)) + assert.Equal(t, map[string]interface{}{}, GetFields(testErr)) fields := map[string]interface{}{ "test_key": "test_value", } err := WithFields(Wrap(testErr), fields) - s.Equal(fields, GetFields(err)) + assert.Equal(t, fields, GetFields(err)) // ensure that you can get fields from a proxyError e := roundtripError(err) - s.Equal(fields, GetFields(e)) + assert.Equal(t, fields, GetFields(e)) } func roundtripError(err error) error { @@ -124,18 +116,18 @@ func roundtripError(err error) error { return outErr } -func (s *TraceSuite) TestWrapNil() { +func TestWrapNil(t *testing.T) { err1 := Wrap(nil, "message: %v", "extra") - s.Nil(err1) + assert.Nil(t, err1) var err2 error err2 = nil err3 := Wrap(err2) - s.Nil(err3) + assert.Nil(t, err3) err4 := Wrap(err3) - s.Nil(err4) + assert.Nil(t, err4) } func TestRaceErrorWrap(t *testing.T) { @@ -173,11 +165,11 @@ func TestRaceErrorWrap(t *testing.T) { wg.Wait() } -func (s *TraceSuite) TestWrapStdlibErrors() { - s.True(IsNotFound(os.ErrNotExist)) +func TestWrapStdlibErrors(t *testing.T) { + assert.True(t, IsNotFound(os.ErrNotExist)) } -func (s *TraceSuite) TestLogFormatter() { +func TestLogFormatter(t *testing.T) { for _, f := range []log.Formatter{&TextFormatter{}, &JSONFormatter{}} { log.SetFormatter(f) @@ -185,12 +177,12 @@ func (s *TraceSuite) TestLogFormatter() { var buf bytes.Buffer log.SetOutput(&buf) log.Infof("hello") - s.Regexp(".*trace_test.go.*", line(buf.String())) + assert.Regexp(t, ".*trace_test.go.*", line(buf.String())) // check case with embedded Infof buf.Reset() log.WithFields(log.Fields{"a": "b"}).Infof("hello") - s.Regexp(".*trace_test.go.*", line(buf.String())) + assert.Regexp(t, ".*trace_test.go.*", line(buf.String())) } } @@ -200,7 +192,7 @@ func (p panicker) String() string { panic(p) } -func (s *TraceSuite) TestTextFormatter() { +func TestTextFormatter(t *testing.T) { padding := 6 f := &TextFormatter{ DisableTimestamp: true, @@ -296,11 +288,11 @@ func (s *TraceSuite) TestTextFormatter() { buf := &bytes.Buffer{} log.SetOutput(buf) tc.log() - s.Regexp(tc.match, line(buf.String()), "test case %v %v, expected match: %v", i+1, tc.comment, tc.match) + assert.Regexp(t, tc.match, line(buf.String()), "test case %v %v, expected match: %v", i+1, tc.comment, tc.match) } } -func (s *TraceSuite) TestTextFormatterWithColors() { +func TestTextFormatterWithColors(t *testing.T) { padding := 6 f := &TextFormatter{ DisableTimestamp: true, @@ -368,11 +360,11 @@ func (s *TraceSuite) TestTextFormatterWithColors() { log.SetOutput(buf) log.SetLevel(log.DebugLevel) tc.log() - s.Regexpf(tc.match, line(buf.String()), "test case %v %v, expected match: %v", i+1, tc.comment, tc.match) + assert.Regexpf(t, tc.match, line(buf.String()), "test case %v %v, expected match: %v", i+1, tc.comment, tc.match) } } -func (s *TraceSuite) TestGenericErrors() { +func TestGenericErrors(t *testing.T) { testCases := []struct { Err Error Predicate func(error) bool @@ -436,53 +428,53 @@ func (s *TraceSuite) TestGenericErrors() { var traceErr *TraceErr var ok bool if traceErr, ok = err.(*TraceErr); !ok { - s.Fail("Expected error to be of type *TraceErr") + t.Fatalf("Expected error to be of type *TraceErr: %#v", err) } - s.NotEmpty(traceErr.Traces, testCase.comment) - s.Regexp(".*.trace_test\\.go.*", line(DebugReport(err)), testCase.comment) - s.NotRegexp(".*.errors\\.go.*", line(DebugReport(err)), testCase.comment) - s.NotRegexp(".*.trace\\.go.*", line(DebugReport(err)), testCase.comment) - s.True(testCase.Predicate(err), testCase.comment) + assert.NotEmpty(t, traceErr.Traces, testCase.comment) + assert.Regexp(t, ".*.trace_test\\.go.*", line(DebugReport(err)), testCase.comment) + assert.NotRegexp(t, ".*.errors\\.go.*", line(DebugReport(err)), testCase.comment) + assert.NotRegexp(t, ".*.trace\\.go.*", line(DebugReport(err)), testCase.comment) + assert.True(t, testCase.Predicate(err), testCase.comment) w := newTestWriter() WriteError(w, err) outErr := ReadError(w.StatusCode, w.Body) if _, ok := outErr.(proxyError); !ok { - s.Fail("Expected error to be of type proxyError") + t.Fatalf("Expected error to be of type proxyError: %#v", outErr) } - s.True(testCase.Predicate(outErr), testCase.comment) + assert.True(t, testCase.Predicate(outErr), testCase.comment) SetDebug(false) w = newTestWriter() WriteError(w, err) outErr = ReadError(w.StatusCode, w.Body) - s.True(testCase.Predicate(outErr), testCase.comment) + assert.True(t, testCase.Predicate(outErr), testCase.comment) } } // Make sure we write some output produced by standard errors -func (s *TraceSuite) TestWriteExternalErrors() { +func TestWriteExternalErrors(t *testing.T) { err := Wrap(fmt.Errorf("snap!")) SetDebug(true) w := newTestWriter() WriteError(w, err) extErr := ReadError(w.StatusCode, w.Body) - s.Equal(http.StatusInternalServerError, w.StatusCode) - s.Regexp(".*.snap.*", strings.Replace(string(w.Body), "\n", "", -1)) - s.Require().NotNil(extErr) - s.EqualError(err, extErr.Error()) + assert.Equal(t, http.StatusInternalServerError, w.StatusCode) + assert.Regexp(t, ".*.snap.*", strings.Replace(string(w.Body), "\n", "", -1)) + require.NotNil(t, extErr) + assert.EqualError(t, err, extErr.Error()) SetDebug(false) w = newTestWriter() WriteError(w, err) extErr = ReadError(w.StatusCode, w.Body) - s.Equal(http.StatusInternalServerError, w.StatusCode) - s.Regexp(".*.snap.*", strings.Replace(string(w.Body), "\n", "", -1)) - s.Require().NotNil(extErr) - s.EqualError(err, extErr.Error()) + assert.Equal(t, http.StatusInternalServerError, w.StatusCode) + assert.Regexp(t, ".*.snap.*", strings.Replace(string(w.Body), "\n", "", -1)) + require.NotNil(t, extErr) + assert.EqualError(t, err, extErr.Error()) } type netError struct{} @@ -491,51 +483,51 @@ func (e *netError) Error() string { return "net" } func (e *netError) Timeout() bool { return true } func (e *netError) Temporary() bool { return true } -func (s *TraceSuite) TestConvert() { +func TestConvert(t *testing.T) { err := ConvertSystemError(&netError{}) - s.True(IsConnectionProblem(err), "failed to detect network error") + assert.True(t, IsConnectionProblem(err), "failed to detect network error") - dir := s.T().TempDir() + dir := t.TempDir() err = os.Mkdir(dir, 0o770) err = ConvertSystemError(err) - s.True(IsAlreadyExists(err), "expected AlreadyExists error, got %T", err) + assert.True(t, IsAlreadyExists(err), "expected AlreadyExists error, got %T", err) } -func (s *TraceSuite) TestAggregates() { +func TestAggregates(t *testing.T) { err1 := Errorf("failed one") err2 := Errorf("failed two") err := NewAggregate(err1, err2) - s.True(IsAggregate(err)) + assert.True(t, IsAggregate(err)) agg := Unwrap(err).(Aggregate) - s.Equal([]error{err1, err2}, agg.Errors()) - s.Equal("failed one, failed two", err.Error()) + assert.Equal(t, []error{err1, err2}, agg.Errors()) + assert.Equal(t, "failed one, failed two", err.Error()) } -func (s *TraceSuite) TestErrorf() { +func TestErrorf(t *testing.T) { err := Errorf("error") - s.Regexp(".*.trace_test.go.*", line(DebugReport(err))) - s.NotRegexp(".*.Fields.*", line(DebugReport(err))) - s.Equal([]string(nil), err.(*TraceErr).Messages) + assert.Regexp(t, ".*.trace_test.go.*", line(DebugReport(err))) + assert.NotRegexp(t, ".*.Fields.*", line(DebugReport(err))) + assert.Equal(t, []string(nil), err.(*TraceErr).Messages) } -func (s *TraceSuite) TestWithField() { +func TestWithField(t *testing.T) { err := WithField(Wrap(Errorf("error")), "testfield", true) - s.Regexp(".*.testfield.*", line(DebugReport(err))) + assert.Regexp(t, ".*.testfield.*", line(DebugReport(err))) } -func (s *TraceSuite) TestWithFields() { +func TestWithFields(t *testing.T) { err := WithFields(Wrap(Errorf("error")), map[string]interface{}{ "testfield1": true, "testfield2": "value2", }) - s.Regexp(".*.Fields.*", line(DebugReport(err))) - s.Regexp(".*.testfield1: true.*", line(DebugReport(err))) - s.Regexp(".*.testfield2: value2.*", line(DebugReport(err))) + assert.Regexp(t, ".*.Fields.*", line(DebugReport(err))) + assert.Regexp(t, ".*.testfield1: true.*", line(DebugReport(err))) + assert.Regexp(t, ".*.testfield2: value2.*", line(DebugReport(err))) } -func (s *TraceSuite) TestAggregateConvertsToCommonErrors() { +func TestAggregateConvertsToCommonErrors(t *testing.T) { testCases := []struct { Err error Predicate func(error) bool @@ -570,33 +562,33 @@ func (s *TraceSuite) TestAggregateConvertsToCommonErrors() { SetDebug(true) err := testCase.Err - s.Regexp(".*.trace_test.go.*", line(DebugReport(err)), testCase.comment) - s.True(testCase.Predicate(err), testCase.comment) + assert.Regexp(t, ".*.trace_test.go.*", line(DebugReport(err)), testCase.comment) + assert.True(t, testCase.Predicate(err), testCase.comment) w := newTestWriter() WriteError(w, err) outErr := ReadError(w.StatusCode, w.Body) - s.True(testCase.RoundtripPredicate(outErr), testCase.comment) + assert.True(t, testCase.RoundtripPredicate(outErr), testCase.comment) SetDebug(false) w = newTestWriter() WriteError(w, err) outErr = ReadError(w.StatusCode, w.Body) - s.True(testCase.RoundtripPredicate(outErr), testCase.comment) + assert.True(t, testCase.RoundtripPredicate(outErr), testCase.comment) } } -func (s *TraceSuite) TestAggregateThrowAwayNils() { +func TestAggregateThrowAwayNils(t *testing.T) { err := NewAggregate(fmt.Errorf("error1"), nil, fmt.Errorf("error2")) - s.Require().NotNil(err) - s.NotRegexp(".*nil.*", err.Error()) + require.NotNil(t, err) + assert.NotRegexp(t, ".*nil.*", err.Error()) } -func (s *TraceSuite) TestAggregateAllNils() { - s.Nil(NewAggregate(nil, nil, nil)) +func TestAggregateAllNils(t *testing.T) { + assert.Nil(t, NewAggregate(nil, nil, nil)) } -func (s *TraceSuite) TestAggregateFromChannel() { +func TestAggregateFromChannel(t *testing.T) { errCh := make(chan error, 3) errCh <- fmt.Errorf("Snap!") errCh <- fmt.Errorf("BAM") @@ -604,13 +596,13 @@ func (s *TraceSuite) TestAggregateFromChannel() { close(errCh) err := NewAggregateFromChannel(errCh, context.Background()) - s.Require().NotNil(err) - s.Regexp(".*Snap!.*", err.Error()) - s.Regexp(".*BAM.*", err.Error()) - s.Regexp(".*omg.*", err.Error()) + require.NotNil(t, err) + assert.Regexp(t, ".*Snap!.*", err.Error()) + assert.Regexp(t, ".*BAM.*", err.Error()) + assert.Regexp(t, ".*omg.*", err.Error()) } -func (s *TraceSuite) TestAggregateFromChannelCancel() { +func TestAggregateFromChannelCancel(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) errCh := make(chan error) outCh := make(chan error) @@ -625,10 +617,10 @@ func (s *TraceSuite) TestAggregateFromChannelCancel() { cancel() err := <-outCh - s.Error(err) + assert.Error(t, err) } -func (s *TraceSuite) TestCompositeErrorsCanProperlyUnwrap() { +func TestCompositeErrorsCanProperlyUnwrap(t *testing.T) { testCases := []struct { err error message string @@ -652,9 +644,9 @@ func (s *TraceSuite) TestCompositeErrorsCanProperlyUnwrap() { } var wrapper ErrorWrapper for _, tt := range testCases { - s.Equal(tt.message, tt.err.Error()) - s.Implements(&wrapper, Unwrap(tt.err)) - s.Equal(tt.wrappedMessage, Unwrap(tt.err).(ErrorWrapper).OrigError().Error()) + assert.Equal(t, tt.message, tt.err.Error()) + assert.Implements(t, &wrapper, Unwrap(tt.err)) + assert.Equal(t, tt.wrappedMessage, Unwrap(tt.err).(ErrorWrapper).OrigError().Error()) } } @@ -696,7 +688,7 @@ func (tw *testWriter) WriteHeader(code int) { } func line(s string) string { - return strings.Replace(s, "\n", "", -1) + return strings.ReplaceAll(s, "\n", "") } func TestStdlibCompat(t *testing.T) { diff --git a/trail/trail.go b/trail/trail.go index a190130..ac81988 100644 --- a/trail/trail.go +++ b/trail/trail.go @@ -81,46 +81,53 @@ func Send(ctx context.Context, err error) error { const DebugReportMetadata = "trace-debug-report" // ToGRPC converts error to GRPC-compatible error -func ToGRPC(err error) error { - if err == nil { +func ToGRPC(originalErr error) error { + if originalErr == nil { return nil } - if errors.Is(err, io.EOF) { - return err + // Avoid modifying top-level gRPC errors. + if _, ok := status.FromError(originalErr); ok { + return originalErr } - // If err is already a gRPC error, don't modify it. - if _, ok := status.FromError(err); ok { - return err - } + for e := originalErr; e != nil; { + if e == io.EOF { + // Keep legacy semantics and return the original error. + return originalErr + } - userMessage := trace.UserMessage(err) - if trace.IsNotFound(err) { - return status.Errorf(codes.NotFound, userMessage) - } - if trace.IsAlreadyExists(err) { - return status.Errorf(codes.AlreadyExists, userMessage) - } - if trace.IsAccessDenied(err) { - return status.Errorf(codes.PermissionDenied, userMessage) - } - if trace.IsCompareFailed(err) { - return status.Errorf(codes.FailedPrecondition, userMessage) - } - if trace.IsBadParameter(err) || trace.IsOAuth2(err) { - return status.Errorf(codes.InvalidArgument, userMessage) - } - if trace.IsLimitExceeded(err) { - return status.Errorf(codes.ResourceExhausted, userMessage) - } - if trace.IsConnectionProblem(err) { - return status.Errorf(codes.Unavailable, userMessage) - } - if trace.IsNotImplemented(err) { - return status.Errorf(codes.Unimplemented, userMessage) + if s, ok := status.FromError(e); ok { + return status.Errorf(s.Code(), trace.UserMessage(originalErr)) + } + + switch e.(type) { + case *trace.AccessDeniedError: + return status.Errorf(codes.PermissionDenied, trace.UserMessage(originalErr)) + case *trace.AlreadyExistsError: + return status.Errorf(codes.AlreadyExists, trace.UserMessage(originalErr)) + case *trace.BadParameterError: + return status.Errorf(codes.InvalidArgument, trace.UserMessage(originalErr)) + case *trace.CompareFailedError: + return status.Errorf(codes.FailedPrecondition, trace.UserMessage(originalErr)) + case *trace.ConnectionProblemError: + return status.Errorf(codes.Unavailable, trace.UserMessage(originalErr)) + case *trace.LimitExceededError: + return status.Errorf(codes.ResourceExhausted, trace.UserMessage(originalErr)) + case *trace.NotFoundError: + return status.Errorf(codes.NotFound, trace.UserMessage(originalErr)) + case *trace.NotImplementedError: + return status.Errorf(codes.Unimplemented, trace.UserMessage(originalErr)) + case *trace.OAuth2Error: + return status.Errorf(codes.InvalidArgument, trace.UserMessage(originalErr)) + case *trace.RetryError: // Not mapped. + case *trace.TrustError: // Not mapped. + } + + e = errors.Unwrap(e) } - return status.Errorf(codes.Unknown, userMessage) + + return status.Errorf(codes.Unknown, trace.UserMessage(originalErr)) } // FromGRPC converts error from GRPC error back to trace.Error diff --git a/trail/trail_test.go b/trail/trail_test.go index 4c46204..c2bdc7b 100644 --- a/trail/trail_test.go +++ b/trail/trail_test.go @@ -18,30 +18,22 @@ package trail import ( "errors" + "fmt" "io" "strings" "testing" "github.com/gravitational/trace" + "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/suite" - "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" ) -func TestTrail(t *testing.T) { - suite.Run(t, new(TrailSuite)) -} - -type TrailSuite struct { - suite.Suite -} - // TestConversion makes sure we convert all trace supported errors // to and back from GRPC codes -func (s *TrailSuite) TestConversion() { +func TestConversion(t *testing.T) { testCases := []struct { Error error Message string @@ -58,12 +50,8 @@ func (s *TrailSuite) TestConversion() { Predicate: trace.IsAccessDenied, }, { - Error: trace.ConnectionProblem(nil, "problem"), - Predicate: trace.IsConnectionProblem, - }, - { - Error: trace.NotFound("not found"), - Predicate: trace.IsNotFound, + Error: trace.AlreadyExists("already exists"), + Predicate: trace.IsAlreadyExists, }, { Error: trace.BadParameter("bad parameter"), @@ -74,13 +62,17 @@ func (s *TrailSuite) TestConversion() { Predicate: trace.IsCompareFailed, }, { - Error: trace.AccessDenied("denied"), - Predicate: trace.IsAccessDenied, + Error: trace.ConnectionProblem(nil, "problem"), + Predicate: trace.IsConnectionProblem, }, { Error: trace.LimitExceeded("exceeded"), Predicate: trace.IsLimitExceeded, }, + { + Error: trace.NotFound("not found"), + Predicate: trace.IsNotFound, + }, { Error: trace.NotImplemented("not implemented"), Predicate: trace.IsNotImplemented, @@ -89,38 +81,38 @@ func (s *TrailSuite) TestConversion() { for i, tc := range testCases { grpcError := ToGRPC(tc.Error) - s.Equal(tc.Error.Error(), grpc.ErrorDesc(grpcError), "test case %v", i+1) + assert.Equal(t, tc.Error.Error(), status.Convert(grpcError).Message(), "test case %v", i+1) out := FromGRPC(grpcError) - s.True(tc.Predicate(out), "test case %v", i+1) - s.Regexp(".*trail_test.go.*", line(trace.DebugReport(out))) - s.NotRegexp(".*trail.go.*", line(trace.DebugReport(out))) + assert.True(t, tc.Predicate(out), "test case %v", i+1) + assert.Regexp(t, ".*trail_test.go.*", line(trace.DebugReport(out))) + assert.NotRegexp(t, ".*trail.go.*", line(trace.DebugReport(out))) } } // TestNil makes sure conversions of nil to and from GRPC are no-op -func (s *TrailSuite) TestNil() { +func TestNil(t *testing.T) { out := FromGRPC(ToGRPC(nil)) - s.Nil(out) + assert.Nil(t, out) } // TestFromEOF makes sure that non-grpc error such as io.EOF is preserved well. -func (s *TrailSuite) TestFromEOF() { +func TestFromEOF(t *testing.T) { out := FromGRPC(trace.Wrap(io.EOF)) - s.True(trace.IsEOF(out)) + assert.True(t, trace.IsEOF(out)) } // TestTraces makes sure we pass traces via metadata and can decode it back -func (s *TrailSuite) TestTraces() { +func TestTraces(t *testing.T) { err := trace.BadParameter("param") meta := metadata.New(nil) SetDebugInfo(err, meta) err2 := FromGRPC(ToGRPC(err), meta) - s.Regexp(".*trail_test.go.*", line(trace.DebugReport(err))) - s.Regexp(".*trail_test.go.*", line(trace.DebugReport(err2))) + assert.Regexp(t, ".*trail_test.go.*", line(trace.DebugReport(err))) + assert.Regexp(t, ".*trail_test.go.*", line(trace.DebugReport(err2))) } func line(s string) string { - return strings.Replace(s, "\n", "", -1) + return strings.ReplaceAll(s, "\n", "") } func TestToGRPCKeepCode(t *testing.T) { @@ -134,3 +126,43 @@ func TestToGRPCKeepCode(t *testing.T) { t.Errorf("after FromGRPC, trace.IsAccessDenied is false, want true, error: %v", err) } } + +func TestToGRPC_statusError(t *testing.T) { + err1 := status.Errorf(codes.NotFound, "not found") + err2 := fmt.Errorf("go wrap: %w", trace.Wrap(err1)) + + tests := []struct { + name string + err error + want error + }{ + { + name: "unwrapped status", + err: err1, + want: err1, // Exact same error. + }, + { + name: "wrapped status", + err: err2, + want: status.Errorf(codes.NotFound, err2.Error()), + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + err := ToGRPC(test.err) + + got, ok := status.FromError(err) + if !ok { + t.Fatalf("Failed to convert `got` to a status.Status: %#v", err) + } + want, ok := status.FromError(test.want) + if !ok { + t.Fatalf("Failed to convert `want` to a status.Status: %#v", err) + } + + if got.Code() != want.Code() || got.Message() != want.Message() { + t.Errorf("ToGRPC = %#v, want %#v", got, test.want) + } + }) + } +} diff --git a/udphook_test.go b/udphook_test.go index 9f76076..53dada0 100644 --- a/udphook_test.go +++ b/udphook_test.go @@ -17,7 +17,7 @@ limitations under the License. package trace import ( - "io/ioutil" + "io" "testing" "github.com/jonboulle/clockwork" @@ -35,7 +35,7 @@ func TestHooks(t *testing.T) { func (s *HooksSuite) TestSafeForConcurrentAccess() { logger := log.New() - logger.Out = ioutil.Discard + logger.Out = io.Discard entry := logger.WithFields(log.Fields{"foo": "bar"}) logger.Hooks.Add(&UDPHook{Clock: clockwork.NewFakeClock()}) for i := 0; i < 3; i++ {