From 561266ace66e0f7337482216c8bec6f007347360 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Thu, 11 Nov 2021 19:51:32 +0000 Subject: [PATCH] Do not use inspect to generate provenance. Closes #1899 --- CHANGELOG.md | 6 +++- msprime/ancestry.py | 75 +++++++++++++++++++++++++--------------- msprime/mutations.py | 26 ++++++++------ tests/test_provenance.py | 59 +++++++++---------------------- 4 files changed, 85 insertions(+), 81 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index cb4891075..cc06547a2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,6 @@ # Changelog -## [1.0.3] - 2021-XX-XX +## [1.0.3] - 2021-11-12 **New features**: @@ -13,6 +13,10 @@ where ARG nodes were not correctly returned. ({issue}`1893`, {user}`jeromekelleher`, {user}`hyl317`) +- Fix memory leak when running ``sim_ancestry`` in a loop + ({pr}`1904`, {issue}`1899`, {user}`jeromekelleher`, {user}`grahamgower`). + + ## [1.0.2] - 2021-06-29 Improved Demes support and minor bugfixes. diff --git a/msprime/ancestry.py b/msprime/ancestry.py index c1b2aa5b0..cf602a125 100644 --- a/msprime/ancestry.py +++ b/msprime/ancestry.py @@ -25,7 +25,6 @@ import copy import dataclasses import enum -import inspect import json import logging import math @@ -411,28 +410,6 @@ def _parse_replicate_index(*, replicate_index, random_seed, num_replicates): return replicate_index -def _build_provenance(command, random_seed, frame): - """ - Builds a provenance dictionary suitable for use as the basis - of tree sequence provenance in replicate simulations. Uses the - specified stack frame to determine the values of the arguments - passed in, with a few exceptions. - """ - argspec = inspect.getargvalues(frame) - # num_replicates is excluded as provenance is per replicate - # replicate index is excluded as it is inserted for each replicate - parameters = { - "command": command, - **{ - arg: argspec.locals[arg] - for arg in argspec.args - if arg not in ["num_replicates", "replicate_index"] - }, - } - parameters["random_seed"] = random_seed - return provenance.get_provenance_dict(parameters) - - def simulate( sample_size=None, *, @@ -579,8 +556,31 @@ def simulate( random_seed = _parse_random_seed(random_seed) provenance_dict = None if record_provenance: - frame = inspect.currentframe() - provenance_dict = _build_provenance("simulate", random_seed, frame) + parameters = dict( + command="simulate", + sample_size=sample_size, + Ne=Ne, + length=length, + recombination_rate=recombination_rate, + recombination_map=recombination_map, + mutation_rate=mutation_rate, + population_configurations=population_configurations, + pedigree=pedigree, + migration_matrix=migration_matrix, + demographic_events=demographic_events, + samples=samples, + model=model, + record_migrations=record_migrations, + from_ts=from_ts, + start_time=start_time, + end_time=end_time, + record_full_arg=record_full_arg, + num_labels=num_labels, + random_seed=random_seed, + # num_replicates is excluded as provenance is per replicate + # replicate index is excluded as it is inserted for each replicate + ) + provenance_dict = provenance.get_provenance_dict(parameters) if mutation_generator is not None: # This error was added in version 0.6.1. @@ -1138,8 +1138,29 @@ def sim_ancestry( random_seed = _parse_random_seed(random_seed) provenance_dict = None if record_provenance: - frame = inspect.currentframe() - provenance_dict = _build_provenance("sim_ancestry", random_seed, frame) + parameters = dict( + command="sim_ancestry", + samples=samples, + demography=demography, + sequence_length=sequence_length, + discrete_genome=discrete_genome, + recombination_rate=recombination_rate, + gene_conversion_rate=gene_conversion_rate, + gene_conversion_tract_length=gene_conversion_tract_length, + population_size=population_size, + ploidy=ploidy, + model=model, + initial_state=initial_state, + start_time=start_time, + end_time=end_time, + record_migrations=record_migrations, + record_full_arg=record_full_arg, + num_labels=num_labels, + random_seed=random_seed, + # num_replicates is excluded as provenance is per replicate + # replicate index is excluded as it is inserted for each replicate + ) + provenance_dict = provenance.get_provenance_dict(parameters) sim = _parse_sim_ancestry( samples=samples, sequence_length=sequence_length, diff --git a/msprime/mutations.py b/msprime/mutations.py index 7a29aafb9..6851f8970 100644 --- a/msprime/mutations.py +++ b/msprime/mutations.py @@ -1331,6 +1331,21 @@ def sim_mutations( else: seed = int(seed) + parameters = dict( + command="sim_mutations", + tree_sequence=tree_sequence, + rate=rate, + model=model, + start_time=start_time, + end_time=end_time, + discrete_genome=discrete_genome, + keep=keep, + random_seed=seed, + ) + encoded_provenance = provenance.json_encode_provenance( + provenance.get_provenance_dict(parameters) + ) + if rate is None: rate = 0 try: @@ -1349,17 +1364,6 @@ def sim_mutations( keep = core._parse_flag(keep, default=True) model = mutation_model_factory(model) - - argspec = inspect.getargvalues(inspect.currentframe()) - parameters = { - "command": "sim_mutations", - **{arg: argspec.locals[arg] for arg in argspec.args}, - } - parameters["random_seed"] = seed - encoded_provenance = provenance.json_encode_provenance( - provenance.get_provenance_dict(parameters) - ) - rng = _msprime.RandomGenerator(seed) lwt = _msprime.LightweightTableCollection() lwt.fromdict(tables.asdict()) diff --git a/tests/test_provenance.py b/tests/test_provenance.py index b033fa82a..b6589d238 100644 --- a/tests/test_provenance.py +++ b/tests/test_provenance.py @@ -19,7 +19,6 @@ """ Tests for the provenance information attached to tree sequences. """ -import inspect import json import logging @@ -30,7 +29,6 @@ import msprime from msprime import _msprime -from msprime import ancestry from msprime import pedigrees @@ -49,41 +47,6 @@ def test_libraries(self): assert libs["tskit"] == {"version": tskit.__version__} -class TestBuildProvenance: - """ - Tests for the provenance dictionary building. This dictionary is used - to encode the parameters for the msprime simulations. - """ - - def test_basic(self): - def somefunc(a, b): - frame = inspect.currentframe() - return ancestry._build_provenance("cmd", 1234, frame) - - d = somefunc(42, 43) - tskit.validate_provenance(d) - params = d["parameters"] - assert params["command"] == "cmd" - assert params["random_seed"] == 1234 - assert params["a"] == 42 - assert params["b"] == 43 - - def test_replicates(self): - def somefunc(*, a, b, num_replicates, replicate_index): - frame = inspect.currentframe() - return ancestry._build_provenance("the_cmd", 42, frame) - - d = somefunc(b="b", a="a", num_replicates=100, replicate_index=1234) - tskit.validate_provenance(d) - params = d["parameters"] - assert params["command"] == "the_cmd" - assert params["random_seed"] == 42 - assert params["a"] == "a" - assert params["b"] == "b" - assert not ("num_replicates" in d) - assert not ("replicate_index" in d) - - class ValidateSchemas: """ Check that the schemas we produce in msprime are valid. @@ -216,7 +179,6 @@ def test_sim_mutations(self): assert decoded.parameters.start_time == 0 assert decoded.parameters.end_time == 100 assert not decoded.parameters.keep - assert decoded.parameters.model["__class__"] == "msprime.mutations.JC69" def test_mutate_model(self): ts = msprime.simulate(5, random_seed=1) @@ -224,7 +186,7 @@ def test_mutate_model(self): decoded = self.decode(ts.provenance(1).record) assert decoded.schema_version == "1.0.0" assert decoded.parameters.command == "sim_mutations" - assert decoded.parameters.model["__class__"] == "msprime.mutations.PAM" + assert decoded.parameters.model == "pam" def test_mutate_map(self): ts = msprime.simulate(5, random_seed=1) @@ -252,9 +214,11 @@ def test_mutate_numpy(self): assert decoded.schema_version == "1.0.0" assert decoded.parameters.command == "sim_mutations" assert decoded.parameters.random_seed == 1 - assert decoded.parameters.rate == 2 - assert decoded.parameters.start_time == 0 - assert decoded.parameters.end_time == 100 + # The dtype values change depending on platform, so not much + # point in trying to test exactly. + assert decoded.parameters.rate["__npgeneric__"] == "2" + assert decoded.parameters.start_time["__npgeneric__"] == "0" + assert decoded.parameters.end_time["__ndarray__"] == 100 class TestParseProvenance: @@ -366,6 +330,17 @@ def test_mutate_rate_map(self): ts = msprime.mutate(ts, rate=rate_map) self.verify(ts) + def test_mutate_numpy(self): + ts = msprime.sim_ancestry(5, random_seed=1) + ts = msprime.sim_mutations( + ts, + rate=np.array([2])[0], + random_seed=np.array([1])[0], + start_time=np.array([0])[0], + end_time=np.array([100][0]), + ) + self.verify(ts) + class TestRetainsProvenance: def test_simulate_retains_provenance(self):