Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tests: add more coverage #330

Merged
merged 1 commit into from
Jul 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/data/hftest/codebook.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"version": 1, "id_salt": "4688a4853dafc6a3d6934f0dd02205be0700d2ca64b636127a4436494dcaf88e"}
2 changes: 2 additions & 0 deletions tests/data/hftest/input/DocumentReference.ndjson
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
{"id":"43","content":[{"attachment":{"contentType":"text\/plain","data":"VGVzdCBub3RlIDE="}}],"context":{"encounter":[{"reference":"Encounter\/23"}],"period":{"end":"2021-06-24","start":"2021-06-23"}},"status":"current","subject":{"reference":"Patient\/334567"},"type":{"coding":[{"code":"NOTE:149798455","display":"Admission MD","system":"http://cumulus.smarthealthit.org/i2b2"}]},"resourceType":"DocumentReference"}
{"id":"44","content":[{"attachment":{"contentType":"text\/plain","data":"VGVzdCBub3RlIDI="}}],"context":{"encounter":[{"reference":"Encounter\/25"}],"period":{"end":"2021-06-25","start":"2021-06-24"}},"status":"current","subject":{"reference":"Patient\/323456"},"type":{"coding":[{"code":"NOTE:149798455","display":"Admission MD","system":"http://cumulus.smarthealthit.org/i2b2"}]},"resourceType":"DocumentReference"}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"table_name": "hftest__summary", "group_name": "test-group", "export_time": "2020-10-13T12:00:20-05:00"}
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
{"id": "c31a3dbf188ed241b2c06b2475cd56159017fa1df1ea882d3fc4beab860fc24d", "docref_id": "c31a3dbf188ed241b2c06b2475cd56159017fa1df1ea882d3fc4beab860fc24d", "generated_on": "2021-09-14T21:23:45+00:00", "task_version": 0, "summary": "Patient has a fever."}
{"id": "eb30741bbb9395fc3da72d02fd29b96e2e4c0c2592c3ae997d80bf522c80070e", "docref_id": "eb30741bbb9395fc3da72d02fd29b96e2e4c0c2592c3ae997d80bf522c80070e", "generated_on": "2021-09-14T21:23:45+00:00", "task_version": 0, "summary": "Patient has a fever."}
68 changes: 68 additions & 0 deletions tests/deid/test_deid_mstool.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
"""Tests for the mstool module"""

import asyncio
import filecmp
import os
import shutil
import tempfile
from unittest import mock

import pytest

Expand Down Expand Up @@ -74,3 +76,69 @@ async def test_bad_fhir(self):
common.write_json(os.path.join(input_dir, "Condition.ndjson"), {})
with self.assertRaises(SystemExit):
await run_mstool(input_dir, output_dir)


# Separate class here from the above, because this doesn't need the MS tool installed
class TestMicrosoftToolWrapper(AsyncTestCase):
"""Test case for the MS tool wrapper code"""

def setUp(self):
super().setUp()

self.process = mock.MagicMock()
self.process.returncode = None # process not yet finished

mock_exec = self.patch("asyncio.create_subprocess_exec")
mock_exec.return_value = self.process

async def test_progress(self):
"""Confirms that we poll for progress as we go"""
mock_progress = mock.MagicMock()
mock_wrapper = mock.MagicMock()
mock_wrapper.__enter__.return_value = mock_progress
self.patch("cumulus_etl.cli_utils.make_progress_bar", return_value=mock_wrapper)

# We are going to stage 3 different checkpoints:
# - a couple bytes written
# - first file in place, a couple bytes of second
# - both files in place, finished
self.patch(
"asyncio.wait_for",
side_effect=[
asyncio.TimeoutError,
asyncio.TimeoutError,
("Out", "Err"),
],
)

def fake_getsize(path: str) -> int:
match path:
case "first.ndjson":
return 10
case "second.ndjson":
return 10
case "tmp1.ndjson":
return 3
case "tmp2.ndjson":
self.process.returncode = 0 # mark the process as done
return 3
case "ghost.ndjson":
# Test that we gracefully handle files deleting underneath us
raise FileNotFoundError

self.patch(
"glob.glob",
side_effect=[
["first.ndjson", "second.ndjson"],
["tmp1.ndjson", "ghost.ndjson"],
["first.ndjson", "tmp2.ndjson"],
],
)
self.patch("os.path.getsize", side_effect=fake_getsize)

await run_mstool("/in", "/out")

self.assertEqual(mock_progress.update.call_count, 3)
self.assertEqual(mock_progress.update.call_args_list[0].kwargs, {"completed": 3 / 20})
self.assertEqual(mock_progress.update.call_args_list[1].kwargs, {"completed": 13 / 20})
self.assertEqual(mock_progress.update.call_args_list[2].kwargs, {"completed": 1})
152 changes: 152 additions & 0 deletions tests/hftest/test_hftask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
"""Tests for etl/studies/hftest/"""

import os

import respx

from cumulus_etl import common, errors
from cumulus_etl.etl.studies import hftest

from tests import i2b2_mock_data
from tests.etl import BaseEtlSimple, TaskTestCase


def mock_prompt(respx_mock: respx.MockRouter, text: str, url: str = "http://localhost:8086/") -> respx.Route:
full_prompt = f"""<s>[INST] <<SYS>>
You will be given a clinical note, and you should reply with a short summary of that note.
<</SYS>>

{text} [/INST]"""
return respx_mock.post(
url,
json={
"inputs": full_prompt,
"options": {
"wait_for_model": True,
},
"parameters": {
"max_new_tokens": 1000,
},
},
).respond(json=[{"generated_text": full_prompt + " Patient has a fever."}])


def mock_info(
respx_mock: respx.MockRouter, url: str = "http://localhost:8086/info", override: dict = None
) -> respx.Route:
response = {
"model_id": "meta-llama/Llama-2-13b-chat-hf",
"model_sha": "0ba94ac9b9e1d5a0037780667e8b219adde1908c",
"sha": "09eca6422788b1710c54ee0d05dd6746f16bb681",
}
response.update(override or {})
return respx_mock.get(url).respond(json=response)


class TestHuggingFaceTestTask(TaskTestCase):
"""Test case for HuggingFaceTestTask"""

@respx.mock(assert_all_called=True)
async def test_happy_path(self, respx_mock):
"""Verify we summarize a basic note properly"""
docref0 = i2b2_mock_data.documentreference()
self.make_json("DocumentReference", "0", **docref0)
mock_prompt(respx_mock, i2b2_mock_data.DOCREF_TEXT)

await hftest.HuggingFaceTestTask(self.job_config, self.scrubber).run()

self.assertEqual(1, self.format.write_records.call_count)
batch = self.format.write_records.call_args[0][0]
self.assertEqual(1, len(batch.rows))
expected_id = self.codebook.db.resource_hash("0")
self.assertEqual(
{
"id": expected_id,
"docref_id": expected_id,
"summary": "Patient has a fever.",
"generated_on": "2021-09-14T21:23:45+00:00",
"task_version": hftest.HuggingFaceTestTask.task_version,
},
batch.rows[0],
)

@respx.mock(assert_all_called=True)
async def test_env_url_override(self, respx_mock):
"""Verify we can override the hugging face default URL."""
docref0 = i2b2_mock_data.documentreference()
self.make_json("DocumentReference", "0", **docref0)

self.patch_dict(os.environ, {"CUMULUS_HUGGING_FACE_URL": "https://blarg/"})
mock_prompt(respx_mock, i2b2_mock_data.DOCREF_TEXT, url="https://blarg/")

await hftest.HuggingFaceTestTask(self.job_config, self.scrubber).run()
self.assertEqual(1, self.format.write_records.call_count)

@respx.mock(assert_all_called=True)
async def test_caching(self, respx_mock):
"""Verify we cache results"""
docref0 = i2b2_mock_data.documentreference()
self.make_json("DocumentReference", "0", **docref0)
route = mock_prompt(respx_mock, i2b2_mock_data.DOCREF_TEXT)

self.assertFalse(os.path.exists(f"{self.phi_dir}/ctakes-cache"))
await hftest.HuggingFaceTestTask(self.job_config, self.scrubber).run()

self.assertEqual(1, route.call_count)
cache_dir = f"{self.phi_dir}/ctakes-cache/hftest__summary_v0/06ee/"
cache_file = f"{cache_dir}/sha256-06ee538c626fbf4bdcec2199b7225c8034f26e2b46a7b5cb7ab385c8e8c00efa.json"
self.assertEqual("Patient has a fever.", common.read_text(cache_file))

await hftest.HuggingFaceTestTask(self.job_config, self.scrubber).run()
self.assertEqual(1, route.call_count)

# Confirm that if we remove the cache file, we call the endpoint again
os.remove(cache_file)
await hftest.HuggingFaceTestTask(self.job_config, self.scrubber).run()
self.assertEqual(2, route.call_count)

@respx.mock(assert_all_called=True)
async def test_init_check_unreachable(self, respx_mock):
"""Verify we bail if the server isn't reachable"""
respx_mock.get("http://localhost:8086/info").respond(status_code=500)
with self.assertRaises(SystemExit) as cm:
await hftest.HuggingFaceTestTask.init_check()
self.assertEqual(errors.SERVICE_MISSING, cm.exception.code)

@respx.mock(assert_all_called=True)
async def test_init_check_config(self, respx_mock):
"""Verify we check the server properties"""
# Happy path
mock_info(respx_mock)
await hftest.HuggingFaceTestTask.init_check()

# Bad model ID
mock_info(respx_mock, override={"model_id": "bogus/Llama-2-13b-chat-hf"})
with self.assertRaises(SystemExit) as cm:
await hftest.HuggingFaceTestTask.init_check()
self.assertEqual(errors.SERVICE_MISSING, cm.exception.code)

# Bad model SHA
mock_info(respx_mock, override={"model_sha": "bogus"})
with self.assertRaises(SystemExit) as cm:
await hftest.HuggingFaceTestTask.init_check()
self.assertEqual(errors.SERVICE_MISSING, cm.exception.code)

# Bad SHA
mock_info(respx_mock, override={"sha": "bogus"})
with self.assertRaises(SystemExit) as cm:
await hftest.HuggingFaceTestTask.init_check()
self.assertEqual(errors.SERVICE_MISSING, cm.exception.code)


class TestHuggingFaceETL(BaseEtlSimple):
"""Tests the end-to-end ETL of the hftest tasks."""

DATA_ROOT = "hftest"

@respx.mock(assert_all_called=True)
async def test_basic_etl(self, respx_mock):
mock_prompt(respx_mock, text="Test note 1")
mock_prompt(respx_mock, text="Test note 2")
await self.run_etl(tasks=["hftest__summary"])
self.assert_output_equal()
4 changes: 3 additions & 1 deletion tests/i2b2_mock_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from cumulus_etl.loaders.i2b2 import transform

DOCREF_TEXT = "Chief complaint: fever and chills. Denies cough."


def patient_dim() -> transform.PatientDimension:
return transform.PatientDimension(
Expand Down Expand Up @@ -63,7 +65,7 @@ def documentreference_dim() -> transform.ObservationFact:
"ENCOUNTER_NUM": 67890,
"CONCEPT_CD": "NOTE:149798455", # emergency room type
"START_DATE": "2016-01-01",
"OBSERVATION_BLOB": "Chief complaint: fever and chills. Denies cough.",
"OBSERVATION_BLOB": DOCREF_TEXT,
"TVAL_CHAR": "Emergency note",
}
)
Expand Down
66 changes: 66 additions & 0 deletions tests/nlp/test_watcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
"""Tests for nlp/watcher.py"""

import os
import tempfile
from unittest import mock

import ddt
import respx

from cumulus_etl import common, errors, nlp

from tests.ctakesmock import CtakesMixin
from tests.utils import AsyncTestCase


class TestNLPWatcher(AsyncTestCase):
"""Generic test case for service watching code"""

@mock.patch("cumulus_etl.cli_utils.is_url_available", new=lambda x: False)
def test_ctakes_down(self):
"""Verify we report cTAKES being down correctly"""
with self.assertRaises(SystemExit) as cm:
nlp.check_ctakes()
self.assertEqual(errors.CTAKES_MISSING, cm.exception.code)

@mock.patch("cumulus_etl.cli_utils.is_url_available", new=lambda x: False)
def test_negation_cnlpt_down(self):
"""Verify we report negation being down correctly"""
with self.assertRaises(SystemExit) as cm:
nlp.check_negation_cnlpt()
self.assertEqual(errors.CNLPT_MISSING, cm.exception.code)

@mock.patch("cumulus_etl.cli_utils.is_url_available", new=lambda x: False)
def test_term_exists_cnlpt_down(self):
"""Verify we report term exists being down correctly"""
with self.assertRaises(SystemExit) as cm:
nlp.check_term_exists_cnlpt()
self.assertEqual(errors.CNLPT_MISSING, cm.exception.code)

def test_restart_ctakes_no_folder(self):
self.assertFalse(nlp.restart_ctakes_with_bsv("", ""))

def test_restart_ctakes_nonexistent_folder(self):
with tempfile.TemporaryDirectory() as tmpdir:
self.assertFalse(nlp.restart_ctakes_with_bsv(f"{tmpdir}/nope", ""))

def test_restart_ctakes_file_not_folder(self):
with tempfile.NamedTemporaryFile() as file:
self.assertFalse(nlp.restart_ctakes_with_bsv(file.name, ""))


class TestCTakesWatcher(CtakesMixin, AsyncTestCase):
"""Test case for cTAKES watching code that needs a real server"""

@mock.patch("select.poll")
@mock.patch("time.sleep", new=lambda x: None) # don't sleep during restart
def test_restart_timeout(self, mock_poll):
mock_poller = mock.MagicMock()
mock_poller.poll.return_value = False
mock_poll.return_value = mock_poller

with tempfile.NamedTemporaryFile() as file:
common.write_text(file.name, "C0028081|T184|night sweats|Sweats")
with self.assertRaises(SystemExit) as cm:
nlp.restart_ctakes_with_bsv(self.ctakes_overrides.name, file.name)
self.assertEqual(errors.CTAKES_RESTART_FAILED, cm.exception.code)
Loading