# Source code for geoopt.manifolds.stiefel

import torch
from typing import Union, Tuple, Optional
from .base import Manifold
from .. import linalg
from ..utils import size2shape
from ..tensor import ManifoldTensor

__all__ = ["Stiefel", "EuclideanStiefel", "CanonicalStiefel", "EuclideanStiefelExact"]

_stiefel_doc = r"""
Manifold induced by the following matrix constraint:

.. math::

X^\top X = I\\
X \in \mathrm{R}^{n\times m}\\
n \ge m
"""

[docs]class Stiefel(Manifold):
__doc__ = r"""
{}

Parameters
----------
canonical : bool
Use canonical inner product instead of euclidean one (defaults to canonical)

--------
:class:CanonicalStiefel, :class:EuclideanStiefel, :class:EuclideanStiefelExact
""".format(
_stiefel_doc
)
ndim = 2

def __new__(cls, canonical=True):
if cls is Stiefel:
if canonical:
return super().__new__(CanonicalStiefel)
else:
return super().__new__(EuclideanStiefel)
else:
return super().__new__(cls)

def _check_shape(
self, shape: Tuple[int], name: str
) -> Union[Tuple[bool, Optional[str]], bool]:
ok, reason = super()._check_shape(shape, name)
if not ok:
return False, reason
shape_is_ok = shape[-1] <= shape[-2]
if not shape_is_ok:
return (
False,
"{} should have shape[-1] <= shape[-2], got {} </= {}".format(
name, shape[-1], shape[-2]
),
)
return True, None

def _check_point_on_manifold(
self, x: torch.Tensor, *, atol=1e-5, rtol=1e-5
) -> Union[Tuple[bool, Optional[str]], bool]:
xtx = x.transpose(-1, -2) @ x
# less memory usage for substract diagonal
xtx[..., torch.arange(x.shape[-1]), torch.arange(x.shape[-1])] -= 1
ok = torch.allclose(xtx, xtx.new((1,)).fill_(0), atol=atol, rtol=rtol)
if not ok:
return False, "X^T X != I with atol={}, rtol={}".format(atol, rtol)
return True, None

def _check_vector_on_tangent(
self, x: torch.Tensor, u: torch.Tensor, *, atol=1e-5, rtol=1e-5
) -> Union[Tuple[bool, Optional[str]], bool]:
diff = u.transpose(-1, -2) @ x + x.transpose(-1, -2) @ u
ok = torch.allclose(diff, diff.new((1,)).fill_(0), atol=atol, rtol=rtol)
if not ok:
return False, "u^T x + x^T u !=0 with atol={}, rtol={}".format(atol, rtol)
return True, None

[docs]    def projx(self, x: torch.Tensor) -> torch.Tensor:
U, _, V = linalg.svd(x, full_matrices=False)

[docs]    def random_naive(self, *size, dtype=None, device=None) -> torch.Tensor:
"""
Naive approach to get random matrix on Stiefel manifold.

A helper function to sample a random point on the Stiefel manifold.
The measure is non-uniform for this method, but fast to compute.

Parameters
----------
size : shape
the desired output shape
dtype : torch.dtype
desired dtype
device : torch.device
desired device

Returns
-------
ManifoldTensor
random point on Stiefel manifold
"""
self._assert_check_shape(size2shape(*size), "x")
tens = torch.randn(*size, device=device, dtype=dtype)
return ManifoldTensor(linalg.qr(tens)[0], manifold=self)

random = random_naive

[docs]    def origin(self, *size, dtype=None, device=None, seed=42) -> torch.Tensor:
"""
Identity matrix point origin.

Parameters
----------
size : shape
the desired shape
device : torch.device
the desired device
dtype : torch.dtype
the desired dtype
seed : int
ignored

Returns
-------
ManifoldTensor
"""
self._assert_check_shape(size2shape(*size), "x")
eye = torch.zeros(*size, dtype=dtype, device=device)
eye[..., torch.arange(eye.shape[-1]), torch.arange(eye.shape[-1])] += 1
return ManifoldTensor(eye, manifold=self)

[docs]class CanonicalStiefel(Stiefel):
__doc__ = r"""Stiefel Manifold with Canonical inner product

{}
""".format(
_stiefel_doc
)

name = "Stiefel(canonical)"
reversible = True

@staticmethod
def _amat(x: torch.Tensor, u: torch.Tensor) -> torch.Tensor:
return u @ x.transpose(-1, -2) - x @ u.transpose(-1, -2)

[docs]    def inner(
self, x: torch.Tensor, u: torch.Tensor, v: torch.Tensor = None, *, keepdim=False
) -> torch.Tensor:
# <u, v>_x = tr(u^T(I-1/2xx^T)v)
# = tr(u^T(v-1/2xx^Tv))
# = tr(u^Tv-1/2u^Txx^Tv)
# = tr(u^Tv-1/2u^Txx^Tv)
# = tr(u^Tv)-1/2tr(x^Tvu^Tx)
# = \sum_ij{(u*v}_ij}-1/2\sum_ij{(x^Tv * x^Tu)_ij}
xtu = x.transpose(-1, -2) @ u
if v is None:
xtv = xtu
v = u
else:
xtv = x.transpose(-1, -2) @ v
return (u * v).sum([-1, -2], keepdim=keepdim) - 0.5 * (xtv * xtu).sum(
[-1, -2], keepdim=keepdim
)

def _transp_follow_one(
self, x: torch.Tensor, v: torch.Tensor, *, u: torch.Tensor
) -> torch.Tensor:
a = self._amat(x, u)
rhs = v + 1 / 2 * a @ v
lhs = -1 / 2 * a
lhs[..., torch.arange(a.shape[-2]), torch.arange(x.shape[-2])] += 1
qv = torch.linalg.solve(lhs, rhs)
return qv

[docs]    def transp_follow_retr(
self, x: torch.Tensor, u: torch.Tensor, v: torch.Tensor
) -> torch.Tensor:
return self._transp_follow_one(x, v, u=u)

transp_follow_expmap = transp_follow_retr

[docs]    def retr_transp(
self, x: torch.Tensor, u: torch.Tensor, v: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
xvs = torch.cat((x, v), -1)
qxvs = self._transp_follow_one(x, xvs, u=u).view(
x.shape[:-1] + (2, x.shape[-1])
)
new_x, new_v = qxvs.unbind(-2)
return new_x, new_v

expmap_transp = retr_transp

[docs]    def proju(self, x: torch.Tensor, u: torch.Tensor) -> torch.Tensor:
return u - x @ u.transpose(-1, -2) @ x

[docs]    def retr(self, x: torch.Tensor, u: torch.Tensor) -> torch.Tensor:
return self._transp_follow_one(x, x, u=u)

expmap = retr

[docs]class EuclideanStiefel(Stiefel):
__doc__ = r"""Stiefel Manifold with Euclidean inner product

{}
""".format(
_stiefel_doc
)

name = "Stiefel(euclidean)"
reversible = False

[docs]    def proju(self, x: torch.Tensor, u: torch.Tensor) -> torch.Tensor:
return u - x @ linalg.sym(x.transpose(-1, -2) @ u)

[docs]    def transp(self, x: torch.Tensor, y: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
return self.proju(y, v)

[docs]    def inner(
self, x: torch.Tensor, u: torch.Tensor, v: torch.Tensor = None, *, keepdim=False
) -> torch.Tensor:
if v is None:
v = u
return (u * v).sum([-1, -2], keepdim=keepdim)

[docs]    def retr(self, x: torch.Tensor, u: torch.Tensor) -> torch.Tensor:
q, r = linalg.qr(x + u)
q *= unflip[..., None, :]
return q

[docs]    def expmap(self, x: torch.Tensor, u: torch.Tensor) -> torch.Tensor:
xtu = x.transpose(-1, -2) @ u
utu = u.transpose(-1, -2) @ u
eye = torch.zeros_like(utu)
eye[..., torch.arange(utu.shape[-2]), torch.arange(utu.shape[-2])] += 1
logw = linalg.block_matrix(((xtu, -utu), (eye, xtu)))
w = linalg.expm(logw)
z = torch.cat((linalg.expm(-xtu), torch.zeros_like(utu)), dim=-2)
y = torch.cat((x, u), dim=-1) @ w @ z
return y

[docs]class EuclideanStiefelExact(EuclideanStiefel):
__doc__ = r"""{}

Notes
-----
The implementation of retraction is an exact exponential map, this retraction will be used in optimization
""".format(
EuclideanStiefel.__doc__
)

retr_transp = EuclideanStiefel.expmap_transp
transp_follow_retr = EuclideanStiefel.transp_follow_expmap
retr = EuclideanStiefel.expmap

[docs]    def extra_repr(self):
return "exact"