Skip to content

Commit

Permalink
attempt at switch for calculating population frequencies
Browse files Browse the repository at this point in the history
  • Loading branch information
savitakartik committed Apr 18, 2024
1 parent ba65e4f commit b063c32
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 8 deletions.
4 changes: 2 additions & 2 deletions tests/test_data_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def check_ts(self, ts):
C1 = compute_mutation_counts(ts)
C2 = model.compute_population_mutation_counts(ts)
nt.assert_array_equal(C1, C2)
tsm = model.TSModel(ts)
tsm = model.TSModel(ts, calc_population_frequencies=True)
df = tsm.mutations_df
nt.assert_array_equal(df["pop_A_freq"], C1[0] / ts.num_samples)
nt.assert_array_equal(df["pop_B_freq"], C1[1] / ts.num_samples)
Expand Down Expand Up @@ -242,7 +242,7 @@ def test_no_metadata_schema(self):
def test_no_populations(self):
tables = single_tree_example_ts().dump_tables()
tables.populations.add_row(b"{}")
tsm = model.TSModel(tables.tree_sequence())
tsm = model.TSModel(tables.tree_sequence(), calc_population_frequencies=True)
with pytest.raises(ValueError, match="must be assigned to populations"):
tsm.mutations_df

Expand Down
22 changes: 18 additions & 4 deletions tsqc/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,15 @@
logger = daiquiri.getLogger("tsqc")


def load_data(path):
def load_data(path, calc_population_frequencies):
logger.info(f"Loading {path}")
try:
ts = tskit.load(path)

except tskit.FileFormatError:
ts = tszip.decompress(path)

tsm = model.TSModel(ts, path.name)
tsm = model.TSModel(ts, calc_population_frequencies, path.name)
return tsm


Expand Down Expand Up @@ -152,13 +153,26 @@ def setup_logging(log_level, no_log_filter):
is_flag=True,
help="Do not filter the output log (advanced debugging only)",
)
def main(path, port, show, log_level, no_log_filter, annotations_file):
@click.option(
"--calc-population-frequencies",
default=False,
help="Calculate population frequencies for sample nodes",
)
def main(
path,
port,
show,
log_level,
no_log_filter,
annotations_file,
calc_population_frequencies,
):
"""
Run the tsqc server.
"""
setup_logging(log_level, no_log_filter)

tsm = load_data(pathlib.Path(path))
tsm = load_data(pathlib.Path(path), calc_population_frequencies)
if annotations_file:
config.ANNOTATIONS_FILE = annotations_file

Expand Down
5 changes: 3 additions & 2 deletions tsqc/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ class TSModel:
convenience methods for analysing the tree sequence.
"""

def __init__(self, ts, name=None):
def __init__(self, ts, calc_population_frequencies=False, name=None):
self.ts = ts
self.name = name

Expand All @@ -366,6 +366,7 @@ def __init__(self, ts, name=None):
self.nodes_num_mutations = np.bincount(
self.ts.mutations_node, minlength=self.ts.num_nodes
)
self.calc_population_frequencies = calc_population_frequencies

@property
def file_uuid(self):
Expand Down Expand Up @@ -456,7 +457,7 @@ def mutations_df(self):
self.mutations_inherited_state = inherited_state

population_data = {}
if ts.num_populations > 0:
if ts.num_populations > 0 and self.calc_population_frequencies:
pop_mutation_count = compute_population_mutation_counts(ts)
for pop in ts.populations():
name = f"pop{pop.id}"
Expand Down

0 comments on commit b063c32

Please sign in to comment.