Skip to content

Commit

Permalink
Add in quick version of GRM
Browse files Browse the repository at this point in the history
  • Loading branch information
jeromekelleher committed Aug 23, 2023
1 parent 11404c7 commit 88f26de
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 18 deletions.
74 changes: 57 additions & 17 deletions python/tests/test_divmat.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
"""
import array
import collections
import functools

import msprime
import numpy as np
Expand Down Expand Up @@ -258,8 +259,23 @@ def divergence_matrix(
)


def stats_api_divergence_matrix(
ts, windows=None, samples=None, sample_sets=None, mode="site", span_normalise=True
def stats_api_divergence_matrix(ts, *args, **kwargs):
return stats_api_matrix_method(ts, ts.divergence, *args, **kwargs)


def stats_api_genetic_relatedness_matrix(ts, *args, **kwargs):
method = functools.partial(ts.genetic_relatedness, proportion=False)
return stats_api_matrix_method(ts, method, *args, **kwargs)


def stats_api_matrix_method(
ts,
method,
windows=None,
samples=None,
sample_sets=None,
mode="site",
span_normalise=True,
):
if samples is not None and sample_sets is not None:
raise ValueError("Cannot specify both")
Expand All @@ -282,19 +298,6 @@ def stats_api_divergence_matrix(
else:
return np.zeros(shape=(0, 0))

# # Make sure that all the specified samples have the sample flag set, otherwise
# # the library code will complain
# tables = ts.dump_tables()
# flags = tables.nodes.flags
# # NOTE: this is a shortcut, setting all flags unconditionally to zero, so don't
# # use this tree sequence outside this method.
# flags[:] = 0
# for sample_set in sample_sets:
# for u in sample_set:
# flags[u] = tskit.NODE_IS_SAMPLE
# tables.nodes.flags = flags
# ts = tables.tree_sequence()

# FIXME We have to go through this annoying rigmarole because windows must start and
# end with 0 and L. We should relax this requirement to just making the windows
# contiguous, so that we just look at specific sections of the genome.
Expand All @@ -308,7 +311,7 @@ def stats_api_divergence_matrix(

n = len(sample_sets)
indexes = [(i, j) for i in range(n) for j in range(n)]
X = ts.divergence(
X = method(
sample_sets,
indexes=indexes,
mode=mode,
Expand Down Expand Up @@ -1282,7 +1285,7 @@ def test_good_args(self, arg, flattened, sizes):
assert isinstance(f, np.ndarray)
assert f.dtype == np.int32
assert isinstance(s, np.ndarray)
assert s.dtype == np.uint32
assert s.dtype == np.uint64
np.testing.assert_array_equal(f, flattened)
np.testing.assert_array_equal(s, sizes)

Expand Down Expand Up @@ -1341,3 +1344,40 @@ def test_dict_args(self, arg):
def test_bad_arg_types(self, arg):
with pytest.raises(TypeError):
tskit.TreeSequence._parse_stat_matrix_ids_arg(arg)


class TestGeneticRelatednessMatrix:
def check(self, ts, mode, windows=None):
G1 = stats_api_genetic_relatedness_matrix(ts, mode=mode, windows=windows)
# Seem to be out by a factor of -2, quick hack just to check
G2 = ts.genetic_relatedness_matrix(mode=mode, windows=windows)
np.testing.assert_array_almost_equal(G1, G2)

@pytest.mark.parametrize("mode", DIVMAT_MODES)
def test_single_tree(self, mode):
# 2.00┊ 6 ┊
# ┊ ┏━┻━┓ ┊
# 1.00┊ 4 5 ┊
# ┊ ┏┻┓ ┏┻┓ ┊
# 0.00┊ 0 1 2 3 ┊
# 0 1
ts = tskit.Tree.generate_balanced(4).tree_sequence
ts = tsutil.insert_branch_sites(ts)
self.check(ts, mode)

@pytest.mark.parametrize("mode", DIVMAT_MODES)
def test_single_tree_windows(self, mode):
# 2.00┊ 6 ┊
# ┊ ┏━┻━┓ ┊
# 1.00┊ 4 5 ┊
# ┊ ┏┻┓ ┏┻┓ ┊
# 0.00┊ 0 1 2 3 ┊
# 0 1
ts = tskit.Tree.generate_balanced(4).tree_sequence
ts = tsutil.insert_branch_sites(ts)
self.check(ts, mode, windows=[0, 0.5, 1])

@pytest.mark.parametrize("ts", get_example_tree_sequences())
@pytest.mark.parametrize("mode", DIVMAT_MODES)
def test_suite_defaults(self, ts, mode):
self.check(ts, mode=mode)
41 changes: 40 additions & 1 deletion python/tskit/trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -7900,7 +7900,7 @@ def _parse_stat_matrix_ids_arg(ids):

def divergence_matrix(
self,
ids,
ids=None,
*,
windows=None,
num_threads=0,
Expand Down Expand Up @@ -8064,6 +8064,45 @@ def genetic_relatedness(

return out

def genetic_relatedness_matrix(
self,
ids=None,
*,
windows=None,
num_threads=0,
mode=None,
span_normalise=True,
):
D = self.divergence_matrix(
ids,
windows=windows,
num_threads=num_threads,
mode=mode,
span_normalise=span_normalise,
)

def _normalise(B):
if len(B) == 0:
return B
K = np.zeros_like(B)
N = K.shape[0]
B_mean = np.mean(B)
Bi_mean = np.mean(B, axis=0)
# TODO numpify - should be easy enough by creating full matrices
# for the row and column means.
for i in range(N):
for j in range(N):
K[i, j] = B[i, j] - Bi_mean[i] - Bi_mean[j] + B_mean
# FIXME I don't know what this factor -2 is about
return K / -2

if windows is None:
return _normalise(D)
else:
for j in range(D.shape[0]):
D[j] = _normalise(D[j])
return D

def genetic_relatedness_weighted(
self,
W,
Expand Down

0 comments on commit 88f26de

Please sign in to comment.