From 2179652c4f11d9a7d42fc4d6d3ec3bbdd58355be Mon Sep 17 00:00:00 2001 From: Quint Daenen Date: Mon, 6 May 2024 11:39:41 +0200 Subject: [PATCH] Properly handle call/query states. --- agent.go | 15 +++++- cmd/goic/main.go | 19 ++++++-- ic/testdata/gen.go | 2 +- pocketic/README.md | 2 +- pocketic/agent_test.go | 18 +++++++ pocketic/blobstore.go | 1 - pocketic/do.go | 99 +++++++++++++++++++++++++++++++++++++++ pocketic/gen.go | 2 +- pocketic/http.go | 26 +++++----- pocketic/instances.go | 10 ---- pocketic/management.go | 1 - pocketic/pocketic.go | 36 ++++++++++---- pocketic/pocketic_test.go | 31 ++++++++++++ pocketic/request.go | 28 ----------- 14 files changed, 220 insertions(+), 70 deletions(-) create mode 100644 pocketic/do.go diff --git a/agent.go b/agent.go index de92173..76b38c9 100644 --- a/agent.go +++ b/agent.go @@ -443,12 +443,23 @@ type Call struct { data []byte } +// Call calls a method on a canister, it does not wait for the result. +func (c Call) Call() error { + c.a.logger.Printf("[AGENT] CALL %s %s (%x)", c.effectiveCanisterID, c.methodName, c.requestID) + _, err := c.a.call(c.effectiveCanisterID, c.data) + return err +} + // CallAndWait calls a method on a canister and waits for the result. func (c Call) CallAndWait(values ...any) error { - c.a.logger.Printf("[AGENT] CALL %s %s (%x)", c.effectiveCanisterID, c.methodName, c.requestID) - if _, err := c.a.call(c.effectiveCanisterID, c.data); err != nil { + if err := c.Call(); err != nil { return err } + return c.Wait(values...) +} + +// Wait waits for the result of the call and unmarshals it into the given values. +func (c Call) Wait(values ...any) error { raw, err := c.a.poll(c.effectiveCanisterID, c.requestID) if err != nil { return err diff --git a/cmd/goic/main.go b/cmd/goic/main.go index 203cd81..0200efc 100644 --- a/cmd/goic/main.go +++ b/cmd/goic/main.go @@ -70,6 +70,10 @@ var root = cmd.NewCommandFork( Name: "packageName", HasValue: true, }, + { + Name: "indirect", + HasValue: false, + }, }, func(args []string, options map[string]string) error { inputPath := args[0] @@ -89,7 +93,8 @@ var root = cmd.NewCommandFork( packageName = p } - return writeDID(canisterName, packageName, path, rawDID) + _, indirect := options["indirect"] + return writeDID(canisterName, packageName, path, rawDID, indirect) }, ), cmd.NewCommand( @@ -105,6 +110,10 @@ var root = cmd.NewCommandFork( Name: "packageName", HasValue: true, }, + { + Name: "indirect", + HasValue: false, + }, }, func(args []string, options map[string]string) error { id := args[0] @@ -128,7 +137,8 @@ var root = cmd.NewCommandFork( packageName = p } - return writeDID(canisterName, packageName, path, rawDID) + _, indirect := options["indirect"] + return writeDID(canisterName, packageName, path, rawDID, indirect) }, ), ), @@ -156,11 +166,14 @@ func main() { } } -func writeDID(canisterName, packageName, outputPath string, rawDID []byte) error { +func writeDID(canisterName, packageName, outputPath string, rawDID []byte, indirect bool) error { g, err := gen.NewGenerator("", canisterName, packageName, rawDID) if err != nil { return err } + if indirect { + g.Indirect() + } raw, err := g.Generate() if err != nil { return err diff --git a/ic/testdata/gen.go b/ic/testdata/gen.go index 2def76f..dc0d596 100644 --- a/ic/testdata/gen.go +++ b/ic/testdata/gen.go @@ -107,7 +107,7 @@ func main() { log.Panic(err) } if name == "ic" { - g = g.Indirect() + g.Indirect() } raw, err := g.Generate() if err != nil { diff --git a/pocketic/README.md b/pocketic/README.md index f30566f..36398f7 100644 --- a/pocketic/README.md +++ b/pocketic/README.md @@ -16,7 +16,7 @@ The client is not yet stable and is subject to change. | ✅ | POST | /blobstore | | ✅ | GET | /blobstore/{id} | | ✅ | POST | /verify_signature | -| ❌ | GET | /read_graph/{state_label}/{op_id} | +| ✳️ | GET | /read_graph/{state_label}/{op_id} | | ✅ | GET | /instances/ | | ✅ | POST | /instances/ | | ✅ | DELETE | /instances/{id} | diff --git a/pocketic/agent_test.go b/pocketic/agent_test.go index 3e9d0f3..f735194 100755 --- a/pocketic/agent_test.go +++ b/pocketic/agent_test.go @@ -40,6 +40,15 @@ func (a Agent) HelloQuery(arg0 string) (*string, error) { return &r0, nil } +// HelloQueryQuery creates an indirect representation of the "helloQuery" method on the "hello" canister. +func (a Agent) HelloQueryQuery(arg0 string) (*agent.Query, error) { + return a.a.CreateQuery( + a.canisterId, + "helloQuery", + arg0, + ) +} + // HelloUpdate calls the "helloUpdate" method on the "hello" canister. func (a Agent) HelloUpdate(arg0 string) (*string, error) { var r0 string @@ -53,3 +62,12 @@ func (a Agent) HelloUpdate(arg0 string) (*string, error) { } return &r0, nil } + +// HelloUpdateCall creates an indirect representation of the "helloUpdate" method on the "hello" canister. +func (a Agent) HelloUpdateCall(arg0 string) (*agent.Call, error) { + return a.a.CreateCall( + a.canisterId, + "helloUpdate", + arg0, + ) +} diff --git a/pocketic/blobstore.go b/pocketic/blobstore.go index de64580..646db8e 100644 --- a/pocketic/blobstore.go +++ b/pocketic/blobstore.go @@ -13,7 +13,6 @@ func (pic PocketIC) GetBlob(blobID []byte) ([]byte, error) { if err := pic.do( http.MethodGet, fmt.Sprintf("%s/blobstore/%s", pic.server.URL(), hex.EncodeToString(blobID)), - http.StatusOK, nil, &bytes, ); err != nil { diff --git a/pocketic/do.go b/pocketic/do.go new file mode 100644 index 0000000..8142548 --- /dev/null +++ b/pocketic/do.go @@ -0,0 +1,99 @@ +package pocketic + +import ( + "encoding/json" + "fmt" + "net/http" + "time" +) + +func (pic PocketIC) do(method, url string, input, output any) error { + start := time.Now() + for { + if pic.timeout < time.Since(start) { + return fmt.Errorf("timeout exceeded") + } + + pic.logger.Printf("[POCKETIC] %s %s %+v", method, url, input) + req, err := newRequest(method, url, input) + if err != nil { + return err + } + resp, err := pic.client.Do(req) + if err != nil { + return err + } + switch resp.StatusCode { + case http.StatusOK, http.StatusCreated: + if resp.Body == nil || output == nil { + // No need to decode the response body. + return nil + } + return json.NewDecoder(resp.Body).Decode(output) + case http.StatusAccepted: + var response startedOrBusyResponse + if err := json.NewDecoder(resp.Body).Decode(&response); err != nil { + return err + } + if method == http.MethodGet { + continue + } + for { + if pic.timeout < time.Since(start) { + return fmt.Errorf("timeout exceeded") + } + + req, err := newRequest( + http.MethodGet, + fmt.Sprintf( + "%s/read_graph/%s/%s", + pic.server.URL(), + response.StateLabel, + response.OpID, + ), + nil, + ) + if err != nil { + return err + } + resp, err := pic.client.Do(req) + if err != nil { + return err + } + switch resp.StatusCode { + case http.StatusOK: + if resp.Body == nil || output == nil { + // No need to decode the response body. + return nil + } + return json.NewDecoder(resp.Body).Decode(output) + case http.StatusAccepted, http.StatusConflict: + default: + var errResp ErrorMessage + if err := json.NewDecoder(resp.Body).Decode(&errResp); err != nil { + return err + } + return errResp + } + } + case http.StatusConflict: + var response startedOrBusyResponse + if err := json.NewDecoder(resp.Body).Decode(&response); err != nil { + return err + } + time.Sleep(pic.delay) // Retry after a short delay. + continue + default: + var errResp ErrorMessage + if err := json.NewDecoder(resp.Body).Decode(&errResp); err != nil { + return err + } + return errResp + } + } +} + +type startedOrBusyResponse struct { + StateLabel string `json:"state_label"` + OpID string `json:"op_id"` +} diff --git a/pocketic/gen.go b/pocketic/gen.go index 0030c0e..96b8be4 100644 --- a/pocketic/gen.go +++ b/pocketic/gen.go @@ -1,3 +1,3 @@ package pocketic -//go:generate go run ../cmd/goic/main.go generate did testdata/main.did hello --output=agent_test.go --packageName=pocketic_test +//go:generate go run ../cmd/goic/main.go generate did testdata/main.did hello --output=agent_test.go --packageName=pocketic_test --indirect diff --git a/pocketic/http.go b/pocketic/http.go index 6a187aa..5d7af5f 100644 --- a/pocketic/http.go +++ b/pocketic/http.go @@ -79,7 +79,6 @@ func (pic PocketIC) AutoProgress() error { return pic.do( http.MethodPost, fmt.Sprintf("%s/auto_progress", pic.instanceURL()), - http.StatusOK, nil, nil, ) @@ -105,25 +104,31 @@ func (pic PocketIC) MakeLive(port *int) (string, error) { if pic.httpGateway != nil { return fmt.Sprintf("http://127.0.0.1:%d", pic.httpGateway.Port), nil } - var resp CreateHttpGatewayResponse - if err := pic.do( + req, err := newRequest( http.MethodPost, fmt.Sprintf("%s/http_gateway", pic.server.URL()), - http.StatusCreated, HttpGatewayConfig{ ListenAt: port, ForwardTo: HttpGatewayBackendPocketICInstance{ PocketIcInstance: pic.InstanceID, }, }, - &resp, - ); err != nil { + ) + if err != nil { + return "", err + } + resp, err := pic.client.Do(req) + if err != nil { + return "", err + } + var gatewayResp CreateHttpGatewayResponse + if err := json.NewDecoder(resp.Body).Decode(&gatewayResp); err != nil { return "", err } - if resp.Error != nil { - return "", resp.Error + if gatewayResp.Error != nil { + return "", gatewayResp.Error } - return fmt.Sprintf("http://127.0.0.1:%d", resp.Created.Port), nil + return fmt.Sprintf("http://127.0.0.1:%d", gatewayResp.Created.Port), nil } // SetTime sets the current time of the IC, on all subnets. @@ -131,7 +136,6 @@ func (pic PocketIC) SetTime(time time.Time) error { return pic.do( http.MethodPost, fmt.Sprintf("%s/update/set_time", pic.instanceURL()), - http.StatusOK, RawTime{ NanosSinceEpoch: time.UnixNano(), }, @@ -144,7 +148,6 @@ func (pic PocketIC) StopProgress() error { return pic.do( http.MethodPost, fmt.Sprintf("%s/stop_progress", pic.instanceURL()), - http.StatusOK, nil, nil, ) @@ -155,7 +158,6 @@ func (pic *PocketIC) stopHttpGateway() error { if err := pic.do( http.MethodPost, fmt.Sprintf("%s/http_gateway/%d/stop", pic.server.URL(), pic.httpGateway.InstanceID), - http.StatusOK, nil, nil, ); err != nil { diff --git a/pocketic/instances.go b/pocketic/instances.go index 53121e0..a1c9f5c 100644 --- a/pocketic/instances.go +++ b/pocketic/instances.go @@ -13,7 +13,6 @@ func (pic PocketIC) CreateInstance(config SubnetConfigSet) (*InstanceConfig, err if err := pic.do( http.MethodPost, fmt.Sprintf("%s/instances", pic.server.URL()), - http.StatusCreated, config, &a, ); err != nil { @@ -30,7 +29,6 @@ func (pic PocketIC) DeleteInstance(instanceID int) error { return pic.do( http.MethodDelete, fmt.Sprintf("%s/instances/%d", pic.server.URL(), instanceID), - http.StatusOK, nil, nil, ) @@ -42,7 +40,6 @@ func (pic PocketIC) GetCycles(canisterID principal.Principal) (int, error) { if err := pic.do( http.MethodPost, fmt.Sprintf("%s/read/get_cycles", pic.instanceURL()), - http.StatusOK, &RawCanisterID{CanisterID: canisterID.Raw}, &cycles, ); err != nil { @@ -57,7 +54,6 @@ func (pic PocketIC) GetInstances() ([]string, error) { if err := pic.do( http.MethodGet, fmt.Sprintf("%s/instances", pic.server.URL()), - http.StatusOK, nil, &instances, ); err != nil { @@ -73,7 +69,6 @@ func (pic PocketIC) GetStableMemory(canisterID principal.Principal) ([]byte, err if err := pic.do( http.MethodPost, fmt.Sprintf("%s/read/get_stable_memory", pic.instanceURL()), - http.StatusOK, &RawCanisterID{CanisterID: canisterID.Raw}, &data, ); err != nil { @@ -88,7 +83,6 @@ func (pic PocketIC) GetSubnet(canisterID principal.Principal) (*principal.Princi if err := pic.do( http.MethodPost, fmt.Sprintf("%s/read/get_subnet", pic.instanceURL()), - http.StatusOK, &RawCanisterID{CanisterID: canisterID.Raw}, &subnetID, ); err != nil { @@ -103,7 +97,6 @@ func (pic PocketIC) GetTime() (*time.Time, error) { if err := pic.do( http.MethodGet, fmt.Sprintf("%s/read/get_time", pic.instanceURL()), - http.StatusOK, nil, &t, ); err != nil { @@ -133,7 +126,6 @@ func (pic PocketIC) RootKey() ([]byte, error) { if err := pic.do( http.MethodPost, fmt.Sprintf("%s/read/pub_key", pic.instanceURL()), - http.StatusOK, &RawSubnetID{SubnetID: subnetID.Raw}, &key, ); err != nil { @@ -151,7 +143,6 @@ func (pic PocketIC) SetStableMemory(canisterID principal.Principal, data []byte, return pic.do( http.MethodPost, fmt.Sprintf("%s/update/set_stable_memory", pic.instanceURL()), - http.StatusOK, RawSetStableMemory{ CanisterID: canisterID.Raw, BlobID: blobID, @@ -164,7 +155,6 @@ func (pic PocketIC) Tick() error { return pic.do( http.MethodPost, fmt.Sprintf("%s/update/tick", pic.instanceURL()), - http.StatusOK, nil, nil, ) diff --git a/pocketic/management.go b/pocketic/management.go index cd6ddd0..1abeddb 100644 --- a/pocketic/management.go +++ b/pocketic/management.go @@ -18,7 +18,6 @@ func (pic PocketIC) AddCycles(canisterID principal.Principal, amount int) (int, if err := pic.do( http.MethodPost, fmt.Sprintf("%s/update/add_cycles", pic.instanceURL()), - http.StatusOK, RawAddCycles{ Amount: amount, CanisterID: canisterID.Raw, diff --git a/pocketic/pocketic.go b/pocketic/pocketic.go index 0ef1ad6..3d12ffe 100644 --- a/pocketic/pocketic.go +++ b/pocketic/pocketic.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "net/http" + "time" "github.com/aviate-labs/agent-go" "github.com/aviate-labs/agent-go/principal" @@ -55,10 +56,11 @@ func (c *CanisterIDRange) UnmarshalJSON(bytes []byte) error { } type Config struct { - subnetConfig SubnetConfigSet - serverConfig []serverOption - client *http.Client - logger agent.Logger + subnetConfig SubnetConfigSet + serverConfig []serverOption + client *http.Client + logger agent.Logger + delay, timeout time.Duration } type DTSFlag bool @@ -150,6 +152,13 @@ func WithNNSSubnet() Option { } } +func WithPollingDelay(delay, timeout time.Duration) Option { + return func(p *Config) { + p.delay = delay + p.timeout = timeout + } +} + // WithSNSSubnet adds an empty SNS subnet. func WithSNSSubnet() Option { return func(p *Config) { @@ -184,9 +193,10 @@ type PocketIC struct { httpGateway *HttpGatewayInfo topology map[string]Topology - logger agent.Logger - client *http.Client - server *server + logger agent.Logger + client *http.Client + delay, timeout time.Duration + server *server } // New creates a new PocketIC client. @@ -196,6 +206,8 @@ func New(opts ...Option) (*PocketIC, error) { subnetConfig: DefaultSubnetConfig, client: http.DefaultClient, logger: new(agent.NoopLogger), + delay: 10 * time.Millisecond, + timeout: 1 * time.Second, } for _, fn := range opts { fn(&config) @@ -219,7 +231,10 @@ func New(opts ...Option) (*PocketIC, error) { if respBody.Error != nil { return nil, respBody.Error } - if err := checkResponse(resp, http.StatusCreated, &respBody); err != nil { + if resp.StatusCode != http.StatusCreated { + return nil, fmt.Errorf("failed to create instance: %s", resp.Status) + } + if err := json.NewDecoder(resp.Body).Decode(&respBody); err != nil { return nil, fmt.Errorf("failed to create instance: %v", err) } @@ -229,6 +244,8 @@ func New(opts ...Option) (*PocketIC, error) { topology: respBody.Created.Topology, logger: config.logger, client: config.client, + delay: config.delay, + timeout: config.timeout, server: s, }, nil } @@ -243,7 +260,6 @@ func (pic PocketIC) Status() error { return pic.do( http.MethodGet, fmt.Sprintf("%s/status", pic.server.URL()), - http.StatusOK, nil, nil, ) @@ -254,11 +270,11 @@ func (pic PocketIC) Topology() map[string]Topology { return pic.topology } +// VerifySignature verifies a signature. func (pic PocketIC) VerifySignature(sig RawVerifyCanisterSigArg) error { return pic.do( http.MethodPost, fmt.Sprintf("%s/verify_signature", pic.server.URL()), - http.StatusOK, sig, nil, ) diff --git a/pocketic/pocketic_test.go b/pocketic/pocketic_test.go index a92e4f5..973500b 100644 --- a/pocketic/pocketic_test.go +++ b/pocketic/pocketic_test.go @@ -13,9 +13,37 @@ import ( "os/exec" "path" "strings" + "sync" "testing" ) +func ConcurrentCalls(t *testing.T) *pocketic.PocketIC { + pic, err := pocketic.New() + if err != nil { + t.Fatal(err) + } + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer func() { + wg.Done() + }() + + canisterID, err := pic.CreateCanister() + if err != nil { + t.Error(err) + return + } + if _, err := pic.AddCycles(*canisterID, 2_000_000_000_000); err != nil { + t.Error(err) + } + }() + } + wg.Wait() + return pic +} + func CreateCanister(t *testing.T) *pocketic.PocketIC { pic, err := pocketic.New(pocketic.WithLogger(new(testLogger))) if err != nil { @@ -117,6 +145,9 @@ func TestPocketIC(t *testing.T) { t.Run("HttpGateway", func(t *testing.T) { instances = append(instances, HttpGateway(t)) }) + t.Run("ConcurrentCalls", func(t *testing.T) { + instances = append(instances, ConcurrentCalls(t)) + }) t.Run("Endpoints", func(t *testing.T) { instances = append(instances, Endpoints(t)) }) diff --git a/pocketic/request.go b/pocketic/request.go index 608608a..eb75d35 100644 --- a/pocketic/request.go +++ b/pocketic/request.go @@ -17,17 +17,6 @@ var headers = func() http.Header { } } -func checkResponse(resp *http.Response, statusCode int, v any) error { - if resp.StatusCode != statusCode { - return fmt.Errorf("unexpected status code: %d", resp.StatusCode) - } - if v == nil { - // No need to decode the response body. - return nil - } - return json.NewDecoder(resp.Body).Decode(v) -} - func newRequest(method, url string, body any) (*http.Request, error) { var bodyBytes io.Reader if body != nil { @@ -51,7 +40,6 @@ func (pic PocketIC) AwaitCall(messageID RawMessageID) ([]byte, error) { if err := pic.do( http.MethodPost, fmt.Sprintf("%s/update/await_ingress_message", pic.instanceURL()), - http.StatusOK, messageID, &resp, ); err != nil { @@ -78,7 +66,6 @@ func (pic PocketIC) ExecuteCall( if err := pic.do( http.MethodPost, fmt.Sprintf("%s/update/execute_ingress_message", pic.instanceURL()), - http.StatusOK, RawCanisterCall{ CanisterID: canisterID.Raw, EffectivePrincipal: effectivePrincipal, @@ -143,7 +130,6 @@ func (pic PocketIC) SubmitCallWithEP( if err := pic.do( http.MethodPost, fmt.Sprintf("%s/update/submit_ingress_message", pic.instanceURL()), - http.StatusOK, RawCanisterCall{ CanisterID: canisterID.Raw, EffectivePrincipal: effectivePrincipal, @@ -172,7 +158,6 @@ func (pic PocketIC) canisterCall(endpoint string, canisterID principal.Principal if err := pic.do( http.MethodPost, fmt.Sprintf("%s/%s", pic.instanceURL(), endpoint), - http.StatusOK, RawCanisterCall{ CanisterID: canisterID.Raw, EffectivePrincipal: effectivePrincipal, @@ -193,19 +178,6 @@ func (pic PocketIC) canisterCall(endpoint string, canisterID principal.Principal return resp.Ok.Reply, nil } -func (pic PocketIC) do(method, url string, statusCode int, input, output any) error { - pic.logger.Printf("[POCKETIC] %s %s %+v", method, url, input) - req, err := newRequest(method, url, input) - if err != nil { - return err - } - resp, err := pic.client.Do(req) - if err != nil { - return err - } - return checkResponse(resp, statusCode, output) -} - // updateCallWithEP calls SubmitCallWithEP and AwaitCall in sequence. func (pic PocketIC) updateCallWithEP(canisterID principal.Principal, effectivePrincipal RawEffectivePrincipal, sender principal.Principal, method string, payload []byte) ([]byte, error) { messageID, err := pic.SubmitCallWithEP(canisterID, effectivePrincipal, sender, method, payload)