Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add optimizer_config to replace original parameters #377

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 11 additions & 17 deletions libmultilabel/nn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ class MultiLabelModel(L.LightningModule):

Args:
num_classes (int): Total number of classes.
learning_rate (float, optional): Learning rate for optimizer. Defaults to 0.0001.
optimizer (str, optional): Optimizer name (i.e., sgd, adam, or adamw). Defaults to 'adam'.
momentum (float, optional): Momentum factor for SGD only. Defaults to 0.9.
weight_decay (int, optional): Weight decay factor. Defaults to 0.
optimizer_config (dict, optional): Optimizer parameters. The keys in the dictionary should match the parameter names defined by PyTorch for the optimizer.
lr_scheduler: (str, optional): Learning rate scheduler. Defaults to None, i.e., no learning rate scheduler. Currently, the only supported lr_scheduler is 'ReduceLROnPlateau'.
scheduler_config (dict, optional): Learning rate scheduler parameters. The keys in the dictionary should match the parameter names defined by PyTorch for the learning rate scheduler.
metric_threshold (float, optional): The decision value threshold over which a label is predicted as positive. Defaults to 0.5.
monitor_metrics (list, optional): Metrics to monitor while validating. Defaults to None.
log_path (str): Path to a directory holding the log files and models.
Expand All @@ -30,10 +30,8 @@ class MultiLabelModel(L.LightningModule):
def __init__(
self,
num_classes,
learning_rate=0.0001,
optimizer="adam",
momentum=0.9,
weight_decay=0,
optimizer_config=None,
lr_scheduler=None,
scheduler_config=None,
val_metric=None,
Expand All @@ -43,15 +41,13 @@ def __init__(
multiclass=False,
silent=False,
save_k_predictions=0,
**kwargs
**kwargs,
):
super().__init__()

# optimizer
self.learning_rate = learning_rate
self.optimizer = optimizer
self.momentum = momentum
self.weight_decay = weight_decay
self.optimizer_config = optimizer_config if optimizer_config is not None else {}

# lr_scheduler
self.lr_scheduler = lr_scheduler
Expand All @@ -78,17 +74,15 @@ def configure_optimizers(self):
parameters = [p for p in self.parameters() if p.requires_grad]
optimizer_name = self.optimizer
if optimizer_name == "sgd":
optimizer = optim.SGD(
parameters, self.learning_rate, momentum=self.momentum, weight_decay=self.weight_decay
)
optimizer = optim.SGD(parameters, **self.optimizer_config)
elif optimizer_name == "adam":
optimizer = optim.Adam(parameters, weight_decay=self.weight_decay, lr=self.learning_rate)
optimizer = optim.Adam(parameters, **self.optimizer_config)
elif optimizer_name == "adamw":
optimizer = optim.AdamW(parameters, weight_decay=self.weight_decay, lr=self.learning_rate)
optimizer = optim.AdamW(parameters, **self.optimizer_config)
elif optimizer_name == "adamax":
optimizer = optim.Adamax(parameters, weight_decay=self.weight_decay, lr=self.learning_rate)
optimizer = optim.Adamax(parameters, **self.optimizer_config)
else:
raise RuntimeError("Unsupported optimizer: {self.optimizer}")
raise RuntimeError(f"Unsupported optimizer: {self.optimizer}")

if self.lr_scheduler:
if self.lr_scheduler == "ReduceLROnPlateau":
Expand Down
12 changes: 3 additions & 9 deletions libmultilabel/nn/nn_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,8 @@ def init_model(
embed_vecs=None,
init_weight=None,
log_path=None,
learning_rate=0.0001,
optimizer="adam",
momentum=0.9,
weight_decay=0,
optimizer_config=None,
lr_scheduler=None,
scheduler_config=None,
val_metric=None,
Expand All @@ -69,10 +67,8 @@ def init_model(
For example, the `init_weight` of `torch.nn.init.kaiming_uniform_`
is `kaiming_uniform`. Defaults to None.
log_path (str): Path to a directory holding the log files and models.
learning_rate (float, optional): Learning rate for optimizer. Defaults to 0.0001.
optimizer (str, optional): Optimizer name (i.e., sgd, adam, or adamw). Defaults to 'adam'.
momentum (float, optional): Momentum factor for SGD only. Defaults to 0.9.
weight_decay (int, optional): Weight decay factor. Defaults to 0.
optimizer_config (dict, optional): Optimizer parameters. The keys in the dictionary should match the parameter names defined by PyTorch for the optimizer.
lr_scheduler (str, optional): Name of the learning rate scheduler. Defaults to None.
scheduler_config (dict, optional): The configuration for learning rate scheduler. Defaults to None.
val_metric (str, optional): The metric to select the best model for testing. Used by some of the schedulers. Defaults to None.
Expand Down Expand Up @@ -102,10 +98,8 @@ def init_model(
word_dict=word_dict,
network=network,
log_path=log_path,
learning_rate=learning_rate,
optimizer=optimizer,
momentum=momentum,
weight_decay=weight_decay,
optimizer_config=optimizer_config,
lr_scheduler=lr_scheduler,
scheduler_config=scheduler_config,
val_metric=val_metric,
Expand Down
3 changes: 3 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,9 @@ def get_config():
args.early_stopping_metric = args.val_metric
if not hasattr(args, "scheduler_config"):
args.scheduler_config = None
args.optimizer_config = {"lr": args.learning_rate, "weight_decay": args.weight_decay}
if args.optimizer == "sgd":
args.optimizer_config["momentum"] = args.momentum
config = AttributeDict(vars(args))

config.run_name = "{}_{}_{}".format(
Expand Down
4 changes: 1 addition & 3 deletions torch_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,10 +189,8 @@ def _setup_model(
embed_vecs=embed_vecs,
init_weight=self.config.init_weight,
log_path=log_path,
learning_rate=self.config.learning_rate,
optimizer=self.config.optimizer,
momentum=self.config.momentum,
weight_decay=self.config.weight_decay,
optimizer_config=self.config.optimizer_config,
lr_scheduler=self.config.lr_scheduler,
scheduler_config=self.config.scheduler_config,
val_metric=self.config.val_metric,
Expand Down
Loading