Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PSM table generation #150

Merged
merged 17 commits into from
Dec 5, 2023
Merged
12 changes: 11 additions & 1 deletion cumulus_library/.sqlfluff
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ dialect = athena
sql_file_exts = .sql,.sql.jinja
# this rule overfires on athena nested arrays
exclude_rules=references.from,structure.column_order,aliasing.unused
max_line_length = 88
max_line_length = 90

[sqlfluff:indentation]
template_blocks_indent = false
Expand All @@ -18,20 +18,30 @@ capitalisation_policy = upper
[sqlfluff:templater:jinja:context]
code_systems = ["http://snomed.info/sct", "http://hl7.org/fhir/sid/icd-10-cm"]
col_type_list = ["a string","b string"]
columns = ['a','b']
cc_columns = [{"name": "baz", "is_array": True}, {"name": "foobar", "is_array": False}]
cc_column = 'code'
code_system_tables = [{table_name":"hasarray","column_name":"acol","is_bare_coding":False,"is_array":True, "has_data": True},{"table_name":"noarray","column_name":"col","is_bare_coding":False,"is_array":False, "has_data": True}{"table_name":"bare","column_name":"bcol","is_bare_coding":True,"is_array":False, "has_data": True},{"table_name":"empty","column_name":"empty","is_bare_coding":False,"is_array":False, "has_data": False}]
column_name = 'bar'
conditions = ["1 > 0", "1 < 2"]
count_ref = count_ref
count_table = count_table
dataset = [["foo","foo"],["bar","bar"]]
dependent_variable = is_flu
ext_systems = ["omb", "text"]
field = 'column_name'
filter_table = filter_table
fhir_extension = fhir_extension
fhir_resource = patient
id = 'id'
join_cols_by_table = { "join_table": { "join_id": "enc_ref","included_cols": [["a"], ["b", "c"]]}}
join_id = subject_ref
medication_datasources = {"by_contained_ref" : True, "by_external_ref" : True}
neg_source_table = neg_source_table
output_table_name = 'created_table'
prefix = Test
primary_ref = encounter_ref
pos_source_table = pos_source_table
schema_name = test_schema
source_table = source_table
source_id = source_id
Expand Down
11 changes: 11 additions & 0 deletions cumulus_library/base_table_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,17 @@ def execute_queries(
for query in self.queries:
query_console_output(verbose, query, progress, task)
cursor.execute(query)
self.post_execution(cursor, schema, verbose, drop_table)

def post_execution(
self,
cursor: DatabaseCursor,
schema: str,
verbose: bool,
drop_table: bool = False,
):
"""Hook for any additional actions to run after execute_queries"""
pass

def comment_queries(self):
"""Convenience method for annotating outputs of template generators to disk"""
Expand Down
27 changes: 26 additions & 1 deletion cumulus_library/databases.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,12 @@
"""Abstraction layers for supported database backends (e.g. AWS & DuckDB)"""
"""Abstraction layers for supported database backends (e.g. AWS & DuckDB)

By convention, to maintain this as a relatively light wrapper layer, if you have
to chose between a convenience function in a specific library (as an example, the
[pyathena to_sql function](https://github.com/laughingman7743/PyAthena/#to-sql))
or using raw sql directly in some form, you should do the latter. This not a law;
if there's a compelling reason to do so, just make sure you add an appropriate
wrapper method in one of DatabaseCursor or DatabaseBackend.
"""

import abc
import datetime
Expand Down Expand Up @@ -47,6 +55,14 @@ def __init__(self, schema_name: str):
def cursor(self) -> DatabaseCursor:
"""Returns a connection to the backing database"""

@abc.abstractmethod
def pandas_cursor(self) -> DatabaseCursor:
"""Returns a connection to the backing database optimized for dataframes

If your database does not provide an optimized cursor, this should function the
same as a vanilla cursor.
"""

Comment on lines +58 to +65
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this is the change to the DB class I was mentioning, and I'm hoping that this comment explains why it's in here the way it is, but to be a bit more verbose about this: pyathena has a method that dramatically improves query execution when it's looking to return a dataframe - something about how they handle chunking under the hood. So, in context, when I'm passing a cursor to a method, I sometimes elect to specifically hand one of these pandas cursors off.

I did this while testing the PSM code (where the cursor is the entrypoint - we :could: rewrite table builders to take a Connection rather than a Cursor, but that's a big refactor by itself and this is already pretty gross), and in the future manifest parsing hook for this to come as a followon PR, I'm planning on specifying the pandas cursor for PSM invocation. The DuckDB version just returns a regular cursor.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I'm fine with this change based on the constraint of "Cursor is the interface, not DatabaseBackend/Connection". Some thoughts around it though:

  • I'd like to see as_pandas added to the Cursor protocol we have, so that consumers of Library know it's contractually available. (See below for some commentary on this.)
  • I'd like to see execute_as_pandas dropped -- I only added that to avoid the need for extending cursors like this. But now we could simplify that interface.
  • The solution of creating an alias for as_pandas in the duckdb returned cursor is fine, but gives me pause because clever monkey-patching can be taken too far. 😄 If this setup gets more complicated, I might vote for a DuckCursor wrapper object that does similar kind of translations needed in future.
  • We really now have two kinds of Cursors - those for which as_pandas is available and those for which it isn't. What happens on a PyAthena normal cursor if you call as_pandas?
    • For our purposes, maybe AthenaDatabaseBackend should create a wrapper AthenaCursor object that throws an exception if you try to call as_pandas on the wrong cursor object.
    • Or even better probably, have two different Cursor protocols. One pandas-powered and one that isn't. That way method signatures would be clear about which cursor they expect to be handed. (if that is always clear?)
    • You could also add Cursor wrappers and a method like .get_database_backend() or something to give access to parent objects without introducing two different kinds of Cursors. But that's a little clunky in its own way. But may feel less clunky.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

honestly - i think i like the idea of refactoring one way or another to get these more in line, i'm just trying to not do it as part of this PR for complexity reasons - we can maybe natter about the shape? some options, pulling on some of these threads:

  • I don't hate making a database connection the atomic unit, but it is probably going to touch the most things
  • as_pandas is, apparently, available as a util method that can be called on a pyathena cursor, so we could switch to that and keep the cursor space down to one per db. that might slot better into the execute_as_pandas paradigm
  • I think genereally a PEP cursor has a reference back to its connection, so maybe it's not the end of the world to have it get the database backend, though i think that's my least favorite of these.

@abc.abstractmethod
def execute_as_pandas(self, sql: str) -> pandas.DataFrame:
"""Returns a pandas.DataFrame version of the results from the provided SQL"""
Expand Down Expand Up @@ -85,6 +101,9 @@ def __init__(self, region: str, workgroup: str, profile: str, schema_name: str):
def cursor(self) -> AthenaCursor:
return self.connection.cursor()

def pandas_cursor(self) -> AthenaPandasCursor:
return self.pandas_cursor

def execute_as_pandas(self, sql: str) -> pandas.DataFrame:
return self.pandas_cursor.execute(sql).as_pandas()

Expand All @@ -95,6 +114,8 @@ class DuckDatabaseBackend(DatabaseBackend):
def __init__(self, db_file: str):
super().__init__("main")
self.connection = duckdb.connect(db_file)
# Aliasing Athena's as_pandas to duckDB's df cast
setattr(duckdb.DuckDBPyConnection, "as_pandas", duckdb.DuckDBPyConnection.df)
dogversioning marked this conversation as resolved.
Show resolved Hide resolved

# Paper over some syntax differences between Athena and DuckDB
self.connection.create_function(
Expand Down Expand Up @@ -150,6 +171,10 @@ def cursor(self) -> duckdb.DuckDBPyConnection:
# because then we'd have to re-register our json tables.
return self.connection

def pandas_cursor(self) -> duckdb.DuckDBPyConnection:
# Since this is not provided, return the vanilla cursor
return self.connection

def execute_as_pandas(self, sql: str) -> pandas.DataFrame:
# We call convert_dtypes here in case there are integer columns.
# Pandas will normally cast nullable-int as a float type unless
Expand Down
271 changes: 271 additions & 0 deletions cumulus_library/statistics/psm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,271 @@
# Module for generating Propensity Score matching cohorts

import numpy as np
import pandas
import sys
import toml

from psmpy import PsmPy


import json
from pathlib import PosixPath
from dataclasses import dataclass

from cumulus_library.cli import StudyBuilder
from cumulus_library.databases import DatabaseCursor
from cumulus_library.base_table_builder import BaseTableBuilder
from cumulus_library.template_sql.templates import (
get_ctas_query_from_df,
get_drop_view_table,
)
from cumulus_library.template_sql.statistics.psm_templates import (
get_distinct_ids,
get_create_covariate_table,
)


@dataclass
class PsmConfig:
"""Provides expected values for PSM execution

These values should be read in from a toml configuration file.
See docs/statistics/propensity-score-matching.md for an example with details about
the expected values for these fields.
"""

classification_json: str
pos_source_table: str
neg_source_table: str
target_table: str
primary_ref: str
count_ref: str
count_table: str
dependent_variable: str
pos_sample_size: int
neg_sample_size: int
join_cols_by_table: dict[str, dict]
seed: int


class PsmBuilder(BaseTableBuilder):
"""TableBuilder for creating PSM tables"""

display_text = "Building PSM tables..."

def __init__(self, toml_config_path: str):
"""Loads PSM job details from a PSM configuration file"""
super().__init__()
dogversioning marked this conversation as resolved.
Show resolved Hide resolved
# We're stashing the toml path for error reporting later
self.toml_path = toml_config_path
try:
with open(self.toml_path, encoding="UTF-8") as file:
toml_config = toml.load(file)

except OSError:
sys.exit(f"PSM configuration not found at {self.toml_path}")
try:
self.config = PsmConfig(
classification_json=f"{PosixPath(self.toml_path).parent}/{toml_config['classification_json']}",
pos_source_table=toml_config["pos_source_table"],
neg_source_table=toml_config["neg_source_table"],
target_table=toml_config["target_table"],
primary_ref=toml_config["primary_ref"],
dependent_variable=toml_config["dependent_variable"],
pos_sample_size=toml_config["pos_sample_size"],
neg_sample_size=toml_config["neg_sample_size"],
join_cols_by_table=toml_config.get("join_cols_by_table", {}),
count_ref=toml_config.get("count_ref", None),
count_table=toml_config.get("count_table", None),
seed=toml_config.get("seed", 123),
)
except KeyError:
# TODO: add link to docsite when you have network access
sys.exit(
f"PSM configuration at {toml_config_path} contains missing/invalid keys."
"Check the PSM documentation for an example config with more details"
)

def _get_symptoms_dict(self, path: str) -> dict:
"""convenience function for loading symptoms dictionaries from a json file"""
with open(path, encoding="UTF-8") as f:
symptoms = json.load(f)
return symptoms

def _get_sampled_ids(
self,
cursor: DatabaseCursor,
schema: str,
query: str,
sample_size: int,
dependent_variable: str,
is_positive: bool,
):
"""Creates a table containing randomly sampled patients for PSM analysis

To use this, it is assumed you have already identified a cohort of positively
IDed patients as a manual process.
:param cursor: A valid DatabaseCusror:
:param schema: the schema/database name where the data exists
:param query: a query generated from the psm_dsitinct_ids template
:param sample_size: the number of records to include in the random sample.
This should generally be >= 20.
:param dependent_variable: the name to use for your filtering column
:param is_positive: defines the value to be used for your filtering column
"""
df = cursor.execute(query).as_pandas()
df = (
df.sort_values(by=[self.config.primary_ref])
# .reset_index()
# .drop("index", axis=1)
)

df = (
# TODO: flip polarity of replace kwarg after increasing the size of the
# unit testing data
df.sample(n=sample_size, random_state=self.config.seed, replace=True)
.sort_values(by=[self.config.primary_ref])
.reset_index()
.drop("index", axis=1)
)

df[dependent_variable] = is_positive
return df

def _create_covariate_table(self, cursor: DatabaseCursor, schema: str):
"""Creates a covariate table from the loaded toml config"""
# checks for primary & link ref being the same
source_refs = list({self.config.primary_ref, self.config.count_ref} - {None})
pos_query = get_distinct_ids(source_refs, self.config.pos_source_table)
pos = self._get_sampled_ids(
cursor,
schema,
pos_query,
self.config.pos_sample_size,
self.config.dependent_variable,
1,
)
neg_query = get_distinct_ids(
source_refs,
self.config.neg_source_table,
join_id=self.config.primary_ref,
filter_table=self.config.pos_source_table,
)
neg = self._get_sampled_ids(
cursor,
schema,
neg_query,
self.config.neg_sample_size,
self.config.dependent_variable,
0,
)
cohort = pandas.concat([pos, neg])

# Replace table (if it exists)
# TODO - replace with timestamp prepended table in future PR
drop = get_drop_view_table(
f"{self.config.pos_source_table}_sampled_ids", "TABLE"
)
cursor.execute(drop)
ctas_query = get_ctas_query_from_df(
schema,
f"{self.config.pos_source_table}_sampled_ids",
cohort,
)
self.queries.append(ctas_query)
# TODO - replace with timestamp prepended table
drop = get_drop_view_table(self.config.target_table, "TABLE")
cursor.execute(drop)
dataset_query = get_create_covariate_table(
target_table=self.config.target_table,
pos_source_table=self.config.pos_source_table,
neg_source_table=self.config.neg_source_table,
primary_ref=self.config.primary_ref,
dependent_variable=self.config.dependent_variable,
join_cols_by_table=self.config.join_cols_by_table,
count_ref=self.config.count_ref,
count_table=self.config.count_table,
)
self.queries.append(dataset_query)

def generate_psm_analysis(self, cursor: DatabaseCursor, schema: str):
"""Runs PSM statistics on generated tables"""
df = cursor.execute(f"select * from {self.config.target_table}").as_pandas()
symptoms_dict = self._get_symptoms_dict(self.config.classification_json)
for dependent_variable, codes in symptoms_dict.items():
df[dependent_variable] = df["code"].apply(lambda x: 1 if x in codes else 0)
df = df.drop(columns="code")
# instance_count present but unused for PSM if table contains a count_ref input
# (it's intended for manual review)
df = df.drop(columns="instance_count", errors="ignore")

columns = []
if self.config.join_cols_by_table is not None:
for table_config in self.config.join_cols_by_table.values():
for column in table_config["included_cols"]:
# If there are two elements, it's a SQL column that has been
# aliased, so we'll look for the alias name
if len(column) == 2:
columns.append(column[1])
# If there is one element, it's a straight SQL column we can
# use with no modification
elif len(column) == 1:
columns.append(column[0])
else:
sys.exit(
f"PSM config at {self.toml_path} contains an "
f"unexpected SQL column definition: {column}."
"Check the PSM documentation for valid usages."
)

# This code block is replacing a column which may contain several categories
# (like male/female/other/unknown for AdministrativeGender), and converts
# it into a series of 1-hot columns for each distinct value in that column,
for column in columns:
encoded_df = pandas.get_dummies(df[column])
df = pandas.concat([df, encoded_df], axis=1)
df = df.drop(column, axis=1)
df = df.reset_index()
Comment on lines +224 to +228
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A block like this could use a comment about why it's doing what it's doing. My rough attempt: you're replacing each column with a dummy version of that column. But even after reading the pandas docs on get_dummies, I'm not 💯 on what that means.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah i'll add something - this is converting to a 1-hot encoding for all values of that column, basically pivoting the column values to new column headers.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated with some explanitory text


try:
psm = PsmPy(
df,
treatment=self.config.dependent_variable,
indx=self.config.primary_ref,
exclude=[],
)
# This function populates the psm.predicted_data element, which is required
# for things like the knn_matched() function call
# TODO: create graph from this data
psm.logistic_ps(balance=True)
# This function populates the psm.df_matched element
# TODO: flip replacement to false after increasing sample data size
# TODO: create graph from this data
psm.knn_matched(
matcher="propensity_logit",
replacement=True,
caliper=None,
drop_unmatched=True,
)

except ZeroDivisionError:
sys.exit(
"Encountered a divide by zero error during statistical graph generation. Try increasing your sample size."
)
except ValueError:
sys.exit(
"Encountered a value error during KNN matching. Try increasing your sample size."
)

def prepare_queries(self, cursor: object, schema: str):
self._create_covariate_table(cursor, schema)

def post_execution(
self,
cursor: object,
schema: str,
verbose: bool,
drop_table: bool = False,
):
# super().execute_queries(cursor, schema, verbose, drop_table)
self.generate_psm_analysis(cursor, schema)
Loading