diff --git a/CHANGELOG.md b/CHANGELOG.md index 3d38f3d4..109dc3cd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,12 @@ # Changelog +## [0.4.0a3] - ****-**-** + +**Fixes** + +- Properly account for "N" as an unknown ancestral state, and ban "" from being + set as an ancestral state ({pr}`963`, {user}`hyanwong`)) + ## [0.4.0a2] - 2024-09-06 2nd Alpha release of tsinfer 0.4.0 diff --git a/docs/usage.md b/docs/usage.md index 2c3d3544..ed2c044d 100644 --- a/docs/usage.md +++ b/docs/usage.md @@ -60,14 +60,14 @@ for sample in range(ds['call_genotype'].shape[1]): We wish to infer a genealogy that could have given rise to this data set. To run _tsinfer_ we wrap the .vcz file in a `tsinfer.VariantData` object. This requires an -*ancestral allele* to be specified for each site; there are +*ancestral state* to be specified for each site; there are many methods for calculating these: details are outside the scope of this manual, but we have started a [discussion topic](https://github.com/tskit-dev/tsinfer/discussions/523) on this issue to provide some recommendations. Sometimes VCF files will contain the -ancestral allele in the "AA" info field, in which case it will be encoded in the -`variant_AA` field of the .vcz file. It's also possible to provide a numpy array +ancestral state in the "AA" ("ancestral allele") info field, in which case it will be encoded +in the `variant_AA` field of the .vcz file. It's also possible to provide a numpy array of ancestral alleles, of the same length as the number of variants. Ancestral alleles that are not in the list of alleles for their respective site are treated as unknown and not used for inference (with a warning given). @@ -76,11 +76,11 @@ and not used for inference (with a warning given). import tsinfer # For this example take the REF allele (index 0) as ancestral -ancestral_allele = ds['variant_allele'][:,0].astype(str) +ancestral_state = ds['variant_allele'][:,0].astype(str) # This is just a numpy array, set the last site to an unknown value, for demo purposes -ancestral_allele[-1] = "." +ancestral_state[-1] = "." -vdata = tsinfer.VariantData("_static/example_data.vcz", ancestral_allele) +vdata = tsinfer.VariantData("_static/example_data.vcz", ancestral_state) ``` The `VariantData` object is a lightweight wrapper around the .vcz file. @@ -127,7 +127,7 @@ site_mask[ds.variant_position[:] >= 6] = True smaller_vdata = tsinfer.VariantData( "_static/example_data.vcz", - ancestral_allele=ancestral_allele[site_mask == False], + ancestral_state=ancestral_state[site_mask == False], site_mask=site_mask, ) print(f"The `smaller_vdata` object returns data for only {smaller_vdata.num_sites} sites") @@ -351,8 +351,8 @@ Once we have our `.vcz` file created, running the inference is straightforward. ```{code-cell} ipython3 # Infer & save a ts from the notebook simulation. -ancestral_alleles = np.load(f"{name}-AA.npy") -vdata = tsinfer.VariantData(f"{name}.vcz", ancestral_alleles) +ancestral_states = np.load(f"{name}-AA.npy") +vdata = tsinfer.VariantData(f"{name}.vcz", ancestral_states) tsinfer.infer(vdata, progress_monitor=True, num_threads=4).dump(name + ".trees") ``` @@ -477,12 +477,12 @@ vcf_location = "_static/P_dom_chr24_phased.vcf.gz" ``` This creates the `sparrows.vcz` datastore, which we open using `tsinfer.VariantData`. -The original VCF had ancestral alleles specified in the `AA` INFO field, so we can -simply provide the string `"variant_AA"` as the ancestral_allele parameter. +The original VCF had the ancestral allelic state specified in the `AA` INFO field, +so we can simply provide the string `"variant_AA"` as the ancestral_state parameter. ```{code-cell} ipython3 -# Do the inference: this VCF has ancestral alleles in the AA field -vdata = tsinfer.VariantData("sparrows.vcz", ancestral_allele="variant_AA") +# Do the inference: this VCF has ancestral states in the AA field +vdata = tsinfer.VariantData("sparrows.vcz", ancestral_state="variant_AA") ts = tsinfer.infer(vdata) print( "Inferred tree sequence: {} trees over {} Mb ({} edges)".format( @@ -534,7 +534,7 @@ Now when we carry out the inference, we get a tree sequence in which the nodes a correctly assigned to named populations ```{code-cell} ipython3 -vdata = tsinfer.VariantData("sparrows.vcz", ancestral_allele="variant_AA") +vdata = tsinfer.VariantData("sparrows.vcz", ancestral_state="variant_AA") sparrow_ts = tsinfer.infer(vdata) for sample_node_id in sparrow_ts.samples(): diff --git a/tests/test_inference.py b/tests/test_inference.py index 07d2a0c0..74727428 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -1532,7 +1532,7 @@ def test_match_samples_batch(self, tmp_path, tmpdir): mat_wd = tsinfer.match_samples_batch_init( work_dir=tmpdir / "working_mat", sample_data_path=mat_sd.path, - ancestral_allele="variant_ancestral_allele", + ancestral_state="variant_ancestral_allele", ancestor_ts_path=tmpdir / "mat_anc.trees", min_work_per_job=1, max_num_partitions=10, @@ -1547,7 +1547,7 @@ def test_match_samples_batch(self, tmp_path, tmpdir): mask_wd = tsinfer.match_samples_batch_init( work_dir=tmpdir / "working_mask", sample_data_path=mask_sd.path, - ancestral_allele="variant_ancestral_allele", + ancestral_state="variant_ancestral_allele", ancestor_ts_path=tmpdir / "mask_anc.trees", min_work_per_job=1, max_num_partitions=10, diff --git a/tests/test_variantdata.py b/tests/test_variantdata.py index a353c4c0..550ae1a5 100644 --- a/tests/test_variantdata.py +++ b/tests/test_variantdata.py @@ -20,8 +20,10 @@ Tests for the data files. """ import json +import logging import sys import tempfile +import warnings import msprime import numcodecs @@ -627,14 +629,12 @@ def test_missing_ancestral_allele(tmp_path): @pytest.mark.skipif(sys.platform == "win32", reason="No cyvcf2 on Windows") -def test_ancestral_missingness(tmp_path): +def test_deliberate_ancestral_missingness(tmp_path): ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path) ds = sgkit.load_dataset(zarr_path) ancestral_allele = ds.variant_ancestral_allele.values ancestral_allele[0] = "N" - ancestral_allele[11] = "-" - ancestral_allele[12] = "💩" - ancestral_allele[15] = "💩" + ancestral_allele[1] = "n" ds = ds.drop_vars(["variant_ancestral_allele"]) sgkit.save_dataset(ds, str(zarr_path) + ".tmp") tsutil.add_array_to_dataset( @@ -644,15 +644,57 @@ def test_ancestral_missingness(tmp_path): ["variants"], ) ds = sgkit.load_dataset(str(zarr_path) + ".tmp") + with warnings.catch_warnings(): + warnings.simplefilter("error") # No warning raised if AA deliberately missing + sd = tsinfer.VariantData(str(zarr_path) + ".tmp", "variant_ancestral_allele") + inf_ts = tsinfer.infer(sd) + for i, (inf_var, var) in enumerate(zip(inf_ts.variants(), ts.variants())): + if i in [0, 1]: + assert inf_var.site.metadata == {"inference_type": "parsimony"} + else: + assert inf_var.site.ancestral_state == var.site.ancestral_state + + +@pytest.mark.skipif(sys.platform == "win32", reason="No cyvcf2 on Windows") +def test_ancestral_missing_warning(tmp_path): + ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path) + ds = sgkit.load_dataset(zarr_path) + anc_state = ds.variant_ancestral_allele.values + anc_state[0] = "N" + anc_state[11] = "-" + anc_state[12] = "💩" + anc_state[15] = "💩" with pytest.warns( UserWarning, match=r"not found in the variant_allele array for the 4 [\s\S]*'💩': 2", ): - sd = tsinfer.VariantData(str(zarr_path) + ".tmp", "variant_ancestral_allele") - inf_ts = tsinfer.infer(sd) + vdata = tsinfer.VariantData(zarr_path, anc_state) + inf_ts = tsinfer.infer(vdata) + for i, (inf_var, var) in enumerate(zip(inf_ts.variants(), ts.variants())): + if i in [0, 11, 12, 15]: + assert inf_var.site.metadata == {"inference_type": "parsimony"} + assert inf_var.site.ancestral_state in var.site.alleles + else: + assert inf_var.site.ancestral_state == var.site.ancestral_state + + +@pytest.mark.skipif(sys.platform == "win32", reason="No cyvcf2 on Windows") +def test_ancestral_missing_info(tmp_path, caplog): + ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path) + ds = sgkit.load_dataset(zarr_path) + anc_state = ds.variant_ancestral_allele.values + anc_state[0] = "N" + anc_state[11] = "N" + anc_state[12] = "n" + anc_state[15] = "n" + with caplog.at_level(logging.INFO): + vdata = tsinfer.VariantData(zarr_path, anc_state) + assert f"4 sites ({4/ts.num_sites * 100 :.2f}%) were deliberately " in caplog.text + inf_ts = tsinfer.infer(vdata) for i, (inf_var, var) in enumerate(zip(inf_ts.variants(), ts.variants())): if i in [0, 11, 12, 15]: assert inf_var.site.metadata == {"inference_type": "parsimony"} + assert inf_var.site.ancestral_state in var.site.alleles else: assert inf_var.site.ancestral_state == var.site.ancestral_state @@ -670,6 +712,25 @@ def test_sgkit_ancestor(small_sd_fixture, tmp_path): class TestVariantDataErrors: + @staticmethod + def simulate_genotype_call_dataset(*args, **kwargs): + # roll our own simulate_genotype_call_dataset to hack around bug in sgkit where + # duplicate alleles are created. Doesn't need to be efficient: just for testing + if "seed" not in kwargs: + kwargs["seed"] = 123 + ds = sgkit.simulate_genotype_call_dataset(*args, **kwargs) + variant_alleles = ds["variant_allele"].values + allowed_alleles = np.array( + ["A", "T", "C", "G", "N"], dtype=variant_alleles.dtype + ) + for row in range(len(variant_alleles)): + alleles = variant_alleles[row] + if len(set(alleles)) != len(alleles): + # Just use a set that we know is unique + variant_alleles[row] = allowed_alleles[0 : len(alleles)] + ds["variant_allele"] = ds["variant_allele"].dims, variant_alleles + return ds + def test_bad_zarr_spec(self): ds = zarr.group() ds["call_genotype"] = zarr.array(np.zeros(10, dtype=np.int8)) @@ -680,7 +741,7 @@ def test_bad_zarr_spec(self): def test_missing_phase(self, tmp_path): path = tmp_path / "data.zarr" - ds = sgkit.simulate_genotype_call_dataset(n_variant=3, n_sample=3) + ds = self.simulate_genotype_call_dataset(n_variant=3, n_sample=3) sgkit.save_dataset(ds, path) with pytest.raises( ValueError, match="The call_genotype_phased array is missing" @@ -689,7 +750,7 @@ def test_missing_phase(self, tmp_path): def test_phased(self, tmp_path): path = tmp_path / "data.zarr" - ds = sgkit.simulate_genotype_call_dataset(n_variant=3, n_sample=3) + ds = self.simulate_genotype_call_dataset(n_variant=3, n_sample=3) ds["call_genotype_phased"] = ( ds["call_genotype"].dims, np.ones(ds["call_genotype"].shape, dtype=bool), @@ -700,13 +761,13 @@ def test_phased(self, tmp_path): def test_ploidy1_missing_phase(self, tmp_path): path = tmp_path / "data.zarr" # Ploidy==1 is always ok - ds = sgkit.simulate_genotype_call_dataset(n_variant=3, n_sample=3, n_ploidy=1) + ds = self.simulate_genotype_call_dataset(n_variant=3, n_sample=3, n_ploidy=1) sgkit.save_dataset(ds, path) tsinfer.VariantData(path, ds["variant_allele"][:, 0].values.astype(str)) def test_ploidy1_unphased(self, tmp_path): path = tmp_path / "data.zarr" - ds = sgkit.simulate_genotype_call_dataset(n_variant=3, n_sample=3, n_ploidy=1) + ds = self.simulate_genotype_call_dataset(n_variant=3, n_sample=3, n_ploidy=1) ds["call_genotype_phased"] = ( ds["call_genotype"].dims, np.zeros(ds["call_genotype"].shape, dtype=bool), @@ -716,7 +777,7 @@ def test_ploidy1_unphased(self, tmp_path): def test_duplicate_positions(self, tmp_path): path = tmp_path / "data.zarr" - ds = sgkit.simulate_genotype_call_dataset(n_variant=3, n_sample=3, phased=True) + ds = self.simulate_genotype_call_dataset(n_variant=3, n_sample=3, phased=True) ds["variant_position"][2] = ds["variant_position"][1] sgkit.save_dataset(ds, path) with pytest.raises(ValueError, match="duplicate or out-of-order values"): @@ -724,23 +785,46 @@ def test_duplicate_positions(self, tmp_path): def test_bad_order_positions(self, tmp_path): path = tmp_path / "data.zarr" - ds = sgkit.simulate_genotype_call_dataset(n_variant=3, n_sample=3, phased=True) + ds = self.simulate_genotype_call_dataset(n_variant=3, n_sample=3, phased=True) ds["variant_position"][0] = ds["variant_position"][2] - 0.5 sgkit.save_dataset(ds, path) with pytest.raises(ValueError, match="duplicate or out-of-order values"): tsinfer.VariantData(path, "variant_ancestral_allele") + def test_bad_ancestral_state(self, tmp_path): + path = tmp_path / "data.zarr" + ds = self.simulate_genotype_call_dataset(n_variant=3, n_sample=3, phased=True) + ancestral_state = ds["variant_allele"][:, 0].values.astype(str) + ancestral_state[1] = "" + sgkit.save_dataset(ds, path) + with pytest.raises(ValueError, match="cannot contain empty strings"): + tsinfer.VariantData(path, ancestral_state) + def test_empty_alleles_not_at_end(self, tmp_path): path = tmp_path / "data.zarr" - ds = sgkit.simulate_genotype_call_dataset(n_variant=3, n_sample=3, n_ploidy=1) + ds = self.simulate_genotype_call_dataset(n_variant=3, n_sample=3, n_ploidy=1) ds["variant_allele"] = ( ds["variant_allele"].dims, - np.array([["", "A", "C"], ["A", "C", ""], ["A", "C", ""]], dtype="S1"), + np.array([["A", "", "C"], ["A", "C", ""], ["A", "C", ""]], dtype="S1"), ) sgkit.save_dataset(ds, path) - vdata = tsinfer.VariantData(path, ds["variant_allele"][:, 0].values.astype(str)) - with pytest.raises(ValueError, match="Empty alleles must be at the end"): - tsinfer.infer(vdata) + with pytest.raises( + ValueError, match='Bad alleles: fill value "" in middle of list' + ): + tsinfer.VariantData(path, ds["variant_allele"][:, 0].values.astype(str)) + + def test_unique_alleles(self, tmp_path): + path = tmp_path / "data.zarr" + ds = self.simulate_genotype_call_dataset(n_variant=3, n_sample=3, n_ploidy=1) + ds["variant_allele"] = ( + ds["variant_allele"].dims, + np.array([["A", "C", "T"], ["A", "C", ""], ["A", "A", ""]], dtype="S1"), + ) + sgkit.save_dataset(ds, path) + with pytest.raises( + ValueError, match="Duplicate allele values provided at site 2" + ): + tsinfer.VariantData(path, np.array(["A", "A", "A"], dtype="S1")) def test_unimplemented_from_tree_sequence(self): # NB we should reimplement something like this functionality. diff --git a/tsinfer/formats.py b/tsinfer/formats.py index 39f59e13..99d48064 100644 --- a/tsinfer/formats.py +++ b/tsinfer/formats.py @@ -31,6 +31,7 @@ import sys import threading import warnings +from typing import Union # noqa: F401 import attr import humanize @@ -2091,12 +2092,11 @@ def sites(self, ids=None): ids = np.arange(0, self.num_sites, dtype=int) for j in ids: anc_idx = ancestral_allele_array[j] - alleles = tuple(alleles_array[j]) site = Site( id=j, position=position_array[j], ancestral_allele=anc_idx, - alleles=alleles, + alleles=tuple(alleles_array[j]), metadata=metadata_array[j], time=time_array[j], ) @@ -2293,6 +2293,48 @@ def populations(self): class VariantData(SampleData): + """ + Class representing input variant data used for inference. This is + mostly a thin wrapper for a Zarr dataset storing information in + the VCF Zarr (.vcz) format, plus information specifing the ancestral allele + and (optional) data masks. It then provides various derived properties and + methods for accessing the data in a form suitable for inference. + + .. note:: + In the VariantData object, "samples" refer to the individuals in the dataset, + each of which can be of arbitrary ploidy. This is in contrast to ``tskit``, + in which each *haploid genome* is treated as a separate "sample". For example + in a diploid dataset, the inferred tree sequence returned at the end of + the inference process will have ``inferred_ts.num_samples`` equal to double + the number returned by ``VariantData.num_samples``. + + :param str path: The path to the file containing the input dataset in VCF-Zarr + format. + :param Union(array, str) ancestral_state: A numpy array of strings specifying + the ancestral states (alleles) used in inference. This must be the same length + as the number of unmasked sites in the dataset. Alternatively, a single string + can be provided, giving the name of an array in the input dataset which contains + the ancestral states. Unknown ancestral states can be specified using "N". + Any ancestral states which do not match any of the known alleles at that site, + will be tallied, and a warning issued summarizing the unknown ancestral states. + :param Union(array, str) sample_mask: A numpy array of booleans specifying which + samples to mask out (exclude) from the dataset. Alternatively, a string + can be provided, giving the name of an array in the input dataset which contains + the sample mask. If ``None`` (default), all samples are included. + :param Union(array, str) site_mask: A numpy array of booleans specifying which + sites to mask out (exclude) from the dataset. Alternatively, a string + can be provided, giving the name of an array in the input dataset which contains + the site mask. If ``None`` (default), all sites are included. + :param Union(array, str) sites_time: A numpy array of floats specifying the relative + time of occurrence of the mutation to the derived state at each site. This must + be of the same length as the number of unmasked sites. Alternatively, a + string can be provided, giving the name of an array in the input dataset + which contains the site times. If ``None`` (default), the frequency of the + derived allele is used as a proxy for the time of occurrence: this is usually a + reasonable approximation to the relative order of ancestors used for inference. + Time values are ignored for sites not used in inference, such as singletons, + sites with more than two alleles, or sites with an unknown ancestral state. + """ FORMAT_NAME = "tsinfer-variant-data" FORMAT_VERSION = (0, 1) @@ -2300,7 +2342,7 @@ class VariantData(SampleData): def __init__( self, path_or_zarr, - ancestral_allele, + ancestral_state, *, sample_mask=None, site_mask=None, @@ -2393,33 +2435,32 @@ def __init__( f"The sites time {sites_time} was not found" f" in the dataset." ) - if isinstance(ancestral_allele, np.ndarray): - if ancestral_allele.shape[0] != self.num_sites: + if isinstance(ancestral_state, np.ndarray): + if ancestral_state.shape[0] != self.num_sites: raise ValueError( - "Ancestral allele array must be the same length as the number of" + "Ancestral state array must be the same length as the number of" " selected sites" ) - self._sites_ancestral_allele = ancestral_allele else: try: - self._sites_ancestral_allele = self.data[ancestral_allele][:][ - self.sites_select - ] + ancestral_state = self.data[ancestral_state][:][self.sites_select] except KeyError: raise ValueError( - f"The ancestral allele {ancestral_allele} was not" + f"The ancestral state array {ancestral_state} was not" f" found in the dataset." ) - self._sites_ancestral_allele = self._sites_ancestral_allele.astype(str) + ancestral_state = ancestral_state.astype(str) + if np.any(ancestral_state == ""): + raise ValueError("Ancestral state array cannot contain empty strings") + + self._sites_ancestral_allele = np.full(self.num_sites, -1, dtype=np.int8) unknown_alleles = collections.Counter() - converted = np.zeros(self.num_sites, dtype=np.int8) - for i, allele in enumerate(self._sites_ancestral_allele): - allele_index = -1 - try: - allele_index = np.where(allele == self.sites_alleles[i])[0][0] - except IndexError: - unknown_alleles[allele] += 1 - converted[i] = allele_index + for i, (anc_state, site) in enumerate(zip(ancestral_state, self.sites())): + if anc_state in {"N", "n"} or anc_state not in site.alleles: + unknown_alleles[anc_state] += 1 + else: + self._sites_ancestral_allele[i] = site.alleles.index(anc_state) + deliberately_unknown = sum([unknown_alleles.get(c, 0) for c in ("N", "n")]) tot = sum(unknown_alleles.values()) if tot > 0: frac_bad = tot / self.num_sites @@ -2428,13 +2469,18 @@ def __init__( f"'{k}': {v} ({frac * 100:.2f}% of sites)" # Summarise per allele type for (k, v), frac in zip(unknown_alleles.items(), frac_bad_per_type) ] - warnings.warn( - "An ancestral allele was not found in the variant_allele array for " - + f"the {tot} sites ({frac_bad * 100 :.2f}%) listed below. " - + "They will be treated as of unknown ancestral state:\n " - + "\n ".join(summarise_unknown) - ) - self._sites_ancestral_allele = converted + if tot == deliberately_unknown: + logging.info( + f"{tot} sites ({frac_bad * 100 :.2f}%) were deliberately marked as " + "of unknown ancestral state. They will not be used for inference" + ) + else: + warnings.warn( + "An ancestral allele was not found in the variant_allele array for " + + f"the {tot} sites ({frac_bad * 100 :.2f}%) listed below. " + + "They will be treated as of unknown ancestral state:\n " + + "\n ".join(summarise_unknown) + ) # Create zarr arrays for convenience when iterating over chunks self.z_sites_select = zarr.array( @@ -2664,6 +2710,45 @@ def individuals_flags(self): except KeyError: return np.full((self.num_individuals), 0, dtype=np.int32) + @staticmethod + def _trim_allele_array(allele_array, site_id): + # Trim a list of allelic states to remove any trailing "" entries. + assert allele_array.shape[0] > 0 + used = np.flatnonzero(allele_array != "") + allele_array = allele_array[: (used[-1] + 1)] + if np.any(allele_array == ""): + raise ValueError( + f'Bad alleles: fill value "" in middle of list: {allele_array}' + ) + if len(set(allele_array)) != len(allele_array): + raise ValueError(f"Duplicate allele values provided at site {site_id}") + return allele_array + + def sites(self, ids=None): + """ + Returns an iterator over the Site objects. A subset of the + sites can be returned using the ``ids`` parameter. This must + be a list of integer site IDs. + """ + position_array = self.sites_position[:] + alleles_array = self.sites_alleles[:] + metadata_array = self.sites_metadata[:] + time_array = self.sites_time[:] + ancestral_allele_array = self.sites_ancestral_allele[:] + if ids is None: + ids = np.arange(0, self.num_sites, dtype=int) + for j in ids: + anc_idx = ancestral_allele_array[j] + site = Site( + id=j, + position=position_array[j], + ancestral_allele=anc_idx, + alleles=tuple(self._trim_allele_array(alleles_array[j], j)), + metadata=metadata_array[j], + time=time_array[j], + ) + yield site + def variants(self, sites=None, recode_ancestral=None): """ Returns an iterator over the :class:`Variant` objects. This is equivalent to @@ -2708,19 +2793,8 @@ def variants(self, sites=None, recode_ancestral=None): geno_map[aa] = 0 geno_map[0:aa] += 1 genos = geno_map[genos] - # Filter out empty alleles, as sgkit pads with them so that all sites have - # the same number of alleles. This is only safe if the empty - # alleles are at the end of the list, so check this. - non_empty_alleles = [] - empty_seen = False - for allele in alleles: - if allele != b"" and allele != "": - if empty_seen: - raise ValueError("Empty alleles must be at the end") - non_empty_alleles.append(allele) - else: - empty_seen = True - alleles = non_empty_alleles + # Empty alleles (padded by sgkit) should never be seen + assert all(a != "" for a in alleles) yield Variant(site=site, alleles=alleles, genotypes=genos) def _all_haplotypes(self, sites=None, recode_ancestral=None, samples_slice=None): diff --git a/tsinfer/inference.py b/tsinfer/inference.py index ab5f477f..f22811b0 100644 --- a/tsinfer/inference.py +++ b/tsinfer/inference.py @@ -586,7 +586,7 @@ def match_ancestors( def match_ancestors_batch_init( working_dir, sample_data_path, - ancestral_allele, + ancestral_state, ancestor_data_path, min_work_per_job, *, @@ -614,7 +614,7 @@ def match_ancestors_batch_init( ancestors = formats.AncestorData.load(ancestor_data_path) sample_data = formats.VariantData( sample_data_path, - ancestral_allele=ancestral_allele, + ancestral_state=ancestral_state, sample_mask=sample_mask, site_mask=site_mask, ) @@ -666,7 +666,7 @@ def match_ancestors_batch_init( metadata = { "sample_data_path": str(sample_data_path), - "ancestral_allele": ancestral_allele, + "ancestral_state": ancestral_state, "ancestor_data_path": str(ancestor_data_path), "sample_mask": sample_mask, "site_mask": site_mask, @@ -690,7 +690,7 @@ def match_ancestors_batch_init( def initialize_ancestor_matcher(metadata, ancestors_ts=None, **kwargs): sample_data = formats.VariantData( metadata["sample_data_path"], - ancestral_allele=metadata["ancestral_allele"], + ancestral_state=metadata["ancestral_state"], sample_mask=metadata["sample_mask"], site_mask=metadata["site_mask"], ) @@ -910,7 +910,7 @@ def augment_ancestors( @dataclasses.dataclass class SampleBatchWorkDescriptor: sample_data_path: str - ancestral_allele: str + ancestral_state: str sample_mask: np.ndarray site_mask: np.ndarray ancestor_ts_path: str @@ -972,7 +972,7 @@ def numpy_decoder(dct): def load_variant_data_and_ancestors_ts(wd: SampleBatchWorkDescriptor): variant_data = formats.VariantData( wd.sample_data_path, - wd.ancestral_allele, + wd.ancestral_state, sample_mask=wd.sample_mask, site_mask=wd.site_mask, ) @@ -989,7 +989,7 @@ def load_variant_data_and_ancestors_ts(wd: SampleBatchWorkDescriptor): def match_samples_batch_init( work_dir, sample_data_path, - ancestral_allele, + ancestral_state, ancestor_ts_path, min_work_per_job, *, @@ -1022,7 +1022,7 @@ def match_samples_batch_init( wd = SampleBatchWorkDescriptor( sample_data_path=str(sample_data_path), - ancestral_allele=ancestral_allele, + ancestral_state=ancestral_state, sample_mask=sample_mask, site_mask=site_mask, ancestor_ts_path=str(ancestor_ts_path),