Source code for geoopt.optim.radam
import torch.optim
from .mixin import OptimMixin
from ..tensor import ManifoldParameter, ManifoldTensor
__all__ = ["RiemannianAdam"]
[docs]class RiemannianAdam(OptimMixin, torch.optim.Adam):
r"""
Riemannian Adam with the same API as :class:`torch.optim.Adam`.
Parameters
----------
params : iterable
iterable of parameters to optimize or dicts defining
parameter groups
lr : float (optional)
learning rate (default: 1e-3)
betas : Tuple[float, float] (optional)
coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999))
eps : float (optional)
term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay : float (optional)
weight decay (L2 penalty) (default: 0)
amsgrad : bool (optional)
whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False)
Other Parameters
----------------
stabilize : int
Stabilize parameters if they are off-manifold due to numerical
reasons every ``stabilize`` steps (default: ``None`` -- no stabilize)
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""
[docs] def step(self, closure=None):
loss = None
if closure is not None:
loss = closure()
with torch.no_grad():
for group in self.param_groups:
if "step" not in group:
group["step"] = 0
betas = group["betas"]
weight_decay = group["weight_decay"]
eps = group["eps"]
learning_rate = group["lr"]
amsgrad = group["amsgrad"]
group["step"] += 1
for point in group["params"]:
grad = point.grad
if grad is None:
continue
if isinstance(point, (ManifoldParameter, ManifoldTensor)):
manifold = point.manifold
else:
manifold = self._default_manifold
if grad.is_sparse:
raise RuntimeError(
"RiemannianAdam does not support sparse gradients, use SparseRiemannianAdam instead"
)
state = self.state[point]
# State initialization
if len(state) == 0:
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(point)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(point)
if amsgrad:
# Maintains max of all exp. moving avg. of sq. grad. values
state["max_exp_avg_sq"] = torch.zeros_like(point)
# make local variables for easy access
exp_avg = state["exp_avg"]
exp_avg_sq = state["exp_avg_sq"]
# actual step
grad.add_(point, alpha=weight_decay)
grad = manifold.egrad2rgrad(point, grad)
exp_avg.mul_(betas[0]).add_(grad, alpha=1 - betas[0])
exp_avg_sq.mul_(betas[1]).add_(
manifold.component_inner(point, grad), alpha=1 - betas[1]
)
bias_correction1 = 1 - betas[0] ** group["step"]
bias_correction2 = 1 - betas[1] ** group["step"]
if amsgrad:
max_exp_avg_sq = state["max_exp_avg_sq"]
# Maintains the maximum of all 2nd moment running avg. till now
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
# Use the max. for normalizing running avg. of gradient
denom = max_exp_avg_sq.div(bias_correction2).sqrt_()
else:
denom = exp_avg_sq.div(bias_correction2).sqrt_()
# copy the state, we need it for retraction
# get the direction for ascend
direction = exp_avg.div(bias_correction1) / denom.add_(eps)
# transport the exponential averaging to the new point
new_point, exp_avg_new = manifold.retr_transp(
point, -learning_rate * direction, exp_avg
)
# use copy only for user facing point
point.copy_(new_point)
exp_avg.copy_(exp_avg_new)
if (
group["stabilize"] is not None
and group["step"] % group["stabilize"] == 0
):
self.stabilize_group(group)
return loss
@torch.no_grad()
def stabilize_group(self, group):
for p in group["params"]:
if not isinstance(p, (ManifoldParameter, ManifoldTensor)):
continue
state = self.state[p]
if not state: # due to None grads
continue
manifold = p.manifold
exp_avg = state["exp_avg"]
p.copy_(manifold.projx(p))
exp_avg.copy_(manifold.proju(p, exp_avg))