Skip to content

Commit

Permalink
Python prototype of new branch algorithm for two-locus stats
Browse files Browse the repository at this point in the history
During the validation of the original algorithm, we realized that the LD
matrix could get "poisoned" with NaN values if we attempted to make an
adjustment to a node that did not contain any samples, which occurs with
some frequency.

We tore things apart and simplified the algorithm so that we no longer
have to do adjustments as we're adding and removing edges. This new
version removes the LD contribution from all modified nodes and adds the
contribution from all nodes at the end of the routine, once we know the
final state of samples under each node.
  • Loading branch information
lkirk authored and mergify[bot] committed Nov 5, 2024
1 parent 9acedd2 commit 73ef4cc
Showing 1 changed file with 66 additions and 105 deletions.
171 changes: 66 additions & 105 deletions python/tests/test_ld_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,15 +132,28 @@ def add(self: "BitSet", row: int, bit: int) -> None:

def get_items(self: "BitSet", row: int) -> Generator[int, None, None]:
"""Get the items stored in the row of a bitset
Uses a de Bruijn sequence lookup table to determine the lowest bit set.
See the wikipedia article for more info: https://w.wiki/BYiF
:param row: Row from the array to list from.
:returns: A generator of integers stored in the array.
"""
lookup = [0, 1, 28, 2, 29, 14, 24, 3, 30, 22, 20, 15, 25, 17, 4, 8, 31, 27,
13, 23, 21, 19, 16, 7, 26, 12, 18, 6, 11, 5, 10, 9] # fmt: skip
m = np.uint32(125613361)
offset = row * self.row_len
for i in range(self.row_len):
for item in range(self.CHUNK_SIZE):
if self.data[i + offset] & (self.DTYPE(1) << item):
yield item + (i * self.CHUNK_SIZE)
v = self.data[i + offset]
if v == 0:
continue
else:
# v & -v operations rely on integer overflow
with np.errstate(over="ignore"):
lsb = v & -v # isolate the least significant bit
while lsb: # while there are bits remaining
yield lookup[(lsb * m) >> 27] + (i * self.CHUNK_SIZE)
v ^= lsb # unset the lsb
lsb = v & -v

def contains(self: "BitSet", row: int, bit: int) -> bool:
"""Test if a bit is contained within a bit array row
Expand Down Expand Up @@ -1561,7 +1574,6 @@ def advance(self, index):

def compute_branch_stat_update(
c,
child_samples,
A_state,
B_state,
state_dim,
Expand All @@ -1572,19 +1584,11 @@ def compute_branch_stat_update(
params,
):
"""Compute an update to the two-locus statistic for a single subset of the
tree being modified, relative to all subsets of the fixed tree. We perform
this operation for all samples edge being modified. For subsequent parent
nodes, we update the statistic by removing the existing contribution after
adding in the update contribution.
i.e. if we're adding two samples ({3, 4}) to a node, if the parent node
contains {1, 2}, we first add the statistic for {1, 2, 3, 4}, then
subtract the stat for {1, 2}.
tree being modified, relative to all subsets of the fixed tree.
:param c: Child node of the edge we're modifying
:param child_samples: Samples under the edge being added/removed
:param A_state: State for the tree contributing to the A samples (fixed)
:param A_state: State for the tree contributing to the B samples (modified)
:param B_state: State for the tree contributing to the B samples (modified)
:param state_dim: Number of sample sets.
:param sign: The sign of the update
:param stat_func: Function used to compute the two-locus statistic
Expand All @@ -1597,7 +1601,6 @@ def compute_branch_stat_update(
return result

AB_samples = BitSet(num_samples, 1)
node_samples_tmp = BitSet(num_samples, 1)
weights = np.zeros((3, state_dim), dtype=np.int64)
result_tmp = np.zeros(state_dim, np.float64)

Expand All @@ -1621,32 +1624,16 @@ def compute_branch_stat_update(
for k in range(state_dim):
result[k] += result_tmp[k] * a_len * b_len

# If we've begun our walk up the parents of the current edge removal, we
# must adjust the statistic for samples that were already present before
# addition or that remain after removal.
if child_samples is not None:
for k in range(state_dim):
row = (state_dim * n) + k
c_row = (state_dim * c) + k
node_samples_tmp.union(0, B_state.node_samples, c_row)
node_samples_tmp.difference(0, child_samples, k)
AB_samples.data[:] = 0 # Zero out the bitset so that we can reuse it
A_state.node_samples.intersect(row, node_samples_tmp, 0, AB_samples)

w_AB = AB_samples.count(0)
w_A = A_state.node_samples.count(row)
w_B = node_samples_tmp.count(0)

weights[0, k] = w_AB
weights[1, k] = w_A - w_AB # w_Ab
weights[2, k] = w_B - w_AB # w_aB

stat_func(state_dim, weights, result_tmp, params)
for k in range(state_dim):
result[k] -= result_tmp[k] * a_len * b_len


def compute_branch_stat(ts, stat_func, stat, params, state_dim, l_state, r_state):
def compute_branch_stat(
ts: tskit.TreeSequence,
stat_func,
stat,
params,
state_dim,
l_state: TreeState,
r_state: TreeState,
):
"""Step between trees in a tree sequence, updating our two-locus statistic
as we add or remove edges. Since we're computing statistics for two loci, we
have a focal tree that remains constant, and a tree that is updated to
Expand All @@ -1673,89 +1660,63 @@ def compute_branch_stat(ts, stat_func, stat, params, state_dim, l_state, r_state
:returns: A tuple containing the statistic between the two trees after
branch updates and the righthand tree state.
"""
num_samples = ts.num_samples
time = ts.tables.nodes.time
updates = BitSet(ts.num_nodes, 1)

child_samples = BitSet(ts.num_samples, state_dim)
for e in r_state.edges_out:
# Identify modified nodes
for e in r_state.edges_out + r_state.edges_in:
p = ts.edges_parent[e]
c = ts.edges_child[e]
child_samples.data[:] = 0
for k in range(state_dim):
c_row = (state_dim * c) + k
child_samples.union(k, r_state.node_samples, c_row)

# Remove the LD contributed by the samples under removed edges. When
# we walk up the tree to propagate these changes to parents of the
# removed edge, we need to add back in the LD contributed by samples
# that aren't removed. We remove samples from the parents of the removed
# branch as we propagate changes upward
in_parent = None
# identify affected nodes above child
while p != tskit.NULL:
compute_branch_stat_update(
c,
in_parent,
l_state,
r_state,
state_dim,
-1,
stat_func,
ts.num_samples,
stat,
params,
)
if in_parent is not None:
# remove samples from the parents of the branch being removed
# we remove the child node after the first iteration
for k in range(state_dim):
c_row = (state_dim * c) + k
r_state.node_samples.difference(c_row, child_samples, k)
in_parent = child_samples
updates.add(0, c)
c = p
p = r_state.parent[p]
for k in range(state_dim):
c_row = (state_dim * c) + k
r_state.node_samples.difference(c_row, child_samples, k)

# reset to the child of the edge being removed.
c = ts.edges_child[e]
r_state.branch_len[c] = 0
r_state.parent[c] = tskit.NULL
# Subtract the whole contribution from child node
for c in updates.get_items(0):
compute_branch_stat_update(
c, l_state, r_state, state_dim, -1, stat_func, num_samples, stat, params
)

# Sample Removal
for e in r_state.edges_out:
p = ts.edges_parent[e]
ec = ts.edges_child[e]
# update samples under nodes, propagate upwards
while p != tskit.NULL:
for k in range(state_dim):
r_state.node_samples.difference(
state_dim * p + k, r_state.node_samples, state_dim * ec + k
)
p = r_state.parent[p]
# set the parent to prevent upwards iteration
r_state.branch_len[ec] = 0
r_state.parent[ec] = tskit.NULL

# Sample Addition
for e in r_state.edges_in:
p = ts.edges_parent[e]
c = ts.edges_child[e]
child_samples.data[:] = 0
for k in range(state_dim):
c_row = (state_dim * c) + k
child_samples.union(k, r_state.node_samples, c_row)
ec = c = ts.edges_child[e]
r_state.branch_len[c] = time[p] - time[c]
r_state.parent[c] = p

# Add the LD contributed by the samples under added edges. When we walk
# up the tree to propagate these changes to parents of the removed edge,
# we need to remove the LD contributed by samples that were already
# there
in_parent = None
# update samples under nodes, store modified node, propagate upwards
while p != tskit.NULL:
updates.add(0, c)
for k in range(state_dim):
p_row = (state_dim * p) + k
r_state.node_samples.union(p_row, child_samples, k)
compute_branch_stat_update(
c,
in_parent,
l_state,
r_state,
state_dim,
+1,
stat_func,
ts.num_samples,
stat,
params,
)
in_parent = child_samples
r_state.node_samples.union(
state_dim * p + k, r_state.node_samples, state_dim * ec + k
)
c = p
p = r_state.parent[p]

# Update all affected child nodes (fully subtracted, deferred from addition)
for c in updates.get_items(0):
compute_branch_stat_update(
c, l_state, r_state, state_dim, +1, stat_func, num_samples, stat, params
)

return stat, r_state


Expand Down

0 comments on commit 73ef4cc

Please sign in to comment.