# Source code for geoopt.manifolds.siegel.bounded_domain

from typing import Optional, Tuple, Union
import torch
from geoopt import linalg as lalg
from geoopt.utils import COMPLEX_DTYPES
from .siegel import SiegelManifold
from .upper_half import UpperHalf
from .vvd_metrics import SiegelMetricType
from ..siegel import csym_math as sm

__all__ = ["BoundedDomain"]

[docs]class BoundedDomain(SiegelManifold):
r"""
Bounded domain Manifold.

This model generalizes the Poincare ball model.
Points in the space are complex symmetric matrices.

.. math::

\mathcal{B}_n := \{ Z \in \operatorname{Sym}(n, \mathbb{C}) | Id - Z^*Z >> 0 \}

Parameters
----------
metric: SiegelMetricType
one of Riemannian, Finsler One, Finsler Infinity, Finsler metric of minimum entropy, or learnable weighted sum.
rank: int
Rank of the space. Only mandatory for "fmin" and "wsum" metrics.
"""

name = "Bounded Domain"

def __init__(
self, metric: SiegelMetricType = SiegelMetricType.RIEMANNIAN, rank: int = None
):
super().__init__(metric=metric, rank=rank)

[docs]    def dist(
self, z1: torch.Tensor, z2: torch.Tensor, *, keepdim=False
) -> torch.Tensor:
"""
Compute distance in the Bounded domain model.

To compute distances in the Bounded Domain Model we need to map the elements to the
Upper Half Space Model by means of the Cayley Transform, and then compute distances
in that domain.

Parameters
----------
z1 : torch.Tensor
point on the manifold
z2 : torch.Tensor
point on the manifold
keepdim : bool, optional
keep the last dim?, by default False

Returns
-------
torch.Tensor
distance between two points
"""
uhsm_z1 = sm.cayley_transform(z1)
uhsm_z2 = sm.cayley_transform(z2)
return super().dist(uhsm_z1, uhsm_z2)

r"""
Transform gradient computed using autodiff to the correct Riemannian gradient for the point :math:Z.

For a function :math:f(Z) on :math:\mathcal{B}_n, the gradient is:

.. math::

where :math:A = Id - \overline{Z}Z

Parameters
----------
z : torch.Tensor
point on the manifold
u : torch.Tensor

Returns
-------
torch.Tensor
"""
a = get_id_minus_conjugate_z_times_z(z)
return lalg.sym(a @ u @ a)  # impose symmetry due to numerical instabilities

[docs]    def projx(self, z: torch.Tensor) -> torch.Tensor:
r"""
Project point :math:Z on the manifold.

In the Bounded domain model, we need to ensure that
:math:Id - \overline(Z)Z is positive definite.

Steps to project: Z complex symmetric matrix
1) Diagonalize Z: :math:Z = \overline{S} D S^*
2) Clamp eigenvalues: :math:D' = clamp(D, max=1 - epsilon)
3) Rebuild Z: :math:Z' = \overline{S} D' S^*

Parameters
----------
z : torch.Tensor
point on the manifold

Returns
-------
torch.Tensor
Projected points
"""
z = super().projx(z)

evalues, s = sm.takagi_eig(z)
eps = sm.EPS[evalues.dtype]
evalues_tilde = torch.clamp(evalues, max=1 - eps)

diag_tilde = torch.diag_embed(evalues_tilde).type_as(z)
z_tilde = s.conj() @ diag_tilde @ s.conj().transpose(-1, -2)

# we do this so no operation is applied on the points that already belong to the space.
# This prevents modifying values due to numerical instabilities
batch_wise_mask = torch.all(evalues < 1 - eps, dim=-1, keepdim=True)

[docs]    def inner(
self, z: torch.Tensor, u: torch.Tensor, v=None, *, keepdim=False
) -> torch.Tensor:
r"""
Inner product for tangent vectors at point :math:Z.

The inner product at point :math:Z = X + iY of the vectors :math:U, V is:

.. math::

g_{Z}(U, V) = \operatorname{Tr}[(Id - \overline{Z}Z)^{-1} U (Id - Z\overline{Z})^{-1} \overline{V}]

Parameters
----------
z : torch.Tensor
point on the manifold
u : torch.Tensor
tangent vector at point :math:z
v : torch.Tensor
tangent vector at point :math:z
keepdim : bool
keep the last dim?

Returns
-------
torch.Tensor
"""
if v is None:
v = u
identity = sm.identity_like(z)
conj_z = z.conj()

inv_id_minus_conjz_z = sm.inverse(identity - (conj_z @ z))
inv_id_minus_z_conjz = sm.inverse(identity - (z @ conj_z))

res = inv_id_minus_conjz_z @ u @ inv_id_minus_z_conjz @ v.conj()
return lalg.trace(res, keepdim=keepdim)

def _check_point_on_manifold(self, x: torch.Tensor, *, atol=1e-4, rtol=1e-5):
if not self._check_matrices_are_symmetric(x, atol=atol, rtol=rtol):
return False, "Matrices are not symmetric"

# Id - \overline{Z}Z is Hermitian and should be positive definite
id_minus_zz = get_id_minus_conjugate_z_times_z(x)
ok = torch.all(sm.eigvalsh(id_minus_zz) > 0)
reason = None if ok else "'Id - overline{Z}Z' is not definite positive"
return ok, reason

[docs]    def random(self, *size, dtype=None, device=None, **kwargs) -> torch.Tensor:
points = UpperHalf().random(*size, dtype=dtype, device=device, **kwargs)
return sm.inverse_cayley_transform(points)

[docs]    def origin(
self,
*size: Union[int, Tuple[int]],
dtype=None,
device=None,
seed: Optional[int] = 42,
) -> torch.Tensor:
"""
Create points at the origin of the manifold in a deterministic way.

For the Bounded domain model, the origin is the zero matrix.
This is, a matrix whose real and imaginary parts are all zeros.

Parameters
----------
size : Union[int, Tuple[int]]
the desired shape
device : torch.device
the desired device
dtype : torch.dtype
the desired dtype
seed : Optional[int]
A parameter controlling deterministic randomness for manifolds that do not provide .origin,
but provide .random. (default: 42)

Returns
-------
torch.Tensor
"""
if dtype and dtype not in COMPLEX_DTYPES:
raise ValueError(f"dtype must be one of {COMPLEX_DTYPES}")
if dtype is None:
dtype = torch.complex128
r"""Given a complex symmetric matrix :math:Z, it returns an Hermitian matrix :math:Id - \overline{Z}Z."""