Skip to content

Commit

Permalink
fix: Handle aggregates correctly in trail.ToGRPC (#98)
Browse files Browse the repository at this point in the history
Handle aggregates (and `os.IsNotExist`) in conversion to status errors.

Aggregates now implement `Unwrap() []error`, the same method provided by
`errors.Join`. This makes it unnecessary to directly implement As() and Is(),
since the corresponding errors methods can use the sliced-Unwrap to traverse the
errors.

Both trail and IsNotFound now check for Unwrap() []error in their loops, making
them correctly support aggregates (and std-joined errors).
  • Loading branch information
codingllama authored Aug 7, 2023
1 parent 1cff453 commit 6b0832f
Show file tree
Hide file tree
Showing 6 changed files with 133 additions and 72 deletions.
26 changes: 8 additions & 18 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ import (
"net"
"net/url"
"os"

"github.com/gravitational/trace/internal"
)

// NotFound returns new instance of not found error
Expand Down Expand Up @@ -71,27 +73,15 @@ func (e *NotFoundError) Is(target error) bool {
}

// IsNotFound returns true if `e` contains a [NotFoundError] in its chain.
func IsNotFound(e error) bool {
for e != nil {
switch e := e.(type) {
case *NotFoundError:
return true

// Aggregates and other errors.
case interface{ As(interface{}) bool }:
nfe := &NotFoundError{}
if e.As(&nfe) {
return true
}
}

if os.IsNotExist(e) {
func IsNotFound(err error) bool {
return internal.TraverseErr(err, func(err error) (ok bool) {
if os.IsNotExist(err) {
return true
}

e = errors.Unwrap(e)
}
return false
_, ok = err.(*NotFoundError)
return ok
})
}

// AlreadyExists returns a new instance of AlreadyExists error
Expand Down
28 changes: 28 additions & 0 deletions internal/traverse.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package internal

// TraverseErr traverses the err error chain until fn returns true.
// Traversal stops on nil errors, fn(nil) is never called.
// Returns true if fn matched, false otherwise.
func TraverseErr(err error, fn func(error) (ok bool)) (ok bool) {
if err == nil {
return false
}

if fn(err) {
return true
}

switch err := err.(type) {
case interface{ Unwrap() error }:
return TraverseErr(err.Unwrap(), fn)

case interface{ Unwrap() []error }:
for _, err2 := range err.Unwrap() {
if TraverseErr(err2, fn) {
return true
}
}
}

return false
}
7 changes: 7 additions & 0 deletions trace.go
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,7 @@ func (r aggregate) Error() string {

// Is implements the `Is` interface, by iterating through each error in the
// aggregate and invoking `errors.Is`.
// Required for Go versions < 1.20 (newer releases support "Unwrap() []error").
func (r aggregate) Is(t error) bool {
for _, err := range r {
if errors.Is(err, t) {
Expand All @@ -506,6 +507,7 @@ func (r aggregate) Is(t error) bool {

// As implements the `As` interface, by iterating through each error in the
// aggregate and invoking `errors.As`.
// Required for Go versions < 1.20 (newer releases support "Unwrap() []error").
func (r aggregate) As(t interface{}) bool {
for _, err := range r {
if errors.As(err, t) {
Expand All @@ -522,6 +524,11 @@ func (r aggregate) Errors() []error {
return cp
}

// Unwrap returns the underlying aggregated errors.
func (r aggregate) Unwrap() []error {
return r.Errors()
}

// IsAggregate returns true if `err` contains an [Aggregate] error in its
// chain.
func IsAggregate(err error) bool {
Expand Down
4 changes: 2 additions & 2 deletions trace_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -747,11 +747,11 @@ func TestAggregate_StdlibCompat(t *testing.T) {
assert.NotErrorIs(t, agg, randomErr)

var badParamErrTarget *BadParameterError
assert.ErrorAs(t, agg, &badParamErrTarget)
require.ErrorAs(t, agg, &badParamErrTarget)
assert.Equal(t, bpMsg, badParamErrTarget.Message, "BadParameter message mismatch")

var notFoundTarget *NotFoundError
assert.False(t, errors.As(agg, &notFoundTarget), "Aggregate does not contain a NotFoundError")
require.False(t, errors.As(agg, &notFoundTarget), "Aggregate does not contain a NotFoundError")
}

func TestIsAggregate(t *testing.T) {
Expand Down
57 changes: 36 additions & 21 deletions trail/trail.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ package trail
import (
"encoding/base64"
"encoding/json"
"errors"
"io"
"os"

"github.com/gravitational/trace"
"github.com/gravitational/trace/internal"
Expand Down Expand Up @@ -91,43 +91,58 @@ func ToGRPC(originalErr error) error {
return originalErr
}

for e := originalErr; e != nil; {
if e == io.EOF {
code := codes.Unknown
returnOriginal := false
internal.TraverseErr(originalErr, func(err error) (ok bool) {
if err == io.EOF {
// Keep legacy semantics and return the original error.
return originalErr
returnOriginal = true
return true
}

if s, ok := status.FromError(e); ok {
return status.Errorf(s.Code(), trace.UserMessage(originalErr))
if s, ok := status.FromError(err); ok {
code = s.Code()
return true
}

switch e.(type) {
// Duplicate check from trace.IsNotFound.
if os.IsNotExist(err) {
code = codes.NotFound
return true
}

ok = true // Assume match
switch err.(type) {
case *trace.AccessDeniedError:
return status.Errorf(codes.PermissionDenied, trace.UserMessage(originalErr))
code = codes.PermissionDenied
case *trace.AlreadyExistsError:
return status.Errorf(codes.AlreadyExists, trace.UserMessage(originalErr))
code = codes.AlreadyExists
case *trace.BadParameterError:
return status.Errorf(codes.InvalidArgument, trace.UserMessage(originalErr))
code = codes.InvalidArgument
case *trace.CompareFailedError:
return status.Errorf(codes.FailedPrecondition, trace.UserMessage(originalErr))
code = codes.FailedPrecondition
case *trace.ConnectionProblemError:
return status.Errorf(codes.Unavailable, trace.UserMessage(originalErr))
code = codes.Unavailable
case *trace.LimitExceededError:
return status.Errorf(codes.ResourceExhausted, trace.UserMessage(originalErr))
code = codes.ResourceExhausted
case *trace.NotFoundError:
return status.Errorf(codes.NotFound, trace.UserMessage(originalErr))
code = codes.NotFound
case *trace.NotImplementedError:
return status.Errorf(codes.Unimplemented, trace.UserMessage(originalErr))
code = codes.Unimplemented
case *trace.OAuth2Error:
return status.Errorf(codes.InvalidArgument, trace.UserMessage(originalErr))
case *trace.RetryError: // Not mapped.
case *trace.TrustError: // Not mapped.
code = codes.InvalidArgument
// *trace.RetryError not mapped.
// *trace.TrustError not mapped.
default:
ok = false
}

e = errors.Unwrap(e)
return ok
})
if returnOriginal {
return originalErr
}

return status.Errorf(codes.Unknown, trace.UserMessage(originalErr))
return status.Error(code, trace.UserMessage(originalErr))
}

// FromGRPC converts error from GRPC error back to trace.Error
Expand Down
83 changes: 52 additions & 31 deletions trail/trail_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"errors"
"fmt"
"io"
"os"
"strings"
"testing"

Expand All @@ -34,58 +35,78 @@ import (
// TestConversion makes sure we convert all trace supported errors
// to and back from GRPC codes
func TestConversion(t *testing.T) {
testCases := []struct {
Error error
Message string
Predicate func(error) bool
tests := []struct {
name string
err error
fn func(error) bool
}{
{
Error: io.EOF,
Predicate: func(err error) bool {
return errors.Is(err, io.EOF)
},
name: "io.EOF",
err: io.EOF,
fn: func(err error) bool { return errors.Is(err, io.EOF) },
},
{
Error: trace.AccessDenied("access denied"),
Predicate: trace.IsAccessDenied,
name: "os.ErrNotExist",
err: os.ErrNotExist,
fn: trace.IsNotFound,
},
{
Error: trace.AlreadyExists("already exists"),
Predicate: trace.IsAlreadyExists,
name: "AccessDenied",
err: trace.AccessDenied("access denied"),
fn: trace.IsAccessDenied,
},
{
Error: trace.BadParameter("bad parameter"),
Predicate: trace.IsBadParameter,
name: "AlreadyExists",
err: trace.AlreadyExists("already exists"),
fn: trace.IsAlreadyExists,
},
{
Error: trace.CompareFailed("compare failed"),
Predicate: trace.IsCompareFailed,
name: "BadParameter",
err: trace.BadParameter("bad parameter"),
fn: trace.IsBadParameter,
},
{
Error: trace.ConnectionProblem(nil, "problem"),
Predicate: trace.IsConnectionProblem,
name: "CompareFailed",
err: trace.CompareFailed("compare failed"),
fn: trace.IsCompareFailed,
},
{
Error: trace.LimitExceeded("exceeded"),
Predicate: trace.IsLimitExceeded,
name: "ConnectionProblem",
err: trace.ConnectionProblem(nil, "problem"),
fn: trace.IsConnectionProblem,
},
{
Error: trace.NotFound("not found"),
Predicate: trace.IsNotFound,
name: "LimitExceeded",
err: trace.LimitExceeded("exceeded"),
fn: trace.IsLimitExceeded,
},
{
Error: trace.NotImplemented("not implemented"),
Predicate: trace.IsNotImplemented,
name: "NotFound",
err: trace.NotFound("not found"),
fn: trace.IsNotFound,
},
{
name: "NotImplemented",
err: trace.NotImplemented("not implemented"),
fn: trace.IsNotImplemented,
},
{
name: "Aggregated BadParameter",
err: trace.NewAggregate(trace.BadParameter("bad parameter")),
fn: trace.IsBadParameter,
},
}

for i, tc := range testCases {
grpcError := ToGRPC(tc.Error)
assert.Equal(t, tc.Error.Error(), status.Convert(grpcError).Message(), "test case %v", i+1)
out := FromGRPC(grpcError)
assert.True(t, tc.Predicate(out), "test case %v", i+1)
assert.Regexp(t, ".*trail_test.go.*", line(trace.DebugReport(out)))
assert.NotRegexp(t, ".*trail.go.*", line(trace.DebugReport(out)))
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
grpcError := ToGRPC(test.err)
assert.Equal(t, test.err.Error(), status.Convert(grpcError).Message(), "Error message mismatch")

out := FromGRPC(grpcError)
assert.True(t, test.fn(out), "Predicate failed")
assert.Regexp(t, ".*trail_test.go.*", line(trace.DebugReport(out)))
assert.NotRegexp(t, ".*trail.go.*", line(trace.DebugReport(out)))
})
}
}

Expand Down

0 comments on commit 6b0832f

Please sign in to comment.