Source code for geoopt.manifolds.lorentz

import torch as th
import torch.nn
import numpy as np
from typing import Tuple, Optional
from . import math
import geoopt
from ..base import Manifold, ScalingInfo
from ...utils import size2shape, broadcast_shapes

__all__ = ["Lorentz"]

_lorentz_ball_doc = r"""
    Lorentz model

    Parameters
    ----------
    k : float|tensor
        manifold negative curvature

    Notes
    -----
    It is extremely recommended to work with this manifold in double precision
"""


[docs]class Lorentz(Manifold): __doc__ = r"""{} """.format( _lorentz_ball_doc ) ndim = 1 reversible = False name = "Lorentz" __scaling__ = Manifold.__scaling__.copy() def __init__(self, k=1.0, learnable=False): super().__init__() k = torch.as_tensor(k) if not torch.is_floating_point(k): k = k.to(torch.get_default_dtype()) self.k = torch.nn.Parameter(k, requires_grad=learnable) def _check_point_on_manifold( self, x: torch.Tensor, *, atol=1e-5, rtol=1e-5, dim=-1 ) -> Tuple[bool, Optional[str]]: dn = x.size(dim) - 1 x = x**2 quad_form = -x.narrow(dim, 0, 1) + x.narrow(dim, 1, dn).sum( dim=dim, keepdim=True ) ok = torch.allclose(quad_form, -self.k, atol=atol, rtol=rtol) if not ok: reason = f"'x' minkowski quadratic form is not equal to {-self.k.item()}" else: reason = None return ok, reason def _check_vector_on_tangent( self, x: torch.Tensor, u: torch.Tensor, *, atol=1e-5, rtol=1e-5, dim=-1 ) -> Tuple[bool, Optional[str]]: inner_ = math.inner(u, x, dim=dim) ok = torch.allclose(inner_, torch.zeros(1), atol=atol, rtol=rtol) if not ok: reason = "Minkowski inner produt is not equal to zero" else: reason = None return ok, reason
[docs] def dist( self, x: torch.Tensor, y: torch.Tensor, *, keepdim=False, dim=-1 ) -> torch.Tensor: return math.dist(x, y, k=self.k, keepdim=keepdim, dim=dim)
@__scaling__(ScalingInfo(1)) def dist0(self, x: torch.Tensor, *, dim=-1, keepdim=False) -> torch.Tensor: return math.dist0(x, k=self.k, dim=dim, keepdim=keepdim)
[docs] def norm(self, u: torch.Tensor, *, keepdim=False, dim=-1) -> torch.Tensor: return math.norm(u, keepdim=keepdim, dim=dim)
def egrad2rgrad(self, x: torch.Tensor, u: torch.Tensor, *, dim=-1) -> torch.Tensor: return math.egrad2rgrad(x, u, dim=dim)
[docs] def projx(self, x: torch.Tensor, *, dim=-1) -> torch.Tensor: return math.project(x, k=self.k, dim=dim)
[docs] def proju(self, x: torch.Tensor, v: torch.Tensor, *, dim=-1) -> torch.Tensor: v = math.project_u(x, v, k=self.k, dim=dim) return v
[docs] def expmap( self, x: torch.Tensor, u: torch.Tensor, *, norm_tan=True, project=True, dim=-1 ) -> torch.Tensor: if norm_tan is True: u = self.proju(x, u, dim=dim) res = math.expmap(x, u, k=self.k, dim=dim) if project is True: return math.project(res, k=self.k, dim=dim) else: return res
@__scaling__(ScalingInfo(u=-1)) def expmap0(self, u: torch.Tensor, *, project=True, dim=-1) -> torch.Tensor: res = math.expmap0(u, k=self.k, dim=dim) if project: return math.project(res, k=self.k, dim=dim) else: return res
[docs] def logmap(self, x: torch.Tensor, y: torch.Tensor, *, dim=-1) -> torch.Tensor: return math.logmap(x, y, k=self.k, dim=dim)
@__scaling__(ScalingInfo(1)) def logmap0(self, y: torch.Tensor, *, dim=-1) -> torch.Tensor: return math.logmap0(y, k=self.k, dim=dim) def logmap0back(self, x: torch.Tensor, *, dim=-1) -> torch.Tensor: return math.logmap0back(x, k=self.k, dim=dim)
[docs] def inner( self, x: torch.Tensor, u: torch.Tensor, v: torch.Tensor = None, *, keepdim=False, dim=-1, ) -> torch.Tensor: # TODO: x argument for maintaining the support of optims if v is None: v = u return math.inner(u, v, dim=dim, keepdim=keepdim)
def inner0( self, v: torch.Tensor = None, *, keepdim=False, dim=-1, ) -> torch.Tensor: return math.inner0(v, k=self.k, dim=dim, keepdim=keepdim)
[docs] def egrad2rgrad(self, x: torch.Tensor, u: torch.Tensor, *, dim=-1) -> torch.Tensor: return math.egrad2rgrad(x, u, k=self.k, dim=dim)
[docs] def transp( self, x: torch.Tensor, y: torch.Tensor, v: torch.Tensor, *, dim=-1 ) -> torch.Tensor: return math.parallel_transport(x, y, v, k=self.k, dim=dim)
def transp0(self, y: torch.Tensor, u: torch.Tensor, *, dim=-1) -> torch.Tensor: return math.parallel_transport0(y, u, k=self.k, dim=dim) def transp0back(self, x: torch.Tensor, u: torch.Tensor, *, dim=-1) -> torch.Tensor: return math.parallel_transport0back(x, u, k=self.k, dim=dim)
[docs] def transp_follow_expmap( self, x: torch.Tensor, u: torch.Tensor, v: torch.Tensor, *, dim=-1, project=True ) -> torch.Tensor: y = self.expmap(x, u, dim=dim, project=project) return self.transp(x, y, v, dim=dim)
@__scaling__(ScalingInfo(t=-1)) def geodesic_unit( self, t: torch.Tensor, x: torch.Tensor, u: torch.Tensor, *, dim=-1, project=True ) -> torch.Tensor: res = math.geodesic_unit(t, x, u, k=self.k) if project: return math.project(res, k=self.k, dim=dim) else: return res
[docs] @__scaling__(ScalingInfo(std=-1), "random") def random_normal( self, *size, mean=0, std=1, dtype=None, device=None ) -> "geoopt.ManifoldTensor": r""" Create a point on the manifold, measure is induced by Normal distribution on the tangent space of zero. Parameters ---------- size : shape the desired shape mean : float|tensor mean value for the Normal distribution std : float|tensor std value for the Normal distribution dtype: torch.dtype target dtype for sample, if not None, should match Manifold dtype device: torch.device target device for sample, if not None, should match Manifold device Returns ------- ManifoldTensor random points on Hyperboloid Notes ----- The device and dtype will match the device and dtype of the Manifold """ self._assert_check_shape(size2shape(*size), "x") if device is not None and device != self.k.device: raise ValueError( "`device` does not match the projector `device`, set the `device` argument to None" ) if dtype is not None and dtype != self.k.dtype: raise ValueError( "`dtype` does not match the projector `dtype`, set the `dtype` arguement to None" ) tens = torch.randn(*size, device=self.k.device, dtype=self.k.dtype) * std + mean tens /= tens.norm(dim=-1, keepdim=True) return geoopt.ManifoldTensor(self.expmap0(tens), manifold=self)
[docs] def origin( self, *size, dtype=None, device=None, seed=42 ) -> "geoopt.ManifoldTensor": """ Zero point origin. Parameters ---------- size : shape the desired shape device : torch.device the desired device dtype : torch.dtype the desired dtype seed : int ignored Returns ------- ManifoldTensor zero point on the manifold """ if dtype is None: dtype = self.k.dtype if device is None: device = self.k.device zero_point = torch.zeros(*size, dtype=dtype, device=device) zero_point[..., 0] = torch.sqrt(self.k) return geoopt.ManifoldTensor(zero_point, manifold=self)
retr = expmap