Skip to content

Commit

Permalink
Temp
Browse files Browse the repository at this point in the history
  • Loading branch information
jeromekelleher committed Aug 23, 2023
1 parent dcba038 commit 66d8081
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 20 deletions.
56 changes: 53 additions & 3 deletions python/tests/test_divmat.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,10 +427,14 @@ def check_divmat(
np.testing.assert_allclose(D1, D2)
assert D1.shape == D2.shape
if compare_lib:
ids = None
if sample_sets is not None:
ids = sample_sets
if samples is not None:
ids = samples
D3 = ts.divergence_matrix(
sample_sets=sample_sets,
ids,
windows=windows,
samples=samples,
mode=mode,
span_normalise=span_normalise,
)
Expand Down Expand Up @@ -915,8 +919,8 @@ def check(
mode="branch",
):
D1 = ts.divergence_matrix(
sample_sets,
windows=windows,
sample_sets=sample_sets,
num_threads=num_threads,
mode=mode,
span_normalise=span_normalise,
Expand Down Expand Up @@ -1251,3 +1255,49 @@ def test_simple_simulation(self):
for j in range(var.num_alleles):
a = A[offsets[j] : offsets[j + 1]]
assert list(a) == list(allele_samples[j])


from tskit import util
import collections

def parse_ids(ids):
"""
Returns a flattened list of sets of IDs. If ids is a 1D list,
interpret as n one-element sets. Otherwise, it must be a sequence
of ID lists.
"""
ids = util.safe_np_int_cast(ids, np.int32)
print(ids)
if len(ids) == 0:
return [], []

# if len(ids) == 0:
# return [], []
flat = util.safe_np_int_cast(ids, np.int32)
sizes = np.ones(len(flat), dtype=np.uint32)
return flat, sizes



class TestIdArgument:
@pytest.mark.parametrize(["arg", "flattened", "sizes"], [
([], [], []),
([1], [1], [1]),
([1, 2], [1, 2], [1, 1]),
(np.array([1, 2]), [1, 2], [1, 1]),
(np.array([1, 2], dtype=np.uint32), [1, 2], [1, 1]),
([[1], [2]], [1, 2], [1, 1]),
([[1, 1], [2]], [1, 1, 2], [2, 1]),
])
def test_good_args(self, arg, flattened, sizes):
f, s = parse_ids(arg)
print(f, s)
np.testing.assert_array_equal(f, flattened)
np.testing.assert_array_equal(s, sizes)

@pytest.mark.parametrize("arg", [
"", {},
])
def test_bad_args(self, arg):
with pytest.raises(TypeError):
parse_ids(arg)
41 changes: 24 additions & 17 deletions python/tskit/trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -7868,10 +7868,9 @@ def worker(sub_windows):

def divergence_matrix(
self,
ids,
*,
windows=None,
samples=None,
sample_sets=None,
num_threads=0,
mode=None,
span_normalise=True,
Expand All @@ -7880,23 +7879,31 @@ def divergence_matrix(
windows = self.parse_windows(windows)
mode = "site" if mode is None else mode

sample_set_sizes = None
flattened_samples = None
if samples is not None:
assert sample_sets is None
flattened_samples = samples
sample_set_sizes = np.ones(len(samples), dtype=np.uint32)
elif sample_sets is not None:
assert samples is None
sample_set_sizes = np.array(
[len(sample_set) for sample_set in sample_sets], dtype=np.uint32
)
if np.sum(sample_set_sizes) == 0:
flattened_samples = []
if ids is None:
ids = self.samples()
flattened_samples = self.samples()
sample_set_sizes = np.ones(len(ids), dtype=np.uint32)
else:
x = np.hstack(ids)
print(x)
ids = util.safe_np_int_cast(ids, np.int32)

ids = np.array(ids)
# print(ids)
if len(ids.shape) == 1:
flattened_samples = util.safe_np_int_cast(ids, np.int32)
sample_set_sizes = np.ones(len(ids), dtype=np.uint32)
else:
flattened_samples = util.safe_np_int_cast(
np.hstack(sample_sets), np.int32

sample_set_sizes = np.array(
[len(sample_set) for sample_set in ids], dtype=np.uint32
)
if np.sum(sample_set_sizes) == 0:
flattened_samples = []
else:
flattened_samples = util.safe_np_int_cast(
np.hstack(ids), np.int32
)

# FIXME this logic should be merged into __run_windowed_stat if
# we generalise the num_threads argument to all stats.
Expand Down

0 comments on commit 66d8081

Please sign in to comment.