Skip to content

Commit

Permalink
Replaced multiprocessing with multiprocess, pickle with dill, simplif…
Browse files Browse the repository at this point in the history
…ied the examples (#6)

* Replaced multiprocessing with multiprocess, pickle with dill, simplified examples

* Fixed linting mistakes

* Added multiprocess to the depedencies in pyproject.toml
  • Loading branch information
berndie authored May 31, 2024
1 parent e2ad1d4 commit 187aac0
Show file tree
Hide file tree
Showing 12 changed files with 321 additions and 294 deletions.
2 changes: 1 addition & 1 deletion brain_pipe/preprocessing/brain/link.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from brain_pipe.pipeline.base import PipelineStep
from brain_pipe.utils.list import flatten
from brain_pipe.utils.multiprocess import MultiprocessingSingleton
from brain_pipe.utils.parallellization import MultiprocessingSingleton
from brain_pipe.utils.path import BIDSStimulusGrouper


Expand Down
2 changes: 1 addition & 1 deletion brain_pipe/runner/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from brain_pipe.dataloaders.base import DataLoader
from brain_pipe.pipeline.base import Pipeline
from brain_pipe.utils.log import default_logging
from brain_pipe.utils.multiprocess import MultiprocessingSingleton
from brain_pipe.utils.parallellization import MultiprocessingSingleton


class DefaultRunner(Runner):
Expand Down
66 changes: 41 additions & 25 deletions brain_pipe/save/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
)
from brain_pipe.save.base import Save
from brain_pipe.utils.list import wrap_in_list
from brain_pipe.utils.multiprocess import MultiprocessingSingleton
from brain_pipe.utils.parallellization import MultiprocessingSingleton

# Shorthand interfaces.
CheckInterface = Callable[[Dict[str, Any], str, Dict[str, Any]], Union[str, bool]]
Expand Down Expand Up @@ -408,6 +408,21 @@ def __init__(
super().__init__(key_fn=key_fn)
self.filename = filename
self.saver = None
# Make sure the lock is at least created before multiprocessing is used
self._get_lock()

def attach_saver(self, saver):
"""Attach a saver to the metadata.
Parameters
----------
saver: DefaultSave
The saver to attach.
"""
new_metadata = super(DefaultSaveMetadata, self).attach_saver(saver)
# Make sure the lock is at least created before multiprocessing is used
new_metadata._get_lock()
return new_metadata

def get_path(self):
"""Get the path to the metadata file.
Expand Down Expand Up @@ -443,6 +458,9 @@ def get_relpath(self, path: str):
return os.path.relpath(path, self.saver.root_dir)
return path

def _get_lock(self):
return MultiprocessingSingleton.get_lock(self.get_path())

@property
def lock(self):
"""Retrieve the lock to use for the metadata file.
Expand All @@ -452,15 +470,14 @@ def lock(self):
multiprocessing.Lock
The lock to use for the metadata file.
"""
return MultiprocessingSingleton.get_lock(self.get_path())
return self._get_lock()

def clear(self):
"""Clear the metadata."""
self.lock.acquire()
metadata_path = self.get_path()
if os.path.exists(metadata_path):
os.remove(metadata_path)
self.lock.release()
with self.lock:
metadata_path = self.get_path()
if os.path.exists(metadata_path):
os.remove(metadata_path)

def get_metadata_for_savepath(
self,
Expand Down Expand Up @@ -524,10 +541,8 @@ def get(self):
metadata_path = self.get_path()
if not os.path.exists(metadata_path):
return {}
self.lock.acquire()
with open(metadata_path) as fp:
metadata = json.load(fp)
self.lock.release()
return metadata

def add(
Expand All @@ -550,18 +565,19 @@ def add(
set_name: Optional[str]
The name of the set.
"""
metadata = self.get()
key = self.key_fn(data_dict)
if key not in metadata:
metadata[key] = []
all_filepaths = wrap_in_list(filepath)
for path in all_filepaths:
metadata_for_savepath = self.get_metadata_for_savepath(
path, feature_name, set_name
)
if metadata_for_savepath not in metadata[key]:
metadata[key] += [metadata_for_savepath]
self.write(metadata)
with self.lock:
metadata = self.get()
key = self.key_fn(data_dict)
if key not in metadata:
metadata[key] = []
all_filepaths = wrap_in_list(filepath)
for path in all_filepaths:
metadata_for_savepath = self.get_metadata_for_savepath(
path, feature_name, set_name
)
if metadata_for_savepath not in metadata[key]:
metadata[key] += [metadata_for_savepath]
self.write(metadata)

def write(self, metadata_dict: Dict[str, Any]):
"""Write the metadata to disk.
Expand All @@ -571,10 +587,8 @@ def write(self, metadata_dict: Dict[str, Any]):
metadata_dict: Dict[str, Any]
A dictionary containing the metadata.
"""
self.lock.acquire()
with open(self.get_path(), "w") as fp:
json.dump(metadata_dict, fp)
self.lock.release()

def __contains__(self, item: Any):
"""Check if the metadata contains a certain item.
Expand All @@ -588,7 +602,8 @@ def __contains__(self, item: Any):
bool
Whether the item is contained.
"""
return item in self.get()
with self.lock:
return item in self.get()

def __getitem__(self, key: Any):
"""Retrieve a metadata item.
Expand All @@ -603,7 +618,8 @@ def __getitem__(self, key: Any):
Any
The metadata item.
"""
return self.get()[key]
with self.lock:
return self.get()[key]


class DefaultSave(Save):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import abc
import logging

import multiprocessing
import multiprocess


class ProgressCallbackFn(abc.ABC):
Expand Down Expand Up @@ -45,7 +45,7 @@ def __call__(self, result):
class MultiprocessingSingleton:
"""Singleton class for multiprocessing."""

manager = multiprocessing.Manager()
manager = multiprocess.Manager()
locks = {}

to_clean = []
Expand Down Expand Up @@ -76,8 +76,8 @@ def get_map_fn(
"""
if nb_processes != 0:
if nb_processes == -1:
nb_processes = multiprocessing.cpu_count()
pool = multiprocessing.Pool(nb_processes, maxtasksperchild=maxtasksperchild)
nb_processes = multiprocess.cpu_count()
pool = multiprocess.Pool(nb_processes, maxtasksperchild=maxtasksperchild)
cls.to_clean += [pool]

def dummy_map_fn(fn, iterable):
Expand Down
6 changes: 3 additions & 3 deletions brain_pipe/utils/serialization.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Serialization utilities."""
import pickle
import dill


def pickle_dump_wrapper(path, obj):
Expand All @@ -13,7 +13,7 @@ def pickle_dump_wrapper(path, obj):
The object to dump.
"""
with open(path, "wb") as f:
pickle.dump(obj, f, protocol=pickle.HIGHEST_PROTOCOL)
dill.dump(obj, f)


def pickle_load_wrapper(path):
Expand All @@ -30,4 +30,4 @@ def pickle_load_wrapper(path):
The loaded object.
"""
with open(path, "rb") as f:
return pickle.load(f)
return dill.load(f)
Loading

0 comments on commit 187aac0

Please sign in to comment.