From 3daf0a980a646f0bcfc48460baed682e9891cb6d Mon Sep 17 00:00:00 2001 From: Austin Burdine Date: Wed, 11 Dec 2019 15:55:26 -0500 Subject: [PATCH] feat: support context property in resolver --- invocation.go | 6 +++--- invocation_test.go | 4 ++-- payload.go | 6 +++--- payload_test.go | 11 +++++------ repository.go | 9 ++++++--- repository_test.go | 41 +++++++++++++++++++++++++++++++++-------- resolver.go | 31 +++++++++++++++++++++++-------- resolver_test.go | 13 +++++++++---- validate.go | 16 ++++++++++++---- 9 files changed, 96 insertions(+), 41 deletions(-) diff --git a/invocation.go b/invocation.go index 78b4515..0b77f3c 100644 --- a/invocation.go +++ b/invocation.go @@ -2,14 +2,14 @@ package resolvers import "encoding/json" -type context struct { +type appsyncContext struct { Arguments json.RawMessage `json:"arguments"` Source json.RawMessage `json:"source"` } type invocation struct { - Resolve string `json:"resolve"` - Context context `json:"context"` + Resolve string `json:"resolve"` + Context appsyncContext `json:"context"` } func (in invocation) isRoot() bool { diff --git a/invocation_test.go b/invocation_test.go index c7ceccc..95a6e3a 100644 --- a/invocation_test.go +++ b/invocation_test.go @@ -11,7 +11,7 @@ var _ = Describe("Invocation", func() { Context("With Arguments", func() { data := invocation{ Resolve: "exaple.resolver", - Context: context{ + Context: appsyncContext{ Arguments: json.RawMessage(`{ "foo": "bar" }`), }, } @@ -28,7 +28,7 @@ var _ = Describe("Invocation", func() { Context("With Source", func() { data := invocation{ Resolve: "exaple.resolver", - Context: context{ + Context: appsyncContext{ Source: json.RawMessage(`{ "bar": "foo" }`), }, } diff --git a/payload.go b/payload.go index cd8049e..d064b2c 100644 --- a/payload.go +++ b/payload.go @@ -10,12 +10,12 @@ type payload struct { Message json.RawMessage } -func (p payload) parse(argsType reflect.Type) ([]reflect.Value, error) { +func (p payload) parse(argsType reflect.Type) (reflect.Value, error) { args := reflect.New(argsType) if err := json.Unmarshal(p.Message, args.Interface()); err != nil { - return nil, fmt.Errorf("Unable to prepare payload: %s", err.Error()) + return args.Elem(), fmt.Errorf("Unable to prepare payload: %s", err.Error()) } - return append([]reflect.Value{}, args.Elem()), nil + return args.Elem(), nil } diff --git a/payload_test.go b/payload_test.go index 7972e5f..d8ffb02 100644 --- a/payload_test.go +++ b/payload_test.go @@ -23,8 +23,8 @@ var _ = Describe("Payload", func() { Expect(err).To(HaveOccurred()) }) - It("should return nil", func() { - Expect(args).To(BeNil()) + It("should not return nil", func() { + Expect(args).NotTo(BeNil()) }) }) @@ -36,19 +36,18 @@ var _ = Describe("Payload", func() { return nil }} - args, err := message.parse(reflect.TypeOf(example.function).In(0)) + arg, err := message.parse(reflect.TypeOf(example.function).In(0)) It("should not error", func() { Expect(err).NotTo(HaveOccurred()) }) It("should return struct", func() { - Expect(args).NotTo(BeNil()) + Expect(arg).NotTo(BeNil()) }) It("should parse data", func() { - Expect(args).To(HaveLen(1)) - Expect(args[0].FieldByName("Name").String()).To(Equal("example")) + Expect(arg.FieldByName("Name").String()).To(Equal("example")) }) }) }) diff --git a/repository.go b/repository.go index 397b862..28c23fd 100644 --- a/repository.go +++ b/repository.go @@ -1,6 +1,7 @@ package resolvers import ( + "context" "fmt" "reflect" ) @@ -13,18 +14,20 @@ func (r Repository) Add(resolve string, handler interface{}) error { err := validators.run(reflect.TypeOf(handler)) if err == nil { - r[resolve] = resolver{handler} + r[resolve] = resolver{ + function: handler, + } } return err } // Handle responds to the AppSync request -func (r Repository) Handle(in invocation) (interface{}, error) { +func (r Repository) Handle(ctx context.Context, in invocation) (interface{}, error) { handler, found := r[in.Resolve] if found { - return handler.call(in.payload()) + return handler.call(ctx, in.payload()) } return nil, fmt.Errorf("No resolver found: %s", in.Resolve) diff --git a/repository_test.go b/repository_test.go index 0e20fa7..f96ff9e 100644 --- a/repository_test.go +++ b/repository_test.go @@ -1,6 +1,7 @@ package resolvers import ( + "context" "encoding/json" "errors" @@ -15,14 +16,18 @@ var _ = Describe("Repository", func() { type response struct { Foo string } + type ctxKey string + const testCtxKey ctxKey = "test" + r := New() r.Add("example.resolver", func(arg arguments) (response, error) { return response{"bar"}, nil }) r.Add("example.resolver.with.error", func(arg arguments) (response, error) { return response{"bar"}, errors.New("Has Error") }) + r.Add("example.resolver.with.context", func(ctx context.Context, arg arguments) (interface{}, error) { return ctx.Value(testCtxKey), nil }) Context("Matching invocation", func() { - res, err := r.Handle(invocation{ + res, err := r.Handle(context.Background(), invocation{ Resolve: "example.resolver", - Context: context{ + Context: appsyncContext{ Arguments: json.RawMessage(`{"bar":"foo"}`), }, }) @@ -37,9 +42,9 @@ var _ = Describe("Repository", func() { }) Context("Matching invocation with error", func() { - _, err := r.Handle(invocation{ + _, err := r.Handle(context.Background(), invocation{ Resolve: "example.resolver.with.error", - Context: context{ + Context: appsyncContext{ Arguments: json.RawMessage(`{"bar":"foo"}`), }, }) @@ -50,9 +55,9 @@ var _ = Describe("Repository", func() { }) Context("Matching invocation with invalid payload", func() { - _, err := r.Handle(invocation{ + _, err := r.Handle(context.Background(), invocation{ Resolve: "example.resolver.with.error", - Context: context{ + Context: appsyncContext{ Arguments: json.RawMessage(`{"bar:foo"}`), }, }) @@ -62,10 +67,30 @@ var _ = Describe("Repository", func() { }) }) + Context("Matching invocation with context", func() { + ctx := context.Background() + ctx = context.WithValue(ctx, testCtxKey, "test string") + + res, err := r.Handle(ctx, invocation{ + Resolve: "example.resolver.with.context", + Context: appsyncContext{ + Arguments: json.RawMessage(`{"bar":"foo"}`), + }, + }) + + It("Should not error", func() { + Expect(err).ToNot(HaveOccurred()) + }) + + It("Should have data", func() { + Expect(res.(string)).To(Equal("test string")) + }) + }) + Context("Not matching invocation", func() { - res, err := r.Handle(invocation{ + res, err := r.Handle(context.Background(), invocation{ Resolve: "example.resolver.not.found", - Context: context{ + Context: appsyncContext{ Arguments: json.RawMessage(`{"bar":"foo"}`), }, }) diff --git a/resolver.go b/resolver.go index 440c849..079116f 100644 --- a/resolver.go +++ b/resolver.go @@ -1,6 +1,7 @@ package resolvers import ( + "context" "encoding/json" "reflect" ) @@ -9,21 +10,35 @@ type resolver struct { function interface{} } -func (r *resolver) hasArguments() bool { - return reflect.TypeOf(r.function).NumIn() == 1 +func (r *resolver) hasContext() bool { + return reflect.TypeOf(r.function).NumIn() == 2 } -func (r *resolver) call(p json.RawMessage) (interface{}, error) { - var args []reflect.Value - var err error +func (r *resolver) hasPayload() bool { + return reflect.TypeOf(r.function).NumIn() > 0 +} - if r.hasArguments() { - pld := payload{p} - args, err = pld.parse(reflect.TypeOf(r.function).In(0)) +func (r *resolver) call(ctx context.Context, p json.RawMessage) (interface{}, error) { + args := make([]reflect.Value, 0, 2) + hasContext := r.hasContext() + + if hasContext { + args = append(args, reflect.ValueOf(ctx)) + } + if r.hasPayload() { + var index int + if hasContext { + index = 1 + } + + pld := payload{p} + val, err := pld.parse(reflect.TypeOf(r.function).In(index)) if err != nil { return nil, err } + + args = append(args, val) } returnValues := reflect.ValueOf(r.function).Call(args) diff --git a/resolver_test.go b/resolver_test.go index 5d54692..89805ee 100644 --- a/resolver_test.go +++ b/resolver_test.go @@ -1,6 +1,7 @@ package resolvers import ( + "context" "reflect" . "github.com/onsi/ginkgo" @@ -21,8 +22,10 @@ var _ = Describe("Resolver", func() { Entry("Not a function, but integer", 1234, "Resolver is not a function, got int"), Entry("Not a function, but string", "123", "Resolver is not a function, got string"), - Entry("Parameter is string", func(args string) (interface{}, error) { return nil, nil }, "Resolver argument must be struct"), - Entry("Too many parameters", func(args struct{}, foo struct{}) (interface{}, error) { return nil, nil }, "Resolver must not have more than one argument, got 2"), + Entry("Parameter is string", func(args string) (interface{}, error) { return nil, nil }, "Resolver payload argument must be a struct"), + Entry("Too many parameters", func(args struct{}, foo struct{}, bar struct{}) (interface{}, error) { return nil, nil }, "Resolver must not have more than two arguments, got 3"), + Entry("Parameter is Context", func(ctx context.Context) (interface{}, error) { return nil, nil }, "Resolver payload argument must be a struct"), + Entry("Two parameters and first isn't context", func(args struct{}, foo struct{}) (interface{}, error) { return nil, nil }, "Resolver takes two arguments, but the first argument is not Context"), Entry("No return value", func() {}, "Resolver must have at least one return value"), Entry("Non-error return value", func(args struct{}) interface{} { return nil }, "Last return value must be an error"), @@ -35,8 +38,10 @@ var _ = Describe("Resolver", func() { Expect(validators.run(reflect.TypeOf(r))).NotTo(HaveOccurred()) }, - Entry("With parameter and multiple return values", func(args struct{}) (interface{}, error) { return nil, nil }), - Entry("With parameter and single return value", func(args struct{}) error { return nil }), + Entry("With payload and multiple return values", func(args struct{}) (interface{}, error) { return nil, nil }), + Entry("With payload and single return value", func(args struct{}) error { return nil }), + Entry("With payload and context and multiple return values", func(ctx context.Context, args struct{}) (interface{}, error) { return nil, nil }), + Entry("With payload and context and single return value", func(ctx context.Context, args struct{}) error { return nil }), Entry("Without parameter, but multiple return values", func() (interface{}, error) { return nil, nil }), Entry("Without parameter, but single return value", func() error { return nil }), ) diff --git a/validate.go b/validate.go index b866e47..7d821da 100644 --- a/validate.go +++ b/validate.go @@ -1,6 +1,7 @@ package resolvers import ( + "context" "errors" "fmt" "reflect" @@ -28,15 +29,22 @@ var validators = validateList{ return nil }, func(h reflect.Type) error { - if num := h.NumIn(); num > 1 { - return fmt.Errorf("Resolver must not have more than one argument, got %v", num) + if num := h.NumIn(); num > 2 { + return fmt.Errorf("Resolver must not have more than two arguments, got %v", num) } return nil }, func(h reflect.Type) error { - if h.NumIn() == 1 && h.In(0).Kind() != reflect.Struct { - return errors.New("Resolver argument must be struct") + if h.NumIn() == 2 && !h.In(0).Implements(reflect.TypeOf((*context.Context)(nil)).Elem()) { + return errors.New("Resolver takes two arguments, but the first argument is not Context") + } + + return nil + }, + func(h reflect.Type) error { + if h.NumIn() > 0 && h.In(h.NumIn()-1).Kind() != reflect.Struct { + return errors.New("Resolver payload argument must be a struct") } return nil