Skip to content

Commit

Permalink
Merge pull request #1904 from jeromekelleher/remove-provenance-parame…
Browse files Browse the repository at this point in the history
…ter-fanciness

Do not use inspect to generate provenance.
  • Loading branch information
jeromekelleher authored Nov 12, 2021
2 parents 813f6c5 + 561266a commit 2dbf702
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 81 deletions.
6 changes: 5 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Changelog

## [1.0.3] - 2021-XX-XX
## [1.0.3] - 2021-11-12

**New features**:

Expand All @@ -13,6 +13,10 @@
where ARG nodes were not correctly returned. ({issue}`1893`,
{user}`jeromekelleher`, {user}`hyl317`)

- Fix memory leak when running ``sim_ancestry`` in a loop
({pr}`1904`, {issue}`1899`, {user}`jeromekelleher`, {user}`grahamgower`).


## [1.0.2] - 2021-06-29

Improved Demes support and minor bugfixes.
Expand Down
75 changes: 48 additions & 27 deletions msprime/ancestry.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import copy
import dataclasses
import enum
import inspect
import json
import logging
import math
Expand Down Expand Up @@ -411,28 +410,6 @@ def _parse_replicate_index(*, replicate_index, random_seed, num_replicates):
return replicate_index


def _build_provenance(command, random_seed, frame):
"""
Builds a provenance dictionary suitable for use as the basis
of tree sequence provenance in replicate simulations. Uses the
specified stack frame to determine the values of the arguments
passed in, with a few exceptions.
"""
argspec = inspect.getargvalues(frame)
# num_replicates is excluded as provenance is per replicate
# replicate index is excluded as it is inserted for each replicate
parameters = {
"command": command,
**{
arg: argspec.locals[arg]
for arg in argspec.args
if arg not in ["num_replicates", "replicate_index"]
},
}
parameters["random_seed"] = random_seed
return provenance.get_provenance_dict(parameters)


def simulate(
sample_size=None,
*,
Expand Down Expand Up @@ -579,8 +556,31 @@ def simulate(
random_seed = _parse_random_seed(random_seed)
provenance_dict = None
if record_provenance:
frame = inspect.currentframe()
provenance_dict = _build_provenance("simulate", random_seed, frame)
parameters = dict(
command="simulate",
sample_size=sample_size,
Ne=Ne,
length=length,
recombination_rate=recombination_rate,
recombination_map=recombination_map,
mutation_rate=mutation_rate,
population_configurations=population_configurations,
pedigree=pedigree,
migration_matrix=migration_matrix,
demographic_events=demographic_events,
samples=samples,
model=model,
record_migrations=record_migrations,
from_ts=from_ts,
start_time=start_time,
end_time=end_time,
record_full_arg=record_full_arg,
num_labels=num_labels,
random_seed=random_seed,
# num_replicates is excluded as provenance is per replicate
# replicate index is excluded as it is inserted for each replicate
)
provenance_dict = provenance.get_provenance_dict(parameters)

if mutation_generator is not None:
# This error was added in version 0.6.1.
Expand Down Expand Up @@ -1138,8 +1138,29 @@ def sim_ancestry(
random_seed = _parse_random_seed(random_seed)
provenance_dict = None
if record_provenance:
frame = inspect.currentframe()
provenance_dict = _build_provenance("sim_ancestry", random_seed, frame)
parameters = dict(
command="sim_ancestry",
samples=samples,
demography=demography,
sequence_length=sequence_length,
discrete_genome=discrete_genome,
recombination_rate=recombination_rate,
gene_conversion_rate=gene_conversion_rate,
gene_conversion_tract_length=gene_conversion_tract_length,
population_size=population_size,
ploidy=ploidy,
model=model,
initial_state=initial_state,
start_time=start_time,
end_time=end_time,
record_migrations=record_migrations,
record_full_arg=record_full_arg,
num_labels=num_labels,
random_seed=random_seed,
# num_replicates is excluded as provenance is per replicate
# replicate index is excluded as it is inserted for each replicate
)
provenance_dict = provenance.get_provenance_dict(parameters)
sim = _parse_sim_ancestry(
samples=samples,
sequence_length=sequence_length,
Expand Down
26 changes: 15 additions & 11 deletions msprime/mutations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1331,6 +1331,21 @@ def sim_mutations(
else:
seed = int(seed)

parameters = dict(
command="sim_mutations",
tree_sequence=tree_sequence,
rate=rate,
model=model,
start_time=start_time,
end_time=end_time,
discrete_genome=discrete_genome,
keep=keep,
random_seed=seed,
)
encoded_provenance = provenance.json_encode_provenance(
provenance.get_provenance_dict(parameters)
)

if rate is None:
rate = 0
try:
Expand All @@ -1349,17 +1364,6 @@ def sim_mutations(
keep = core._parse_flag(keep, default=True)

model = mutation_model_factory(model)

argspec = inspect.getargvalues(inspect.currentframe())
parameters = {
"command": "sim_mutations",
**{arg: argspec.locals[arg] for arg in argspec.args},
}
parameters["random_seed"] = seed
encoded_provenance = provenance.json_encode_provenance(
provenance.get_provenance_dict(parameters)
)

rng = _msprime.RandomGenerator(seed)
lwt = _msprime.LightweightTableCollection()
lwt.fromdict(tables.asdict())
Expand Down
59 changes: 17 additions & 42 deletions tests/test_provenance.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
"""
Tests for the provenance information attached to tree sequences.
"""
import inspect
import json
import logging

Expand All @@ -30,7 +29,6 @@

import msprime
from msprime import _msprime
from msprime import ancestry
from msprime import pedigrees


Expand All @@ -49,41 +47,6 @@ def test_libraries(self):
assert libs["tskit"] == {"version": tskit.__version__}


class TestBuildProvenance:
"""
Tests for the provenance dictionary building. This dictionary is used
to encode the parameters for the msprime simulations.
"""

def test_basic(self):
def somefunc(a, b):
frame = inspect.currentframe()
return ancestry._build_provenance("cmd", 1234, frame)

d = somefunc(42, 43)
tskit.validate_provenance(d)
params = d["parameters"]
assert params["command"] == "cmd"
assert params["random_seed"] == 1234
assert params["a"] == 42
assert params["b"] == 43

def test_replicates(self):
def somefunc(*, a, b, num_replicates, replicate_index):
frame = inspect.currentframe()
return ancestry._build_provenance("the_cmd", 42, frame)

d = somefunc(b="b", a="a", num_replicates=100, replicate_index=1234)
tskit.validate_provenance(d)
params = d["parameters"]
assert params["command"] == "the_cmd"
assert params["random_seed"] == 42
assert params["a"] == "a"
assert params["b"] == "b"
assert not ("num_replicates" in d)
assert not ("replicate_index" in d)


class ValidateSchemas:
"""
Check that the schemas we produce in msprime are valid.
Expand Down Expand Up @@ -216,15 +179,14 @@ def test_sim_mutations(self):
assert decoded.parameters.start_time == 0
assert decoded.parameters.end_time == 100
assert not decoded.parameters.keep
assert decoded.parameters.model["__class__"] == "msprime.mutations.JC69"

def test_mutate_model(self):
ts = msprime.simulate(5, random_seed=1)
ts = msprime.sim_mutations(ts, model="pam")
decoded = self.decode(ts.provenance(1).record)
assert decoded.schema_version == "1.0.0"
assert decoded.parameters.command == "sim_mutations"
assert decoded.parameters.model["__class__"] == "msprime.mutations.PAM"
assert decoded.parameters.model == "pam"

def test_mutate_map(self):
ts = msprime.simulate(5, random_seed=1)
Expand Down Expand Up @@ -252,9 +214,11 @@ def test_mutate_numpy(self):
assert decoded.schema_version == "1.0.0"
assert decoded.parameters.command == "sim_mutations"
assert decoded.parameters.random_seed == 1
assert decoded.parameters.rate == 2
assert decoded.parameters.start_time == 0
assert decoded.parameters.end_time == 100
# The dtype values change depending on platform, so not much
# point in trying to test exactly.
assert decoded.parameters.rate["__npgeneric__"] == "2"
assert decoded.parameters.start_time["__npgeneric__"] == "0"
assert decoded.parameters.end_time["__ndarray__"] == 100


class TestParseProvenance:
Expand Down Expand Up @@ -366,6 +330,17 @@ def test_mutate_rate_map(self):
ts = msprime.mutate(ts, rate=rate_map)
self.verify(ts)

def test_mutate_numpy(self):
ts = msprime.sim_ancestry(5, random_seed=1)
ts = msprime.sim_mutations(
ts,
rate=np.array([2])[0],
random_seed=np.array([1])[0],
start_time=np.array([0])[0],
end_time=np.array([100][0]),
)
self.verify(ts)


class TestRetainsProvenance:
def test_simulate_retains_provenance(self):
Expand Down

0 comments on commit 2dbf702

Please sign in to comment.