diff --git a/make/migrations/postgresql/0120_2.9.0_schema.up.sql b/make/migrations/postgresql/0120_2.9.0_schema.up.sql new file mode 100644 index 00000000000..16f6b7b71fd --- /dev/null +++ b/make/migrations/postgresql/0120_2.9.0_schema.up.sql @@ -0,0 +1 @@ +CREATE INDEX IF NOT EXISTS idx_task_extra_attrs_report_uuids ON task USING gin ((extra_attrs::jsonb->'report_uuids')); diff --git a/src/controller/scan/base_controller.go b/src/controller/scan/base_controller.go index 6ef9b3d00c3..41bf571ab57 100644 --- a/src/controller/scan/base_controller.go +++ b/src/controller/scan/base_controller.go @@ -964,15 +964,6 @@ func (bc *basicController) launchScanJob(ctx context.Context, param *launchScanJ reportUUIDsKey: reportUUIDs, } - // NOTE: due to the limitation of the beego's orm, the List method of the task manager not support ?! operator for the jsonb field, - // we cann't list the tasks for scan reports of uuid1, uuid2 by SQL `SELECT * FROM task WHERE (extra_attrs->'report_uuids')::jsonb ?| array['uuid1', 'uuid2']` - // or by `SELECT * FROM task WHERE id IN (SELECT id FROM task WHERE (extra_attrs->'report_uuids')::jsonb ?| array['uuid1', 'uuid2'])` - // so save {"report:uuid1": "1", "report:uuid2": "2"} in the extra_attrs of the task, and then list it with - // SQL `SELECT * FROM task WHERE extra_attrs->>'report:uuid1' = '1'` in loop - for _, reportUUID := range reportUUIDs { - extraAttrs["report:"+reportUUID] = "1" - } - _, err = bc.taskMgr.Create(ctx, param.ExecutionID, j, extraAttrs) return err } @@ -1022,11 +1013,12 @@ func (bc *basicController) listScanTasks(ctx context.Context, reportUUIDs []stri } func (bc *basicController) getScanTask(ctx context.Context, reportUUID string) (*task.Task, error) { - query := q.New(q.KeyWords{"extra_attrs." + "report:" + reportUUID: "1"}) - tasks, err := bc.taskMgr.List(bc.cloneCtx(ctx), query) + // NOTE: the method uses the postgres' unique operations and should consider here if support other database in the future. + tasks, err := bc.taskMgr.ListScanTasksByReportUUID(ctx, reportUUID) if err != nil { return nil, err } + if len(tasks) == 0 { return nil, errors.NotFoundError(nil).WithMessage("task for report %s not found", reportUUID) } diff --git a/src/controller/scan/base_controller_test.go b/src/controller/scan/base_controller_test.go index 84aab98d99b..8b781fc501c 100644 --- a/src/controller/scan/base_controller_test.go +++ b/src/controller/scan/base_controller_test.go @@ -321,7 +321,7 @@ func (suite *ControllerTestSuite) TestScanControllerScan() { walkFn(suite.artifact) }).Once() - mock.OnAnything(suite.taskMgr, "List").Return([]*task.Task{ + mock.OnAnything(suite.taskMgr, "ListScanTasksByReportUUID").Return([]*task.Task{ {ExtraAttrs: suite.makeExtraAttrs(int64(1), "rp-uuid-001"), Status: "Success"}, }, nil).Once() @@ -343,7 +343,7 @@ func (suite *ControllerTestSuite) TestScanControllerScan() { walkFn(suite.artifact) }).Once() - mock.OnAnything(suite.taskMgr, "List").Return([]*task.Task{ + mock.OnAnything(suite.taskMgr, "ListScanTasksByReportUUID").Return([]*task.Task{ {ExtraAttrs: suite.makeExtraAttrs(int64(1), "rp-uuid-001"), Status: "Success"}, }, nil).Once() @@ -360,7 +360,7 @@ func (suite *ControllerTestSuite) TestScanControllerScan() { walkFn(suite.artifact) }).Once() - mock.OnAnything(suite.taskMgr, "List").Return([]*task.Task{ + mock.OnAnything(suite.taskMgr, "ListScanTasksByReportUUID").Return([]*task.Task{ {ExtraAttrs: suite.makeExtraAttrs(int64(1), "rp-uuid-001"), Status: "Running"}, }, nil).Once() @@ -409,37 +409,40 @@ func (suite *ControllerTestSuite) TestScanControllerStop() { // TestScanControllerGetReport ... func (suite *ControllerTestSuite) TestScanControllerGetReport() { + ctx := orm.NewContext(nil, &ormtesting.FakeOrmer{}) mock.OnAnything(suite.ar, "Walk").Return(nil).Run(func(args mock.Arguments) { walkFn := args.Get(2).(func(*artifact.Artifact) error) walkFn(suite.artifact) }).Once() - mock.OnAnything(suite.taskMgr, "List").Return([]*task.Task{ + mock.OnAnything(suite.taskMgr, "ListScanTasksByReportUUID").Return([]*task.Task{ {ExtraAttrs: suite.makeExtraAttrs(int64(1), "rp-uuid-001")}, }, nil).Once() mock.OnAnything(suite.accessoryMgr, "List").Return(nil, nil) - rep, err := suite.c.GetReport(context.TODO(), suite.artifact, []string{v1.MimeTypeNativeReport}) + rep, err := suite.c.GetReport(ctx, suite.artifact, []string{v1.MimeTypeNativeReport}) require.NoError(suite.T(), err) assert.Equal(suite.T(), 1, len(rep)) } // TestScanControllerGetSummary ... func (suite *ControllerTestSuite) TestScanControllerGetSummary() { + ctx := orm.NewContext(nil, &ormtesting.FakeOrmer{}) mock.OnAnything(suite.accessoryMgr, "List").Return([]accessoryModel.Accessory{}, nil).Once() mock.OnAnything(suite.ar, "Walk").Return(nil).Run(func(args mock.Arguments) { walkFn := args.Get(2).(func(*artifact.Artifact) error) walkFn(suite.artifact) }).Once() - mock.OnAnything(suite.taskMgr, "List").Return(nil, nil).Once() + mock.OnAnything(suite.taskMgr, "ListScanTasksByReportUUID").Return(nil, nil).Once() - sum, err := suite.c.GetSummary(context.TODO(), suite.artifact, []string{v1.MimeTypeNativeReport}) + sum, err := suite.c.GetSummary(ctx, suite.artifact, []string{v1.MimeTypeNativeReport}) require.NoError(suite.T(), err) assert.Equal(suite.T(), 1, len(sum)) } // TestScanControllerGetScanLog ... func (suite *ControllerTestSuite) TestScanControllerGetScanLog() { - mock.OnAnything(suite.taskMgr, "List").Return([]*task.Task{ + ctx := orm.NewContext(nil, &ormtesting.FakeOrmer{}) + mock.OnAnything(suite.taskMgr, "ListScanTasksByReportUUID").Return([]*task.Task{ { ID: 1, ExtraAttrs: suite.makeExtraAttrs(int64(1), "rp-uuid-001"), @@ -448,7 +451,7 @@ func (suite *ControllerTestSuite) TestScanControllerGetScanLog() { mock.OnAnything(suite.taskMgr, "GetLog").Return([]byte("log"), nil).Once() - bytes, err := suite.c.GetScanLog(context.TODO(), &artifact.Artifact{Artifact: art.Artifact{ID: 1, ProjectID: 1}}, "rp-uuid-001") + bytes, err := suite.c.GetScanLog(ctx, &artifact.Artifact{Artifact: art.Artifact{ID: 1, ProjectID: 1}}, "rp-uuid-001") require.NoError(suite.T(), err) assert.Condition(suite.T(), func() (success bool) { success = len(bytes) > 0 @@ -457,8 +460,8 @@ func (suite *ControllerTestSuite) TestScanControllerGetScanLog() { } func (suite *ControllerTestSuite) TestScanControllerGetMultiScanLog() { - kw1 := q.KeyWords{"extra_attrs.report:rp-uuid-001": "1"} - suite.taskMgr.On("List", context.TODO(), q.New(kw1)).Return([]*task.Task{ + ctx := orm.NewContext(nil, &ormtesting.FakeOrmer{}) + suite.taskMgr.On("ListScanTasksByReportUUID", ctx, "rp-uuid-001").Return([]*task.Task{ { ID: 1, ExtraAttrs: suite.makeExtraAttrs(int64(1), "rp-uuid-001"), @@ -469,8 +472,7 @@ func (suite *ControllerTestSuite) TestScanControllerGetMultiScanLog() { walkFn(suite.artifact) }) mock.OnAnything(suite.accessoryMgr, "List").Return(nil, nil) - kw2 := q.KeyWords{"extra_attrs.report:rp-uuid-002": "1"} - suite.taskMgr.On("List", context.TODO(), q.New(kw2)).Return([]*task.Task{ + suite.taskMgr.On("ListScanTasksByReportUUID", ctx, "rp-uuid-002").Return([]*task.Task{ { ID: 2, ExtraAttrs: suite.makeExtraAttrs(int64(1), "rp-uuid-002"), @@ -480,7 +482,7 @@ func (suite *ControllerTestSuite) TestScanControllerGetMultiScanLog() { // Both success mock.OnAnything(suite.taskMgr, "GetLog").Return([]byte("log"), nil).Twice() - bytes, err := suite.c.GetScanLog(context.TODO(), &artifact.Artifact{Artifact: art.Artifact{ID: 1, ProjectID: 1}}, base64.StdEncoding.EncodeToString([]byte("rp-uuid-001|rp-uuid-002"))) + bytes, err := suite.c.GetScanLog(ctx, &artifact.Artifact{Artifact: art.Artifact{ID: 1, ProjectID: 1}}, base64.StdEncoding.EncodeToString([]byte("rp-uuid-001|rp-uuid-002"))) suite.Nil(err) suite.NotEmpty(bytes) suite.Contains(string(bytes), "Logs of report rp-uuid-001") @@ -489,10 +491,10 @@ func (suite *ControllerTestSuite) TestScanControllerGetMultiScanLog() { { // One successfully, one failed - suite.taskMgr.On("GetLog", context.TODO(), int64(1)).Return([]byte("log"), nil).Once() - suite.taskMgr.On("GetLog", context.TODO(), int64(2)).Return(nil, fmt.Errorf("failed")).Once() + suite.taskMgr.On("GetLog", ctx, int64(1)).Return([]byte("log"), nil).Once() + suite.taskMgr.On("GetLog", ctx, int64(2)).Return(nil, fmt.Errorf("failed")).Once() - bytes, err := suite.c.GetScanLog(context.TODO(), &artifact.Artifact{Artifact: art.Artifact{ID: 1, ProjectID: 1}}, base64.StdEncoding.EncodeToString([]byte("rp-uuid-001|rp-uuid-002"))) + bytes, err := suite.c.GetScanLog(ctx, &artifact.Artifact{Artifact: art.Artifact{ID: 1, ProjectID: 1}}, base64.StdEncoding.EncodeToString([]byte("rp-uuid-001|rp-uuid-002"))) suite.Nil(err) suite.NotEmpty(bytes) suite.NotContains(string(bytes), "Logs of report rp-uuid-001") @@ -502,7 +504,7 @@ func (suite *ControllerTestSuite) TestScanControllerGetMultiScanLog() { // Both failed mock.OnAnything(suite.taskMgr, "GetLog").Return(nil, fmt.Errorf("failed")).Twice() - bytes, err := suite.c.GetScanLog(context.TODO(), &artifact.Artifact{Artifact: art.Artifact{ID: 1, ProjectID: 1}}, base64.StdEncoding.EncodeToString([]byte("rp-uuid-001|rp-uuid-002"))) + bytes, err := suite.c.GetScanLog(ctx, &artifact.Artifact{Artifact: art.Artifact{ID: 1, ProjectID: 1}}, base64.StdEncoding.EncodeToString([]byte("rp-uuid-001|rp-uuid-002"))) suite.Error(err) suite.Empty(bytes) } @@ -511,7 +513,7 @@ func (suite *ControllerTestSuite) TestScanControllerGetMultiScanLog() { // Both empty mock.OnAnything(suite.taskMgr, "GetLog").Return(nil, nil).Twice() - bytes, err := suite.c.GetScanLog(context.TODO(), &artifact.Artifact{Artifact: art.Artifact{ID: 1, ProjectID: 1}}, base64.StdEncoding.EncodeToString([]byte("rp-uuid-001|rp-uuid-002"))) + bytes, err := suite.c.GetScanLog(ctx, &artifact.Artifact{Artifact: art.Artifact{ID: 1, ProjectID: 1}}, base64.StdEncoding.EncodeToString([]byte("rp-uuid-001|rp-uuid-002"))) suite.Nil(err) suite.Empty(bytes) } @@ -560,7 +562,7 @@ func (suite *ControllerTestSuite) TestScanAll() { walkFn(suite.artifact) }).Once() - mock.OnAnything(suite.taskMgr, "List").Return(nil, nil).Once() + mock.OnAnything(suite.taskMgr, "ListScanTasksByReportUUID").Return(nil, nil).Once() mock.OnAnything(suite.reportMgr, "Delete").Return(nil).Once() mock.OnAnything(suite.reportMgr, "Create").Return("uuid", nil).Once() diff --git a/src/pkg/task/dao/task.go b/src/pkg/task/dao/task.go index 0c4ff311c2b..de3c71bd22b 100644 --- a/src/pkg/task/dao/task.go +++ b/src/pkg/task/dao/task.go @@ -53,6 +53,9 @@ type TaskDAO interface { UpdateStatusInBatch(ctx context.Context, jobIDs []string, status string, batchSize int) (err error) // ExecutionIDsByVendorAndStatus retrieve the execution id by vendor status ExecutionIDsByVendorAndStatus(ctx context.Context, vendorType, status string) ([]int64, error) + // ListScanTasksByReportUUID lists scan tasks by report uuid, although it's a specific case but it will be + // more suitable to support multi database in the future. + ListScanTasksByReportUUID(ctx context.Context, uuid string) (tasks []*Task, err error) } // NewTaskDAO returns an instance of TaskDAO @@ -88,6 +91,25 @@ func (t *taskDAO) List(ctx context.Context, query *q.Query) ([]*Task, error) { return tasks, nil } +func (t *taskDAO) ListScanTasksByReportUUID(ctx context.Context, uuid string) ([]*Task, error) { + ormer, err := orm.FromContext(ctx) + if err != nil { + return nil, err + } + + tasks := []*Task{} + // Due to the limitation of the beego's orm, the SQL cannot be converted by orm framework, + // so we can only execute the query by raw SQL, the SQL filters the task contains the report uuid in the column extra_attrs, + // consider from performance side which can using indexes to speed up queries. + sql := fmt.Sprintf(`SELECT * FROM task WHERE extra_attrs::jsonb->'report_uuids' @> '["%s"]'`, uuid) + _, err = ormer.Raw(sql).QueryRows(&tasks) + if err != nil { + return nil, err + } + + return tasks, nil +} + func (t *taskDAO) Get(ctx context.Context, id int64) (*Task, error) { task := &Task{ ID: id, diff --git a/src/pkg/task/dao/task_test.go b/src/pkg/task/dao/task_test.go index f9440ec4416..aeca41a1ca9 100644 --- a/src/pkg/task/dao/task_test.go +++ b/src/pkg/task/dao/task_test.go @@ -112,6 +112,27 @@ func (t *taskDAOTestSuite) TestList() { t.Require().Len(tasks, 0) } +func (t *taskDAOTestSuite) TestListScanTasksByReportUUID() { + // should not exist if non set + tasks, err := t.taskDAO.ListScanTasksByReportUUID(t.ctx, "fake-report-uuid") + t.Require().Nil(err) + t.Require().Len(tasks, 0) + // create one with report uuid + taskID, err := t.taskDAO.Create(t.ctx, &Task{ + ExecutionID: t.executionID, + Status: "success", + StatusCode: 1, + ExtraAttrs: `{"report_uuids": ["fake-report-uuid"]}`, + }) + t.Require().Nil(err) + defer t.taskDAO.Delete(t.ctx, taskID) + // should exist as created + tasks, err = t.taskDAO.ListScanTasksByReportUUID(t.ctx, "fake-report-uuid") + t.Require().Nil(err) + t.Require().Len(tasks, 1) + t.Equal(taskID, tasks[0].ID) +} + func (t *taskDAOTestSuite) TestGet() { // not exist _, err := t.taskDAO.Get(t.ctx, 10000) diff --git a/src/pkg/task/mock_task_dao_test.go b/src/pkg/task/mock_task_dao_test.go index 2ba7e33c65c..6d436b2fd36 100644 --- a/src/pkg/task/mock_task_dao_test.go +++ b/src/pkg/task/mock_task_dao_test.go @@ -182,6 +182,32 @@ func (_m *mockTaskDAO) List(ctx context.Context, query *q.Query) ([]*dao.Task, e return r0, r1 } +// ListScanTasksByReportUUID provides a mock function with given fields: ctx, uuid +func (_m *mockTaskDAO) ListScanTasksByReportUUID(ctx context.Context, uuid string) ([]*dao.Task, error) { + ret := _m.Called(ctx, uuid) + + var r0 []*dao.Task + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) ([]*dao.Task, error)); ok { + return rf(ctx, uuid) + } + if rf, ok := ret.Get(0).(func(context.Context, string) []*dao.Task); ok { + r0 = rf(ctx, uuid) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*dao.Task) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, uuid) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // ListStatusCount provides a mock function with given fields: ctx, executionID func (_m *mockTaskDAO) ListStatusCount(ctx context.Context, executionID int64) ([]*dao.StatusCount, error) { ret := _m.Called(ctx, executionID) diff --git a/src/pkg/task/mock_task_manager_test.go b/src/pkg/task/mock_task_manager_test.go index 916d5ea0045..d32dcc7ba8a 100644 --- a/src/pkg/task/mock_task_manager_test.go +++ b/src/pkg/task/mock_task_manager_test.go @@ -199,6 +199,32 @@ func (_m *mockTaskManager) List(ctx context.Context, query *q.Query) ([]*Task, e return r0, r1 } +// ListScanTasksByReportUUID provides a mock function with given fields: ctx, uuid +func (_m *mockTaskManager) ListScanTasksByReportUUID(ctx context.Context, uuid string) ([]*Task, error) { + ret := _m.Called(ctx, uuid) + + var r0 []*Task + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) ([]*Task, error)); ok { + return rf(ctx, uuid) + } + if rf, ok := ret.Get(0).(func(context.Context, string) []*Task); ok { + r0 = rf(ctx, uuid) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*Task) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, uuid) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // Stop provides a mock function with given fields: ctx, id func (_m *mockTaskManager) Stop(ctx context.Context, id int64) error { ret := _m.Called(ctx, id) diff --git a/src/pkg/task/task.go b/src/pkg/task/task.go index 88fae30929c..b2d0524faab 100644 --- a/src/pkg/task/task.go +++ b/src/pkg/task/task.go @@ -64,6 +64,9 @@ type Manager interface { UpdateStatusInBatch(ctx context.Context, jobIDs []string, status string, batchSize int) error // ExecutionIDsByVendorAndStatus retrieve execution id by vendor type and status ExecutionIDsByVendorAndStatus(ctx context.Context, vendorType, status string) ([]int64, error) + // ListScanTasksByReportUUID lists scan tasks by report uuid, although it's a specific case but it will be + // more suitable to support multi database in the future. + ListScanTasksByReportUUID(ctx context.Context, uuid string) (tasks []*Task, err error) } // NewManager creates an instance of the default task manager @@ -234,6 +237,20 @@ func (m *manager) List(ctx context.Context, query *q.Query) ([]*Task, error) { return ts, nil } +func (m *manager) ListScanTasksByReportUUID(ctx context.Context, uuid string) ([]*Task, error) { + tasks, err := m.dao.ListScanTasksByReportUUID(ctx, uuid) + if err != nil { + return nil, err + } + var ts []*Task + for _, task := range tasks { + t := &Task{} + t.From(task) + ts = append(ts, t) + } + return ts, nil +} + func (m *manager) UpdateExtraAttrs(ctx context.Context, id int64, extraAttrs map[string]interface{}) error { data, err := json.Marshal(extraAttrs) if err != nil { diff --git a/src/pkg/task/task_test.go b/src/pkg/task/task_test.go index 2c482ac16ba..7b00500bb6d 100644 --- a/src/pkg/task/task_test.go +++ b/src/pkg/task/task_test.go @@ -147,6 +147,19 @@ func (t *taskManagerTestSuite) TestList() { t.dao.AssertExpectations(t.T()) } +func (t *taskManagerTestSuite) TestListScanTasksByReportUUID() { + t.dao.On("ListScanTasksByReportUUID", mock.Anything, mock.Anything).Return([]*dao.Task{ + { + ID: 1, + }, + }, nil) + tasks, err := t.mgr.ListScanTasksByReportUUID(nil, "uuid") + t.Require().Nil(err) + t.Require().Len(tasks, 1) + t.Equal(int64(1), tasks[0].ID) + t.dao.AssertExpectations(t.T()) +} + func TestTaskManagerTestSuite(t *testing.T) { suite.Run(t, &taskManagerTestSuite{}) } diff --git a/src/testing/pkg/task/manager.go b/src/testing/pkg/task/manager.go index 53a5c07dbdb..be0e06316f9 100644 --- a/src/testing/pkg/task/manager.go +++ b/src/testing/pkg/task/manager.go @@ -201,6 +201,32 @@ func (_m *Manager) List(ctx context.Context, query *q.Query) ([]*task.Task, erro return r0, r1 } +// ListScanTasksByReportUUID provides a mock function with given fields: ctx, uuid +func (_m *Manager) ListScanTasksByReportUUID(ctx context.Context, uuid string) ([]*task.Task, error) { + ret := _m.Called(ctx, uuid) + + var r0 []*task.Task + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) ([]*task.Task, error)); ok { + return rf(ctx, uuid) + } + if rf, ok := ret.Get(0).(func(context.Context, string) []*task.Task); ok { + r0 = rf(ctx, uuid) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*task.Task) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, uuid) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // Stop provides a mock function with given fields: ctx, id func (_m *Manager) Stop(ctx context.Context, id int64) error { ret := _m.Called(ctx, id)