diff --git a/README.rst b/README.rst index b6d560e..cdd6c40 100644 --- a/README.rst +++ b/README.rst @@ -92,6 +92,9 @@ Supported Optimizers | `Adahessian`_ | https://arxiv.org/abs/2006.00719 | +---------------+--------------------------------------------------------------------------------------------------------------------------------------+ | | | +| `AdamD`_ | https://arxiv.org/abs/2110.10828 | ++---------------+--------------------------------------------------------------------------------------------------------------------------------------+ +| | | | `AdamP`_ | https://arxiv.org/abs/2006.08217 | +---------------+--------------------------------------------------------------------------------------------------------------------------------------+ | | | @@ -200,9 +203,9 @@ see if there is any improvement. A2GradExp --------- -+--------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------+ -| .. image:: https://raw.githubusercontent.com/jettify/pytorch-optimizer/master/docs/rastrigin_A2GradExp.png | .. image:: https://raw.githubusercontent.com/jettify/pytorch-optimizer/master/docs/rosenbrock_A2GradExp.png | -+--------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------+ ++----------------------------------------------------+----------------------------------------------------+ +| .. image:: docs/rastrigin_A2GradExp.png | .. image:: docs/rosenbrock_A2GradExp.png | ++----------------------------------------------------+----------------------------------------------------+ .. code:: python @@ -227,9 +230,9 @@ A2GradExp A2GradInc --------- -+--------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------+ -| .. image:: https://raw.githubusercontent.com/jettify/pytorch-optimizer/master/docs/rastrigin_A2GradInc.png | .. image:: https://raw.githubusercontent.com/jettify/pytorch-optimizer/master/docs/rosenbrock_A2GradInc.png | -+--------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------+ ++----------------------------------------------------+----------------------------------------------------+ +| .. image:: docs/rastrigin_A2GradInc.png | .. image:: docs/rosenbrock_A2GradInc.png | ++----------------------------------------------------+----------------------------------------------------+ .. code:: python @@ -253,9 +256,9 @@ A2GradInc A2GradUni --------- -+--------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------+ -| .. image:: https://raw.githubusercontent.com/jettify/pytorch-optimizer/master/docs/rastrigin_A2GradUni.png | .. image:: https://raw.githubusercontent.com/jettify/pytorch-optimizer/master/docs/rosenbrock_A2GradUni.png | -+--------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------+ ++----------------------------------------------------+----------------------------------------------------+ +| .. image:: docs/rastrigin_A2GradUni.png | .. image:: docs/rosenbrock_A2GradUni.png | ++----------------------------------------------------+----------------------------------------------------+ .. code:: python @@ -279,9 +282,9 @@ A2GradUni AccSGD ------ -+-----------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------+ -| .. image:: https://raw.githubusercontent.com/jettify/pytorch-optimizer/master/docs/rastrigin_AccSGD.png | .. image:: https://raw.githubusercontent.com/jettify/pytorch-optimizer/master/docs/rosenbrock_AccSGD.png | -+-----------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------+ ++----------------------------------------------------+----------------------------------------------------+ +| .. image:: docs/rastrigin_AccSGD.png | .. image:: docs/rosenbrock_AccSGD.png | ++----------------------------------------------------+----------------------------------------------------+ .. code:: python @@ -307,9 +310,9 @@ AccSGD AdaBelief --------- -+-------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------+ -| .. image:: https://raw.githubusercontent.com/jettify/pytorch-optimizer/master/docs/rastrigin_AdaBelief.png | .. image:: https://raw.githubusercontent.com/jettify/pytorch-optimizer/master/docs/rosenbrock_AdaBelief.png | -+-------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------+ ++----------------------------------------------------+----------------------------------------------------+ +| .. image:: docs/rastrigin_AdaBelief.png | .. image:: docs/rosenbrock_AdaBelief.png | ++----------------------------------------------------+----------------------------------------------------+ .. code:: python @@ -338,9 +341,9 @@ AdaBelief AdaBound -------- -+------------------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------------------------------------+ -| .. image:: https://raw.githubusercontent.com/jettify/pytorch-optimizer/master/docs/rastrigin_AdaBound.png | .. image:: https://raw.githubusercontent.com/jettify/pytorch-optimizer/master/docs/rosenbrock_AdaBound.png | -+------------------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------------------------------------+ ++----------------------------------------------------+----------------------------------------------------+ +| .. image:: docs/rastrigin_AdaBound.png | .. image:: docs/rosenbrock_AdaBound.png | ++----------------------------------------------------+----------------------------------------------------+ .. code:: python @@ -371,9 +374,9 @@ upper bounds. The dynamic learning rate bounds are based on the exponential moving averages of the adaptive learning rates themselves, which smooth out unexpected large learning rates and stabilize the training of deep neural networks. -+------------------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------------------------------------+ -| .. image:: https://raw.githubusercontent.com/jettify/pytorch-optimizer/master/docs/rastrigin_AdaMod.png | .. image:: https://raw.githubusercontent.com/jettify/pytorch-optimizer/master/docs/rosenbrock_AdaMod.png | -+------------------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------------------------------------+ ++----------------------------------------------------+----------------------------------------------------+ +| .. image:: docs/rastrigin_AdaMod.png | .. image:: docs/rosenbrock_AdaMod.png | ++----------------------------------------------------+----------------------------------------------------+ .. code:: python @@ -397,9 +400,9 @@ unexpected large learning rates and stabilize the training of deep neural networ Adafactor --------- -+------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------+ -| .. image:: https://raw.githubusercontent.com/jettify/pytorch-optimizer/master/docs/rastrigin_Adafactor.png | .. image:: https://raw.githubusercontent.com/jettify/pytorch-optimizer/master/docs/rosenbrock_Adafactor.png | -+------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------+ ++----------------------------------------------------+----------------------------------------------------+ +| .. image:: docs/rastrigin_Adafactor.png | .. image:: docs/rosenbrock_Adafactor.png | ++----------------------------------------------------+----------------------------------------------------+ .. code:: python @@ -427,9 +430,10 @@ Adafactor Adahessian ---------- -+-------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------+ -| .. image:: https://raw.githubusercontent.com/jettify/pytorch-optimizer/master/docs/rastrigin_Adahessian.png | .. image:: https://raw.githubusercontent.com/jettify/pytorch-optimizer/master/docs/rosenbrock_Adahessian.png | -+-------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------+ ++----------------------------------------------------+----------------------------------------------------+ +| .. image:: docs/rastrigin_Adahessian.png | .. image:: docs/rosenbrock_Adahessian.png | ++----------------------------------------------------+----------------------------------------------------+ + .. code:: python @@ -452,6 +456,59 @@ Adahessian **Reference Code**: https://github.com/amirgholami/adahessian +AdamD +----- +AdamD is really an option available to a number of optimizers. The goal of this +option is to not inflate the first few training steps before the running mean +is initialized. + +rastrigin +########## ++--------------+----------------------------------------------------------+----------------------------------------------------------+ +| AdamP | .. image:: docs/rastrigin_AdamP.png | .. image:: docs/rastrigin_AdamP_adamD.png | ++--------------+----------------------------------------------------------+----------------------------------------------------------+ +| Adam | .. image:: docs/rastrigin_Adam_internal.png | .. image:: docs/rastrigin_Adam_internal_adamD.png | ++--------------+----------------------------------------------------------+----------------------------------------------------------+ +| AdaBelief | .. image:: docs/rastrigin_AdaBelief.png | .. image:: docs/rastrigin_AdaBelief_adamD.png | ++--------------+----------------------------------------------------------+----------------------------------------------------------+ +| AdaBound | .. image:: docs/rastrigin_AdaBound.png | .. image:: docs/rastrigin_AdaBound_adamD.png | ++--------------+----------------------------------------------------------+----------------------------------------------------------+ +| AdaMod | .. image:: docs/rastrigin_AdaMod.png | .. image:: docs/rastrigin_AdaMod_adamD.png | ++--------------+----------------------------------------------------------+----------------------------------------------------------+ +| Adahessian | .. image:: docs/rastrigin_Adahessian.png | .. image:: docs/rastrigin_Adahessian_adamD.png | ++--------------+----------------------------------------------------------+----------------------------------------------------------+ +| DiffGrad | .. image:: docs/rastrigin_DiffGrad.png | .. image:: docs/rastrigin_DiffGrad_adamD.png | ++--------------+----------------------------------------------------------+----------------------------------------------------------+ +| Lamb | .. image:: docs/rastrigin_Lamb.png | .. image:: docs/rastrigin_Lamb_adamD.png | ++--------------+----------------------------------------------------------+----------------------------------------------------------+ +| SWATS | .. image:: docs/rastrigin_SWATS.png | .. image:: docs/rastrigin_SWATS_adamD.png | ++--------------+----------------------------------------------------------+----------------------------------------------------------+ +| Yogi | .. image:: docs/rastrigin_Yogi.png | .. image:: docs/rastrigin_Yogi_adamD.png | ++--------------+----------------------------------------------------------+----------------------------------------------------------+ + +rosenbrock +########## ++--------------+----------------------------------------------------------+----------------------------------------------------------+ +| AdamP | .. image:: docs/rosenbrock_AdamP.png | .. image:: docs/rosenbrock_AdamP_adamD.png | ++--------------+----------------------------------------------------------+----------------------------------------------------------+ +| Adam | .. image:: docs/rosenbrock_Adam_internal.png | .. image:: docs/rosenbrock_Adam_internal_adamD.png | ++--------------+----------------------------------------------------------+----------------------------------------------------------+ +| AdaBelief | .. image:: docs/rosenbrock_AdaBelief.png | .. image:: docs/rosenbrock_AdaBelief_adamD.png | ++--------------+----------------------------------------------------------+----------------------------------------------------------+ +| AdaBound | .. image:: docs/rosenbrock_AdaBound.png | .. image:: docs/rosenbrock_AdaBound_adamD.png | ++--------------+----------------------------------------------------------+----------------------------------------------------------+ +| AdaMod | .. image:: docs/rosenbrock_AdaMod.png | .. image:: docs/rosenbrock_AdaMod_adamD.png | ++--------------+----------------------------------------------------------+----------------------------------------------------------+ +| Adahessian | .. image:: docs/rosenbrock_Adahessian.png | .. image:: docs/rosenbrock_Adahessian_adamD.png | ++--------------+----------------------------------------------------------+----------------------------------------------------------+ +| DiffGrad | .. image:: docs/rosenbrock_DiffGrad.png | .. image:: docs/rosenbrock_DiffGrad_adamD.png | ++--------------+----------------------------------------------------------+----------------------------------------------------------+ +| Lamb | .. image:: docs/rosenbrock_Lamb.png | .. image:: docs/rosenbrock_Lamb_adamD.png | ++--------------+----------------------------------------------------------+----------------------------------------------------------+ +| SWATS | .. image:: docs/rosenbrock_SWATS.png | .. image:: docs/rosenbrock_SWATS_adamD.png | ++--------------+----------------------------------------------------------+----------------------------------------------------------+ +| Yogi | .. image:: docs/rosenbrock_Yogi.png | .. image:: docs/rosenbrock_Yogi_adamD.png | ++--------------+----------------------------------------------------------+----------------------------------------------------------+ AdamP ------ @@ -461,9 +518,9 @@ remove the radial component (i.e., parallel to the weight vector) from the updat Intuitively, this operation prevents the unnecessary update along the radial direction that only increases the weight norm without contributing to the loss minimization. -+------------------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------------------------------------+ -| .. image:: https://raw.githubusercontent.com/jettify/pytorch-optimizer/master/docs/rastrigin_AdamP.png | .. image:: https://raw.githubusercontent.com/jettify/pytorch-optimizer/master/docs/rosenbrock_AdamP.png | -+------------------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------------------------------------+ ++--------------------------------------------+--------------------------------------------+ +| .. image:: docs/rastrigin_AdamP.png | .. image:: docs/rosenbrock_AdamP.png | ++--------------------------------------------+--------------------------------------------+ .. code:: python @@ -489,9 +546,9 @@ that only increases the weight norm without contributing to the loss minimizatio AggMo ----- -+------------------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------------------------------------+ -| .. image:: https://raw.githubusercontent.com/jettify/pytorch-optimizer/master/docs/rastrigin_AggMo.png | .. image:: https://raw.githubusercontent.com/jettify/pytorch-optimizer/master/docs/rosenbrock_AggMo.png | -+------------------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------------------------------------+ ++--------------------------------------------+--------------------------------------------+ +| .. image:: docs/rastrigin_AggMo.png | .. image:: docs/rosenbrock_AggMo.png | ++--------------------------------------------+--------------------------------------------+ .. code:: python @@ -514,9 +571,9 @@ AggMo Apollo ------ -+------------------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------------------------------------+ -| .. image:: https://raw.githubusercontent.com/jettify/pytorch-optimizer/master/docs/rastrigin_Apollo.png | .. image:: https://raw.githubusercontent.com/jettify/pytorch-optimizer/master/docs/rosenbrock_Apollo.png | -+------------------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------------------------------------+ ++--------------------------------------------+--------------------------------------------+ +| .. image:: docs/rastrigin_Apollo.png | .. image:: docs/rosenbrock_Apollo.png | ++--------------------------------------------+--------------------------------------------+ .. code:: python @@ -546,9 +603,9 @@ gradient, the step size is adjusted for each parameter in such a way that it should have a larger step size for faster gradient changing parameters and a lower step size for lower gradient changing parameters. -+------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------+ -| .. image:: https://raw.githubusercontent.com/jettify/pytorch-optimizer/master/docs/rastrigin_DiffGrad.png | .. image:: https://raw.githubusercontent.com/jettify/pytorch-optimizer/master/docs/rosenbrock_DiffGrad.png | -+------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------+ ++--------------------------------------------+--------------------------------------------+ +| .. image:: docs/rastrigin_DiffGrad.png | .. image:: docs/rosenbrock_DiffGrad.png | ++--------------------------------------------+--------------------------------------------+ .. code:: python @@ -572,9 +629,9 @@ parameters and a lower step size for lower gradient changing parameters. Lamb ---- -+--------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------+ -| .. image:: https://raw.githubusercontent.com/jettify/pytorch-optimizer/master/docs/rastrigin_Lamb.png | .. image:: https://raw.githubusercontent.com/jettify/pytorch-optimizer/master/docs/rosenbrock_Lamb.png | -+--------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------+ ++--------------------------------------------+--------------------------------------------+ +| .. image:: docs/rastrigin_Lamb.png | .. image:: docs/rosenbrock_Lamb.png | ++--------------------------------------------+--------------------------------------------+ .. code:: python @@ -598,9 +655,9 @@ Lamb Lookahead --------- -+-----------------------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------------------------------------------+ -| .. image:: https://raw.githubusercontent.com/jettify/pytorch-optimizer/master/docs/rastrigin_LookaheadYogi.png | .. image:: https://raw.githubusercontent.com/jettify/pytorch-optimizer/master/docs/rosenbrock_LookaheadYogi.png | -+-----------------------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------------------------------------------+ ++----------------------------------------------------+----------------------------------------------------+ +| .. image:: docs/rastrigin_LookaheadYogi.png | .. image:: docs/rosenbrock_LookaheadYogi.png | ++----------------------------------------------------+----------------------------------------------------+ .. code:: python @@ -629,9 +686,9 @@ Lookahead MADGRAD --------- -+-----------------------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------------------------------------------+ -| .. image:: https://raw.githubusercontent.com/jettify/pytorch-optimizer/master/docs/rastrigin_MADGRAD.png | .. image:: https://raw.githubusercontent.com/jettify/pytorch-optimizer/master/docs/rosenbrock_MADGRAD.png | -+-----------------------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------------------------------------------+ ++----------------------------------------------------+----------------------------------------------------+ +| .. image:: docs/rastrigin_MADGRAD.png | .. image:: docs/rosenbrock_MADGRAD.png | ++----------------------------------------------------+----------------------------------------------------+ .. code:: python @@ -656,9 +713,10 @@ MADGRAD NovoGrad -------- -+------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------+ -| .. image:: https://raw.githubusercontent.com/jettify/pytorch-optimizer/master/docs/rastrigin_NovoGrad.png | .. image:: https://raw.githubusercontent.com/jettify/pytorch-optimizer/master/docs/rosenbrock_NovoGrad.png | -+------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------+ ++----------------------------------------------------+----------------------------------------------------+ +| .. image:: docs/rastrigin_NovoGrad.png | .. image:: docs/rosenbrock_NovoGrad.png | ++----------------------------------------------------+----------------------------------------------------+ + .. code:: python @@ -685,9 +743,9 @@ NovoGrad PID --- -+-------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------+ -| .. image:: https://raw.githubusercontent.com/jettify/pytorch-optimizer/master/docs/rastrigin_PID.png | .. image:: https://raw.githubusercontent.com/jettify/pytorch-optimizer/master/docs/rosenbrock_PID.png | -+-------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------+ ++----------------------------------------------------+----------------------------------------------------+ +| .. image:: docs/rastrigin_PID.png | .. image:: docs/rosenbrock_PID.png | ++----------------------------------------------------+----------------------------------------------------+ .. code:: python @@ -714,9 +772,9 @@ PID QHAdam ------ -+----------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------+ -| .. image:: https://raw.githubusercontent.com/jettify/pytorch-optimizer/master/docs/rastrigin_QHAdam.png | .. image:: https://raw.githubusercontent.com/jettify/pytorch-optimizer/master/docs/rosenbrock_QHAdam.png | -+----------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------+ ++----------------------------------------------------+----------------------------------------------------+ +| .. image:: docs/rastrigin_QHAdam.png | .. image:: docs/rosenbrock_QHAdam.png | ++----------------------------------------------------+----------------------------------------------------+ .. code:: python @@ -743,9 +801,10 @@ QHAdam QHM --- -+-------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------+ -| .. image:: https://raw.githubusercontent.com/jettify/pytorch-optimizer/master/docs/rastrigin_QHM.png | .. image:: https://raw.githubusercontent.com/jettify/pytorch-optimizer/master/docs/rosenbrock_QHM.png | -+-------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------+ ++----------------------------------------------------+----------------------------------------------------+ +| .. image:: docs/rastrigin_QHM.png | .. image:: docs/rosenbrock_QHM.png | ++----------------------------------------------------+----------------------------------------------------+ + .. code:: python @@ -800,9 +859,10 @@ Deprecated, please use version provided by PyTorch_. Ranger ------ -+----------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------+ -| .. image:: https://raw.githubusercontent.com/jettify/pytorch-optimizer/master/docs/rastrigin_Ranger.png | .. image:: https://raw.githubusercontent.com/jettify/pytorch-optimizer/master/docs/rosenbrock_Ranger.png | -+----------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------+ ++----------------------------------------------------+----------------------------------------------------+ +| .. image:: docs/rastrigin_Ranger.png | .. image:: docs/rosenbrock_Ranger.png | ++----------------------------------------------------+----------------------------------------------------+ + .. code:: python @@ -830,9 +890,10 @@ Ranger RangerQH -------- -+------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------+ -| .. image:: https://raw.githubusercontent.com/jettify/pytorch-optimizer/master/docs/rastrigin_RangerQH.png | .. image:: https://raw.githubusercontent.com/jettify/pytorch-optimizer/master/docs/rosenbrock_RangerQH.png | -+------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------+ ++----------------------------------------------------+----------------------------------------------------+ +| .. image:: docs/rastrigin_RangerQH.png | .. image:: docs/rosenbrock_RangerQH.png | ++----------------------------------------------------+----------------------------------------------------+ + .. code:: python @@ -861,9 +922,10 @@ RangerQH RangerVA -------- -+------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------+ -| .. image:: https://raw.githubusercontent.com/jettify/pytorch-optimizer/master/docs/rastrigin_RangerVA.png | .. image:: https://raw.githubusercontent.com/jettify/pytorch-optimizer/master/docs/rosenbrock_RangerVA.png | -+------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------+ ++----------------------------------------------------+----------------------------------------------------+ +| .. image:: docs/rastrigin_RangerVA.png | .. image:: docs/rosenbrock_RangerVA.png | ++----------------------------------------------------+----------------------------------------------------+ + .. code:: python @@ -895,9 +957,10 @@ RangerVA SGDP ---- -+--------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------+ -| .. image:: https://raw.githubusercontent.com/jettify/pytorch-optimizer/master/docs/rastrigin_SGDP.png | .. image:: https://raw.githubusercontent.com/jettify/pytorch-optimizer/master/docs/rosenbrock_SGDP.png | -+--------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------+ ++----------------------------------------------------+----------------------------------------------------+ +| .. image:: docs/rastrigin_SGDP.png | .. image:: docs/rosenbrock_SGDP.png | ++----------------------------------------------------+----------------------------------------------------+ + .. code:: python @@ -925,9 +988,10 @@ SGDP SGDW ---- -+--------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------+ -| .. image:: https://raw.githubusercontent.com/jettify/pytorch-optimizer/master/docs/rastrigin_SGDW.png | .. image:: https://raw.githubusercontent.com/jettify/pytorch-optimizer/master/docs/rosenbrock_SGDW.png | -+--------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------+ ++----------------------------------------------------+----------------------------------------------------+ +| .. image:: docs/rastrigin_SGDW.png | .. image:: docs/rosenbrock_SGDW.png | ++----------------------------------------------------+----------------------------------------------------+ + .. code:: python @@ -953,9 +1017,9 @@ SGDW SWATS ----- -+---------------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------------------+ -| .. image:: https://raw.githubusercontent.com/jettify/pytorch-optimizer/master/docs/rastrigin_SWATS.png | .. image:: https://raw.githubusercontent.com/jettify/pytorch-optimizer/master/docs/rosenbrock_SWATS.png | -+---------------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------------------+ ++----------------------------------------------------+----------------------------------------------------+ +| .. image:: docs/rastrigin_SWATS.png | .. image:: docs/rosenbrock_SWATS.png | ++----------------------------------------------------+----------------------------------------------------+ .. code:: python @@ -982,9 +1046,10 @@ SWATS Shampoo ------- -+-----------------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------------------------------------+ -| .. image:: https://raw.githubusercontent.com/jettify/pytorch-optimizer/master/docs/rastrigin_Shampoo.png | .. image:: https://raw.githubusercontent.com/jettify/pytorch-optimizer/master/docs/rosenbrock_Shampoo.png | -+-----------------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------------------------------------+ ++----------------------------------------------------+----------------------------------------------------+ +| .. image:: docs/rastrigin_Shampoo.png | .. image:: docs/rosenbrock_Shampoo.png | ++----------------------------------------------------+----------------------------------------------------+ + .. code:: python @@ -1013,9 +1078,10 @@ Yogi Yogi is optimization algorithm based on ADAM with more fine grained effective learning rate control, and has similar theoretical guarantees on convergence as ADAM. -+--------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------+ -| .. image:: https://raw.githubusercontent.com/jettify/pytorch-optimizer/master/docs/rastrigin_Yogi.png | .. image:: https://raw.githubusercontent.com/jettify/pytorch-optimizer/master/docs/rosenbrock_Yogi.png | -+--------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------+ ++----------------------------------------------------+----------------------------------------------------+ +| .. image:: docs/rastrigin_Yogi.png | .. image:: docs/rosenbrock_Yogi.png | ++----------------------------------------------------+----------------------------------------------------+ + .. code:: python @@ -1041,16 +1107,19 @@ learning rate control, and has similar theoretical guarantees on convergence as Adam (PyTorch built-in) ----------------------- -+---------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------+ -| .. image:: https://raw.githubusercontent.com/jettify/pytorch-optimizer/master/docs/rastrigin_Adam.png | .. image:: https://raw.githubusercontent.com/jettify/pytorch-optimizer/master/docs/rosenbrock_Adam.png | -+---------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------+ ++-------------------+-------------------------------------------------+-------------------------------------------------+ +| Adam (pytorch) | .. image:: docs/rastrigin_Adam.png | .. image:: docs/rosenbrock_Adam.png | ++-------------------+-------------------------------------------------+-------------------------------------------------+ +| Adam (Ours) | .. image:: docs/rastrigin_Adam_internal.png | .. image:: docs/rosenbrock_Adam_internal.png | ++-------------------+-------------------------------------------------+-------------------------------------------------+ + SGD (PyTorch built-in) ---------------------- -+--------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------+ -| .. image:: https://raw.githubusercontent.com/jettify/pytorch-optimizer/master/docs/rastrigin_SGD.png | .. image:: https://raw.githubusercontent.com/jettify/pytorch-optimizer/master/docs/rosenbrock_SGD.png | -+--------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------+ ++----------------------------------------------------+----------------------------------------------------+ +| .. image:: docs/rastrigin_SGD.png | .. image:: docs/rosenbrock_SGD.png | ++----------------------------------------------------+----------------------------------------------------+ .. _Python: https://www.python.org .. _PyTorch: https://github.com/pytorch/pytorch diff --git a/docs/index.rst b/docs/index.rst index 5362451..01e56f3 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -60,6 +60,9 @@ Supported Optimizers | :ref:`Adafactor`| https://arxiv.org/abs/1804.04235 | +-----------------+-------------------------------------------------------------------------------+ | | | +| :ref:`AdamD` | https://arxiv.org/abs/2110.10828 | ++-----------------+-------------------------------------------------------------------------------+ +| | | | :ref:`AdamP` | https://arxiv.org/abs/1804.00325 | +-----------------+-------------------------------------------------------------------------------+ | | | diff --git a/docs/rastrigin_A2GradExp.png b/docs/rastrigin_A2GradExp.png index 851542e..50cfcdc 100644 Binary files a/docs/rastrigin_A2GradExp.png and b/docs/rastrigin_A2GradExp.png differ diff --git a/docs/rastrigin_A2GradInc.png b/docs/rastrigin_A2GradInc.png index cbee4a2..4f79c51 100644 Binary files a/docs/rastrigin_A2GradInc.png and b/docs/rastrigin_A2GradInc.png differ diff --git a/docs/rastrigin_A2GradUni.png b/docs/rastrigin_A2GradUni.png index d26dce8..744fe79 100644 Binary files a/docs/rastrigin_A2GradUni.png and b/docs/rastrigin_A2GradUni.png differ diff --git a/docs/rastrigin_AccSGD.png b/docs/rastrigin_AccSGD.png index 086497d..1fa455c 100644 Binary files a/docs/rastrigin_AccSGD.png and b/docs/rastrigin_AccSGD.png differ diff --git a/docs/rastrigin_AdaBelief.png b/docs/rastrigin_AdaBelief.png index dc68ec3..6b8ac87 100644 Binary files a/docs/rastrigin_AdaBelief.png and b/docs/rastrigin_AdaBelief.png differ diff --git a/docs/rastrigin_AdaBelief_adamD.png b/docs/rastrigin_AdaBelief_adamD.png new file mode 100644 index 0000000..09b5c97 Binary files /dev/null and b/docs/rastrigin_AdaBelief_adamD.png differ diff --git a/docs/rastrigin_AdaBound.png b/docs/rastrigin_AdaBound.png index f6f1f0f..173681e 100644 Binary files a/docs/rastrigin_AdaBound.png and b/docs/rastrigin_AdaBound.png differ diff --git a/docs/rastrigin_AdaBound_adamD.png b/docs/rastrigin_AdaBound_adamD.png new file mode 100644 index 0000000..61b918a Binary files /dev/null and b/docs/rastrigin_AdaBound_adamD.png differ diff --git a/docs/rastrigin_AdaMod.png b/docs/rastrigin_AdaMod.png index ba510e0..9f55a4c 100644 Binary files a/docs/rastrigin_AdaMod.png and b/docs/rastrigin_AdaMod.png differ diff --git a/docs/rastrigin_AdaMod_adamD.png b/docs/rastrigin_AdaMod_adamD.png new file mode 100644 index 0000000..e128795 Binary files /dev/null and b/docs/rastrigin_AdaMod_adamD.png differ diff --git a/docs/rastrigin_Adafactor.png b/docs/rastrigin_Adafactor.png index d34d8ff..38dc137 100644 Binary files a/docs/rastrigin_Adafactor.png and b/docs/rastrigin_Adafactor.png differ diff --git a/docs/rastrigin_Adahessian.png b/docs/rastrigin_Adahessian.png index fa4442c..d18d46d 100644 Binary files a/docs/rastrigin_Adahessian.png and b/docs/rastrigin_Adahessian.png differ diff --git a/docs/rastrigin_Adahessian_adamD.png b/docs/rastrigin_Adahessian_adamD.png new file mode 100644 index 0000000..0cffdc2 Binary files /dev/null and b/docs/rastrigin_Adahessian_adamD.png differ diff --git a/docs/rastrigin_Adam.png b/docs/rastrigin_Adam.png index 9f334c6..1cca630 100644 Binary files a/docs/rastrigin_Adam.png and b/docs/rastrigin_Adam.png differ diff --git a/docs/rastrigin_AdamP.png b/docs/rastrigin_AdamP.png index b98200b..f71d8c1 100644 Binary files a/docs/rastrigin_AdamP.png and b/docs/rastrigin_AdamP.png differ diff --git a/docs/rastrigin_AdamP_adamD.png b/docs/rastrigin_AdamP_adamD.png new file mode 100644 index 0000000..b91d0e9 Binary files /dev/null and b/docs/rastrigin_AdamP_adamD.png differ diff --git a/docs/rastrigin_AdamW_internal.png b/docs/rastrigin_AdamW_internal.png new file mode 100644 index 0000000..d965ba7 Binary files /dev/null and b/docs/rastrigin_AdamW_internal.png differ diff --git a/docs/rastrigin_AdamW_internal_adamD.png b/docs/rastrigin_AdamW_internal_adamD.png new file mode 100644 index 0000000..22a91cf Binary files /dev/null and b/docs/rastrigin_AdamW_internal_adamD.png differ diff --git a/docs/rastrigin_Adam_internal.png b/docs/rastrigin_Adam_internal.png new file mode 100644 index 0000000..75ee3b1 Binary files /dev/null and b/docs/rastrigin_Adam_internal.png differ diff --git a/docs/rastrigin_Adam_internal_adamD.png b/docs/rastrigin_Adam_internal_adamD.png new file mode 100644 index 0000000..4864fa2 Binary files /dev/null and b/docs/rastrigin_Adam_internal_adamD.png differ diff --git a/docs/rastrigin_AggMo.png b/docs/rastrigin_AggMo.png index 1767de4..18c5628 100644 Binary files a/docs/rastrigin_AggMo.png and b/docs/rastrigin_AggMo.png differ diff --git a/docs/rastrigin_Apollo.png b/docs/rastrigin_Apollo.png index 4868c56..842fdf1 100644 Binary files a/docs/rastrigin_Apollo.png and b/docs/rastrigin_Apollo.png differ diff --git a/docs/rastrigin_DiffGrad.png b/docs/rastrigin_DiffGrad.png index 3bce8e1..d5910f1 100644 Binary files a/docs/rastrigin_DiffGrad.png and b/docs/rastrigin_DiffGrad.png differ diff --git a/docs/rastrigin_DiffGrad_adamD.png b/docs/rastrigin_DiffGrad_adamD.png new file mode 100644 index 0000000..8f014cd Binary files /dev/null and b/docs/rastrigin_DiffGrad_adamD.png differ diff --git a/docs/rastrigin_Lamb.png b/docs/rastrigin_Lamb.png index ce86255..c6f69d9 100644 Binary files a/docs/rastrigin_Lamb.png and b/docs/rastrigin_Lamb.png differ diff --git a/docs/rastrigin_Lamb_adamD.png b/docs/rastrigin_Lamb_adamD.png new file mode 100644 index 0000000..b7098c1 Binary files /dev/null and b/docs/rastrigin_Lamb_adamD.png differ diff --git a/docs/rastrigin_LookaheadYogi.png b/docs/rastrigin_LookaheadYogi.png index 5880a54..333e026 100644 Binary files a/docs/rastrigin_LookaheadYogi.png and b/docs/rastrigin_LookaheadYogi.png differ diff --git a/docs/rastrigin_MADGRAD.png b/docs/rastrigin_MADGRAD.png index 9e46428..92ab21c 100644 Binary files a/docs/rastrigin_MADGRAD.png and b/docs/rastrigin_MADGRAD.png differ diff --git a/docs/rastrigin_NovoGrad.png b/docs/rastrigin_NovoGrad.png index eb6d0c6..b114f66 100644 Binary files a/docs/rastrigin_NovoGrad.png and b/docs/rastrigin_NovoGrad.png differ diff --git a/docs/rastrigin_PID.png b/docs/rastrigin_PID.png index 98fc481..be4639d 100644 Binary files a/docs/rastrigin_PID.png and b/docs/rastrigin_PID.png differ diff --git a/docs/rastrigin_QHAdam.png b/docs/rastrigin_QHAdam.png index 03e0c88..033adb2 100644 Binary files a/docs/rastrigin_QHAdam.png and b/docs/rastrigin_QHAdam.png differ diff --git a/docs/rastrigin_QHM.png b/docs/rastrigin_QHM.png index 4331eb3..d9b324b 100644 Binary files a/docs/rastrigin_QHM.png and b/docs/rastrigin_QHM.png differ diff --git a/docs/rastrigin_Ranger.png b/docs/rastrigin_Ranger.png index f2188ca..cb9443d 100644 Binary files a/docs/rastrigin_Ranger.png and b/docs/rastrigin_Ranger.png differ diff --git a/docs/rastrigin_RangerQH.png b/docs/rastrigin_RangerQH.png index 16d1e24..fa967d2 100644 Binary files a/docs/rastrigin_RangerQH.png and b/docs/rastrigin_RangerQH.png differ diff --git a/docs/rastrigin_RangerVA.png b/docs/rastrigin_RangerVA.png index fb2f490..e16fabd 100644 Binary files a/docs/rastrigin_RangerVA.png and b/docs/rastrigin_RangerVA.png differ diff --git a/docs/rastrigin_SGD.png b/docs/rastrigin_SGD.png index 1e07df3..345218b 100644 Binary files a/docs/rastrigin_SGD.png and b/docs/rastrigin_SGD.png differ diff --git a/docs/rastrigin_SGDP.png b/docs/rastrigin_SGDP.png index f1d5a29..9470507 100644 Binary files a/docs/rastrigin_SGDP.png and b/docs/rastrigin_SGDP.png differ diff --git a/docs/rastrigin_SGDW.png b/docs/rastrigin_SGDW.png index e45bebe..e82c687 100644 Binary files a/docs/rastrigin_SGDW.png and b/docs/rastrigin_SGDW.png differ diff --git a/docs/rastrigin_SWATS.png b/docs/rastrigin_SWATS.png index 458df35..586720e 100644 Binary files a/docs/rastrigin_SWATS.png and b/docs/rastrigin_SWATS.png differ diff --git a/docs/rastrigin_SWATS_adamD.png b/docs/rastrigin_SWATS_adamD.png new file mode 100644 index 0000000..1417f35 Binary files /dev/null and b/docs/rastrigin_SWATS_adamD.png differ diff --git a/docs/rastrigin_Shampoo.png b/docs/rastrigin_Shampoo.png index 47e7ff2..50109c3 100644 Binary files a/docs/rastrigin_Shampoo.png and b/docs/rastrigin_Shampoo.png differ diff --git a/docs/rastrigin_Yogi.png b/docs/rastrigin_Yogi.png index a31f3d2..1f3aaf5 100644 Binary files a/docs/rastrigin_Yogi.png and b/docs/rastrigin_Yogi.png differ diff --git a/docs/rastrigin_Yogi_adamD.png b/docs/rastrigin_Yogi_adamD.png new file mode 100644 index 0000000..e4a8cd8 Binary files /dev/null and b/docs/rastrigin_Yogi_adamD.png differ diff --git a/docs/rosenbrock_A2GradExp.png b/docs/rosenbrock_A2GradExp.png index e7b165c..ef4c9fc 100644 Binary files a/docs/rosenbrock_A2GradExp.png and b/docs/rosenbrock_A2GradExp.png differ diff --git a/docs/rosenbrock_A2GradInc.png b/docs/rosenbrock_A2GradInc.png index 25be224..c006607 100644 Binary files a/docs/rosenbrock_A2GradInc.png and b/docs/rosenbrock_A2GradInc.png differ diff --git a/docs/rosenbrock_A2GradUni.png b/docs/rosenbrock_A2GradUni.png index 0314ad1..d5b5e84 100644 Binary files a/docs/rosenbrock_A2GradUni.png and b/docs/rosenbrock_A2GradUni.png differ diff --git a/docs/rosenbrock_AccSGD.png b/docs/rosenbrock_AccSGD.png index 364fce3..3b932ce 100644 Binary files a/docs/rosenbrock_AccSGD.png and b/docs/rosenbrock_AccSGD.png differ diff --git a/docs/rosenbrock_AdaBelief.png b/docs/rosenbrock_AdaBelief.png index cfff2c8..7d7c937 100644 Binary files a/docs/rosenbrock_AdaBelief.png and b/docs/rosenbrock_AdaBelief.png differ diff --git a/docs/rosenbrock_AdaBelief_adamD.png b/docs/rosenbrock_AdaBelief_adamD.png new file mode 100644 index 0000000..02c0c71 Binary files /dev/null and b/docs/rosenbrock_AdaBelief_adamD.png differ diff --git a/docs/rosenbrock_AdaBound.png b/docs/rosenbrock_AdaBound.png index 0805c78..c782cc8 100644 Binary files a/docs/rosenbrock_AdaBound.png and b/docs/rosenbrock_AdaBound.png differ diff --git a/docs/rosenbrock_AdaBound_adamD.png b/docs/rosenbrock_AdaBound_adamD.png new file mode 100644 index 0000000..0fffd0f Binary files /dev/null and b/docs/rosenbrock_AdaBound_adamD.png differ diff --git a/docs/rosenbrock_AdaMod.png b/docs/rosenbrock_AdaMod.png index 1dc4607..a4a92da 100644 Binary files a/docs/rosenbrock_AdaMod.png and b/docs/rosenbrock_AdaMod.png differ diff --git a/docs/rosenbrock_AdaMod_adamD.png b/docs/rosenbrock_AdaMod_adamD.png new file mode 100644 index 0000000..f487d08 Binary files /dev/null and b/docs/rosenbrock_AdaMod_adamD.png differ diff --git a/docs/rosenbrock_Adafactor.png b/docs/rosenbrock_Adafactor.png index 0b57b00..1d8bc70 100644 Binary files a/docs/rosenbrock_Adafactor.png and b/docs/rosenbrock_Adafactor.png differ diff --git a/docs/rosenbrock_Adahessian.png b/docs/rosenbrock_Adahessian.png index 2f26e2e..d48f546 100644 Binary files a/docs/rosenbrock_Adahessian.png and b/docs/rosenbrock_Adahessian.png differ diff --git a/docs/rosenbrock_Adahessian_adamD.png b/docs/rosenbrock_Adahessian_adamD.png new file mode 100644 index 0000000..544a9c5 Binary files /dev/null and b/docs/rosenbrock_Adahessian_adamD.png differ diff --git a/docs/rosenbrock_Adam.png b/docs/rosenbrock_Adam.png index 4695095..b09d3c3 100644 Binary files a/docs/rosenbrock_Adam.png and b/docs/rosenbrock_Adam.png differ diff --git a/docs/rosenbrock_AdamP.png b/docs/rosenbrock_AdamP.png index a1b42dd..27724e0 100644 Binary files a/docs/rosenbrock_AdamP.png and b/docs/rosenbrock_AdamP.png differ diff --git a/docs/rosenbrock_AdamP_adamD.png b/docs/rosenbrock_AdamP_adamD.png new file mode 100644 index 0000000..c7248d2 Binary files /dev/null and b/docs/rosenbrock_AdamP_adamD.png differ diff --git a/docs/rosenbrock_AdamW_internal.png b/docs/rosenbrock_AdamW_internal.png new file mode 100644 index 0000000..2121b4c Binary files /dev/null and b/docs/rosenbrock_AdamW_internal.png differ diff --git a/docs/rosenbrock_AdamW_internal_adamD.png b/docs/rosenbrock_AdamW_internal_adamD.png new file mode 100644 index 0000000..cde9889 Binary files /dev/null and b/docs/rosenbrock_AdamW_internal_adamD.png differ diff --git a/docs/rosenbrock_Adam_internal.png b/docs/rosenbrock_Adam_internal.png new file mode 100644 index 0000000..93a461e Binary files /dev/null and b/docs/rosenbrock_Adam_internal.png differ diff --git a/docs/rosenbrock_Adam_internal_adamD.png b/docs/rosenbrock_Adam_internal_adamD.png new file mode 100644 index 0000000..815d47d Binary files /dev/null and b/docs/rosenbrock_Adam_internal_adamD.png differ diff --git a/docs/rosenbrock_AggMo.png b/docs/rosenbrock_AggMo.png index 1576607..aed477d 100644 Binary files a/docs/rosenbrock_AggMo.png and b/docs/rosenbrock_AggMo.png differ diff --git a/docs/rosenbrock_Apollo.png b/docs/rosenbrock_Apollo.png index 5d6190c..c884698 100644 Binary files a/docs/rosenbrock_Apollo.png and b/docs/rosenbrock_Apollo.png differ diff --git a/docs/rosenbrock_DiffGrad.png b/docs/rosenbrock_DiffGrad.png index 33554aa..33f0e08 100644 Binary files a/docs/rosenbrock_DiffGrad.png and b/docs/rosenbrock_DiffGrad.png differ diff --git a/docs/rosenbrock_DiffGrad_adamD.png b/docs/rosenbrock_DiffGrad_adamD.png new file mode 100644 index 0000000..aca3597 Binary files /dev/null and b/docs/rosenbrock_DiffGrad_adamD.png differ diff --git a/docs/rosenbrock_Lamb.png b/docs/rosenbrock_Lamb.png index 9ccf1e1..56e8edc 100644 Binary files a/docs/rosenbrock_Lamb.png and b/docs/rosenbrock_Lamb.png differ diff --git a/docs/rosenbrock_Lamb_adamD.png b/docs/rosenbrock_Lamb_adamD.png new file mode 100644 index 0000000..162d4ae Binary files /dev/null and b/docs/rosenbrock_Lamb_adamD.png differ diff --git a/docs/rosenbrock_LookaheadYogi.png b/docs/rosenbrock_LookaheadYogi.png index 8b8f390..c0d71a3 100644 Binary files a/docs/rosenbrock_LookaheadYogi.png and b/docs/rosenbrock_LookaheadYogi.png differ diff --git a/docs/rosenbrock_MADGRAD.png b/docs/rosenbrock_MADGRAD.png index 67ae7dd..2a3e197 100644 Binary files a/docs/rosenbrock_MADGRAD.png and b/docs/rosenbrock_MADGRAD.png differ diff --git a/docs/rosenbrock_NovoGrad.png b/docs/rosenbrock_NovoGrad.png index 3580386..0a4a9cc 100644 Binary files a/docs/rosenbrock_NovoGrad.png and b/docs/rosenbrock_NovoGrad.png differ diff --git a/docs/rosenbrock_PID.png b/docs/rosenbrock_PID.png index 085c44b..546e11d 100644 Binary files a/docs/rosenbrock_PID.png and b/docs/rosenbrock_PID.png differ diff --git a/docs/rosenbrock_QHAdam.png b/docs/rosenbrock_QHAdam.png index 9bec432..dd89aae 100644 Binary files a/docs/rosenbrock_QHAdam.png and b/docs/rosenbrock_QHAdam.png differ diff --git a/docs/rosenbrock_QHM.png b/docs/rosenbrock_QHM.png index 14f258e..fba5117 100644 Binary files a/docs/rosenbrock_QHM.png and b/docs/rosenbrock_QHM.png differ diff --git a/docs/rosenbrock_Ranger.png b/docs/rosenbrock_Ranger.png index 11043da..e682332 100644 Binary files a/docs/rosenbrock_Ranger.png and b/docs/rosenbrock_Ranger.png differ diff --git a/docs/rosenbrock_RangerQH.png b/docs/rosenbrock_RangerQH.png index 0f1b5a2..86a1568 100644 Binary files a/docs/rosenbrock_RangerQH.png and b/docs/rosenbrock_RangerQH.png differ diff --git a/docs/rosenbrock_RangerVA.png b/docs/rosenbrock_RangerVA.png index 759e492..9be84b3 100644 Binary files a/docs/rosenbrock_RangerVA.png and b/docs/rosenbrock_RangerVA.png differ diff --git a/docs/rosenbrock_SGD.png b/docs/rosenbrock_SGD.png index d86ac24..a6a2512 100644 Binary files a/docs/rosenbrock_SGD.png and b/docs/rosenbrock_SGD.png differ diff --git a/docs/rosenbrock_SGDP.png b/docs/rosenbrock_SGDP.png index f9bb39e..69ea5d0 100644 Binary files a/docs/rosenbrock_SGDP.png and b/docs/rosenbrock_SGDP.png differ diff --git a/docs/rosenbrock_SGDW.png b/docs/rosenbrock_SGDW.png index 5680fe1..a5c8783 100644 Binary files a/docs/rosenbrock_SGDW.png and b/docs/rosenbrock_SGDW.png differ diff --git a/docs/rosenbrock_SWATS.png b/docs/rosenbrock_SWATS.png index b2d51fb..e45a24b 100644 Binary files a/docs/rosenbrock_SWATS.png and b/docs/rosenbrock_SWATS.png differ diff --git a/docs/rosenbrock_SWATS_adamD.png b/docs/rosenbrock_SWATS_adamD.png new file mode 100644 index 0000000..985581f Binary files /dev/null and b/docs/rosenbrock_SWATS_adamD.png differ diff --git a/docs/rosenbrock_Shampoo.png b/docs/rosenbrock_Shampoo.png index decbb38..cd7b505 100644 Binary files a/docs/rosenbrock_Shampoo.png and b/docs/rosenbrock_Shampoo.png differ diff --git a/docs/rosenbrock_Yogi.png b/docs/rosenbrock_Yogi.png index fd30e3a..7c5e5f8 100644 Binary files a/docs/rosenbrock_Yogi.png and b/docs/rosenbrock_Yogi.png differ diff --git a/docs/rosenbrock_Yogi_adamD.png b/docs/rosenbrock_Yogi_adamD.png new file mode 100644 index 0000000..50dc4e7 Binary files /dev/null and b/docs/rosenbrock_Yogi_adamD.png differ diff --git a/examples/viz_optimizers.py b/examples/viz_optimizers.py index baf738d..698e24d 100644 --- a/examples/viz_optimizers.py +++ b/examples/viz_optimizers.py @@ -1,4 +1,5 @@ import math +from typing import List, Tuple import matplotlib.pyplot as plt import numpy as np @@ -9,6 +10,9 @@ plt.style.use('seaborn-white') +NUM_ITER: int = 500 +NUM_ITER_HPARAM: int = 200 + def rosenbrock(tensor): # https://en.wikipedia.org/wiki/Test_functions_for_optimization @@ -29,11 +33,10 @@ def rastrigin(tensor, lib=torch): def execute_steps( - func, initial_state, optimizer_class, optimizer_config, num_iter=500 + func, initial_state, optimizer_class, optimizer_config, num_iter=NUM_ITER ): x = torch.Tensor(initial_state).requires_grad_(True) optimizer = optimizer_class([x], **optimizer_config) - steps = [] steps = np.zeros((2, num_iter + 1)) steps[:, 0] = np.array(initial_state) for i in range(1, num_iter + 1): @@ -49,10 +52,11 @@ def execute_steps( def objective_rastrigin(params): lr = params['lr'] optimizer_class = params['optimizer_class'] + kwargs = params['kwargs'] initial_state = (-2.0, 3.5) minimum = (0, 0) - optimizer_config = dict(lr=lr) - num_iter = 100 + optimizer_config = dict(lr=lr, **kwargs) + num_iter = NUM_ITER_HPARAM steps = execute_steps( rastrigin, initial_state, optimizer_class, optimizer_config, num_iter ) @@ -62,74 +66,167 @@ def objective_rastrigin(params): def objective_rosenbrok(params): lr = params['lr'] optimizer_class = params['optimizer_class'] + kwargs = params['kwargs'] minimum = (1.0, 1.0) initial_state = (-2.0, 2.0) - optimizer_config = dict(lr=lr) - num_iter = 100 + optimizer_config = dict(lr=lr, **kwargs) + num_iter = NUM_ITER_HPARAM steps = execute_steps( rosenbrock, initial_state, optimizer_class, optimizer_config, num_iter ) return (steps[0][-1] - minimum[0]) ** 2 + (steps[1][-1] - minimum[1]) ** 2 -def plot_rastrigin(grad_iter, optimizer_name, lr): +def plot_rastrigin(grad_iters, optimizer_name, lr): x = np.linspace(-4.5, 4.5, 250) y = np.linspace(-4.5, 4.5, 250) minimum = (0, 0) X, Y = np.meshgrid(x, y) Z = rastrigin([X, Y], lib=np) - - iter_x, iter_y = grad_iter[0, :], grad_iter[1, :] - + assert len(grad_iters) <= 3, "Cannot handle more than three states" + l = None fig = plt.figure(figsize=(8, 8)) ax = fig.add_subplot(1, 1, 1) - ax.contour(X, Y, Z, 20, cmap='jet') - ax.plot(iter_x, iter_y, color='r', marker='x') - ax.set_title( + plt.contour(X, Y, Z, 20, cmap='jet', alpha=0.75) + for grad_iter, color in zip(grad_iters, ['r', 'm', 'c']): + iter_x, iter_y = grad_iter[0, :], grad_iter[1, :] + if l is None: + l = len(iter_x) + ax.plot(iter_x, iter_y, color=color, marker=None, alpha=0.75) + for px, py, pdx, pdy in zip( + iter_x[:-1], + iter_y[:-1], + iter_x[1:] - iter_x[:-1], + iter_y[1:] - iter_y[:-1], + ): + ax.arrow( + x=px, + y=py, + dx=pdx, + dy=pdy, + overhang=0.5, + width=0.001, + head_width=0.08, + length_includes_head=True, + color=color, + visible=True, + ) + # Starting point + ax.plot( + iter_x[0], + iter_y[0], + marker="s", + markersize=11, + markeredgecolor="black", + markerfacecolor=color, + markeredgewidth=2, + ) + # Ending point + ax.plot( + iter_x[-1], + iter_y[-1], + marker="P", + markersize=11, + markeredgecolor="black", + markerfacecolor=color, + markeredgewidth=2, + ) + plt.title( 'Rastrigin func: {} with ' - '{} iterations, lr={:.6}'.format(optimizer_name, len(iter_x), lr) + '{} iterations, lr={:.6}'.format(optimizer_name, l, lr) ) - plt.plot(*minimum, 'gD') - plt.plot(iter_x[-1], iter_y[-1], 'rD') + plt.xlim(-4.5, 4.5) + plt.ylim(-4.5, 4.5) + plt.plot(*minimum, 'X', color="green", markersize=11) plt.savefig('docs/rastrigin_{}.png'.format(optimizer_name)) + plt.close() -def plot_rosenbrok(grad_iter, optimizer_name, lr): +def plot_rosenbrok(grad_iters, optimizer_name, lr): x = np.linspace(-2, 2, 250) y = np.linspace(-1, 3, 250) minimum = (1.0, 1.0) + assert len(grad_iters) <= 3, "Cannot handle more than three states" X, Y = np.meshgrid(x, y) Z = rosenbrock([X, Y]) - - iter_x, iter_y = grad_iter[0, :], grad_iter[1, :] - + l = None fig = plt.figure(figsize=(8, 8)) - ax = fig.add_subplot(1, 1, 1) - ax.contour(X, Y, Z, 90, cmap='jet') - ax.plot(iter_x, iter_y, color='r', marker='x') - - ax.set_title( + ax.contour(X, Y, Z, 90, cmap='jet', alpha=0.75) + for grad_iter, color in zip(grad_iters, ['r', 'm', 'c']): + iter_x, iter_y = grad_iter[0, :], grad_iter[1, :] + if l is None: + l = len(iter_x) + ax.plot(iter_x, iter_y, color=color, marker=None, alpha=0.75) + + for px, py, pdx, pdy in zip( + iter_x[:-1], + iter_y[:-1], + iter_x[1:] - iter_x[:-1], + iter_y[1:] - iter_y[:-1], + ): + ax.arrow( + x=px, + y=py, + dx=pdx, + dy=pdy, + overhang=0.5, + width=0.001, + head_width=0.0375, + length_includes_head=True, + color=color, + visible=True, + ) + # Starting point + ax.plot( + iter_x[0], + iter_y[0], + marker="s", + markersize=11, + markeredgecolor="black", + markerfacecolor=color, + markeredgewidth=2, + ) + # Ending point + ax.plot( + iter_x[-1], + iter_y[-1], + marker="P", + markersize=11, + markeredgecolor="black", + markerfacecolor=color, + markeredgewidth=2, + ) + plt.title( 'Rosenbrock func: {} with {} ' - 'iterations, lr={:.6}'.format(optimizer_name, len(iter_x), lr) + 'iterations, lr={:.6}'.format(optimizer_name, l, lr) ) - plt.plot(*minimum, 'gD') - plt.plot(iter_x[-1], iter_y[-1], 'rD') + plt.plot(*minimum, 'X', color="green", markersize=11) + plt.xlim(-2, 2) + plt.ylim(-1, 3) plt.savefig('docs/rosenbrock_{}.png'.format(optimizer_name)) + plt.close() def execute_experiments( - optimizers, objective, func, plot_func, initial_state, seed=1 + optimizers, + objective, + func, + plot_func, + initial_states: List[Tuple[float, float]], + seed=1, ): seed = seed for item in optimizers: - optimizer_class, lr_low, lr_hi = item + optimizer_class, lr_low, lr_hi, kwargs, extra_desc = item + extra_desc_str = '' if not extra_desc else f'_{extra_desc}' space = { 'optimizer_class': hp.choice('optimizer_class', [optimizer_class]), 'lr': hp.loguniform('lr', lr_low, lr_hi), + 'kwargs': kwargs, } best = fmin( fn=objective, @@ -139,15 +236,21 @@ def execute_experiments( rstate=np.random.RandomState(seed), ) print(best['lr'], optimizer_class) - - steps = execute_steps( - func, - initial_state, - optimizer_class, - {'lr': best['lr']}, - num_iter=500, + steps_lst = [] + for initial_state in initial_states: + steps = execute_steps( + func, + initial_state, + optimizer_class, + {'lr': best['lr'], **kwargs}, + num_iter=NUM_ITER, + ) + steps_lst.append(steps) + plot_func( + steps_lst, + f'{optimizer_class.__name__}{extra_desc_str}', + best['lr'], ) - plot_func(steps, optimizer_class.__name__, best['lr']) def LookaheadYogi(*a, **kw): @@ -162,46 +265,78 @@ def LookaheadYogi(*a, **kw): # help to converge on better lr faster. optimizers = [ # baselines - (torch.optim.Adam, -8, 0.5), - (torch.optim.SGD, -8, -1.0), + (torch.optim.Adam, -8, 0.5, {}, None), + (torch.optim.SGD, -8, -1.0, {}, None), # Adam based - (optim.AdaBound, -8, 0.3), - (optim.Adahessian, -1, 8), - (optim.AdaMod, -8, 0.2), - (optim.AdamP, -8, 0.2), - (optim.DiffGrad, -8, 0.4), - (optim.Lamb, -8, -2.9), - (optim.MADGRAD, -8, 0.5), - (optim.NovoGrad, -8, -1.7), - (optim.RAdam, -8, 0.5), - (optim.Yogi, -8, 0.1), + (optim.Adam, -8, 0.5, {}, 'internal'), + ( + optim.Adam, + -8, + 0.5, + {'adamd_bias_correction': True}, + 'internal_adamD', + ), + (optim.AdamW, -8, 0.5, {}, 'internal'), + ( + optim.AdamW, + -8, + 0.5, + {'adamd_bias_correction': True}, + 'internal_adamD', + ), + (optim.AdaBound, -8, 0.3, {}, None), + (optim.AdaBound, -8, 0.3, {'adamd_bias_correction': True}, 'adamD'), + # TODO + (optim.Adahessian, -1, 8, {}, None), + (optim.Adahessian, -1, 8, {'adamd_bias_correction': True}, 'adamD'), + (optim.AdaMod, -8, 0.2, {}, None), + (optim.AdaMod, -8, 0.2, {'adamd_bias_correction': True}, 'adamD'), + (optim.AdamP, -8, 0.2, {}, None), + (optim.AdamP, -8, 0.2, {'adamd_bias_correction': True}, 'adamD'), + (optim.DiffGrad, -8, 0.4, {}, None), + (optim.DiffGrad, -8, 0.4, {'adamd_bias_correction': True}, 'adamD'), + (optim.Lamb, -8, -2.9, {}, None), + ( + optim.Lamb, + -8, + -2.9, + {'debias': True, 'adamd_bias_correction': True}, + 'adamD', + ), + (optim.MADGRAD, -8, 0.5, {}, None), + (optim.NovoGrad, -8, -1.7, {}, None), + (optim.RAdam, -8, 0.5, {}, None), + (optim.Yogi, -8, 0.1, {}, None), + (optim.Yogi, -8, 0.1, {'adamd_bias_correction': True}, 'adamD'), # SGD/Momentum based - (optim.AccSGD, -8, -1.4), - (optim.SGDW, -8, -1.5), - (optim.SGDP, -8, -1.5), - (optim.PID, -8, -1.0), - (optim.QHM, -6, -0.2), - (optim.QHAdam, -8, 0.1), - (optim.Ranger, -8, 0.1), - (optim.RangerQH, -8, 0.1), - (optim.RangerVA, -8, 0.1), - (optim.Shampoo, -8, 0.1), - (LookaheadYogi, -8, 0.1), - (optim.AggMo, -8, -1.5), - (optim.SWATS, -8, -1.5), - (optim.Adafactor, -8, 0.5), - (optim.A2GradUni, -8, 0.1), - (optim.A2GradInc, -8, 0.1), - (optim.A2GradExp, -8, 0.1), - (optim.AdaBelief, -8, 0.1), - (optim.Apollo, -8, 0.1), + (optim.AccSGD, -8, -1.4, {}, None), + (optim.SGDW, -8, -1.5, {}, None), + (optim.SGDP, -8, -1.5, {}, None), + (optim.PID, -8, -1.0, {}, None), + (optim.QHM, -6, -0.2, {}, None), + (optim.QHAdam, -8, 0.1, {}, None), + (optim.Ranger, -8, 0.1, {}, None), + (optim.RangerQH, -8, 0.1, {}, None), + (optim.RangerVA, -8, 0.1, {}, None), + (optim.Shampoo, -8, 0.1, {}, None), + (LookaheadYogi, -8, 0.1, {}, None), + (optim.AggMo, -8, -1.5, {}, None), + (optim.SWATS, -8, -1.5, {}, None), + (optim.SWATS, -8, -1.5, {'adamd_bias_correction': True}, 'adamD'), + (optim.Adafactor, -8, 0.5, {}, None), + (optim.A2GradUni, -8, 0.1, {}, None), + (optim.A2GradInc, -8, 0.1, {}, None), + (optim.A2GradExp, -8, 0.1, {}, None), + (optim.AdaBelief, -8, 0.1, {}, None), + (optim.AdaBelief, -8, 0.1, {'adamd_bias_correction': True}, 'adamD'), + (optim.Apollo, -8, 0.1, {}, None), ] execute_experiments( optimizers, objective_rastrigin, rastrigin, plot_rastrigin, - (-2.0, 3.5), + [(-2.0, 3.5), (1.0, -2.0)], ) execute_experiments( @@ -209,5 +344,5 @@ def LookaheadYogi(*a, **kw): objective_rosenbrok, rosenbrock, plot_rosenbrok, - (-2.0, 2.0), + [(-2.0, 2.0), (-0.5, 2.75)], ) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_basic.py b/tests/test_basic.py index 89c43db..95b79b7 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -55,19 +55,70 @@ def build_lookahead(*a, **kw): (optim.RAdam, {'lr': 0.01, 'betas': (0.9, 0.95), 'eps': 1e-3}, 800), (optim.SGDW, {'lr': 0.002, 'momentum': 0.91}, 900), (optim.DiffGrad, {'lr': 0.5}, 500), + (optim.DiffGrad, {'lr': 0.5, 'adamd_bias_correction': True}, 500), (optim.AdaMod, {'lr': 1.0}, 800), + (optim.AdaMod, {'lr': 1.0, 'adamd_bias_correction': True}, 800), (optim.AdaBound, {'lr': 1.0}, 800), + (optim.AdaBound, {'lr': 1.0, 'adamd_bias_correction': True}, 800), (optim.Yogi, {'lr': 1.0}, 500), + (optim.Yogi, {'lr': 1.0, 'adamd_bias_correction': True}, 500), (optim.AccSGD, {'lr': 0.015}, 800), (build_lookahead, {'lr': 1.0}, 500), (optim.QHAdam, {'lr': 1.0}, 500), + # FIXME find params that work for Lamb + # (optim.Lamb, {'lr': 0.1, 'betas': (0.9, 0.95)}, 900), + # (optim.Lamb, {'lr': 0.1, 'betas': (0.9, 0.95), 'debias': True, + # 'adamd_bias_correction': True}, 900), + (optim.Adam, {'lr': 0.01, 'betas': (0.9, 0.95)}, 900), + (optim.AdamW, {'lr': 0.01, 'betas': (0.9, 0.95)}, 900), + ( + optim.Adam, + {'lr': 0.01, 'betas': (0.9, 0.95), 'adamd_bias_correction': True}, + 900, + ), + ( + optim.AdamW, + {'lr': 0.01, 'betas': (0.9, 0.95), 'adamd_bias_correction': True}, + 900, + ), (optim.AdamP, {'lr': 0.01, 'betas': (0.9, 0.95), 'eps': 1e-3}, 800), + ( + optim.AdamP, + { + 'lr': 0.01, + 'betas': (0.9, 0.95), + 'eps': 1e-3, + 'adamd_bias_correction': True, + }, + 800, + ), (optim.SGDP, {'lr': 0.002, 'momentum': 0.91}, 900), (optim.AggMo, {'lr': 0.003}, 1800), (optim.SWATS, {'lr': 0.1, 'amsgrad': True, 'nesterov': True}, 900), + ( + optim.SWATS, + { + 'lr': 0.2, + 'amsgrad': True, + 'nesterov': True, + 'adamd_bias_correction': True, + }, + 900, + ), (optim.Adafactor, {'lr': None, 'decay_rate': -0.3, 'beta1': 0.9}, 800), (optim.AdaBelief, {'lr': 1.0}, 500), + (optim.AdaBelief, {'lr': 1.0, 'adamd_bias_correction': True}, 500), (optim.Adahessian, {'lr': 0.15, 'hessian_power': 0.6, 'seed': 0}, 900), + ( + optim.Adahessian, + { + 'lr': 0.15, + 'hessian_power': 0.6, + 'seed': 0, + 'adamd_bias_correction': True, + }, + 900, + ), (optim.MADGRAD, {'lr': 0.02}, 500), (optim.LARS, {'lr': 0.002, 'momentum': 0.91}, 900), ] diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index db547d3..bbc2268 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -74,6 +74,8 @@ def build_lookahead(*a, **kw): optim.AdaBound, optim.AdaMod, optim.Adafactor, + optim.Adam, + optim.AdamW, optim.AdamP, optim.AggMo, optim.Apollo, diff --git a/tests/test_optimizer_with_nn.py b/tests/test_optimizer_with_nn.py index 08d6247..77844cf 100644 --- a/tests/test_optimizer_with_nn.py +++ b/tests/test_optimizer_with_nn.py @@ -57,19 +57,71 @@ def build_lookahead(*a, **kw): (optim.A2GradUni, {'lips': 5.0, 'beta': 1e-3}, 500), (optim.AccSGD, {'lr': 1.0, 'weight_decay': 1e-3}, 200), (optim.AdaBelief, {'lr': 0.1, 'weight_decay': 1e-3}, 200), + ( + optim.AdaBelief, + {'lr': 0.1, 'weight_decay': 1e-3, 'adamd_bias_correction': True}, + 200, + ), (optim.AdaBound, {'lr': 1.5, 'gamma': 0.1, 'weight_decay': 1e-3}, 200), + ( + optim.AdaBound, + { + 'lr': 1.5, + 'gamma': 0.1, + 'weight_decay': 1e-3, + 'adamd_bias_correction': True, + }, + 200, + ), (optim.AdaMod, {'lr': 2.0, 'weight_decay': 1e-3}, 200), + ( + optim.AdaMod, + {'lr': 2.0, 'weight_decay': 1e-3, 'adamd_bias_correction': True}, + 200, + ), (optim.Adafactor, {'lr': 0.004466, 'weight_decay': 1e-3}, 1500), + (optim.Adam, {'lr': 0.045, 'weight_decay': 1e-3}, 800), + ( + optim.Adam, + {'lr': 0.045, 'weight_decay': 1e-3, 'adamd_bias_correction': True}, + 800, + ), + (optim.AdamW, {'lr': 0.045, 'weight_decay': 1e-3}, 800), + ( + optim.AdamW, + {'lr': 0.045, 'weight_decay': 1e-3, 'adamd_bias_correction': True}, + 800, + ), (optim.AdamP, {'lr': 0.045, 'weight_decay': 1e-3}, 800), + ( + optim.AdamP, + {'lr': 0.045, 'weight_decay': 1e-3, 'adamd_bias_correction': True}, + 800, + ), (optim.AggMo, {'lr': 0.17059, 'weight_decay': 1e-3}, 1000), (optim.Apollo, {'lr': 0.1, 'weight_decay': 1e-3}, 200), (optim.DiffGrad, {'lr': 0.5, 'weight_decay': 1e-3}, 200), + ( + optim.DiffGrad, + {'lr': 0.5, 'weight_decay': 1e-3, 'adamd_bias_correction': True}, + 200, + ), ( optim.LARS, {'lr': 1.0, 'weight_decay': 1e-3, 'trust_coefficient': 0.01}, 200, ), (optim.Lamb, {'lr': 0.0151, 'weight_decay': 1e-3}, 1000), + ( + optim.Lamb, + { + 'lr': 0.0151, + 'weight_decay': 1e-3, + 'debias': True, + 'adamd_bias_correction': True, + }, + 1000, + ), (optim.MADGRAD, {'lr': 1.0, 'weight_decay': 1e-3}, 200), (optim.NovoGrad, {'lr': 0.01, 'weight_decay': 1e-3}, 200), (optim.PID, {'lr': 0.01, 'weight_decay': 1e-3, 'momentum': 0.1}, 200), @@ -82,13 +134,28 @@ def build_lookahead(*a, **kw): (optim.SGDP, {'lr': 1.0, 'weight_decay': 1e-3}, 200), (optim.SGDW, {'lr': 1.0, 'weight_decay': 1e-3}, 200), (optim.SWATS, {'lr': 0.703, 'weight_decay': 1e-3}, 600), + ( + optim.SWATS, + {'lr': 0.703, 'weight_decay': 1e-3, 'adamd_bias_correction': True}, + 600, + ), ( optim.Shampoo, {'lr': 0.279, 'weight_decay': 1e-3, 'momentum': 0.05}, 1600, ), (optim.Yogi, {'lr': 0.1, 'weight_decay': 1e-3}, 200), + ( + optim.Yogi, + {'lr': 0.1, 'weight_decay': 1e-3, 'adamd_bias_correction': True}, + 200, + ), (optim.Adahessian, {'lr': 0.1, 'weight_decay': 1e-3}, 200), + ( + optim.Adahessian, + {'lr': 0.1, 'weight_decay': 1e-3, 'adamd_bias_correction': True}, + 200, + ), ] diff --git a/torch_optimizer/__init__.py b/torch_optimizer/__init__.py index dc97b28..bd42f11 100644 --- a/torch_optimizer/__init__.py +++ b/torch_optimizer/__init__.py @@ -25,6 +25,7 @@ from .adabound import AdaBound from .adafactor import Adafactor from .adahessian import Adahessian +from .adam import Adam, AdamW from .adamod import AdaMod from .adamp import AdamP from .aggmo import AggMo @@ -55,6 +56,8 @@ 'AdaMod', 'Adafactor', 'Adahessian', + 'Adam', + 'AdamW', 'AdamP', 'AggMo', 'Apollo', @@ -87,6 +90,8 @@ AccSGD, AdaBound, AdaMod, + Adam, + AdamW, AdamP, AggMo, DiffGrad, diff --git a/torch_optimizer/adabelief.py b/torch_optimizer/adabelief.py index a65d76f..c5837a2 100644 --- a/torch_optimizer/adabelief.py +++ b/torch_optimizer/adabelief.py @@ -37,6 +37,10 @@ class AdaBelief(Optimizer): rate (lr). (default: False) rectify: (default: False) If set as True, then perform the rectified update similar to RAdam + adamd_bias_correction: When performing bias correction (debias=True), + only correct the denominator to avoid inflating step sizes early + in training as suggested in `AdamD: Improved bias-correction in + Adam`__ (default: False) Example: >>> import torch_optimizer as optim @@ -46,6 +50,7 @@ class AdaBelief(Optimizer): >>> optimizer.step() __ https://arxiv.org/abs/2010.07468 + __ https://arxiv.org/abs/2110.10828 Note: Reference code: https://github.com/juntang-zhuang/Adabelief-Optimizer @@ -62,6 +67,7 @@ def __init__( weight_decouple: bool = False, fixed_decay: bool = False, rectify: bool = False, + adamd_bias_correction: bool = False, ) -> None: if lr <= 0.0: raise ValueError('Invalid learning rate: {}'.format(lr)) @@ -85,6 +91,7 @@ def __init__( eps=eps, weight_decay=weight_decay, amsgrad=amsgrad, + adamd_bias_correction=adamd_bias_correction, ) super(AdaBelief, self).__init__(params, defaults) @@ -96,6 +103,7 @@ def __setstate__(self, state): super(AdaBelief, self).__setstate__(state) for group in self.param_groups: group.setdefault('amsgrad', False) + group.setdefault('adamd_bias_correction', False) def step(self, closure: OptLossClosure = None) -> OptFloat: r"""Performs a single optimization step. @@ -187,7 +195,10 @@ def step(self, closure: OptLossClosure = None) -> OptFloat: if not self._rectify: # Default update - step_size = group['lr'] / bias_correction1 + if group['adamd_bias_correction']: + step_size = group['lr'] + else: + step_size = group['lr'] / bias_correction1 p.data.addcdiv_(exp_avg, denom, value=-step_size) else: # Rectified update @@ -209,8 +220,10 @@ def step(self, closure: OptLossClosure = None) -> OptFloat: / rho_t ) rt = math.sqrt(rt) - - step_size = rt * group['lr'] / bias_correction1 + if group['adamd_bias_correction']: + step_size = rt * group['lr'] + else: + step_size = rt * group['lr'] / bias_correction1 p.data.addcdiv_(-step_size, exp_avg, denom) diff --git a/torch_optimizer/adabound.py b/torch_optimizer/adabound.py index 067c343..4ce1689 100644 --- a/torch_optimizer/adabound.py +++ b/torch_optimizer/adabound.py @@ -27,6 +27,10 @@ class AdaBound(Optimizer): (default: 1e-8) weight_decay: weight decay (L2 penalty) (default: 0) amsbound: whether to use the AMSBound variant of this algorithm + adamd_bias_correction: When performing bias correction (debias=True), + only correct the denominator to avoid inflating step sizes early + in training as suggested in `AdamD: Improved bias-correction in + Adam`__ (default: False) Example: >>> import torch_optimizer as optim @@ -36,6 +40,7 @@ class AdaBound(Optimizer): >>> optimizer.step() __ https://arxiv.org/abs/1902.09843 + __ https://arxiv.org/abs/2110.10828 Note: Reference code: https://github.com/Luolc/AdaBound @@ -51,6 +56,7 @@ def __init__( eps: float = 1e-8, weight_decay: float = 0, amsbound: bool = False, + adamd_bias_correction: bool = False, ) -> None: if lr <= 0.0: raise ValueError('Invalid learning rate: {}'.format(lr)) @@ -82,6 +88,7 @@ def __init__( eps=eps, weight_decay=weight_decay, amsbound=amsbound, + adamd_bias_correction=adamd_bias_correction, ) super(AdaBound, self).__init__(params, defaults) self.base_lrs = [group['lr'] for group in self.param_groups] @@ -90,6 +97,7 @@ def __setstate__(self, state: State) -> None: super(AdaBound, self).__setstate__(state) for group in self.param_groups: group.setdefault('amsbound', False) + group.setdefault('adamd_bias_correction', False) def step(self, closure: OptLossClosure = None) -> OptFloat: r"""Performs a single optimization step. @@ -158,11 +166,15 @@ def step(self, closure: OptLossClosure = None) -> OptFloat: bias_correction1 = 1 - beta1 ** state['step'] bias_correction2 = 1 - beta2 ** state['step'] - step_size = ( - group['lr'] - * math.sqrt(bias_correction2) - / bias_correction1 - ) + + if group['adamd_bias_correction']: + step_size = group['lr'] * math.sqrt(bias_correction2) + else: + step_size = ( + group['lr'] + * math.sqrt(bias_correction2) + / bias_correction1 + ) # Applies bounds on actual learning rate # lr_scheduler cannot affect final_lr, this is a workaround diff --git a/torch_optimizer/adahessian.py b/torch_optimizer/adahessian.py index ca37329..671c6cb 100644 --- a/torch_optimizer/adahessian.py +++ b/torch_optimizer/adahessian.py @@ -27,6 +27,10 @@ class Adahessian(Optimizer): weight_decay (float, optional): weight decay (L2 penalty) (default: 0) hessian_power (float, optional): Hessian power (default: 0.5) seed (int, optional): Random number generator seed (default: None) + adamd_bias_correction: When performing bias correction (debias=True), + only correct the denominator to avoid inflating step sizes early + in training as suggested in `AdamD: Improved bias-correction in + Adam`__ (default: False) Example: >>> import torch_optimizer as optim @@ -36,6 +40,7 @@ class Adahessian(Optimizer): >>> optimizer.step() __ https://arxiv.org/abs/2006.00719 + __ https://arxiv.org/abs/2110.10828 Note: Reference code: https://github.com/amirgholami/adahessian @@ -50,6 +55,7 @@ def __init__( weight_decay: float = 0, hessian_power: float = 0.5, seed: Optional[int] = None, + adamd_bias_correction: bool = False, ) -> None: if lr <= 0.0: raise ValueError('Invalid learning rate: {}'.format(lr)) @@ -75,9 +81,15 @@ def __init__( eps=eps, weight_decay=weight_decay, hessian_power=hessian_power, + adamd_bias_correction=adamd_bias_correction, ) super(Adahessian, self).__init__(params, defaults) + def __setstate__(self, state): + super(Adahessian, self).__setstate__(state) + for group in self.param_groups: + group.setdefault('adamd_bias_correction', False) + def get_trace(self, params: Params, grads: Grads) -> List[torch.Tensor]: """Get an estimate of Hessian Trace. This is done by computing the Hessian vector product with a random @@ -197,9 +209,12 @@ def step(self, closure: OptLossClosure = None) -> OptFloat: ).add_(group['eps']) # make update - p.data = p.data - group['lr'] * ( - exp_avg / bias_correction1 / denom - + group['weight_decay'] * p.data + if group['adamd_bias_correction']: + step_size = group['lr'] + else: + step_size = group['lr'] / bias_correction1 + p.data = p.data - step_size * ( + exp_avg / denom + group['weight_decay'] * p.data ) return loss diff --git a/torch_optimizer/adam.py b/torch_optimizer/adam.py new file mode 100644 index 0000000..bef29c6 --- /dev/null +++ b/torch_optimizer/adam.py @@ -0,0 +1,500 @@ +import math +from typing import List + +import torch +from torch import Tensor +from torch.optim.optimizer import Optimizer + +from .types import Params + +__all__ = ('Adam', 'AdamW') + + +def adam( + params: Params, + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + state_steps: List[int], + *, + amsgrad: bool, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, + adamd_bias_correction: bool +): + r"""Functional API that performs Adam algorithm computation. + See :class:`~torch.optim.Adam` for details. + """ + + for i, param in enumerate(params): + + grad = grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + step = state_steps[i] + + bias_correction1 = 1 - beta1 ** step + bias_correction2 = 1 - beta2 ** step + + if weight_decay != 0: + grad = grad.add(param, alpha=weight_decay) + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) + if amsgrad: + # Maintains the maximum of all 2nd moment running avg. till now + torch.maximum( + max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i] + ) + # Use the max. for normalizing running avg. of gradient + denom = ( + max_exp_avg_sqs[i].sqrt() / math.sqrt(bias_correction2) + ).add_(eps) + else: + denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps) + + if adamd_bias_correction: + step_size = lr + else: + step_size = lr / bias_correction1 + + param.addcdiv_(exp_avg, denom, value=-step_size) + + +def adamw( + params: Params, + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + state_steps: List[int], + *, + amsgrad: bool, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, + adamd_bias_correction: bool +): + r"""Functional API that performs AdamW algorithm computation. + See :class:`~torch.optim.AdamW` for details. + """ + for i, param in enumerate(params): + grad = grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + step = state_steps[i] + + # Perform stepweight decay + param.mul_(1 - lr * weight_decay) + + bias_correction1 = 1 - beta1 ** step + bias_correction2 = 1 - beta2 ** step + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + if amsgrad: + # Maintains the maximum of all 2nd moment running avg. till now + torch.maximum( + max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i] + ) + # Use the max. for normalizing running avg. of gradient + denom = ( + max_exp_avg_sqs[i].sqrt() / math.sqrt(bias_correction2) + ).add_(eps) + else: + denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps) + if adamd_bias_correction: + step_size = lr + else: + step_size = lr / bias_correction1 + param.addcdiv_(exp_avg, denom, value=-step_size) + + +class Adam(Optimizer): + r"""Implements Adam algorithm. + .. math:: + \begin{aligned} + &\rule{110mm}{0.4pt} \\ + &\textbf{input} : \gamma \text{ (lr)}, \beta_1, \beta_2 + \text{ (betas)},\theta_0 \text{ (params)},f(\theta) + \text{ (objective)} \\ + &\hspace{13mm} \lambda \text{ (weight decay)}, + \: amsgrad \\ + &\textbf{initialize} : m_0 \leftarrow 0 \text{ ( first moment)}, + v_0\leftarrow 0 \text{ (second moment)},\: + \widehat{v_0}^{max}\leftarrow 0\\[-1.ex] + &\rule{110mm}{0.4pt} \\ + &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ + &\hspace{5mm}g_t \leftarrow + \nabla_{\theta} f_t (\theta_{t-1}) \\ + &\hspace{5mm}\textbf{if} \: \lambda \neq 0 \\ + &\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\ + &\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + + (1 - \beta_1) g_t \\ + &\hspace{5mm}v_t \leftarrow \beta_2 v_{t-1} + + (1-\beta_2) g^2_t \\ + &\hspace{5mm}\widehat{m_t} \leftarrow + m_t/\big(1-\beta_1^t \big) \\ + &\hspace{5mm}\widehat{v_t} \leftarrow + v_t/\big(1-\beta_2^t \big) \\ + &\hspace{5mm}\textbf{if} \: amsgrad \\ + &\hspace{10mm}\widehat{v_t}^{max} \leftarrow + \mathrm{max}(\widehat{v_t}^{max}, + \widehat{v_t}) \\ + &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} - + \gamma \widehat{m_t}/ + \big(\sqrt{\widehat{v_t}^{max}} + \epsilon \big) \\ + &\hspace{5mm}\textbf{else} \\ + &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} - + \gamma \widehat{m_t}/ + \big(\sqrt{\widehat{v_t}} + \epsilon \big) \\ + &\rule{110mm}{0.4pt} \\[-1.ex] + &\bf{return} \: \theta_t \\[-1.ex] + &\rule{110mm}{0.4pt} \\[-1.ex] + \end{aligned} + For further details regarding the algorithm we refer to `Adam: A Method for + Stochastic Optimization`_. + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + amsgrad (boolean, optional): whether to use the AMSGrad variant of this + algorithm from the paper `On the Convergence of Adam and Beyond`_ + (default: False) + adamd_bias_correction (boolean, optional): When performing bias + correction, only correct the denominator to avoid inflating step + sizes early in training as suggested in `AdamD: Improved + bias-correction in Adam`_ (default: False) + .. _Adam\: A Method for Stochastic Optimization: + https://arxiv.org/abs/1412.6980 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + .. _AdamD: Improved bias-correction in Adam: + https://arxiv.org/abs/2110.10828 + """ + + def __init__( + self, + params: Params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0, + amsgrad=False, + adamd_bias_correction: bool = False, + ): + if not 0.0 <= lr: + raise ValueError('Invalid learning rate: {}'.format(lr)) + if not 0.0 <= eps: + raise ValueError('Invalid epsilon value: {}'.format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError( + 'Invalid beta parameter at index 0: {}'.format(betas[0]) + ) + if not 0.0 <= betas[1] < 1.0: + raise ValueError( + 'Invalid beta parameter at index 1: {}'.format(betas[1]) + ) + if not 0.0 <= weight_decay: + raise ValueError( + 'Invalid weight_decay value: {}'.format(weight_decay) + ) + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + amsgrad=amsgrad, + adamd_bias_correction=adamd_bias_correction, + ) + super(Adam, self).__init__(params, defaults) + + def __setstate__(self, state): + super(Adam, self).__setstate__(state) + for group in self.param_groups: + group.setdefault('amsgrad', False) + group.setdefault('adamd_bias_correction', False) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + max_exp_avg_sqs = [] + state_steps = [] + beta1, beta2 = group['betas'] + + for p in group['params']: + if p.grad is not None: + params_with_grad.append(p) + if p.grad.is_sparse: + raise RuntimeError( + 'Adam does not support sparse gradients, ' + 'please consider SparseAdam instead' + ) + grads.append(p.grad) + + state = self.state[p] + # Lazy state initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + if group['amsgrad']: + # Maintains max of all exp. moving avg. of sq. + # grad. values + state['max_exp_avg_sq'] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + + exp_avgs.append(state['exp_avg']) + exp_avg_sqs.append(state['exp_avg_sq']) + + if group['amsgrad']: + max_exp_avg_sqs.append(state['max_exp_avg_sq']) + + # update the steps for each param group update + state['step'] += 1 + # record the step after step update + state_steps.append(state['step']) + + adam( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=group['amsgrad'], + beta1=beta1, + beta2=beta2, + lr=group['lr'], + weight_decay=group['weight_decay'], + eps=group['eps'], + adamd_bias_correction=group['adamd_bias_correction'], + ) + return loss + + +class AdamW(Optimizer): + r"""Implements AdamW algorithm. + .. math:: + \begin{aligned} + &\rule{110mm}{0.4pt} \\ + &\textbf{input} : \gamma \text{(lr)}, \: \beta_1, \beta_2 + \text{(betas)}, \: \theta_0 \text{(params)}, \: + f(\theta) \text{(objective)}, + \: \epsilon \text{ (epsilon)} \\ + &\hspace{13mm} \lambda \text{(weight decay)}, \: amsgrad \\ + &\textbf{initialize} : m_0 \leftarrow 0 + \text{ (first moment)}, v_0 \leftarrow 0 + \text{ ( second moment)}, \: + \widehat{v_0}^{max}\leftarrow 0 \\[-1.ex] + &\rule{110mm}{0.4pt} \\ + &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ + &\hspace{5mm}g_t \leftarrow + \nabla_{\theta} f_t (\theta_{t-1}) \\ + &\hspace{5mm} \theta_t \leftarrow \theta_{t-1} - + \gamma \lambda \theta_{t-1} \\ + &\hspace{5mm}m_t \leftarrow + \beta_1 m_{t-1} + (1 - \beta_1) g_t \\ + &\hspace{5mm}v_t \leftarrow + \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\ + &\hspace{5mm}\widehat{m_t} \leftarrow + m_t/\big(1-\beta_1^t \big) \\ + &\hspace{5mm}\widehat{v_t} \leftarrow + v_t/\big(1-\beta_2^t \big) \\ + &\hspace{5mm}\textbf{if} \: amsgrad \\ + &\hspace{10mm}\widehat{v_t}^{max} \leftarrow + \mathrm{max}(\widehat{v_t}^{max}, + \widehat{v_t}) \\ + &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} + - \gamma \widehat{m_t}/ + \big(\sqrt{\widehat{v_t}^{max}} + \epsilon \big) \\ + &\hspace{5mm}\textbf{else} \\ + &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} - + \gamma \widehat{m_t}/ + \big(\sqrt{\widehat{v_t}} + \epsilon \big) \\ + &\rule{110mm}{0.4pt} \\[-1.ex] + &\bf{return} \: \theta_t \\[-1.ex] + &\rule{110mm}{0.4pt} \\[-1.ex] + \end{aligned} + For further details regarding the algorithm we refer to `Decoupled Weight + Decay Regularization`_. + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient + (default: 1e-2) + amsgrad (boolean, optional): whether to use the AMSGrad variant of this + algorithm from the paper `On the Convergence of Adam and Beyond`_ + (default: False) + adamd_bias_correction (boolean, optional): When performing bias + correction, only correct the denominator to avoid inflating step + sizes early in training as suggested in `AdamD: Improved + bias-correction in Adam`_ (default: False) + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + .. _AdamD: Improved bias-correction in Adam: + https://arxiv.org/abs/2110.10828 + """ + + def __init__( + self, + params: Params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=1e-2, + amsgrad=False, + adamd_bias_correction: bool = False, + ): + if not 0.0 <= lr: + raise ValueError('Invalid learning rate: {}'.format(lr)) + if not 0.0 <= eps: + raise ValueError('Invalid epsilon value: {}'.format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError( + 'Invalid beta parameter at index 0: {}'.format(betas[0]) + ) + if not 0.0 <= betas[1] < 1.0: + raise ValueError( + 'Invalid beta parameter at index 1: {}'.format(betas[1]) + ) + if not 0.0 <= weight_decay: + raise ValueError( + 'Invalid weight_decay value: {}'.format(weight_decay) + ) + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + amsgrad=amsgrad, + adamd_bias_correction=adamd_bias_correction, + ) + super(AdamW, self).__init__(params, defaults) + + def __setstate__(self, state): + super(AdamW, self).__setstate__(state) + for group in self.param_groups: + group.setdefault('amsgrad', False) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + max_exp_avg_sqs = [] + state_steps = [] + amsgrad = group['amsgrad'] + beta1, beta2 = group['betas'] + + for p in group['params']: + if p.grad is None: + continue + params_with_grad.append(p) + if p.grad.is_sparse: + raise RuntimeError( + 'AdamW does not support sparse gradients' + ) + grads.append(p.grad) + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + if amsgrad: + # Maintains max of all exp. moving avg. of sq. + # grad. values + state['max_exp_avg_sq'] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + + exp_avgs.append(state['exp_avg']) + exp_avg_sqs.append(state['exp_avg_sq']) + + if amsgrad: + max_exp_avg_sqs.append(state['max_exp_avg_sq']) + + # update the steps for each param group update + state['step'] += 1 + # record the step after step update + state_steps.append(state['step']) + + adamw( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=amsgrad, + beta1=beta1, + beta2=beta2, + lr=group['lr'], + weight_decay=group['weight_decay'], + eps=group['eps'], + adamd_bias_correction=group['adamd_bias_correction'], + ) + + return loss diff --git a/torch_optimizer/adamod.py b/torch_optimizer/adamod.py index 4aebe01..f2c1ac1 100644 --- a/torch_optimizer/adamod.py +++ b/torch_optimizer/adamod.py @@ -25,6 +25,10 @@ class AdaMod(Optimizer): eps: term added to the denominator to improve numerical stability (default: 1e-8) weight_decay: weight decay (L2 penalty) (default: 0) + adamd_bias_correction: When performing bias correction (debias=True), + only correct the denominator to avoid inflating step sizes early + in training as suggested in `AdamD: Improved bias-correction in + Adam`__ (default: False) Example: >>> import torch_optimizer as optim @@ -34,6 +38,7 @@ class AdaMod(Optimizer): >>> optimizer.step() __ https://arxiv.org/abs/1910.12249 + __ https://arxiv.org/abs/2110.10828 Note: Reference code: https://github.com/lancopku/AdaMod @@ -47,6 +52,7 @@ def __init__( beta3: float = 0.999, eps: float = 1e-8, weight_decay: float = 0, + adamd_bias_correction: bool = False, ) -> None: if lr <= 0.0: raise ValueError('Invalid learning rate: {}'.format(lr)) @@ -67,10 +73,20 @@ def __init__( 'Invalid weight_decay value: {}'.format(weight_decay) ) defaults = dict( - lr=lr, betas=betas, beta3=beta3, eps=eps, weight_decay=weight_decay + lr=lr, + betas=betas, + beta3=beta3, + eps=eps, + weight_decay=weight_decay, + adamd_bias_correction=adamd_bias_correction, ) super(AdaMod, self).__init__(params, defaults) + def __setstate__(self, state): + super(AdaMod, self).__setstate__(state) + for group in self.param_groups: + group.setdefault('adamd_bias_correction', False) + def step(self, closure: OptLossClosure = None) -> OptFloat: """Performs a single optimization step. @@ -125,11 +141,14 @@ def step(self, closure: OptLossClosure = None) -> OptFloat: bias_correction1 = 1 - beta1 ** state['step'] bias_correction2 = 1 - beta2 ** state['step'] - step_size = ( - group['lr'] - * math.sqrt(bias_correction2) - / bias_correction1 - ) + if group['adamd_bias_correction']: + step_size = group['lr'] * math.sqrt(bias_correction2) + else: + step_size = ( + group['lr'] + * math.sqrt(bias_correction2) + / bias_correction1 + ) if group['weight_decay'] != 0: p.data.add_( diff --git a/torch_optimizer/adamp.py b/torch_optimizer/adamp.py index 1d8f159..09b9052 100644 --- a/torch_optimizer/adamp.py +++ b/torch_optimizer/adamp.py @@ -28,6 +28,10 @@ class AdamP(Optimizer): wd_ratio: relative weight decay applied on scale-invariant parameters compared to that applied on scale-variant parameters (default: 0.1) nesterov: enables Nesterov momentum (default: False) + adamd_bias_correction: When performing bias correction (debias=True), + only correct the denominator to avoid inflating step sizes early + in training as suggested in `AdamD: Improved bias-correction in + Adam`__ (default: False) Example: @@ -38,6 +42,7 @@ class AdamP(Optimizer): >>> optimizer.step() __ https://arxiv.org/abs/2006.08217 + __ https://arxiv.org/abs/2110.10828 Note: Reference code: https://github.com/clovaai/AdamP @@ -53,6 +58,7 @@ def __init__( delta: float = 0.1, wd_ratio: float = 0.1, nesterov: bool = False, + adamd_bias_correction: bool = False, ) -> None: if lr <= 0.0: raise ValueError('Invalid learning rate: {}'.format(lr)) @@ -83,9 +89,16 @@ def __init__( delta=delta, wd_ratio=wd_ratio, nesterov=nesterov, + adamd_bias_correction=adamd_bias_correction, ) super(AdamP, self).__init__(params, defaults) + def __setstate__(self, state): + super(AdamP, self).__setstate__(state) + for group in self.param_groups: + group.setdefault('nesterov', False) + group.setdefault('adamd_bias_correction', False) + @staticmethod def _channel_view(x): return x.view(x.size(0), -1) @@ -169,7 +182,10 @@ def step(self, closure: OptLossClosure = None) -> OptFloat: denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_( group['eps'] ) - step_size = group['lr'] / bias_correction1 + if group['adamd_bias_correction']: + step_size = group['lr'] + else: + step_size = group['lr'] / bias_correction1 if nesterov: perturb = (beta1 * exp_avg + (1 - beta1) * grad) / denom diff --git a/torch_optimizer/diffgrad.py b/torch_optimizer/diffgrad.py index e812e4e..a989f8a 100644 --- a/torch_optimizer/diffgrad.py +++ b/torch_optimizer/diffgrad.py @@ -23,7 +23,10 @@ class DiffGrad(Optimizer): eps: term added to the denominator to improve numerical stability (default: 1e-8) weight_decay: weight decay (L2 penalty) (default: 0) - + adamd_bias_correction: When performing bias correction (debias=True), + only correct the denominator to avoid inflating step sizes early + in training as suggested in `AdamD: Improved bias-correction in + Adam`__ (default: False) Example: >>> import torch_optimizer as optim >>> optimizer = optim.DiffGrad(model.parameters(), lr=0.1) @@ -32,6 +35,7 @@ class DiffGrad(Optimizer): >>> optimizer.step() __ https://arxiv.org/abs/1909.11015 + __ https://arxiv.org/abs/2110.10828 Note: Reference code: https://github.com/shivram1987/diffGrad @@ -44,6 +48,7 @@ def __init__( betas: Betas2 = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.0, + adamd_bias_correction: bool = False, ) -> None: if lr <= 0.0: raise ValueError('Invalid learning rate: {}'.format(lr)) @@ -62,7 +67,13 @@ def __init__( 'Invalid weight_decay value: {}'.format(weight_decay) ) - defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + adamd_bias_correction=adamd_bias_correction, + ) super(DiffGrad, self).__init__(params, defaults) def step(self, closure: OptLossClosure = None) -> OptFloat: @@ -133,12 +144,14 @@ def step(self, closure: OptLossClosure = None) -> OptFloat: # update momentum with dfc exp_avg1 = exp_avg * dfc - - step_size = ( - group['lr'] - * math.sqrt(bias_correction2) - / bias_correction1 - ) + if group['adamd_bias_correction']: + step_size = group['lr'] * math.sqrt(bias_correction2) + else: + step_size = ( + group['lr'] + * math.sqrt(bias_correction2) + / bias_correction1 + ) p.data.addcdiv_(exp_avg1, denom, value=-step_size) diff --git a/torch_optimizer/lamb.py b/torch_optimizer/lamb.py index 4ac5fbb..a169789 100644 --- a/torch_optimizer/lamb.py +++ b/torch_optimizer/lamb.py @@ -28,7 +28,10 @@ class Lamb(Optimizer): adam: always use trust ratio = 1, which turns this into Adam. Useful for comparison purposes. (default: False) debias: debias adam by (1 - beta**step) (default: False) - + adamd_bias_correction: When performing bias correction (debias=True), + only correct the denominator to avoid inflating step sizes early + in training as suggested in `AdamD: Improved bias-correction in + Adam`__ (default: False) Example: >>> import torch_optimizer as optim >>> optimizer = optim.Lamb(model.parameters(), lr=0.1) @@ -37,6 +40,7 @@ class Lamb(Optimizer): >>> optimizer.step() __ https://arxiv.org/abs/1904.00962 + __ https://arxiv.org/abs/2110.10828 Note: Reference code: https://github.com/cybertronai/pytorch-lamb @@ -52,6 +56,7 @@ def __init__( clamp_value: float = 10, adam: bool = False, debias: bool = False, + adamd_bias_correction: bool = False, ) -> None: if lr <= 0.0: raise ValueError('Invalid learning rate: {}'.format(lr)) @@ -72,13 +77,24 @@ def __init__( if clamp_value < 0.0: raise ValueError('Invalid clamp value: {}'.format(clamp_value)) - defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + adamd_bias_correction=adamd_bias_correction, + ) self.clamp_value = clamp_value self.adam = adam self.debias = debias super(Lamb, self).__init__(params, defaults) + def __setstate__(self, state): + super(Lamb, self).__setstate__(state) + for group in self.param_groups: + group.setdefault('adamd_bias_correction', False) + def step(self, closure: OptLossClosure = None) -> OptFloat: r"""Performs a single optimization step. @@ -129,7 +145,8 @@ def step(self, closure: OptLossClosure = None) -> OptFloat: # Paper v3 does not use debiasing. if self.debias: bias_correction = math.sqrt(1 - beta2 ** state['step']) - bias_correction /= 1 - beta1 ** state['step'] + if not group['adamd_bias_correction']: + bias_correction /= 1 - beta1 ** state['step'] else: bias_correction = 1 diff --git a/torch_optimizer/swats.py b/torch_optimizer/swats.py index b383c42..67c6325 100644 --- a/torch_optimizer/swats.py +++ b/torch_optimizer/swats.py @@ -24,7 +24,10 @@ class SWATS(Optimizer): algorithm from the paper `On the Convergence of Adam and Beyond` (default: False) nesterov: enables Nesterov momentum (default: False) - + adamd_bias_correction: When performing bias correction (debias=True), + only correct the denominator to avoid inflating step sizes early + in training as suggested in `AdamD: Improved bias-correction in + Adam`__ (default: False) Example: >>> import torch_optimizer as optim @@ -34,6 +37,7 @@ class SWATS(Optimizer): >>> optimizer.step() __ https://arxiv.org/pdf/1712.07628.pdf + __ https://arxiv.org/abs/2110.10828 Note: Reference code: https://github.com/Mrpatekful/swats @@ -48,6 +52,7 @@ def __init__( weight_decay: float = 0, amsgrad: bool = False, nesterov: bool = False, + adamd_bias_correction: bool = False, ): if not 0.0 <= lr: raise ValueError('Invalid learning rate: {}'.format(lr)) @@ -73,6 +78,7 @@ def __init__( weight_decay=weight_decay, amsgrad=amsgrad, nesterov=nesterov, + adamd_bias_correction=adamd_bias_correction, ) super().__init__(params, defaults) @@ -82,6 +88,7 @@ def __setstate__(self, state: State) -> None: for group in self.param_groups: group.setdefault('amsgrad', False) group.setdefault('nesterov', False) + group.setdefault('adamd_bias_correction', False) def step(self, closure: OptLossClosure = None) -> OptFloat: r"""Performs a single optimization step. @@ -176,9 +183,15 @@ def step(self, closure: OptLossClosure = None) -> OptFloat: bias_correction1 = 1 - beta1 ** state['step'] bias_correction2 = 1 - beta2 ** state['step'] - step_size = ( - group['lr'] * (bias_correction2 ** 0.5) / bias_correction1 - ) + + if group['adamd_bias_correction']: + step_size = group['lr'] * (bias_correction2 ** 0.5) + else: + step_size = ( + group['lr'] + * (bias_correction2 ** 0.5) + / bias_correction1 + ) p = -step_size * (exp_avg / denom) w.data.add_(p) diff --git a/torch_optimizer/yogi.py b/torch_optimizer/yogi.py index 29d14ae..a0a4794 100644 --- a/torch_optimizer/yogi.py +++ b/torch_optimizer/yogi.py @@ -24,6 +24,10 @@ class Yogi(Optimizer): initial_accumulator: initial values for first and second moments (default: 1e-6) weight_decay: weight decay (L2 penalty) (default: 0) + adamd_bias_correction: When performing bias correction (debias=True), + only correct the denominator to avoid inflating step sizes early + in training as suggested in `AdamD: Improved bias-correction in + Adam`__ (default: False) Example: >>> import torch_optimizer as optim @@ -33,6 +37,7 @@ class Yogi(Optimizer): >>> optimizer.step() __ https://papers.nips.cc/paper/8186-adaptive-methods-for-nonconvex-optimization # noqa + __ https://arxiv.org/abs/2110.10828 Note: Reference code: https://github.com/4rtemi5/Yogi-Optimizer_Keras @@ -46,6 +51,7 @@ def __init__( eps: float = 1e-3, initial_accumulator: float = 1e-6, weight_decay: float = 0, + adamd_bias_correction: bool = False, ) -> None: if lr <= 0.0: raise ValueError('Invalid learning rate: {}'.format(lr)) @@ -70,9 +76,15 @@ def __init__( eps=eps, initial_accumulator=initial_accumulator, weight_decay=weight_decay, + adamd_bias_correction=adamd_bias_correction, ) super(Yogi, self).__init__(params, defaults) + def __setstate__(self, state): + super(Yogi, self).__setstate__(state) + for group in self.param_groups: + group.setdefault('adamd_bias_correction', False) + def step(self, closure: OptLossClosure = None) -> OptFloat: r"""Performs a single optimization step. @@ -142,7 +154,10 @@ def step(self, closure: OptLossClosure = None) -> OptFloat: denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_( group['eps'] ) - step_size = group['lr'] / bias_correction1 + if group['adamd_bias_correction']: + step_size = group['lr'] + else: + step_size = group['lr'] / bias_correction1 p.data.addcdiv_(exp_avg, denom, value=-step_size) return loss