Skip to content

Commit

Permalink
use db to cache gap analysis results instead of redis (#435)
Browse files Browse the repository at this point in the history
* use db to cache gap analysis results instead of redis

* lint

* typo

* fix 1 test, fix bugs, make cache key into its own function

* lint

* fix-mock

* migration

* Fix: bad ref

* fix tests

* lint

* minor changes

* lint

---------

Co-authored-by: john681611 <[email protected]>
  • Loading branch information
northdpole and john681611 authored Oct 24, 2023
1 parent cda463f commit 8111c40
Show file tree
Hide file tree
Showing 7 changed files with 192 additions and 76 deletions.
50 changes: 38 additions & 12 deletions application/database/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import uuid

from application.utils.gap_analysis import get_path_score
from application.utils.hash import make_array_hash
from application.utils.hash import make_array_hash, make_cache_key


from .. import sqla # type: ignore
Expand Down Expand Up @@ -172,6 +172,13 @@ class Embeddings(BaseModel): # type: ignore
)


class GapAnalysisResults(BaseModel):
__tablename__ = "gap_analysis_results"
cache_key = sqla.Column(sqla.String, primary_key=True)
ga_object = sqla.Column(sqla.String)
__table_args__ = (sqla.UniqueConstraint(cache_key, name="unique_cache_key_field"),)


class RelatedRel(StructuredRel):
pass

Expand Down Expand Up @@ -425,7 +432,6 @@ def link_CRE_to_Node(self, CRE_id, node_id, link_type):
def gap_analysis(self, name_1, name_2):
base_standard = NeoStandard.nodes.filter(name=name_1)
denylist = ["Cross-cutting concerns"]
from pprint import pprint
from datetime import datetime

t1 = datetime.now()
Expand All @@ -442,8 +448,6 @@ def gap_analysis(self, name_1, name_2):
resolve_objects=True,
)
t2 = datetime.now()
pprint(f"path records all took {t2-t1}")
pprint(path_records_all.__len__())
path_records, _ = db.cypher_query(
"""
OPTIONAL MATCH (BaseStandard:NeoStandard {name: $name1})
Expand Down Expand Up @@ -485,9 +489,6 @@ def format_path_record(rec):
"path": [format_segment(seg, rec.nodes) for seg in rec.relationships],
}

pprint(
f"path records all took {t2-t1} path records took {t3 - t2}, total: {t3 - t1}"
)
return [NEO_DB.parse_node(rec) for rec in base_standard], [
format_path_record(rec[0]) for rec in (path_records + path_records_all)
]
Expand Down Expand Up @@ -1635,6 +1636,22 @@ def add_embedding(

return existing

def get_gap_analysis_result(self, cache_key) -> str:
res = (
self.session.query(GapAnalysisResults)
.filter(GapAnalysisResults.cache_key == cache_key)
.first()
)
if res:
return res.ga_object

def add_gap_analysis_result(self, cache_key: str, ga_object: str):
existing = self.get_gap_analysis_result(cache_key)
if not existing:
res = GapAnalysisResults(cache_key=cache_key, ga_object=ga_object)
self.session.add(res)
self.session.commit()


def dbNodeFromNode(doc: cre_defs.Node) -> Optional[Node]:
if doc.doctype == cre_defs.Credoctypes.Standard:
Expand Down Expand Up @@ -1767,6 +1784,7 @@ def gap_analysis(
store_in_cache: bool = False,
cache_key: str = "",
):
cre_db = Node_collection()
base_standard, paths = neo_db.gap_analysis(node_names[0], node_names[1])
if base_standard is None:
return None
Expand Down Expand Up @@ -1809,16 +1827,24 @@ def gap_analysis(
): # lightweight memory option to not return potentially huge object and instead store in a cache,
# in case this is called via worker, we save both this and the caller memory by avoiding duplicate object in mem

conn = redis.connect()
# conn = redis.connect()
if cache_key == "":
cache_key = make_array_hash(node_names)

conn.set(cache_key, flask_json.dumps({"result": grouped_paths}))
# conn.set(cache_key, flask_json.dumps({"result": grouped_paths}))
cre_db.add_gap_analysis_result(
cache_key=cache_key, ga_object=flask_json.dumps({"result": grouped_paths})
)

for key in extra_paths_dict:
conn.set(
cache_key + "->" + key,
flask_json.dumps({"result": extra_paths_dict[key]}),
cre_db.add_gap_analysis_result(
cache_key=make_cache_key(node_names, key),
ga_object=flask_json.dumps({"result": extra_paths_dict[key]}),
)
# conn.set(
# cache_key + "->" + key,
# flask_json.dumps({"result": extra_paths_dict[key]}),
# )
return (node_names, {}, {})

return (node_names, grouped_paths, extra_paths_dict)
22 changes: 8 additions & 14 deletions application/tests/db_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1511,9 +1511,8 @@ def test_gap_analysis_duplicate_link_path_existing_higher_and_in_extras(
)
self.assertEqual(db.gap_analysis(collection.neo_db, ["a", "b"]), expected)

@patch.object(redis, "from_url")
@patch.object(db.NEO_DB, "gap_analysis")
def test_gap_analysis_dump_to_cache(self, gap_mock, redis_conn_mock):
def test_gap_analysis_dump_to_cache(self, gap_mock):
collection = db.Node_collection()
collection.neo_db.connected = True
path = [
Expand Down Expand Up @@ -1567,18 +1566,13 @@ def test_gap_analysis_dump_to_cache(self, gap_mock, redis_conn_mock):
response = db.gap_analysis(collection.neo_db, ["a", "b"], True)

self.assertEqual(response, (expected_response[0], {}, {}))

redis_conn_mock.return_value.set.assert_has_calls(
[
mock.call(
"d8160c9b3dc20d4e931aeb4f45262155",
flask_json.dumps({"result": expected_response[1]}),
),
mock.call(
"d8160c9b3dc20d4e931aeb4f45262155->a",
flask_json.dumps({"result": expected_response[2]["a"]}),
),
]
self.assertEqual(
collection.get_gap_analysis_result("d8160c9b3dc20d4e931aeb4f45262155"),
flask_json.dumps({"result": expected_response[1]}),
)
self.assertEqual(
collection.get_gap_analysis_result("d8160c9b3dc20d4e931aeb4f45262155->a"),
flask_json.dumps({"result": expected_response[2]["a"]}),
)

def test_neo_db_parse_node_code(self):
Expand Down
24 changes: 17 additions & 7 deletions application/tests/web_main_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,17 @@
from application.defs import cre_defs as defs
from application.defs import osib_defs
from application.web import web_main
from application.utils.hash import make_array_hash, make_cache_key


class MockJob:
@property
def id(self):
return "ABC"

def get_status(self):
return rq.job.JobStatus.STARTED


class TestMain(unittest.TestCase):
def tearDown(self) -> None:
Expand Down Expand Up @@ -574,10 +578,14 @@ def test_smartlink(self) -> None:
self.assertEqual(404, response.status_code)

@patch.object(redis, "from_url")
def test_gap_analysis_from_cache_full_response(self, redis_conn_mock) -> None:
@patch.object(db, "Node_collection")
def test_gap_analysis_from_cache_full_response(
self, db_mock, redis_conn_mock
) -> None:
expected = {"result": "hello"}
redis_conn_mock.return_value.exists.return_value = True
redis_conn_mock.return_value.get.return_value = json.dumps(expected)
db_mock.return_value.get_gap_analysis_result.return_value = json.dumps(expected)
with self.app.test_client() as client:
response = client.get(
"/rest/v1/map_analysis?standard=aaa&standard=bbb",
Expand All @@ -586,14 +594,16 @@ def test_gap_analysis_from_cache_full_response(self, redis_conn_mock) -> None:
self.assertEqual(200, response.status_code)
self.assertEqual(expected, json.loads(response.data))

@patch.object(rq.job.Job, "fetch")
@patch.object(rq.Queue, "enqueue_call")
@patch.object(redis, "from_url")
def test_gap_analysis_from_cache_job_id(
self, redis_conn_mock, enqueue_call_mock
self, redis_conn_mock, enqueue_call_mock, fetch_mock
) -> None:
expected = {"job_id": "hello"}
redis_conn_mock.return_value.exists.return_value = True
redis_conn_mock.return_value.get.return_value = json.dumps(expected)
fetch_mock.return_value = MockJob()
with self.app.test_client() as client:
response = client.get(
"/rest/v1/map_analysis?standard=aaa&standard=bbb",
Expand All @@ -610,8 +620,9 @@ def test_gap_analysis_create_job_id(
self, redis_conn_mock, enqueue_call_mock, db_mock
) -> None:
expected = {"job_id": "ABC"}
redis_conn_mock.return_value.exists.return_value = False
redis_conn_mock.return_value.get.return_value = None
enqueue_call_mock.return_value = MockJob()
db_mock.return_value.get_gap_analysis_result.return_value = None
with self.app.test_client() as client:
response = client.get(
"/rest/v1/map_analysis?standard=aaa&standard=bbb",
Expand Down Expand Up @@ -683,11 +694,10 @@ def test_gap_analysis_weak_links_no_cache(self) -> None:
)
self.assertEqual(404, response.status_code)

@patch.object(redis, "from_url")
def test_gap_analysis_weak_links_response(self, redis_conn_mock) -> None:
@patch.object(db, "Node_collection")
def test_gap_analysis_weak_links_response(self, db_mock) -> None:
expected = {"result": "hello"}
redis_conn_mock.return_value.exists.return_value = True
redis_conn_mock.return_value.get.return_value = json.dumps(expected)
db_mock.return_value.get_gap_analysis_result.return_value = json.dumps(expected)
with self.app.test_client() as client:
response = client.get(
"/rest/v1/map_analysis_weak_links?standard=aaa&standard=bbb&key=ccc`",
Expand Down
4 changes: 4 additions & 0 deletions application/utils/hash.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import hashlib


def make_cache_key(standards: list, key: str) -> str:
return make_array_hash(standards) + "->" + key


def make_array_hash(array: list):
return hashlib.md5(":".join(array).encode("utf-8")).hexdigest()
21 changes: 11 additions & 10 deletions application/utils/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@

def connect():
redis_url = os.getenv("REDIS_URL", "redis://localhost:6379")

url = urlparse(redis_url)
r = redis.Redis(
host=url.hostname,
port=url.port,
password=url.password,
ssl=True,
ssl_cert_reqs=None,
)
return r
if redis_url == "redis://localhost:6379":
return redis.from_url(redis_url)
else:
url = urlparse(redis_url)
return redis.Redis(
host=url.hostname,
port=url.port,
password=url.password,
ssl=True,
ssl_cert_reqs=None,
)
75 changes: 42 additions & 33 deletions application/web/web_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from application.utils.spreadsheet import write_csv
import oauthlib
import google.auth.transport.requests
from application.utils.hash import make_array_hash
from application.utils.hash import make_array_hash, make_cache_key

ITEMS_PER_PAGE = 20

Expand Down Expand Up @@ -226,26 +226,27 @@ def gap_analysis() -> Any:
standards = request.args.getlist("standard")
conn = redis.connect()
standards_hash = make_array_hash(standards)
if conn.exists(standards_hash):
gap_analysis_results = conn.get(standards_hash)
if gap_analysis_results:
gap_analysis_dict = json.loads(gap_analysis_results)
if gap_analysis_dict.get("result"):
return jsonify({"result": gap_analysis_dict.get("result")})
elif gap_analysis_dict.get("job_id"):
try:
res = job.Job.fetch(
id=gap_analysis_dict.get("job_id"), connection=conn
)
except exceptions.NoSuchJobError as nje:
abort(404, "No such job")
if (
res.get_status() != job.JobStatus.FAILED
and res.get_status() == job.JobStatus.STOPPED
and res.get_status() == job.JobStatus.CANCELED
):
logger.info("gap analysis job id already exists, returning early")
return jsonify({"job_id": gap_analysis_dict.get("job_id")})
result = database.get_gap_analysis_result(standards_hash)
if result:
gap_analysis_dict = flask_json.loads(result)
if gap_analysis_dict.get("result"):
return jsonify(gap_analysis_dict)

gap_analysis_results = conn.get(standards_hash)
if gap_analysis_results:
gap_analysis_dict = json.loads(gap_analysis_results)
if gap_analysis_dict.get("job_id"):
try:
res = job.Job.fetch(id=gap_analysis_dict.get("job_id"), connection=conn)
except exceptions.NoSuchJobError as nje:
abort(404, "No such job")
if (
res.get_status() != job.JobStatus.FAILED
and res.get_status() != job.JobStatus.STOPPED
and res.get_status() != job.JobStatus.CANCELED
):
logger.info("gap analysis job id already exists, returning early")
return jsonify({"job_id": gap_analysis_dict.get("job_id")})
q = Queue(connection=conn)
gap_analysis_job = q.enqueue_call(
db.gap_analysis,
Expand All @@ -266,15 +267,21 @@ def gap_analysis() -> Any:
def gap_analysis_weak_links() -> Any:
standards = request.args.getlist("standard")
key = request.args.get("key")
conn = redis.connect()
standards_hash = make_array_hash(standards)
cache_key = standards_hash + "->" + key
if conn.exists(cache_key):
gap_analysis_results = conn.get(cache_key)
if gap_analysis_results:
gap_analysis_dict = json.loads(gap_analysis_results)
if gap_analysis_dict.get("result"):
return jsonify({"result": gap_analysis_dict.get("result")})
cache_key = make_cache_key(standards=standards, key=key)

database = db.Node_collection()
gap_analysis_results = database.get_gap_analysis_result(cache_key=cache_key)
if gap_analysis_results:
gap_analysis_dict = json.loads(gap_analysis_results)
if gap_analysis_dict.get("result"):
return jsonify({"result": gap_analysis_dict.get("result")})

# if conn.exists(cache_key):
# gap_analysis_results = conn.get(cache_key)
# if gap_analysis_results:
# gap_analysis_dict = json.loads(gap_analysis_results)
# if gap_analysis_dict.get("result"):
# return jsonify({"result": gap_analysis_dict.get("result")})
abort(404, "No such Cache")


Expand Down Expand Up @@ -315,12 +322,14 @@ def fetch_job() -> Any:

if conn.exists(standards_hash):
logger.info("and hash is already in cache")
ga = conn.get(standards_hash)
# ga = conn.get(standards_hash)
database = db.Node_collection()
ga = database.get_gap_analysis_result(standards_hash)
if ga:
logger.info("and results in cache")
ga = json.loads(ga)
ga = flask_json.loads(ga)
if ga.get("result"):
return jsonify({"result": ga.get("result")})
return jsonify(ga)
else:
logger.error(
"Finished job does not have a result object, this is a bug!"
Expand Down
Loading

0 comments on commit 8111c40

Please sign in to comment.