from typing import Optional, Tuple, Union
import torch
from geoopt import linalg as lalg
from geoopt.utils import COMPLEX_DTYPES
from .siegel import SiegelManifold
from .vvd_metrics import SiegelMetricType
from ..siegel import csym_math as sm
__all__ = ["UpperHalf"]
[docs]class UpperHalf(SiegelManifold):
r"""
Upper Half Space Manifold.
This model generalizes the upper half plane model of the hyperbolic plane.
Points in the space are complex symmetric matrices.
.. math::
\mathcal{S}_n = \{Z = X + iY \in \operatorname{Sym}(n, \mathbb{C}) | Y >> 0 \}.
Parameters
----------
metric: SiegelMetricType
one of Riemannian, Finsler One, Finsler Infinity, Finsler metric of minimum entropy, or learnable weighted sum.
rank: int
Rank of the space. Only mandatory for "fmin" and "wsum" metrics.
"""
name = "Upper Half Space"
def __init__(
self, metric: SiegelMetricType = SiegelMetricType.RIEMANNIAN, rank: int = None
):
super().__init__(metric=metric, rank=rank)
[docs] def egrad2rgrad(self, z: torch.Tensor, u: torch.Tensor) -> torch.Tensor:
r"""
Transform gradient computed using autodiff to the correct Riemannian gradient for the point :math:`Z`.
For a function :math:`f(Z)` on :math:`\mathcal{S}_n`, the gradient is:
.. math::
\operatorname{grad}_{R}(f(Z)) = Y \cdot \operatorname{grad}_E(f(Z)) \cdot Y
where :math:`Y` is the imaginary part of :math:`Z`.
Parameters
----------
z : torch.Tensor
point on the manifold
u : torch.Tensor
gradient to be projected
Returns
-------
torch.Tensor
Riemannian gradient
"""
real_grad, imag_grad = u.real, u.imag
y = z.imag
real_grad = y @ real_grad @ y
imag_grad = y @ imag_grad @ y
return lalg.sym(
sm.to_complex(real_grad, imag_grad)
) # impose symmetry due to numerical instabilities
[docs] def projx(self, z: torch.Tensor) -> torch.Tensor:
"""
Project point :math:`Z` on the manifold.
In this space, we need to ensure that :math:`Y = Im(Z)` is positive definite.
Since the matrix Y is symmetric, it is possible to diagonalize it.
For a diagonal matrix the condition is just that all diagonal entries are positive,
so we clamp the values that are <= 0 in the diagonal to an epsilon, and then restore
the matrix back into non-diagonal form using the base change matrix that was obtained
from the diagonalization.
Parameters
----------
z : torch.Tensor
point on the manifold
Returns
-------
torch.Tensor
Projected points
"""
z = super().projx(z)
y = sm.positive_conjugate_projection(z.imag)
return sm.to_complex(z.real, y)
[docs] def inner(
self, z: torch.Tensor, u: torch.Tensor, v=None, *, keepdim=False
) -> torch.Tensor:
r"""
Inner product for tangent vectors at point :math:`Z`.
The inner product at point :math:`Z = X + iY` of the vectors :math:`U, V` is:
.. math::
g_{Z}(U, V) = \operatorname{Tr}[ Y^{-1} U Y^{-1} \overline{V} ]
Parameters
----------
z : torch.Tensor
point on the manifold
u : torch.Tensor
tangent vector at point :math:`z`
v : torch.Tensor
tangent vector at point :math:`z`
keepdim : bool
keep the last dim?
Returns
-------
torch.Tensor
inner product (broadcasted)
"""
if v is None:
v = u
inv_y = sm.inverse(z.imag).type_as(z)
res = inv_y @ u @ inv_y @ v.conj()
return lalg.trace(res, keepdim=keepdim)
def _check_point_on_manifold(self, z: torch.Tensor, *, atol=1e-5, rtol=1e-5):
if not self._check_matrices_are_symmetric(z, atol=atol, rtol=rtol):
return False, "Matrices are not symmetric"
# Im(Z) should be positive definite.
ok = torch.all(sm.eigvalsh(z.imag) > 0)
if not ok:
reason = "Imaginary part of Z is not positive definite"
else:
reason = None
return ok, reason
[docs] def random(self, *size, dtype=None, device=None, **kwargs) -> torch.Tensor:
if dtype and dtype not in COMPLEX_DTYPES:
raise ValueError(f"dtype must be one of {COMPLEX_DTYPES}")
if dtype is None:
dtype = torch.complex128
tens = 0.5 * torch.randn(*size, dtype=dtype, device=device)
tens = lalg.sym(tens)
tens.imag = lalg.expm(tens.imag)
return tens
[docs] def origin(
self,
*size: Union[int, Tuple[int]],
dtype=None,
device=None,
seed: Optional[int] = 42,
) -> torch.Tensor:
"""
Create points at the origin of the manifold in a deterministic way.
For the Upper half model, the origin is the imaginary identity.
This is, a matrix whose real part is all zeros, and the identity as the imaginary part.
Parameters
----------
size : Union[int, Tuple[int]]
the desired shape
device : torch.device
the desired device
dtype : torch.dtype
the desired dtype
seed : Optional[int]
A parameter controlling deterministic randomness for manifolds that do not provide ``.origin``,
but provide ``.random``. (default: 42)
Returns
-------
torch.Tensor
"""
imag = torch.eye(*size[:-1], dtype=dtype, device=device)
if imag.dtype in COMPLEX_DTYPES:
imag = imag.real
return torch.complex(torch.zeros_like(imag), imag)