From b1e282365281c0b88c4467f603702f95a64f9401 Mon Sep 17 00:00:00 2001 From: Michael Terry Date: Mon, 29 Jul 2024 12:47:22 -0400 Subject: [PATCH 1/3] style: switch to ruff --- .github/workflows/ci.yaml | 23 +----- .pre-commit-config.yaml | 17 ++-- cumulus_etl/cli_utils.py | 2 +- cumulus_etl/common.py | 7 +- cumulus_etl/completion/__init__.py | 2 +- cumulus_etl/completion/schema.py | 1 - cumulus_etl/deid/codebook.py | 4 +- cumulus_etl/deid/scrubber.py | 2 +- cumulus_etl/errors.py | 1 - cumulus_etl/etl/cli.py | 1 - cumulus_etl/etl/config.py | 14 ++-- cumulus_etl/etl/convert/cli.py | 2 +- .../etl/studies/covid_symptom/covid_tasks.py | 16 ++-- cumulus_etl/etl/tasks/base.py | 19 +++-- cumulus_etl/etl/tasks/basic_tasks.py | 77 ++++++++++--------- cumulus_etl/etl/tasks/nlp_task.py | 18 +++-- cumulus_etl/etl/tasks/task_factory.py | 9 ++- cumulus_etl/fhir/fhir_auth.py | 4 +- cumulus_etl/fhir/fhir_client.py | 16 ++-- cumulus_etl/formats/base.py | 4 +- cumulus_etl/formats/batch.py | 2 +- cumulus_etl/loaders/fhir/bulk_export.py | 16 ++-- cumulus_etl/loaders/fhir/export_log.py | 2 +- cumulus_etl/loaders/fhir/ndjson_loader.py | 6 +- cumulus_etl/loaders/i2b2/loader.py | 2 +- cumulus_etl/loaders/i2b2/oracle/extract.py | 9 ++- cumulus_etl/loaders/i2b2/oracle/query.py | 11 ++- cumulus_etl/loaders/i2b2/transform.py | 5 +- cumulus_etl/nlp/__init__.py | 4 +- cumulus_etl/nlp/utils.py | 2 +- cumulus_etl/upload_notes/downloader.py | 10 +-- cumulus_etl/upload_notes/labelstudio.py | 1 - cumulus_etl/upload_notes/selector.py | 8 +- pyproject.toml | 53 ++++++++----- tests/convert/test_convert_cli.py | 2 +- tests/covid_symptom/test_covid_results.py | 3 +- tests/etl/base.py | 4 +- tests/etl/test_batching.py | 1 - tests/etl/test_etl_cli.py | 5 +- tests/formats/test_deltalake.py | 2 +- tests/hftest/test_hftask.py | 3 +- tests/loaders/i2b2/test_i2b2_transform.py | 1 - tests/loaders/ndjson/test_bulk_export.py | 4 +- tests/nlp/test_watcher.py | 5 -- tests/test_cli.py | 2 +- tests/upload_notes/test_upload_cli.py | 15 ++-- tests/upload_notes/test_upload_labelstudio.py | 1 - tests/utils.py | 6 +- 48 files changed, 209 insertions(+), 215 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 8b5a186a..0ef2a0ea 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -118,24 +118,5 @@ jobs: python -m pip install --upgrade pip pip install .[dev] - - name: Run pycodestyle - # E203: pycodestyle is a little too rigid about slices & whitespace - # See https://black.readthedocs.io/en/stable/the_black_code_style/current_style.html#slices - # W503: a default ignore that we are restoring - run: | - pycodestyle --max-line-length=120 --ignore=E203,W503 . - - - name: Run pylint - if: success() || failure() # still run pylint if above checks fail - run: | - pylint cumulus_etl tests - - - name: Run bandit - if: success() || failure() # still run bandit if above checks fail - run: | - bandit -c pyproject.toml -r . - - - name: Run black - if: success() || failure() # still run black if above checks fails - run: | - black --check --verbose --line-length 120 . + - name: Run ruff + run: ruff check --output-format=github . diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c67f6ecf..afb9945a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,11 +1,10 @@ repos: - - repo: https://github.com/psf/black - rev: 24.4.2 # keep in rough sync with pyproject.toml + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.5.5 # keep in rough sync with pyproject.toml hooks: - - id: black - entry: bash -c 'black "$@"; git add -u' -- - # It is recommended to specify the latest version of Python - # supported by your project here, or alternatively use - # pre-commit's default_language_version, see - # https://pre-commit.com/#top_level-default_language_version - language_version: python3.12 + - name: Ruff formatting + id: ruff-format + entry: bash -c 'ruff format --force-exclude "$@"; git add -u' -- + - name: Ruff linting + id: ruff + stages: [pre-push] diff --git a/cumulus_etl/cli_utils.py b/cumulus_etl/cli_utils.py index e7455269..72873370 100644 --- a/cumulus_etl/cli_utils.py +++ b/cumulus_etl/cli_utils.py @@ -46,7 +46,7 @@ def add_debugging(parser: argparse.ArgumentParser): return group -def make_export_dir(export_to: str = None) -> common.Directory: +def make_export_dir(export_to: str | None = None) -> common.Directory: """Makes a temporary directory to drop exported ndjson files into""" # Handle the easy case -- just a random temp dir if not export_to: diff --git a/cumulus_etl/common.py b/cumulus_etl/common.py index 22ee589f..30c71059 100644 --- a/cumulus_etl/common.py +++ b/cumulus_etl/common.py @@ -16,7 +16,6 @@ from cumulus_etl import store - ############################################################################### # # Types @@ -151,7 +150,7 @@ def read_json(path: str) -> Any: return json.load(f) -def write_json(path: str, data: Any, indent: int = None) -> None: +def write_json(path: str, data: Any, indent: int | None = None) -> None: """ Writes data to the given path, in json format :param path: filesystem path @@ -354,7 +353,7 @@ def datetime_now(local: bool = False) -> datetime.datetime: return now -def timestamp_datetime(time: datetime.datetime = None) -> str: +def timestamp_datetime(time: datetime.datetime | None = None) -> str: """ Human-readable UTC date and time :return: MMMM-DD-YYY hh:mm:ss @@ -363,7 +362,7 @@ def timestamp_datetime(time: datetime.datetime = None) -> str: return time.strftime("%Y-%m-%d %H:%M:%S") -def timestamp_filename(time: datetime.datetime = None) -> str: +def timestamp_filename(time: datetime.datetime | None = None) -> str: """ Human-readable UTC date and time suitable for a filesystem path diff --git a/cumulus_etl/completion/__init__.py b/cumulus_etl/completion/__init__.py index 174a266b..985db897 100644 --- a/cumulus_etl/completion/__init__.py +++ b/cumulus_etl/completion/__init__.py @@ -12,8 +12,8 @@ """ from .schema import ( - COMPLETION_TABLE, COMPLETION_ENCOUNTERS_TABLE, + COMPLETION_TABLE, completion_encounters_output_args, completion_encounters_schema, completion_format_args, diff --git a/cumulus_etl/completion/schema.py b/cumulus_etl/completion/schema.py index e2c18daf..d1a66f9c 100644 --- a/cumulus_etl/completion/schema.py +++ b/cumulus_etl/completion/schema.py @@ -2,7 +2,6 @@ import pyarrow - COMPLETION_TABLE = "etl__completion" COMPLETION_ENCOUNTERS_TABLE = "etl__completion_encounters" diff --git a/cumulus_etl/deid/codebook.py b/cumulus_etl/deid/codebook.py index d1af4430..7e8cfcc6 100644 --- a/cumulus_etl/deid/codebook.py +++ b/cumulus_etl/deid/codebook.py @@ -19,7 +19,7 @@ class Codebook: Some IDs may be cryptographically hashed versions of the real ID, some may be entirely random. """ - def __init__(self, codebook_dir: str = None): + def __init__(self, codebook_dir: str | None = None): """ :param codebook_dir: saved codebook path or None (initialize empty) """ @@ -83,7 +83,7 @@ def real_ids(self, resource_type: str, fake_ids: Iterable[str]) -> Iterator[str] class CodebookDB: """Class to hold codebook data and read/write it to storage""" - def __init__(self, codebook_dir: str = None): + def __init__(self, codebook_dir: str | None = None): """ Create a codebook database. diff --git a/cumulus_etl/deid/scrubber.py b/cumulus_etl/deid/scrubber.py index 0f986d51..ac4fa662 100644 --- a/cumulus_etl/deid/scrubber.py +++ b/cumulus_etl/deid/scrubber.py @@ -30,7 +30,7 @@ class Scrubber: the resource is fully de-identified. """ - def __init__(self, codebook_dir: str = None, use_philter: bool = False): + def __init__(self, codebook_dir: str | None = None, use_philter: bool = False): self.codebook = codebook.Codebook(codebook_dir) self.codebook_dir = codebook_dir self.philter = philter.Philter() if use_philter else None diff --git a/cumulus_etl/errors.py b/cumulus_etl/errors.py index 652731d7..994c3eeb 100644 --- a/cumulus_etl/errors.py +++ b/cumulus_etl/errors.py @@ -6,7 +6,6 @@ import httpx import rich.console - # Error return codes, mostly just distinguished for the benefit of tests. # These start at 10 just to leave some room for future use. SQL_USER_MISSING = 10 diff --git a/cumulus_etl/etl/cli.py b/cumulus_etl/etl/cli.py index fb3be80b..bc836ae7 100644 --- a/cumulus_etl/etl/cli.py +++ b/cumulus_etl/etl/cli.py @@ -16,7 +16,6 @@ from cumulus_etl.etl.config import JobConfig, JobSummary from cumulus_etl.etl.tasks import task_factory - ############################################################################### # # Main Pipeline (run all tasks) diff --git a/cumulus_etl/etl/config.py b/cumulus_etl/etl/config.py index c468c54d..9195ae2d 100644 --- a/cumulus_etl/etl/config.py +++ b/cumulus_etl/etl/config.py @@ -25,14 +25,14 @@ def __init__( input_format: str, output_format: str, client: fhir.FhirClient, - timestamp: datetime.datetime = None, - comment: str = None, + timestamp: datetime.datetime | None = None, + comment: str | None = None, batch_size: int = 1, # this default is never really used - overridden by command line args - ctakes_overrides: str = None, - dir_errors: str = None, - tasks: list[str] = None, - export_group_name: str = None, - export_datetime: datetime.datetime = None, + ctakes_overrides: str | None = None, + dir_errors: str | None = None, + tasks: list[str] | None = None, + export_group_name: str | None = None, + export_datetime: datetime.datetime | None = None, ): self._dir_input_orig = dir_input_orig self.dir_input = dir_input_deid diff --git a/cumulus_etl/etl/convert/cli.py b/cumulus_etl/etl/convert/cli.py index d4b7036f..5f63707f 100644 --- a/cumulus_etl/etl/convert/cli.py +++ b/cumulus_etl/etl/convert/cli.py @@ -7,8 +7,8 @@ import argparse import os import tempfile +from collections.abc import Callable from functools import partial -from typing import Callable import pyarrow import rich.progress diff --git a/cumulus_etl/etl/studies/covid_symptom/covid_tasks.py b/cumulus_etl/etl/studies/covid_symptom/covid_tasks.py index 48162460..31574623 100644 --- a/cumulus_etl/etl/studies/covid_symptom/covid_tasks.py +++ b/cumulus_etl/etl/studies/covid_symptom/covid_tasks.py @@ -1,6 +1,7 @@ """Define tasks for the covid_symptom study""" import itertools +from typing import ClassVar import ctakesclient import pyarrow @@ -11,7 +12,6 @@ from cumulus_etl.etl import tasks from cumulus_etl.etl.studies.covid_symptom import covid_ctakes - # List of recognized emergency department note types. We'll add more as we discover them in use. ED_CODES = { "http://loinc.org": { @@ -109,7 +109,7 @@ class BaseCovidSymptomNlpResultsTask(tasks.BaseNlpTask): # cNLP: smartonfhir/cnlp-transformers:negation-0.4.0 # ctakesclient: 3.0 - outputs = [tasks.OutputTable(resource_type=None, group_field="docref_id")] + outputs: ClassVar = [tasks.OutputTable(resource_type=None, group_field="docref_id")] async def prepare_task(self) -> bool: bsv_path = ctakesclient.filesystem.covid_symptoms_path() @@ -196,9 +196,9 @@ def get_schema(cls, resource_type: str | None, rows: list[dict]) -> pyarrow.Sche class CovidSymptomNlpResultsTask(BaseCovidSymptomNlpResultsTask): """Covid Symptom study task, to generate symptom lists from ED notes using cTAKES and cnlpt negation""" - name = "covid_symptom__nlp_results" - tags = {"covid_symptom", "gpu"} - polarity_model = TransformerModel.NEGATION + name: ClassVar = "covid_symptom__nlp_results" + tags: ClassVar = {"covid_symptom", "gpu"} + polarity_model: ClassVar = TransformerModel.NEGATION @classmethod async def init_check(cls) -> None: @@ -209,12 +209,12 @@ async def init_check(cls) -> None: class CovidSymptomNlpResultsTermExistsTask(BaseCovidSymptomNlpResultsTask): """Covid Symptom study task, to generate symptom lists from ED notes using cTAKES and cnlpt termexists""" - name = "covid_symptom__nlp_results_term_exists" - polarity_model = TransformerModel.TERM_EXISTS + name: ClassVar = "covid_symptom__nlp_results_term_exists" + polarity_model: ClassVar = TransformerModel.TERM_EXISTS # Explicitly don't use any tags because this is really a "hidden" task that is mostly for comparing # polarity model performance more than running a study. So we don't want it to be accidentally run. - tags = {} + tags: ClassVar = {} @classmethod async def init_check(cls) -> None: diff --git a/cumulus_etl/etl/tasks/base.py b/cumulus_etl/etl/tasks/base.py index eac6f7b4..9d93204c 100644 --- a/cumulus_etl/etl/tasks/base.py +++ b/cumulus_etl/etl/tasks/base.py @@ -4,6 +4,7 @@ import dataclasses import os from collections.abc import AsyncIterator, Iterator +from typing import ClassVar import cumulus_fhir_support import pyarrow @@ -86,12 +87,14 @@ class EtlTask: """ # Properties: - name: str = None # task & table name - resource: str = None # incoming resource that this task operates on (will be included in bulk exports etc) - tags: set[str] = [] - needs_bulk_deid = True # whether this task needs bulk MS tool de-id run on its inputs (NLP tasks usually don't) + name: ClassVar[str] = None # task & table name + # incoming resource that this task operates on (will be included in bulk exports etc) + resource: ClassVar[str] = None + tags: ClassVar[set[str]] = [] + # whether this task needs bulk MS tool de-id run on its inputs (NLP tasks usually don't) + needs_bulk_deid: ClassVar[bool] = True - outputs: list[OutputTable] = [OutputTable()] + outputs: ClassVar[list[OutputTable]] = [OutputTable()] ########################################################################################## # @@ -100,8 +103,8 @@ class EtlTask: ########################################################################################## def __init__(self, task_config: config.JobConfig, scrubber: deid.Scrubber): - assert self.name # nosec - assert self.resource # nosec + assert self.name # noqa: S101 + assert self.resource # noqa: S101 self.task_config = task_config self.scrubber = scrubber self.formatters: list[formats.Format | None] = [None] * len(self.outputs) # create format placeholders @@ -152,7 +155,7 @@ async def run(self) -> list[config.JobSummary]: return self.summaries @classmethod - def make_batch_from_rows(cls, resource_type: str | None, rows: list[dict], groups: set[str] = None): + def make_batch_from_rows(cls, resource_type: str | None, rows: list[dict], groups: set[str] | None = None): schema = cls.get_schema(resource_type, rows) return formats.Batch(rows, groups=groups, schema=schema) diff --git a/cumulus_etl/etl/tasks/basic_tasks.py b/cumulus_etl/etl/tasks/basic_tasks.py index 064793f9..d69af758 100644 --- a/cumulus_etl/etl/tasks/basic_tasks.py +++ b/cumulus_etl/etl/tasks/basic_tasks.py @@ -3,6 +3,7 @@ import copy import logging import os +from typing import ClassVar import pyarrow import rich.progress @@ -12,46 +13,46 @@ class AllergyIntoleranceTask(tasks.EtlTask): - name = "allergyintolerance" - resource = "AllergyIntolerance" - tags = {"cpu"} + name: ClassVar = "allergyintolerance" + resource: ClassVar = "AllergyIntolerance" + tags: ClassVar = {"cpu"} class ConditionTask(tasks.EtlTask): - name = "condition" - resource = "Condition" - tags = {"cpu"} + name: ClassVar = "condition" + resource: ClassVar = "Condition" + tags: ClassVar = {"cpu"} class DeviceTask(tasks.EtlTask): - name = "device" - resource = "Device" - tags = {"cpu"} + name: ClassVar = "device" + resource: ClassVar = "Device" + tags: ClassVar = {"cpu"} class DiagnosticReportTask(tasks.EtlTask): - name = "diagnosticreport" - resource = "DiagnosticReport" - tags = {"cpu"} + name: ClassVar = "diagnosticreport" + resource: ClassVar = "DiagnosticReport" + tags: ClassVar = {"cpu"} class DocumentReferenceTask(tasks.EtlTask): - name = "documentreference" - resource = "DocumentReference" - tags = {"cpu"} + name: ClassVar = "documentreference" + resource: ClassVar = "DocumentReference" + tags: ClassVar = {"cpu"} class EncounterTask(tasks.EtlTask): """Processes Encounter FHIR resources""" - name = "encounter" - resource = "Encounter" - tags = {"cpu"} + name: ClassVar = "encounter" + resource: ClassVar = "Encounter" + tags: ClassVar = {"cpu"} # Encounters are a little more complicated than normal FHIR resources. # We also write out a table tying Encounters to a group name, for completion tracking. - outputs = [ + outputs: ClassVar = [ # Write completion data out first, so that if an encounter is being completion-tracked, # there's never a gap where it doesn't have an entry. This will help downstream users # know if an Encounter is tracked or not - by simply looking at this table. @@ -81,17 +82,17 @@ def get_schema(cls, resource_type: str | None, rows: list[dict]) -> pyarrow.Sche class ImmunizationTask(tasks.EtlTask): - name = "immunization" - resource = "Immunization" - tags = {"cpu"} + name: ClassVar = "immunization" + resource: ClassVar = "Immunization" + tags: ClassVar = {"cpu"} class MedicationRequestTask(tasks.EtlTask): """Write MedicationRequest resources and associated Medication resources""" - name = "medicationrequest" - resource = "MedicationRequest" - tags = {"cpu"} + name: ClassVar = "medicationrequest" + resource: ClassVar = "MedicationRequest" + tags: ClassVar = {"cpu"} # We may write to a second Medication table as we go. # MedicationRequest can have inline medications via CodeableConcepts, or external Medication references. @@ -99,7 +100,7 @@ class MedicationRequestTask(tasks.EtlTask): # We do all this special business logic because Medication is a special, "reference" resource, # and many EHRs don't let you simply bulk export them. - outputs = [ + outputs: ClassVar = [ # Write medication out first, to avoid a moment where links are broken tasks.OutputTable(name="medication", resource_type="Medication"), tasks.OutputTable(), @@ -171,24 +172,24 @@ async def read_entries(self, *, progress: rich.progress.Progress = None) -> task class ObservationTask(tasks.EtlTask): - name = "observation" - resource = "Observation" - tags = {"cpu"} + name: ClassVar = "observation" + resource: ClassVar = "Observation" + tags: ClassVar = {"cpu"} class PatientTask(tasks.EtlTask): - name = "patient" - resource = "Patient" - tags = {"cpu"} + name: ClassVar = "patient" + resource: ClassVar = "Patient" + tags: ClassVar = {"cpu"} class ProcedureTask(tasks.EtlTask): - name = "procedure" - resource = "Procedure" - tags = {"cpu"} + name: ClassVar = "procedure" + resource: ClassVar = "Procedure" + tags: ClassVar = {"cpu"} class ServiceRequestTask(tasks.EtlTask): - name = "servicerequest" - resource = "ServiceRequest" - tags = {"cpu"} + name: ClassVar = "servicerequest" + resource: ClassVar = "ServiceRequest" + tags: ClassVar = {"cpu"} diff --git a/cumulus_etl/etl/tasks/nlp_task.py b/cumulus_etl/etl/tasks/nlp_task.py index 1ad5014f..82406d82 100644 --- a/cumulus_etl/etl/tasks/nlp_task.py +++ b/cumulus_etl/etl/tasks/nlp_task.py @@ -4,7 +4,8 @@ import logging import os import sys -from typing import Callable +from collections.abc import Callable +from typing import ClassVar import rich.progress @@ -15,12 +16,15 @@ class BaseNlpTask(EtlTask): """Base class for any clinical-notes-based NLP task.""" - resource = "DocumentReference" - needs_bulk_deid = False + resource: ClassVar = "DocumentReference" + needs_bulk_deid: ClassVar = False # You may want to override these in your subclass - outputs = [OutputTable(resource_type=None)] # maybe a group_field? (remember to call self.seen_docrefs.add() if so) - tags = {"gpu"} # maybe a study identifier? + outputs: ClassVar = [ + # maybe add a group_field? (remember to call self.seen_docrefs.add() if so) + OutputTable(resource_type=None) + ] + tags: ClassVar = {"gpu"} # maybe a study identifier? # Task Version # The "task_version" field is a simple integer that gets incremented any time an NLP-relevant parameter is changed. @@ -32,7 +36,7 @@ class BaseNlpTask(EtlTask): # - Record the new bundle of metadata in your class documentation # - Update any safety checks in prepare_task() or elsewhere that check the NLP versioning # - Be aware that your caching will be reset - task_version = 1 + task_version: ClassVar = 1 # Task Version History: # ** 1 (20xx-xx): First version ** # CHANGE ME @@ -57,7 +61,7 @@ def add_error(self, docref: dict) -> None: writer.write(docref) async def read_notes( - self, *, doc_check: Callable[[dict], bool] = None, progress: rich.progress.Progress = None + self, *, doc_check: Callable[[dict], bool] | None = None, progress: rich.progress.Progress = None ) -> (dict, dict, str): """ Iterate through clinical notes. diff --git a/cumulus_etl/etl/tasks/task_factory.py b/cumulus_etl/etl/tasks/task_factory.py index d1a5ced8..579c2bc0 100644 --- a/cumulus_etl/etl/tasks/task_factory.py +++ b/cumulus_etl/etl/tasks/task_factory.py @@ -8,7 +8,7 @@ from cumulus_etl.etl.studies import covid_symptom, hftest from cumulus_etl.etl.tasks import basic_tasks -AnyTask = TypeVar("AnyTask", bound="EtlTask") +AnyTask = TypeVar("AnyTask", bound="EtlTask") # noqa: F821 def get_all_tasks() -> list[type[AnyTask]]: @@ -19,7 +19,8 @@ def get_all_tasks() -> list[type[AnyTask]]: """ # Right now, just hard-code these. One day we might allow plugins or something similarly dynamic. # Note: tasks will be run in the order listed here. - return get_default_tasks() + [ + return [ + *get_default_tasks(), covid_symptom.CovidSymptomNlpResultsTask, covid_symptom.CovidSymptomNlpResultsTermExistsTask, hftest.HuggingFaceTestTask, @@ -53,7 +54,9 @@ def get_default_tasks() -> list[type[AnyTask]]: ] -def get_selected_tasks(names: Iterable[str] = None, filter_tags: Iterable[str] = None) -> list[type[AnyTask]]: +def get_selected_tasks( + names: Iterable[str] | None = None, filter_tags: Iterable[str] | None = None +) -> list[type[AnyTask]]: """ Returns classes for every selected task. diff --git a/cumulus_etl/fhir/fhir_auth.py b/cumulus_etl/fhir/fhir_auth.py index 1bb375fe..a1ce2215 100644 --- a/cumulus_etl/fhir/fhir_auth.py +++ b/cumulus_etl/fhir/fhir_auth.py @@ -5,8 +5,8 @@ import time import urllib.parse import uuid -from json import JSONDecodeError from collections.abc import Iterable +from json import JSONDecodeError import httpx from jwcrypto import jwk, jwt @@ -166,7 +166,7 @@ def __init__(self, user: str, password: str): super().__init__() # Assume utf8 is acceptable -- we should in theory also run these through Unicode normalization, in case they # have interesting Unicode characters. But we can always add that in the future. - combo_bytes = f"{user}:{password}".encode("utf8") + combo_bytes = f"{user}:{password}".encode() self._basic_token = base64.standard_b64encode(combo_bytes).decode("ascii") async def authorize(self, session: httpx.AsyncClient, reauthorize=False) -> None: diff --git a/cumulus_etl/fhir/fhir_client.py b/cumulus_etl/fhir/fhir_client.py index 925135a9..278c0c85 100644 --- a/cumulus_etl/fhir/fhir_client.py +++ b/cumulus_etl/fhir/fhir_client.py @@ -4,8 +4,8 @@ import enum import re import sys -from json import JSONDecodeError from collections.abc import Iterable +from json import JSONDecodeError import httpx @@ -34,11 +34,11 @@ def __init__( self, url: str | None, resources: Iterable[str], - basic_user: str = None, - basic_password: str = None, - bearer_token: str = None, - smart_client_id: str = None, - smart_jwks: dict = None, + basic_user: str | None = None, + basic_password: str | None = None, + bearer_token: str | None = None, + smart_client_id: str | None = None, + smart_jwks: dict | None = None, ): """ Initialize and authorize a BackendServiceServer context manager. @@ -78,7 +78,9 @@ async def __aexit__(self, exc_type, exc_value, traceback): if self._session: await self._session.aclose() - async def request(self, method: str, path: str, headers: dict = None, stream: bool = False) -> httpx.Response: + async def request( + self, method: str, path: str, headers: dict | None = None, stream: bool = False + ) -> httpx.Response: """ Issues an HTTP request. diff --git a/cumulus_etl/formats/base.py b/cumulus_etl/formats/base.py index ee72ae9d..20ed5115 100644 --- a/cumulus_etl/formats/base.py +++ b/cumulus_etl/formats/base.py @@ -27,8 +27,8 @@ def __init__( self, root: store.Root, dbname: str, - group_field: str = None, - uniqueness_fields: Collection[str] = None, + group_field: str | None = None, + uniqueness_fields: Collection[str] | None = None, update_existing: bool = True, ): """ diff --git a/cumulus_etl/formats/batch.py b/cumulus_etl/formats/batch.py index 96fada27..c7d7027a 100644 --- a/cumulus_etl/formats/batch.py +++ b/cumulus_etl/formats/batch.py @@ -15,7 +15,7 @@ class Batch: - Written to the target location as one piece (e.g. one ndjson file or one Delta Lake update chunk) """ - def __init__(self, rows: list[dict], groups: set[str] = None, schema: pyarrow.Schema = None): + def __init__(self, rows: list[dict], groups: set[str] | None = None, schema: pyarrow.Schema = None): self.rows = rows # `groups` is the set of the values of the format's `group_field` represented by `rows`. # We can't just get this from rows directly because there might be groups that now have zero entries. diff --git a/cumulus_etl/loaders/fhir/bulk_export.py b/cumulus_etl/loaders/fhir/bulk_export.py index 6e675675..4676ea8d 100644 --- a/cumulus_etl/loaders/fhir/bulk_export.py +++ b/cumulus_etl/loaders/fhir/bulk_export.py @@ -38,8 +38,8 @@ def __init__( resources: list[str], url: str, destination: str, - since: str = None, - until: str = None, + since: str | None = None, + until: str | None = None, ): """ Initialize a bulk exporter (but does not start an export). @@ -130,7 +130,7 @@ async def export(self) -> None: # The spec acknowledges that "error" is perhaps misleading for an array that can contain info messages. error_texts, warning_texts = await self._gather_all_messages(response_json.get("error", [])) if warning_texts: - print("\n - ".join(["Messages from server:"] + warning_texts)) + print("\n - ".join(["Messages from server:", *warning_texts])) # Download all the files print("Bulk FHIR export finished, now downloading resources…") @@ -149,7 +149,7 @@ async def export(self) -> None: # the server DID give us. Servers may have lots of ignorable errors that need human review, # before passing back to us as input ndjson. if error_texts: - raise errors.FatalError("\n - ".join(["Errors occurred during export:"] + error_texts)) + raise errors.FatalError("\n - ".join(["Errors occurred during export:", *error_texts])) ################################################################################################################### # @@ -168,10 +168,10 @@ async def _delete_export(self, poll_url: str) -> None: async def _request_with_delay( self, path: str, - headers: dict = None, + headers: dict | None = None, target_status_code: int = 200, method: str = "GET", - log_progress: Callable[[httpx.Response], None] = None, + log_progress: Callable[[httpx.Response], None] | None = None, ) -> httpx.Response: """ Requests a file, while respecting any requests to wait longer. @@ -228,8 +228,8 @@ async def _request_with_delay( async def _request_with_logging( self, *args, - log_begin: Callable[[], None] = None, - log_error: Callable[[Exception], None] = None, + log_begin: Callable[[], None] | None = None, + log_error: Callable[[Exception], None] | None = None, **kwargs, ) -> httpx.Response: if log_begin: diff --git a/cumulus_etl/loaders/fhir/export_log.py b/cumulus_etl/loaders/fhir/export_log.py index a5aca8a8..02cc1ccd 100644 --- a/cumulus_etl/loaders/fhir/export_log.py +++ b/cumulus_etl/loaders/fhir/export_log.py @@ -114,7 +114,7 @@ def __init__(self, root: store.Root): self._num_bytes = 0 self._start_time = None - def _event(self, event_id: str, detail: dict, *, timestamp: datetime.datetime = None) -> None: + def _event(self, event_id: str, detail: dict, *, timestamp: datetime.datetime | None = None) -> None: timestamp = timestamp or common.datetime_now(local=True) if self._start_time is None: self._start_time = timestamp diff --git a/cumulus_etl/loaders/fhir/ndjson_loader.py b/cumulus_etl/loaders/fhir/ndjson_loader.py index 1c5837aa..3e4cbcc2 100644 --- a/cumulus_etl/loaders/fhir/ndjson_loader.py +++ b/cumulus_etl/loaders/fhir/ndjson_loader.py @@ -17,9 +17,9 @@ def __init__( self, root: store.Root, client: fhir.FhirClient = None, - export_to: str = None, - since: str = None, - until: str = None, + export_to: str | None = None, + since: str | None = None, + until: str | None = None, ): """ :param root: location to load ndjson from diff --git a/cumulus_etl/loaders/i2b2/loader.py b/cumulus_etl/loaders/i2b2/loader.py index aa5f0e93..100131fc 100644 --- a/cumulus_etl/loaders/i2b2/loader.py +++ b/cumulus_etl/loaders/i2b2/loader.py @@ -25,7 +25,7 @@ class I2b2Loader(Loader): Expected format is either a tcp:// URL pointing at an Oracle server or a local folder. """ - def __init__(self, root: store.Root, export_to: str = None): + def __init__(self, root: store.Root, export_to: str | None = None): """ Initialize a new I2b2Loader class :param root: the base location to read data from diff --git a/cumulus_etl/loaders/i2b2/oracle/extract.py b/cumulus_etl/loaders/i2b2/oracle/extract.py index 50d5608d..f63cce3b 100644 --- a/cumulus_etl/loaders/i2b2/oracle/extract.py +++ b/cumulus_etl/loaders/i2b2/oracle/extract.py @@ -4,9 +4,14 @@ from collections.abc import Iterable from cumulus_etl import common -from cumulus_etl.loaders.i2b2.schema import ObservationFact, PatientDimension, VisitDimension -from cumulus_etl.loaders.i2b2.schema import ConceptDimension, ProviderDimension from cumulus_etl.loaders.i2b2.oracle import connect, query +from cumulus_etl.loaders.i2b2.schema import ( + ConceptDimension, + ObservationFact, + PatientDimension, + ProviderDimension, + VisitDimension, +) def execute(dsn: str, desc: str, sql_statement: str) -> Iterable[dict]: diff --git a/cumulus_etl/loaders/i2b2/oracle/query.py b/cumulus_etl/loaders/i2b2/oracle/query.py index 9c33e9b4..43252de8 100644 --- a/cumulus_etl/loaders/i2b2/oracle/query.py +++ b/cumulus_etl/loaders/i2b2/oracle/query.py @@ -2,7 +2,6 @@ from cumulus_etl.loaders.i2b2.schema import Table, ValueType - ############################################################################### # Table.patient_dimension ############################################################################### @@ -12,7 +11,7 @@ def sql_patient() -> str: birth_date = format_date("BIRTH_DATE") death_date = format_date("DEATH_DATE") cols = f"PATIENT_NUM, {birth_date}, {death_date}, SEX_CD, RACE_CD, ZIP_CD" - return f"select {cols} \n from {Table.patient.value}" # nosec + return f"select {cols} \n from {Table.patient.value}" # noqa: S608 ############################################################################### @@ -23,7 +22,7 @@ def sql_patient() -> str: def sql_provider() -> str: cols_dates = format_date("IMPORT_DATE") cols = f"PROVIDER_ID, PROVIDER_PATH, NAME_CHAR, {cols_dates}" - return f"select {cols} \n from {Table.provider.value}" # nosec + return f"select {cols} \n from {Table.provider.value}" # noqa: S608 ############################################################################### @@ -41,7 +40,7 @@ def sql_visit() -> str: cols_dates = f"{start_date}, {end_date}, {import_date}, LENGTH_OF_STAY" cols = "ENCOUNTER_NUM, PATIENT_NUM, LOCATION_CD, INOUT_CD, " f"{cols_dates}" - return f"select {cols} \n from {Table.visit.value}" # nosec + return f"select {cols} \n from {Table.visit.value}" # noqa: S608 def after_start_date(start_date: str) -> str: @@ -73,7 +72,7 @@ def sql_concept() -> str: """ cols_dates = format_date("IMPORT_DATE") cols = f"CONCEPT_CD, NAME_CHAR, SOURCESYSTEM_CD, CONCEPT_BLOB, {cols_dates}" - return f"select {cols} \n from {Table.concept.value}" # nosec + return f"select {cols} \n from {Table.concept.value}" # noqa: S608 ############################################################################### @@ -102,7 +101,7 @@ def sql_observation_fact(categories: list[str]) -> str: matchers = [f"(concept_cd like '{category}:%')" for category in categories] - return f"select {cols} \n from {Table.observation_fact.value} O " f"where {' or '.join(matchers)}" # nosec + return f"select {cols} \n from {Table.observation_fact.value} O " f"where {' or '.join(matchers)}" # noqa: S608 def eq_val_type(val_type: ValueType) -> str: diff --git a/cumulus_etl/loaders/i2b2/transform.py b/cumulus_etl/loaders/i2b2/transform.py index 98e6ef09..9b3c6c1b 100644 --- a/cumulus_etl/loaders/i2b2/transform.py +++ b/cumulus_etl/loaders/i2b2/transform.py @@ -5,8 +5,7 @@ from cumulus_etl import fhir from cumulus_etl.loaders.i2b2 import external_mappings -from cumulus_etl.loaders.i2b2.schema import PatientDimension, VisitDimension, ObservationFact - +from cumulus_etl.loaders.i2b2.schema import ObservationFact, PatientDimension, VisitDimension ############################################################################### # @@ -341,7 +340,7 @@ def get_observation_value(obsfact: ObservationFact) -> dict: return {"valueQuantity": quantity} -def make_concept(code: str, system: str | None, display: str = None, display_codes: dict = None) -> dict: +def make_concept(code: str, system: str | None, display: str | None = None, display_codes: dict | None = None) -> dict: """Syntactic sugar to make a codeable concept""" coding = {"code": code, "system": system} if display: diff --git a/cumulus_etl/nlp/__init__.py b/cumulus_etl/nlp/__init__.py index 783a16a7..9431411a 100644 --- a/cumulus_etl/nlp/__init__.py +++ b/cumulus_etl/nlp/__init__.py @@ -1,6 +1,6 @@ """Support code for NLP servers""" from .extract import TransformerModel, ctakes_extract, ctakes_httpx_client, list_polarity -from .huggingface import hf_prompt, hf_info, llama2_prompt +from .huggingface import hf_info, hf_prompt, llama2_prompt from .utils import cache_wrapper, is_docref_valid -from .watcher import check_negation_cnlpt, check_term_exists_cnlpt, check_ctakes, restart_ctakes_with_bsv +from .watcher import check_ctakes, check_negation_cnlpt, check_term_exists_cnlpt, restart_ctakes_with_bsv diff --git a/cumulus_etl/nlp/utils.py b/cumulus_etl/nlp/utils.py index 9e1ee5d0..81c50403 100644 --- a/cumulus_etl/nlp/utils.py +++ b/cumulus_etl/nlp/utils.py @@ -2,7 +2,7 @@ import hashlib import os -from typing import Callable +from collections.abc import Callable from cumulus_etl import common, store diff --git a/cumulus_etl/upload_notes/downloader.py b/cumulus_etl/upload_notes/downloader.py index 27f3cb90..a69723bf 100644 --- a/cumulus_etl/upload_notes/downloader.py +++ b/cumulus_etl/upload_notes/downloader.py @@ -14,9 +14,9 @@ async def download_docrefs_from_fhir_server( client: fhir.FhirClient, root_input: store.Root, codebook: deid.Codebook, - docrefs: str = None, - anon_docrefs: str = None, - export_to: str = None, + docrefs: str | None = None, + anon_docrefs: str | None = None, + export_to: str | None = None, ): if docrefs: return await _download_docrefs_from_real_ids(client, docrefs, export_to=export_to) @@ -32,7 +32,7 @@ async def _download_docrefs_from_fake_ids( client: fhir.FhirClient, codebook: deid.Codebook, docref_csv: str, - export_to: str = None, + export_to: str | None = None, ) -> common.Directory: """Download DocumentReference resources for the given patient and docref identifiers""" output_folder = cli_utils.make_export_dir(export_to) @@ -61,7 +61,7 @@ async def _download_docrefs_from_fake_ids( async def _download_docrefs_from_real_ids( client: fhir.FhirClient, docref_csv: str, - export_to: str = None, + export_to: str | None = None, ) -> common.Directory: """Download DocumentReference resources for the given patient and docref identifiers""" output_folder = cli_utils.make_export_dir(export_to) diff --git a/cumulus_etl/upload_notes/labelstudio.py b/cumulus_etl/upload_notes/labelstudio.py index d766e2b7..dabcaa70 100644 --- a/cumulus_etl/upload_notes/labelstudio.py +++ b/cumulus_etl/upload_notes/labelstudio.py @@ -11,7 +11,6 @@ from cumulus_etl import errors - ############################################################################### # # LabelStudio : Document Annotation diff --git a/cumulus_etl/upload_notes/selector.py b/cumulus_etl/upload_notes/selector.py index f4bdd2ce..bf2e1473 100644 --- a/cumulus_etl/upload_notes/selector.py +++ b/cumulus_etl/upload_notes/selector.py @@ -10,9 +10,9 @@ def select_docrefs_from_files( root_input: store.Root, codebook: deid.Codebook, - docrefs: str = None, - anon_docrefs: str = None, - export_to: str = None, + docrefs: str | None = None, + anon_docrefs: str | None = None, + export_to: str | None = None, ) -> common.Directory: """Takes an input folder of ndjson and exports just the chosen docrefs to a new ndjson folder""" # Get an appropriate filter method, for the given docrefs @@ -32,7 +32,7 @@ def select_docrefs_from_files( def _create_docref_filter( - codebook: deid.Codebook, docrefs: str = None, anon_docrefs: str = None + codebook: deid.Codebook, docrefs: str | None = None, anon_docrefs: str | None = None ) -> Callable[[Iterable[dict]], Iterator[dict]]: """This returns a method that will can an iterator of docrefs and returns an iterator of fewer docrefs""" # Decide how we're filtering the input files (by real or fake ID, or no filtering at all!) diff --git a/pyproject.toml b/pyproject.toml index 8163e186..b5cb555f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,23 @@ classifiers = [ ] dynamic = ["version"] +[project.optional-dependencies] +tests = [ + "coverage", + "ddt", + "moto[server,s3] >= 5.0", + "pytest", + "pytest-cov", + "respx", + "time-machine", +] +dev = [ + "pre-commit", + # Ruff is using minor versions for breaking changes until their 1.0 release. + # See https://docs.astral.sh/ruff/versioning/ + "ruff < 0.6", # keep in rough sync with pre-commit-config.yaml +] + [project.urls] "Homepage" = "https://github.com/smart-on-fhir/cumulus-etl" @@ -59,26 +76,22 @@ exclude = [ "**/.pytest_cache", ] -[tool.bandit] -exclude_dirs = ["tests"] - -[tool.black] +[tool.ruff] line-length = 120 -[project.optional-dependencies] -tests = [ - "coverage", - "ddt", - "moto[server,s3] >= 5.0", - "pytest", - "pytest-cov", - "respx", - "time-machine", -] -dev = [ - "bandit[toml]", - "black >= 24, < 25", # keep in rough sync with .pre-commit-config.yaml - "pre-commit", - "pycodestyle", - "pylint", +[tool.ruff.lint] +allowed-confusables = ["’"] # allow proper apostrophes +select = [ + "A", # prevent using keywords that clobber python builtins + "E", # pycodestyle + "F", # pyflakes + "I", # isort + "PLE", # pylint errors + "RUF", # the ruff developer's own rules + "S", # bandit security warnings + "UP", # alert you when better syntax is available in your python version ] + +[tool.ruff.lint.per-file-ignores] +"**/__init__.py" = ["F401"] # init files hold API, so not using imports is intentional +"tests/**" = ["S"] # tests do suspicious stuff that's fine, actually diff --git a/tests/convert/test_convert_cli.py b/tests/convert/test_convert_cli.py index b825958e..51a66410 100644 --- a/tests/convert/test_convert_cli.py +++ b/tests/convert/test_convert_cli.py @@ -46,7 +46,7 @@ def prepare_original_dir(self) -> str: return job_timestamp - async def run_convert(self, input_path: str = None, output_path: str = None) -> None: + async def run_convert(self, input_path: str | None = None, output_path: str | None = None) -> None: args = [ "convert", input_path or self.original_path, diff --git a/tests/covid_symptom/test_covid_results.py b/tests/covid_symptom/test_covid_results.py index 5211a88c..443fbbe4 100644 --- a/tests/covid_symptom/test_covid_results.py +++ b/tests/covid_symptom/test_covid_results.py @@ -7,9 +7,8 @@ import respx from cumulus_etl.etl.studies import covid_symptom - -from tests.ctakesmock import CtakesMixin from tests import i2b2_mock_data +from tests.ctakesmock import CtakesMixin from tests.etl import BaseEtlSimple, TaskTestCase diff --git a/tests/etl/base.py b/tests/etl/base.py index c19addac..d10480d9 100644 --- a/tests/etl/base.py +++ b/tests/etl/base.py @@ -50,10 +50,10 @@ async def run_etl( comment=None, batch_size=None, tasks=None, - tags: list[str] = None, + tags: list[str] | None = None, philter=True, errors_to=None, - export_to: str = None, + export_to: str | None = None, input_format: str = "ndjson", export_group: str = "test-group", export_timestamp: str = "2020-10-13T12:00:20-05:00", diff --git a/tests/etl/test_batching.py b/tests/etl/test_batching.py index 08f463d6..9d995185 100644 --- a/tests/etl/test_batching.py +++ b/tests/etl/test_batching.py @@ -5,7 +5,6 @@ import ddt from cumulus_etl.etl.tasks import batching - from tests.utils import AsyncTestCase diff --git a/tests/etl/test_etl_cli.py b/tests/etl/test_etl_cli.py index da5149cf..55c6fb23 100644 --- a/tests/etl/test_etl_cli.py +++ b/tests/etl/test_etl_cli.py @@ -15,7 +15,6 @@ from cumulus_etl import common, errors, loaders, store from cumulus_etl.etl import context - from tests.ctakesmock import fake_ctakes_extract from tests.etl import BaseEtlSimple from tests.s3mock import S3Mixin @@ -254,7 +253,7 @@ def setUp(self): def read_config_file(self, name: str) -> dict: full_path = os.path.join(self.job_config_path, name) - with open(full_path, "r", encoding="utf8") as f: + with open(full_path, encoding="utf8") as f: return json.load(f) async def test_serialization(self): @@ -418,7 +417,7 @@ def path_for_checksum(self, prefix, checksum): def read_symptoms(self): """Loads the output symptoms ndjson from disk""" path = os.path.join(self.output_path, "covid_symptom__nlp_results", "covid_symptom__nlp_results.000.ndjson") - with open(path, "r", encoding="utf8") as f: + with open(path, encoding="utf8") as f: lines = f.readlines() return [json.loads(line) for line in lines] diff --git a/tests/formats/test_deltalake.py b/tests/formats/test_deltalake.py index 39a11805..d62dcf1e 100644 --- a/tests/formats/test_deltalake.py +++ b/tests/formats/test_deltalake.py @@ -58,7 +58,7 @@ def store( self, rows: list[dict], schema: pyarrow.Schema = None, - groups: set[str] = None, + groups: set[str] | None = None, **kwargs, ) -> bool: """ diff --git a/tests/hftest/test_hftask.py b/tests/hftest/test_hftask.py index 2240fe18..7a78258e 100644 --- a/tests/hftest/test_hftask.py +++ b/tests/hftest/test_hftask.py @@ -6,7 +6,6 @@ from cumulus_etl import common, errors from cumulus_etl.etl.studies import hftest - from tests import i2b2_mock_data from tests.etl import BaseEtlSimple, TaskTestCase @@ -32,7 +31,7 @@ def mock_prompt(respx_mock: respx.MockRouter, text: str, url: str = "http://loca def mock_info( - respx_mock: respx.MockRouter, url: str = "http://localhost:8086/info", override: dict = None + respx_mock: respx.MockRouter, url: str = "http://localhost:8086/info", override: dict | None = None ) -> respx.Route: response = { "model_id": "meta-llama/Llama-2-13b-chat-hf", diff --git a/tests/loaders/i2b2/test_i2b2_transform.py b/tests/loaders/i2b2/test_i2b2_transform.py index 7b75abab..a45fcfc7 100644 --- a/tests/loaders/i2b2/test_i2b2_transform.py +++ b/tests/loaders/i2b2/test_i2b2_transform.py @@ -3,7 +3,6 @@ import ddt from cumulus_etl.loaders.i2b2 import schema, transform - from tests import i2b2_mock_data from tests.utils import AsyncTestCase diff --git a/tests/loaders/ndjson/test_bulk_export.py b/tests/loaders/ndjson/test_bulk_export.py index 655062b9..5be99d49 100644 --- a/tests/loaders/ndjson/test_bulk_export.py +++ b/tests/loaders/ndjson/test_bulk_export.py @@ -90,7 +90,9 @@ def assert_log_equals(self, *rows) -> None: if row[1] is not None: self.assertEqual(reordered_details[index][1], row[1]) - def mock_kickoff(self, params: str = "?_type=Condition%2CPatient", side_effect: list = None, **kwargs) -> None: + def mock_kickoff( + self, params: str = "?_type=Condition%2CPatient", side_effect: list | None = None, **kwargs + ) -> None: kwargs.setdefault("status_code", 202) route = self.respx_mock.get( f"{self.fhir_url}/$export{params}", diff --git a/tests/nlp/test_watcher.py b/tests/nlp/test_watcher.py index 2b2da7f2..3280ca17 100644 --- a/tests/nlp/test_watcher.py +++ b/tests/nlp/test_watcher.py @@ -1,14 +1,9 @@ """Tests for nlp/watcher.py""" -import os import tempfile from unittest import mock -import ddt -import respx - from cumulus_etl import common, errors, nlp - from tests.ctakesmock import CtakesMixin from tests.utils import AsyncTestCase diff --git a/tests/test_cli.py b/tests/test_cli.py index b6a41f9b..15798c8b 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -30,7 +30,7 @@ async def test_usage(self, argv, expected_usage): stdout = io.StringIO() with contextlib.redirect_stdout(stdout): with self.assertRaises(SystemExit): - await cli.main(argv + ["--help"]) + await cli.main([*argv, "--help"]) self.assertTrue(stdout.getvalue().startswith(expected_usage), stdout.getvalue()) diff --git a/tests/upload_notes/test_upload_cli.py b/tests/upload_notes/test_upload_cli.py index e638e3d8..5778bcfc 100644 --- a/tests/upload_notes/test_upload_cli.py +++ b/tests/upload_notes/test_upload_cli.py @@ -14,7 +14,6 @@ from cumulus_etl import cli, common, errors from cumulus_etl.upload_notes.labelstudio import LabelStudioNote - from tests.ctakesmock import CtakesMixin from tests.utils import AsyncTestCase @@ -113,11 +112,11 @@ async def run_upload_notes( @staticmethod def make_docref( doc_id: str, - text: str = None, - content: list[dict] = None, - enc_id: str = None, - date: str = None, - period_start: str = None, + text: str | None = None, + content: list[dict] | None = None, + enc_id: str | None = None, + date: str | None = None, + period_start: str | None = None, ) -> dict: docref = { "resourceType": "DocumentReference", @@ -163,7 +162,7 @@ def mock_search_url(respx_mock: respx.MockRouter, patient: str, doc_ids: Iterabl @staticmethod def mock_read_url( - respx_mock: respx.MockRouter, doc_id: str, code: int = 200, docref: dict = None, **kwargs + respx_mock: respx.MockRouter, doc_id: str, code: int = 200, docref: dict | None = None, **kwargs ) -> None: docref = docref or TestUploadNotes.make_docref(doc_id, **kwargs) respx_mock.get(f"https://localhost/DocumentReference/{doc_id}").respond(status_code=code, json=docref) @@ -178,7 +177,7 @@ def write_anon_docrefs(path: str, ids: list[tuple[str, str]]) -> None: @staticmethod def write_real_docrefs(path: str, ids: list[str]) -> None: """Fills a file with the provided docref ids""" - lines = ["docref_id"] + ids + lines = ["docref_id", *ids] with open(path, "w", encoding="utf8") as f: f.write("\n".join(lines)) diff --git a/tests/upload_notes/test_upload_labelstudio.py b/tests/upload_notes/test_upload_labelstudio.py index 89532e8d..d72aa71b 100644 --- a/tests/upload_notes/test_upload_labelstudio.py +++ b/tests/upload_notes/test_upload_labelstudio.py @@ -6,7 +6,6 @@ from ctakesclient.typesystem import Polarity from cumulus_etl.upload_notes.labelstudio import LabelStudioClient, LabelStudioNote - from tests import ctakesmock from tests.utils import AsyncTestCase diff --git a/tests/utils.py b/tests/utils.py index fb99f7e0..cdb8e63b 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -268,7 +268,7 @@ def make_response(status_code=200, json_payload=None, text=None, reason=None, he ) -def read_delta_lake(lake_path: str, *, version: int = None) -> list[dict]: +def read_delta_lake(lake_path: str, *, version: int | None = None) -> list[dict]: """ Reads in a delta lake folder at a certain time, sorted by id. @@ -294,7 +294,7 @@ def read_delta_lake(lake_path: str, *, version: int = None) -> list[dict]: @contextlib.contextmanager -def time_it(desc: str = None): +def time_it(desc: str | None = None): """Tiny little timer context manager that is useful when debugging""" start = time.perf_counter() yield @@ -304,7 +304,7 @@ def time_it(desc: str = None): @contextlib.contextmanager -def mem_it(desc: str = None): +def mem_it(desc: str | None = None): """Tiny little context manager to measure memory usage""" start_tracing = not tracemalloc.is_tracing() if start_tracing: From e436e6ad07860e18045647b7565606db21e6de1b Mon Sep 17 00:00:00 2001 From: Michael Terry Date: Mon, 29 Jul 2024 13:23:15 -0400 Subject: [PATCH 2/3] style: drop from 120 to 100 cols wide --- cumulus_etl/cli.py | 5 +- cumulus_etl/cli_utils.py | 27 ++++- cumulus_etl/common.py | 12 +- cumulus_etl/deid/codebook.py | 20 +++- cumulus_etl/deid/mstool.py | 14 ++- cumulus_etl/deid/scrubber.py | 29 ++++- cumulus_etl/errors.py | 4 +- cumulus_etl/etl/cli.py | 54 +++++++-- cumulus_etl/etl/context.py | 3 +- cumulus_etl/etl/convert/cli.py | 7 +- .../etl/studies/covid_symptom/covid_ctakes.py | 19 ++- .../etl/studies/covid_symptom/covid_tasks.py | 8 +- cumulus_etl/etl/tasks/base.py | 25 ++-- cumulus_etl/etl/tasks/basic_tasks.py | 8 +- cumulus_etl/etl/tasks/batching.py | 7 +- cumulus_etl/etl/tasks/nlp_task.py | 5 +- cumulus_etl/etl/tasks/task_factory.py | 5 +- cumulus_etl/fhir/fhir_auth.py | 6 +- cumulus_etl/fhir/fhir_client.py | 27 ++++- cumulus_etl/formats/batch.py | 4 +- cumulus_etl/formats/deltalake.py | 21 +++- cumulus_etl/loaders/fhir/bulk_export.py | 13 +- cumulus_etl/loaders/fhir/export_log.py | 4 +- cumulus_etl/loaders/fhir/ndjson_loader.py | 3 +- cumulus_etl/loaders/i2b2/loader.py | 29 +++-- cumulus_etl/loaders/i2b2/oracle/connect.py | 3 +- cumulus_etl/loaders/i2b2/oracle/query.py | 5 +- cumulus_etl/loaders/i2b2/transform.py | 56 +++++++-- cumulus_etl/nlp/__init__.py | 7 +- cumulus_etl/nlp/extract.py | 4 +- cumulus_etl/nlp/huggingface.py | 7 +- cumulus_etl/nlp/utils.py | 9 +- cumulus_etl/nlp/watcher.py | 3 +- cumulus_etl/upload_notes/cli.py | 81 ++++++++++--- cumulus_etl/upload_notes/downloader.py | 16 ++- cumulus_etl/upload_notes/labelstudio.py | 36 ++++-- cumulus_etl/upload_notes/selector.py | 7 +- pyproject.toml | 11 +- tests/convert/test_convert_cli.py | 54 ++++++--- tests/covid_symptom/test_covid_results.py | 43 +++++-- tests/ctakesmock.py | 66 +++++++++-- tests/deid/test_deid_codebook.py | 12 +- tests/deid/test_deid_philter.py | 20 +++- tests/deid/test_deid_scrubber.py | 22 +++- tests/etl/base.py | 7 +- tests/etl/test_etl_cli.py | 82 +++++++++---- tests/etl/test_etl_context.py | 8 +- tests/etl/test_tasks.py | 89 ++++++++++---- tests/fhir/test_fhir_client.py | 81 ++++++++++--- tests/fhir/test_fhir_utils.py | 27 ++++- tests/formats/test_deltalake.py | 34 +++++- tests/hftest/test_hftask.py | 8 +- tests/loaders/i2b2/test_i2b2_etl.py | 4 +- .../loaders/i2b2/test_i2b2_oracle_connect.py | 4 +- .../loaders/i2b2/test_i2b2_oracle_extract.py | 18 ++- tests/loaders/i2b2/test_i2b2_oracle_query.py | 4 +- tests/loaders/i2b2/test_i2b2_transform.py | 28 ++++- tests/loaders/ndjson/test_bulk_export.py | 111 ++++++++++++++---- tests/loaders/ndjson/test_ndjson_loader.py | 27 +++-- tests/test_common.py | 8 +- tests/upload_notes/test_upload_cli.py | 60 ++++++++-- tests/upload_notes/test_upload_labelstudio.py | 14 ++- tests/utils.py | 53 +++++---- 63 files changed, 1140 insertions(+), 348 deletions(-) diff --git a/cumulus_etl/cli.py b/cumulus_etl/cli.py index c5a1654b..56b6c264 100644 --- a/cumulus_etl/cli.py +++ b/cumulus_etl/cli.py @@ -43,7 +43,8 @@ def get_subcommand(argv: list[str]) -> str | None: if arg in Command.values(): return argv.pop(i) # remove it to make later parsers' jobs easier elif not arg.startswith("-"): - return None # first positional arg did not match a known command, assume default command + # first positional arg did not match a known command, assume default command + return None async def main(argv: list[str]) -> None: @@ -71,7 +72,7 @@ async def main(argv: list[str]) -> None: if not subcommand: # Add a note about other subcommands we offer, and tell argparse not to wrap our formatting parser.formatter_class = argparse.RawDescriptionHelpFormatter - parser.description += "\n\n" "other commands available:\n" " convert\n" " upload-notes" + parser.description += "\n\nother commands available:\n convert\n upload-notes" run_method = etl.run_etl with tempfile.TemporaryDirectory() as tempdir: diff --git a/cumulus_etl/cli_utils.py b/cumulus_etl/cli_utils.py index 72873370..1ba5bea2 100644 --- a/cumulus_etl/cli_utils.py +++ b/cumulus_etl/cli_utils.py @@ -16,16 +16,28 @@ def add_auth(parser: argparse.ArgumentParser) -> None: group.add_argument("--smart-client-id", metavar="ID", help="Client ID for SMART authentication") group.add_argument("--smart-jwks", metavar="PATH", help="JWKS file for SMART authentication") group.add_argument("--basic-user", metavar="USER", help="Username for Basic authentication") - group.add_argument("--basic-passwd", metavar="PATH", help="Password file for Basic authentication") - group.add_argument("--bearer-token", metavar="PATH", help="Token file for Bearer authentication") - group.add_argument("--fhir-url", metavar="URL", help="FHIR server base URL, only needed if you exported separately") + group.add_argument( + "--basic-passwd", metavar="PATH", help="Password file for Basic authentication" + ) + group.add_argument( + "--bearer-token", metavar="PATH", help="Token file for Bearer authentication" + ) + group.add_argument( + "--fhir-url", + metavar="URL", + help="FHIR server base URL, only needed if you exported separately", + ) def add_aws(parser: argparse.ArgumentParser) -> None: group = parser.add_argument_group("AWS") - group.add_argument("--s3-region", metavar="REGION", help="If using S3 paths (s3://...), this is their region") group.add_argument( - "--s3-kms-key", metavar="KEY", help="If using S3 paths (s3://...), this is the KMS key ID to use" + "--s3-region", metavar="REGION", help="If using S3 paths (s3://...), this is their region" + ) + group.add_argument( + "--s3-kms-key", + metavar="KEY", + help="If using S3 paths (s3://...), this is the KMS key ID to use", ) @@ -57,7 +69,10 @@ def make_export_dir(export_to: str | None = None) -> common.Directory: if urllib.parse.urlparse(export_to).netloc: # We require a local folder because that's all that the MS deid tool can operate on. # If we were to relax this requirement, we'd want to copy the exported files over to a local dir. - errors.fatal(f"The target export folder '{export_to}' must be local. ", errors.BULK_EXPORT_FOLDER_NOT_LOCAL) + errors.fatal( + f"The target export folder '{export_to}' must be local. ", + errors.BULK_EXPORT_FOLDER_NOT_LOCAL, + ) confirm_dir_is_empty(store.Root(export_to, create=True)) diff --git a/cumulus_etl/common.py b/cumulus_etl/common.py index 30c71059..999e8eb7 100644 --- a/cumulus_etl/common.py +++ b/cumulus_etl/common.py @@ -76,7 +76,9 @@ def get_temp_dir(subdir: str) -> str: def ls_resources(root: store.Root, resources: set[str], warn_if_empty: bool = False) -> list[str]: - found_files = cumulus_fhir_support.list_multiline_json_in_dir(root.path, resources, fsspec_fs=root.fs) + found_files = cumulus_fhir_support.list_multiline_json_in_dir( + root.path, resources, fsspec_fs=root.fs + ) if warn_if_empty: # Invert the {path: type} found_files dictionary into {type: [paths...]} @@ -175,7 +177,9 @@ def read_ndjson(root: store.Root, path: str) -> Iterator[dict]: yield from cumulus_fhir_support.read_multiline_json(path, fsspec_fs=root.fs) -def read_resource_ndjson(root: store.Root, resource: str, warn_if_empty: bool = False) -> Iterator[dict]: +def read_resource_ndjson( + root: store.Root, resource: str, warn_if_empty: bool = False +) -> Iterator[dict]: """ Grabs all ndjson files from a folder, of a particular resource type. """ @@ -239,7 +243,9 @@ def read_local_line_count(path) -> int: count = 0 buf = None with open(path, "rb") as f: - bufgen = itertools.takewhile(lambda x: x, (f.raw.read(1024 * 1024) for _ in itertools.repeat(None))) + bufgen = itertools.takewhile( + lambda x: x, (f.raw.read(1024 * 1024) for _ in itertools.repeat(None)) + ) for buf in bufgen: count += buf.count(b"\n") if buf and buf[-1] != ord("\n"): # catch a final line without a trailing newline diff --git a/cumulus_etl/deid/codebook.py b/cumulus_etl/deid/codebook.py index 7e8cfcc6..2dafbbd8 100644 --- a/cumulus_etl/deid/codebook.py +++ b/cumulus_etl/deid/codebook.py @@ -70,7 +70,9 @@ def real_ids(self, resource_type: str, fake_ids: Iterable[str]) -> Iterator[str] if real_id: yield real_id else: - logging.warning("Real ID not found for anonymous %s ID %s. Ignoring.", resource_type, fake_id) + logging.warning( + "Real ID not found for anonymous %s ID %s. Ignoring.", resource_type, fake_id + ) ############################################################################### @@ -110,7 +112,9 @@ def __init__(self, codebook_dir: str | None = None): if codebook_dir: self._load_saved_settings(common.read_json(os.path.join(codebook_dir, "codebook.json"))) try: - self.cached_mapping = common.read_json(os.path.join(codebook_dir, "codebook-cached-mappings.json")) + self.cached_mapping = common.read_json( + os.path.join(codebook_dir, "codebook-cached-mappings.json") + ) except (FileNotFoundError, PermissionError): pass @@ -145,7 +149,9 @@ def encounter(self, real_id: str, cache_mapping: bool = True) -> str: """ return self._preserved_resource_hash("Encounter", real_id, cache_mapping) - def _preserved_resource_hash(self, resource_type: str, real_id: str, cache_mapping: bool) -> str: + def _preserved_resource_hash( + self, resource_type: str, real_id: str, cache_mapping: bool + ) -> str: """ Get a hashed ID and preserve the mapping. @@ -170,7 +176,10 @@ def _preserved_resource_hash(self, resource_type: str, real_id: str, cache_mappi # Save this generated ID mapping so that we can store it for debugging purposes later. # Only save if we don't have a legacy mapping, so that we don't have both in memory at the same time. - if cache_mapping and self.cached_mapping.setdefault(resource_type, {}).get(real_id) != fake_id: + if ( + cache_mapping + and self.cached_mapping.setdefault(resource_type, {}).get(real_id) != fake_id + ): # We expect the IDs to always be identical. The above check is mostly concerned with None != fake_id, # but is written defensively in case a bad mapping got saved for some reason. self.cached_mapping[resource_type][real_id] = fake_id @@ -206,7 +215,8 @@ def resource_hash(self, real_id: str) -> str: def _id_salt(self) -> bytes: """Returns the saved salt or creates and saves one if needed""" salt = self.settings["id_salt"] - return binascii.unhexlify(salt) # revert from doubled hex 64-char string representation back to just 32 bytes + # revert from doubled hex 64-char string representation back to just 32 bytes + return binascii.unhexlify(salt) def _load_saved_settings(self, saved: dict) -> None: """ diff --git a/cumulus_etl/deid/mstool.py b/cumulus_etl/deid/mstool.py index a6d34fdf..e7a1d10d 100644 --- a/cumulus_etl/deid/mstool.py +++ b/cumulus_etl/deid/mstool.py @@ -38,12 +38,15 @@ async def run_mstool(input_dir: str, output_dir: str) -> None: if process.returncode != 0: print( - f"An error occurred while de-identifying the input resources:\n\n{stderr.decode('utf8')}", file=sys.stderr + f"An error occurred while de-identifying the input resources:\n\n{stderr.decode('utf8')}", + file=sys.stderr, ) raise SystemExit(errors.MSTOOL_FAILED) -async def _wait_for_completion(process: asyncio.subprocess.Process, input_dir: str, output_dir: str) -> (str, str): +async def _wait_for_completion( + process: asyncio.subprocess.Process, input_dir: str, output_dir: str +) -> (str, str): """Waits for the MS tool to finish, with a nice little progress bar, returns stdout and stderr""" stdout, stderr = None, None @@ -74,7 +77,8 @@ def _compare_file_sizes(target: dict[str, int], current: dict[str, int]) -> floa total_current = 0 for filename, size in current.items(): if filename in target: - total_current += target[filename] # use target size, because current (de-identified) files will be smaller + # use target size, because current (de-identified) files will be smaller + total_current += target[filename] else: # an in-progress file is being written out total_current += size return total_current / total_expected @@ -93,4 +97,6 @@ def _get_file_size_safe(path: str) -> int: def _count_file_sizes(pattern: str) -> dict[str, int]: """Returns all files that match the given pattern and their sizes""" - return {os.path.basename(filename): _get_file_size_safe(filename) for filename in glob.glob(pattern)} + return { + os.path.basename(filename): _get_file_size_safe(filename) for filename in glob.glob(pattern) + } diff --git a/cumulus_etl/deid/scrubber.py b/cumulus_etl/deid/scrubber.py index ac4fa662..6a75b795 100644 --- a/cumulus_etl/deid/scrubber.py +++ b/cumulus_etl/deid/scrubber.py @@ -60,7 +60,9 @@ def scrub_resource(self, node: dict, scrub_attachments: bool = True) -> bool: :returns: whether this resource is allowed to be used """ try: - self._scrub_node(node.get("resourceType"), "root", node, scrub_attachments=scrub_attachments) + self._scrub_node( + node.get("resourceType"), "root", node, scrub_attachments=scrub_attachments + ) except SkipResource as exc: logging.warning("Ignoring resource of type %s: %s", node.__class__.__name__, exc) return False @@ -90,7 +92,9 @@ def save(self) -> None: # ############################################################################### - def _scrub_node(self, resource_type: str, node_path: str, node: dict, scrub_attachments: bool) -> None: + def _scrub_node( + self, resource_type: str, node_path: str, node: dict, scrub_attachments: bool + ) -> None: """Examines all properties of a node""" for key, values in list(node.items()): if values is None: @@ -106,7 +110,13 @@ def _scrub_node(self, resource_type: str, node_path: str, node: dict, scrub_atta ) def _scrub_single_value( - self, resource_type: str, node_path: str, node: dict, key: str, value: Any, scrub_attachments: bool + self, + resource_type: str, + node_path: str, + node: dict, + key: str, + value: Any, + scrub_attachments: bool, ) -> None: """Examines one single property of a node""" # For now, just manually run each operation. If this grows further, we can abstract it more. @@ -119,7 +129,9 @@ def _scrub_single_value( # Recurse if we are holding another FHIR object (i.e. a dict instead of a string) if isinstance(value, dict): - self._scrub_node(resource_type, f"{node_path}.{key}", value, scrub_attachments=scrub_attachments) + self._scrub_node( + resource_type, f"{node_path}.{key}", value, scrub_attachments=scrub_attachments + ) ############################################################################### # @@ -188,7 +200,11 @@ def _check_text(self, node: dict, key: str, value: Any): @staticmethod def _check_attachments(resource_type: str, node_path: str, node: dict, key: str) -> None: """Strip any attachment data""" - if resource_type == "DocumentReference" and node_path == "root.content.attachment" and key in {"data", "url"}: + if ( + resource_type == "DocumentReference" + and node_path == "root.content.attachment" + and key in {"data", "url"} + ): del node[key] @staticmethod @@ -200,7 +216,8 @@ def _check_security(node_path: str, node: dict, key: str, value: Any) -> None: """ if node_path == "root" and key == "meta": if "security" in value: - del value["security"] # maybe too aggressive -- is there data we care about in meta.security? + # maybe too aggressive -- is there data we care about in meta.security? + del value["security"] # If we wiped out the only content in Meta, remove it so as not to confuse downstream bits like parquet # writers which try to infer values from an empty struct and fail. diff --git a/cumulus_etl/errors.py b/cumulus_etl/errors.py index 994c3eeb..b3e6bf95 100644 --- a/cumulus_etl/errors.py +++ b/cumulus_etl/errors.py @@ -62,7 +62,9 @@ class FhirAuthMissing(FhirConnectionConfigError): """We needed to connect to a FHIR server but no authentication config was provided""" def __init__(self): - super().__init__("Could not download some files without authentication parameters (see --help)") + super().__init__( + "Could not download some files without authentication parameters (see --help)" + ) def fatal(message: str, status: int) -> NoReturn: diff --git a/cumulus_etl/etl/cli.py b/cumulus_etl/etl/cli.py index bc836ae7..e6d5af94 100644 --- a/cumulus_etl/etl/cli.py +++ b/cumulus_etl/etl/cli.py @@ -94,7 +94,10 @@ def define_etl_parser(parser: argparse.ArgumentParser) -> None: parser.add_argument("dir_output", metavar="/path/to/output") parser.add_argument("dir_phi", metavar="/path/to/phi") parser.add_argument( - "--input-format", default="ndjson", choices=["i2b2", "ndjson"], help="input format (default is ndjson)" + "--input-format", + default="ndjson", + choices=["i2b2", "ndjson"], + help="input format (default is ndjson)", ) parser.add_argument( "--output-format", @@ -107,18 +110,25 @@ def define_etl_parser(parser: argparse.ArgumentParser) -> None: type=int, metavar="SIZE", default=200000, - help="how many entries to process at once and thus " "how many to put in one output file (default is 200k)", + help="how many entries to process at once and thus " + "how many to put in one output file (default is 200k)", ) parser.add_argument("--comment", help="add the comment to the log file") - parser.add_argument("--philter", action="store_true", help="run philter on all freeform text fields") - parser.add_argument("--errors-to", metavar="DIR", help="where to put resources that could not be processed") + parser.add_argument( + "--philter", action="store_true", help="run philter on all freeform text fields" + ) + parser.add_argument( + "--errors-to", metavar="DIR", help="where to put resources that could not be processed" + ) cli_utils.add_aws(parser) cli_utils.add_auth(parser) export = parser.add_argument_group("bulk export") export.add_argument( - "--export-to", metavar="DIR", help="Where to put exported files (default is to delete after use)" + "--export-to", + metavar="DIR", + help="Where to put exported files (default is to delete after use)", ) export.add_argument("--since", help="Start date for export from the FHIR server") export.add_argument("--until", help="End date for export from the FHIR server") @@ -127,12 +137,16 @@ def define_etl_parser(parser: argparse.ArgumentParser) -> None: group.add_argument("--export-group", help=argparse.SUPPRESS) group.add_argument("--export-timestamp", help=argparse.SUPPRESS) # Temporary explicit opt-in flag during the development of the completion-tracking feature - group.add_argument("--write-completion", action="store_true", default=False, help=argparse.SUPPRESS) + group.add_argument( + "--write-completion", action="store_true", default=False, help=argparse.SUPPRESS + ) cli_utils.add_nlp(parser) task = parser.add_argument_group("task selection") - task.add_argument("--task", action="append", help="Only update the given output tables (comma separated)") + task.add_argument( + "--task", action="append", help="Only update the given output tables (comma separated)" + ) task.add_argument( "--task-filter", action="append", @@ -143,7 +157,9 @@ def define_etl_parser(parser: argparse.ArgumentParser) -> None: cli_utils.add_debugging(parser) -def print_config(args: argparse.Namespace, job_datetime: datetime.datetime, all_tasks: Iterable[tasks.EtlTask]) -> None: +def print_config( + args: argparse.Namespace, job_datetime: datetime.datetime, all_tasks: Iterable[tasks.EtlTask] +) -> None: """ Prints the ETL configuration to the console. @@ -185,12 +201,16 @@ def print_config(args: argparse.Namespace, job_datetime: datetime.datetime, all_ rich.get_console().print(table) -def handle_completion_args(args: argparse.Namespace, loader: loaders.Loader) -> (str, datetime.datetime): +def handle_completion_args( + args: argparse.Namespace, loader: loaders.Loader +) -> (str, datetime.datetime): """Returns (group_name, datetime)""" # Grab completion options from CLI or loader export_group_name = args.export_group or loader.group_name export_datetime = ( - datetime.datetime.fromisoformat(args.export_timestamp) if args.export_timestamp else loader.export_datetime + datetime.datetime.fromisoformat(args.export_timestamp) + if args.export_timestamp + else loader.export_datetime ) # Disable entirely if asked to @@ -211,7 +231,9 @@ def handle_completion_args(args: argparse.Namespace, loader: loaders.Loader) -> async def etl_main(args: argparse.Namespace) -> None: # Set up some common variables - store.set_user_fs_options(vars(args)) # record filesystem options like --s3-region before creating Roots + + # record filesystem options like --s3-region before creating Roots + store.set_user_fs_options(vars(args)) root_input = store.Root(args.dir_input) root_phi = store.Root(args.dir_phi, create=True) @@ -221,7 +243,9 @@ async def etl_main(args: argparse.Namespace) -> None: # Check which tasks are being run, allowing comma-separated values task_names = args.task and set(itertools.chain.from_iterable(t.split(",") for t in args.task)) - task_filters = args.task_filter and list(itertools.chain.from_iterable(t.split(",") for t in args.task_filter)) + task_filters = args.task_filter and list( + itertools.chain.from_iterable(t.split(",") for t in args.task_filter) + ) selected_tasks = task_factory.get_selected_tasks(task_names, task_filters) # Print configuration @@ -249,7 +273,11 @@ async def etl_main(args: argparse.Namespace) -> None: config_loader = loaders.I2b2Loader(root_input, export_to=args.export_to) else: config_loader = loaders.FhirNdjsonLoader( - root_input, client=client, export_to=args.export_to, since=args.since, until=args.until + root_input, + client=client, + export_to=args.export_to, + since=args.since, + until=args.until, ) # Pull down resources from any remote location (like s3), convert from i2b2, or do a bulk export diff --git a/cumulus_etl/etl/context.py b/cumulus_etl/etl/context.py index c1b146e0..7258ff46 100644 --- a/cumulus_etl/etl/context.py +++ b/cumulus_etl/etl/context.py @@ -64,7 +64,8 @@ def last_successful_output_dir(self, value: str) -> None: self._data[self._LAST_SUCCESSFUL_OUTPUT_DIR] = value def save(self) -> None: - common.write_json(self._path, self.as_json(), indent=4) # pretty-print this since it isn't large + # pretty-print this since it isn't large + common.write_json(self._path, self.as_json(), indent=4) def as_json(self) -> dict: return dict(self._data) diff --git a/cumulus_etl/etl/convert/cli.py b/cumulus_etl/etl/convert/cli.py index 5f63707f..ffc2957f 100644 --- a/cumulus_etl/etl/convert/cli.py +++ b/cumulus_etl/etl/convert/cli.py @@ -127,7 +127,9 @@ def copy_job_configs(input_root: store.Root, output_root: store.Root) -> None: output_root.put(job_config_path, output_root.path, recursive=True) -def walk_tree(input_root: store.Root, output_root: store.Root, formatter_class: type[formats.Format]) -> None: +def walk_tree( + input_root: store.Root, output_root: store.Root, formatter_class: type[formats.Format] +) -> None: all_tasks = task_factory.get_all_tasks() with cli_utils.make_progress_bar() as progress: @@ -182,7 +184,8 @@ def define_convert_parser(parser: argparse.ArgumentParser) -> None: async def convert_main(args: argparse.Namespace) -> None: """Main logic for converting""" - store.set_user_fs_options(vars(args)) # record filesystem options like --s3-region before creating Roots + # record filesystem options like --s3-region before creating Roots + store.set_user_fs_options(vars(args)) input_root = store.Root(args.input_dir) validate_input_dir(input_root) diff --git a/cumulus_etl/etl/studies/covid_symptom/covid_ctakes.py b/cumulus_etl/etl/studies/covid_symptom/covid_ctakes.py index 0e618fc8..7fe4dd18 100644 --- a/cumulus_etl/etl/studies/covid_symptom/covid_ctakes.py +++ b/cumulus_etl/etl/studies/covid_symptom/covid_ctakes.py @@ -60,9 +60,13 @@ async def covid_symptoms_extract( timestamp = common.datetime_now().isoformat() try: - ctakes_json = await nlp.ctakes_extract(cache, ctakes_namespace, clinical_note, client=ctakes_http_client) + ctakes_json = await nlp.ctakes_extract( + cache, ctakes_namespace, clinical_note, client=ctakes_http_client + ) except Exception as exc: # pylint: disable=broad-except - logging.warning("Could not extract symptoms for docref %s (%s): %s", docref_id, type(exc).__name__, exc) + logging.warning( + "Could not extract symptoms for docref %s (%s): %s", docref_id, type(exc).__name__, exc + ) return None matches = ctakes_json.list_sign_symptom(ctakesclient.typesystem.Polarity.pos) @@ -85,10 +89,17 @@ def is_covid_match(m: ctakesclient.typesystem.MatchText): try: spans = ctakes_json.list_spans(matches) polarities_cnlp = await nlp.list_polarity( - cache, cnlp_namespace, clinical_note, spans, model=polarity_model, client=cnlp_http_client + cache, + cnlp_namespace, + clinical_note, + spans, + model=polarity_model, + client=cnlp_http_client, ) except Exception as exc: # pylint: disable=broad-except - logging.warning("Could not check polarity for docref %s (%s): %s", docref_id, type(exc).__name__, exc) + logging.warning( + "Could not check polarity for docref %s (%s): %s", docref_id, type(exc).__name__, exc + ) return None # Helper to make a single row (match_value is None if there were no found symptoms at all) diff --git a/cumulus_etl/etl/studies/covid_symptom/covid_tasks.py b/cumulus_etl/etl/studies/covid_symptom/covid_tasks.py index 31574623..166b3de6 100644 --- a/cumulus_etl/etl/studies/covid_symptom/covid_tasks.py +++ b/cumulus_etl/etl/studies/covid_symptom/covid_tasks.py @@ -67,7 +67,9 @@ def is_ed_coding(coding): def is_ed_docref(docref): """Returns true if this is a coding for an emergency department note""" # We check both type and category for safety -- we aren't sure yet how EHRs are using these fields. - codings = list(itertools.chain.from_iterable([cat.get("coding", []) for cat in docref.get("category", [])])) + codings = list( + itertools.chain.from_iterable([cat.get("coding", []) for cat in docref.get("category", [])]) + ) codings += docref.get("type", {}).get("coding", []) return any(is_ed_coding(x) for x in codings) @@ -126,7 +128,9 @@ async def read_entries(self, *, progress: rich.progress.Progress = None) -> task # one client for both NLP services for now -- no parallel requests yet, so no need to be fancy http_client = nlp.ctakes_httpx_client() - async for orig_docref, docref, clinical_note in self.read_notes(progress=progress, doc_check=is_ed_docref): + async for orig_docref, docref, clinical_note in self.read_notes( + progress=progress, doc_check=is_ed_docref + ): symptoms = await covid_ctakes.covid_symptoms_extract( phi_root, docref, diff --git a/cumulus_etl/etl/tasks/base.py b/cumulus_etl/etl/tasks/base.py index 9d93204c..f4b39aeb 100644 --- a/cumulus_etl/etl/tasks/base.py +++ b/cumulus_etl/etl/tasks/base.py @@ -107,8 +107,11 @@ def __init__(self, task_config: config.JobConfig, scrubber: deid.Scrubber): assert self.resource # noqa: S101 self.task_config = task_config self.scrubber = scrubber - self.formatters: list[formats.Format | None] = [None] * len(self.outputs) # create format placeholders - self.summaries: list[config.JobSummary] = [config.JobSummary(output.get_name(self)) for output in self.outputs] + # create format placeholders + self.formatters: list[formats.Format | None] = [None] * len(self.outputs) + self.summaries: list[config.JobSummary] = [ + config.JobSummary(output.get_name(self)) for output in self.outputs + ] self.completion_tracking_enabled = ( self.task_config.export_group_name is not None and self.task_config.export_datetime ) @@ -155,7 +158,9 @@ async def run(self) -> list[config.JobSummary]: return self.summaries @classmethod - def make_batch_from_rows(cls, resource_type: str | None, rows: list[dict], groups: set[str] | None = None): + def make_batch_from_rows( + cls, resource_type: str | None, rows: list[dict], groups: set[str] | None = None + ): schema = cls.get_schema(resource_type, rows) return formats.Batch(rows, groups=groups, schema=schema) @@ -178,7 +183,9 @@ async def _write_tables_in_batches( def update_status(): status.plain = "\n".join( - f"{x.success:,} written to {x.label}" for i, x in enumerate(self.summaries) if self.outputs[i].visible + f"{x.success:,} written to {x.label}" + for i, x in enumerate(self.summaries) + if self.outputs[i].visible ) batch_index = 0 @@ -187,8 +194,11 @@ def update_status(): async for batches in batching.batch_iterate(entries, self.task_config.batch_size): if format_progress_task is not None: - progress.update(format_progress_task, visible=False) # hide old batches, to save screen space - format_progress_task = progress.add_task(f"Writing batch {batch_index + 1:,}", total=None) + # hide old batches, to save screen space + progress.update(format_progress_task, visible=False) + format_progress_task = progress.add_task( + f"Writing batch {batch_index + 1:,}", total=None + ) # Batches is a tuple of lists of resources - the tuple almost never matters, but it is there in case the # task is generating multiple types of resources. Like MedicationRequest creating Medications as it goes. @@ -215,7 +225,8 @@ def _touch_remaining_tables(self): """Writes empty dataframe to any table we haven't written to yet""" for table_index, formatter in enumerate(self.formatters): if formatter is None: # No data got written yet - self._write_one_table_batch([], table_index, 0) # just write an empty dataframe (should be fast) + # just write an empty dataframe (should be fast) + self._write_one_table_batch([], table_index, 0) def _update_completion_table(self) -> None: # TODO: what about empty sets - do we assume the export gave 0 results or skip it? diff --git a/cumulus_etl/etl/tasks/basic_tasks.py b/cumulus_etl/etl/tasks/basic_tasks.py index d69af758..009b5029 100644 --- a/cumulus_etl/etl/tasks/basic_tasks.py +++ b/cumulus_etl/etl/tasks/basic_tasks.py @@ -124,7 +124,9 @@ def scrub_medication(self, medication: dict | None) -> bool: # Since Medications are not patient-specific, we don't need the full MS treatment. # But still, we should probably drop some bits that might more easily identify the *institution*. # This is a poor-man's MS config tool (and a blocklist rather than allow-list, but it's a very simple resource) - medication.pop("extension", None) # *should* remove at all layers, but this will catch 99% of them + + # *should* remove extensions at all layers, but this will catch 99% of them + medication.pop("extension", None) medication.pop("identifier", None) medication.pop("text", None) # Leave batch.lotNumber freeform text in place, it might be useful for quality control @@ -152,7 +154,9 @@ async def fetch_medication(self, resource: dict) -> dict | None: self.summaries[1].had_errors = True if self.task_config.dir_errors: - error_root = store.Root(os.path.join(self.task_config.dir_errors, self.name), create=True) + error_root = store.Root( + os.path.join(self.task_config.dir_errors, self.name), create=True + ) error_path = error_root.joinpath("medication-fetch-errors.ndjson") with common.NdjsonWriter(error_path, "a") as writer: writer.write(resource) diff --git a/cumulus_etl/etl/tasks/batching.py b/cumulus_etl/etl/tasks/batching.py index d0abf5c2..ac9d6585 100644 --- a/cumulus_etl/etl/tasks/batching.py +++ b/cumulus_etl/etl/tasks/batching.py @@ -41,7 +41,9 @@ async def _batch_slice(iterable: AsyncIterable[AtomStreams], n: int) -> ItemBatc return slices -async def batch_iterate(iterable: AsyncIterable[AtomStreams], size: int) -> AsyncIterator[ItemBatches]: +async def batch_iterate( + iterable: AsyncIterable[AtomStreams], size: int +) -> AsyncIterator[ItemBatches]: """ Yields sub-iterators, each roughly {size} elements from iterable. @@ -64,6 +66,7 @@ async def batch_iterate(iterable: AsyncIterable[AtomStreams], size: int) -> Asyn if size < 1: raise ValueError("Must iterate by at least a batch of 1") - true_iterable = aiter(iterable) # get a real once-through iterable (we want to iterate only once) + # get a real once-through iterable (we want to iterate only once) + true_iterable = aiter(iterable) while batches := await _batch_slice(true_iterable, size): yield batches diff --git a/cumulus_etl/etl/tasks/nlp_task.py b/cumulus_etl/etl/tasks/nlp_task.py index 82406d82..ae0fdb05 100644 --- a/cumulus_etl/etl/tasks/nlp_task.py +++ b/cumulus_etl/etl/tasks/nlp_task.py @@ -61,7 +61,10 @@ def add_error(self, docref: dict) -> None: writer.write(docref) async def read_notes( - self, *, doc_check: Callable[[dict], bool] | None = None, progress: rich.progress.Progress = None + self, + *, + doc_check: Callable[[dict], bool] | None = None, + progress: rich.progress.Progress = None, ) -> (dict, dict, str): """ Iterate through clinical notes. diff --git a/cumulus_etl/etl/tasks/task_factory.py b/cumulus_etl/etl/tasks/task_factory.py index 579c2bc0..5b0db1bc 100644 --- a/cumulus_etl/etl/tasks/task_factory.py +++ b/cumulus_etl/etl/tasks/task_factory.py @@ -89,7 +89,10 @@ def get_selected_tasks( all_task_names = {t.name for t in all_tasks} if unknown_names := names - all_task_names: print_names = "\n".join(sorted(f" {key}" for key in all_task_names)) - print(f"Unknown task '{unknown_names.pop()}' requested. Valid task names:\n{print_names}", file=sys.stderr) + print( + f"Unknown task '{unknown_names.pop()}' requested. Valid task names:\n{print_names}", + file=sys.stderr, + ) raise SystemExit(errors.TASK_UNKNOWN) # Check for names that conflict with the chosen filters diff --git a/cumulus_etl/fhir/fhir_auth.py b/cumulus_etl/fhir/fhir_auth.py index a1ce2215..a26f2925 100644 --- a/cumulus_etl/fhir/fhir_auth.py +++ b/cumulus_etl/fhir/fhir_auth.py @@ -85,7 +85,9 @@ async def authorize(self, session: httpx.AsyncClient, reauthorize=False) -> None if not message: message = str(exc) - errors.fatal(f"Could not authenticate with the FHIR server: {message}", errors.FHIR_AUTH_FAILED) + errors.fatal( + f"Could not authenticate with the FHIR server: {message}", errors.FHIR_AUTH_FAILED + ) def sign_headers(self, headers: dict) -> dict: """Add signature token to request headers""" @@ -202,7 +204,7 @@ def create_auth( smart_jwks: dict | None, ) -> Auth: """Determine which auth method to use based on user provided arguments""" - valid_smart_jwks = smart_jwks is not None # compared to a falsy (but technically usable) empty dict for example + valid_smart_jwks = smart_jwks is not None # Check if the user tried to specify multiple types of auth, and help them out has_basic_args = bool(basic_user or basic_password) diff --git a/cumulus_etl/fhir/fhir_client.py b/cumulus_etl/fhir/fhir_client.py index 278c0c85..088bc093 100644 --- a/cumulus_etl/fhir/fhir_client.py +++ b/cumulus_etl/fhir/fhir_client.py @@ -53,12 +53,19 @@ def __init__( """ self._server_root = url # all requests are relative to this URL if self._server_root and not self._server_root.endswith("/"): - self._server_root += "/" # This will ensure the last segment does not get chopped off by urljoin + # This will ensure the last segment does not get chopped off by urljoin + self._server_root += "/" self._client_id = smart_client_id self._server_type = ServerType.UNKNOWN self._auth = fhir_auth.create_auth( - self._server_root, resources, basic_user, basic_password, bearer_token, smart_client_id, smart_jwks + self._server_root, + resources, + basic_user, + basic_password, + bearer_token, + smart_client_id, + smart_jwks, ) self._session: httpx.AsyncClient | None = None self._capabilities: dict = {} @@ -106,14 +113,18 @@ async def request( # merge in user headers with defaults final_headers.update(headers or {}) - response = await self._request_with_signed_headers(method, url, final_headers, stream=stream) + response = await self._request_with_signed_headers( + method, url, final_headers, stream=stream + ) # Check if our access token expired and thus needs to be refreshed if response.status_code == 401: await self._auth.authorize(self._session, reauthorize=True) if stream: await response.aclose() - response = await self._request_with_signed_headers(method, url, final_headers, stream=stream) + response = await self._request_with_signed_headers( + method, url, final_headers, stream=stream + ) try: response.raise_for_status() @@ -201,7 +212,9 @@ async def _read_capabilities(self) -> None: self._capabilities = capabilities - async def _request_with_signed_headers(self, method: str, url: str, headers: dict, **kwargs) -> httpx.Response: + async def _request_with_signed_headers( + self, method: str, url: str, headers: dict, **kwargs + ) -> httpx.Response: """ Issues a GET request and sign the headers with the current access token. @@ -252,7 +265,9 @@ def create_fhir_client_for_cli( try: try: # Try to load client ID from file first (some servers use crazy long ones, like SMART's bulk-data-server) - smart_client_id = common.read_text(args.smart_client_id).strip() if args.smart_client_id else None + smart_client_id = ( + common.read_text(args.smart_client_id).strip() if args.smart_client_id else None + ) except FileNotFoundError: smart_client_id = args.smart_client_id diff --git a/cumulus_etl/formats/batch.py b/cumulus_etl/formats/batch.py index c7d7027a..a2eb2ed7 100644 --- a/cumulus_etl/formats/batch.py +++ b/cumulus_etl/formats/batch.py @@ -15,7 +15,9 @@ class Batch: - Written to the target location as one piece (e.g. one ndjson file or one Delta Lake update chunk) """ - def __init__(self, rows: list[dict], groups: set[str] | None = None, schema: pyarrow.Schema = None): + def __init__( + self, rows: list[dict], groups: set[str] | None = None, schema: pyarrow.Schema = None + ): self.rows = rows # `groups` is the set of the values of the format's `group_field` represented by `rows`. # We can't just get this from rows directly because there might be groups that now have zero entries. diff --git a/cumulus_etl/formats/deltalake.py b/cumulus_etl/formats/deltalake.py index 3e7c752a..459dfbff 100644 --- a/cumulus_etl/formats/deltalake.py +++ b/cumulus_etl/formats/deltalake.py @@ -69,7 +69,10 @@ def initialize_class(cls, root: store.Root) -> None: pyspark.sql.SparkSession.builder.appName("cumulus-etl") .config("spark.databricks.delta.schema.autoMerge.enabled", "true") .config("spark.driver.memory", "4g") - .config("spark.sql.catalog.spark_catalog", "org.apache.spark.sql.delta.catalog.DeltaCatalog") + .config( + "spark.sql.catalog.spark_catalog", + "org.apache.spark.sql.delta.catalog.DeltaCatalog", + ) .config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension") ) @@ -94,7 +97,9 @@ def _write_one_batch(self, batch: Batch) -> None: delta_table.generate("symlink_format_manifest") - def update_delta_table(self, updates: pyspark.sql.DataFrame, groups: set[str]) -> delta.DeltaTable: + def update_delta_table( + self, updates: pyspark.sql.DataFrame, groups: set[str] + ) -> delta.DeltaTable: table = ( delta.DeltaTable.createIfNotExists(self.spark) .addColumns(updates.schema) @@ -108,7 +113,9 @@ def update_delta_table(self, updates: pyspark.sql.DataFrame, groups: set[str]) - # Merge in new data merge = ( - table.alias("table").merge(source=updates.alias("updates"), condition=condition).whenNotMatchedInsertAll() + table.alias("table") + .merge(source=updates.alias("updates"), condition=condition) + .whenNotMatchedInsertAll() ) if self.update_existing: update_condition = self._get_update_condition(updates.schema) @@ -144,7 +151,8 @@ def finalize(self) -> None: logging.exception("Could not finalize Delta Lake table %s", self.dbname) def _table_path(self, dbname: str) -> str: - return self.root.joinpath(dbname).replace("s3://", "s3a://") # hadoop uses the s3a: scheme instead of s3: + # hadoop uses the s3a: scheme instead of s3: + return self.root.joinpath(dbname).replace("s3://", "s3a://") @staticmethod def _get_update_condition(schema: pyspark.sql.types.StructType) -> str | None: @@ -192,7 +200,10 @@ def _configure_fs(root: store.Root, spark: pyspark.sql.SparkSession): # This credentials.provider option enables usage of the AWS credentials default priority list (i.e. it will # cause a check for a ~/.aws/credentials file to happen instead of just looking for env vars). # See http://wrschneider.github.io/2019/02/02/spark-credentials-file.html for details - spark.conf.set("fs.s3a.aws.credentials.provider", "com.amazonaws.auth.DefaultAWSCredentialsProviderChain") + spark.conf.set( + "fs.s3a.aws.credentials.provider", + "com.amazonaws.auth.DefaultAWSCredentialsProviderChain", + ) spark.conf.set("fs.s3a.sse.enabled", "true") spark.conf.set("fs.s3a.server-side-encryption-algorithm", "SSE-KMS") kms_key = fsspec_options.get("s3_additional_kwargs", {}).get("SSEKMSKeyId") diff --git a/cumulus_etl/loaders/fhir/bulk_export.py b/cumulus_etl/loaders/fhir/bulk_export.py index 4676ea8d..b46fa001 100644 --- a/cumulus_etl/loaders/fhir/bulk_export.py +++ b/cumulus_etl/loaders/fhir/bulk_export.py @@ -56,7 +56,8 @@ def __init__( self._resources = resources self._url = url if not self._url.endswith("/"): - self._url += "/" # This will ensure the last segment does not get chopped off by urljoin + # This will ensure the last segment does not get chopped off by urljoin + self._url += "/" self._destination = destination self._total_wait_time = 0 # in seconds, across all our requests self._since = since @@ -190,7 +191,9 @@ async def _request_with_delay( if response.status_code == target_status_code: if status_box.plain: status.stop() - print(f" Waited for a total of {common.human_time_offset(self._total_wait_time)}") + print( + f" Waited for a total of {common.human_time_offset(self._total_wait_time)}" + ) return response # 202 == server is still working on it, 429 == server is busy -- in both cases, we wait @@ -251,7 +254,8 @@ async def _gather_all_messages(self, error_list: list[dict]) -> (list[str], list """ coroutines = [] for error in error_list: - if error.get("type") == "OperationOutcome": # per spec as of writing, the only allowed type + # per spec as of writing, OperationOutcome is the only allowed type + if error.get("type") == "OperationOutcome": coroutines.append( self._request_with_logging( error["url"], @@ -270,7 +274,8 @@ async def _gather_all_messages(self, error_list: list[dict]) -> (list[str], list fatal_messages = [] info_messages = [] for response in responses: - outcomes = [json.loads(x) for x in response.text.split("\n") if x] # a list of OperationOutcomes + # Create a list of OperationOutcomes + outcomes = [json.loads(x) for x in response.text.split("\n") if x] self._log.download_complete(response.url, len(outcomes), len(response.text)) for outcome in outcomes: for issue in outcome.get("issue", []): diff --git a/cumulus_etl/loaders/fhir/export_log.py b/cumulus_etl/loaders/fhir/export_log.py index 02cc1ccd..c6d4fa3a 100644 --- a/cumulus_etl/loaders/fhir/export_log.py +++ b/cumulus_etl/loaders/fhir/export_log.py @@ -114,7 +114,9 @@ def __init__(self, root: store.Root): self._num_bytes = 0 self._start_time = None - def _event(self, event_id: str, detail: dict, *, timestamp: datetime.datetime | None = None) -> None: + def _event( + self, event_id: str, detail: dict, *, timestamp: datetime.datetime | None = None + ) -> None: timestamp = timestamp or common.datetime_now(local=True) if self._start_time is None: self._start_time = timestamp diff --git a/cumulus_etl/loaders/fhir/ndjson_loader.py b/cumulus_etl/loaders/fhir/ndjson_loader.py index 3e4cbcc2..f8767eae 100644 --- a/cumulus_etl/loaders/fhir/ndjson_loader.py +++ b/cumulus_etl/loaders/fhir/ndjson_loader.py @@ -42,7 +42,8 @@ async def load_all(self, resources: list[str]) -> common.Directory: else: if self.export_to or self.since or self.until: errors.fatal( - "You provided FHIR bulk export parameters but did not provide a FHIR server", errors.ARGS_CONFLICT + "You provided FHIR bulk export parameters but did not provide a FHIR server", + errors.ARGS_CONFLICT, ) input_root = self.root diff --git a/cumulus_etl/loaders/i2b2/loader.py b/cumulus_etl/loaders/i2b2/loader.py index 100131fc..dd98617a 100644 --- a/cumulus_etl/loaders/i2b2/loader.py +++ b/cumulus_etl/loaders/i2b2/loader.py @@ -59,7 +59,9 @@ def _load_all_with_extractors( tmpdir = cli_utils.make_export_dir(self.export_to) if "Condition" in resources: - with open(Path(Path(__file__).resolve().parent, "icd.json"), encoding="utf-8") as code_json: + with open( + Path(Path(__file__).resolve().parent, "icd.json"), encoding="utf-8" + ) as code_json: code_dict = json.load(code_json) self._loop( conditions(), @@ -109,11 +111,17 @@ def _load_all_with_extractors( return tmpdir - def _loop(self, i2b2_entries: Iterable[schema.Dimension], to_fhir: I2b2ToFhirCallable, output_path: str) -> None: + def _loop( + self, + i2b2_entries: Iterable[schema.Dimension], + to_fhir: I2b2ToFhirCallable, + output_path: str, + ) -> None: """Takes one kind of i2b2 resource, loads them all up, and writes out a FHIR ndjson file""" fhir_resources = (to_fhir(x) for x in i2b2_entries) - ids = set() # keep track of every ID we've seen so far, because sometimes i2b2 can have duplicates + # keep track of every ID we've seen so far, because sometimes i2b2 can have duplicates + ids = set() with common.NdjsonWriter(output_path) as output_file: # Now write each FHIR resource line by line to the output @@ -151,10 +159,15 @@ def _load_all_from_csv(self, resources: list[str]) -> common.Directory: os.path.join(path, "observation_fact_vitals.csv"), ), documentreferences=partial( - extract.extract_csv_observation_facts, os.path.join(path, "observation_fact_notes.csv") + extract.extract_csv_observation_facts, + os.path.join(path, "observation_fact_notes.csv"), + ), + patients=partial( + extract.extract_csv_patients, os.path.join(path, "patient_dimension.csv") + ), + encounters=partial( + extract.extract_csv_visits, os.path.join(path, "visit_dimension.csv") ), - patients=partial(extract.extract_csv_patients, os.path.join(path, "patient_dimension.csv")), - encounters=partial(extract.extract_csv_visits, os.path.join(path, "visit_dimension.csv")), ) ################################################################################################################### @@ -169,7 +182,9 @@ def _load_all_from_oracle(self, resources: list[str]) -> common.Directory: resources, conditions=partial(oracle_extract.list_observation_fact, path, ["ICD9", "ICD10"]), lab_views=partial(oracle_extract.list_observation_fact, path, ["LAB"]), - medicationrequests=partial(oracle_extract.list_observation_fact, path, ["ADMINMED", "HOMEMED"]), + medicationrequests=partial( + oracle_extract.list_observation_fact, path, ["ADMINMED", "HOMEMED"] + ), vitals=partial(oracle_extract.list_observation_fact, path, ["VITAL"]), documentreferences=partial(oracle_extract.list_observation_fact, path, ["NOTE"]), patients=partial(oracle_extract.list_patient, path), diff --git a/cumulus_etl/loaders/i2b2/oracle/connect.py b/cumulus_etl/loaders/i2b2/oracle/connect.py index 5ebd2688..d889757b 100644 --- a/cumulus_etl/loaders/i2b2/oracle/connect.py +++ b/cumulus_etl/loaders/i2b2/oracle/connect.py @@ -16,7 +16,8 @@ def _get_user() -> str: user = os.environ.get("CUMULUS_SQL_USER") if not user: print( - "To connect to an Oracle SQL server, please set the environment variable CUMULUS_SQL_USER", file=sys.stderr + "To connect to an Oracle SQL server, please set the environment variable CUMULUS_SQL_USER", + file=sys.stderr, ) raise SystemExit(errors.SQL_USER_MISSING) return user diff --git a/cumulus_etl/loaders/i2b2/oracle/query.py b/cumulus_etl/loaders/i2b2/oracle/query.py index 43252de8..952ed46d 100644 --- a/cumulus_etl/loaders/i2b2/oracle/query.py +++ b/cumulus_etl/loaders/i2b2/oracle/query.py @@ -101,7 +101,10 @@ def sql_observation_fact(categories: list[str]) -> str: matchers = [f"(concept_cd like '{category}:%')" for category in categories] - return f"select {cols} \n from {Table.observation_fact.value} O " f"where {' or '.join(matchers)}" # noqa: S608 + return ( + f"select {cols} \n from {Table.observation_fact.value} O " # noqa: S608 + f"where {' or '.join(matchers)}" + ) def eq_val_type(val_type: ValueType) -> str: diff --git a/cumulus_etl/loaders/i2b2/transform.py b/cumulus_etl/loaders/i2b2/transform.py index 9b3c6c1b..cdfdafc7 100644 --- a/cumulus_etl/loaders/i2b2/transform.py +++ b/cumulus_etl/loaders/i2b2/transform.py @@ -92,11 +92,16 @@ def to_fhir_encounter(visit: VisitDimension) -> dict: "status": "unknown", "period": {"start": chop_to_date(visit.start_date), "end": chop_to_date(visit.end_date)}, # Most generic encounter type possible, only included because the 'type' field is required in us-core - "type": [make_concept("308335008", "http://snomed.info/sct", "Patient encounter procedure")], + "type": [ + make_concept("308335008", "http://snomed.info/sct", "Patient encounter procedure") + ], } if visit.length_of_stay: # days - encounter["length"] = {"unit": "d", "value": visit.length_of_stay and float(visit.length_of_stay)} + encounter["length"] = { + "unit": "d", + "value": visit.length_of_stay and float(visit.length_of_stay), + } class_fhir = external_mappings.SNOMED_ADMISSION.get(visit.inout_cd) if not class_fhir: @@ -125,7 +130,9 @@ def to_fhir_observation_lab(obsfact: ObservationFact) -> dict: "id": str(obsfact.instance_num), "subject": fhir.ref_resource("Patient", obsfact.patient_num), "encounter": fhir.ref_resource("Encounter", obsfact.encounter_num), - "category": [make_concept("laboratory", "http://terminology.hl7.org/CodeSystem/observation-category")], + "category": [ + make_concept("laboratory", "http://terminology.hl7.org/CodeSystem/observation-category") + ], "effectiveDateTime": chop_to_date(obsfact.start_date), "status": "unknown", } @@ -144,7 +151,9 @@ def to_fhir_observation_lab(obsfact: ObservationFact) -> dict: else: lab_result = obsfact.tval_char lab_result_system = "http://cumulus.smarthealthit.org/i2b2" - observation["valueCodeableConcept"] = make_concept(lab_result, lab_result_system, display=obsfact.tval_char) + observation["valueCodeableConcept"] = make_concept( + lab_result, lab_result_system, display=obsfact.tval_char + ) return observation @@ -162,7 +171,11 @@ def to_fhir_observation_vitals(obsfact: ObservationFact) -> dict: "resourceType": "Observation", "id": str(obsfact.instance_num), "status": "unknown", - "category": [make_concept("vital-signs", "http://terminology.hl7.org/CodeSystem/observation-category")], + "category": [ + make_concept( + "vital-signs", "http://terminology.hl7.org/CodeSystem/observation-category" + ) + ], "code": make_concept(obsfact.concept_cd, "http://cumulus.smarthealthit.org/i2b2"), "subject": fhir.ref_resource("Patient", obsfact.patient_num), "encounter": fhir.ref_resource("Encounter", obsfact.encounter_num), @@ -194,8 +207,12 @@ def to_fhir_condition(obsfact: ObservationFact, display_codes: dict) -> dict: ) ], "recordedDate": chop_to_date(obsfact.start_date), - "clinicalStatus": make_concept("active", "http://terminology.hl7.org/CodeSystem/condition-clinical"), - "verificationStatus": make_concept("unconfirmed", "http://terminology.hl7.org/CodeSystem/condition-ver-status"), + "clinicalStatus": make_concept( + "active", "http://terminology.hl7.org/CodeSystem/condition-clinical" + ), + "verificationStatus": make_concept( + "unconfirmed", "http://terminology.hl7.org/CodeSystem/condition-ver-status" + ), } # Code @@ -260,11 +277,16 @@ def to_fhir_documentreference(obsfact: ObservationFact) -> dict: "subject": fhir.ref_resource("Patient", obsfact.patient_num), "context": { "encounter": [fhir.ref_resource("Encounter", obsfact.encounter_num)], - "period": {"start": chop_to_date(obsfact.start_date), "end": chop_to_date(obsfact.end_date)}, + "period": { + "start": chop_to_date(obsfact.start_date), + "end": chop_to_date(obsfact.end_date), + }, }, # It would be nice to get a real mapping for the "NOTE:" concept CD types to a real system. # But for now, use this custom (and the URL isn't even valid) system to note these i2b2 concepts. - "type": make_concept(obsfact.concept_cd, "http://cumulus.smarthealthit.org/i2b2", obsfact.tval_char), + "type": make_concept( + obsfact.concept_cd, "http://cumulus.smarthealthit.org/i2b2", obsfact.tval_char + ), "status": "current", "content": [ { @@ -305,9 +327,17 @@ def get_observation_value(obsfact: ObservationFact) -> dict: http://hl7.org/fhir/R4/datatypes.html#Quantity """ if obsfact.valtype_cd == "T": - return {"valueCodeableConcept": make_concept(obsfact.tval_char, "http://cumulus.smarthealthit.org/i2b2")} + return { + "valueCodeableConcept": make_concept( + obsfact.tval_char, "http://cumulus.smarthealthit.org/i2b2" + ) + } elif obsfact.valtype_cd == "B": - return {"valueCodeableConcept": make_concept(obsfact.observation_blob, "http://cumulus.smarthealthit.org/i2b2")} + return { + "valueCodeableConcept": make_concept( + obsfact.observation_blob, "http://cumulus.smarthealthit.org/i2b2" + ) + } elif obsfact.valtype_cd == "@": # no value return {} elif obsfact.valtype_cd != "N": @@ -340,7 +370,9 @@ def get_observation_value(obsfact: ObservationFact) -> dict: return {"valueQuantity": quantity} -def make_concept(code: str, system: str | None, display: str | None = None, display_codes: dict | None = None) -> dict: +def make_concept( + code: str, system: str | None, display: str | None = None, display_codes: dict | None = None +) -> dict: """Syntactic sugar to make a codeable concept""" coding = {"code": code, "system": system} if display: diff --git a/cumulus_etl/nlp/__init__.py b/cumulus_etl/nlp/__init__.py index 9431411a..65d80676 100644 --- a/cumulus_etl/nlp/__init__.py +++ b/cumulus_etl/nlp/__init__.py @@ -3,4 +3,9 @@ from .extract import TransformerModel, ctakes_extract, ctakes_httpx_client, list_polarity from .huggingface import hf_info, hf_prompt, llama2_prompt from .utils import cache_wrapper, is_docref_valid -from .watcher import check_ctakes, check_negation_cnlpt, check_term_exists_cnlpt, restart_ctakes_with_bsv +from .watcher import ( + check_ctakes, + check_negation_cnlpt, + check_term_exists_cnlpt, + restart_ctakes_with_bsv, +) diff --git a/cumulus_etl/nlp/extract.py b/cumulus_etl/nlp/extract.py index 08ae9ab8..e23b9fbd 100644 --- a/cumulus_etl/nlp/extract.py +++ b/cumulus_etl/nlp/extract.py @@ -59,7 +59,9 @@ async def list_polarity( try: result = [ctakesclient.typesystem.Polarity(x) for x in common.read_json(full_path)] except Exception: # pylint: disable=broad-except - result = await ctakesclient.transformer.list_polarity(sentence, spans, client=client, model=model) + result = await ctakesclient.transformer.list_polarity( + sentence, spans, client=client, model=model + ) cache.makedirs(os.path.dirname(full_path)) common.write_json(full_path, [x.value for x in result]) diff --git a/cumulus_etl/nlp/huggingface.py b/cumulus_etl/nlp/huggingface.py index e854acf6..5d5e665b 100644 --- a/cumulus_etl/nlp/huggingface.py +++ b/cumulus_etl/nlp/huggingface.py @@ -11,7 +11,9 @@ def get_hugging_face_url() -> str: return os.environ.get("CUMULUS_HUGGING_FACE_URL") or "http://localhost:8086/" -async def llama2_prompt(system_prompt: str, user_prompt: str, *, client: httpx.AsyncClient = None) -> str: +async def llama2_prompt( + system_prompt: str, user_prompt: str, *, client: httpx.AsyncClient = None +) -> str: """ Prompts a llama2 chat model and provides its response. @@ -35,7 +37,8 @@ async def llama2_prompt(system_prompt: str, user_prompt: str, *, client: httpx.A response = await hf_prompt(whole_prompt, client=client) text = response[0]["generated_text"] - text = text.removeprefix(whole_prompt).strip() # llama2 gives back the prompt too, but we don't need it + # llama2 gives back the prompt too, but we don't need it + text = text.removeprefix(whole_prompt).strip() return text diff --git a/cumulus_etl/nlp/utils.py b/cumulus_etl/nlp/utils.py index 81c50403..ef1b10e0 100644 --- a/cumulus_etl/nlp/utils.py +++ b/cumulus_etl/nlp/utils.py @@ -9,12 +9,15 @@ def is_docref_valid(docref: dict) -> bool: """Returns True if this docref is not a draft or entered-in-error resource and could be considered for NLP""" - good_status = docref.get("status") in ("current", None) # status of DocRef itself - good_doc_status = docref.get("docStatus") in ("final", "amended", None) # status of clinical note + good_status = docref.get("status") in {"current", None} # status of DocRef itself + # docStatus is status of clinical note attachments + good_doc_status = docref.get("docStatus") in {"final", "amended", None} return good_status and good_doc_status -async def cache_wrapper(cache_dir: str, namespace: str, content: str, method: Callable, *args, **kwargs) -> str: +async def cache_wrapper( + cache_dir: str, namespace: str, content: str, method: Callable, *args, **kwargs +) -> str: """Looks up an NLP result in the cache first, falling back to actually calling NLP.""" # First, what is our target path for a possible cache file cache_dir = store.Root(cache_dir, create=True) diff --git a/cumulus_etl/nlp/watcher.py b/cumulus_etl/nlp/watcher.py index 21c7ab77..487e029b 100644 --- a/cumulus_etl/nlp/watcher.py +++ b/cumulus_etl/nlp/watcher.py @@ -72,7 +72,8 @@ def wait_for_ctakes_restart(): # *** Acquire socket connection with cTAKES (cTAKES is required to exist already) *** connection = socket.create_connection((url.hostname, url.port)) poller = select.poll() - poller.register(connection, select.POLLRDHUP) # will watch for remote disconnect (death or remote timeout) + # Poll for RDHUP to watch for remote disconnect (death or remote timeout) + poller.register(connection, select.POLLRDHUP) # *** Yield to caller *** yield diff --git a/cumulus_etl/upload_notes/cli.py b/cumulus_etl/upload_notes/cli.py index d99952d2..550a4db7 100644 --- a/cumulus_etl/upload_notes/cli.py +++ b/cumulus_etl/upload_notes/cli.py @@ -29,12 +29,16 @@ def init_checks(args: argparse.Namespace): if not cli_utils.is_url_available(args.label_studio_url, retry=False): errors.fatal( - f"A running Label Studio server was not found at:\n {args.label_studio_url}", errors.LABEL_STUDIO_MISSING + f"A running Label Studio server was not found at:\n {args.label_studio_url}", + errors.LABEL_STUDIO_MISSING, ) async def gather_docrefs( - client: fhir.FhirClient, root_input: store.Root, codebook: deid.Codebook, args: argparse.Namespace + client: fhir.FhirClient, + root_input: store.Root, + codebook: deid.Codebook, + args: argparse.Namespace, ) -> common.Directory: """Selects and downloads just the docrefs we need to an export folder.""" common.print_header("Gathering documents...") @@ -42,15 +46,27 @@ async def gather_docrefs( # There are three possibilities: we have real IDs, fake IDs, or neither. # Note that we don't support providing both real & fake IDs right now. It's not clear that would be useful. if args.docrefs and args.anon_docrefs: - errors.fatal("You cannot use both --docrefs and --anon-docrefs at the same time.", errors.ARGS_CONFLICT) + errors.fatal( + "You cannot use both --docrefs and --anon-docrefs at the same time.", + errors.ARGS_CONFLICT, + ) if root_input.protocol == "https": # is this a FHIR server? return await downloader.download_docrefs_from_fhir_server( - client, root_input, codebook, docrefs=args.docrefs, anon_docrefs=args.anon_docrefs, export_to=args.export_to + client, + root_input, + codebook, + docrefs=args.docrefs, + anon_docrefs=args.anon_docrefs, + export_to=args.export_to, ) else: return selector.select_docrefs_from_files( - root_input, codebook, docrefs=args.docrefs, anon_docrefs=args.anon_docrefs, export_to=args.export_to + root_input, + codebook, + docrefs=args.docrefs, + anon_docrefs=args.anon_docrefs, + export_to=args.export_to, ) @@ -133,7 +149,9 @@ async def run_nlp(notes: Collection[LabelStudioNote], args: argparse.Namespace) client=http_client, model=nlp.TransformerModel.NEGATION, ) - note.ctakes_matches = [match for i, match in enumerate(matches) if cnlpt_results[i] == Polarity.pos] + note.ctakes_matches = [ + match for i, match in enumerate(matches) if cnlpt_results[i] == Polarity.pos + ] def philter_notes(notes: Collection[LabelStudioNote], args: argparse.Namespace) -> None: @@ -194,7 +212,9 @@ def group_notes_by_encounter(notes: Collection[LabelStudioNote]) -> list[LabelSt offset = len(grouped_text) grouped_text += note.text - offset_doc_spans = {k: (v[0] + offset, v[1] + offset) for k, v in note.doc_spans.items()} + offset_doc_spans = { + k: (v[0] + offset, v[1] + offset) for k, v in note.doc_spans.items() + } grouped_doc_spans.update(offset_doc_spans) for match in note.ctakes_matches: @@ -243,7 +263,9 @@ def define_upload_notes_parser(parser: argparse.ArgumentParser) -> None: parser.add_argument("dir_phi", metavar="/path/to/phi") parser.add_argument( - "--export-to", metavar="PATH", help="Where to put exported documents (default is to delete after use)" + "--export-to", + metavar="PATH", + help="Where to put exported documents (default is to delete after use)", ) parser.add_argument( @@ -254,15 +276,25 @@ def define_upload_notes_parser(parser: argparse.ArgumentParser) -> None: ) # Old, simpler version of the above (feel free to remove after May 2024) parser.add_argument( - "--no-philter", action="store_const", const=PHILTER_DISABLE, dest="philter", help=argparse.SUPPRESS + "--no-philter", + action="store_const", + const=PHILTER_DISABLE, + dest="philter", + help=argparse.SUPPRESS, ) cli_utils.add_aws(parser) cli_utils.add_auth(parser) docs = parser.add_argument_group("document selection") - docs.add_argument("--anon-docrefs", metavar="PATH", help="CSV file with anonymized patient_id,docref_id columns") - docs.add_argument("--docrefs", metavar="PATH", help="CSV file with a docref_id column of original IDs") + docs.add_argument( + "--anon-docrefs", + metavar="PATH", + help="CSV file with anonymized patient_id,docref_id columns", + ) + docs.add_argument( + "--docrefs", metavar="PATH", help="CSV file with a docref_id column of original IDs" + ) group = cli_utils.add_nlp(parser) group.add_argument( @@ -271,12 +303,24 @@ def define_upload_notes_parser(parser: argparse.ArgumentParser) -> None: help="BSV file with concept CUIs (defaults to Covid)", default=ctakesclient.filesystem.covid_symptoms_path(), ) - group.add_argument("--no-nlp", action="store_false", dest="nlp", default=True, help="Don’t run NLP on notes") + group.add_argument( + "--no-nlp", action="store_false", dest="nlp", default=True, help="Don’t run NLP on notes" + ) group = parser.add_argument_group("Label Studio") - group.add_argument("--ls-token", metavar="PATH", help="Token file for Label Studio access", required=True) - group.add_argument("--ls-project", metavar="ID", type=int, help="Label Studio project ID to update", required=True) - group.add_argument("--overwrite", action="store_true", help="Whether to overwrite an existing task for a note") + group.add_argument( + "--ls-token", metavar="PATH", help="Token file for Label Studio access", required=True + ) + group.add_argument( + "--ls-project", + metavar="ID", + type=int, + help="Label Studio project ID to update", + required=True, + ) + group.add_argument( + "--overwrite", action="store_true", help="Whether to overwrite an existing task for a note" + ) cli_utils.add_debugging(parser) @@ -293,7 +337,9 @@ async def upload_notes_main(args: argparse.Namespace) -> None: """ init_checks(args) - store.set_user_fs_options(vars(args)) # record filesystem options like --s3-region before creating Roots + # record filesystem options like --s3-region before creating Roots + store.set_user_fs_options(vars(args)) + root_input = store.Root(args.dir_input) codebook = deid.Codebook(args.dir_phi) @@ -307,7 +353,8 @@ async def upload_notes_main(args: argparse.Namespace) -> None: notes = await read_notes_from_ndjson(client, ndjson_folder.name, codebook) await run_nlp(notes, args) - philter_notes(notes, args) # safe to do after NLP because philter does not change character counts + # It's safe to philter notes after NLP because philter does not change character counts + philter_notes(notes, args) notes = group_notes_by_encounter(notes) push_to_label_studio(notes, access_token, labels, args) diff --git a/cumulus_etl/upload_notes/downloader.py b/cumulus_etl/upload_notes/downloader.py index a69723bf..96d2a14c 100644 --- a/cumulus_etl/upload_notes/downloader.py +++ b/cumulus_etl/upload_notes/downloader.py @@ -21,7 +21,9 @@ async def download_docrefs_from_fhir_server( if docrefs: return await _download_docrefs_from_real_ids(client, docrefs, export_to=export_to) elif anon_docrefs: - return await _download_docrefs_from_fake_ids(client, codebook, anon_docrefs, export_to=export_to) + return await _download_docrefs_from_fake_ids( + client, codebook, anon_docrefs, export_to=export_to + ) else: # else we'll download the entire target path as a bulk export (presumably the user has scoped a Group) ndjson_loader = loaders.FhirNdjsonLoader(root_input, client, export_to=export_to) @@ -49,12 +51,15 @@ async def _download_docrefs_from_fake_ids( # Kick off a bunch of requests to the FHIR server for any documents for these patients # (filtered to only the given fake IDs) coroutines = [ - _request_docrefs_for_patient(client, patient_id, codebook, fake_docref_ids) for patient_id in patient_ids + _request_docrefs_for_patient(client, patient_id, codebook, fake_docref_ids) + for patient_id in patient_ids ] docrefs_per_patient = await asyncio.gather(*coroutines) # And write them all out - _write_docrefs_to_output_folder(itertools.chain.from_iterable(docrefs_per_patient), output_folder.name) + _write_docrefs_to_output_folder( + itertools.chain.from_iterable(docrefs_per_patient), output_folder.name + ) return output_folder @@ -91,7 +96,10 @@ def _write_docrefs_to_output_folder(docrefs: Iterable[dict], output_folder: str) async def _request_docrefs_for_patient( - client: fhir.FhirClient, patient_id: str, codebook: deid.Codebook, fake_docref_ids: Container[str] + client: fhir.FhirClient, + patient_id: str, + codebook: deid.Codebook, + fake_docref_ids: Container[str], ) -> list[dict]: """Returns all DocumentReferences for a given patient""" params = { diff --git a/cumulus_etl/upload_notes/labelstudio.py b/cumulus_etl/upload_notes/labelstudio.py index dabcaa70..55148e86 100644 --- a/cumulus_etl/upload_notes/labelstudio.py +++ b/cumulus_etl/upload_notes/labelstudio.py @@ -38,7 +38,9 @@ class LabelStudioNote: doc_spans: dict[str, tuple[int, int]] = dataclasses.field(default_factory=dict) # Matches found by cTAKES - ctakes_matches: list[ctakesclient.typesystem.MatchText] = dataclasses.field(default_factory=list) + ctakes_matches: list[ctakesclient.typesystem.MatchText] = dataclasses.field( + default_factory=list + ) # Matches found by Philter philter_map: dict[int, int] = dataclasses.field(default_factory=dict) @@ -59,7 +61,11 @@ def push_tasks(self, notes: Collection[LabelStudioNote], *, overwrite: bool = Fa enc_ids = [note.enc_id for note in notes] enc_id_filter = lsdm.Filters.create( lsdm.Filters.AND, - [lsdm.Filters.item(lsdm.Column.data("enc_id"), lsdm.Operator.IN_LIST, lsdm.Type.List, enc_ids)], + [ + lsdm.Filters.item( + lsdm.Column.data("enc_id"), lsdm.Operator.IN_LIST, lsdm.Type.List, enc_ids + ) + ], ) existing_tasks = self._project.get_tasks(filters=enc_id_filter) new_task_count = len(notes) - len(existing_tasks) @@ -108,7 +114,8 @@ def _format_task_for_note(self, note: LabelStudioNote) -> dict: "enc_id": note.enc_id, "anon_id": note.anon_id, "docref_mappings": note.doc_mappings, - "docref_spans": {k: list(v) for k, v in note.doc_spans.items()}, # json doesn't natively have tuples + # json doesn't natively have tuples, so convert spans to lists + "docref_spans": {k: list(v) for k, v in note.doc_spans.items()}, }, "predictions": [], } @@ -134,8 +141,11 @@ def _format_ctakes_predictions(self, task: dict, note: LabelStudioNote) -> None: results = [] count = 0 for match in note.ctakes_matches: - matched_labels = {self._cui_labels.get(concept.cui) for concept in match.conceptAttributes} - matched_labels.discard(None) # drop the result of a concept not being in our bsv label set + matched_labels = { + self._cui_labels.get(concept.cui) for concept in match.conceptAttributes + } + # drop the result of a concept not being in our bsv label set + matched_labels.discard(None) if matched_labels: results.append(self._format_ctakes_match(count, match, matched_labels)) used_labels.update(matched_labels) @@ -151,7 +161,13 @@ def _format_ctakes_match(self, count: int, match: MatchText, labels: Iterable[st "from_name": self._labels_name, "to_name": self._labels_config["to_name"][0], "type": "labels", - "value": {"start": match.begin, "end": match.end, "score": 1.0, "text": match.text, "labels": list(labels)}, + "value": { + "start": match.begin, + "end": match.end, + "score": 1.0, + "text": match.text, + "labels": list(labels), + }, } def _format_philter_predictions(self, task: dict, note: LabelStudioNote) -> None: @@ -188,7 +204,13 @@ def _format_philter_span(self, count: int, start: int, end: int, note: LabelStud "type": "labels", # We hardcode the label "_philter" - Label Studio will still highlight unknown labels, # and this is unlikely to collide with existing labels. - "value": {"start": start, "end": end, "score": 1.0, "text": text, "labels": ["_philter"]}, + "value": { + "start": start, + "end": end, + "score": 1.0, + "text": text, + "labels": ["_philter"], + }, } def _update_used_labels(self, task: dict, used_labels: Iterable[str]) -> None: diff --git a/cumulus_etl/upload_notes/selector.py b/cumulus_etl/upload_notes/selector.py index bf2e1473..2d3c6ca5 100644 --- a/cumulus_etl/upload_notes/selector.py +++ b/cumulus_etl/upload_notes/selector.py @@ -60,10 +60,13 @@ def _filter_real_docrefs(docrefs_csv: str, docrefs: Iterable[dict]) -> Iterator[ break -def _filter_fake_docrefs(codebook: deid.Codebook, anon_docrefs_csv: str, docrefs: Iterable[dict]) -> Iterator[dict]: +def _filter_fake_docrefs( + codebook: deid.Codebook, anon_docrefs_csv: str, docrefs: Iterable[dict] +) -> Iterator[dict]: """Calculates the fake ID for all docrefs found, and keeps any that match the csv list""" with common.read_csv(anon_docrefs_csv) as reader: - fake_docref_ids = {row["docref_id"] for row in reader} # ignore the patient_id column, not needed + # ignore the patient_id column, not needed + fake_docref_ids = {row["docref_id"] for row in reader} for docref in docrefs: fake_id = codebook.fake_id("DocumentReference", docref["id"], caching_allowed=False) diff --git a/pyproject.toml b/pyproject.toml index b5cb555f..e4678aec 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -77,7 +77,7 @@ exclude = [ ] [tool.ruff] -line-length = 120 +line-length = 100 [tool.ruff.lint] allowed-confusables = ["’"] # allow proper apostrophes @@ -91,6 +91,15 @@ select = [ "S", # bandit security warnings "UP", # alert you when better syntax is available in your python version ] +ignore = [ + # E501 is the line-too-long check. + # Ruff formatting will generally control the length of Python lines for us. + # But it leaves comments alone. And since we used to have a longer line length (120), + # we have a lot of legacy comments over 100 width. + # Just disable the check for now, rather than manually fixing all 300+ lines. + # Hopefully we can address them slowly over time. + "E501", +] [tool.ruff.lint.per-file-ignores] "**/__init__.py" = ["F401"] # init files hold API, so not using imports is intentional diff --git a/tests/convert/test_convert_cli.py b/tests/convert/test_convert_cli.py index 51a66410..c3662677 100644 --- a/tests/convert/test_convert_cli.py +++ b/tests/convert/test_convert_cli.py @@ -37,7 +37,8 @@ def prepare_original_dir(self) -> str: f"{self.datadir}/covid/term-exists/etl__completion/etl__completion.000.ndjson", f"{self.original_path}/etl__completion/etl__completion.covid.ndjson", ) - os.makedirs(f"{self.original_path}/ignored") # just to confirm we only copy what we understand + # just to confirm we only copy what we understand, add an ignored folder + os.makedirs(f"{self.original_path}/ignored") job_timestamp = "2023-02-28__19.53.08" config_dir = f"{self.original_path}/JobConfig/{job_timestamp}" @@ -46,7 +47,9 @@ def prepare_original_dir(self) -> str: return job_timestamp - async def run_convert(self, input_path: str | None = None, output_path: str | None = None) -> None: + async def run_convert( + self, input_path: str | None = None, output_path: str | None = None + ) -> None: args = [ "convert", input_path or self.original_path, @@ -81,7 +84,8 @@ async def test_happy_path(self): expected_tables = set(os.listdir(self.original_path)) - {"ignored"} self.assertEqual(expected_tables, set(os.listdir(self.target_path))) self.assertEqual( - {"test": True}, common.read_json(f"{self.target_path}/JobConfig/{job_timestamp}/job_config.json") + {"test": True}, + common.read_json(f"{self.target_path}/JobConfig/{job_timestamp}/job_config.json"), ) patients = utils.read_delta_lake(f"{self.target_path}/patient") # spot check some patients self.assertEqual(2, len(patients)) @@ -91,7 +95,10 @@ async def test_happy_path(self): conditions = utils.read_delta_lake(f"{self.target_path}/condition") # and conditions self.assertEqual(2, len(conditions)) self.assertEqual("2010-03-02", conditions[0]["recordedDate"]) - symptoms = utils.read_delta_lake(f"{self.target_path}/covid_symptom__nlp_results_term_exists") # and covid + # and a non-default study table + symptoms = utils.read_delta_lake( + f"{self.target_path}/covid_symptom__nlp_results_term_exists" + ) self.assertEqual(2, len(symptoms)) self.assertEqual("for", symptoms[0]["match"]["text"]) completion = utils.read_delta_lake(f"{self.target_path}/etl__completion") # and completion @@ -107,7 +114,13 @@ async def test_happy_path(self): delta_path = os.path.join(self.tmpdir, "delta") os.makedirs(f"{delta_path}/patient") with common.NdjsonWriter(f"{delta_path}/patient/new.ndjson") as writer: - writer.write({"resourceType": "Patient", "id": "1de9ea66-70d3-da1f-c735-df5ef7697fb9", "birthDate": "1800"}) + writer.write( + { + "resourceType": "Patient", + "id": "1de9ea66-70d3-da1f-c735-df5ef7697fb9", + "birthDate": "1800", + } + ) writer.write({"resourceType": "Patient", "id": "z-gen", "birthDate": "2005"}) os.makedirs(f"{delta_path}/etl__completion_encounters") with common.NdjsonWriter(f"{delta_path}/etl__completion_encounters/new.ndjson") as writer: @@ -120,7 +133,13 @@ async def test_happy_path(self): } ) # Totally new encounter - writer.write({"encounter_id": "NEW", "group_name": "NEW", "export_time": "2021-12-12T17:00:20+00:00"}) + writer.write( + { + "encounter_id": "NEW", + "group_name": "NEW", + "export_time": "2021-12-12T17:00:20+00:00", + } + ) delta_config_dir = f"{delta_path}/JobConfig/{delta_timestamp}" os.makedirs(delta_config_dir) common.write_json(f"{delta_config_dir}/job_config.json", {"delta": "yup"}) @@ -128,25 +147,29 @@ async def test_happy_path(self): # How did that change the delta lake dir? Hopefully we only interwove the new data self.assertEqual( # confirm this is still here - {"test": True}, common.read_json(f"{self.target_path}/JobConfig/{job_timestamp}/job_config.json") + {"test": True}, + common.read_json(f"{self.target_path}/JobConfig/{job_timestamp}/job_config.json"), ) self.assertEqual({"delta": "yup"}, common.read_json(f"{delta_config_dir}/job_config.json")) patients = utils.read_delta_lake(f"{self.target_path}/patient") # re-check the patients self.assertEqual(3, len(patients)) - self.assertEqual("1800", patients[0]["birthDate"]) # these rows are sorted by id, so these are reliable indexes + # these rows are sorted by id, so these are reliable indexes + self.assertEqual("1800", patients[0]["birthDate"]) self.assertEqual("1983", patients[1]["birthDate"]) self.assertEqual("2005", patients[2]["birthDate"]) - conditions = utils.read_delta_lake(f"{self.target_path}/condition") # and conditions shouldn't change at all + # and conditions shouldn't change at all + conditions = utils.read_delta_lake(f"{self.target_path}/condition") self.assertEqual(2, len(conditions)) self.assertEqual("2010-03-02", conditions[0]["recordedDate"]) - comp_enc = utils.read_delta_lake( - f"{self.target_path}/etl__completion_encounters" - ) # and *some* enc mappings did + # but *some* enc mappings did + comp_enc = utils.read_delta_lake(f"{self.target_path}/etl__completion_encounters") self.assertEqual(3, len(comp_enc)) self.assertEqual("08f0ebd4-950c-ddd9-ce97-b5bdf073eed1", comp_enc[0]["encounter_id"]) - self.assertEqual("2020-10-13T12:00:20-05:00", comp_enc[0]["export_time"]) # confirm this *didn't* get updated + # confirm export_time *didn't* get updated + self.assertEqual("2020-10-13T12:00:20-05:00", comp_enc[0]["export_time"]) self.assertEqual("NEW", comp_enc[1]["encounter_id"]) - self.assertEqual("2021-12-12T17:00:20+00:00", comp_enc[1]["export_time"]) # but the new row did get inserted + # but the new row did get inserted + self.assertEqual("2021-12-12T17:00:20+00:00", comp_enc[1]["export_time"]) @mock.patch("cumulus_etl.formats.Format.write_records") async def test_batch_metadata(self, mock_write): @@ -186,4 +209,5 @@ async def test_batch_metadata(self, mock_write): }, mock_write.call_args_list[1][0][0].groups, # first (actual) covid batch ) - self.assertEqual({"nonexistent"}, mock_write.call_args_list[2][0][0].groups) # second (faked) covid batch + # second (faked) covid batch + self.assertEqual({"nonexistent"}, mock_write.call_args_list[2][0][0].groups) diff --git a/tests/covid_symptom/test_covid_results.py b/tests/covid_symptom/test_covid_results.py index 443fbbe4..4cb488ff 100644 --- a/tests/covid_symptom/test_covid_results.py +++ b/tests/covid_symptom/test_covid_results.py @@ -43,17 +43,33 @@ async def test_unknown_modifier_extensions_skipped_for_nlp_symptoms(self): # Invalid codes ([], False), ([{"system": "http://cumulus.smarthealthit.org/i2b2", "code": "NOTE:0"}], False), - ([{"system": "https://fhir.cerner.com/96976f07-eccb-424c-9825-e0d0b887148b/codeSet/72", "code": "0"}], False), + ( + [ + { + "system": "https://fhir.cerner.com/96976f07-eccb-424c-9825-e0d0b887148b/codeSet/72", + "code": "0", + } + ], + False, + ), ([{"system": "http://loinc.org", "code": "00000-0"}], False), ([{"system": "http://example.org", "code": "nope"}], False), # Valid codes ([{"system": "http://cumulus.smarthealthit.org/i2b2", "code": "NOTE:3710480"}], True), ( - [{"system": "https://fhir.cerner.com/96976f07-eccb-424c-9825-e0d0b887148b/codeSet/72", "code": "3710480"}], + [ + { + "system": "https://fhir.cerner.com/96976f07-eccb-424c-9825-e0d0b887148b/codeSet/72", + "code": "3710480", + } + ], True, ), ([{"system": "http://loinc.org", "code": "57053-1"}], True), - ([{"system": "nope", "code": "nope"}, {"system": "http://loinc.org", "code": "57053-1"}], True), + ( + [{"system": "nope", "code": "nope"}, {"system": "http://loinc.org", "code": "57053-1"}], + True, + ), ) @ddt.unpack async def test_ed_note_filtering_for_nlp(self, codings, expected): @@ -98,10 +114,13 @@ async def test_non_ed_visit_is_skipped_for_covid_symptoms(self): ({"docStatus": "entered-in-error"}, False), ({"docStatus": "final"}, True), ({"docStatus": "amended"}, True), - ({}, True), # without any docStatus, we still run NLP on it ("status" is required and can't be skipped) + # without any docStatus, we still run NLP on it ("status" is required and can't be skipped) + ({}, True), ) @ddt.unpack - async def test_bad_doc_status_is_skipped_for_covid_symptoms(self, status: dict, should_process: bool): + async def test_bad_doc_status_is_skipped_for_covid_symptoms( + self, status: dict, should_process: bool + ): """Verify we ignore certain docStatus codes for the covid symptoms NLP""" docref = i2b2_mock_data.documentreference() docref.update(status) @@ -117,8 +136,10 @@ async def test_bad_doc_status_is_skipped_for_covid_symptoms(self, status: dict, ([("http://localhost/file-cough", "text/plain")], "cough"), # handles absolute URL ([("file-cough", "text/html")], "cough"), # handles html ([("file-cough", "application/xhtml+xml")], "cough"), # handles xhtml - ([("file-cough", "text/html"), ("file-fever", "text/plain")], "fever"), # prefers text/plain to html - ([("file-cough", "application/xhtml+xml"), ("file-fever", "text/html")], "fever"), # prefers html to xhtml + # prefers text/plain to html + ([("file-cough", "text/html"), ("file-fever", "text/plain")], "fever"), + # prefers html to xhtml + ([("file-cough", "application/xhtml+xml"), ("file-fever", "text/html")], "fever"), ([("file-cough", "text/nope")], None), # ignores unsupported mimetypes ) @ddt.unpack @@ -131,7 +152,9 @@ async def test_note_urls_downloaded(self, attachments, expected_text, respx_mock respx_mock.post(os.environ["URL_CTAKES_REST"]).pass_through() # ignore cTAKES docref0 = i2b2_mock_data.documentreference() - docref0["content"] = [{"attachment": {"url": a[0], "contentType": a[1]}} for a in attachments] + docref0["content"] = [ + {"attachment": {"url": a[0], "contentType": a[1]}} for a in attachments + ] self.make_json("DocumentReference", "doc0", **docref0) async with self.job_config.client: @@ -153,7 +176,9 @@ async def test_nlp_errors_saved(self): await covid_symptom.CovidSymptomNlpResultsTask(self.job_config, self.scrubber).run() - self.assertEqual(["nlp-errors.ndjson"], os.listdir(f"{self.errors_dir}/covid_symptom__nlp_results")) + self.assertEqual( + ["nlp-errors.ndjson"], os.listdir(f"{self.errors_dir}/covid_symptom__nlp_results") + ) self.assertEqual( ["A", "C"], # pre-scrubbed versions of the docrefs are stored, for easier debugging [ diff --git a/tests/ctakesmock.py b/tests/ctakesmock.py index b68cdbb5..39bb661e 100644 --- a/tests/ctakesmock.py +++ b/tests/ctakesmock.py @@ -27,7 +27,8 @@ class CtakesMixin(unittest.TestCase): def setUp(self): super().setUp() - version_patcher = mock.patch("ctakesclient.__version__", new="1.2.0") # just freeze this in place + # just freeze the version in place + version_patcher = mock.patch("ctakesclient.__version__", new="1.2.0") self.addCleanup(version_patcher.stop) version_patcher.start() @@ -37,7 +38,8 @@ def setUp(self): self._run_fake_ctakes_server(f"{self.ctakes_overrides.name}/symptoms.bsv") cnlp_patcher = mock.patch( - "cumulus_etl.nlp.extract.ctakesclient.transformer.list_polarity", side_effect=fake_transformer_list_polarity + "cumulus_etl.nlp.extract.ctakesclient.transformer.list_polarity", + side_effect=fake_transformer_list_polarity, ) self.addCleanup(cnlp_patcher.stop) self.cnlp_mock = cnlp_patcher.start() @@ -59,7 +61,11 @@ def _run_fake_ctakes_server(self, overrides_path: str) -> None: self._ctakes_called.value = 0 self.ctakes_server = multiprocessing.Process( target=partial( - _serve_with_restarts, overrides_path, CtakesMixin.ctakes_port, self._ctakes_called, has_started + _serve_with_restarts, + overrides_path, + CtakesMixin.ctakes_port, + self._ctakes_called, + has_started, ), daemon=True, ) @@ -76,7 +82,10 @@ def _get_mtime(path) -> float | None: def _serve_with_restarts( - overrides_path: str, port: int, was_called: multiprocessing.Value, has_started: multiprocessing.Event + overrides_path: str, + port: int, + was_called: multiprocessing.Value, + has_started: multiprocessing.Event, ) -> None: server_address = ("", port) mtime = None @@ -174,8 +183,18 @@ def fake_ctakes_extract(sentence: str) -> typesystem.CtakesJSON: "text": fever_word, "polarity": 0, "conceptAttributes": [ - {"code": "386661006", "cui": "C0015967", "codingScheme": "SNOMEDCT_US", "tui": "T184"}, - {"code": "50177009", "cui": "C0015967", "codingScheme": "SNOMEDCT_US", "tui": "T184"}, + { + "code": "386661006", + "cui": "C0015967", + "codingScheme": "SNOMEDCT_US", + "tui": "T184", + }, + { + "code": "50177009", + "cui": "C0015967", + "codingScheme": "SNOMEDCT_US", + "tui": "T184", + }, ], "type": "SignSymptomMention", }, @@ -186,7 +205,12 @@ def fake_ctakes_extract(sentence: str) -> typesystem.CtakesJSON: "text": fever_word, "polarity": 0, "conceptAttributes": [ - {"code": "422587007", "cui": "C0027497", "codingScheme": "SNOMEDCT_US", "tui": "T184"}, + { + "code": "422587007", + "cui": "C0027497", + "codingScheme": "SNOMEDCT_US", + "tui": "T184", + }, ], "type": "SignSymptomMention", }, @@ -197,10 +221,30 @@ def fake_ctakes_extract(sentence: str) -> typesystem.CtakesJSON: "text": itch_word, "polarity": 0, "conceptAttributes": [ - {"code": "418290006", "cui": "C0033774", "codingScheme": "SNOMEDCT_US", "tui": "T184"}, - {"code": "279333002", "cui": "C0033774", "codingScheme": "SNOMEDCT_US", "tui": "T184"}, - {"code": "424492005", "cui": "C0033774", "codingScheme": "SNOMEDCT_US", "tui": "T184"}, - {"code": "418363000", "cui": "C0033774", "codingScheme": "SNOMEDCT_US", "tui": "T184"}, + { + "code": "418290006", + "cui": "C0033774", + "codingScheme": "SNOMEDCT_US", + "tui": "T184", + }, + { + "code": "279333002", + "cui": "C0033774", + "codingScheme": "SNOMEDCT_US", + "tui": "T184", + }, + { + "code": "424492005", + "cui": "C0033774", + "codingScheme": "SNOMEDCT_US", + "tui": "T184", + }, + { + "code": "418363000", + "cui": "C0033774", + "codingScheme": "SNOMEDCT_US", + "tui": "T184", + }, ], "type": "SignSymptomMention", }, diff --git a/tests/deid/test_deid_codebook.py b/tests/deid/test_deid_codebook.py index 3abab25d..fe7dc635 100644 --- a/tests/deid/test_deid_codebook.py +++ b/tests/deid/test_deid_codebook.py @@ -37,8 +37,11 @@ def test_hashed_type(self): fake_id = cb.fake_id("Condition", "1") self.assertEqual(fake_id, cb.fake_id("Condition", "1")) self.assertNotEqual(fake_id, cb.fake_id("Condition", "2")) - self.assertEqual(fake_id, cb.fake_id("Observation", "1")) # '1' hashes the same across types - self.assertEqual("ee1b8555df1476e7512bc31940148a7821edae6e152e92037e6e8d7e948800a4", fake_id) + # '1' hashes the same across types + self.assertEqual(fake_id, cb.fake_id("Observation", "1")) + self.assertEqual( + "ee1b8555df1476e7512bc31940148a7821edae6e152e92037e6e8d7e948800a4", fake_id + ) self.assertEqual("31323334", cb.db.settings.get("id_salt")) def test_missing_db_file(self): @@ -111,7 +114,10 @@ def test_save_and_load(self): db.save(tmpdir) # Verify that we saved the cached mapping to disk too - self.assertEqual(expected_mapping, common.read_json(os.path.join(tmpdir, "codebook-cached-mappings.json"))) + self.assertEqual( + expected_mapping, + common.read_json(os.path.join(tmpdir, "codebook-cached-mappings.json")), + ) db2 = CodebookDB(tmpdir) diff --git a/tests/deid/test_deid_philter.py b/tests/deid/test_deid_philter.py index 2959503e..a9ea1a72 100644 --- a/tests/deid/test_deid_philter.py +++ b/tests/deid/test_deid_philter.py @@ -15,16 +15,28 @@ def setUp(self): self.scrubber = deid.Scrubber(use_philter=True) @ddt.data( - ({"CodeableConcept": {"text": "Fever at 123 Main St"}}, {"CodeableConcept": {"text": "Fever at *** **** **"}}), - ({"Coding": {"display": "Patient 012-34-5678"}}, {"Coding": {"display": "Patient ***-**-****"}}), + ( + {"CodeableConcept": {"text": "Fever at 123 Main St"}}, + {"CodeableConcept": {"text": "Fever at *** **** **"}}, + ), + ( + {"Coding": {"display": "Patient 012-34-5678"}}, + {"Coding": {"display": "Patient ***-**-****"}}, + ), ( # philter catches the month for some reason, but correctly leaves the date numbers alone {"resourceType": "Observation", "valueString": "Born on december 12 2012"}, {"resourceType": "Observation", "valueString": "Born on ******** 12 2012"}, ), ( - {"resourceType": "Observation", "component": [{"valueString": "Contact at foo@bar.com"}]}, - {"resourceType": "Observation", "component": [{"valueString": "Contact at ***@***.***"}]}, + { + "resourceType": "Observation", + "component": [{"valueString": "Contact at foo@bar.com"}], + }, + { + "resourceType": "Observation", + "component": [{"valueString": "Contact at ***@***.***"}], + }, ), ) @ddt.unpack diff --git a/tests/deid/test_deid_scrubber.py b/tests/deid/test_deid_scrubber.py index c1bd1a2f..3f044515 100644 --- a/tests/deid/test_deid_scrubber.py +++ b/tests/deid/test_deid_scrubber.py @@ -32,7 +32,10 @@ def test_encounter(self): scrubber = Scrubber() self.assertTrue(scrubber.scrub_resource(encounter)) self.assertEqual(encounter["id"], scrubber.codebook.fake_id("Encounter", "67890")) - self.assertEqual(encounter["subject"]["reference"], f"Patient/{scrubber.codebook.fake_id('Patient', '12345')}") + self.assertEqual( + encounter["subject"]["reference"], + f"Patient/{scrubber.codebook.fake_id('Patient', '12345')}", + ) def test_condition(self): """Verify a basic condition (hashed ids)""" @@ -44,9 +47,13 @@ def test_condition(self): scrubber = Scrubber() self.assertTrue(scrubber.scrub_resource(condition)) self.assertEqual(condition["id"], scrubber.codebook.fake_id("Condition", "4567")) - self.assertEqual(condition["subject"]["reference"], f"Patient/{scrubber.codebook.fake_id('Patient', '12345')}") self.assertEqual( - condition["encounter"]["reference"], f"Encounter/{scrubber.codebook.fake_id('Encounter', '67890')}" + condition["subject"]["reference"], + f"Patient/{scrubber.codebook.fake_id('Patient', '12345')}", + ) + self.assertEqual( + condition["encounter"]["reference"], + f"Encounter/{scrubber.codebook.fake_id('Encounter', '67890')}", ) def test_documentreference(self): @@ -62,7 +69,10 @@ def test_documentreference(self): scrubber = Scrubber() self.assertTrue(scrubber.scrub_resource(docref)) self.assertEqual(docref["id"], scrubber.codebook.fake_id("DocumentReference", "345")) - self.assertEqual(docref["subject"]["reference"], f"Patient/{scrubber.codebook.fake_id('Patient', '12345')}") + self.assertEqual( + docref["subject"]["reference"], + f"Patient/{scrubber.codebook.fake_id('Patient', '12345')}", + ) self.assertEqual( docref["context"]["encounter"][0]["reference"], f"Encounter/{scrubber.codebook.fake_id('Encounter', '67890')}", @@ -99,7 +109,9 @@ def test_unknown_modifier_extension(self): def test_nlp_extensions_allowed(self): """Confirm we that nlp-generated resources are allowed, with their modifier extensions""" - match = typesystem.MatchText({"begin": 0, "end": 1, "polarity": 0, "text": "f", "type": "SignSymptomMention"}) + match = typesystem.MatchText( + {"begin": 0, "end": 1, "polarity": 0, "text": "f", "type": "SignSymptomMention"} + ) observation = text2fhir.nlp_observation("1", "2", "3", match).as_json() scrubber = Scrubber() diff --git a/tests/etl/base.py b/tests/etl/base.py index d10480d9..3dedebf8 100644 --- a/tests/etl/base.py +++ b/tests/etl/base.py @@ -133,7 +133,9 @@ def setUp(self) -> None: batch_size=5, dir_errors=self.errors_dir, export_group_name="test-group", - export_datetime=datetime.datetime(2012, 10, 10, 5, 30, 12, tzinfo=datetime.timezone.utc), + export_datetime=datetime.datetime( + 2012, 10, 10, 5, 30, 12, tzinfo=datetime.timezone.utc + ), ) def make_formatter(dbname: str, **kwargs): @@ -168,5 +170,6 @@ def make_json(self, resource_type, resource_id, **kwargs): self.json_file_count += 1 filename = f"{self.json_file_count}.ndjson" common.write_json( - os.path.join(self.input_dir, filename), {"resourceType": resource_type, **kwargs, "id": resource_id} + os.path.join(self.input_dir, filename), + {"resourceType": resource_type, **kwargs, "id": resource_id}, ) diff --git a/tests/etl/test_etl_cli.py b/tests/etl/test_etl_cli.py index 55c6fb23..c82c93b6 100644 --- a/tests/etl/test_etl_cli.py +++ b/tests/etl/test_etl_cli.py @@ -66,7 +66,9 @@ def fake_scrub(phi_dir: str): # Run a couple checks to ensure that we do indeed have PHI in this dir self.assertIn("Patient.ndjson", os.listdir(phi_dir)) - patients = list(cumulus_fhir_support.read_multiline_json(os.path.join(phi_dir, "Patient.ndjson"))) + patients = list( + cumulus_fhir_support.read_multiline_json(os.path.join(phi_dir, "Patient.ndjson")) + ) first = patients[0] self.assertEqual("02139", first["address"][0]["postalCode"]) @@ -87,14 +89,18 @@ async def test_unknown_task(self): async def test_failed_task(self): # Make it so any writes will fail - with mock.patch("cumulus_etl.formats.ndjson.NdjsonFormat.write_format", side_effect=Exception): + with mock.patch( + "cumulus_etl.formats.ndjson.NdjsonFormat.write_format", side_effect=Exception + ): with self.assertRaises(SystemExit) as cm: await self.run_etl() self.assertEqual(errors.TASK_FAILED, cm.exception.code) async def test_single_task(self): # Grab all observations before we mock anything - observations = loaders.FhirNdjsonLoader(store.Root(self.input_path)).load_all(["Observation"]) + observations = loaders.FhirNdjsonLoader(store.Root(self.input_path)).load_all( + ["Observation"] + ) def fake_load_all(internal_self, resources): del internal_self @@ -106,12 +112,18 @@ def fake_load_all(internal_self, resources): await self.run_etl(tasks=["observation"]) # Confirm we only wrote the one resource - self.assertEqual({"etl__completion", "observation", "JobConfig"}, set(os.listdir(self.output_path))) - self.assertEqual(["observation.000.ndjson"], os.listdir(os.path.join(self.output_path, "observation"))) + self.assertEqual( + {"etl__completion", "observation", "JobConfig"}, set(os.listdir(self.output_path)) + ) + self.assertEqual( + ["observation.000.ndjson"], os.listdir(os.path.join(self.output_path, "observation")) + ) async def test_multiple_tasks(self): # Grab all observations before we mock anything - loaded = loaders.FhirNdjsonLoader(store.Root(self.input_path)).load_all(["Observation", "Patient"]) + loaded = loaders.FhirNdjsonLoader(store.Root(self.input_path)).load_all( + ["Observation", "Patient"] + ) def fake_load_all(internal_self, resources): del internal_self @@ -123,9 +135,16 @@ def fake_load_all(internal_self, resources): await self.run_etl(tasks=["observation", "patient"]) # Confirm we only wrote the two resources - self.assertEqual({"etl__completion", "observation", "patient", "JobConfig"}, set(os.listdir(self.output_path))) - self.assertEqual(["observation.000.ndjson"], os.listdir(os.path.join(self.output_path, "observation"))) - self.assertEqual(["patient.000.ndjson"], os.listdir(os.path.join(self.output_path, "patient"))) + self.assertEqual( + {"etl__completion", "observation", "patient", "JobConfig"}, + set(os.listdir(self.output_path)), + ) + self.assertEqual( + ["observation.000.ndjson"], os.listdir(os.path.join(self.output_path, "observation")) + ) + self.assertEqual( + ["patient.000.ndjson"], os.listdir(os.path.join(self.output_path, "patient")) + ) async def test_codebook_is_saved_during(self): """Verify that we are saving the codebook as we go""" @@ -136,7 +155,9 @@ async def test_codebook_is_saved_during(self): # Cause a system exit as soon as we try to write a file. # The goal is that the codebook is already in place by this time. with self.assertRaises(SystemExit): - with mock.patch("cumulus_etl.formats.ndjson.NdjsonFormat.write_format", side_effect=SystemExit): + with mock.patch( + "cumulus_etl.formats.ndjson.NdjsonFormat.write_format", side_effect=SystemExit + ): await self.run_etl(tasks=["patient"]) # Ensure we wrote a valid codebook out @@ -221,7 +242,9 @@ async def test_task_init_checks(self, mock_check): async def test_completion_args(self, etl_args, loader_vals, expected_vals): """Verify that we parse completion args with the correct fallbacks and checks.""" # Grab all observations before we mock anything - observations = loaders.FhirNdjsonLoader(store.Root(self.input_path)).load_all(["Observation"]) + observations = loaders.FhirNdjsonLoader(store.Root(self.input_path)).load_all( + ["Observation"] + ) def fake_load_all(internal_self, resources): del resources @@ -412,11 +435,15 @@ def setUp(self): ] def path_for_checksum(self, prefix, checksum): - return os.path.join(self.phi_path, "ctakes-cache", prefix, checksum[0:4], f"sha256-{checksum}.json") + return os.path.join( + self.phi_path, "ctakes-cache", prefix, checksum[0:4], f"sha256-{checksum}.json" + ) def read_symptoms(self): """Loads the output symptoms ndjson from disk""" - path = os.path.join(self.output_path, "covid_symptom__nlp_results", "covid_symptom__nlp_results.000.ndjson") + path = os.path.join( + self.output_path, "covid_symptom__nlp_results", "covid_symptom__nlp_results.000.ndjson" + ) with open(path, encoding="utf8") as f: lines = f.readlines() return [json.loads(line) for line in lines] @@ -433,8 +460,13 @@ async def test_stores_cached_json(self): for index, checksum in enumerate(self.expected_checksums): ner = fake_ctakes_extract(facts[index]) - self.assertEqual(ner.as_json(), common.read_json(self.path_for_checksum(self.CACHE_FOLDER, checksum))) - self.assertEqual([0, 0], common.read_json(self.path_for_checksum(f"{self.CACHE_FOLDER}-cnlp_v2", checksum))) + self.assertEqual( + ner.as_json(), common.read_json(self.path_for_checksum(self.CACHE_FOLDER, checksum)) + ) + self.assertEqual( + [0, 0], + common.read_json(self.path_for_checksum(f"{self.CACHE_FOLDER}-cnlp_v2", checksum)), + ) async def test_does_not_hit_server_if_cache_exists(self): for index, checksum in enumerate(self.expected_checksums): @@ -452,7 +484,12 @@ async def test_does_not_hit_server_if_cache_exists(self): "polarity": 0, "type": "SignSymptomMention", "conceptAttributes": [ - {"code": "68235000", "cui": "C0027424", "codingScheme": "SNOMEDCT_US", "tui": "T184"}, + { + "code": "68235000", + "cui": "C0027424", + "codingScheme": "SNOMEDCT_US", + "tui": "T184", + }, ], } ], @@ -475,7 +512,8 @@ async def test_does_not_hit_server_if_cache_exists(self): self.assertEqual({"foobar0", "foobar1"}, {x["match"]["text"] for x in symptoms}) for symptom in symptoms: self.assertEqual( - {("68235000", "C0027424")}, {(x["code"], x["cui"]) for x in symptom["match"]["conceptAttributes"]} + {("68235000", "C0027424")}, + {(x["code"], x["cui"]) for x in symptom["match"]["conceptAttributes"]}, ) @respx.mock @@ -522,7 +560,8 @@ async def test_cnlp_rejects(self): self.assertEqual(2, len(symptoms)) # Confirm that the only symptom to survive was the second nausea one self.assertEqual( - {("422587007", "C0027497")}, {(x["code"], x["cui"]) for x in symptoms[0]["match"]["conceptAttributes"]} + {("422587007", "C0027497")}, + {(x["code"], x["cui"]) for x in symptoms[0]["match"]["conceptAttributes"]}, ) async def test_non_covid_symptoms_skipped(self): @@ -530,8 +569,11 @@ async def test_non_covid_symptoms_skipped(self): await self.run_etl(tasks=["covid_symptom__nlp_results"]) symptoms = self.read_symptoms() - self.assertEqual({"for"}, {x["match"]["text"] for x in symptoms}) # the second word ("for") is the fever word - attributes = itertools.chain.from_iterable(symptom["match"]["conceptAttributes"] for symptom in symptoms) + # the second word ("for") is the fever word + self.assertEqual({"for"}, {x["match"]["text"] for x in symptoms}) + attributes = itertools.chain.from_iterable( + symptom["match"]["conceptAttributes"] for symptom in symptoms + ) cuis = {x["cui"] for x in attributes} self.assertEqual({"C0027497", "C0015967"}, cuis) # notably, no C0033774 itch CUI diff --git a/tests/etl/test_etl_context.py b/tests/etl/test_etl_context.py index 06e217a2..8730d12d 100644 --- a/tests/etl/test_etl_context.py +++ b/tests/etl/test_etl_context.py @@ -33,7 +33,9 @@ def test_save_and_load(self): context.as_json(), ) - context.last_successful_datetime = datetime.datetime(2008, 5, 1, 14, 30, 30, tzinfo=datetime.timezone.utc) + context.last_successful_datetime = datetime.datetime( + 2008, 5, 1, 14, 30, 30, tzinfo=datetime.timezone.utc + ) self.assertEqual( { "last_successful_datetime": "2008-05-01T14:30:30+00:00", @@ -48,7 +50,9 @@ def test_save_and_load(self): def test_last_successful_props(self): context = JobContext("nope") - context.last_successful_datetime = datetime.datetime(2008, 5, 1, 14, 30, 30, tzinfo=datetime.timezone.utc) + context.last_successful_datetime = datetime.datetime( + 2008, 5, 1, 14, 30, 30, tzinfo=datetime.timezone.utc + ) context.last_successful_input_dir = "/input" context.last_successful_output_dir = "/output" self.assertEqual( diff --git a/tests/etl/test_tasks.py b/tests/etl/test_tasks.py index 663a0988..b417d499 100644 --- a/tests/etl/test_tasks.py +++ b/tests/etl/test_tasks.py @@ -63,8 +63,10 @@ def test_filtered_but_named_task(self): (None, "default"), ([], "default"), (filter(None, []), "default"), # iterable, not list - (["observation", "condition", "procedure"], ["condition", "observation", "procedure"]), # re-ordered - (["condition", "patient", "encounter"], ["encounter", "patient", "condition"]), # encounter and patient first + # re-ordered + (["observation", "condition", "procedure"], ["condition", "observation", "procedure"]), + # encounter and patient first + (["condition", "patient", "encounter"], ["encounter", "patient", "condition"]), ) @ddt.unpack def test_task_selection_ordering(self, user_tasks, expected_tasks): @@ -88,7 +90,8 @@ async def test_drop_duplicates(self): batch = self.format.write_records.call_args[0][0] self.assertEqual(2, len(batch.rows)) self.assertEqual( - {self.codebook.db.patient("A"), self.codebook.db.patient("B")}, {row["id"] for row in batch.rows} + {self.codebook.db.patient("A"), self.codebook.db.patient("B")}, + {row["id"] for row in batch.rows}, ) async def test_batch_write_errors_saved(self): @@ -102,14 +105,21 @@ async def test_batch_write_errors_saved(self): await basic_tasks.PatientTask(self.job_config, self.scrubber).run() self.assertEqual( - ["write-error.000.ndjson", "write-error.002.ndjson"], list(sorted(os.listdir(f"{self.errors_dir}/patient"))) + ["write-error.000.ndjson", "write-error.002.ndjson"], + list(sorted(os.listdir(f"{self.errors_dir}/patient"))), ) self.assertEqual( - {"resourceType": "Patient", "id": "30d95f17d9f51f3a151c51bf0a7fcb1717363f3a87d2dbace7d594ee68d3a82f"}, + { + "resourceType": "Patient", + "id": "30d95f17d9f51f3a151c51bf0a7fcb1717363f3a87d2dbace7d594ee68d3a82f", + }, common.read_json(f"{self.errors_dir}/patient/write-error.000.ndjson"), ) self.assertEqual( - {"resourceType": "Patient", "id": "ed9ab553005a7c9bdb26ecf9f612ea996ad99b1a96a34bf88c260f1c901d8289"}, + { + "resourceType": "Patient", + "id": "ed9ab553005a7c9bdb26ecf9f612ea996ad99b1a96a34bf88c260f1c901d8289", + }, common.read_json(f"{self.errors_dir}/patient/write-error.002.ndjson"), ) @@ -283,7 +293,11 @@ class TestMedicationRequestTask(TaskTestCase): async def test_inline_codes(self): """Verify that we handle basic normal inline codes (no external fetching) as a baseline""" - self.make_json("MedicationRequest", "InlineCode", medicationCodeableConcept={"text": "Old but checks out"}) + self.make_json( + "MedicationRequest", + "InlineCode", + medicationCodeableConcept={"text": "Old but checks out"}, + ) self.make_json("MedicationRequest", "NoCode") await basic_tasks.MedicationRequestTask(self.job_config, self.scrubber).run() @@ -318,7 +332,10 @@ async def test_contained_medications(self): # Confirm we wrote the basic MedicationRequest self.assertEqual(1, med_req_format.write_records.call_count) batch = med_req_format.write_records.call_args[0][0] - self.assertEqual(f'#{self.codebook.db.resource_hash("123")}', batch.rows[0]["medicationReference"]["reference"]) + self.assertEqual( + f'#{self.codebook.db.resource_hash("123")}', + batch.rows[0]["medicationReference"]["reference"], + ) # Confirm we wrote an empty dataframe to the medication table self.assertEqual(1, med_format.write_records.call_count) @@ -328,7 +345,9 @@ async def test_contained_medications(self): @mock.patch("cumulus_etl.fhir.download_reference") async def test_external_medications(self, mock_download): """Verify that we download referenced medications""" - self.make_json("MedicationRequest", "A", medicationReference={"reference": "Medication/123"}) + self.make_json( + "MedicationRequest", "A", medicationReference={"reference": "Medication/123"} + ) mock_download.return_value = {"resourceType": "Medication", "id": "med1"} await basic_tasks.MedicationRequestTask(self.job_config, self.scrubber).run() @@ -376,7 +395,9 @@ async def test_external_medications(self, mock_download): @mock.patch("cumulus_etl.fhir.download_reference") async def test_external_medication_scrubbed(self, mock_download): """Verify that we scrub referenced medications as we download them""" - self.make_json("MedicationRequest", "A", medicationReference={"reference": "Medication/123"}) + self.make_json( + "MedicationRequest", "A", medicationReference={"reference": "Medication/123"} + ) mock_download.return_value = { "resourceType": "Medication", "id": "med1", @@ -405,9 +426,15 @@ async def test_external_medication_scrubbed(self, mock_download): @mock.patch("cumulus_etl.fhir.download_reference") async def test_external_medications_with_error(self, mock_download): """Verify that we record/save download errors""" - self.make_json("MedicationRequest", "A", medicationReference={"reference": "Medication/123"}) - self.make_json("MedicationRequest", "B", medicationReference={"reference": "Medication/456"}) - self.make_json("MedicationRequest", "C", medicationReference={"reference": "Medication/789"}) + self.make_json( + "MedicationRequest", "A", medicationReference={"reference": "Medication/123"} + ) + self.make_json( + "MedicationRequest", "B", medicationReference={"reference": "Medication/456"} + ) + self.make_json( + "MedicationRequest", "C", medicationReference={"reference": "Medication/789"} + ) mock_download.side_effect = [ # Fail on first and third ValueError("bad hostname"), {"resourceType": "Medication", "id": "medB"}, @@ -427,25 +454,37 @@ async def test_external_medications_with_error(self, mock_download): # Confirm we still wrote out the medication for B self.assertEqual(1, med_format.write_records.call_count) batch = med_format.write_records.call_args[0][0] - self.assertEqual([self.codebook.db.resource_hash("medB")], [row["id"] for row in batch.rows]) + self.assertEqual( + [self.codebook.db.resource_hash("medB")], [row["id"] for row in batch.rows] + ) # And we saved the error? med_error_dir = f"{self.errors_dir}/medicationrequest" - self.assertEqual(["medication-fetch-errors.ndjson"], list(sorted(os.listdir(med_error_dir)))) + self.assertEqual( + ["medication-fetch-errors.ndjson"], list(sorted(os.listdir(med_error_dir))) + ) self.assertEqual( ["A", "C"], # pre-scrubbed versions of the resources are stored, for easier debugging [ x["id"] - for x in cumulus_fhir_support.read_multiline_json(f"{med_error_dir}/medication-fetch-errors.ndjson") + for x in cumulus_fhir_support.read_multiline_json( + f"{med_error_dir}/medication-fetch-errors.ndjson" + ) ], ) @mock.patch("cumulus_etl.fhir.download_reference") async def test_external_medications_skips_duplicates(self, mock_download): """Verify that we skip medications that are repeated""" - self.make_json("MedicationRequest", "A", medicationReference={"reference": "Medication/dup"}) - self.make_json("MedicationRequest", "B", medicationReference={"reference": "Medication/dup"}) - self.make_json("MedicationRequest", "C", medicationReference={"reference": "Medication/new"}) + self.make_json( + "MedicationRequest", "A", medicationReference={"reference": "Medication/dup"} + ) + self.make_json( + "MedicationRequest", "B", medicationReference={"reference": "Medication/dup"} + ) + self.make_json( + "MedicationRequest", "C", medicationReference={"reference": "Medication/new"} + ) self.job_config.batch_size = 1 # to confirm we detect duplicates even across batches mock_download.side_effect = [ {"resourceType": "Medication", "id": "dup"}, @@ -475,8 +514,12 @@ async def test_external_medications_skips_duplicates(self, mock_download): @mock.patch("cumulus_etl.fhir.download_reference") async def test_external_medications_skips_unknown_modifiers(self, mock_download): """Verify that we skip medications with unknown modifier extensions (unlikely, but still)""" - self.make_json("MedicationRequest", "A", medicationReference={"reference": "Medication/odd"}) - self.make_json("MedicationRequest", "B", medicationReference={"reference": "Medication/good"}) + self.make_json( + "MedicationRequest", "A", medicationReference={"reference": "Medication/odd"} + ) + self.make_json( + "MedicationRequest", "B", medicationReference={"reference": "Medication/good"} + ) mock_download.side_effect = [ { "resourceType": "Medication", @@ -495,4 +538,6 @@ async def test_external_medications_skips_unknown_modifiers(self, mock_download) self.assertEqual(1, med_format.write_records.call_count) batch = med_format.write_records.call_args[0][0] - self.assertEqual([self.codebook.db.resource_hash("good")], [row["id"] for row in batch.rows]) # no "odd" + self.assertEqual( # no "odd" + [self.codebook.db.resource_hash("good")], [row["id"] for row in batch.rows] + ) diff --git a/tests/fhir/test_fhir_client.py b/tests/fhir/test_fhir_client.py index e8e7d34d..0a3092e1 100644 --- a/tests/fhir/test_fhir_client.py +++ b/tests/fhir/test_fhir_client.py @@ -27,7 +27,9 @@ def setUp(self): # By default, set up a working server and auth. Tests can break things as needed. self.client_id = "my-client-id" - self.jwk = jwk.JWK.generate(kty="RSA", alg="RS384", kid="a", key_ops=["sign", "verify"]).export(as_dict=True) + self.jwk = jwk.JWK.generate( + kty="RSA", alg="RS384", kid="a", key_ops=["sign", "verify"] + ).export(as_dict=True) self.jwks = {"keys": [self.jwk]} self.server_url = "https://example.com/fhir" self.token_url = "https://auth.example.com/token" @@ -120,7 +122,10 @@ async def test_auth_with_jwks(self): ) async with fhir.FhirClient( - self.server_url, ["Condition", "Patient"], smart_client_id=self.client_id, smart_jwks=self.jwks + self.server_url, + ["Condition", "Patient"], + smart_client_id=self.client_id, + smart_jwks=self.jwks, ) as client: await client.request("GET", "foo") @@ -145,7 +150,9 @@ async def test_auth_with_bearer_token(self): headers={"Authorization": "Bearer fob"}, ) - async with fhir.FhirClient(self.server_url, ["Condition", "Patient"], bearer_token="fob") as server: + async with fhir.FhirClient( + self.server_url, ["Condition", "Patient"], bearer_token="fob" + ) as server: await server.request("GET", "foo") async def test_auth_with_basic_auth(self): @@ -155,7 +162,9 @@ async def test_auth_with_basic_auth(self): headers={"Authorization": "Basic VXNlcjpwNHNzdzByZA=="}, ) - async with fhir.FhirClient(self.server_url, [], basic_user="User", basic_password="p4ssw0rd") as server: + async with fhir.FhirClient( + self.server_url, [], basic_user="User", basic_password="p4ssw0rd" + ) as server: await server.request("GET", "foo") async def test_get_with_new_header(self): @@ -168,7 +177,9 @@ async def test_get_with_new_header(self): }, ) - async with fhir.FhirClient(self.server_url, [], smart_client_id=self.client_id, smart_jwks=self.jwks) as server: + async with fhir.FhirClient( + self.server_url, [], smart_client_id=self.client_id, smart_jwks=self.jwks + ) as server: # With new header and stream await server.request("GET", "foo", headers={"Test": "Value"}, stream=True) @@ -181,7 +192,9 @@ async def test_get_with_overriden_header(self): }, ) - async with fhir.FhirClient(self.server_url, [], smart_client_id=self.client_id, smart_jwks=self.jwks) as server: + async with fhir.FhirClient( + self.server_url, [], smart_client_id=self.client_id, smart_jwks=self.jwks + ) as server: # With overriding a header and default stream (False) await server.request("GET", "bar", headers={"Accept": "text/plain"}) @@ -194,7 +207,9 @@ async def test_get_with_overriden_header(self): ) async def test_jwks_without_suitable_key(self, bad_jwks): with self.assertRaisesRegex(errors.FatalError, "No valid private key found"): - async with fhir.FhirClient(self.server_url, [], smart_client_id=self.client_id, smart_jwks=bad_jwks): + async with fhir.FhirClient( + self.server_url, [], smart_client_id=self.client_id, smart_jwks=bad_jwks + ): pass @ddt.data( @@ -218,12 +233,17 @@ async def test_bad_smart_config(self, bad_config_override): ) with self.assertRaisesRegex(SystemExit, str(errors.FHIR_AUTH_FAILED)): - async with fhir.FhirClient(self.server_url, [], smart_client_id=self.client_id, smart_jwks=self.jwks): + async with fhir.FhirClient( + self.server_url, [], smart_client_id=self.client_id, smart_jwks=self.jwks + ): pass @ddt.data( ({"json": {"error_description": "Ouch!"}}, "Ouch!"), - ({"json": {"error_uri": "http://ouch.com/sadface"}}, 'visit "http://ouch.com/sadface" for more details'), + ( + {"json": {"error_uri": "http://ouch.com/sadface"}}, + 'visit "http://ouch.com/sadface" for more details', + ), # If nothing comes back, we use the default httpx error message ( {}, @@ -239,11 +259,16 @@ async def test_authorize_error(self, response_params, expected_error): self.respx_mock["token"].respond(400, **response_params) with mock.patch("cumulus_etl.errors.fatal") as mock_fatal: - async with fhir.FhirClient(self.server_url, [], smart_client_id=self.client_id, smart_jwks=self.jwks): + async with fhir.FhirClient( + self.server_url, [], smart_client_id=self.client_id, smart_jwks=self.jwks + ): pass self.assertEqual( - mock.call(f"Could not authenticate with the FHIR server: {expected_error}", errors.FHIR_AUTH_FAILED), + mock.call( + f"Could not authenticate with the FHIR server: {expected_error}", + errors.FHIR_AUTH_FAILED, + ), mock_fatal.call_args, ) @@ -252,7 +277,9 @@ async def test_get_error_401(self): route = self.respx_mock.get(f"{self.server_url}/foo") route.side_effect = [make_response(status_code=401), make_response()] - async with fhir.FhirClient(self.server_url, [], smart_client_id=self.client_id, smart_jwks=self.jwks) as server: + async with fhir.FhirClient( + self.server_url, [], smart_client_id=self.client_id, smart_jwks=self.jwks + ) as server: self.assertEqual(1, self.respx_mock["token"].call_count) # Check that we correctly tried to re-authenticate @@ -266,7 +293,9 @@ async def test_get_error_429(self): self.respx_mock.get(f"{self.server_url}/retry-me").respond(429) self.respx_mock.get(f"{self.server_url}/nope").respond(430) - async with fhir.FhirClient(self.server_url, [], smart_client_id=self.client_id, smart_jwks=self.jwks) as server: + async with fhir.FhirClient( + self.server_url, [], smart_client_id=self.client_id, smart_jwks=self.jwks + ) as server: # Confirm 429 passes response = await server.request("GET", "retry-me") self.assertEqual(429, response.status_code) @@ -276,10 +305,15 @@ async def test_get_error_429(self): await server.request("GET", "nope") @ddt.data( + # OperationOutcome { - "json_payload": {"resourceType": "OperationOutcome", "issue": [{"diagnostics": "testmsg"}]} - }, # OperationOutcome - {"json_payload": {"issue": [{"diagnostics": "msg"}]}, "reason": "testmsg"}, # non-OperationOutcome json + "json_payload": { + "resourceType": "OperationOutcome", + "issue": [{"diagnostics": "testmsg"}], + } + }, + # non-OperationOutcome json + {"json_payload": {"issue": [{"diagnostics": "msg"}]}, "reason": "testmsg"}, {"text": "testmsg"}, # just pure text content {"reason": "testmsg"}, ) @@ -291,7 +325,9 @@ async def test_get_error_other(self, response_args): return_value=make_response(status_code=500, **response_args), ) - async with fhir.FhirClient(self.server_url, [], smart_client_id=self.client_id, smart_jwks=self.jwks) as server: + async with fhir.FhirClient( + self.server_url, [], smart_client_id=self.client_id, smart_jwks=self.jwks + ) as server: with self.assertRaisesRegex(errors.FatalError, "testmsg"): await server.request("GET", "foo") @@ -304,7 +340,12 @@ async def test_get_error_other(self, response_args): def test_added_binary_scope(self, resources_in, expected_resources_out, mock_client): """Verify that we add a Binary scope if DocumentReference is requested""" args = argparse.Namespace( - fhir_url=None, smart_client_id=None, smart_jwks=None, basic_user=None, basic_passwd=None, bearer_token=None + fhir_url=None, + smart_client_id=None, + smart_jwks=None, + basic_user=None, + basic_passwd=None, + bearer_token=None, ) fhir.create_fhir_client_for_cli(args, store.Root("/tmp"), resources_in) self.assertEqual(mock_client.call_args[0][1], expected_resources_out) @@ -351,6 +392,8 @@ async def test_client_id_in_header(self, server_type, expected_text): ) self.mock_as_server_type(server_type) - async with fhir.FhirClient(self.server_url, [], bearer_token="foo", smart_client_id="my-id") as server: + async with fhir.FhirClient( + self.server_url, [], bearer_token="foo", smart_client_id="my-id" + ) as server: response = await server.request("GET", "file") self.assertEqual(expected_text, response.text) diff --git a/tests/fhir/test_fhir_utils.py b/tests/fhir/test_fhir_utils.py index d8638e91..9c432bb6 100644 --- a/tests/fhir/test_fhir_utils.py +++ b/tests/fhir/test_fhir_utils.py @@ -20,7 +20,11 @@ class TestReferenceHandlers(utils.AsyncTestCase): ({"reference": "123", "type": "Patient"}, "Patient", "123"), ({"reference": "#123"}, None, "#123"), # Synthea style reference - ({"reference": "Patient?identifier=http://example.com|123"}, "Patient", "identifier=http://example.com|123"), + ( + {"reference": "Patient?identifier=http://example.com|123"}, + "Patient", + "identifier=http://example.com|123", + ), ) @ddt.unpack def test_unref_successes(self, full_reference, expected_type, expected_id): @@ -62,9 +66,21 @@ class TestDateParsing(utils.AsyncTestCase): ("1992-11-06", datetime.datetime(1992, 11, 6)), # naive ( "1992-11-06T13:28:17.239+02:00", - datetime.datetime(1992, 11, 6, 13, 28, 17, 239000, tzinfo=datetime.timezone(datetime.timedelta(hours=2))), + datetime.datetime( + 1992, + 11, + 6, + 13, + 28, + 17, + 239000, + tzinfo=datetime.timezone(datetime.timedelta(hours=2)), + ), + ), + ( + "1992-11-06T13:28:17.239Z", + datetime.datetime(1992, 11, 6, 13, 28, 17, 239000, tzinfo=datetime.timezone.utc), ), - ("1992-11-06T13:28:17.239Z", datetime.datetime(1992, 11, 6, 13, 28, 17, 239000, tzinfo=datetime.timezone.utc)), ) @ddt.unpack def test_parse_datetime(self, input_value, expected_value): @@ -167,4 +183,7 @@ async def test_download_reference(self, reference, expected_result): result = await fhir.download_reference(mock_client, reference) self.assertEqual(expected_result, result) - self.assertEqual([mock.call("GET", reference)] if expected_result else [], mock_client.request.call_args_list) + self.assertEqual( + [mock.call("GET", reference)] if expected_result else [], + mock_client.request.call_args_list, + ) diff --git a/tests/formats/test_deltalake.py b/tests/formats/test_deltalake.py index d62dcf1e..aca0e8fb 100644 --- a/tests/formats/test_deltalake.py +++ b/tests/formats/test_deltalake.py @@ -118,7 +118,11 @@ def test_last_updated_support(self): {"id": "now", "meta": {"lastUpdated": now}, "value": 1}, {"id": "now-without-zed", "meta": {"lastUpdated": now_without_zed}, "value": 1}, {"id": "future", "meta": {"lastUpdated": future}, "value": 1}, - {"id": "future-with-offset", "meta": {"lastUpdated": future_with_offset}, "value": 1}, + { + "id": "future-with-offset", + "meta": {"lastUpdated": future_with_offset}, + "value": 1, + }, # this next one is off-spec (lastUpdated must provide at least seconds), but still {"id": "future-partial", "meta": {"lastUpdated": "3000-01-01"}, "value": 1}, {"id": "missing-date-table", "meta": {}, "value": 1}, @@ -155,7 +159,11 @@ def test_last_updated_support(self): {"id": "now", "meta": {"lastUpdated": now}, "value": 1}, {"id": "now-without-zed", "meta": {"lastUpdated": now_without_zed}, "value": 1}, {"id": "future", "meta": {"lastUpdated": future}, "value": 1}, - {"id": "future-with-offset", "meta": {"lastUpdated": future_with_offset}, "value": 1}, + { + "id": "future-with-offset", + "meta": {"lastUpdated": future_with_offset}, + "value": 1, + }, {"id": "future-partial", "meta": {"lastUpdated": "3000-01-01"}, "value": 1}, {"id": "missing-date-table", "meta": {"lastUpdated": now}, "value": 2}, {"id": "missing-date-update", "meta": {}, "value": 2}, @@ -275,7 +283,9 @@ def test_merged_schema_for_resource(self): (pyarrow.int32(), 2000, pyarrow.int64(), 3000000000, False, "integer", 2000), ) @ddt.unpack - def test_column_type_merges(self, type1, val1, type2, val2, expected_success, expected_type, expected_value): + def test_column_type_merges( + self, type1, val1, type2, val2, expected_success, expected_type, expected_value + ): """Verify that if we write a slightly different, but compatible field to the delta lake, it works""" schema1 = pyarrow.schema( [ @@ -312,8 +322,16 @@ def test_group_field(self): self.df( aa={"group": "A", "val": 5}, # will be deleted as stale group member ab={"group": "A", "val": 10}, # will be updated - b={"group": "B", "val": 1}, # will be ignored because 2nd batch won't have group B in it - c={"group": "C", "val": 2}, # will be deleted as group with zero members in new batch + # will be ignored because 2nd batch won't have group B in it + b={ + "group": "B", + "val": 1, + }, + # will be deleted as group with zero members in new batch + c={ + "group": "C", + "val": 2, + }, ), group_field="value.group", groups={"A", "B", "C"}, @@ -337,7 +355,11 @@ def test_group_field(self): d={"group": 'D"', "val": 3}, # whole new group ), group_field="value.group", - groups={"A", "C", 'D"'}, # C is present but with no rows (existing rows will be deleted) + groups={ + "A", + "C", # C is present but with no rows (existing rows will be deleted) + 'D"', + }, ) self.assert_lake_equal( self.df( diff --git a/tests/hftest/test_hftask.py b/tests/hftest/test_hftask.py index 7a78258e..ddaa74b8 100644 --- a/tests/hftest/test_hftask.py +++ b/tests/hftest/test_hftask.py @@ -10,7 +10,9 @@ from tests.etl import BaseEtlSimple, TaskTestCase -def mock_prompt(respx_mock: respx.MockRouter, text: str, url: str = "http://localhost:8086/") -> respx.Route: +def mock_prompt( + respx_mock: respx.MockRouter, text: str, url: str = "http://localhost:8086/" +) -> respx.Route: full_prompt = f"""[INST] <> You will be given a clinical note, and you should reply with a short summary of that note. <> @@ -31,7 +33,9 @@ def mock_prompt(respx_mock: respx.MockRouter, text: str, url: str = "http://loca def mock_info( - respx_mock: respx.MockRouter, url: str = "http://localhost:8086/info", override: dict | None = None + respx_mock: respx.MockRouter, + url: str = "http://localhost:8086/info", + override: dict | None = None, ) -> respx.Route: response = { "model_id": "meta-llama/Llama-2-13b-chat-hf", diff --git a/tests/loaders/i2b2/test_i2b2_etl.py b/tests/loaders/i2b2/test_i2b2_etl.py index 5554d00a..2fa80a42 100644 --- a/tests/loaders/i2b2/test_i2b2_etl.py +++ b/tests/loaders/i2b2/test_i2b2_etl.py @@ -33,7 +33,9 @@ async def test_full_etl(self): async def test_export(self): with tempfile.TemporaryDirectory() as export_path: # Only run patient task to make the test faster and confirm we don't export unnecessary files - await self.run_etl(input_format="i2b2", export_to=export_path, tasks=["patient"], philter=False) + await self.run_etl( + input_format="i2b2", export_to=export_path, tasks=["patient"], philter=False + ) expected_export_path = os.path.join(self.datadir, self.DATA_ROOT, "export") dircmp = filecmp.dircmp(export_path, expected_export_path) diff --git a/tests/loaders/i2b2/test_i2b2_oracle_connect.py b/tests/loaders/i2b2/test_i2b2_oracle_connect.py index 4dff598b..ad0fdf6b 100644 --- a/tests/loaders/i2b2/test_i2b2_oracle_connect.py +++ b/tests/loaders/i2b2/test_i2b2_oracle_connect.py @@ -40,7 +40,9 @@ def test_password_required(self, password): self.assertEqual(errors.SQL_PASSWORD_MISSING, cm.exception.code) @mock.patch("cumulus_etl.loaders.i2b2.oracle.connect.oracledb") - @mock.patch.dict(os.environ, {"CUMULUS_SQL_USER": "test-user", "CUMULUS_SQL_PASSWORD": "p4sswd"}) + @mock.patch.dict( + os.environ, {"CUMULUS_SQL_USER": "test-user", "CUMULUS_SQL_PASSWORD": "p4sswd"} + ) def test_connect(self, mock_oracledb): """Verify that we pass all the right parameters to Oracle when connecting""" connect.connect("tcp://localhost/foo") diff --git a/tests/loaders/i2b2/test_i2b2_oracle_extract.py b/tests/loaders/i2b2/test_i2b2_oracle_extract.py index 7bdb586c..0f3d3496 100644 --- a/tests/loaders/i2b2/test_i2b2_oracle_extract.py +++ b/tests/loaders/i2b2/test_i2b2_oracle_extract.py @@ -36,7 +36,9 @@ def test_list_observation_fact(self): results = extract.list_observation_fact("localhost", "Diagnosis") self.assertEqual(1, len(results)) self.assertEqual("notes", results[0].observation_blob) - self.assertEqual([mock.call(query.sql_observation_fact("Diagnosis"))], self.mock_execute.call_args_list) + self.assertEqual( + [mock.call(query.sql_observation_fact("Diagnosis"))], self.mock_execute.call_args_list + ) def test_list_patient(self): self.mock_cursor.__iter__.return_value = [ @@ -103,6 +105,14 @@ async def test_loader(self, mock_extract): set(os.listdir(tmpdir.name)), ) - self.assertEqual(i2b2_mock_data.condition(), common.read_json(os.path.join(tmpdir.name, "Condition.ndjson"))) - self.assertEqual(i2b2_mock_data.encounter(), common.read_json(os.path.join(tmpdir.name, "Encounter.ndjson"))) - self.assertEqual(i2b2_mock_data.patient(), common.read_json(os.path.join(tmpdir.name, "Patient.ndjson"))) + self.assertEqual( + i2b2_mock_data.condition(), + common.read_json(os.path.join(tmpdir.name, "Condition.ndjson")), + ) + self.assertEqual( + i2b2_mock_data.encounter(), + common.read_json(os.path.join(tmpdir.name, "Encounter.ndjson")), + ) + self.assertEqual( + i2b2_mock_data.patient(), common.read_json(os.path.join(tmpdir.name, "Patient.ndjson")) + ) diff --git a/tests/loaders/i2b2/test_i2b2_oracle_query.py b/tests/loaders/i2b2/test_i2b2_oracle_query.py index 40f2dc8d..e0df9d4f 100644 --- a/tests/loaders/i2b2/test_i2b2_oracle_query.py +++ b/tests/loaders/i2b2/test_i2b2_oracle_query.py @@ -13,7 +13,9 @@ def pretty(text): print(text) -def count_by_date(column: str, column_alias: str, count="*", count_alias="cnt", frmt="YYYY-MM-DD") -> str: +def count_by_date( + column: str, column_alias: str, count="*", count_alias="cnt", frmt="YYYY-MM-DD" +) -> str: sql_count = f"count({count}) as {count_alias}" sql_group = query.format_date(column, column_alias, frmt) return sql_count + "," + sql_group diff --git a/tests/loaders/i2b2/test_i2b2_transform.py b/tests/loaders/i2b2/test_i2b2_transform.py index a45fcfc7..87eac137 100644 --- a/tests/loaders/i2b2/test_i2b2_transform.py +++ b/tests/loaders/i2b2/test_i2b2_transform.py @@ -24,7 +24,12 @@ def test_to_fhir_patient(self): @ddt.data( ("Black or African American", "race", "urn:oid:2.16.840.1.113883.6.238", "2054-5"), ("Hispanic or Latino", "ethnicity", "urn:oid:2.16.840.1.113883.6.238", "2135-2"), - ("Declined to Answer", "ethnicity", "http://terminology.hl7.org/CodeSystem/v3-NullFlavor", "ASKU"), + ( + "Declined to Answer", + "ethnicity", + "http://terminology.hl7.org/CodeSystem/v3-NullFlavor", + "ASKU", + ), ) @ddt.unpack def test_patient_race_vs_ethnicity(self, race_cd, url, system, code): @@ -66,7 +71,9 @@ def test_to_fhir_condition(self): self.assertEqual("Patient/12345", condition["subject"]["reference"]) self.assertEqual("Encounter/67890", condition["encounter"]["reference"]) self.assertEqual("U07.1", condition["code"]["coding"][0]["code"]) - self.assertEqual("http://hl7.org/fhir/sid/icd-10-cm", condition["code"]["coding"][0]["system"]) + self.assertEqual( + "http://hl7.org/fhir/sid/icd-10-cm", condition["code"]["coding"][0]["system"] + ) self.assertEqual("COVID-19", condition["code"]["coding"][0]["display"]) def test_to_fhir_documentreference(self): @@ -101,7 +108,9 @@ def test_to_fhir_observation_lab_case_does_not_matter(self): self.assertEqual("10828004", lab_fhir["valueCodeableConcept"]["coding"][0]["code"]) self.assertEqual("POSitiVE", lab_fhir["valueCodeableConcept"]["coding"][0]["display"]) - self.assertEqual("http://snomed.info/sct", lab_fhir["valueCodeableConcept"]["coding"][0]["system"]) + self.assertEqual( + "http://snomed.info/sct", lab_fhir["valueCodeableConcept"]["coding"][0]["system"] + ) def test_to_fhir_observation_lab_unknown_tval(self): dim = i2b2_mock_data.observation_dim() @@ -111,7 +120,8 @@ def test_to_fhir_observation_lab_unknown_tval(self): self.assertEqual("Nope", lab_fhir["valueCodeableConcept"]["coding"][0]["code"]) self.assertEqual("Nope", lab_fhir["valueCodeableConcept"]["coding"][0]["display"]) self.assertEqual( - "http://cumulus.smarthealthit.org/i2b2", lab_fhir["valueCodeableConcept"]["coding"][0]["system"] + "http://cumulus.smarthealthit.org/i2b2", + lab_fhir["valueCodeableConcept"]["coding"][0]["system"], ) def test_to_fhir_medicationrequest(self): @@ -169,9 +179,15 @@ def test_to_fhir_observation_vitals(self): ] } ], - "code": {"coding": [{"code": "VITAL:1234", "system": "http://cumulus.smarthealthit.org/i2b2"}]}, + "code": { + "coding": [ + {"code": "VITAL:1234", "system": "http://cumulus.smarthealthit.org/i2b2"} + ] + }, "valueCodeableConcept": { - "coding": [{"code": "Left Leg", "system": "http://cumulus.smarthealthit.org/i2b2"}] + "coding": [ + {"code": "Left Leg", "system": "http://cumulus.smarthealthit.org/i2b2"} + ] }, "effectiveDateTime": "2020-10-30", "status": "unknown", diff --git a/tests/loaders/ndjson/test_bulk_export.py b/tests/loaders/ndjson/test_bulk_export.py index 5be99d49..50f91fc6 100644 --- a/tests/loaders/ndjson/test_bulk_export.py +++ b/tests/loaders/ndjson/test_bulk_export.py @@ -146,7 +146,9 @@ async def test_happy_path(self): await self.export() self.assertEqual("MyGroup", self.exporter.group_name) - self.assertEqual("2015-02-07T13:28:17.239000+02:00", self.exporter.export_datetime.isoformat()) + self.assertEqual( + "2015-02-07T13:28:17.239000+02:00", self.exporter.export_datetime.isoformat() + ) # Ensure we can read back our own log and parse the above values too parser = BulkExportLogParser(store.Root(self.tmpdir)) @@ -154,12 +156,17 @@ async def test_happy_path(self): self.assertEqual("2015-02-07T13:28:17.239000+02:00", parser.export_datetime.isoformat()) self.assertEqual( - {"resourceType": "Condition", "id": "1"}, common.read_json(f"{self.tmpdir}/Condition.000.ndjson") + {"resourceType": "Condition", "id": "1"}, + common.read_json(f"{self.tmpdir}/Condition.000.ndjson"), ) self.assertEqual( - {"resourceType": "Condition", "id": "2"}, common.read_json(f"{self.tmpdir}/Condition.001.ndjson") + {"resourceType": "Condition", "id": "2"}, + common.read_json(f"{self.tmpdir}/Condition.001.ndjson"), + ) + self.assertEqual( + {"resourceType": "Patient", "id": "P"}, + common.read_json(f"{self.tmpdir}/Patient.000.ndjson"), ) - self.assertEqual({"resourceType": "Patient", "id": "P"}, common.read_json(f"{self.tmpdir}/Patient.000.ndjson")) self.assert_log_equals( ( @@ -190,20 +197,44 @@ async def test_happy_path(self): ), ( "download_request", - {"fileUrl": "https://example.com/con1", "itemType": "output", "resourceType": "Condition"}, + { + "fileUrl": "https://example.com/con1", + "itemType": "output", + "resourceType": "Condition", + }, + ), + ( + "download_complete", + {"fileSize": 40, "fileUrl": "https://example.com/con1", "resourceCount": 1}, ), - ("download_complete", {"fileSize": 40, "fileUrl": "https://example.com/con1", "resourceCount": 1}), ( "download_request", - {"fileUrl": "https://example.com/con2", "itemType": "output", "resourceType": "Condition"}, + { + "fileUrl": "https://example.com/con2", + "itemType": "output", + "resourceType": "Condition", + }, + ), + ( + "download_complete", + {"fileSize": 40, "fileUrl": "https://example.com/con2", "resourceCount": 1}, ), - ("download_complete", {"fileSize": 40, "fileUrl": "https://example.com/con2", "resourceCount": 1}), ( "download_request", - {"fileUrl": "https://example.com/pat1", "itemType": "output", "resourceType": "Patient"}, + { + "fileUrl": "https://example.com/pat1", + "itemType": "output", + "resourceType": "Patient", + }, + ), + ( + "download_complete", + {"fileSize": 38, "fileUrl": "https://example.com/pat1", "resourceCount": 1}, + ), + ( + "export_complete", + {"attachments": None, "bytes": 118, "duration": 0, "files": 3, "resources": 3}, ), - ("download_complete", {"fileSize": 38, "fileUrl": "https://example.com/pat1", "resourceCount": 1}), - ("export_complete", {"attachments": None, "bytes": 118, "duration": 0, "files": 3, "resources": 3}), ) async def test_since_until(self): @@ -285,20 +316,44 @@ async def test_export_error(self): ), ( "download_request", - {"fileUrl": "https://example.com/con1", "itemType": "output", "resourceType": "Condition"}, + { + "fileUrl": "https://example.com/con1", + "itemType": "output", + "resourceType": "Condition", + }, + ), + ( + "download_complete", + {"fileSize": 29, "fileUrl": "https://example.com/con1", "resourceCount": 1}, ), - ("download_complete", {"fileSize": 29, "fileUrl": "https://example.com/con1", "resourceCount": 1}), ( "download_request", - {"fileUrl": "https://example.com/err1", "itemType": "error", "resourceType": "OperationOutcome"}, + { + "fileUrl": "https://example.com/err1", + "itemType": "error", + "resourceType": "OperationOutcome", + }, + ), + ( + "download_complete", + {"fileSize": 93, "fileUrl": "https://example.com/err1", "resourceCount": 1}, ), - ("download_complete", {"fileSize": 93, "fileUrl": "https://example.com/err1", "resourceCount": 1}), ( "download_request", - {"fileUrl": "https://example.com/err2", "itemType": "error", "resourceType": "OperationOutcome"}, + { + "fileUrl": "https://example.com/err2", + "itemType": "error", + "resourceType": "OperationOutcome", + }, + ), + ( + "download_complete", + {"fileSize": 322, "fileUrl": "https://example.com/err2", "resourceCount": 3}, + ), + ( + "export_complete", + {"attachments": None, "bytes": 444, "duration": 0, "files": 3, "resources": 5}, ), - ("download_complete", {"fileSize": 322, "fileUrl": "https://example.com/err2", "resourceCount": 3}), - ("export_complete", {"attachments": None, "bytes": 444, "duration": 0, "files": 3, "resources": 5}), ) async def test_export_warning(self): @@ -337,7 +392,9 @@ async def test_file_download_error(self): ], }, ) - self.respx_mock.get("https://example.com/con1").respond(status_code=501, content=b'["error"]') + self.respx_mock.get("https://example.com/con1").respond( + status_code=501, content=b'["error"]' + ) with self.assertRaisesRegex( errors.FatalError, @@ -395,7 +452,9 @@ async def test_delay(self, mock_sleep): side_effect=[ # Before returning a successful kickoff, pause for an hour respx.MockResponse(status_code=429, headers={"Retry-After": "3600"}), - respx.MockResponse(status_code=202, headers={"Content-Location": "https://example.com/poll"}), + respx.MockResponse( + status_code=202, headers={"Content-Location": "https://example.com/poll"} + ), ] ) self.respx_mock.get("https://example.com/poll").side_effect = [ @@ -404,7 +463,9 @@ async def test_delay(self, mock_sleep): # five hours (though 202 responses will get limited to five min) respx.MockResponse(status_code=202, headers={"Retry-After": "18000"}, content=b"..."), # 23 hours (putting us over a day) - respx.MockResponse(status_code=429, headers={"Retry-After": "82800", "X-Progress": "plz wait"}), + respx.MockResponse( + status_code=429, headers={"Retry-After": "82800", "X-Progress": "plz wait"} + ), ] with self.assertRaisesRegex(errors.FatalError, "Timed out waiting"): @@ -458,7 +519,8 @@ async def test_no_delete_if_interrupted(self): "body": "Test Status Call Failed", "code": 500, "message": ( - 'An error occurred when connecting to "https://example.com/poll": ' "Test Status Call Failed" + 'An error occurred when connecting to "https://example.com/poll": ' + "Test Status Call Failed" ), "responseHeaders": {"content-length": "23"}, }, @@ -577,7 +639,10 @@ async def test_successful_bulk_export(self): ) self.assertEqual( - {"id": "4342abf315cf6f243e11f4d460303e36c6c3663a25c91cc6b1a8002476c850dd", "resourceType": "Patient"}, + { + "id": "4342abf315cf6f243e11f4d460303e36c6c3663a25c91cc6b1a8002476c850dd", + "resourceType": "Patient", + }, common.read_json(f"{tmpdir}/output/patient/patient.000.ndjson"), ) diff --git a/tests/loaders/ndjson/test_ndjson_loader.py b/tests/loaders/ndjson/test_ndjson_loader.py index df3ae6a3..4c2c7186 100644 --- a/tests/loaders/ndjson/test_ndjson_loader.py +++ b/tests/loaders/ndjson/test_ndjson_loader.py @@ -28,7 +28,9 @@ def setUp(self): # Mock out the bulk export code by default. We don't care about actually doing any # bulk work in this test case, just confirming the flow. - exporter_patcher = mock.patch("cumulus_etl.loaders.fhir.ndjson_loader.BulkExporter", spec=BulkExporter) + exporter_patcher = mock.patch( + "cumulus_etl.loaders.fhir.ndjson_loader.BulkExporter", spec=BulkExporter + ) self.addCleanup(exporter_patcher.stop) self.mock_exporter_class = exporter_patcher.start() self.mock_exporter = mock.AsyncMock() @@ -65,7 +67,9 @@ async def test_local_happy_path(self): self.assertEqual(["Patient.ndjson"], os.listdir(loaded_dir.name)) self.assertEqual(patient, common.read_json(f"{loaded_dir.name}/Patient.ndjson")) self.assertEqual("G", loader.group_name) - self.assertEqual(datetime.datetime.fromisoformat("1999-03-14T14:12:10"), loader.export_datetime) + self.assertEqual( + datetime.datetime.fromisoformat("1999-03-14T14:12:10"), loader.export_datetime + ) # At some point, we do want to make this fatal. # But not while this feature is still optional. @@ -247,7 +251,8 @@ async def test_export_flow(self, mock_client): """ Verify that we make the right calls down as far as the bulk export helper classes, with the right resources. """ - self.mock_exporter.export.side_effect = ValueError # stop us when we get this far, but also confirm we call it + # stop us when we get to the exporting step, but also confirm we call it + self.mock_exporter.export.side_effect = ValueError with self.assertRaises(ValueError): await cli.main( @@ -271,7 +276,9 @@ async def test_fatal_errors_are_fatal(self): self.mock_exporter.export.side_effect = errors.FatalError with self.assertRaises(SystemExit) as cm: - await loaders.FhirNdjsonLoader(store.Root("http://localhost:9999"), mock.AsyncMock()).load_all(["Patient"]) + await loaders.FhirNdjsonLoader( + store.Root("http://localhost:9999"), mock.AsyncMock() + ).load_all(["Patient"]) self.assertEqual(1, self.mock_exporter.export.call_count) self.assertEqual(errors.BULK_EXPORT_FAILED, cm.exception.code) @@ -289,7 +296,9 @@ async def fake_export() -> None: self.mock_exporter.export.side_effect = fake_export target = f"{tmpdir}/target" - loader = loaders.FhirNdjsonLoader(store.Root("http://localhost:9999"), mock.AsyncMock(), export_to=target) + loader = loaders.FhirNdjsonLoader( + store.Root("http://localhost:9999"), mock.AsyncMock(), export_to=target + ) folder = await loader.load_all(["Patient"]) # Confirm export folder still has the data (and log) we created above in the mock @@ -326,14 +335,18 @@ async def test_export_to_folder_has_contents(self): """Verify we fail if an export folder already has contents""" with tempfile.TemporaryDirectory() as tmpdir: os.mkdir(f"{tmpdir}/stuff") - loader = loaders.FhirNdjsonLoader(store.Root("http://localhost:9999"), mock.AsyncMock(), export_to=tmpdir) + loader = loaders.FhirNdjsonLoader( + store.Root("http://localhost:9999"), mock.AsyncMock(), export_to=tmpdir + ) with self.assertRaises(SystemExit) as cm: await loader.load_all([]) self.assertEqual(cm.exception.code, errors.FOLDER_NOT_EMPTY) async def test_export_to_folder_not_local(self): """Verify we fail if an export folder is not local""" - loader = loaders.FhirNdjsonLoader(store.Root("http://localhost:9999"), mock.AsyncMock(), export_to="http://foo") + loader = loaders.FhirNdjsonLoader( + store.Root("http://localhost:9999"), mock.AsyncMock(), export_to="http://foo" + ) with self.assertRaises(SystemExit) as cm: await loader.load_all([]) self.assertEqual(cm.exception.code, errors.BULK_EXPORT_FOLDER_NOT_LOCAL) diff --git a/tests/test_common.py b/tests/test_common.py index 75221828..07652c3f 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -74,9 +74,13 @@ def test_writes_are_atomic(self): # By default, fsspec writes are not atomic - just sanity check that text _can_ get through exploding_text with self.exploding_text() as text: - with fsspec.open(f"{self.bucket_url}/partial.txt", "w", endpoint_url=s3mock.S3Mixin.ENDPOINT_URL) as f: + with fsspec.open( + f"{self.bucket_url}/partial.txt", "w", endpoint_url=s3mock.S3Mixin.ENDPOINT_URL + ) as f: f.write(text) - self.assertEqual([f"{self.bucket}/partial.txt"], self.s3fs.ls(self.bucket_url, detail=False)) + self.assertEqual( + [f"{self.bucket}/partial.txt"], self.s3fs.ls(self.bucket_url, detail=False) + ) @ddt.idata( # Every combination of these sizes, backends, and data formats: diff --git a/tests/upload_notes/test_upload_cli.py b/tests/upload_notes/test_upload_cli.py index 5778bcfc..628e77da 100644 --- a/tests/upload_notes/test_upload_cli.py +++ b/tests/upload_notes/test_upload_cli.py @@ -158,14 +158,22 @@ def mock_search_url(respx_mock: respx.MockRouter, patient: str, doc_ids: Iterabl ], } - respx_mock.get(f"https://localhost/DocumentReference?patient={patient}&_elements=content").respond(json=bundle) + respx_mock.get( + f"https://localhost/DocumentReference?patient={patient}&_elements=content" + ).respond(json=bundle) @staticmethod def mock_read_url( - respx_mock: respx.MockRouter, doc_id: str, code: int = 200, docref: dict | None = None, **kwargs + respx_mock: respx.MockRouter, + doc_id: str, + code: int = 200, + docref: dict | None = None, + **kwargs, ) -> None: docref = docref or TestUploadNotes.make_docref(doc_id, **kwargs) - respx_mock.get(f"https://localhost/DocumentReference/{doc_id}").respond(status_code=code, json=docref) + respx_mock.get(f"https://localhost/DocumentReference/{doc_id}").respond( + status_code=code, json=docref + ) @staticmethod def write_anon_docrefs(path: str, ids: list[tuple[str, str]]) -> None: @@ -182,7 +190,9 @@ def write_real_docrefs(path: str, ids: list[str]) -> None: f.write("\n".join(lines)) def get_exported_ids(self) -> set[str]: - rows = cumulus_fhir_support.read_multiline_json(f"{self.export_path}/DocumentReference.ndjson") + rows = cumulus_fhir_support.read_multiline_json( + f"{self.export_path}/DocumentReference.ndjson" + ) return {row["id"] for row in rows} def get_pushed_ids(self) -> set[str]: @@ -195,10 +205,12 @@ def wrap_note(title: str, text: str, first: bool = True, date: str | None = None finalized = "" if not first: finalized += "\n\n\n" - finalized += "########################################\n########################################\n" + finalized += "########################################\n" + finalized += "########################################\n" finalized += f"{title}\n" finalized += f"{date or 'Unknown time'}\n" - finalized += "########################################\n########################################\n\n\n" + finalized += "########################################\n" + finalized += "########################################\n\n\n" finalized += text.strip() return finalized @@ -333,8 +345,18 @@ async def test_successful_push_to_label_studio(self): "text": "for", "polarity": 0, "conceptAttributes": [ - {"code": "386661006", "cui": "C0015967", "codingScheme": "SNOMEDCT_US", "tui": "T184"}, - {"code": "50177009", "cui": "C0015967", "codingScheme": "SNOMEDCT_US", "tui": "T184"}, + { + "code": "386661006", + "cui": "C0015967", + "codingScheme": "SNOMEDCT_US", + "tui": "T184", + }, + { + "code": "50177009", + "cui": "C0015967", + "codingScheme": "SNOMEDCT_US", + "tui": "T184", + }, ], "type": "SignSymptomMention", }, @@ -363,7 +385,11 @@ async def test_disabled_nlp(self): ) @ddt.unpack async def test_philter_redact(self, upload_args, expect_redacted): - notes = [LabelStudioNote("EncID", "EncAnon", title="My Title", text="John Smith called on 10/13/2010")] + notes = [ + LabelStudioNote( + "EncID", "EncAnon", title="My Title", text="John Smith called on 10/13/2010" + ) + ] with mock.patch("cumulus_etl.upload_notes.cli.read_notes_from_ndjson", return_value=notes): await self.run_upload_notes(**upload_args) @@ -381,7 +407,11 @@ async def test_philter_redact(self, upload_args, expect_redacted): self.assertEqual(self.wrap_note("My Title", expected_text), task.text) async def test_philter_label(self): - notes = [LabelStudioNote("EncID", "EncAnon", title="My Title", text="John Smith called on 10/13/2010")] + notes = [ + LabelStudioNote( + "EncID", "EncAnon", title="My Title", text="John Smith called on 10/13/2010" + ) + ] with mock.patch("cumulus_etl.upload_notes.cli.read_notes_from_ndjson", return_value=notes): await self.run_upload_notes(philter="label") @@ -397,11 +427,17 @@ async def test_grouped_datetime(self): with common.NdjsonWriter(f"{tmpdir}/DocumentReference.ndjson") as writer: writer.write(TestUploadNotes.make_docref("D1", enc_id="E1", text="DocRef 1")) writer.write( - TestUploadNotes.make_docref("D2", enc_id="E1", text="DocRef 2", date="2018-01-03T13:10:10+01:00") + TestUploadNotes.make_docref( + "D2", enc_id="E1", text="DocRef 2", date="2018-01-03T13:10:10+01:00" + ) ) writer.write( TestUploadNotes.make_docref( - "D3", enc_id="E1", text="DocRef 3", date="2018-01-03T13:10:20Z", period_start="2018" + "D3", + enc_id="E1", + text="DocRef 3", + date="2018-01-03T13:10:20Z", + period_start="2018", ) ) await self.run_upload_notes(input_path=tmpdir, philter="disable") diff --git a/tests/upload_notes/test_upload_labelstudio.py b/tests/upload_notes/test_upload_labelstudio.py index d72aa71b..24601959 100644 --- a/tests/upload_notes/test_upload_labelstudio.py +++ b/tests/upload_notes/test_upload_labelstudio.py @@ -24,7 +24,9 @@ def setUp(self): self.ls_project.parsed_label_config = {"mylabel": {"type": "Labels", "to_name": ["mytext"]}} @staticmethod - def make_note(*, enc_id: str = "enc", ctakes: bool = True, philter_label: bool = True) -> LabelStudioNote: + def make_note( + *, enc_id: str = "enc", ctakes: bool = True, philter_label: bool = True + ) -> LabelStudioNote: text = "Normal note text" note = LabelStudioNote( enc_id, @@ -34,7 +36,9 @@ def make_note(*, enc_id: str = "enc", ctakes: bool = True, philter_label: bool = text=text, ) if ctakes: - note.ctakes_matches = ctakesmock.fake_ctakes_extract(note.text).list_match(polarity=Polarity.pos) + note.ctakes_matches = ctakesmock.fake_ctakes_extract(note.text).list_match( + polarity=Polarity.pos + ) if philter_label: matches = ctakesmock.fake_ctakes_extract(note.text).list_match(polarity=Polarity.pos) note.philter_map = {m.begin: m.end for m in matches} @@ -50,7 +54,8 @@ def push_tasks(*notes, **kwargs) -> None: # These two CUIs are in our standard mock cTAKES response "C0033774": "Itch", "C0027497": "Nausea", - "C0028081": "Night Sweats", # to demonstrate that unmatched CUIs are not generally pushed + # The third is demonstrates that unmatched CUIs are not generally pushed + "C0028081": "Night Sweats", }, ) client.push_tasks(notes, **kwargs) @@ -76,7 +81,8 @@ def test_basic_push(self): { "model_version": "Cumulus cTAKES", "result": [ - # Note that fever does not show up, as it was not in our initial CUI mapping (in push_tasks) + # Note that fever does not show up, + # as it was not in our initial CUI mapping (in push_tasks) { "from_name": "mylabel", "id": "ctakes0", diff --git a/tests/utils.py b/tests/utils.py index cdb8e63b..74dc97bf 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -22,8 +22,11 @@ from cumulus_etl.formats.deltalake import DeltaLakeFormat # Pass a non-UTC time to time-machine to help notice any bad timezone handling. -# But only bother exposing the UTC version to other test code, since that's what will be most useful/common. -_FROZEN_TIME = datetime.datetime(2021, 9, 15, 1, 23, 45, tzinfo=datetime.timezone(datetime.timedelta(hours=4))) +# But only bother exposing the UTC version to other test code, +# since that's what will be most useful/common. +_FROZEN_TIME = datetime.datetime( + 2021, 9, 15, 1, 23, 45, tzinfo=datetime.timezone(datetime.timedelta(hours=4)) +) FROZEN_TIME_UTC = _FROZEN_TIME.astimezone(datetime.timezone.utc) @@ -36,12 +39,13 @@ class AsyncTestCase(unittest.IsolatedAsyncioTestCase): def setUp(self): super().setUp() - self.patch( - "cumulus_etl.deid.codebook.secrets.token_hex", new=lambda x: "1234" - ) # keep all codebook IDs consistent - # It's so common to want to see more than the tiny default fragment -- just enable this across the board. - self.maxDiff = None # pylint: disable=invalid-name + # keep all codebook IDs consistent + self.patch("cumulus_etl.deid.codebook.secrets.token_hex", new=lambda x: "1234") + + # It's so common to want to see more than the tiny default fragment. + # So we just enable this across the board. + self.maxDiff = None # Make it easy to grab test data, regardless of where the test is self.datadir = os.path.join(os.path.dirname(__file__), "data") @@ -64,13 +68,13 @@ def patch(self, *args, **kwargs) -> mock.Mock: return patcher.start() def patch_dict(self, *args, **kwargs) -> mock.Mock: - """Syntactic sugar to ease making a dictionary mock over a test's lifecycle, without decorators""" + """Syntactic sugar for making a dict mock over a test's lifecycle, without decorators""" patcher = mock.patch.dict(*args, **kwargs) self.addCleanup(patcher.stop) return patcher.start() def patch_object(self, *args, **kwargs) -> mock.Mock: - """Syntactic sugar to ease making an object mock over a test's lifecycle, without decorators""" + """Syntactic sugar for making an object mock over a test's lifecycle, without decorators""" patcher = mock.patch.object(*args, **kwargs) self.addCleanup(patcher.stop) return patcher.start() @@ -88,17 +92,20 @@ def _callTestMethod(self, method): """ Works around an async test case bug in python 3.10 and below. - This seems to be some version of https://github.com/python/cpython/issues/83282 but fixed & never backported. + This seems to be some version of https://github.com/python/cpython/issues/83282 + but fixed & never backported. I was not able to find a separate bug report for this specific issue. - Given the following two test methods (for Pythons before 3.11), only the second one will hang: + Given the following two test methods (for Pythons before 3.11), + only the second one will hang: async def test_fails_correctly(self): raise BaseException("OK") async def test_hangs_forever(self): raise SystemExit("Nope") - This class works around that by wrapping all test methods and translating uncaught SystemExits into failures. + This class works around that by wrapping all test methods and translating uncaught + SystemExits into failures. _callTestMethod() can be deleted once we no longer use python 3.10 in our testing suite. """ return super()._callTestMethod(functools.partial(self._catch_system_exit, method)) @@ -176,9 +183,9 @@ def setUp(self): self.fhir_client_id = "test-client-id" self.fhir_bearer = "1234567890" # the provided oauth bearer token - jwk_token = jwk.JWK.generate(kty="EC", alg="ES384", curve="P-384", kid="a", key_ops=["sign", "verify"]).export( - as_dict=True - ) + jwk_token = jwk.JWK.generate( + kty="EC", alg="ES384", curve="P-384", kid="a", key_ops=["sign", "verify"] + ).export(as_dict=True) self.fhir_jwks = {"keys": [jwk_token]} self._fhir_jwks_file = tempfile.NamedTemporaryFile() # pylint: disable=consider-using-with @@ -239,19 +246,23 @@ def fhir_client(self, resources: list[str]) -> fhir.FhirClient: ) -def make_response(status_code=200, json_payload=None, text=None, reason=None, headers=None, stream=False): +def make_response( + status_code=200, json_payload=None, text=None, reason=None, headers=None, stream=False +): """ Makes a fake respx response for ease of testing. Usually you'll want to use respx.get(...) etc directly. - But if you want to mock out the client <-> server interaction entirely, you can use this method to fake a - Response object from a method that returns one. + But if you want to mock out the client <-> server interaction entirely, + you can use this method to fake a Response object from a method that returns one. Example: server.request.return_value = make_response() """ headers = dict(headers or {}) - headers.setdefault("Content-Type", "application/json" if json_payload else "text/plain; charset=utf-8") + headers.setdefault( + "Content-Type", "application/json" if json_payload else "text/plain; charset=utf-8" + ) json_payload = json.dumps(json_payload) if json_payload else None body = (json_payload or text or "").encode("utf8") stream_contents = None @@ -284,8 +295,8 @@ def read_delta_lake(lake_path: str, *, version: int | None = None) -> list[dict] table_spark = reader.format("delta").load(lake_path) # Convert the spark table to Python primitives. - # Going to rdd or pandas and then to Python keeps inserting spark-specific constructs like Row(). - # So instead, convert to a JSON string and then back to Python. + # Going to rdd or pandas and then to Python keeps inserting spark-specific constructs like + # Row(). So instead, convert to a JSON string and then back to Python. rows = [json.loads(row) for row in table_spark.toJSON().collect()] # Try to sort by id, but if that doesn't exist (which happens for some completion tables), From 5be339e31db2db0f3e889fad220a82c1e11d6cbb Mon Sep 17 00:00:00 2001 From: Michael Terry Date: Mon, 29 Jul 2024 13:49:08 -0400 Subject: [PATCH 3/3] style: drop all pylint disables --- cumulus_etl/cli_utils.py | 2 +- cumulus_etl/deid/scrubber.py | 2 +- cumulus_etl/etl/studies/covid_symptom/covid_ctakes.py | 4 ++-- cumulus_etl/etl/tasks/basic_tasks.py | 2 +- cumulus_etl/etl/tasks/nlp_task.py | 2 +- cumulus_etl/formats/base.py | 2 +- cumulus_etl/formats/deltalake.py | 4 ++-- cumulus_etl/loaders/fhir/ndjson_loader.py | 2 +- cumulus_etl/loaders/i2b2/oracle/query.py | 4 ++-- cumulus_etl/nlp/extract.py | 4 ++-- cumulus_etl/upload_notes/selector.py | 2 +- tests/covid_symptom/test_covid_results.py | 4 ++-- tests/ctakesmock.py | 4 ++-- tests/formats/test_deltalake.py | 2 +- tests/loaders/i2b2/test_i2b2_oracle_extract.py | 2 +- tests/loaders/i2b2/test_i2b2_oracle_query.py | 2 +- tests/loaders/i2b2/test_i2b2_transform.py | 1 - tests/loaders/ndjson/test_bulk_export.py | 2 +- tests/loaders/ndjson/test_ndjson_loader.py | 2 +- tests/s3mock.py | 2 +- tests/utils.py | 6 +++--- 21 files changed, 28 insertions(+), 29 deletions(-) diff --git a/cumulus_etl/cli_utils.py b/cumulus_etl/cli_utils.py index 1ba5bea2..41513618 100644 --- a/cumulus_etl/cli_utils.py +++ b/cumulus_etl/cli_utils.py @@ -62,7 +62,7 @@ def make_export_dir(export_to: str | None = None) -> common.Directory: """Makes a temporary directory to drop exported ndjson files into""" # Handle the easy case -- just a random temp dir if not export_to: - return tempfile.TemporaryDirectory() # pylint: disable=consider-using-with + return tempfile.TemporaryDirectory() # OK the user has a specific spot in mind. Let's do some quality checks. It must be local and empty. diff --git a/cumulus_etl/deid/scrubber.py b/cumulus_etl/deid/scrubber.py index 6a75b795..c13037f8 100644 --- a/cumulus_etl/deid/scrubber.py +++ b/cumulus_etl/deid/scrubber.py @@ -44,7 +44,7 @@ async def scrub_bulk_data(input_dir: str) -> tempfile.TemporaryDirectory: :returns: a temporary directory holding the de-identified results, in FHIR ndjson format """ - tmpdir = tempfile.TemporaryDirectory() # pylint: disable=consider-using-with + tmpdir = tempfile.TemporaryDirectory() await mstool.run_mstool(input_dir, tmpdir.name) return tmpdir diff --git a/cumulus_etl/etl/studies/covid_symptom/covid_ctakes.py b/cumulus_etl/etl/studies/covid_symptom/covid_ctakes.py index 7fe4dd18..da5e110e 100644 --- a/cumulus_etl/etl/studies/covid_symptom/covid_ctakes.py +++ b/cumulus_etl/etl/studies/covid_symptom/covid_ctakes.py @@ -63,7 +63,7 @@ async def covid_symptoms_extract( ctakes_json = await nlp.ctakes_extract( cache, ctakes_namespace, clinical_note, client=ctakes_http_client ) - except Exception as exc: # pylint: disable=broad-except + except Exception as exc: logging.warning( "Could not extract symptoms for docref %s (%s): %s", docref_id, type(exc).__name__, exc ) @@ -96,7 +96,7 @@ def is_covid_match(m: ctakesclient.typesystem.MatchText): model=polarity_model, client=cnlp_http_client, ) - except Exception as exc: # pylint: disable=broad-except + except Exception as exc: logging.warning( "Could not check polarity for docref %s (%s): %s", docref_id, type(exc).__name__, exc ) diff --git a/cumulus_etl/etl/tasks/basic_tasks.py b/cumulus_etl/etl/tasks/basic_tasks.py index 009b5029..2dacc971 100644 --- a/cumulus_etl/etl/tasks/basic_tasks.py +++ b/cumulus_etl/etl/tasks/basic_tasks.py @@ -149,7 +149,7 @@ async def fetch_medication(self, resource: dict) -> dict | None: try: medication = await fhir.download_reference(self.task_config.client, reference) - except Exception as exc: # pylint: disable=broad-except + except Exception as exc: logging.warning("Could not download Medication reference: %s", exc) self.summaries[1].had_errors = True diff --git a/cumulus_etl/etl/tasks/nlp_task.py b/cumulus_etl/etl/tasks/nlp_task.py index ae0fdb05..0150e33a 100644 --- a/cumulus_etl/etl/tasks/nlp_task.py +++ b/cumulus_etl/etl/tasks/nlp_task.py @@ -94,7 +94,7 @@ async def read_notes( warned_connection_error = True self.add_error(orig_docref) continue - except Exception as exc: # pylint: disable=broad-except + except Exception as exc: logging.warning("Error getting text for docref %s: %s", docref["id"], exc) self.add_error(orig_docref) continue diff --git a/cumulus_etl/formats/base.py b/cumulus_etl/formats/base.py index 20ed5115..8f56dd46 100644 --- a/cumulus_etl/formats/base.py +++ b/cumulus_etl/formats/base.py @@ -62,7 +62,7 @@ def write_records(self, batch: Batch) -> bool: try: self._write_one_batch(batch) return True - except Exception: # pylint: disable=broad-except + except Exception: logging.exception("Could not process data records") return False diff --git a/cumulus_etl/formats/deltalake.py b/cumulus_etl/formats/deltalake.py index 459dfbff..1e8fc60d 100644 --- a/cumulus_etl/formats/deltalake.py +++ b/cumulus_etl/formats/deltalake.py @@ -139,7 +139,7 @@ def finalize(self) -> None: table = delta.DeltaTable.forPath(self.spark, full_path) except AnalysisException: return # if the table doesn't exist because we didn't write anything, that's fine - just bail - except Exception: # pylint: disable=broad-except + except Exception: logging.exception("Could not finalize Delta Lake table %s", self.dbname) return @@ -147,7 +147,7 @@ def finalize(self) -> None: table.optimize().executeCompaction() # pool small files for better query performance table.generate("symlink_format_manifest") table.vacuum() # Clean up unused data files older than retention policy (default 7 days) - except Exception: # pylint: disable=broad-except + except Exception: logging.exception("Could not finalize Delta Lake table %s", self.dbname) def _table_path(self, dbname: str) -> str: diff --git a/cumulus_etl/loaders/fhir/ndjson_loader.py b/cumulus_etl/loaders/fhir/ndjson_loader.py index f8767eae..d3be266e 100644 --- a/cumulus_etl/loaders/fhir/ndjson_loader.py +++ b/cumulus_etl/loaders/fhir/ndjson_loader.py @@ -67,7 +67,7 @@ async def load_all(self, resources: list[str]) -> common.Directory: # This uses more disk space temporarily (copied files will get deleted once the MS tool is done and this # TemporaryDirectory gets discarded), but that seems reasonable. print("Copying ndjson input files…") - tmpdir = tempfile.TemporaryDirectory() # pylint: disable=consider-using-with + tmpdir = tempfile.TemporaryDirectory() filenames = common.ls_resources(input_root, set(resources), warn_if_empty=True) for filename in filenames: input_root.get(filename, f"{tmpdir.name}/") diff --git a/cumulus_etl/loaders/i2b2/oracle/query.py b/cumulus_etl/loaders/i2b2/oracle/query.py index 952ed46d..0394b064 100644 --- a/cumulus_etl/loaders/i2b2/oracle/query.py +++ b/cumulus_etl/loaders/i2b2/oracle/query.py @@ -120,11 +120,11 @@ def where(expression=None) -> str: return "\n WHERE " + expression if expression else "" -def AND(expression: str) -> str: # pylint: disable=invalid-name +def AND(expression: str) -> str: return f"\n AND ({expression})" -def OR(expression: str) -> str: # pylint: disable=invalid-name +def OR(expression: str) -> str: return f"\n OR ({expression})" diff --git a/cumulus_etl/nlp/extract.py b/cumulus_etl/nlp/extract.py index e23b9fbd..fdc0f616 100644 --- a/cumulus_etl/nlp/extract.py +++ b/cumulus_etl/nlp/extract.py @@ -29,7 +29,7 @@ async def ctakes_extract( try: cached_response = common.read_json(full_path) result = ctakesclient.typesystem.CtakesJSON(source=cached_response) - except Exception: # pylint: disable=broad-except + except Exception: result = await ctakesclient.client.extract(sentence, client=client) cache.makedirs(os.path.dirname(full_path)) common.write_json(full_path, result.as_json()) @@ -58,7 +58,7 @@ async def list_polarity( try: result = [ctakesclient.typesystem.Polarity(x) for x in common.read_json(full_path)] - except Exception: # pylint: disable=broad-except + except Exception: result = await ctakesclient.transformer.list_polarity( sentence, spans, client=client, model=model ) diff --git a/cumulus_etl/upload_notes/selector.py b/cumulus_etl/upload_notes/selector.py index 2d3c6ca5..8dee1966 100644 --- a/cumulus_etl/upload_notes/selector.py +++ b/cumulus_etl/upload_notes/selector.py @@ -43,7 +43,7 @@ def _create_docref_filter( else: # Just accept everything (we still want to read them though, to copy them to a possible export folder). # So this lambda just returns an iterator over its input. - return lambda x: iter(x) # pylint: disable=unnecessary-lambda + return lambda x: iter(x) def _filter_real_docrefs(docrefs_csv: str, docrefs: Iterable[dict]) -> Iterator[dict]: diff --git a/tests/covid_symptom/test_covid_results.py b/tests/covid_symptom/test_covid_results.py index 4cb488ff..c2bfe8cb 100644 --- a/tests/covid_symptom/test_covid_results.py +++ b/tests/covid_symptom/test_covid_results.py @@ -92,10 +92,10 @@ async def test_ed_note_filtering_for_nlp(self, codings, expected): async def test_non_ed_visit_is_skipped_for_covid_symptoms(self): """Verify we ignore non ED visits for the covid symptoms NLP""" docref0 = i2b2_mock_data.documentreference() - docref0["type"]["coding"][0]["code"] = "NOTE:nope" # pylint: disable=unsubscriptable-object + docref0["type"]["coding"][0]["code"] = "NOTE:nope" self.make_json("DocumentReference", "skipped", **docref0) docref1 = i2b2_mock_data.documentreference() - docref1["type"]["coding"][0]["code"] = "NOTE:149798455" # pylint: disable=unsubscriptable-object + docref1["type"]["coding"][0]["code"] = "NOTE:149798455" self.make_json("DocumentReference", "present", **docref1) await covid_symptom.CovidSymptomNlpResultsTask(self.job_config, self.scrubber).run() diff --git a/tests/ctakesmock.py b/tests/ctakesmock.py index 39bb661e..381dba52 100644 --- a/tests/ctakesmock.py +++ b/tests/ctakesmock.py @@ -34,7 +34,7 @@ def setUp(self): CtakesMixin.ctakes_port += 1 os.environ["URL_CTAKES_REST"] = f"http://localhost:{CtakesMixin.ctakes_port}/" - self.ctakes_overrides = tempfile.TemporaryDirectory() # pylint: disable=consider-using-with + self.ctakes_overrides = tempfile.TemporaryDirectory() self._run_fake_ctakes_server(f"{self.ctakes_overrides.name}/symptoms.bsv") cnlp_patcher = mock.patch( @@ -131,7 +131,7 @@ class FakeCTakesHandler(http.server.BaseHTTPRequestHandler): # We don't want that behavior, because if our mock code is misbehaving and not detecting an overrides file change, # if we time out a request on our end, it will look like a successful file change detection and mask a testing bug. - def do_POST(self): # pylint: disable=invalid-name + def do_POST(self): """Serve a POST request.""" self.server.was_called.value = 1 # signal to test framework that we were actually called diff --git a/tests/formats/test_deltalake.py b/tests/formats/test_deltalake.py index aca0e8fb..e1cfb89b 100644 --- a/tests/formats/test_deltalake.py +++ b/tests/formats/test_deltalake.py @@ -27,7 +27,7 @@ class TestDeltaLake(utils.AsyncTestCase): @classmethod def setUpClass(cls): super().setUpClass() - output_tempdir = tempfile.TemporaryDirectory() # pylint: disable=consider-using-with + output_tempdir = tempfile.TemporaryDirectory() cls.output_tempdir = output_tempdir cls.output_dir = output_tempdir.name cls.root = store.Root(output_tempdir.name) diff --git a/tests/loaders/i2b2/test_i2b2_oracle_extract.py b/tests/loaders/i2b2/test_i2b2_oracle_extract.py index 0f3d3496..75f2d34a 100644 --- a/tests/loaders/i2b2/test_i2b2_oracle_extract.py +++ b/tests/loaders/i2b2/test_i2b2_oracle_extract.py @@ -15,7 +15,7 @@ class TestOracleExtraction(AsyncTestCase): def setUp(self) -> None: super().setUp() - self.maxDiff = None # pylint: disable=invalid-name + self.maxDiff = None # Mock all the sql connection/cursor/execution stuff connect_patcher = mock.patch("cumulus_etl.loaders.i2b2.oracle.extract.connect") diff --git a/tests/loaders/i2b2/test_i2b2_oracle_query.py b/tests/loaders/i2b2/test_i2b2_oracle_query.py index e0df9d4f..edb3bf7b 100644 --- a/tests/loaders/i2b2/test_i2b2_oracle_query.py +++ b/tests/loaders/i2b2/test_i2b2_oracle_query.py @@ -50,7 +50,7 @@ class TestOracleQueries(utils.AsyncTestCase): def setUp(self) -> None: super().setUp() - self.maxDiff = None # pylint: disable=invalid-name + self.maxDiff = None def test_list_patient(self): common.print_header("# patient") diff --git a/tests/loaders/i2b2/test_i2b2_transform.py b/tests/loaders/i2b2/test_i2b2_transform.py index 87eac137..d6065aee 100644 --- a/tests/loaders/i2b2/test_i2b2_transform.py +++ b/tests/loaders/i2b2/test_i2b2_transform.py @@ -18,7 +18,6 @@ def test_to_fhir_patient(self): self.assertEqual(str(12345), subject["id"]) self.assertEqual("2005-06-07", subject["birthDate"]) self.assertEqual("female", subject["gender"]) - # pylint: disable-next=unsubscriptable-object self.assertEqual("02115", subject["address"][0]["postalCode"]) @ddt.data( diff --git a/tests/loaders/ndjson/test_bulk_export.py b/tests/loaders/ndjson/test_bulk_export.py index 50f91fc6..6f13fadd 100644 --- a/tests/loaders/ndjson/test_bulk_export.py +++ b/tests/loaders/ndjson/test_bulk_export.py @@ -472,7 +472,7 @@ async def test_delay(self, mock_sleep): await self.export() # 86760 == 24 hours + six minutes - self.assertEqual(86760, self.exporter._total_wait_time) # pylint: disable=protected-access + self.assertEqual(86760, self.exporter._total_wait_time) self.assertListEqual( [ diff --git a/tests/loaders/ndjson/test_ndjson_loader.py b/tests/loaders/ndjson/test_ndjson_loader.py index 4c2c7186..a44b8770 100644 --- a/tests/loaders/ndjson/test_ndjson_loader.py +++ b/tests/loaders/ndjson/test_ndjson_loader.py @@ -21,7 +21,7 @@ class TestNdjsonLoader(AsyncTestCase): def setUp(self): super().setUp() - self.jwks_file = tempfile.NamedTemporaryFile() # pylint: disable=consider-using-with + self.jwks_file = tempfile.NamedTemporaryFile() self.jwks_path = self.jwks_file.name self.jwks_file.write(b'{"fake":"jwks"}') self.jwks_file.flush() diff --git a/tests/s3mock.py b/tests/s3mock.py index befdc934..2c33c750 100644 --- a/tests/s3mock.py +++ b/tests/s3mock.py @@ -54,7 +54,7 @@ def setUp(self): try: self.s3fs.mkdir(self.bucket) # create the bucket as a quickstart - except Exception: # pylint: disable=broad-except + except Exception: self._kill_moto_server() self.fail("Stale moto server") diff --git a/tests/utils.py b/tests/utils.py index 74dc97bf..f62372cc 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -57,7 +57,7 @@ def setUp(self): def make_tempdir(self) -> str: """Creates a temporary dir that will be automatically cleaned up""" - tempdir = tempfile.TemporaryDirectory() # pylint: disable=consider-using-with + tempdir = tempfile.TemporaryDirectory() self.addCleanup(tempdir.cleanup) return tempdir.name @@ -120,7 +120,7 @@ def setUp(self): filecmp.clear_cache() # you'll always want this when debugging - self.maxDiff = None # pylint: disable=invalid-name + self.maxDiff = None def assert_etl_output_equal(self, left: str, right: str): """Compares the etl output with the expected json structure""" @@ -188,7 +188,7 @@ def setUp(self): ).export(as_dict=True) self.fhir_jwks = {"keys": [jwk_token]} - self._fhir_jwks_file = tempfile.NamedTemporaryFile() # pylint: disable=consider-using-with + self._fhir_jwks_file = tempfile.NamedTemporaryFile() self._fhir_jwks_file.write(json.dumps(self.fhir_jwks).encode("utf8")) self._fhir_jwks_file.flush() self.addCleanup(self._fhir_jwks_file.close)