Skip to content

Commit

Permalink
Merge pull request #2987 from nspope/pair-coalescence-quantiles-fix
Browse files Browse the repository at this point in the history
Minor fixes to error checking for pair coalescence stats
  • Loading branch information
jeromekelleher authored Sep 17, 2024
2 parents 024e042 + 9977d70 commit 600a18c
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 83 deletions.
37 changes: 17 additions & 20 deletions c/tests/test_stats.c
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,6 @@ verify_pair_coalescence_counts(tsk_treeseq_t *ts, tsk_flags_t options)
ret = tsk_treeseq_pair_coalescence_counts(ts, P, sample_set_sizes, sample_sets, I,
index_tuples, T, breakpoints, N, node_bin_map, options, C);
CU_ASSERT_EQUAL_FATAL(ret, 0);
/* TODO: compare against naive pairs per node per tree */

/* cover errors */
double bad_breakpoints[2] = { breakpoints[1], 0.0 };
Expand Down Expand Up @@ -428,12 +427,15 @@ verify_pair_coalescence_quantiles(tsk_treeseq_t *ts)
tsk_size_t sample_set_sizes[P];
tsk_id_t index_tuples[2 * I];
tsk_id_t node_bin_map[N];
tsk_id_t node_bin_map_empty[N];
tsk_id_t node_bin_map_shuff[N];
tsk_size_t dim = T * Q * I;
double C[dim];
tsk_size_t i, j, k;

for (i = 0; i < N; i++) {
node_bin_map[i] = TSK_NULL;
node_bin_map_empty[i] = TSK_NULL;
node_bin_map_shuff[i] = (tsk_id_t)(i % B);
for (j = 0; j < B; j++) {
if (nodes_time[i] >= epochs[j] && nodes_time[i] < epochs[j + 1]) {
node_bin_map[i] = (tsk_id_t) j;
Expand Down Expand Up @@ -463,14 +465,16 @@ verify_pair_coalescence_quantiles(tsk_treeseq_t *ts)
ret = tsk_treeseq_pair_coalescence_quantiles(ts, P, sample_set_sizes, sample_sets, I,
index_tuples, T, breakpoints, B, node_bin_map, Q, quantiles, 0, C);
CU_ASSERT_EQUAL_FATAL(ret, 0);
/* TODO: compare against naive quantiles per tree */

quantiles[Q - 1] = 0.9;
ret = tsk_treeseq_pair_coalescence_quantiles(ts, P, sample_set_sizes, sample_sets, I,
index_tuples, T, breakpoints, B, node_bin_map, Q, quantiles, 0, C);
CU_ASSERT_EQUAL_FATAL(ret, 0);
quantiles[Q - 1] = 1.0;
/* TODO: compare against naive quantiles per tree */

ret = tsk_treeseq_pair_coalescence_quantiles(ts, P, sample_set_sizes, sample_sets, I,
index_tuples, T, breakpoints, B, node_bin_map_empty, Q, quantiles, 0, C);
CU_ASSERT_EQUAL_FATAL(ret, 0);

/* cover errors */
quantiles[0] = -1.0;
Expand All @@ -493,23 +497,17 @@ verify_pair_coalescence_quantiles(tsk_treeseq_t *ts)
quantiles[0] = 0.0;
quantiles[1] = 0.25;

for (i = 0; i < N; i++) {
if (node_bin_map[i] == 0) {
node_bin_map[i] = 2;
} else if (node_bin_map[i] == 2) {
node_bin_map[i] = 0;
}
}
ts->tables->nodes.time[N - 1] = -1.0;
ret = tsk_treeseq_pair_coalescence_quantiles(ts, P, sample_set_sizes, sample_sets, I,
index_tuples, T, breakpoints, B, node_bin_map, Q, quantiles, 0, C);
index_tuples, T, breakpoints, B, node_bin_map_shuff, Q, quantiles, 0, C);
CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_UNSORTED_TIMES);
for (i = 0; i < N; i++) {
if (node_bin_map[i] == 0) {
node_bin_map[i] = 2;
} else if (node_bin_map[i] == 2) {
node_bin_map[i] = 0;
}
}
ts->tables->nodes.time[N - 1] = max_time;

node_bin_map[0] = (tsk_id_t) B;
ret = tsk_treeseq_pair_coalescence_quantiles(ts, P, sample_set_sizes, sample_sets, I,
index_tuples, T, breakpoints, B, node_bin_map, Q, quantiles, 0, C);
CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_NODE_BIN_MAP_DIM);
node_bin_map[0] = 0;
}

/* Check coalescence rates */
Expand Down Expand Up @@ -570,7 +568,6 @@ verify_pair_coalescence_rates(tsk_treeseq_t *ts)
ret = tsk_treeseq_pair_coalescence_rates(ts, P, sample_set_sizes, sample_sets, I,
index_tuples, T, breakpoints, B, node_bin_map, epochs, 0, C);
CU_ASSERT_EQUAL_FATAL(ret, 0);
/* TODO: compare against naive coalescence rates per tree */

node_bin_map[0] = TSK_NULL;
ret = tsk_treeseq_pair_coalescence_rates(ts, P, sample_set_sizes, sample_sets, I,
Expand Down
63 changes: 55 additions & 8 deletions c/tskit/trees.c
Original file line number Diff line number Diff line change
Expand Up @@ -9255,15 +9255,10 @@ pair_coalescence_quantiles(tsk_size_t input_dim, const double *weight,
tsk_size_t i, j;
j = 0;
coalesced = 0.0;
timepoint = -INFINITY;
/* TODO: may be more efficient to use a binary search */
timepoint = TSK_UNKNOWN_TIME;
for (i = 0; i < input_dim; i++) {
if (weight[i] > 0) {
coalesced += weight[i];
if (values[i] <= timepoint) {
ret = TSK_ERR_UNSORTED_TIMES;
goto out;
}
timepoint = values[i];
while (j < output_dim && quantiles[j] <= coalesced) {
output[j] = timepoint;
Expand All @@ -9274,7 +9269,6 @@ pair_coalescence_quantiles(tsk_size_t input_dim, const double *weight,
if (quantiles[output_dim - 1] == 1.0) {
output[output_dim - 1] = timepoint;
}
out:
return ret;
}

Expand All @@ -9284,7 +9278,7 @@ check_quantiles(const tsk_size_t num_quantiles, const double *quantiles)
int ret = 0;
tsk_size_t i;
double last = -INFINITY;
for (i = 0; i < num_quantiles; ++i) {
for (i = 0; i < num_quantiles; i++) {
if (quantiles[i] <= last || quantiles[i] < 0.0 || quantiles[i] > 1.0) {
ret = TSK_ERR_BAD_QUANTILES;
goto out;
Expand All @@ -9295,6 +9289,55 @@ check_quantiles(const tsk_size_t num_quantiles, const double *quantiles)
return ret;
}

static int
check_sorted_node_bin_map(
const tsk_treeseq_t *self, tsk_size_t num_bins, const tsk_id_t *node_bin_map)
{
int ret = 0;
tsk_size_t num_nodes = self->tables->nodes.num_rows;
const double *nodes_time = self->tables->nodes.time;
double last;
tsk_id_t i, j;
double *min_time = tsk_malloc(num_bins * sizeof(*min_time));
double *max_time = tsk_malloc(num_bins * sizeof(*max_time));
if (min_time == NULL || max_time == NULL) {
ret = TSK_ERR_NO_MEMORY;
goto out;
}
for (j = 0; j < (tsk_id_t) num_bins; j++) {
min_time[j] = TSK_UNKNOWN_TIME;
max_time[j] = TSK_UNKNOWN_TIME;
}
for (i = 0; i < (tsk_id_t) num_nodes; i++) {
j = node_bin_map[i];
if (j < 0 || j >= (tsk_id_t) num_bins) {
continue;
}
if (tsk_is_unknown_time(max_time[j]) || nodes_time[i] > max_time[j]) {
max_time[j] = nodes_time[i];
}
if (tsk_is_unknown_time(min_time[j]) || nodes_time[i] < min_time[j]) {
min_time[j] = nodes_time[i];
}
}
last = -INFINITY;
for (j = 0; j < (tsk_id_t) num_bins; j++) {
if (tsk_is_unknown_time(min_time[j])) {
continue;
}
if (min_time[j] < last) {
ret = TSK_ERR_UNSORTED_TIMES;
goto out;
} else {
last = max_time[j];
}
}
out:
tsk_safe_free(min_time);
tsk_safe_free(max_time);
return ret;
}

int
tsk_treeseq_pair_coalescence_quantiles(const tsk_treeseq_t *self,
tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes,
Expand All @@ -9309,6 +9352,10 @@ tsk_treeseq_pair_coalescence_quantiles(const tsk_treeseq_t *self,
if (ret != 0) {
goto out;
}
ret = check_sorted_node_bin_map(self, num_bins, node_bin_map);
if (ret != 0) {
goto out;
}
options |= TSK_STAT_SPAN_NORMALISE | TSK_STAT_PAIR_NORMALISE;
ret = tsk_treeseq_pair_coalescence_stat(self, num_sample_sets, sample_set_sizes,
sample_sets, num_set_indexes, set_indexes, num_windows, windows, num_bins,
Expand Down
65 changes: 30 additions & 35 deletions python/tests/test_coalrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1395,46 +1395,32 @@ def example_ts(self):
random_seed=1024,
)

def test_oor_windows(self):
def test_bad_windows(self):
ts = self.example_ts()
with pytest.raises(ValueError, match="must be sequence boundary"):
with pytest.raises(ValueError, match="too small depth"):
ts.pair_coalescence_counts(windows="whatever")
with pytest.raises(ValueError, match="must have at least 2 elements"):
ts.pair_coalescence_counts(windows=[0.0])
with pytest.raises(tskit.LibraryError, match="must be increasing list"):
ts.pair_coalescence_counts(
windows=np.array([0.0, 2.0]) * ts.sequence_length
windows=np.array([0.0, 0.3, 0.2, 1.0]) * ts.sequence_length
)

def test_unsorted_windows(self):
ts = self.example_ts()
with pytest.raises(ValueError, match="must be strictly increasing"):
with pytest.raises(tskit.LibraryError, match="must be increasing list"):
ts.pair_coalescence_counts(
windows=np.array([0.0, 0.3, 0.2, 1.0]) * ts.sequence_length
windows=np.array([0.0, 2.0]) * ts.sequence_length
)

def test_bad_windows(self):
ts = self.example_ts()
with pytest.raises(ValueError, match="must be an array of breakpoints"):
ts.pair_coalescence_counts(windows="whatever")
with pytest.raises(ValueError, match="must be an array of breakpoints"):
ts.pair_coalescence_counts(windows=np.array([0.0]))

def test_empty_sample_sets(self):
def test_bad_sample_sets(self):
ts = self.example_ts()
with pytest.raises(ValueError, match="contain at least one element"):
ts.pair_coalescence_counts(sample_sets=[[0, 1, 2], []])

def test_oob_sample_sets(self):
ts = self.example_ts()
with pytest.raises(ValueError, match="is out of bounds"):
with pytest.raises(tskit.LibraryError, match="out of bounds"):
ts.pair_coalescence_counts(sample_sets=[[0, ts.num_nodes]])

def test_nonbinary_indexes(self):
ts = self.example_ts()
with pytest.raises(ValueError, match="must be length two"):
ts.pair_coalescence_counts(indexes=[(0, 0, 0)])

def test_oob_indexes(self):
def test_bad_indexes(self):
ts = self.example_ts()
with pytest.raises(ValueError, match="is out of bounds"):
with pytest.raises(tskit.LibraryError, match="out of bounds"):
ts.pair_coalescence_counts(indexes=[(0, 1)])
with pytest.raises(ValueError, match="must be a k x 2 array"):
ts.pair_coalescence_counts(indexes=[(0, 0, 0)])

def test_no_indexes(self):
ts = self.example_ts()
Expand All @@ -1455,17 +1441,16 @@ def test_uncalibrated_time(self):
with pytest.raises(ValueError, match="require calibrated node times"):
ts.pair_coalescence_counts(time_windows=np.array([0.0, np.inf]))

def test_bad_time_windows(self):
@pytest.mark.parametrize("time_windows", [[], [0.0], [[0.0, 1.0]], "whatever"])
def test_bad_time_windows(self, time_windows):
ts = self.example_ts()
with pytest.raises(ValueError, match="must be an array of breakpoints"):
ts.pair_coalescence_counts(time_windows="whatever")
with pytest.raises(ValueError, match="must be an array of breakpoints"):
ts.pair_coalescence_counts(time_windows=np.array([0.0]))
with pytest.raises(ValueError, match="too small depth"):
ts.pair_coalescence_counts(time_windows="time_windows")

def test_unsorted_time_windows(self):
ts = self.example_ts()
time_windows = np.array([0.0, 12.0, 6.0, np.inf])
with pytest.raises(ValueError, match="must be strictly increasing"):
with pytest.raises(ValueError, match="monotonically increasing or decreasing"):
ts.pair_coalescence_counts(time_windows=time_windows)

def test_empty_time_windows(self):
Expand Down Expand Up @@ -1580,6 +1565,10 @@ def test_errors(self):
quantiles = np.linspace(0, 1, 10)
with pytest.raises(ValueError, match="more than two sample sets"):
ts.pair_coalescence_quantiles(quantiles, sample_sets=sample_sets)
tables = ts.dump_tables()
tables.time_units = tskit.TIME_UNITS_UNCALIBRATED
with pytest.raises(ValueError, match="require calibrated node times"):
tables.tree_sequence().pair_coalescence_quantiles(quantiles=np.array([0.5]))


class TestPairCoalescenceRates:
Expand Down Expand Up @@ -1679,3 +1668,9 @@ def test_errors(self):
time_windows = np.array([0, np.inf])
with pytest.raises(ValueError, match="more than two sample sets"):
ts.pair_coalescence_rates(time_windows, sample_sets=sample_sets)
tables = ts.dump_tables()
tables.time_units = tskit.TIME_UNITS_UNCALIBRATED
with pytest.raises(ValueError, match="require calibrated node times"):
tables.tree_sequence().pair_coalescence_rates(
time_windows=np.array([0.0, np.inf])
)
26 changes: 6 additions & 20 deletions python/tskit/trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -9377,11 +9377,6 @@ def pair_coalescence_counts(

if sample_sets is None:
sample_sets = [list(self.samples())]
for s in sample_sets:
if len(s) == 0:
raise ValueError("Sample sets must contain at least one element")
if not (min(s) >= 0 and max(s) < self.num_nodes):
raise ValueError("Sample is out of bounds")

drop_middle_dimension = False
if indexes is None:
Expand All @@ -9394,30 +9389,15 @@ def pair_coalescence_counts(
raise ValueError(
"Must specify indexes if there are more than two sample sets"
)
for i in indexes:
if not len(i) == 2:
raise ValueError("Sample set indexes must be length two")
if not (min(i) >= 0 and max(i) < len(sample_sets)):
raise ValueError("Sample set index is out of bounds")

drop_left_dimension = False
if windows is None:
drop_left_dimension = True
windows = np.array([0.0, self.sequence_length])
if not (isinstance(windows, np.ndarray) and windows.size > 1):
raise ValueError("Windows must be an array of breakpoints")
if not (windows[0] == 0.0 and windows[-1] == self.sequence_length):
raise ValueError("First and last window breaks must be sequence boundary")
if not np.all(np.diff(windows) > 0):
raise ValueError("Window breaks must be strictly increasing")

if isinstance(time_windows, str) and time_windows == "nodes":
node_bin_map = np.arange(self.num_nodes, dtype=np.int32)
else:
if not (isinstance(time_windows, np.ndarray) and time_windows.size > 1):
raise ValueError("Time windows must be an array of breakpoints")
if not np.all(np.diff(time_windows) > 0):
raise ValueError("Time windows must be strictly increasing")
if self.time_units == tskit.TIME_UNITS_UNCALIBRATED:
raise ValueError("Time windows require calibrated node times")
node_bin_map = np.digitize(self.nodes_time, time_windows) - 1
Expand Down Expand Up @@ -9496,6 +9476,9 @@ def pair_coalescence_quantiles(
"Must specify indexes if there are more than two sample sets"
)

if self.time_units == tskit.TIME_UNITS_UNCALIBRATED:
raise ValueError("Pair coalescence quantiles require calibrated node times")

drop_left_dimension = False
if windows is None:
drop_left_dimension = True
Expand Down Expand Up @@ -9595,6 +9578,9 @@ def pair_coalescence_rates(
"Must specify indexes if there are more than two sample sets"
)

if self.time_units == tskit.TIME_UNITS_UNCALIBRATED:
raise ValueError("Pair coalescence rates require calibrated node times")

drop_left_dimension = False
if windows is None:
drop_left_dimension = True
Expand Down

0 comments on commit 600a18c

Please sign in to comment.