Source code for geoopt.manifolds.stiefel

import torch
from typing import Union, Tuple, Optional
from .base import Manifold
from .. import linalg
from ..utils import size2shape
from ..tensor import ManifoldTensor

__all__ = ["Stiefel", "EuclideanStiefel", "CanonicalStiefel", "EuclideanStiefelExact"]

_stiefel_doc = r"""
    Manifold induced by the following matrix constraint:

    .. math::

        X^\top X = I\\
        X \in \mathrm{R}^{n\times m}\\
        n \ge m

[docs]class Stiefel(Manifold): __doc__ = r""" {} Parameters ---------- canonical : bool Use canonical inner product instead of euclidean one (defaults to canonical) See Also -------- :class:`CanonicalStiefel`, :class:`EuclideanStiefel`, :class:`EuclideanStiefelExact` """.format( _stiefel_doc ) ndim = 2 def __new__(cls, canonical=True): if cls is Stiefel: if canonical: return super().__new__(CanonicalStiefel) else: return super().__new__(EuclideanStiefel) else: return super().__new__(cls) def _check_shape( self, shape: Tuple[int], name: str ) -> Union[Tuple[bool, Optional[str]], bool]: ok, reason = super()._check_shape(shape, name) if not ok: return False, reason shape_is_ok = shape[-1] <= shape[-2] if not shape_is_ok: return ( False, "`{}` should have shape[-1] <= shape[-2], got {} </= {}".format( name, shape[-1], shape[-2] ), ) return True, None def _check_point_on_manifold( self, x: torch.Tensor, *, atol=1e-5, rtol=1e-5 ) -> Union[Tuple[bool, Optional[str]], bool]: xtx = x.transpose(-1, -2) @ x # less memory usage for substract diagonal xtx[..., torch.arange(x.shape[-1]), torch.arange(x.shape[-1])] -= 1 ok = torch.allclose(xtx,,)).fill_(0), atol=atol, rtol=rtol) if not ok: return False, "`X^T X != I` with atol={}, rtol={}".format(atol, rtol) return True, None def _check_vector_on_tangent( self, x: torch.Tensor, u: torch.Tensor, *, atol=1e-5, rtol=1e-5 ) -> Union[Tuple[bool, Optional[str]], bool]: diff = u.transpose(-1, -2) @ x + x.transpose(-1, -2) @ u ok = torch.allclose(diff,,)).fill_(0), atol=atol, rtol=rtol) if not ok: return False, "`u^T x + x^T u !=0` with atol={}, rtol={}".format(atol, rtol) return True, None
[docs] def projx(self, x: torch.Tensor) -> torch.Tensor: U, _, V = linalg.svd(x, full_matrices=False) return torch.einsum("...ik,...kj->...ij", U, V)
[docs] def random_naive(self, *size, dtype=None, device=None) -> torch.Tensor: """ Naive approach to get random matrix on Stiefel manifold. A helper function to sample a random point on the Stiefel manifold. The measure is non-uniform for this method, but fast to compute. Parameters ---------- size : shape the desired output shape dtype : torch.dtype desired dtype device : torch.device desired device Returns ------- ManifoldTensor random point on Stiefel manifold """ self._assert_check_shape(size2shape(*size), "x") tens = torch.randn(*size, device=device, dtype=dtype) return ManifoldTensor(linalg.qr(tens)[0], manifold=self)
random = random_naive
[docs] def origin(self, *size, dtype=None, device=None, seed=42) -> torch.Tensor: """ Identity matrix point origin. Parameters ---------- size : shape the desired shape device : torch.device the desired device dtype : torch.dtype the desired dtype seed : int ignored Returns ------- ManifoldTensor """ self._assert_check_shape(size2shape(*size), "x") eye = torch.zeros(*size, dtype=dtype, device=device) eye[..., torch.arange(eye.shape[-1]), torch.arange(eye.shape[-1])] += 1 return ManifoldTensor(eye, manifold=self)
[docs]class CanonicalStiefel(Stiefel): __doc__ = r"""Stiefel Manifold with Canonical inner product {} """.format( _stiefel_doc ) name = "Stiefel(canonical)" reversible = True @staticmethod def _amat(x: torch.Tensor, u: torch.Tensor) -> torch.Tensor: return u @ x.transpose(-1, -2) - x @ u.transpose(-1, -2)
[docs] def inner( self, x: torch.Tensor, u: torch.Tensor, v: torch.Tensor = None, *, keepdim=False ) -> torch.Tensor: # <u, v>_x = tr(u^T(I-1/2xx^T)v) # = tr(u^T(v-1/2xx^Tv)) # = tr(u^Tv-1/2u^Txx^Tv) # = tr(u^Tv-1/2u^Txx^Tv) # = tr(u^Tv)-1/2tr(x^Tvu^Tx) # = \sum_ij{(u*v}_ij}-1/2\sum_ij{(x^Tv * x^Tu)_ij} xtu = x.transpose(-1, -2) @ u if v is None: xtv = xtu v = u else: xtv = x.transpose(-1, -2) @ v return (u * v).sum([-1, -2], keepdim=keepdim) - 0.5 * (xtv * xtu).sum( [-1, -2], keepdim=keepdim )
def _transp_follow_one( self, x: torch.Tensor, v: torch.Tensor, *, u: torch.Tensor ) -> torch.Tensor: a = self._amat(x, u) rhs = v + 1 / 2 * a @ v lhs = -1 / 2 * a lhs[..., torch.arange(a.shape[-2]), torch.arange(x.shape[-2])] += 1 qv = torch.linalg.solve(lhs, rhs) return qv
[docs] def transp_follow_retr( self, x: torch.Tensor, u: torch.Tensor, v: torch.Tensor ) -> torch.Tensor: return self._transp_follow_one(x, v, u=u)
transp_follow_expmap = transp_follow_retr
[docs] def retr_transp( self, x: torch.Tensor, u: torch.Tensor, v: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: xvs =, v), -1) qxvs = self._transp_follow_one(x, xvs, u=u).view( x.shape[:-1] + (2, x.shape[-1]) ) new_x, new_v = qxvs.unbind(-2) return new_x, new_v
expmap_transp = retr_transp
[docs] def proju(self, x: torch.Tensor, u: torch.Tensor) -> torch.Tensor: return u - x @ u.transpose(-1, -2) @ x
egrad2rgrad = proju
[docs] def retr(self, x: torch.Tensor, u: torch.Tensor) -> torch.Tensor: return self._transp_follow_one(x, x, u=u)
expmap = retr
[docs]class EuclideanStiefel(Stiefel): __doc__ = r"""Stiefel Manifold with Euclidean inner product {} """.format( _stiefel_doc ) name = "Stiefel(euclidean)" reversible = False
[docs] def proju(self, x: torch.Tensor, u: torch.Tensor) -> torch.Tensor: return u - x @ linalg.sym(x.transpose(-1, -2) @ u)
egrad2rgrad = proju
[docs] def transp(self, x: torch.Tensor, y: torch.Tensor, v: torch.Tensor) -> torch.Tensor: return self.proju(y, v)
[docs] def inner( self, x: torch.Tensor, u: torch.Tensor, v: torch.Tensor = None, *, keepdim=False ) -> torch.Tensor: if v is None: v = u return (u * v).sum([-1, -2], keepdim=keepdim)
[docs] def retr(self, x: torch.Tensor, u: torch.Tensor) -> torch.Tensor: q, r = linalg.qr(x + u) unflip = linalg.extract_diag(r).sign().add(0.5).sign() q *= unflip[..., None, :] return q
[docs] def expmap(self, x: torch.Tensor, u: torch.Tensor) -> torch.Tensor: xtu = x.transpose(-1, -2) @ u utu = u.transpose(-1, -2) @ u eye = torch.zeros_like(utu) eye[..., torch.arange(utu.shape[-2]), torch.arange(utu.shape[-2])] += 1 logw = linalg.block_matrix(((xtu, -utu), (eye, xtu))) w = linalg.expm(logw) z =, torch.zeros_like(utu)), dim=-2) y =, u), dim=-1) @ w @ z return y
[docs]class EuclideanStiefelExact(EuclideanStiefel): __doc__ = r"""{} Notes ----- The implementation of retraction is an exact exponential map, this retraction will be used in optimization """.format( EuclideanStiefel.__doc__ ) retr_transp = EuclideanStiefel.expmap_transp transp_follow_retr = EuclideanStiefel.transp_follow_expmap retr = EuclideanStiefel.expmap
[docs] def extra_repr(self): return "exact"