Source code for geoopt.samplers.rhmc

import math

import numpy as np
import torch

from geoopt.tensor import ManifoldParameter, ManifoldTensor
from geoopt.samplers.base import Sampler
from ..utils import copy_or_set_

__all__ = ["RHMC"]


[docs]class RHMC(Sampler): r""" Riemannian Hamiltonian Monte-Carlo. Parameters ---------- params : iterable iterables of tensors for which to perform sampling epsilon : float step size n_steps : int number of leapfrog steps """ def __init__(self, params, epsilon=1e-3, n_steps=1): defaults = dict(epsilon=epsilon) super().__init__(params, defaults) self.n_steps = n_steps def _step(self, p, r, epsilon): if isinstance(p, (ManifoldParameter, ManifoldTensor)): manifold = p.manifold else: manifold = self._default_manifold egrad2rgrad = manifold.egrad2rgrad retr_transp = manifold.retr_transp r.add_(epsilon * egrad2rgrad(p, p.grad)) p_, r_ = retr_transp(p, r * epsilon, r) copy_or_set_(p, p_) r.set_(r_)
[docs] def step(self, closure): logp = closure() logp.backward() old_logp = logp.item() old_H = -old_logp with torch.no_grad(): for group in self.param_groups: for p in group["params"]: if p.grad is None: continue if isinstance(p, (ManifoldParameter, ManifoldTensor)): manifold = p.manifold else: manifold = self._default_manifold egrad2rgrad = manifold.egrad2rgrad state = self.state[p] if "r" not in state: state["old_p"] = torch.zeros_like(p) state["old_r"] = torch.zeros_like(p) state["r"] = torch.zeros_like(p) r = state["r"] r.normal_() r.set_(egrad2rgrad(p, r)) old_H += 0.5 * (r * r).sum().item() state["old_p"].copy_(p) state["old_r"].copy_(r) epsilon = group["epsilon"] self._step(p, r, epsilon) p.grad.zero_() for i in range(1, self.n_steps): logp = closure() logp.backward() with torch.no_grad(): for group in self.param_groups: for p in group["params"]: if p.grad is None: continue self._step(p, self.state[p]["r"], group["epsilon"]) p.grad.zero_() logp = closure() logp.backward() new_logp = logp.item() new_H = -new_logp with torch.no_grad(): for group in self.param_groups: for p in group["params"]: if p.grad is None: continue if isinstance(p, (ManifoldParameter, ManifoldTensor)): manifold = p.manifold else: manifold = self._default_manifold egrad2rgrad = manifold.egrad2rgrad r = self.state[p]["r"] r.add_(0.5 * epsilon * egrad2rgrad(p, p.grad)) p.grad.zero_() new_H += 0.5 * (r * r).sum().item() rho = min(1.0, math.exp(old_H - new_H)) if not self.burnin: self.steps += 1 self.acceptance_probs.append(rho) if np.random.rand(1) >= rho: # reject if not self.burnin: self.n_rejected += 1 for group in self.param_groups: for p in group["params"]: if p.grad is None: continue state = self.state[p] r = state["r"] p.copy_(state["old_p"]) r.copy_(state["old_r"]) self.log_probs.append(old_logp) else: self.log_probs.append(new_logp)
@torch.no_grad() def stabilize_group(self, group): for p in group["params"]: if not isinstance(p, (ManifoldParameter, ManifoldTensor)): continue copy_or_set_(p, p.manifold.projx(p)) state = self.state[p] if not state: # due to None grads continue copy_or_set_(state["old_p"], p.manifold.projx(state["old_p"]))