Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch the ETL to ruff #331

Merged
merged 3 commits into from
Jul 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 2 additions & 21 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -118,24 +118,5 @@ jobs:
python -m pip install --upgrade pip
pip install .[dev]

- name: Run pycodestyle
# E203: pycodestyle is a little too rigid about slices & whitespace
# See https://black.readthedocs.io/en/stable/the_black_code_style/current_style.html#slices
# W503: a default ignore that we are restoring
run: |
pycodestyle --max-line-length=120 --ignore=E203,W503 .

- name: Run pylint
if: success() || failure() # still run pylint if above checks fail
run: |
pylint cumulus_etl tests

- name: Run bandit
if: success() || failure() # still run bandit if above checks fail
run: |
bandit -c pyproject.toml -r .

- name: Run black
if: success() || failure() # still run black if above checks fails
run: |
black --check --verbose --line-length 120 .
- name: Run ruff
run: ruff check --output-format=github .
17 changes: 8 additions & 9 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
repos:
- repo: https://github.com/psf/black
rev: 24.4.2 # keep in rough sync with pyproject.toml
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.5.5 # keep in rough sync with pyproject.toml
hooks:
- id: black
entry: bash -c 'black "$@"; git add -u' --
# It is recommended to specify the latest version of Python
# supported by your project here, or alternatively use
# pre-commit's default_language_version, see
# https://pre-commit.com/#top_level-default_language_version
language_version: python3.12
- name: Ruff formatting
id: ruff-format
entry: bash -c 'ruff format --force-exclude "$@"; git add -u' --
- name: Ruff linting
id: ruff
stages: [pre-push]
5 changes: 3 additions & 2 deletions cumulus_etl/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ def get_subcommand(argv: list[str]) -> str | None:
if arg in Command.values():
return argv.pop(i) # remove it to make later parsers' jobs easier
elif not arg.startswith("-"):
return None # first positional arg did not match a known command, assume default command
# first positional arg did not match a known command, assume default command
return None


async def main(argv: list[str]) -> None:
Expand Down Expand Up @@ -71,7 +72,7 @@ async def main(argv: list[str]) -> None:
if not subcommand:
# Add a note about other subcommands we offer, and tell argparse not to wrap our formatting
parser.formatter_class = argparse.RawDescriptionHelpFormatter
parser.description += "\n\n" "other commands available:\n" " convert\n" " upload-notes"
parser.description += "\n\nother commands available:\n convert\n upload-notes"
run_method = etl.run_etl

with tempfile.TemporaryDirectory() as tempdir:
Expand Down
31 changes: 23 additions & 8 deletions cumulus_etl/cli_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,28 @@ def add_auth(parser: argparse.ArgumentParser) -> None:
group.add_argument("--smart-client-id", metavar="ID", help="Client ID for SMART authentication")
group.add_argument("--smart-jwks", metavar="PATH", help="JWKS file for SMART authentication")
group.add_argument("--basic-user", metavar="USER", help="Username for Basic authentication")
group.add_argument("--basic-passwd", metavar="PATH", help="Password file for Basic authentication")
group.add_argument("--bearer-token", metavar="PATH", help="Token file for Bearer authentication")
group.add_argument("--fhir-url", metavar="URL", help="FHIR server base URL, only needed if you exported separately")
group.add_argument(
"--basic-passwd", metavar="PATH", help="Password file for Basic authentication"
)
group.add_argument(
"--bearer-token", metavar="PATH", help="Token file for Bearer authentication"
)
group.add_argument(
"--fhir-url",
metavar="URL",
help="FHIR server base URL, only needed if you exported separately",
)


def add_aws(parser: argparse.ArgumentParser) -> None:
group = parser.add_argument_group("AWS")
group.add_argument("--s3-region", metavar="REGION", help="If using S3 paths (s3://...), this is their region")
group.add_argument(
"--s3-kms-key", metavar="KEY", help="If using S3 paths (s3://...), this is the KMS key ID to use"
"--s3-region", metavar="REGION", help="If using S3 paths (s3://...), this is their region"
)
group.add_argument(
"--s3-kms-key",
metavar="KEY",
help="If using S3 paths (s3://...), this is the KMS key ID to use",
)


Expand All @@ -46,18 +58,21 @@ def add_debugging(parser: argparse.ArgumentParser):
return group


def make_export_dir(export_to: str = None) -> common.Directory:
def make_export_dir(export_to: str | None = None) -> common.Directory:
"""Makes a temporary directory to drop exported ndjson files into"""
# Handle the easy case -- just a random temp dir
if not export_to:
return tempfile.TemporaryDirectory() # pylint: disable=consider-using-with
return tempfile.TemporaryDirectory()

# OK the user has a specific spot in mind. Let's do some quality checks. It must be local and empty.

if urllib.parse.urlparse(export_to).netloc:
# We require a local folder because that's all that the MS deid tool can operate on.
# If we were to relax this requirement, we'd want to copy the exported files over to a local dir.
errors.fatal(f"The target export folder '{export_to}' must be local. ", errors.BULK_EXPORT_FOLDER_NOT_LOCAL)
errors.fatal(
f"The target export folder '{export_to}' must be local. ",
errors.BULK_EXPORT_FOLDER_NOT_LOCAL,
)

confirm_dir_is_empty(store.Root(export_to, create=True))

Expand Down
19 changes: 12 additions & 7 deletions cumulus_etl/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

from cumulus_etl import store


###############################################################################
#
# Types
Expand Down Expand Up @@ -77,7 +76,9 @@ def get_temp_dir(subdir: str) -> str:


def ls_resources(root: store.Root, resources: set[str], warn_if_empty: bool = False) -> list[str]:
found_files = cumulus_fhir_support.list_multiline_json_in_dir(root.path, resources, fsspec_fs=root.fs)
found_files = cumulus_fhir_support.list_multiline_json_in_dir(
root.path, resources, fsspec_fs=root.fs
)

if warn_if_empty:
# Invert the {path: type} found_files dictionary into {type: [paths...]}
Expand Down Expand Up @@ -151,7 +152,7 @@ def read_json(path: str) -> Any:
return json.load(f)


def write_json(path: str, data: Any, indent: int = None) -> None:
def write_json(path: str, data: Any, indent: int | None = None) -> None:
"""
Writes data to the given path, in json format
:param path: filesystem path
Expand All @@ -176,7 +177,9 @@ def read_ndjson(root: store.Root, path: str) -> Iterator[dict]:
yield from cumulus_fhir_support.read_multiline_json(path, fsspec_fs=root.fs)


def read_resource_ndjson(root: store.Root, resource: str, warn_if_empty: bool = False) -> Iterator[dict]:
def read_resource_ndjson(
root: store.Root, resource: str, warn_if_empty: bool = False
) -> Iterator[dict]:
"""
Grabs all ndjson files from a folder, of a particular resource type.
"""
Expand Down Expand Up @@ -240,7 +243,9 @@ def read_local_line_count(path) -> int:
count = 0
buf = None
with open(path, "rb") as f:
bufgen = itertools.takewhile(lambda x: x, (f.raw.read(1024 * 1024) for _ in itertools.repeat(None)))
bufgen = itertools.takewhile(
lambda x: x, (f.raw.read(1024 * 1024) for _ in itertools.repeat(None))
)
for buf in bufgen:
count += buf.count(b"\n")
if buf and buf[-1] != ord("\n"): # catch a final line without a trailing newline
Expand Down Expand Up @@ -354,7 +359,7 @@ def datetime_now(local: bool = False) -> datetime.datetime:
return now


def timestamp_datetime(time: datetime.datetime = None) -> str:
def timestamp_datetime(time: datetime.datetime | None = None) -> str:
"""
Human-readable UTC date and time
:return: MMMM-DD-YYY hh:mm:ss
Expand All @@ -363,7 +368,7 @@ def timestamp_datetime(time: datetime.datetime = None) -> str:
return time.strftime("%Y-%m-%d %H:%M:%S")


def timestamp_filename(time: datetime.datetime = None) -> str:
def timestamp_filename(time: datetime.datetime | None = None) -> str:
"""
Human-readable UTC date and time suitable for a filesystem path

Expand Down
2 changes: 1 addition & 1 deletion cumulus_etl/completion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
"""

from .schema import (
COMPLETION_TABLE,
COMPLETION_ENCOUNTERS_TABLE,
COMPLETION_TABLE,
completion_encounters_output_args,
completion_encounters_schema,
completion_format_args,
Expand Down
1 change: 0 additions & 1 deletion cumulus_etl/completion/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import pyarrow


COMPLETION_TABLE = "etl__completion"
COMPLETION_ENCOUNTERS_TABLE = "etl__completion_encounters"

Expand Down
24 changes: 17 additions & 7 deletions cumulus_etl/deid/codebook.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class Codebook:
Some IDs may be cryptographically hashed versions of the real ID, some may be entirely random.
"""

def __init__(self, codebook_dir: str = None):
def __init__(self, codebook_dir: str | None = None):
"""
:param codebook_dir: saved codebook path or None (initialize empty)
"""
Expand Down Expand Up @@ -70,7 +70,9 @@ def real_ids(self, resource_type: str, fake_ids: Iterable[str]) -> Iterator[str]
if real_id:
yield real_id
else:
logging.warning("Real ID not found for anonymous %s ID %s. Ignoring.", resource_type, fake_id)
logging.warning(
"Real ID not found for anonymous %s ID %s. Ignoring.", resource_type, fake_id
)


###############################################################################
Expand All @@ -83,7 +85,7 @@ def real_ids(self, resource_type: str, fake_ids: Iterable[str]) -> Iterator[str]
class CodebookDB:
"""Class to hold codebook data and read/write it to storage"""

def __init__(self, codebook_dir: str = None):
def __init__(self, codebook_dir: str | None = None):
"""
Create a codebook database.

Expand All @@ -110,7 +112,9 @@ def __init__(self, codebook_dir: str = None):
if codebook_dir:
self._load_saved_settings(common.read_json(os.path.join(codebook_dir, "codebook.json")))
try:
self.cached_mapping = common.read_json(os.path.join(codebook_dir, "codebook-cached-mappings.json"))
self.cached_mapping = common.read_json(
os.path.join(codebook_dir, "codebook-cached-mappings.json")
)
except (FileNotFoundError, PermissionError):
pass

Expand Down Expand Up @@ -145,7 +149,9 @@ def encounter(self, real_id: str, cache_mapping: bool = True) -> str:
"""
return self._preserved_resource_hash("Encounter", real_id, cache_mapping)

def _preserved_resource_hash(self, resource_type: str, real_id: str, cache_mapping: bool) -> str:
def _preserved_resource_hash(
self, resource_type: str, real_id: str, cache_mapping: bool
) -> str:
"""
Get a hashed ID and preserve the mapping.

Expand All @@ -170,7 +176,10 @@ def _preserved_resource_hash(self, resource_type: str, real_id: str, cache_mappi

# Save this generated ID mapping so that we can store it for debugging purposes later.
# Only save if we don't have a legacy mapping, so that we don't have both in memory at the same time.
if cache_mapping and self.cached_mapping.setdefault(resource_type, {}).get(real_id) != fake_id:
if (
cache_mapping
and self.cached_mapping.setdefault(resource_type, {}).get(real_id) != fake_id
):
# We expect the IDs to always be identical. The above check is mostly concerned with None != fake_id,
# but is written defensively in case a bad mapping got saved for some reason.
self.cached_mapping[resource_type][real_id] = fake_id
Expand Down Expand Up @@ -206,7 +215,8 @@ def resource_hash(self, real_id: str) -> str:
def _id_salt(self) -> bytes:
"""Returns the saved salt or creates and saves one if needed"""
salt = self.settings["id_salt"]
return binascii.unhexlify(salt) # revert from doubled hex 64-char string representation back to just 32 bytes
# revert from doubled hex 64-char string representation back to just 32 bytes
return binascii.unhexlify(salt)

def _load_saved_settings(self, saved: dict) -> None:
"""
Expand Down
14 changes: 10 additions & 4 deletions cumulus_etl/deid/mstool.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,15 @@ async def run_mstool(input_dir: str, output_dir: str) -> None:

if process.returncode != 0:
print(
f"An error occurred while de-identifying the input resources:\n\n{stderr.decode('utf8')}", file=sys.stderr
f"An error occurred while de-identifying the input resources:\n\n{stderr.decode('utf8')}",
file=sys.stderr,
)
raise SystemExit(errors.MSTOOL_FAILED)


async def _wait_for_completion(process: asyncio.subprocess.Process, input_dir: str, output_dir: str) -> (str, str):
async def _wait_for_completion(
process: asyncio.subprocess.Process, input_dir: str, output_dir: str
) -> (str, str):
"""Waits for the MS tool to finish, with a nice little progress bar, returns stdout and stderr"""
stdout, stderr = None, None

Expand Down Expand Up @@ -74,7 +77,8 @@ def _compare_file_sizes(target: dict[str, int], current: dict[str, int]) -> floa
total_current = 0
for filename, size in current.items():
if filename in target:
total_current += target[filename] # use target size, because current (de-identified) files will be smaller
# use target size, because current (de-identified) files will be smaller
total_current += target[filename]
else: # an in-progress file is being written out
total_current += size
return total_current / total_expected
Expand All @@ -93,4 +97,6 @@ def _get_file_size_safe(path: str) -> int:

def _count_file_sizes(pattern: str) -> dict[str, int]:
"""Returns all files that match the given pattern and their sizes"""
return {os.path.basename(filename): _get_file_size_safe(filename) for filename in glob.glob(pattern)}
return {
os.path.basename(filename): _get_file_size_safe(filename) for filename in glob.glob(pattern)
}
Loading
Loading