diff --git a/errors.go b/errors.go index 4101c69..ab92b7a 100644 --- a/errors.go +++ b/errors.go @@ -24,6 +24,8 @@ import ( "net" "net/url" "os" + + "github.com/gravitational/trace/internal" ) // NotFound returns new instance of not found error @@ -71,27 +73,15 @@ func (e *NotFoundError) Is(target error) bool { } // IsNotFound returns true if `e` contains a [NotFoundError] in its chain. -func IsNotFound(e error) bool { - for e != nil { - switch e := e.(type) { - case *NotFoundError: - return true - - // Aggregates and other errors. - case interface{ As(interface{}) bool }: - nfe := &NotFoundError{} - if e.As(&nfe) { - return true - } - } - - if os.IsNotExist(e) { +func IsNotFound(err error) bool { + return internal.TraverseErr(err, func(err error) (ok bool) { + if os.IsNotExist(err) { return true } - e = errors.Unwrap(e) - } - return false + _, ok = err.(*NotFoundError) + return ok + }) } // AlreadyExists returns a new instance of AlreadyExists error diff --git a/internal/traverse.go b/internal/traverse.go new file mode 100644 index 0000000..2df44a6 --- /dev/null +++ b/internal/traverse.go @@ -0,0 +1,28 @@ +package internal + +// TraverseErr traverses the err error chain until fn returns true. +// Traversal stops on nil errors, fn(nil) is never called. +// Returns true if fn matched, false otherwise. +func TraverseErr(err error, fn func(error) (ok bool)) (ok bool) { + if err == nil { + return false + } + + if fn(err) { + return true + } + + switch err := err.(type) { + case interface{ Unwrap() error }: + return TraverseErr(err.Unwrap(), fn) + + case interface{ Unwrap() []error }: + for _, err2 := range err.Unwrap() { + if TraverseErr(err2, fn) { + return true + } + } + } + + return false +} diff --git a/trace.go b/trace.go index d8307cd..d2a4956 100644 --- a/trace.go +++ b/trace.go @@ -495,6 +495,7 @@ func (r aggregate) Error() string { // Is implements the `Is` interface, by iterating through each error in the // aggregate and invoking `errors.Is`. +// Required for Go versions < 1.20 (newer releases support "Unwrap() []error"). func (r aggregate) Is(t error) bool { for _, err := range r { if errors.Is(err, t) { @@ -506,6 +507,7 @@ func (r aggregate) Is(t error) bool { // As implements the `As` interface, by iterating through each error in the // aggregate and invoking `errors.As`. +// Required for Go versions < 1.20 (newer releases support "Unwrap() []error"). func (r aggregate) As(t interface{}) bool { for _, err := range r { if errors.As(err, t) { @@ -522,6 +524,11 @@ func (r aggregate) Errors() []error { return cp } +// Unwrap returns the underlying aggregated errors. +func (r aggregate) Unwrap() []error { + return r.Errors() +} + // IsAggregate returns true if `err` contains an [Aggregate] error in its // chain. func IsAggregate(err error) bool { diff --git a/trace_test.go b/trace_test.go index 8c88b7b..48e9233 100644 --- a/trace_test.go +++ b/trace_test.go @@ -747,11 +747,11 @@ func TestAggregate_StdlibCompat(t *testing.T) { assert.NotErrorIs(t, agg, randomErr) var badParamErrTarget *BadParameterError - assert.ErrorAs(t, agg, &badParamErrTarget) + require.ErrorAs(t, agg, &badParamErrTarget) assert.Equal(t, bpMsg, badParamErrTarget.Message, "BadParameter message mismatch") var notFoundTarget *NotFoundError - assert.False(t, errors.As(agg, ¬FoundTarget), "Aggregate does not contain a NotFoundError") + require.False(t, errors.As(agg, ¬FoundTarget), "Aggregate does not contain a NotFoundError") } func TestIsAggregate(t *testing.T) { diff --git a/trail/trail.go b/trail/trail.go index ac81988..b377c72 100644 --- a/trail/trail.go +++ b/trail/trail.go @@ -42,8 +42,8 @@ package trail import ( "encoding/base64" "encoding/json" - "errors" "io" + "os" "github.com/gravitational/trace" "github.com/gravitational/trace/internal" @@ -91,43 +91,58 @@ func ToGRPC(originalErr error) error { return originalErr } - for e := originalErr; e != nil; { - if e == io.EOF { + code := codes.Unknown + returnOriginal := false + internal.TraverseErr(originalErr, func(err error) (ok bool) { + if err == io.EOF { // Keep legacy semantics and return the original error. - return originalErr + returnOriginal = true + return true } - if s, ok := status.FromError(e); ok { - return status.Errorf(s.Code(), trace.UserMessage(originalErr)) + if s, ok := status.FromError(err); ok { + code = s.Code() + return true } - switch e.(type) { + // Duplicate check from trace.IsNotFound. + if os.IsNotExist(err) { + code = codes.NotFound + return true + } + + ok = true // Assume match + switch err.(type) { case *trace.AccessDeniedError: - return status.Errorf(codes.PermissionDenied, trace.UserMessage(originalErr)) + code = codes.PermissionDenied case *trace.AlreadyExistsError: - return status.Errorf(codes.AlreadyExists, trace.UserMessage(originalErr)) + code = codes.AlreadyExists case *trace.BadParameterError: - return status.Errorf(codes.InvalidArgument, trace.UserMessage(originalErr)) + code = codes.InvalidArgument case *trace.CompareFailedError: - return status.Errorf(codes.FailedPrecondition, trace.UserMessage(originalErr)) + code = codes.FailedPrecondition case *trace.ConnectionProblemError: - return status.Errorf(codes.Unavailable, trace.UserMessage(originalErr)) + code = codes.Unavailable case *trace.LimitExceededError: - return status.Errorf(codes.ResourceExhausted, trace.UserMessage(originalErr)) + code = codes.ResourceExhausted case *trace.NotFoundError: - return status.Errorf(codes.NotFound, trace.UserMessage(originalErr)) + code = codes.NotFound case *trace.NotImplementedError: - return status.Errorf(codes.Unimplemented, trace.UserMessage(originalErr)) + code = codes.Unimplemented case *trace.OAuth2Error: - return status.Errorf(codes.InvalidArgument, trace.UserMessage(originalErr)) - case *trace.RetryError: // Not mapped. - case *trace.TrustError: // Not mapped. + code = codes.InvalidArgument + // *trace.RetryError not mapped. + // *trace.TrustError not mapped. + default: + ok = false } - - e = errors.Unwrap(e) + return ok + }) + if returnOriginal { + return originalErr } - return status.Errorf(codes.Unknown, trace.UserMessage(originalErr)) + return status.Error(code, trace.UserMessage(originalErr)) } // FromGRPC converts error from GRPC error back to trace.Error diff --git a/trail/trail_test.go b/trail/trail_test.go index c2bdc7b..63f4faa 100644 --- a/trail/trail_test.go +++ b/trail/trail_test.go @@ -20,6 +20,7 @@ import ( "errors" "fmt" "io" + "os" "strings" "testing" @@ -34,58 +35,78 @@ import ( // TestConversion makes sure we convert all trace supported errors // to and back from GRPC codes func TestConversion(t *testing.T) { - testCases := []struct { - Error error - Message string - Predicate func(error) bool + tests := []struct { + name string + err error + fn func(error) bool }{ { - Error: io.EOF, - Predicate: func(err error) bool { - return errors.Is(err, io.EOF) - }, + name: "io.EOF", + err: io.EOF, + fn: func(err error) bool { return errors.Is(err, io.EOF) }, }, { - Error: trace.AccessDenied("access denied"), - Predicate: trace.IsAccessDenied, + name: "os.ErrNotExist", + err: os.ErrNotExist, + fn: trace.IsNotFound, }, { - Error: trace.AlreadyExists("already exists"), - Predicate: trace.IsAlreadyExists, + name: "AccessDenied", + err: trace.AccessDenied("access denied"), + fn: trace.IsAccessDenied, }, { - Error: trace.BadParameter("bad parameter"), - Predicate: trace.IsBadParameter, + name: "AlreadyExists", + err: trace.AlreadyExists("already exists"), + fn: trace.IsAlreadyExists, }, { - Error: trace.CompareFailed("compare failed"), - Predicate: trace.IsCompareFailed, + name: "BadParameter", + err: trace.BadParameter("bad parameter"), + fn: trace.IsBadParameter, }, { - Error: trace.ConnectionProblem(nil, "problem"), - Predicate: trace.IsConnectionProblem, + name: "CompareFailed", + err: trace.CompareFailed("compare failed"), + fn: trace.IsCompareFailed, }, { - Error: trace.LimitExceeded("exceeded"), - Predicate: trace.IsLimitExceeded, + name: "ConnectionProblem", + err: trace.ConnectionProblem(nil, "problem"), + fn: trace.IsConnectionProblem, }, { - Error: trace.NotFound("not found"), - Predicate: trace.IsNotFound, + name: "LimitExceeded", + err: trace.LimitExceeded("exceeded"), + fn: trace.IsLimitExceeded, }, { - Error: trace.NotImplemented("not implemented"), - Predicate: trace.IsNotImplemented, + name: "NotFound", + err: trace.NotFound("not found"), + fn: trace.IsNotFound, + }, + { + name: "NotImplemented", + err: trace.NotImplemented("not implemented"), + fn: trace.IsNotImplemented, + }, + { + name: "Aggregated BadParameter", + err: trace.NewAggregate(trace.BadParameter("bad parameter")), + fn: trace.IsBadParameter, }, } - for i, tc := range testCases { - grpcError := ToGRPC(tc.Error) - assert.Equal(t, tc.Error.Error(), status.Convert(grpcError).Message(), "test case %v", i+1) - out := FromGRPC(grpcError) - assert.True(t, tc.Predicate(out), "test case %v", i+1) - assert.Regexp(t, ".*trail_test.go.*", line(trace.DebugReport(out))) - assert.NotRegexp(t, ".*trail.go.*", line(trace.DebugReport(out))) + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + grpcError := ToGRPC(test.err) + assert.Equal(t, test.err.Error(), status.Convert(grpcError).Message(), "Error message mismatch") + + out := FromGRPC(grpcError) + assert.True(t, test.fn(out), "Predicate failed") + assert.Regexp(t, ".*trail_test.go.*", line(trace.DebugReport(out))) + assert.NotRegexp(t, ".*trail.go.*", line(trace.DebugReport(out))) + }) } }