Skip to content

Commit

Permalink
Allow VariantData to take either a path to a vcz or the zarr store it…
Browse files Browse the repository at this point in the history
…self
  • Loading branch information
hyanwong committed Sep 5, 2024
1 parent a080e8f commit 21fbe57
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 86 deletions.
159 changes: 77 additions & 82 deletions tests/test_variantdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,146 +133,141 @@ def test_sgkit_individual_metadata_not_clobbered(tmp_path):


@pytest.mark.skipif(sys.platform == "win32", reason="No cyvcf2 on windows")
def test_sgkit_dataset_accessors(tmp_path):
ts, zarr_path = tsutil.make_ts_and_zarr(
tmp_path, add_optional=True, shuffle_alleles=False
)
samples = tsinfer.VariantData(
zarr_path, "variant_ancestral_allele", sites_time="sites_time"
)
ds = sgkit.load_dataset(zarr_path)

assert samples.format_name == "tsinfer-sgkit-sample-data"
assert samples.format_version == (0, 1)
assert samples.finalised
assert samples.sequence_length == ts.sequence_length + 1337
assert samples.num_sites == ts.num_sites
assert samples.sites_metadata_schema == ts.tables.sites.metadata_schema.schema
assert samples.sites_metadata == [site.metadata for site in ts.sites()]
assert np.array_equal(samples.sites_time, np.arange(ts.num_sites) / ts.num_sites)
assert np.array_equal(samples.sites_position, ts.tables.sites.position)
for alleles, v in zip(samples.sites_alleles, ts.variants()):
@pytest.mark.parametrize("in_mem", [True, False])
def test_variantdata_accessors(tmp_path, in_mem):
path = None if in_mem else tmp_path
ts, data = tsutil.make_ts_and_zarr(path, add_optional=True, shuffle_alleles=False)
vd = tsinfer.VariantData(data, "variant_ancestral_allele", sites_time="sites_time")
ds = data if in_mem else sgkit.load_dataset(data)

assert vd.format_name == "tsinfer-sgkit-sample-data"
assert vd.format_version == (0, 1)
assert vd.finalised
assert vd.sequence_length == ts.sequence_length + 1337
assert vd.num_sites == ts.num_sites
assert vd.sites_metadata_schema == ts.tables.sites.metadata_schema.schema
assert vd.sites_metadata == [site.metadata for site in ts.sites()]
assert np.array_equal(vd.sites_time, np.arange(ts.num_sites) / ts.num_sites)
assert np.array_equal(vd.sites_position, ts.tables.sites.position)
for alleles, v in zip(vd.sites_alleles, ts.variants()):
# sgkit alleles are padded to be rectangular
assert np.all(alleles[: len(v.alleles)] == v.alleles)
assert np.all(alleles[len(v.alleles) :] == "")
assert np.array_equal(samples.sites_select, np.ones(ts.num_sites, dtype=bool))
assert np.array_equal(vd.sites_select, np.ones(ts.num_sites, dtype=bool))
assert np.array_equal(
samples.sites_ancestral_allele, np.zeros(ts.num_sites, dtype=np.int8)
vd.sites_ancestral_allele, np.zeros(ts.num_sites, dtype=np.int8)
)
assert np.array_equal(samples.sites_genotypes, ts.genotype_matrix())
assert np.array_equal(vd.sites_genotypes, ts.genotype_matrix())
assert np.array_equal(
samples.provenances_timestamp, ["2021-01-01T00:00:00", "2021-01-02T00:00:00"]
vd.provenances_timestamp, ["2021-01-01T00:00:00", "2021-01-02T00:00:00"]
)
assert samples.provenances_record == [{"foo": 1}, {"foo": 2}]
assert samples.num_samples == ts.num_samples
assert vd.provenances_record == [{"foo": 1}, {"foo": 2}]
assert vd.num_samples == ts.num_samples
assert np.array_equal(
samples.samples_individual, np.repeat(np.arange(ts.num_samples // 3), 3)
vd.samples_individual, np.repeat(np.arange(ts.num_samples // 3), 3)
)
assert samples.metadata_schema == tsutil.example_schema("example").schema
assert samples.metadata == ts.tables.metadata
assert vd.metadata_schema == tsutil.example_schema("example").schema
assert vd.metadata == ts.tables.metadata
assert (
samples.populations_metadata_schema
== ts.tables.populations.metadata_schema.schema
vd.populations_metadata_schema == ts.tables.populations.metadata_schema.schema
)
assert samples.populations_metadata == [pop.metadata for pop in ts.populations()]
assert samples.num_individuals == ts.num_individuals
assert vd.populations_metadata == [pop.metadata for pop in ts.populations()]
assert vd.num_individuals == ts.num_individuals
assert np.array_equal(
samples.individuals_time, np.arange(ts.num_individuals, dtype=np.float32)
vd.individuals_time, np.arange(ts.num_individuals, dtype=np.float32)
)
assert (
samples.individuals_metadata_schema
== ts.tables.individuals.metadata_schema.schema
vd.individuals_metadata_schema == ts.tables.individuals.metadata_schema.schema
)
assert samples.individuals_metadata == [
assert vd.individuals_metadata == [
{"variant_data_sample_id": sample_id, **ind.metadata}
for ind, sample_id in zip(ts.individuals(), ds["sample_id"].values)
for ind, sample_id in zip(ts.individuals(), ds.sample_id[:])
]
assert np.array_equal(
samples.individuals_location,
vd.individuals_location,
np.tile(np.array([["0", "1"]], dtype="float32"), (ts.num_individuals, 1)),
)
assert np.array_equal(
samples.individuals_population, np.zeros(ts.num_individuals, dtype="int32")
vd.individuals_population, np.zeros(ts.num_individuals, dtype="int32")
)
assert np.array_equal(
samples.individuals_flags,
vd.individuals_flags,
np.random.RandomState(42).randint(
0, 2_000_000, ts.num_individuals, dtype="int32"
),
)

# Need to shuffle for the ancestral allele test
ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path, add_optional=True)
samples = tsinfer.VariantData(zarr_path, "variant_ancestral_allele")
ts, data = tsutil.make_ts_and_zarr(path, add_optional=True)
vd = tsinfer.VariantData(data, "variant_ancestral_allele")
for i in range(ts.num_sites):
assert (
samples.sites_alleles[i][samples.sites_ancestral_allele[i]]
vd.sites_alleles[i][vd.sites_ancestral_allele[i]]
== ts.site(i).ancestral_state
)


@pytest.mark.skipif(sys.platform == "win32", reason="No cyvcf2 on windows")
def test_sgkit_accessors_defaults(tmp_path):
ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path)
samples = tsinfer.VariantData(zarr_path, "variant_ancestral_allele")
ds = sgkit.load_dataset(zarr_path)
@pytest.mark.parametrize("in_mem", [True, False])
def test_variantdata_accessors_defaults(tmp_path, in_mem):
path = None if in_mem else tmp_path
ts, data = tsutil.make_ts_and_zarr(path)
vdata = tsinfer.VariantData(data, "variant_ancestral_allele")
ds = data if in_mem else sgkit.load_dataset(data)

default_schema = tskit.MetadataSchema.permissive_json().schema
assert samples.sequence_length == ts.sequence_length
assert samples.sites_metadata_schema == default_schema
assert samples.sites_metadata == [{} for _ in range(ts.num_sites)]
for time in samples.sites_time:
assert vdata.sequence_length == ts.sequence_length
assert vdata.sites_metadata_schema == default_schema
assert vdata.sites_metadata == [{} for _ in range(ts.num_sites)]
for time in vdata.sites_time:
assert tskit.is_unknown_time(time)
assert np.array_equal(samples.sites_select, np.ones(ts.num_sites, dtype=bool))
assert np.array_equal(samples.provenances_timestamp, [])
assert np.array_equal(samples.provenances_record, [])
assert samples.metadata_schema == default_schema
assert samples.metadata == {}
assert samples.populations_metadata_schema == default_schema
assert samples.populations_metadata == []
assert samples.individuals_metadata_schema == default_schema
assert samples.individuals_metadata == [
{"variant_data_sample_id": sample_id} for sample_id in ds["sample_id"].values
assert np.array_equal(vdata.sites_select, np.ones(ts.num_sites, dtype=bool))
assert np.array_equal(vdata.provenances_timestamp, [])
assert np.array_equal(vdata.provenances_record, [])
assert vdata.metadata_schema == default_schema
assert vdata.metadata == {}
assert vdata.populations_metadata_schema == default_schema
assert vdata.populations_metadata == []
assert vdata.individuals_metadata_schema == default_schema
assert vdata.individuals_metadata == [
{"variant_data_sample_id": sample_id} for sample_id in ds.sample_id[:]
]
for time in samples.individuals_time:
for time in vdata.individuals_time:
assert tskit.is_unknown_time(time)
assert np.array_equal(
samples.individuals_location, np.array([[]] * ts.num_individuals, dtype=float)
vdata.individuals_location, np.array([[]] * ts.num_individuals, dtype=float)
)
assert np.array_equal(
samples.individuals_population, np.full(ts.num_individuals, tskit.NULL)
vdata.individuals_population, np.full(ts.num_individuals, tskit.NULL)
)
assert np.array_equal(
samples.individuals_flags, np.zeros(ts.num_individuals, dtype=int)
vdata.individuals_flags, np.zeros(ts.num_individuals, dtype=int)
)


@pytest.mark.skipif(sys.platform == "win32", reason="No cyvcf2 on windows")
def test_variantdata_sites_time_default(tmp_path):
ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path)
samples = tsinfer.VariantData(zarr_path, "variant_ancestral_allele")
def test_variantdata_sites_time_default():
ts, data = tsutil.make_ts_and_zarr()
vdata = tsinfer.VariantData(data, "variant_ancestral_allele")

assert (
np.all(np.isnan(samples.sites_time))
and samples.sites_time.size == samples.num_sites
np.all(np.isnan(vdata.sites_time)) and vdata.sites_time.size == vdata.num_sites
)


@pytest.mark.skipif(sys.platform == "win32", reason="No cyvcf2 on windows")
def test_variantdata_sites_time_array(tmp_path):
ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path)
def test_variantdata_sites_time_array():
ts, data = tsutil.make_ts_and_zarr()
sites_time = np.arange(ts.num_sites)
samples = tsinfer.VariantData(
zarr_path, "variant_ancestral_allele", sites_time=sites_time
)
assert np.array_equal(samples.sites_time, sites_time)
vdata = tsinfer.VariantData(data, "variant_ancestral_allele", sites_time=sites_time)
assert np.array_equal(vdata.sites_time, sites_time)
wrong_length_sites_time = np.arange(ts.num_sites + 1)
with pytest.raises(
ValueError,
match="Sites time array must be the same length as the number of selected sites",
):
tsinfer.VariantData(
zarr_path,
data,
"variant_ancestral_allele",
sites_time=wrong_length_sites_time,
)
Expand Down Expand Up @@ -302,17 +297,17 @@ def test_sgkit_variant_mask(self, tmp_path, sites):
for i in sites:
sites_mask[i] = False
tsutil.add_array_to_dataset("variant_mask_42", sites_mask, zarr_path)
samples = tsinfer.VariantData(
vdata = tsinfer.VariantData(
zarr_path,
"variant_ancestral_allele",
site_mask="variant_mask_42",
)
assert samples.num_sites == len(sites)
assert np.array_equal(samples.sites_select, ~sites_mask)
assert vdata.num_sites == len(sites)
assert np.array_equal(vdata.sites_select, ~sites_mask)
assert np.array_equal(
samples.sites_position, ts.tables.sites.position[~sites_mask]
vdata.sites_position, ts.tables.sites.position[~sites_mask]
)
inf_ts = tsinfer.infer(samples)
inf_ts = tsinfer.infer(vdata)
assert np.array_equal(
ts.genotype_matrix()[~sites_mask], inf_ts.genotype_matrix()
)
Expand Down
21 changes: 20 additions & 1 deletion tests/tsutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,15 @@
Extra utility functions used in several test files
"""
import json
import tempfile
from pathlib import Path

import msprime
import numpy as np
import sgkit
import tskit
import xarray as xr
import zarr

import tsinfer

Expand Down Expand Up @@ -219,7 +222,23 @@ def add_attribute_to_dataset(name, contents, zarr_path):
sgkit.save_dataset(ds, zarr_path, mode="a")


def make_ts_and_zarr(path, add_optional=False, shuffle_alleles=True):
def make_ts_and_zarr(path=None, add_optional=False, shuffle_alleles=True):
if path is None:
in_mem_copy = zarr.group()
with tempfile.TemporaryDirectory() as path:
ts, zarr_path = _make_ts_and_zarr(
Path(path), add_optional=add_optional, shuffle_alleles=shuffle_alleles
)
# For testing only, return an in-memory copy of the dataset we just made
zarr.convenience.copy_all(zarr.open(zarr_path), in_mem_copy)
return ts, in_mem_copy
else:
return _make_ts_and_zarr(
path, add_optional=add_optional, shuffle_alleles=shuffle_alleles
)


def _make_ts_and_zarr(path, add_optional=False, shuffle_alleles=True):
import sgkit.io.vcf

ts = msprime.sim_ancestry(
Expand Down
13 changes: 10 additions & 3 deletions tsinfer/formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -2299,15 +2299,22 @@ class VariantData(SampleData):

def __init__(
self,
path,
path_or_zarr,
ancestral_allele,
*,
sample_mask=None,
site_mask=None,
sites_time=None,
):
self.path = path
self.data = zarr.open(path, mode="r")
try:
if len(path_or_zarr.call_genotype.shape) == 3:
# Assumed to be a VCF Zarr hierarchy
self.path = None
self.data = path_or_zarr
except AttributeError:
self.path = path_or_zarr
self.data = zarr.open(path_or_zarr, mode="r")

genotypes_arr = self.data["call_genotype"]
_, self._num_individuals_before_mask, self.ploidy = genotypes_arr.shape

Expand Down

0 comments on commit 21fbe57

Please sign in to comment.