From e4087b099f94b6c651a3e1bbd98991cc941ae65a Mon Sep 17 00:00:00 2001 From: Spyros Date: Thu, 25 Jan 2024 19:01:52 +0000 Subject: [PATCH] parallelize main --- application/cmd/cre_main.py | 51 ++++++++++++++++++++++++++----- application/utils/gap_analysis.py | 1 - application/web/web_main.py | 2 +- 3 files changed, 44 insertions(+), 10 deletions(-) diff --git a/application/cmd/cre_main.py b/application/cmd/cre_main.py index 45e6c7ca..303c055b 100644 --- a/application/cmd/cre_main.py +++ b/application/cmd/cre_main.py @@ -5,7 +5,8 @@ import shutil import tempfile from typing import Any, Callable, Dict, Generator, List, Optional, Tuple - +from application.utils.hash import make_cache_key +from rq import Worker, Queue, Connection, job, exceptions import yaml from application import create_app # type: ignore from application.config import CMDConfig @@ -13,6 +14,7 @@ from application.defs import cre_defs as defs from application.defs import osib_defs as odefs from application.utils import spreadsheet as sheet_utils +from application.utils import redis from application.utils import spreadsheet_parsers from application.utils.external_project_parsers import ( capec_parser, @@ -199,6 +201,35 @@ def parse_file( return resulting_objects +def send_job_to_worker(job_info_hash: str, job: function, kwargs: Dict): + conn = redis.connect() + if conn.get(job_info_hash): + logger.debug( + f"Job with info-hash {job_info_hash} has already returned, skipping running {job.__name__} with args {kwargs}" + ) + return + + q = Queue(connection=conn) + job = q.enqueue_call(job, kwargs, timeout="10m") + return job + + +def register_standard( + standard_entries: List[defs.Standard], + collection: db.Node_collection, + prompt_client: prompt_client.PromptHandler, +): + if not standard_entries: + return + for node in standard_entries: + register_node(node, collection) + prompt_client.generate_embeddings_for(node.name) + populate_neo4j_db(collection) + conn = redis.connect() + conn.set(make_cache_key(standards=standard_entries, key=""), value="") + + +# TODO (spyros): test, mock send_job_to_worker def parse_standards_from_spreadsheeet( cre_file: List[Dict[str, Any]], collection: db.Node_collection ) -> None: @@ -210,14 +241,18 @@ def parse_standards_from_spreadsheeet( for _, cres in cres.pop(defs.Credoctypes.CRE.value): for cre in cres: register_cre(cre, collection) - # TODO(notrhdpole): sync GraphDB + populate_neo4j_db(collection) pc.generate_embeddings_for(defs.Credoctypes.CRE.value) - for standard_name, standards in documents: - # TODO (northdpole): parallelise, send each element of this array to a different worker and move the following to a worker method - for node in standards: - register_node(node, collection) - pc.generate_embeddings_for(standard_name) - # TODO(notrhdpole): sync GraphDB + for _, standard_entries in documents: + send_job_to_worker( + job_info_hash=make_cache_key(standard_entries, ""), + job=register_standard, + kwargs={ + "standard_entries": standard_entries, + "collection": collection, + "prompt_client": pc, + }, + ) # TODO(notrhdpole): calculate gap analysis elif any(key.startswith("CRE hierarchy") for key in cre_file[0].keys()): diff --git a/application/utils/gap_analysis.py b/application/utils/gap_analysis.py index e917beb5..5b72f19b 100644 --- a/application/utils/gap_analysis.py +++ b/application/utils/gap_analysis.py @@ -72,6 +72,5 @@ def preload(target_url: str): if f"{sb}->{sa}" in waiting: waiting.remove(f"{sb}->{sa}") print(f"calculating {len(waiting)} gap analyses") - # print(waiting) time.sleep(30) print("map analysis preloaded successfully") diff --git a/application/web/web_main.py b/application/web/web_main.py index 49e43bb5..114d449a 100644 --- a/application/web/web_main.py +++ b/application/web/web_main.py @@ -9,7 +9,7 @@ from typing import Any from application.utils import oscal_utils, redis -from rq import Worker, Queue, Connection, job, exceptions +from rq import Queue, job, exceptions from application.database import db from application.defs import cre_defs as defs