import torch.optim

from .mixin import OptimMixin
from ..tensor import ManifoldParameter, ManifoldTensor
from ..utils import copy_or_set_

r"""
Riemannian Adam with the same API as :class:torch.optim.Adam.

Parameters
----------
params : iterable
iterable of parameters to optimize or dicts defining
parameter groups
lr : float (optional)
learning rate (default: 1e-3)
betas : Tuple[float, float] (optional)
coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999))
eps : float (optional)
term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay : float (optional)
weight decay (L2 penalty) (default: 0)
whether to use the AMSGrad variant of this
algorithm from the paper On the Convergence of Adam and Beyond_
(default: False)

Other Parameters
----------------
stabilize : int
Stabilize parameters if they are off-manifold due to numerical
reasons every stabilize steps (default: None -- no stabilize)

.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ

"""

[docs]    def step(self, closure=None):
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
if "step" not in group:
group["step"] = 0
betas = group["betas"]
weight_decay = group["weight_decay"]
eps = group["eps"]
learning_rate = group["lr"]
for point in group["params"]:
continue
if isinstance(point, (ManifoldParameter, ManifoldTensor)):
manifold = point.manifold
else:
manifold = self._default_manifold

raise RuntimeError(
)

state = self.state[point]

# State initialization
if len(state) == 0:
state["step"] = 0
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(point)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(point)
# Maintains max of all exp. moving avg. of sq. grad. values
state["max_exp_avg_sq"] = torch.zeros_like(point)
# make local variables for easy access
exp_avg = state["exp_avg"]
exp_avg_sq = state["exp_avg_sq"]
# actual step
)
max_exp_avg_sq = state["max_exp_avg_sq"]
# Maintains the maximum of all 2nd moment running avg. till now
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
# Use the max. for normalizing running avg. of gradient
else:
group["step"] += 1
bias_correction1 = 1 - betas ** group["step"]
bias_correction2 = 1 - betas ** group["step"]
step_size = (
learning_rate * bias_correction2 ** 0.5 / bias_correction1
)

# copy the state, we need it for retraction
# get the direction for ascend
direction = exp_avg / denom
# transport the exponential averaging to the new point
new_point, exp_avg_new = manifold.retr_transp(
point, -step_size * direction, exp_avg
)
# use copy only for user facing point
copy_or_set_(point, new_point)
exp_avg.set_(exp_avg_new)

group["step"] += 1
if self._stabilize is not None and group["step"] % self._stabilize == 0:
self.stabilize_group(group)
return loss