Source code for geoopt.optim.rsgd

import torch.optim.optimizer
from ..tensor import ManifoldParameter, ManifoldTensor
from .mixin import OptimMixin
from ..utils import copy_or_set_

__all__ = ["RiemannianSGD"]

[docs]class RiemannianSGD(OptimMixin, torch.optim.Optimizer): r""" Riemannian Stochastic Gradient Descent with the same API as :class:`torch.optim.SGD`. Parameters ---------- params : iterable iterable of parameters to optimize or dicts defining parameter groups lr : float learning rate momentum : float (optional) momentum factor (default: 0) weight_decay : float (optional) weight decay (L2 penalty) (default: 0) dampening : float (optional) dampening for momentum (default: 0) nesterov : bool (optional) enables Nesterov momentum (default: False) Other Parameters ---------------- stabilize : int Stabilize parameters if they are off-manifold due to numerical reasons every ``stabilize`` steps (default: ``None`` -- no stabilize) """ def __init__( self, params, lr, momentum=0, dampening=0, weight_decay=0, nesterov=False, stabilize=None, ): if lr < 0.0: raise ValueError("Invalid learning rate: {}".format(lr)) if momentum < 0.0: raise ValueError("Invalid momentum value: {}".format(momentum)) if weight_decay < 0.0: raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) defaults = dict( lr=lr, momentum=momentum, dampening=dampening, weight_decay=weight_decay, nesterov=nesterov, ) if nesterov and (momentum <= 0 or dampening != 0): raise ValueError("Nesterov momentum requires a momentum and zero dampening") super().__init__(params, defaults, stabilize=stabilize)
[docs] def step(self, closure=None): loss = None if closure is not None: loss = closure() with torch.no_grad(): for group in self.param_groups: if "step" not in group: group["step"] = 0 weight_decay = group["weight_decay"] momentum = group["momentum"] dampening = group["dampening"] nesterov = group["nesterov"] learning_rate = group["lr"] for point in group["params"]: grad = point.grad if grad is None: continue state = self.state[point] # State initialization if len(state) == 0: if momentum > 0: state["momentum_buffer"] = grad.clone() if isinstance(point, (ManifoldParameter, ManifoldTensor)): manifold = point.manifold else: manifold = self._default_manifold grad.add_(weight_decay, point) grad = manifold.egrad2rgrad(point, grad) if momentum > 0: momentum_buffer = state["momentum_buffer"] momentum_buffer.mul_(momentum).add_(1 - dampening, grad) if nesterov: grad = grad.add_(momentum, momentum_buffer) else: grad = momentum_buffer # we have all the things projected new_point, new_momentum_buffer = manifold.retr_transp( point, -learning_rate * grad, momentum_buffer ) momentum_buffer.set_(new_momentum_buffer) # use copy only for user facing point copy_or_set_(point, new_point) else: new_point = manifold.retr(point, -learning_rate * grad) copy_or_set_(point, new_point) group["step"] += 1 if self._stabilize is not None and group["step"] % self._stabilize == 0: self.stabilize_group(group) return loss
@torch.no_grad() def stabilize_group(self, group): for p in group["params"]: if not isinstance(p, (ManifoldParameter, ManifoldTensor)): continue manifold = p.manifold momentum = group["momentum"] copy_or_set_(p, manifold.projx(p)) if momentum > 0: param_state = self.state[p] if not param_state: # due to None grads continue if "momentum_buffer" in param_state: buf = param_state["momentum_buffer"] buf.set_(manifold.proju(p, buf))