From b063c32bce3910f4d91fb2ef0031ae3c28a9d067 Mon Sep 17 00:00:00 2001 From: Savita Karthikeyan Date: Thu, 18 Apr 2024 16:08:16 +0100 Subject: [PATCH] attempt at switch for calculating population frequencies --- tests/test_data_model.py | 4 ++-- tsqc/__main__.py | 22 ++++++++++++++++++---- tsqc/model.py | 5 +++-- 3 files changed, 23 insertions(+), 8 deletions(-) diff --git a/tests/test_data_model.py b/tests/test_data_model.py index 6ab9b9a..5908d71 100644 --- a/tests/test_data_model.py +++ b/tests/test_data_model.py @@ -211,7 +211,7 @@ def check_ts(self, ts): C1 = compute_mutation_counts(ts) C2 = model.compute_population_mutation_counts(ts) nt.assert_array_equal(C1, C2) - tsm = model.TSModel(ts) + tsm = model.TSModel(ts, calc_population_frequencies=True) df = tsm.mutations_df nt.assert_array_equal(df["pop_A_freq"], C1[0] / ts.num_samples) nt.assert_array_equal(df["pop_B_freq"], C1[1] / ts.num_samples) @@ -242,7 +242,7 @@ def test_no_metadata_schema(self): def test_no_populations(self): tables = single_tree_example_ts().dump_tables() tables.populations.add_row(b"{}") - tsm = model.TSModel(tables.tree_sequence()) + tsm = model.TSModel(tables.tree_sequence(), calc_population_frequencies=True) with pytest.raises(ValueError, match="must be assigned to populations"): tsm.mutations_df diff --git a/tsqc/__main__.py b/tsqc/__main__.py index c32dea1..9b91924 100644 --- a/tsqc/__main__.py +++ b/tsqc/__main__.py @@ -21,14 +21,15 @@ logger = daiquiri.getLogger("tsqc") -def load_data(path): +def load_data(path, calc_population_frequencies): logger.info(f"Loading {path}") try: ts = tskit.load(path) + except tskit.FileFormatError: ts = tszip.decompress(path) - tsm = model.TSModel(ts, path.name) + tsm = model.TSModel(ts, calc_population_frequencies, path.name) return tsm @@ -152,13 +153,26 @@ def setup_logging(log_level, no_log_filter): is_flag=True, help="Do not filter the output log (advanced debugging only)", ) -def main(path, port, show, log_level, no_log_filter, annotations_file): +@click.option( + "--calc-population-frequencies", + default=False, + help="Calculate population frequencies for sample nodes", +) +def main( + path, + port, + show, + log_level, + no_log_filter, + annotations_file, + calc_population_frequencies, +): """ Run the tsqc server. """ setup_logging(log_level, no_log_filter) - tsm = load_data(pathlib.Path(path)) + tsm = load_data(pathlib.Path(path), calc_population_frequencies) if annotations_file: config.ANNOTATIONS_FILE = annotations_file diff --git a/tsqc/model.py b/tsqc/model.py index 079b265..4535ff4 100644 --- a/tsqc/model.py +++ b/tsqc/model.py @@ -356,7 +356,7 @@ class TSModel: convenience methods for analysing the tree sequence. """ - def __init__(self, ts, name=None): + def __init__(self, ts, calc_population_frequencies=False, name=None): self.ts = ts self.name = name @@ -366,6 +366,7 @@ def __init__(self, ts, name=None): self.nodes_num_mutations = np.bincount( self.ts.mutations_node, minlength=self.ts.num_nodes ) + self.calc_population_frequencies = calc_population_frequencies @property def file_uuid(self): @@ -456,7 +457,7 @@ def mutations_df(self): self.mutations_inherited_state = inherited_state population_data = {} - if ts.num_populations > 0: + if ts.num_populations > 0 and self.calc_population_frequencies: pop_mutation_count = compute_population_mutation_counts(ts) for pop in ts.populations(): name = f"pop{pop.id}"