Skip to content

Commit

Permalink
Initial pass at computing LAA values
Browse files Browse the repository at this point in the history
  • Loading branch information
jeromekelleher committed Jan 9, 2025
1 parent 9b5e967 commit 8c28169
Showing 1 changed file with 75 additions and 0 deletions.
75 changes: 75 additions & 0 deletions bio2zarr/vcf2zarr/vcz.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,8 @@ def convert_local_allele_field_types(fields):
gt = fields_by_name["call_genotype"]
if gt.shape[-1] != 2:
raise ValueError("Local alleles only supported on diploid data")
# TODO check if LAA is already in here

shape = gt.shape[:-1]
chunks = gt.chunks[:-1]

Expand All @@ -214,6 +216,7 @@ def convert_local_allele_field_types(fields):
)
pl = fields_by_name.get("call_PL", None)
if pl is not None:
# TODO check if call_LPL is in the list already
pl.name = "call_LPL"
pl.vcf_field = None
pl.shape = (*shape, 3)
Expand Down Expand Up @@ -511,6 +514,46 @@ def fromdict(d):
return ret


def compute_laa_field(genotypes, alleles) -> np.ndarray:
"""
Computes the value of the LAA field for each sample given the genotypes
for a variant.
The LAA field is a list of one-based indices into the ALT alleles
that indicates which alternate alleles are observed in the sample.
"""
alt_allele_count = len(alleles) - 1
allele_counts = np.zeros((genotypes.shape[0], len(alleles)), dtype=int)

genotypes = genotypes.clip(0, None)
genotype_allele_counts = np.apply_along_axis(
np.bincount, axis=1, arr=genotypes, minlength=len(alleles)
)
allele_counts += genotype_allele_counts

allele_counts[:, 0] = 0 # We don't count the reference allele
max_row_length = 1

def nonzero_pad(arr: np.ndarray, *, length: int):
nonlocal max_row_length
alleles = arr.nonzero()[0]
max_row_length = max(max_row_length, len(alleles))
pad_length = length - len(alleles)
return np.pad(
alleles,
(0, pad_length),
mode="constant",
constant_values=constants.INT_FILL,
)

alleles = np.apply_along_axis(
nonzero_pad, axis=1, arr=allele_counts, length=max(1, alt_allele_count)
)
alleles = alleles[:, :max_row_length]

return alleles


@dataclasses.dataclass
class VcfZarrWriteSummary(core.JsonDataclass):
num_partitions: int
Expand Down Expand Up @@ -543,6 +586,12 @@ def has_genotypes(self):
return True
return False

def has_local_alleles(self):
for field in self.schema.fields:
if field.name == "call_LAA" and field.vcf_field is None:
return True
return False

#######################
# init
#######################
Expand Down Expand Up @@ -729,6 +778,8 @@ def encode_partition(self, partition_index):
self.encode_array_partition(array_spec, partition_index)
if self.has_genotypes():
self.encode_genotypes_partition(partition_index)
if self.has_local_alleles():
self.encode_local_alleles_partition(partition_index)

final_path = self.partition_path(partition_index)
logger.info(f"Finalising {partition_index} at {final_path}")
Expand Down Expand Up @@ -800,6 +851,30 @@ def encode_genotypes_partition(self, partition_index):
self.finalise_partition_array(partition_index, "call_genotype_mask")
self.finalise_partition_array(partition_index, "call_genotype_phased")

def encode_local_alleles_partition(self, partition_index):
call_LAA_array = self.init_partition_array(partition_index, "call_LAA")
partition = self.metadata.partitions[partition_index]
call_LAA = core.BufferedArray(call_LAA_array, partition.start)

gt_array = zarr.open_array(
store=self.wip_partition_array_path(partition_index, "call_genotype"),
mode="r",
)
alleles_array = zarr.open_array(
store=self.wip_partition_array_path(partition_index, "variant_allele"),
mode="r",
)
for chunk_index in range(gt_array.cdata_shape[0]):
A = alleles_array.blocks[chunk_index]
G = gt_array.blocks[chunk_index]
for alleles, var in zip(A, G):
j = call_LAA.next_buffer_row()
# TODO we should probably compute LAAs by chunk for efficiency
call_LAA.buff[j] = compute_laa_field(var, alleles)

call_LAA.flush()
self.finalise_partition_array(partition_index, "call_LAA")

def encode_alleles_partition(self, partition_index):
array_name = "variant_allele"
alleles_array = self.init_partition_array(partition_index, array_name)
Expand Down

0 comments on commit 8c28169

Please sign in to comment.