Source code for geoopt.samplers.rsgld
import math
import torch
from geoopt.tensor import ManifoldParameter, ManifoldTensor
from geoopt.samplers.base import Sampler
__all__ = ["RSGLD"]
[docs]class RSGLD(Sampler):
r"""
Riemannian Stochastic Gradient Langevin Dynamics.
Parameters
----------
params : iterable
iterables of tensors for which to perform sampling
epsilon : float
step size
"""
def __init__(self, params, epsilon=1e-3):
defaults = dict(epsilon=epsilon)
super().__init__(params, defaults)
[docs] def step(self, closure):
logp = closure()
logp.backward()
with torch.no_grad():
for group in self.param_groups:
for p in group["params"]:
if isinstance(p, (ManifoldParameter, ManifoldTensor)):
manifold = p.manifold
else:
manifold = self._default_manifold
egrad2rgrad, retr = manifold.egrad2rgrad, manifold.retr
epsilon = group["epsilon"]
n = torch.randn_like(p).mul_(math.sqrt(epsilon))
r = egrad2rgrad(p, 0.5 * epsilon * p.grad + n)
# use copy only for user facing point
p.copy_(retr(p, r))
p.grad.zero_()
if not self.burnin:
self.steps += 1
self.log_probs.append(logp.item())
@torch.no_grad()
def stabilize_group(self, group):
for p in group["params"]:
if not isinstance(p, (ManifoldParameter, ManifoldTensor)):
continue
p.copy_(p.manifold.projx(p))