diff --git a/bio2zarr/vcf2zarr/vcz.py b/bio2zarr/vcf2zarr/vcz.py index d46bfc9..a427678 100644 --- a/bio2zarr/vcf2zarr/vcz.py +++ b/bio2zarr/vcf2zarr/vcz.py @@ -197,6 +197,8 @@ def convert_local_allele_field_types(fields): gt = fields_by_name["call_genotype"] if gt.shape[-1] != 2: raise ValueError("Local alleles only supported on diploid data") + # TODO check if LAA is already in here + shape = gt.shape[:-1] chunks = gt.chunks[:-1] @@ -214,6 +216,7 @@ def convert_local_allele_field_types(fields): ) pl = fields_by_name.get("call_PL", None) if pl is not None: + # TODO check if call_LPL is in the list already pl.name = "call_LPL" pl.vcf_field = None pl.shape = (*shape, 3) @@ -511,6 +514,46 @@ def fromdict(d): return ret +def compute_laa_field(genotypes, alleles) -> np.ndarray: + """ + Computes the value of the LAA field for each sample given the genotypes + for a variant. + + The LAA field is a list of one-based indices into the ALT alleles + that indicates which alternate alleles are observed in the sample. + """ + alt_allele_count = len(alleles) - 1 + allele_counts = np.zeros((genotypes.shape[0], len(alleles)), dtype=int) + + genotypes = genotypes.clip(0, None) + genotype_allele_counts = np.apply_along_axis( + np.bincount, axis=1, arr=genotypes, minlength=len(alleles) + ) + allele_counts += genotype_allele_counts + + allele_counts[:, 0] = 0 # We don't count the reference allele + max_row_length = 1 + + def nonzero_pad(arr: np.ndarray, *, length: int): + nonlocal max_row_length + alleles = arr.nonzero()[0] + max_row_length = max(max_row_length, len(alleles)) + pad_length = length - len(alleles) + return np.pad( + alleles, + (0, pad_length), + mode="constant", + constant_values=constants.INT_FILL, + ) + + alleles = np.apply_along_axis( + nonzero_pad, axis=1, arr=allele_counts, length=max(1, alt_allele_count) + ) + alleles = alleles[:, :max_row_length] + + return alleles + + @dataclasses.dataclass class VcfZarrWriteSummary(core.JsonDataclass): num_partitions: int @@ -543,6 +586,12 @@ def has_genotypes(self): return True return False + def has_local_alleles(self): + for field in self.schema.fields: + if field.name == "call_LAA" and field.vcf_field is None: + return True + return False + ####################### # init ####################### @@ -729,6 +778,8 @@ def encode_partition(self, partition_index): self.encode_array_partition(array_spec, partition_index) if self.has_genotypes(): self.encode_genotypes_partition(partition_index) + if self.has_local_alleles(): + self.encode_local_alleles_partition(partition_index) final_path = self.partition_path(partition_index) logger.info(f"Finalising {partition_index} at {final_path}") @@ -800,6 +851,30 @@ def encode_genotypes_partition(self, partition_index): self.finalise_partition_array(partition_index, "call_genotype_mask") self.finalise_partition_array(partition_index, "call_genotype_phased") + def encode_local_alleles_partition(self, partition_index): + call_LAA_array = self.init_partition_array(partition_index, "call_LAA") + partition = self.metadata.partitions[partition_index] + call_LAA = core.BufferedArray(call_LAA_array, partition.start) + + gt_array = zarr.open_array( + store=self.wip_partition_array_path(partition_index, "call_genotype"), + mode="r", + ) + alleles_array = zarr.open_array( + store=self.wip_partition_array_path(partition_index, "variant_allele"), + mode="r", + ) + for chunk_index in range(gt_array.cdata_shape[0]): + A = alleles_array.blocks[chunk_index] + G = gt_array.blocks[chunk_index] + for alleles, var in zip(A, G): + j = call_LAA.next_buffer_row() + # TODO we should probably compute LAAs by chunk for efficiency + call_LAA.buff[j] = compute_laa_field(var, alleles) + + call_LAA.flush() + self.finalise_partition_array(partition_index, "call_LAA") + def encode_alleles_partition(self, partition_index): array_name = "variant_allele" alleles_array = self.init_partition_array(partition_index, array_name)