Skip to content

Commit

Permalink
Merge pull request #1909 from jeromekelleher/cache-pop-table-v2
Browse files Browse the repository at this point in the history
Improve performance in the lots-of-simulations case
  • Loading branch information
jeromekelleher authored Nov 12, 2021
2 parents 2dbf702 + d838347 commit 5c58c87
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 34 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
- Support for running full ARG simulations with gene conversion
({pr}`1801`, {issue}`1773`, {user}`JereKoskela`).

- Improved performance when running many small simulations
({pr}`1909`, {user}`jeromekelleher`.

**Bug fixes**:

- Fix bug in full ARG simulation with missing regions of the genome,
Expand Down
113 changes: 79 additions & 34 deletions msprime/demography.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import enum
import inspect
import itertools
import json
import logging
import math
import numbers
Expand Down Expand Up @@ -58,6 +59,78 @@ class IncompletePopulationMetadataWarning(UserWarning):
"""


class LruCache(collections.OrderedDict):
# LRU example from the OrderedDict documentation
def __init__(self, maxsize=128, *args, **kwds):
self.maxsize = maxsize
super().__init__(*args, **kwds)

def __getitem__(self, key):
value = super().__getitem__(key)
self.move_to_end(key)
return value

def __setitem__(self, key, value):
super().__setitem__(key, value)
if len(self) > self.maxsize:
oldest = next(iter(self))
del self[oldest]


_population_table_cache = LruCache(16)


def _build_population_table(populations):
"""
Return a tskit PopulationTable instance encoding the metadata for the
specified populations. Because encoding metadata is quite expensive
we maintain an LRU cache.
"""
population_metadata = []
for population in populations:
metadata = {
"name": population.name,
"description": population.description,
}
if population.extra_metadata is not None:
intersection = set(population.extra_metadata.keys()) & set(metadata.keys())
if len(intersection) > 0:
printed_list = list(sorted(intersection))
raise ValueError(
f"Cannot set standard metadata key(s) {printed_list} "
"using extra_metadata. Please set using the corresponding "
"property of the Population class."
)
metadata.update(population.extra_metadata)
population_metadata.append(metadata)

# The only thing we store in the Population table is the metadata, so
# we cache based on this.
key = json.dumps(population_metadata, sort_keys=True)
if key not in _population_table_cache:
table = tskit.PopulationTable()
table.metadata_schema = tskit.MetadataSchema(
{
"codec": "json",
"type": "object",
"properties": {
"name": {"type": "string"},
"description": {"type": ["string", "null"]},
},
# The name and description fields are always filled out by
# msprime, so we tell downstream tools this by making them
# "required" by the schema.
"required": ["name", "description"],
"additionalProperties": True,
}
)
for metadata in population_metadata:
table.add_row(metadata=metadata)
_population_table_cache[key] = table

return _population_table_cache[key]


def check_num_populations(num_populations):
"""
Check if an input number of populations is valid.
Expand Down Expand Up @@ -1038,41 +1111,13 @@ def insert_populations(self, tables):
:meta private:
"""
metadata_schema = tskit.MetadataSchema(
{
"codec": "json",
"type": "object",
"properties": {
"name": {"type": "string"},
"description": {"type": ["string", "null"]},
},
# The name and description fields are always filled out by
# msprime, so we tell downstream tools this by making them
# "required" by the schema.
"required": ["name", "description"],
"additionalProperties": True,
}
)
assert len(tables.populations) == 0
tables.populations.metadata_schema = metadata_schema
for population in self.populations:
metadata = {
"name": population.name,
"description": population.description,
}
if population.extra_metadata is not None:
intersection = set(population.extra_metadata.keys()) & set(
metadata.keys()
)
if len(intersection) > 0:
printed_list = list(sorted(intersection))
raise ValueError(
f"Cannot set standard metadata key(s) {printed_list} "
"using extra_metadata. Please set using the corresponding "
"property of the Population class."
)
metadata.update(population.extra_metadata)
tables.populations.add_row(metadata=metadata)
population_table = _build_population_table(self.populations)
tables.populations.metadata_schema = population_table.metadata_schema
tables.populations.set_columns(
metadata=population_table.metadata,
metadata_offset=population_table.metadata_offset,
)

def insert_extra_populations(self, tables):
"""
Expand Down
12 changes: 12 additions & 0 deletions tests/test_demography.py
Original file line number Diff line number Diff line change
Expand Up @@ -6651,3 +6651,15 @@ def test_ooa_manual(self):
if len(pop_map) == 1:
assert np.all(epoch_local.migration_matrix == 0)
assert np.all(epoch_sps.migration_matrix == 0)


def test_lru_cache():
# Very basic test, as this is pulled directly from Python docs.
d = demog_mod.LruCache(2)
d[0] = 0
assert len(d) == 1
d[1] = 1
assert len(d) == 2
d[2] = 2
assert len(d) == 2
assert 0 not in d

0 comments on commit 5c58c87

Please sign in to comment.