-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
PSM table generation #150
PSM table generation #150
Changes from all commits
9b67de7
fb745ff
5d7ef49
c391b4f
4f270e1
64096b3
eac38ae
af4a27b
0a477ba
e684d8f
86c70bc
561f728
53458a2
1f09718
6274b80
5ab07b9
aa81d3b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,271 @@ | ||
# Module for generating Propensity Score matching cohorts | ||
|
||
import numpy as np | ||
import pandas | ||
import sys | ||
import toml | ||
|
||
from psmpy import PsmPy | ||
|
||
|
||
import json | ||
from pathlib import PosixPath | ||
from dataclasses import dataclass | ||
|
||
from cumulus_library.cli import StudyBuilder | ||
from cumulus_library.databases import DatabaseCursor | ||
from cumulus_library.base_table_builder import BaseTableBuilder | ||
from cumulus_library.template_sql.templates import ( | ||
get_ctas_query_from_df, | ||
get_drop_view_table, | ||
) | ||
from cumulus_library.template_sql.statistics.psm_templates import ( | ||
get_distinct_ids, | ||
get_create_covariate_table, | ||
) | ||
|
||
|
||
@dataclass | ||
class PsmConfig: | ||
"""Provides expected values for PSM execution | ||
|
||
These values should be read in from a toml configuration file. | ||
See docs/statistics/propensity-score-matching.md for an example with details about | ||
the expected values for these fields. | ||
""" | ||
|
||
classification_json: str | ||
pos_source_table: str | ||
neg_source_table: str | ||
target_table: str | ||
primary_ref: str | ||
count_ref: str | ||
count_table: str | ||
dependent_variable: str | ||
pos_sample_size: int | ||
neg_sample_size: int | ||
join_cols_by_table: dict[str, dict] | ||
seed: int | ||
|
||
|
||
class PsmBuilder(BaseTableBuilder): | ||
"""TableBuilder for creating PSM tables""" | ||
|
||
display_text = "Building PSM tables..." | ||
|
||
def __init__(self, toml_config_path: str): | ||
"""Loads PSM job details from a PSM configuration file""" | ||
super().__init__() | ||
dogversioning marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# We're stashing the toml path for error reporting later | ||
self.toml_path = toml_config_path | ||
try: | ||
with open(self.toml_path, encoding="UTF-8") as file: | ||
toml_config = toml.load(file) | ||
|
||
except OSError: | ||
sys.exit(f"PSM configuration not found at {self.toml_path}") | ||
try: | ||
self.config = PsmConfig( | ||
classification_json=f"{PosixPath(self.toml_path).parent}/{toml_config['classification_json']}", | ||
pos_source_table=toml_config["pos_source_table"], | ||
neg_source_table=toml_config["neg_source_table"], | ||
target_table=toml_config["target_table"], | ||
primary_ref=toml_config["primary_ref"], | ||
dependent_variable=toml_config["dependent_variable"], | ||
pos_sample_size=toml_config["pos_sample_size"], | ||
neg_sample_size=toml_config["neg_sample_size"], | ||
join_cols_by_table=toml_config.get("join_cols_by_table", {}), | ||
count_ref=toml_config.get("count_ref", None), | ||
count_table=toml_config.get("count_table", None), | ||
seed=toml_config.get("seed", 123), | ||
) | ||
except KeyError: | ||
# TODO: add link to docsite when you have network access | ||
sys.exit( | ||
f"PSM configuration at {toml_config_path} contains missing/invalid keys." | ||
"Check the PSM documentation for an example config with more details" | ||
) | ||
|
||
def _get_symptoms_dict(self, path: str) -> dict: | ||
"""convenience function for loading symptoms dictionaries from a json file""" | ||
with open(path, encoding="UTF-8") as f: | ||
symptoms = json.load(f) | ||
return symptoms | ||
|
||
def _get_sampled_ids( | ||
self, | ||
cursor: DatabaseCursor, | ||
schema: str, | ||
query: str, | ||
sample_size: int, | ||
dependent_variable: str, | ||
is_positive: bool, | ||
): | ||
"""Creates a table containing randomly sampled patients for PSM analysis | ||
|
||
To use this, it is assumed you have already identified a cohort of positively | ||
IDed patients as a manual process. | ||
:param cursor: A valid DatabaseCusror: | ||
:param schema: the schema/database name where the data exists | ||
:param query: a query generated from the psm_dsitinct_ids template | ||
:param sample_size: the number of records to include in the random sample. | ||
This should generally be >= 20. | ||
:param dependent_variable: the name to use for your filtering column | ||
:param is_positive: defines the value to be used for your filtering column | ||
""" | ||
df = cursor.execute(query).as_pandas() | ||
df = ( | ||
df.sort_values(by=[self.config.primary_ref]) | ||
# .reset_index() | ||
# .drop("index", axis=1) | ||
) | ||
|
||
df = ( | ||
# TODO: flip polarity of replace kwarg after increasing the size of the | ||
# unit testing data | ||
df.sample(n=sample_size, random_state=self.config.seed, replace=True) | ||
.sort_values(by=[self.config.primary_ref]) | ||
.reset_index() | ||
.drop("index", axis=1) | ||
) | ||
|
||
df[dependent_variable] = is_positive | ||
return df | ||
|
||
def _create_covariate_table(self, cursor: DatabaseCursor, schema: str): | ||
"""Creates a covariate table from the loaded toml config""" | ||
# checks for primary & link ref being the same | ||
source_refs = list({self.config.primary_ref, self.config.count_ref} - {None}) | ||
pos_query = get_distinct_ids(source_refs, self.config.pos_source_table) | ||
pos = self._get_sampled_ids( | ||
cursor, | ||
schema, | ||
pos_query, | ||
self.config.pos_sample_size, | ||
self.config.dependent_variable, | ||
1, | ||
) | ||
neg_query = get_distinct_ids( | ||
source_refs, | ||
self.config.neg_source_table, | ||
join_id=self.config.primary_ref, | ||
filter_table=self.config.pos_source_table, | ||
) | ||
neg = self._get_sampled_ids( | ||
cursor, | ||
schema, | ||
neg_query, | ||
self.config.neg_sample_size, | ||
self.config.dependent_variable, | ||
0, | ||
) | ||
cohort = pandas.concat([pos, neg]) | ||
|
||
# Replace table (if it exists) | ||
# TODO - replace with timestamp prepended table in future PR | ||
drop = get_drop_view_table( | ||
f"{self.config.pos_source_table}_sampled_ids", "TABLE" | ||
) | ||
cursor.execute(drop) | ||
ctas_query = get_ctas_query_from_df( | ||
schema, | ||
f"{self.config.pos_source_table}_sampled_ids", | ||
cohort, | ||
) | ||
self.queries.append(ctas_query) | ||
# TODO - replace with timestamp prepended table | ||
drop = get_drop_view_table(self.config.target_table, "TABLE") | ||
cursor.execute(drop) | ||
dataset_query = get_create_covariate_table( | ||
target_table=self.config.target_table, | ||
pos_source_table=self.config.pos_source_table, | ||
neg_source_table=self.config.neg_source_table, | ||
primary_ref=self.config.primary_ref, | ||
dependent_variable=self.config.dependent_variable, | ||
join_cols_by_table=self.config.join_cols_by_table, | ||
count_ref=self.config.count_ref, | ||
count_table=self.config.count_table, | ||
) | ||
self.queries.append(dataset_query) | ||
|
||
def generate_psm_analysis(self, cursor: DatabaseCursor, schema: str): | ||
"""Runs PSM statistics on generated tables""" | ||
df = cursor.execute(f"select * from {self.config.target_table}").as_pandas() | ||
symptoms_dict = self._get_symptoms_dict(self.config.classification_json) | ||
for dependent_variable, codes in symptoms_dict.items(): | ||
df[dependent_variable] = df["code"].apply(lambda x: 1 if x in codes else 0) | ||
df = df.drop(columns="code") | ||
# instance_count present but unused for PSM if table contains a count_ref input | ||
# (it's intended for manual review) | ||
df = df.drop(columns="instance_count", errors="ignore") | ||
|
||
columns = [] | ||
if self.config.join_cols_by_table is not None: | ||
for table_config in self.config.join_cols_by_table.values(): | ||
for column in table_config["included_cols"]: | ||
# If there are two elements, it's a SQL column that has been | ||
# aliased, so we'll look for the alias name | ||
if len(column) == 2: | ||
columns.append(column[1]) | ||
# If there is one element, it's a straight SQL column we can | ||
# use with no modification | ||
elif len(column) == 1: | ||
columns.append(column[0]) | ||
else: | ||
sys.exit( | ||
f"PSM config at {self.toml_path} contains an " | ||
f"unexpected SQL column definition: {column}." | ||
"Check the PSM documentation for valid usages." | ||
) | ||
|
||
# This code block is replacing a column which may contain several categories | ||
# (like male/female/other/unknown for AdministrativeGender), and converts | ||
# it into a series of 1-hot columns for each distinct value in that column, | ||
for column in columns: | ||
encoded_df = pandas.get_dummies(df[column]) | ||
df = pandas.concat([df, encoded_df], axis=1) | ||
df = df.drop(column, axis=1) | ||
df = df.reset_index() | ||
Comment on lines
+224
to
+228
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A block like this could use a comment about why it's doing what it's doing. My rough attempt: you're replacing each column with a dummy version of that column. But even after reading the pandas docs on There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah i'll add something - this is converting to a 1-hot encoding for all values of that column, basically pivoting the column values to new column headers. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. updated with some explanitory text |
||
|
||
try: | ||
psm = PsmPy( | ||
df, | ||
treatment=self.config.dependent_variable, | ||
indx=self.config.primary_ref, | ||
exclude=[], | ||
) | ||
# This function populates the psm.predicted_data element, which is required | ||
# for things like the knn_matched() function call | ||
# TODO: create graph from this data | ||
psm.logistic_ps(balance=True) | ||
# This function populates the psm.df_matched element | ||
# TODO: flip replacement to false after increasing sample data size | ||
# TODO: create graph from this data | ||
psm.knn_matched( | ||
matcher="propensity_logit", | ||
replacement=True, | ||
caliper=None, | ||
drop_unmatched=True, | ||
) | ||
|
||
except ZeroDivisionError: | ||
sys.exit( | ||
"Encountered a divide by zero error during statistical graph generation. Try increasing your sample size." | ||
) | ||
except ValueError: | ||
sys.exit( | ||
"Encountered a value error during KNN matching. Try increasing your sample size." | ||
) | ||
|
||
def prepare_queries(self, cursor: object, schema: str): | ||
self._create_covariate_table(cursor, schema) | ||
|
||
def post_execution( | ||
self, | ||
cursor: object, | ||
schema: str, | ||
verbose: bool, | ||
drop_table: bool = False, | ||
): | ||
# super().execute_queries(cursor, schema, verbose, drop_table) | ||
self.generate_psm_analysis(cursor, schema) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So this is the change to the DB class I was mentioning, and I'm hoping that this comment explains why it's in here the way it is, but to be a bit more verbose about this: pyathena has a method that dramatically improves query execution when it's looking to return a dataframe - something about how they handle chunking under the hood. So, in context, when I'm passing a cursor to a method, I sometimes elect to specifically hand one of these pandas cursors off.
I did this while testing the PSM code (where the cursor is the entrypoint - we :could: rewrite table builders to take a Connection rather than a Cursor, but that's a big refactor by itself and this is already pretty gross), and in the future manifest parsing hook for this to come as a followon PR, I'm planning on specifying the pandas cursor for PSM invocation. The DuckDB version just returns a regular cursor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I'm fine with this change based on the constraint of "Cursor is the interface, not DatabaseBackend/Connection". Some thoughts around it though:
as_pandas
added to the Cursor protocol we have, so that consumers of Library know it's contractually available. (See below for some commentary on this.)execute_as_pandas
dropped -- I only added that to avoid the need for extending cursors like this. But now we could simplify that interface.as_pandas
in the duckdb returned cursor is fine, but gives me pause because clever monkey-patching can be taken too far. 😄 If this setup gets more complicated, I might vote for a DuckCursor wrapper object that does similar kind of translations needed in future.as_pandas
is available and those for which it isn't. What happens on a PyAthena normal cursor if you callas_pandas
?as_pandas
on the wrong cursor object..get_database_backend()
or something to give access to parent objects without introducing two different kinds of Cursors. But that's a little clunky in its own way. But may feel less clunky.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
honestly - i think i like the idea of refactoring one way or another to get these more in line, i'm just trying to not do it as part of this PR for complexity reasons - we can maybe natter about the shape? some options, pulling on some of these threads:
as_pandas
is, apparently, available as a util method that can be called on a pyathena cursor, so we could switch to that and keep the cursor space down to one per db. that might slot better into theexecute_as_pandas
paradigm