diff --git a/cmd/sign/main.go b/cmd/sign/main.go index 1be6dff5..d4ac5eb7 100644 --- a/cmd/sign/main.go +++ b/cmd/sign/main.go @@ -20,7 +20,7 @@ func main() { req := &api.CreateTaskReq{ ProjectID: "912", ProjectVersion: "v1.0.0", - Payloads: []string{"{\"private_input\":\"14\", \"public_input\":\"3,34\", \"receipt_type\":\"Snark\"}"}, + Payload: "{\"private_input\":\"14\", \"public_input\":\"3,34\", \"receipt_type\":\"Snark\"}", } reqJson, _ := json.Marshal(req) fmt.Println(string(reqJson)) diff --git a/datasource/clickhouse.go b/datasource/clickhouse.go index 18cd7f66..80eb756d 100644 --- a/datasource/clickhouse.go +++ b/datasource/clickhouse.go @@ -2,7 +2,6 @@ package datasource import ( "context" - "encoding/json" "math/big" "github.com/ClickHouse/clickhouse-go/v2" @@ -33,19 +32,15 @@ func (p *Clickhouse) Retrieve(taskIDs []common.Hash) ([]*task.Task, error) { res := []*task.Task{} for i := range ts { - ps := [][]byte{} - if err := json.Unmarshal(ts[i].Payloads, &ps); err != nil { - return nil, errors.Wrapf(err, "failed to unmarshal task payloads, task_id %v", ts[i].TaskID) - } - pid := new(big.Int) - if _, ok := pid.SetString(ts[i].ProjectID, 10); !ok { + pid, ok := new(big.Int).SetString(ts[i].ProjectID, 10) + if !ok { return nil, errors.New("failed to decode project id string") } res = append(res, &task.Task{ ID: common.BytesToHash(ts[i].TaskID), ProjectID: pid, ProjectVersion: ts[i].ProjectVersion, - Payloads: ps, + Payload: ts[i].Payload, DeviceID: common.BytesToAddress(ts[i].DeviceID), Signature: ts[i].Signature, }) diff --git a/e2e/util.go b/e2e/util.go index 42e650ac..06ee9097 100644 --- a/e2e/util.go +++ b/e2e/util.go @@ -27,10 +27,9 @@ import ( func signMesssage(data []byte, projectID uint64, key *ecdsa.PrivateKey) ([]byte, error) { req := &api.CreateTaskReq{ - DeviceID: "did:io:" + crypto.PubkeyToAddress(key.PublicKey).String(), Nonce: uint64(time.Now().Unix()), ProjectID: strconv.Itoa(int(projectID)), - Payloads: []string{hexutil.Encode(data)}, + Payload: hexutil.Encode(data), } reqJson, err := json.Marshal(req) diff --git a/service/apinode/api/http.go b/service/apinode/api/http.go index dfedd7e0..c2b42be2 100644 --- a/service/apinode/api/http.go +++ b/service/apinode/api/http.go @@ -2,7 +2,6 @@ package api import ( "bytes" - "crypto/ecdsa" "crypto/sha256" "encoding/json" "fmt" @@ -34,13 +33,12 @@ func newErrResp(err error) *errResp { } type CreateTaskReq struct { - Nonce uint64 `json:"nonce" binding:"required"` - DeviceID string `json:"deviceID" binding:"required"` - ProjectID string `json:"projectID" binding:"required"` - ProjectVersion string `json:"projectVersion,omitempty"` - Payloads []string `json:"payloads" binding:"required"` - Algorithm string `json:"algorithm,omitempty"` // Refer to the constants defined in JWT (JSON Web Token) https://jwt.io/ - Signature string `json:"signature,omitempty" binding:"required"` + Nonce uint64 `json:"nonce" binding:"required"` + ProjectID string `json:"projectID" binding:"required"` + ProjectVersion string `json:"projectVersion,omitempty"` + Payload string `json:"payload" binding:"required"` + Algorithm string `json:"algorithm,omitempty"` // Refer to the constants defined in JWT (JSON Web Token) https://jwt.io/ + Signature string `json:"signature,omitempty" binding:"required"` } type CreateTaskResp struct { @@ -69,6 +67,11 @@ type httpServer struct { proverAddr string } +type recoverRes struct { + addr common.Address + sig []byte +} + func (s *httpServer) createTask(c *gin.Context) { req := &CreateTaskReq{} if err := c.ShouldBindJSON(req); err != nil { @@ -77,8 +80,8 @@ func (s *httpServer) createTask(c *gin.Context) { return } - pid := new(big.Int) - if _, ok := pid.SetString(req.ProjectID, 10); !ok { + pid, ok := new(big.Int).SetString(req.ProjectID, 10) + if !ok { slog.Error("failed to decode project id string", "project_id", req.ProjectID) c.JSON(http.StatusBadRequest, newErrResp(errors.New("failed to decode project id string"))) return @@ -89,42 +92,42 @@ func (s *httpServer) createTask(c *gin.Context) { c.JSON(http.StatusBadRequest, newErrResp(errors.Wrap(err, "failed to decode signature from hex format"))) return } - deviceAddr := common.HexToAddress(strings.TrimPrefix(req.DeviceID, "did:io:")) - addr, sig, alg, err := recoverAddr(*req, sig, deviceAddr) + + recovered, alg, err := recover(*req, sig) if err != nil { slog.Error("failed to recover public key", "error", err) c.JSON(http.StatusBadRequest, newErrResp(errors.Wrap(err, "invalid signature; could not recover public key"))) return } - - ok, err := s.db.IsDeviceApproved(pid, addr) - if err != nil { - slog.Error("failed to check device permission", "error", err) - c.JSON(http.StatusInternalServerError, newErrResp(errors.Wrap(err, "failed to check device permission"))) - return + var addr common.Address + var approved bool + for _, r := range recovered { + ok, err := s.db.IsDeviceApproved(pid, r.addr) + if err != nil { + slog.Error("failed to check device permission", "error", err) + c.JSON(http.StatusInternalServerError, newErrResp(errors.Wrap(err, "failed to check device permission"))) + return + } + if ok { + approved = true + addr = r.addr + sig = r.sig + break + } } - if !ok { - slog.Error("device does not have permission", "project_id", pid.String(), "device_address", addr.String()) + if !approved { + slog.Error("device does not have permission", "project_id", pid.String()) c.JSON(http.StatusForbidden, newErrResp(errors.New("device does not have permission"))) return } - payloadsB := make([][]byte, 0, len(req.Payloads)) - for _, p := range req.Payloads { - d, err := hexutil.Decode(p) - if err != nil { - slog.Error("failed to decode payload from hex format", "error", err) - c.JSON(http.StatusBadRequest, newErrResp(errors.Wrap(err, "failed to decode payload from hex format"))) - return - } - payloadsB = append(payloadsB, d) - } - payloadsJ, err := json.Marshal(payloadsB) + payload, err := hexutil.Decode(req.Payload) if err != nil { - slog.Error("failed to marshal payloads", "error", err) - c.JSON(http.StatusInternalServerError, newErrResp(errors.Wrap(err, "failed to marshal payloads"))) + slog.Error("failed to decode payload from hex format", "error", err) + c.JSON(http.StatusBadRequest, newErrResp(errors.Wrap(err, "failed to decode payload from hex format"))) return } + taskID := crypto.Keccak256Hash(sig) if err := s.db.CreateTask( @@ -134,7 +137,7 @@ func (s *httpServer) createTask(c *gin.Context) { Nonce: req.Nonce, ProjectID: pid.String(), ProjectVersion: req.ProjectVersion, - Payloads: payloadsJ, + Payload: payload, Signature: sig, Algorithm: alg, }, @@ -175,40 +178,27 @@ func (s *httpServer) createTask(c *gin.Context) { }) } -func recoverAddr(req CreateTaskReq, sig []byte, deviceAddr common.Address) (common.Address, []byte, string, error) { +func recover(req CreateTaskReq, sig []byte) ([]*recoverRes, string, error) { req.Signature = "" reqJson, err := json.Marshal(req) if err != nil { - return common.Address{}, nil, "", errors.Wrap(err, "failed to marshal request into json format") + return nil, "", errors.Wrap(err, "failed to marshal request into json format") } switch req.Algorithm { default: h := sha256.Sum256(reqJson) - res := []struct { - pk *ecdsa.PublicKey - sig []byte - }{} + res := []*recoverRes{} rID := []uint8{0, 1} for _, id := range rID { ns := append(sig, byte(id)) if pk, err := crypto.SigToPub(h[:], ns); err != nil { - slog.Debug("failed to recover public key from signature", "error", err, "recover_id", id, "signature", hexutil.Encode(sig)) + return nil, "", errors.Wrapf(err, "failed to recover public key from signature, recover_id %d", id) } else { - res = append(res, struct { - pk *ecdsa.PublicKey - sig []byte - }{pk: pk, sig: ns}) - } - } - - for _, r := range res { - addr := crypto.PubkeyToAddress(*r.pk) - if bytes.Equal(addr.Bytes(), deviceAddr.Bytes()) { - return addr, r.sig, "ES256", nil + res = append(res, &recoverRes{addr: crypto.PubkeyToAddress(*pk), sig: ns}) } } - return common.Address{}, nil, "", errors.New("failed to recover public key from signature") + return res, "ES256", nil } } diff --git a/service/apinode/db/clickhouse.go b/service/apinode/db/clickhouse.go index 055d9ab8..95723a12 100644 --- a/service/apinode/db/clickhouse.go +++ b/service/apinode/db/clickhouse.go @@ -17,7 +17,7 @@ type Task struct { Nonce uint64 `ch:"nonce"` ProjectID string `ch:"project_id"` ProjectVersion string `ch:"project_version"` - Payloads []byte `ch:"payloads"` + Payload []byte `ch:"payload"` Signature []byte `ch:"signature"` Algorithm string `ch:"algorithm"` CreatedAt time.Time `ch:"create_at"` @@ -52,7 +52,7 @@ func migrateCH(conn driver.Conn) error { nonce UInt64 NOT NULL, project_id String NOT NULL, project_version String NOT NULL, - payloads Array(UInt8) NOT NULL, + payload Array(UInt8) NOT NULL, signature Array(UInt8) NOT NULL, algorithm String NOT NULL, create_at DateTime NOT NULL diff --git a/task/task.go b/task/task.go index e4c0e65b..51a2a55d 100644 --- a/task/task.go +++ b/task/task.go @@ -16,7 +16,7 @@ type Task struct { ProjectID *big.Int `json:"projectID"` ProjectVersion string `json:"projectVersion,omitempty"` DeviceID common.Address `json:"deviceID"` - Payloads [][]byte `json:"payloads"` + Payload []byte `json:"payload"` Signature []byte `json:"signature,omitempty"` } diff --git a/vm/vm.go b/vm/vm.go index 0d27382b..0534eb10 100644 --- a/vm/vm.go +++ b/vm/vm.go @@ -45,11 +45,11 @@ func (r *Handler) Handle(task *task.Task, vmTypeID uint64, code string, expParam resp, err := cli.ExecuteTask(context.Background(), &proto.ExecuteTaskRequest{ ProjectID: task.ProjectID.Uint64(), TaskID: task.ID[:], - Payloads: task.Payloads, + Payloads: [][]byte{task.Payload}, }) if err != nil { slog.Error("failed to execute task", "project_id", task.ProjectID, "vm_type", vmTypeID, - "task_id", task.ID, "binary", code, "payloads", task.Payloads, "err", err) + "task_id", task.ID, "binary", code, "payloads", task.Payload, "err", err) return nil, errors.Wrap(err, "failed to execute vm instance") }