Skip to content

Commit

Permalink
enable model and data sharding
Browse files Browse the repository at this point in the history
- create partition manager object
- make MAP compatible
- migrate to Orbax checkpointing
- refactor predictive
  • Loading branch information
gianlucadetommaso committed Jul 6, 2023
1 parent 404840e commit 4444907
Show file tree
Hide file tree
Showing 78 changed files with 3,479 additions and 1,725 deletions.
10 changes: 5 additions & 5 deletions benchmarks/transformers/masked_language_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,13 +193,13 @@ def unpack_model_tar(model_ckpt_path: pathlib.Path) -> pathlib.Path:

try:
logger.info(list(pathlib.Path(args.restore_checkpoint_dir).rglob("*")))
restore_checkpoint_path = unpack_model_tar(
restore_checkpoint_dir = unpack_model_tar(
list(pathlib.Path(args.restore_checkpoint_dir).rglob("*"))[0]
)
logger.info(list(pathlib.Path(restore_checkpoint_path).rglob("*")))
logger.info(list(pathlib.Path(restore_checkpoint_dir).rglob("*")))
except:
logger.info("No checkpoint to restore")
restore_checkpoint_path = None
restore_checkpoint_dir = None

tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)

Expand Down Expand Up @@ -341,7 +341,7 @@ def accuracy_mlm(preds: Array, targets: Array) -> jnp.ndarray:
save_checkpoint_dir=args.save_checkpoint_dir,
save_every_n_steps=args.save_every_n_steps,
keep_top_n_checkpoints=args.keep_top_n_checkpoints,
restore_checkpoint_path=restore_checkpoint_path,
restore_checkpoint_dir=restore_checkpoint_dir,
),
)
if args.last_layer_only and (
Expand All @@ -357,7 +357,7 @@ def accuracy_mlm(preds: Array, targets: Array) -> jnp.ndarray:
and args.last_layer_only
else None,
)
if restore_checkpoint_path is not None:
if restore_checkpoint_dir is not None:
fit_config.optimizer = last_layer_optimizer
train_kwargs = {"fit_config": fit_config}
else:
Expand Down
26 changes: 16 additions & 10 deletions benchmarks/transformers/prob_model_text_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,13 +240,13 @@ def unpack_model_tar(model_ckpt_path: pathlib.Path) -> pathlib.Path:

try:
logger.info(list(pathlib.Path(args.load_model_dir).rglob("*")))
restore_checkpoint_path = unpack_model_tar(
restore_checkpoint_dir = unpack_model_tar(
list(pathlib.Path(args.load_model_dir).rglob("*"))[0]
)
logger.info(list(pathlib.Path(restore_checkpoint_path).rglob("*")))
logger.info(list(pathlib.Path(restore_checkpoint_dir).rglob("*")))
except:
logger.info("No checkpoint to restore")
restore_checkpoint_path = None
restore_checkpoint_dir = None

tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)

Expand Down Expand Up @@ -400,11 +400,17 @@ def unpack_model_tar(model_ckpt_path: pathlib.Path) -> pathlib.Path:

model_editor = None
if args.enable_probit_model_editor:
probit_freeze_fun = lambda p, v: True if "classifier" in p else False if args.probit_last_layer_only else None
probit_freeze_fun = (
lambda p, v: True
if "classifier" in p
else False
if args.probit_last_layer_only
else None
)
model_editor = ProbitModelEditor(
freeze_fun=probit_freeze_fun,
init_log_var=args.probit_init_log_var,
stop_gradient=args.probit_stop_gradient
stop_gradient=args.probit_stop_gradient,
)

### TRAINING
Expand Down Expand Up @@ -438,7 +444,7 @@ def unpack_model_tar(model_ckpt_path: pathlib.Path) -> pathlib.Path:
save_checkpoint_dir=args.output_data_dir,
save_every_n_steps=args.save_every_n_steps,
keep_top_n_checkpoints=args.keep_top_n_checkpoints,
restore_checkpoint_path=restore_checkpoint_path,
restore_checkpoint_dir=restore_checkpoint_dir,
),
callbacks=[
ResetCovarianceCallback(
Expand Down Expand Up @@ -469,7 +475,7 @@ def unpack_model_tar(model_ckpt_path: pathlib.Path) -> pathlib.Path:
last_layer_optimizer = FitOptimizer(
method=optimizer, n_epochs=args.num_train_epochs, freeze_fun=freeze_fun
)
if restore_checkpoint_path is not None:
if restore_checkpoint_dir is not None:
fit_config.optimizer = last_layer_optimizer
train_kwargs = {"fit_config": fit_config}
else:
Expand All @@ -494,11 +500,11 @@ def unpack_model_tar(model_ckpt_path: pathlib.Path) -> pathlib.Path:
calib_data_loader=None,
**train_kwargs,
)
elif restore_checkpoint_path is not None:
prob_model.load_state(restore_checkpoint_path)
elif restore_checkpoint_dir is not None:
prob_model.load_state(restore_checkpoint_dir)
else:
raise ValueError(
"Either restore_checkpoint_path or num_train_epochs > 0 should be specified."
"Either restore_checkpoint_dir or num_train_epochs > 0 should be specified."
)

if args.enable_probit_model_editor:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,19 @@ Please find their references below.

.. automodule:: fortuna.output_calib_model.classification
:members:
:exclude-members: get_path_latest_checkpoint, save_checkpoint, restore_checkpoint
:exclude-members: save_checkpoint, restore_checkpoint

.. _output_calib_regressor:

.. automodule:: fortuna.output_calib_model.regression
:members:
:exclude-members: get_path_latest_checkpoint, save_checkpoint, restore_checkpoint
:exclude-members: save_checkpoint, restore_checkpoint

.. _output_calib_base:

.. automodule:: fortuna.output_calib_model.base
:members:
:exclude-members: get_path_latest_checkpoint, save_checkpoint, restore_checkpoint
:exclude-members: save_checkpoint, restore_checkpoint

.. toctree::
:maxdepth: 1
Expand Down
2 changes: 1 addition & 1 deletion examples/scaling_up_bayesian_inference.pct.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def __call__(self, x, train: bool = False, **kwargs) -> jnp.ndarray:

# We are ready to call `prob_model.train`, which will perform posterior inference under-the-hood. In order to do Bayesian inference on the last layer only and freeze the other parameters, all we need to do is to pass a function `freeze_fun` to the optimizer configuration object, deciding which parameters should be "frozen" and which should be "trainable".
#
# In addition, we configure `map_fit_config` to make a preliminary run with MAP, and set the frozen parameters to a meaningful value. Alternatively, if any of these is available, you can also either restore an existing checkpoint by configuring `FitCheckpointer.restore_checkpoint_path`, or start from a current state by setting `FitCheckpointer.start_from_current_state` to `True`.
# In addition, we configure `map_fit_config` to make a preliminary run with MAP, and set the frozen parameters to a meaningful value. Alternatively, if any of these is available, you can also either restore an existing checkpoint by configuring `FitCheckpointer.restore_checkpoint_dir`, or start from a current state by setting `FitCheckpointer.start_from_current_state` to `True`.

from fortuna.prob_model import FitConfig, FitOptimizer

Expand Down
26 changes: 12 additions & 14 deletions fortuna/calib_model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,11 +137,11 @@ def _calibrate(
rng=self.rng.get(),
state=state,
loss_fun=loss,
training_dataloader=calib_data_loader,
training_data_loader=calib_data_loader,
training_dataset_size=n_calib_data,
n_epochs=config.optimizer.n_epochs,
metrics=config.monitor.metrics,
validation_dataloader=val_data_loader,
validation_data_loader=val_data_loader,
validation_dataset_size=n_val_data,
verbose=config.monitor.verbose,
callbacks=config.callbacks,
Expand All @@ -158,30 +158,28 @@ def _calibrate(
logging.info("Calibration completed.")
return status

def load_state(self, checkpoint_path: Path) -> None:
def load_state(self, checkpoint_dir: Path) -> None:
"""
Load the state of the posterior distribution from a checkpoint path. The checkpoint must be compatible with the
probabilistic model.
Parameters
----------
checkpoint_path : Path
checkpoint_dir : Path
Path to a checkpoint file or directory to restore.
"""
try:
self.restore_checkpoint(checkpoint_path)
self.restore_checkpoint(checkpoint_dir)
except ValueError:
raise ValueError(
f"No checkpoint was found in `checkpoint_path={checkpoint_path}`."
f"No checkpoint was found in `checkpoint_dir={checkpoint_dir}`."
)
self.predictive.state = CalibStateRepository(checkpoint_dir=checkpoint_path)
self.predictive.state = CalibStateRepository(checkpoint_dir=checkpoint_dir)

def save_state(
self, checkpoint_path: Path, keep_top_n_checkpoints: int = 1
) -> None:
def save_state(self, checkpoint_dir: Path, keep_top_n_checkpoints: int = 1) -> None:
return self.predictive.state.put(
self.predictive.state.get(),
checkpoint_path=checkpoint_path,
checkpoint_dir=checkpoint_dir,
keep=keep_top_n_checkpoints,
)

Expand Down Expand Up @@ -224,7 +222,7 @@ def _init(self, data_loader: DataLoader, config: Config):
)

def _init_state(self, calib_data_loader: DataLoader, config: Config) -> CalibState:
if config.checkpointer.restore_checkpoint_path is None:
if config.checkpointer.restore_checkpoint_dir is None:
if config.checkpointer.start_from_current_state:
state = self.predictive.state.get(optimizer=config.optimizer.method)
else:
Expand All @@ -233,10 +231,10 @@ def _init_state(self, calib_data_loader: DataLoader, config: Config) -> CalibSta
if config.checkpointer.start_from_current_state:
logging.warning(
"`config.checkpointer.start_from_current_state` will be ignored since "
"`config.checkpointer.restore_checkpoint_path` is given."
"`config.checkpointer.restore_checkpoint_dir` is given."
)
state = self.restore_checkpoint(
restore_checkpoint_path=config.checkpointer.restore_checkpoint_path,
restore_checkpoint_dir=config.checkpointer.restore_checkpoint_dir,
optimizer=config.optimizer.method,
)
return state
60 changes: 31 additions & 29 deletions fortuna/calib_model/calib_mixin.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,42 @@
import os
from typing import Optional

from flax.training import checkpoints

from fortuna.calib_model.state import CalibState
from fortuna.training.mixin import WithCheckpointingMixin
from fortuna.training.mixins.checkpointing import WithCheckpointingMixin
from fortuna.typing import (
OptaxOptimizer,
Path,
)

# from flax.training import checkpoints

class WithCalibCheckpointingMixin(WithCheckpointingMixin):
def restore_checkpoint(
self,
restore_checkpoint_path: Path,
optimizer: Optional[OptaxOptimizer] = None,
prefix: str = "checkpoint_",
**kwargs,
) -> CalibState:
if not os.path.isdir(restore_checkpoint_path) and not os.path.isfile(
restore_checkpoint_path
):
raise ValueError(
f"`restore_checkpoint_path={restore_checkpoint_path}` was not found."
)
d = checkpoints.restore_checkpoint(
ckpt_dir=str(restore_checkpoint_path),
target=None,
step=None,
prefix=prefix,
parallel=True,
)
if d is None:
raise ValueError(
f"No checkpoint was found in `restore_checkpoint_path={restore_checkpoint_path}`."
)

return CalibState.init_from_dict(d, optimizer, **kwargs)

class WithCalibCheckpointingMixin(WithCheckpointingMixin):
pass
# def restore_checkpoint(
# self,
# restore_checkpoint_dir: Path,
# optimizer: Optional[OptaxOptimizer] = None,
# prefix: str = "",
# **kwargs,
# ) -> CalibState:
# if not os.path.isdir(restore_checkpoint_dir) and not os.path.isfile(
# restore_checkpoint_dir
# ):
# raise ValueError(
# f"`restore_checkpoint_dir={restore_checkpoint_dir}` was not found."
# )
# d = checkpoints.restore_checkpoint(
# ckpt_dir=str(restore_checkpoint_dir),
# target=None,
# step=None,
# prefix=prefix,
# parallel=True,
# )
# if d is None:
# raise ValueError(
# f"No checkpoint was found in `restore_checkpoint_dir={restore_checkpoint_dir}`."
# )
#
# return CalibState.init_from_dict(d, optimizer, **kwargs)
8 changes: 3 additions & 5 deletions fortuna/calib_model/calib_model_calibrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,9 @@
from optax._src.base import PyTree

from fortuna.calib_model.state import CalibState
from fortuna.training.trainer import (
JittedMixin,
MultiDeviceMixin,
TrainerABC,
)
from fortuna.training.mixins.jitted import JittedMixin
from fortuna.training.mixins.multi_device import MultiDeviceMixin
from fortuna.training.trainer import TrainerABC
from fortuna.typing import (
Array,
Batch,
Expand Down
8 changes: 4 additions & 4 deletions fortuna/calib_model/config/checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ class Checkpointer:
def __init__(
self,
save_checkpoint_dir: Optional[Path] = None,
restore_checkpoint_path: Optional[Path] = None,
restore_checkpoint_dir: Optional[Path] = None,
start_from_current_state: bool = False,
save_every_n_steps: Optional[int] = None,
keep_top_n_checkpoints: Optional[int] = 2,
Expand All @@ -20,10 +20,10 @@ def __init__(
----------
save_checkpoint_dir: Optional[Path] = None
Save directory location.
restore_checkpoint_path: Optional[Path]
restore_checkpoint_dir: Optional[Path]
Path to checkpoint file or directory to restore.
start_from_current_state: bool = False
If True, the optimization will start from the current state. If `restore_checkpoint_path` is given, then
If True, the optimization will start from the current state. If `restore_checkpoint_dir` is given, then
`start_from_current_state` is ignored.
save_every_n_steps: int
Number of training steps between checkpoints. To disable, set `every_n_train_steps` to None or 0 (no
Expand All @@ -36,7 +36,7 @@ def __init__(
"""
self.save_checkpoint_dir = save_checkpoint_dir
self.save_every_n_steps = save_every_n_steps
self.restore_checkpoint_path = restore_checkpoint_path
self.restore_checkpoint_dir = restore_checkpoint_dir
self.start_from_current_state = start_from_current_state
self.keep_top_n_checkpoints = keep_top_n_checkpoints
self.dump_state = dump_state
4 changes: 2 additions & 2 deletions fortuna/data/dataset/huggingface_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,12 +112,12 @@ def get_data_loader(
drop_last: bool
if True, the last batch (which potentially is smaller then the default batch size) is dropped.
verbose: bool
Whether to show a progress bar while iterating over the dataloader or not.
Whether to show a progress bar while iterating over the data_loader or not.
Returns
-------
HuggingFaceDataLoader
The dataloader
The data_loader
"""
iterable = IterableData.from_callable(
lambda *args, **kwargs: self._get_data_loader(
Expand Down
Loading

0 comments on commit 4444907

Please sign in to comment.