Source code for geoopt.manifolds.stereographic.math

r"""
:math:`\kappa`-Stereographic math module.

The functions for the mathematics in gyrovector spaces are taken from the
following resources:

    [1] Ganea, Octavian, Gary Bécigneul, and Thomas Hofmann. "Hyperbolic
           neural networks." Advances in neural information processing systems.
           2018.
    [2] Bachmann, Gregor, Gary Bécigneul, and Octavian-Eugen Ganea. "Constant
           Curvature Graph Convolutional Networks." arXiv preprint
           arXiv:1911.05076 (2019).
    [3] Skopek, Ondrej, Octavian-Eugen Ganea, and Gary Bécigneul.
           "Mixed-curvature Variational Autoencoders." arXiv preprint
           arXiv:1911.08411 (2019).
    [4] Ungar, Abraham A. Analytic hyperbolic geometry: Mathematical
           foundations and applications. World Scientific, 2005.
    [5] Albert, Ungar Abraham. Barycentric calculus in Euclidean and
           hyperbolic geometry: A comparative introduction. World Scientific,
           2010.
"""

import functools
import torch.jit
from typing import List, Optional
from ...utils import list_range, drop_dims, sign, clamp_abs, sabs


@torch.jit.script
def tanh(x):
    return x.clamp(-15, 15).tanh()


@torch.jit.script
def artanh(x: torch.Tensor):
    x = x.clamp(-1 + 1e-7, 1 - 1e-7)
    return (torch.log(1 + x).sub(torch.log(1 - x))).mul(0.5)


@torch.jit.script
def arsinh(x: torch.Tensor):
    return (x + torch.sqrt(1 + x.pow(2))).clamp_min(1e-15).log().to(x.dtype)


@torch.jit.script
def abs_zero_grad(x):
    # this op has derivative equal to 1 at zero
    return x * sign(x)


@torch.jit.script
def tan_k_zero_taylor(x: torch.Tensor, k: torch.Tensor, order: int = -1):
    if order == 0:
        return x
    k = abs_zero_grad(k)
    if order == -1 or order == 5:
        return (
            x
            + 1 / 3 * k * x**3
            + 2 / 15 * k**2 * x**5
            + 17 / 315 * k**3 * x**7
            + 62 / 2835 * k**4 * x**9
            + 1382 / 155925 * k**5 * x**11
            # + o(k**6)
        )
    elif order == 1:
        return x + 1 / 3 * k * x**3
    elif order == 2:
        return x + 1 / 3 * k * x**3 + 2 / 15 * k**2 * x**5
    elif order == 3:
        return (
            x
            + 1 / 3 * k * x**3
            + 2 / 15 * k**2 * x**5
            + 17 / 315 * k**3 * x**7
        )
    elif order == 4:
        return (
            x
            + 1 / 3 * k * x**3
            + 2 / 15 * k**2 * x**5
            + 17 / 315 * k**3 * x**7
            + 62 / 2835 * k**4 * x**9
        )
    else:
        raise RuntimeError("order not in [-1, 5]")


@torch.jit.script
def artan_k_zero_taylor(x: torch.Tensor, k: torch.Tensor, order: int = -1):
    if order == 0:
        return x
    k = abs_zero_grad(k)
    if order == -1 or order == 5:
        return (
            x
            - 1 / 3 * k * x**3
            + 1 / 5 * k**2 * x**5
            - 1 / 7 * k**3 * x**7
            + 1 / 9 * k**4 * x**9
            - 1 / 11 * k**5 * x**11
            # + o(k**6)
        )
    elif order == 1:
        return x - 1 / 3 * k * x**3
    elif order == 2:
        return x - 1 / 3 * k * x**3 + 1 / 5 * k**2 * x**5
    elif order == 3:
        return (
            x - 1 / 3 * k * x**3 + 1 / 5 * k**2 * x**5 - 1 / 7 * k**3 * x**7
        )
    elif order == 4:
        return (
            x
            - 1 / 3 * k * x**3
            + 1 / 5 * k**2 * x**5
            - 1 / 7 * k**3 * x**7
            + 1 / 9 * k**4 * x**9
        )
    else:
        raise RuntimeError("order not in [-1, 5]")


@torch.jit.script
def arsin_k_zero_taylor(x: torch.Tensor, k: torch.Tensor, order: int = -1):
    if order == 0:
        return x
    k = abs_zero_grad(k)
    if order == -1 or order == 5:
        return (
            x
            + k * x**3 / 6
            + 3 / 40 * k**2 * x**5
            + 5 / 112 * k**3 * x**7
            + 35 / 1152 * k**4 * x**9
            + 63 / 2816 * k**5 * x**11
            # + o(k**6)
        )
    elif order == 1:
        return x + k * x**3 / 6
    elif order == 2:
        return x + k * x**3 / 6 + 3 / 40 * k**2 * x**5
    elif order == 3:
        return x + k * x**3 / 6 + 3 / 40 * k**2 * x**5 + 5 / 112 * k**3 * x**7
    elif order == 4:
        return (
            x
            + k * x**3 / 6
            + 3 / 40 * k**2 * x**5
            + 5 / 112 * k**3 * x**7
            + 35 / 1152 * k**4 * x**9
        )
    else:
        raise RuntimeError("order not in [-1, 5]")


@torch.jit.script
def sin_k_zero_taylor(x: torch.Tensor, k: torch.Tensor, order: int = -1):
    if order == 0:
        return x
    k = abs_zero_grad(k)
    if order == -1 or order == 5:
        return (
            x
            - k * x**3 / 6
            + k**2 * x**5 / 120
            - k**3 * x**7 / 5040
            + k**4 * x**9 / 362880
            - k**5 * x**11 / 39916800
            # + o(k**6)
        )
    elif order == 1:
        return x - k * x**3 / 6
    elif order == 2:
        return x - k * x**3 / 6 + k**2 * x**5 / 120
    elif order == 3:
        return x - k * x**3 / 6 + k**2 * x**5 / 120 - k**3 * x**7 / 5040
    elif order == 4:
        return (
            x
            - k * x**3 / 6
            + k**2 * x**5 / 120
            - k**3 * x**7 / 5040
            + k**4 * x**9 / 362880
        )
    else:
        raise RuntimeError("order not in [-1, 5]")


@torch.jit.script
def tan_k(x: torch.Tensor, k: torch.Tensor):
    k_sign = k.sign()
    zero = torch.zeros((), device=k.device, dtype=k.dtype)
    k_zero = k.isclose(zero)
    # shrink sign
    k_sign = torch.masked_fill(k_sign, k_zero, zero.to(k_sign.dtype))
    if torch.all(k_zero):
        return tan_k_zero_taylor(x, k, order=1)
    k_sqrt = sabs(k).sqrt()
    scaled_x = x * k_sqrt

    if torch.all(k_sign.lt(0)):
        return k_sqrt.reciprocal() * tanh(scaled_x)
    elif torch.all(k_sign.gt(0)):
        return k_sqrt.reciprocal() * scaled_x.clamp_max(1e38).tan()
    else:
        tan_k_nonzero = (
            torch.where(k_sign.gt(0), scaled_x.clamp_max(1e38).tan(), tanh(scaled_x))
            * k_sqrt.reciprocal()
        )
        return torch.where(k_zero, tan_k_zero_taylor(x, k, order=1), tan_k_nonzero)


@torch.jit.script
def artan_k(x: torch.Tensor, k: torch.Tensor):
    k_sign = k.sign()
    zero = torch.zeros((), device=k.device, dtype=k.dtype)
    k_zero = k.isclose(zero)
    # shrink sign
    k_sign = torch.masked_fill(k_sign, k_zero, zero.to(k_sign.dtype))
    if torch.all(k_zero):
        return artan_k_zero_taylor(x, k, order=1)
    k_sqrt = sabs(k).sqrt()
    scaled_x = x * k_sqrt

    if torch.all(k_sign.lt(0)):
        return k_sqrt.reciprocal() * artanh(scaled_x)
    elif torch.all(k_sign.gt(0)):
        return k_sqrt.reciprocal() * scaled_x.atan()
    else:
        artan_k_nonzero = (
            torch.where(k_sign.gt(0), scaled_x.atan(), artanh(scaled_x))
            * k_sqrt.reciprocal()
        )
        return torch.where(k_zero, artan_k_zero_taylor(x, k, order=1), artan_k_nonzero)


@torch.jit.script
def arsin_k(x: torch.Tensor, k: torch.Tensor):
    k_sign = k.sign()
    zero = torch.zeros((), device=k.device, dtype=k.dtype)
    k_zero = k.isclose(zero)
    # shrink sign
    k_sign = torch.masked_fill(k_sign, k_zero, zero.to(k_sign.dtype))
    if torch.all(k_zero):
        return arsin_k_zero_taylor(x, k)
    k_sqrt = sabs(k).sqrt()
    scaled_x = x * k_sqrt

    if torch.all(k_sign.lt(0)):
        return k_sqrt.reciprocal() * arsinh(scaled_x)
    elif torch.all(k_sign.gt(0)):
        return k_sqrt.reciprocal() * scaled_x.asin()
    else:
        arsin_k_nonzero = (
            torch.where(
                k_sign.gt(0),
                scaled_x.clamp(-1 + 1e-7, 1 - 1e-7).asin(),
                arsinh(scaled_x),
            )
            * k_sqrt.reciprocal()
        )
        return torch.where(k_zero, arsin_k_zero_taylor(x, k, order=1), arsin_k_nonzero)


@torch.jit.script
def sin_k(x: torch.Tensor, k: torch.Tensor):
    k_sign = k.sign()
    zero = torch.zeros((), device=k.device, dtype=k.dtype)
    k_zero = k.isclose(zero)
    # shrink sign
    k_sign = torch.masked_fill(k_sign, k_zero, zero.to(k_sign.dtype))
    if torch.all(k_zero):
        return sin_k_zero_taylor(x, k)
    k_sqrt = sabs(k).sqrt()
    scaled_x = x * k_sqrt

    if torch.all(k_sign.lt(0)):
        return k_sqrt.reciprocal() * torch.sinh(scaled_x)
    elif torch.all(k_sign.gt(0)):
        return k_sqrt.reciprocal() * scaled_x.sin()
    else:
        sin_k_nonzero = (
            torch.where(k_sign.gt(0), scaled_x.sin(), torch.sinh(scaled_x))
            * k_sqrt.reciprocal()
        )
        return torch.where(k_zero, sin_k_zero_taylor(x, k, order=1), sin_k_nonzero)


[docs]def project(x: torch.Tensor, *, k: torch.Tensor, dim=-1, eps=-1): r""" Safe projection on the manifold for numerical stability. Parameters ---------- x : tensor point on the Poincare ball k : tensor sectional curvature of manifold dim : int reduction dimension to compute norm eps : float stability parameter, uses default for dtype if not provided Returns ------- tensor projected vector on the manifold """ return _project(x, k, dim, eps)
@torch.jit.script def _project(x, k, dim: int = -1, eps: float = -1.0): if eps < 0: if x.dtype == torch.float32: eps = 4e-3 else: eps = 1e-5 maxnorm = (1 - eps) / (sabs(k) ** 0.5) maxnorm = torch.where(k.lt(0), maxnorm, k.new_full((), 1e15)) norm = x.norm(dim=dim, keepdim=True, p=2).clamp_min(1e-15) cond = norm > maxnorm projected = x / norm * maxnorm return torch.where(cond, projected, x)
[docs]def lambda_x(x: torch.Tensor, *, k: torch.Tensor, keepdim=False, dim=-1): r""" Compute the conformal factor :math:`\lambda^\kappa_x` for a point on the ball. .. math:: \lambda^\kappa_x = \frac{2}{1 + \kappa \|x\|_2^2} Parameters ---------- x : tensor point on the Poincare ball k : tensor sectional curvature of manifold keepdim : bool retain the last dim? (default: false) dim : int reduction dimension Returns ------- tensor conformal factor """ return _lambda_x(x, k, keepdim=keepdim, dim=dim)
@torch.jit.script def _lambda_x(x: torch.Tensor, k: torch.Tensor, keepdim: bool = False, dim: int = -1): return 2 / (1 + k * x.pow(2).sum(dim=dim, keepdim=keepdim)).clamp_min(1e-15)
[docs]def inner( x: torch.Tensor, u: torch.Tensor, v: torch.Tensor, *, k, keepdim=False, dim=-1 ): r""" Compute inner product for two vectors on the tangent space w.r.t Riemannian metric on the Poincare ball. .. math:: \langle u, v\rangle_x = (\lambda^\kappa_x)^2 \langle u, v \rangle Parameters ---------- x : tensor point on the Poincare ball u : tensor tangent vector to :math:`x` on Poincare ball v : tensor tangent vector to :math:`x` on Poincare ball k : tensor sectional curvature of manifold keepdim : bool retain the last dim? (default: false) dim : int reduction dimension Returns ------- tensor inner product """ return _inner(x, u, v, k, keepdim=keepdim, dim=dim)
@torch.jit.script def _inner( x: torch.Tensor, u: torch.Tensor, v: torch.Tensor, k: torch.Tensor, keepdim: bool = False, dim: int = -1, ): return _lambda_x(x, k, keepdim=True, dim=dim) ** 2 * (u * v).sum( dim=dim, keepdim=keepdim )
[docs]def norm(x: torch.Tensor, u: torch.Tensor, *, k: torch.Tensor, keepdim=False, dim=-1): r""" Compute vector norm on the tangent space w.r.t Riemannian metric on the Poincare ball. .. math:: \|u\|_x = \lambda^\kappa_x \|u\|_2 Parameters ---------- x : tensor point on the Poincare ball u : tensor tangent vector to :math:`x` on Poincare ball k : tensor sectional curvature of manifold keepdim : bool retain the last dim? (default: false) dim : int reduction dimension Returns ------- tensor norm of vector """ return _norm(x, u, k, keepdim=keepdim, dim=dim)
@torch.jit.script def _norm( x: torch.Tensor, u: torch.Tensor, k: torch.Tensor, keepdim: bool = False, dim: int = -1, ): return _lambda_x(x, k, keepdim=keepdim, dim=dim) * u.norm( dim=dim, keepdim=keepdim, p=2 )
[docs]def mobius_add(x: torch.Tensor, y: torch.Tensor, *, k: torch.Tensor, dim=-1): r""" Compute the Möbius gyrovector addition. .. math:: x \oplus_\kappa y = \frac{ (1 - 2 \kappa \langle x, y\rangle - \kappa \|y\|^2_2) x + (1 + \kappa \|x\|_2^2) y }{ 1 - 2 \kappa \langle x, y\rangle + \kappa^2 \|x\|^2_2 \|y\|^2_2 } .. plot:: plots/extended/stereographic/mobius_add.py In general this operation is not commutative: .. math:: x \oplus_\kappa y \ne y \oplus_\kappa x But in some cases this property holds: * zero vector case .. math:: \mathbf{0} \oplus_\kappa x = x \oplus_\kappa \mathbf{0} * zero curvature case that is same as Euclidean addition .. math:: x \oplus_0 y = y \oplus_0 x Another useful property is so called left-cancellation law: .. math:: (-x) \oplus_\kappa (x \oplus_\kappa y) = y Parameters ---------- x : tensor point on the manifold y : tensor point on the manifold k : tensor sectional curvature of manifold dim : int reduction dimension for operations Returns ------- tensor the result of the Möbius addition """ return _mobius_add(x, y, k, dim=dim)
@torch.jit.script def _mobius_add(x: torch.Tensor, y: torch.Tensor, k: torch.Tensor, dim: int = -1): x2 = x.pow(2).sum(dim=dim, keepdim=True) y2 = y.pow(2).sum(dim=dim, keepdim=True) xy = (x * y).sum(dim=dim, keepdim=True) num = (1 - 2 * k * xy - k * y2) * x + (1 + k * x2) * y denom = 1 - 2 * k * xy + k**2 * x2 * y2 # minimize denom (omit K to simplify th notation) # 1) # {d(denom)/d(x) = 2 y + 2x * <y, y> = 0 # {d(denom)/d(y) = 2 x + 2y * <x, x> = 0 # 2) # {y + x * <y, y> = 0 # {x + y * <x, x> = 0 # 3) # {- y/<y, y> = x # {- x/<x, x> = y # 4) # minimum = 1 - 2 <y, y>/<y, y> + <y, y>/<y, y> = 0 return num / denom.clamp_min(1e-15)
[docs]def mobius_sub(x: torch.Tensor, y: torch.Tensor, *, k: torch.Tensor, dim=-1): r""" Compute the Möbius gyrovector subtraction. The Möbius subtraction can be represented via the Möbius addition as follows: .. math:: x \ominus_\kappa y = x \oplus_\kappa (-y) Parameters ---------- x : tensor point on manifold y : tensor point on manifold k : tensor sectional curvature of manifold dim : int reduction dimension for operations Returns ------- tensor the result of the Möbius subtraction """ return _mobius_sub(x, y, k, dim=dim)
def _mobius_sub(x: torch.Tensor, y: torch.Tensor, k: torch.Tensor, dim: int = -1): return _mobius_add(x, -y, k, dim=dim)
[docs]def gyration( a: torch.Tensor, b: torch.Tensor, u: torch.Tensor, *, k: torch.Tensor, dim=-1 ): r""" Compute the gyration of :math:`u` by :math:`[a,b]`. The gyration is a special operation of gyrovector spaces. The gyrovector space addition operation :math:`\oplus_\kappa` is not associative (as mentioned in :func:`mobius_add`), but it is gyroassociative, which means .. math:: u \oplus_\kappa (v \oplus_\kappa w) = (u\oplus_\kappa v) \oplus_\kappa \operatorname{gyr}[u, v]w, where .. math:: \operatorname{gyr}[u, v]w = \ominus (u \oplus_\kappa v) \oplus (u \oplus_\kappa (v \oplus_\kappa w)) We can simplify this equation using the explicit formula for the Möbius addition [1]. Recall, .. math:: A = - \kappa^2 \langle u, w\rangle \langle v, v\rangle - \kappa \langle v, w\rangle + 2 \kappa^2 \langle u, v\rangle \langle v, w\rangle\\ B = - \kappa^2 \langle v, w\rangle \langle u, u\rangle + \kappa \langle u, w\rangle\\ D = 1 - 2 \kappa \langle u, v\rangle + \kappa^2 \langle u, u\rangle \langle v, v\rangle\\ \operatorname{gyr}[u, v]w = w + 2 \frac{A u + B v}{D}. Parameters ---------- a : tensor first point on manifold b : tensor second point on manifold u : tensor vector field for operation k : tensor sectional curvature of manifold dim : int reduction dimension for operations Returns ------- tensor the result of automorphism References ---------- [1] A. A. Ungar (2009), A Gyrovector Space Approach to Hyperbolic Geometry """ return _gyration(a, b, u, k, dim=dim)
@torch.jit.script def _gyration( u: torch.Tensor, v: torch.Tensor, w: torch.Tensor, k: torch.Tensor, dim: int = -1 ): # non-simplified # mupv = -_mobius_add(u, v, K) # vpw = _mobius_add(u, w, K) # upvpw = _mobius_add(u, vpw, K) # return _mobius_add(mupv, upvpw, K) # simplified u2 = u.pow(2).sum(dim=dim, keepdim=True) v2 = v.pow(2).sum(dim=dim, keepdim=True) uv = (u * v).sum(dim=dim, keepdim=True) uw = (u * w).sum(dim=dim, keepdim=True) vw = (v * w).sum(dim=dim, keepdim=True) K2 = k**2 a = -K2 * uw * v2 - k * vw + 2 * K2 * uv * vw b = -K2 * vw * u2 + k * uw d = 1 - 2 * k * uv + K2 * u2 * v2 return w + 2 * (a * u + b * v) / d.clamp_min(1e-15) def mobius_coadd(x: torch.Tensor, y: torch.Tensor, *, k: torch.Tensor, dim=-1): r""" Compute the Möbius gyrovector coaddition. The addition operation :math:`\oplus_\kappa` is neither associative, nor commutative. In contrast, the coaddition :math:`\boxplus_\kappa` (or cooperation) is an associative operation that is defined as follows. .. math:: a \boxplus_\kappa b = b \boxplus_\kappa a = a\operatorname{gyr}[a, -b]b\\ = \frac{ (1 + \kappa \|y\|^2_2) x + (1 + \kappa \|x\|_2^2) y }{ 1 + \kappa^2 \|x\|^2_2 \|y\|^2_2 }, where :math:`\operatorname{gyr}[a, b]v = \ominus_\kappa (a \oplus_\kappa b) \oplus_\kappa (a \oplus_\kappa (b \oplus_\kappa v))` The following right cancellation property holds .. math:: (a \boxplus_\kappa b) \ominus_\kappa b = a\\ (a \oplus_\kappa b) \boxminus_\kappa b = a Parameters ---------- x : tensor point on manifold y : tensor point on manifold k : tensor sectional curvature of manifold dim : int reduction dimension for operations Returns ------- tensor the result of the Möbius coaddition """ return _mobius_coadd(x, y, k, dim=dim) # TODO: check numerical stability with Gregor's paper!!! @torch.jit.script def _mobius_coadd(x: torch.Tensor, y: torch.Tensor, k: torch.Tensor, dim: int = -1): # x2 = x.pow(2).sum(dim=dim, keepdim=True) # y2 = y.pow(2).sum(dim=dim, keepdim=True) # num = (1 + K * y2) * x + (1 + K * x2) * y # denom = 1 - K ** 2 * x2 * y2 # avoid division by zero in this way # return num / denom.clamp_min(1e-15) # return _mobius_add(x, _gyration(x, -y, y, k=k, dim=dim), k, dim=dim) def mobius_cosub(x: torch.Tensor, y: torch.Tensor, *, k: torch.Tensor, dim=-1): r""" Compute the Möbius gyrovector cosubtraction. The Möbius cosubtraction is defined as follows: .. math:: a \boxminus_\kappa b = a \boxplus_\kappa -b Parameters ---------- x : tensor point on manifold y : tensor point on manifold k : tensor sectional curvature of manifold dim : int reduction dimension for operations Returns ------- tensor the result of the Möbius cosubtraction """ return _mobius_cosub(x, y, k, dim=dim) @torch.jit.script def _mobius_cosub(x: torch.Tensor, y: torch.Tensor, k: torch.Tensor, dim: int = -1): return _mobius_coadd(x, -y, k, dim=dim) # TODO: can we make this operation somehow safer by breaking up the # TODO: scalar multiplication for K>0 when the argument to the # TODO: tan function gets close to pi/2+k*pi for k in Z? # TODO: one could use the scalar associative law # TODO: s_1 (X) s_2 (X) x = (s_1*s_2) (X) x # TODO: to implement a more stable Möbius scalar mult
[docs]def mobius_scalar_mul(r: torch.Tensor, x: torch.Tensor, *, k: torch.Tensor, dim=-1): r""" Compute the Möbius scalar multiplication. .. math:: r \otimes_\kappa x = \tan_\kappa(r\tan_\kappa^{-1}(\|x\|_2))\frac{x}{\|x\|_2} This operation has properties similar to the Euclidean scalar multiplication * `n-addition` property .. math:: r \otimes_\kappa x = x \oplus_\kappa \dots \oplus_\kappa x * Distributive property .. math:: (r_1 + r_2) \otimes_\kappa x = r_1 \otimes_\kappa x \oplus r_2 \otimes_\kappa x * Scalar associativity .. math:: (r_1 r_2) \otimes_\kappa x = r_1 \otimes_\kappa (r_2 \otimes_\kappa x) * Monodistributivity .. math:: r \otimes_\kappa (r_1 \otimes x \oplus r_2 \otimes x) = r \otimes_\kappa (r_1 \otimes x) \oplus r \otimes (r_2 \otimes x) * Scaling property .. math:: |r| \otimes_\kappa x / \|r \otimes_\kappa x\|_2 = x/\|x\|_2 Parameters ---------- r : tensor scalar for multiplication x : tensor point on manifold k : tensor sectional curvature of manifold dim : int reduction dimension for operations Returns ------- tensor the result of the Möbius scalar multiplication """ return _mobius_scalar_mul(r, x, k, dim=dim)
@torch.jit.script def _mobius_scalar_mul( r: torch.Tensor, x: torch.Tensor, k: torch.Tensor, dim: int = -1 ): x_norm = x.norm(dim=dim, keepdim=True, p=2).clamp_min(1e-15) res_c = tan_k(r * artan_k(x_norm, k), k) * (x / x_norm) return res_c
[docs]def dist(x: torch.Tensor, y: torch.Tensor, *, k: torch.Tensor, keepdim=False, dim=-1): r""" Compute the geodesic distance between :math:`x` and :math:`y` on the manifold. .. math:: d_\kappa(x, y) = 2\tan_\kappa^{-1}(\|(-x)\oplus_\kappa y\|_2) .. plot:: plots/extended/stereographic/distance.py Parameters ---------- x : tensor point on manifold y : tensor point on manifold k : tensor sectional curvature of manifold keepdim : bool retain the last dim? (default: false) dim : int reduction dimension Returns ------- tensor geodesic distance between :math:`x` and :math:`y` """ return _dist(x, y, k, keepdim=keepdim, dim=dim)
@torch.jit.script def _dist( x: torch.Tensor, y: torch.Tensor, k: torch.Tensor, keepdim: bool = False, dim: int = -1, ): return 2.0 * artan_k( _mobius_add(-x, y, k, dim=dim).norm(dim=dim, p=2, keepdim=keepdim), k ) def dist0(x: torch.Tensor, *, k: torch.Tensor, keepdim=False, dim=-1): r""" Compute geodesic distance to the manifold's origin. Parameters ---------- x : tensor point on manifold k : tensor sectional curvature of manifold keepdim : bool retain the last dim? (default: false) dim : int reduction dimension for operations Returns ------- tensor geodesic distance between :math:`x` and :math:`0` """ return _dist0(x, k, keepdim=keepdim, dim=dim) @torch.jit.script def _dist0(x: torch.Tensor, k: torch.Tensor, keepdim: bool = False, dim: int = -1): return 2.0 * artan_k(x.norm(dim=dim, p=2, keepdim=keepdim), k)
[docs]def geodesic( t: torch.Tensor, x: torch.Tensor, y: torch.Tensor, *, k: torch.Tensor, dim=-1 ): r""" Compute the point on the path connecting :math:`x` and :math:`y` at time :math:`x`. The path can also be treated as an extension of the line segment to an unbounded geodesic that goes through :math:`x` and :math:`y`. The equation of the geodesic is given as: .. math:: \gamma_{x\to y}(t) = x \oplus_\kappa t \otimes_\kappa ((-x) \oplus_\kappa y) The properties of the geodesic are the following: .. math:: \gamma_{x\to y}(0) = x\\ \gamma_{x\to y}(1) = y\\ \dot\gamma_{x\to y}(t) = v Furthermore, the geodesic also satisfies the property of local distance minimization: .. math:: d_\kappa(\gamma_{x\to y}(t_1), \gamma_{x\to y}(t_2)) = v|t_1-t_2| "Natural parametrization" of the curve ensures unit speed geodesics which yields the above formula with :math:`v=1`. However, we can always compute the constant speed :math:`v` from the points that the particular path connects: .. math:: v = d_\kappa(\gamma_{x\to y}(0), \gamma_{x\to y}(1)) = d_\kappa(x, y) Parameters ---------- t : tensor travelling time x : tensor starting point on manifold y : tensor target point on manifold k : tensor sectional curvature of manifold dim : int reduction dimension for operations Returns ------- tensor point on the geodesic going through x and y """ return _geodesic(t, x, y, k, dim=dim)
@torch.jit.script def _geodesic( t: torch.Tensor, x: torch.Tensor, y: torch.Tensor, k: torch.Tensor, dim: int = -1 ): # this is not very numerically stable v = _mobius_add(-x, y, k, dim=dim) tv = _mobius_scalar_mul(t, v, k, dim=dim) gamma_t = _mobius_add(x, tv, k, dim=dim) return gamma_t
[docs]def expmap(x: torch.Tensor, u: torch.Tensor, *, k: torch.Tensor, dim=-1): r""" Compute the exponential map of :math:`u` at :math:`x`. The expmap is tightly related with :func:`geodesic`. Intuitively, the expmap represents a smooth travel along a geodesic from the starting point :math:`x`, into the initial direction :math:`u` at speed :math:`\|u\|_x` for the duration of one time unit. In formulas one can express this as the travel along the curve :math:`\gamma_{x, u}(t)` such that .. math:: \gamma_{x, u}(0) = x\\ \dot\gamma_{x, u}(0) = u\\ \|\dot\gamma_{x, u}(t)\|_{\gamma_{x, u}(t)} = \|u\|_x The existence of this curve relies on uniqueness of the differential equation solution, that is local. For the universal manifold the solution is well defined globally and we have. .. math:: \operatorname{exp}^\kappa_x(u) = \gamma_{x, u}(1) = \\ x\oplus_\kappa \tan_\kappa(\|u\|_x/2) \frac{u}{\|u\|_2} Parameters ---------- x : tensor starting point on manifold u : tensor speed vector in tangent space at x k : tensor sectional curvature of manifold dim : int reduction dimension for operations Returns ------- tensor :math:`\gamma_{x, u}(1)` end point """ return _expmap(x, u, k, dim=dim)
@torch.jit.script def _expmap(x: torch.Tensor, u: torch.Tensor, k: torch.Tensor, dim: int = -1): u_norm = u.norm(dim=dim, p=2, keepdim=True).clamp_min(1e-15) lam = _lambda_x(x, k, dim=dim, keepdim=True) second_term = tan_k((lam / 2.0) * u_norm, k) * (u / u_norm) y = _mobius_add(x, second_term, k, dim=dim) return y
[docs]def expmap0(u: torch.Tensor, *, k: torch.Tensor, dim=-1): r""" Compute the exponential map of :math:`u` at the origin :math:`0`. .. math:: \operatorname{exp}^\kappa_0(u) = \tan_\kappa(\|u\|_2/2) \frac{u}{\|u\|_2} Parameters ---------- u : tensor speed vector on manifold k : tensor sectional curvature of manifold dim : int reduction dimension for operations Returns ------- tensor :math:`\gamma_{0, u}(1)` end point """ return _expmap0(u, k, dim=dim)
@torch.jit.script def _expmap0(u: torch.Tensor, k: torch.Tensor, dim: int = -1): u_norm = u.norm(dim=dim, p=2, keepdim=True).clamp_min(1e-15) gamma_1 = tan_k(u_norm, k) * (u / u_norm) return gamma_1
[docs]def geodesic_unit( t: torch.Tensor, x: torch.Tensor, u: torch.Tensor, *, k: torch.Tensor, dim=-1 ): r""" Compute the point on the unit speed geodesic. The point on the unit speed geodesic at time :math:`t`, starting from :math:`x` with initial direction :math:`u/\|u\|_x` is computed as follows: .. math:: \gamma_{x,u}(t) = x\oplus_\kappa \tan_\kappa(t/2) \frac{u}{\|u\|_2} Parameters ---------- t : tensor travelling time x : tensor initial point on manifold u : tensor initial direction in tangent space at x k : tensor sectional curvature of manifold dim : int reduction dimension for operations Returns ------- tensor the point on the unit speed geodesic """ return _geodesic_unit(t, x, u, k, dim=dim)
@torch.jit.script def _geodesic_unit( t: torch.Tensor, x: torch.Tensor, u: torch.Tensor, k: torch.Tensor, dim: int = -1, ): u_norm = u.norm(dim=dim, p=2, keepdim=True).clamp_min(1e-15) second_term = tan_k(t / 2.0, k) * (u / u_norm) gamma_1 = _mobius_add(x, second_term, k, dim=dim) return gamma_1
[docs]def logmap(x: torch.Tensor, y: torch.Tensor, *, k: torch.Tensor, dim=-1): r""" Compute the logarithmic map of :math:`y` at :math:`x`. .. math:: \operatorname{log}^\kappa_x(y) = \frac{2}{\lambda_x^\kappa} \tan_\kappa^{-1}(\|(-x)\oplus_\kappa y\|_2) * \frac{(-x)\oplus_\kappa y}{\|(-x)\oplus_\kappa y\|_2} The result of the logmap is a vector :math:`u` in the tangent space of :math:`x` such that .. math:: y = \operatorname{exp}^\kappa_x(\operatorname{log}^\kappa_x(y)) Parameters ---------- x : tensor starting point on manifold y : tensor target point on manifold k : tensor sectional curvature of manifold dim : int reduction dimension for operations Returns ------- tensor tangent vector :math:`u\in T_x M` that transports :math:`x` to :math:`y` """ return _logmap(x, y, k, dim=dim)
@torch.jit.script def _logmap(x: torch.Tensor, y: torch.Tensor, k: torch.Tensor, dim: int = -1): sub = _mobius_add(-x, y, k, dim=dim) sub_norm = sub.norm(dim=dim, p=2, keepdim=True).clamp_min(1e-15) lam = _lambda_x(x, k, keepdim=True, dim=dim) return 2.0 * artan_k(sub_norm, k) * (sub / (lam * sub_norm))
[docs]def logmap0(y: torch.Tensor, *, k: torch.Tensor, dim=-1): r""" Compute the logarithmic map of :math:`y` at the origin :math:`0`. .. math:: \operatorname{log}^\kappa_0(y) = \tan_\kappa^{-1}(\|y\|_2) \frac{y}{\|y\|_2} The result of the logmap at the origin is a vector :math:`u` in the tangent space of the origin :math:`0` such that .. math:: y = \operatorname{exp}^\kappa_0(\operatorname{log}^\kappa_0(y)) Parameters ---------- y : tensor target point on manifold k : tensor sectional curvature of manifold dim : int reduction dimension for operations Returns ------- tensor tangent vector :math:`u\in T_0 M` that transports :math:`0` to :math:`y` """ return _logmap0(y, k, dim=dim)
@torch.jit.script def _logmap0(y: torch.Tensor, k, dim: int = -1): y_norm = y.norm(dim=dim, p=2, keepdim=True).clamp_min(1e-15) return (y / y_norm) * artan_k(y_norm, k)
[docs]def mobius_matvec(m: torch.Tensor, x: torch.Tensor, *, k: torch.Tensor, dim=-1): r""" Compute the generalization of matrix-vector multiplication in gyrovector spaces. The Möbius matrix vector operation is defined as follows: .. math:: M \otimes_\kappa x = \tan_\kappa\left( \frac{\|Mx\|_2}{\|x\|_2}\tan_\kappa^{-1}(\|x\|_2) \right)\frac{Mx}{\|Mx\|_2} .. plot:: plots/extended/stereographic/mobius_matvec.py Parameters ---------- m : tensor matrix for multiplication. Batched matmul is performed if ``m.dim() > 2``, but only last dim reduction is supported x : tensor point on manifold k : tensor sectional curvature of manifold dim : int reduction dimension for operations Returns ------- tensor Möbius matvec result """ return _mobius_matvec(m, x, k, dim=dim)
@torch.jit.script def _mobius_matvec(m: torch.Tensor, x: torch.Tensor, k: torch.Tensor, dim: int = -1): if m.dim() > 2 and dim != -1: raise RuntimeError( "broadcasted Möbius matvec is supported for the last dim only" ) x_norm = x.norm(dim=dim, keepdim=True, p=2).clamp_min(1e-15) if dim != -1 or m.dim() == 2: mx = torch.tensordot(x, m, ([dim], [1])) else: mx = torch.matmul(m, x.unsqueeze(-1)).squeeze(-1) mx_norm = mx.norm(dim=dim, keepdim=True, p=2).clamp_min(1e-15) res_c = tan_k(mx_norm / x_norm * artan_k(x_norm, k), k) * (mx / mx_norm) cond = (mx == 0).prod(dim=dim, keepdim=True, dtype=torch.bool) res_0 = torch.zeros(1, dtype=res_c.dtype, device=res_c.device) res = torch.where(cond, res_0, res_c) return res # TODO: check if this extends to gyrovector spaces for positive curvature # TODO: add plot
[docs]def mobius_pointwise_mul(w: torch.Tensor, x: torch.Tensor, *, k: torch.Tensor, dim=-1): r""" Compute the generalization for point-wise multiplication in gyrovector spaces. The Möbius pointwise multiplication is defined as follows .. math:: \operatorname{diag}(w) \otimes_\kappa x = \tan_\kappa\left( \frac{\|\operatorname{diag}(w)x\|_2}{x}\tanh^{-1}(\|x\|_2) \right)\frac{\|\operatorname{diag}(w)x\|_2}{\|x\|_2} Parameters ---------- w : tensor weights for multiplication (should be broadcastable to x) x : tensor point on manifold k : tensor sectional curvature of manifold dim : int reduction dimension for operations Returns ------- tensor Möbius point-wise mul result """ return _mobius_pointwise_mul(w, x, k, dim=dim)
@torch.jit.script def _mobius_pointwise_mul( w: torch.Tensor, x: torch.Tensor, k: torch.Tensor, dim: int = -1 ): x_norm = x.norm(dim=dim, keepdim=True, p=2).clamp_min(1e-15) wx = w * x wx_norm = wx.norm(dim=dim, keepdim=True, p=2).clamp_min(1e-15) res_c = tan_k(wx_norm / x_norm * artan_k(x_norm, k), k) * (wx / wx_norm) zero = torch.zeros((), dtype=res_c.dtype, device=res_c.device) cond = wx.isclose(zero).prod(dim=dim, keepdim=True, dtype=torch.bool) res = torch.where(cond, zero, res_c) return res
[docs]def mobius_fn_apply_chain(x: torch.Tensor, *fns: callable, k: torch.Tensor, dim=-1): r""" Compute the generalization of sequential function application in gyrovector spaces. First, a gyrovector is mapped to the tangent space (first-order approx.) via :math:`\operatorname{log}^\kappa_0` and then the sequence of functions is applied to the vector in the tangent space. The resulting tangent vector is then mapped back with :math:`\operatorname{exp}^\kappa_0`. .. math:: f^{\otimes_\kappa}(x) = \operatorname{exp}^\kappa_0(f(\operatorname{log}^\kappa_0(y))) The definition of mobius function application allows chaining as .. math:: y = \operatorname{exp}^\kappa_0(\operatorname{log}^\kappa_0(y)) Resulting in .. math:: (f \circ g)^{\otimes_\kappa}(x) = \operatorname{exp}^\kappa_0( (f \circ g) (\operatorname{log}^\kappa_0(y)) ) Parameters ---------- x : tensor point on manifold fns : callable[] functions to apply k : tensor sectional curvature of manifold dim : int reduction dimension for operations Returns ------- tensor Apply chain result """ if not fns: return x else: ex = _logmap0(x, k, dim=dim) for fn in fns: ex = fn(ex) y = _expmap0(ex, k, dim=dim) return y
[docs]def mobius_fn_apply( fn: callable, x: torch.Tensor, *args, k: torch.Tensor, dim=-1, **kwargs ): r""" Compute the generalization of function application in gyrovector spaces. First, a gyrovector is mapped to the tangent space (first-order approx.) via :math:`\operatorname{log}^\kappa_0` and then the function is applied to the vector in the tangent space. The resulting tangent vector is then mapped back with :math:`\operatorname{exp}^\kappa_0`. .. math:: f^{\otimes_\kappa}(x) = \operatorname{exp}^\kappa_0(f(\operatorname{log}^\kappa_0(y))) .. plot:: plots/extended/stereographic/mobius_sigmoid_apply.py Parameters ---------- x : tensor point on manifold fn : callable function to apply k : tensor sectional curvature of manifold dim : int reduction dimension for operations Returns ------- tensor Result of function in hyperbolic space """ ex = _logmap0(x, k, dim=dim) ex = fn(ex, *args, **kwargs) y = _expmap0(ex, k, dim=dim) return y
def mobiusify(fn: callable): r""" Wrap a function such that is works in gyrovector spaces. Parameters ---------- fn : callable function in Euclidean space Returns ------- callable function working in gyrovector spaces Notes ----- New function will accept additional argument ``k`` and ``dim``. """ @functools.wraps(fn) def mobius_fn(x, *args, k, dim=-1, **kwargs): ex = _logmap0(x, k, dim=dim) ex = fn(ex, *args, **kwargs) y = _expmap0(ex, k, dim=dim) return y return mobius_fn
[docs]def dist2plane( x: torch.Tensor, p: torch.Tensor, a: torch.Tensor, *, k: torch.Tensor, keepdim=False, signed=False, scaled=False, dim=-1, ): r""" Geodesic distance from :math:`x` to a hyperplane :math:`H_{a, b}`. The hyperplane is such that its set of points is orthogonal to :math:`a` and contains :math:`p`. .. plot:: plots/extended/stereographic/distance2plane.py To form an intuition what is a hyperplane in gyrovector spaces, let's first consider an Euclidean hyperplane .. math:: H_{a, b} = \left\{ x \in \mathbb{R}^n\;:\;\langle x, a\rangle - b = 0 \right\}, where :math:`a\in \mathbb{R}^n\backslash \{\mathbf{0}\}` and :math:`b\in \mathbb{R}^n`. This formulation of a hyperplane is hard to generalize, therefore we can rewrite :math:`\langle x, a\rangle - b` utilizing orthogonal completion. Setting any :math:`p` s.t. :math:`b=\langle a, p\rangle` we have .. math:: H_{a, b} = \left\{ x \in \mathbb{R}^n\;:\;\langle x, a\rangle - b = 0 \right\}\\ =H_{a, \langle a, p\rangle} = \tilde{H}_{a, p}\\ = \left\{ x \in \mathbb{R}^n\;:\;\langle x, a\rangle - \langle a, p\rangle = 0 \right\}\\ =\left\{ x \in \mathbb{R}^n\;:\;\langle -p + x, a\rangle = 0 \right\}\\ = p + \{a\}^\perp Naturally we have a set :math:`\{a\}^\perp` with applied :math:`+` operator to each element. Generalizing a notion of summation to the gyrovector space we replace :math:`+` with :math:`\oplus_\kappa`. Next, we should figure out what is :math:`\{a\}^\perp` in the gyrovector space. First thing that we should acknowledge is that notion of orthogonality is defined for vectors in tangent spaces. Let's consider now :math:`p\in \mathcal{M}_\kappa^n` and :math:`a\in T_p\mathcal{M}_\kappa^n\backslash \{\mathbf{0}\}`. Slightly deviating from traditional notation let's write :math:`\{a\}_p^\perp` highlighting the tight relationship of :math:`a\in T_p\mathcal{M}_\kappa^n\backslash \{\mathbf{0}\}` with :math:`p \in \mathcal{M}_\kappa^n`. We then define .. math:: \{a\}_p^\perp := \left\{ z\in T_p\mathcal{M}_\kappa^n \;:\; \langle z, a\rangle_p = 0 \right\} Recalling that a tangent vector :math:`z` for point :math:`p` yields :math:`x = \operatorname{exp}^\kappa_p(z)` we rewrite the above equation as .. math:: \{a\}_p^\perp := \left\{ x\in \mathcal{M}_\kappa^n \;:\; \langle \operatorname{log}_p^\kappa(x), a\rangle_p = 0 \right\} This formulation is something more pleasant to work with. Putting all together .. math:: \tilde{H}_{a, p}^\kappa = p + \{a\}^\perp_p\\ = \left\{ x \in \mathcal{M}_\kappa^n\;:\;\langle \operatorname{log}^\kappa_p(x), a\rangle_p = 0 \right\} \\ = \left\{ x \in \mathcal{M}_\kappa^n\;:\;\langle -p \oplus_\kappa x, a\rangle = 0 \right\} To compute the distance :math:`d_\kappa(x, \tilde{H}_{a, p}^\kappa)` we find .. math:: d_\kappa(x, \tilde{H}_{a, p}^\kappa) = \inf_{w\in \tilde{H}_{a, p}^\kappa} d_\kappa(x, w)\\ = \sin^{-1}_\kappa\left\{ \frac{ 2 |\langle(-p)\oplus_\kappa x, a\rangle| }{ (1+\kappa\|(-p)\oplus_\kappa \|x\|^2_2)\|a\|_2 } \right\} Parameters ---------- x : tensor point on manifold to compute distance for a : tensor hyperplane normal vector in tangent space of :math:`p` p : tensor point on manifold lying on the hyperplane k : tensor sectional curvature of manifold keepdim : bool retain the last dim? (default: false) signed : bool return signed distance scaled : bool scale distance by tangent norm dim : int reduction dimension for operations Returns ------- tensor distance to the hyperplane """ return _dist2plane( x, a, p, k, keepdim=keepdim, signed=signed, dim=dim, scaled=scaled )
@torch.jit.script def _dist2plane( x: torch.Tensor, a: torch.Tensor, p: torch.Tensor, k: torch.Tensor, keepdim: bool = False, signed: bool = False, scaled: bool = False, dim: int = -1, ): diff = _mobius_add(-p, x, k, dim=dim) diff_norm2 = diff.pow(2).sum(dim=dim, keepdim=keepdim).clamp_min(1e-15) sc_diff_a = (diff * a).sum(dim=dim, keepdim=keepdim) if not signed: sc_diff_a = sc_diff_a.abs() a_norm = a.norm(dim=dim, keepdim=keepdim, p=2) num = 2.0 * sc_diff_a denom = clamp_abs((1 + k * diff_norm2) * a_norm) distance = arsin_k(num / denom, k) if scaled: distance = distance * a_norm return distance
[docs]def parallel_transport( x: torch.Tensor, y: torch.Tensor, v: torch.Tensor, *, k: torch.Tensor, dim=-1 ): r""" Compute the parallel transport of :math:`v` from :math:`x` to :math:`y`. The parallel transport is essential for adaptive algorithms on Riemannian manifolds. For gyrovector spaces the parallel transport is expressed through the gyration. .. plot:: plots/extended/stereographic/gyrovector_parallel_transport.py To recover parallel transport we first need to study isomorphisms between gyrovectors and vectors. The reason is that originally, parallel transport is well defined for gyrovectors as .. math:: P_{x\to y}(z) = \operatorname{gyr}[y, -x]z, where :math:`x,\:y,\:z \in \mathcal{M}_\kappa^n` and :math:`\operatorname{gyr}[a, b]c = \ominus (a \oplus_\kappa b) \oplus_\kappa (a \oplus_\kappa (b \oplus_\kappa c))` But we want to obtain parallel transport for vectors, not for gyrovectors. The blessing is the isomorphism mentioned above. This mapping is given by .. math:: U^\kappa_p \: : \: T_p\mathcal{M}_\kappa^n \to \mathbb{G} = v \mapsto \lambda^\kappa_p v Finally, having the points :math:`x,\:y \in \mathcal{M}_\kappa^n` and a tangent vector :math:`u\in T_x\mathcal{M}_\kappa^n` we obtain .. math:: P^\kappa_{x\to y}(v) = (U^\kappa_y)^{-1}\left(\operatorname{gyr}[y, -x] U^\kappa_x(v)\right)\\ = \operatorname{gyr}[y, -x] v \lambda^\kappa_x / \lambda^\kappa_y .. plot:: plots/extended/stereographic/parallel_transport.py Parameters ---------- x : tensor starting point y : tensor end point v : tensor tangent vector at x to be transported to y k : tensor sectional curvature of manifold dim : int reduction dimension for operations Returns ------- tensor transported vector """ return _parallel_transport(x, y, v, k, dim=dim)
@torch.jit.script def _parallel_transport( x: torch.Tensor, y: torch.Tensor, u: torch.Tensor, k: torch.Tensor, dim: int = -1 ): return ( _gyration(y, -x, u, k, dim=dim) * _lambda_x(x, k, keepdim=True, dim=dim) / _lambda_x(y, k, keepdim=True, dim=dim) ) def parallel_transport0(y: torch.Tensor, v: torch.Tensor, *, k: torch.Tensor, dim=-1): r""" Compute the parallel transport of :math:`v` from the origin :math:`0` to :math:`y`. This is just a special case of the parallel transport with the starting point at the origin that can be computed more efficiently and more numerically stable. Parameters ---------- y : tensor target point v : tensor vector to be transported from the origin to y k : tensor sectional curvature of manifold dim : int reduction dimension for operations Returns ------- tensor """ return _parallel_transport0(y, v, k, dim=dim) @torch.jit.script def _parallel_transport0( y: torch.Tensor, v: torch.Tensor, k: torch.Tensor, dim: int = -1 ): return v * (1 + k * y.pow(2).sum(dim=dim, keepdim=True)).clamp_min(1e-15) def parallel_transport0back( x: torch.Tensor, v: torch.Tensor, *, k: torch.Tensor, dim: int = -1 ): r""" Perform parallel transport to the zero point. Special case parallel transport with last point at zero that can be computed more efficiently and numerically stable Parameters ---------- x : tensor target point v : tensor vector to be transported k : tensor sectional curvature of manifold dim : int reduction dimension for operations Returns ------- tensor """ return _parallel_transport0back(x, v, k=k, dim=dim) @torch.jit.script def _parallel_transport0back( x: torch.Tensor, v: torch.Tensor, k: torch.Tensor, dim: int = -1 ): return v / (1 + k * x.pow(2).sum(dim=dim, keepdim=True)).clamp_min(1e-15)
[docs]def egrad2rgrad(x: torch.Tensor, grad: torch.Tensor, *, k: torch.Tensor, dim=-1): r""" Convert the Euclidean gradient to the Riemannian gradient. .. math:: \nabla_x = \nabla^E_x / (\lambda_x^\kappa)^2 Parameters ---------- x : tensor point on the manifold grad : tensor Euclidean gradient for :math:`x` k : tensor sectional curvature of manifold dim : int reduction dimension for operations Returns ------- tensor Riemannian gradient :math:`u\in T_x\mathcal{M}_\kappa^n` """ return _egrad2rgrad(x, grad, k, dim=dim)
@torch.jit.script def _egrad2rgrad(x: torch.Tensor, grad: torch.Tensor, k: torch.Tensor, dim: int = -1): return grad / _lambda_x(x, k, keepdim=True, dim=dim) ** 2 def sproj(x: torch.Tensor, *, k: torch.Tensor, dim: int = -1): """ Stereographic Projection from hyperboloid or sphere. Parameters ---------- x : tensor point to be projected k : tensor constant sectional curvature dim : int dimension to operate on Returns ------- tensor the result of the projection """ return _sproj(x, k, dim=dim) @torch.jit.script def _sproj(x: torch.Tensor, k: torch.Tensor, dim: int = -1): inv_r = torch.sqrt(sabs(k)) factor = 1.0 / (1.0 + inv_r * x.narrow(dim, -1, 1)) proj = factor * x.narrow(dim, 0, x.size(dim) - 1) return proj def inv_sproj(x: torch.Tensor, *, k: torch.Tensor, dim: int = -1): """ Inverse of Stereographic Projection to hyperboloid or sphere. Parameters ---------- x : tensor point to be projected k : tensor constant sectional curvature dim : int dimension to operate on Returns ------- tensor the result of the projection """ return _inv_sproj(x, k, dim=dim) @torch.jit.script def _inv_sproj(x: torch.Tensor, k: torch.Tensor, dim: int = -1): inv_r = torch.sqrt(sabs(k)) lam_x = _lambda_x(x, k, keepdim=True, dim=dim) A = lam_x * x B = 1.0 / inv_r * (lam_x - 1.0) proj = torch.cat((A, B), dim=dim) return proj def antipode(x: torch.Tensor, *, k: torch.Tensor, dim: int = -1): r""" Compute the antipode of a point :math:`x_1,...,x_n` for :math:`\kappa > 0`. Let :math:`x` be a point on some sphere. Then :math:`-x` is its antipode. Since we're dealing with stereographic projections, for :math:`sproj(x)` we get the antipode :math:`sproj(-x)`. Which is given as follows: .. math:: \text{antipode}(x) = \frac{1+\kappa\|x\|^2_2}{2\kappa\|x\|^2_2}{}(-x) Parameters ---------- x : tensor points :math:`x_1,...,x_n` on manifold to compute antipode for k : tensor sectional curvature of manifold dim : int reduction dimension for operations Returns ------- tensor antipode """ return _antipode(x, k, dim=dim) @torch.jit.script def _antipode(x: torch.Tensor, k: torch.Tensor, dim: int = -1): # NOTE: implementation that uses stereographic projections seems to be less accurate # sproj(-inv_sproj(x)) if torch.all(k.le(0)): return -x v = x / x.norm(p=2, dim=dim, keepdim=True).clamp_min(1e-15) R = sabs(k).sqrt().reciprocal() pi = 3.141592653589793 a = _geodesic_unit(pi * R, x, v, k, dim=dim) return torch.where(k.gt(0), a, -x) def weighted_midpoint( xs: torch.Tensor, weights: Optional[torch.Tensor] = None, *, k: torch.Tensor, reducedim: Optional[List[int]] = None, dim: int = -1, keepdim: bool = False, lincomb: bool = False, posweight: bool = False, ): r""" Compute weighted Möbius gyromidpoint. The weighted Möbius gyromidpoint of a set of points :math:`x_1,...,x_n` according to weights :math:`\alpha_1,...,\alpha_n` is computed as follows: The weighted Möbius gyromidpoint is computed as follows .. math:: m_{\kappa}(x_1,\ldots,x_n,\alpha_1,\ldots,\alpha_n) = \frac{1}{2} \otimes_\kappa \left( \sum_{i=1}^n \frac{ \alpha_i\lambda_{x_i}^\kappa }{ \sum_{j=1}^n\alpha_j(\lambda_{x_j}^\kappa-1) } x_i \right) where the weights :math:`\alpha_1,...,\alpha_n` do not necessarily need to sum to 1 (only their relative weight matters). Note that this formula also requires to choose between the midpoint and its antipode for :math:`\kappa > 0`. Parameters ---------- xs : tensor points on poincare ball weights : tensor weights for averaging (make sure they broadcast correctly and manifold dimension is skipped) reducedim : int|list|tuple reduce dimension dim : int dimension to calculate conformal and Lorenz factors k : tensor constant sectional curvature keepdim : bool retain the last dim? (default: false) lincomb : bool linear combination implementation posweight : bool make all weights positive. Negative weight will weight antipode of entry with positive weight instead. This will give experimentally better numerics and nice interpolation properties for linear combination and averaging Returns ------- tensor Einstein midpoint in poincare coordinates """ return _weighted_midpoint( xs=xs, k=k, weights=weights, reducedim=reducedim, dim=dim, keepdim=keepdim, lincomb=lincomb, posweight=posweight, ) @torch.jit.script def _weighted_midpoint( xs: torch.Tensor, k: torch.Tensor, weights: Optional[torch.Tensor] = None, reducedim: Optional[List[int]] = None, dim: int = -1, keepdim: bool = False, lincomb: bool = False, posweight: bool = False, ): if reducedim is None: reducedim = list_range(xs.dim()) reducedim.pop(dim) gamma = _lambda_x(xs, k=k, dim=dim, keepdim=True) if weights is None: weights = torch.tensor(1.0, dtype=xs.dtype, device=xs.device) else: weights = weights.unsqueeze(dim) if posweight and weights.lt(0).any(): xs = torch.where(weights.lt(0), _antipode(xs, k=k, dim=dim), xs) weights = weights.abs() denominator = ((gamma - 1) * weights).sum(reducedim, keepdim=True) nominator = (gamma * weights * xs).sum(reducedim, keepdim=True) two_mean = nominator / clamp_abs(denominator, 1e-10) a_mean = _mobius_scalar_mul( torch.tensor(0.5, dtype=xs.dtype, device=xs.device), two_mean, k=k, dim=dim ) if torch.any(k.gt(0)): # check antipode b_mean = _antipode(a_mean, k, dim=dim) a_dist = _dist(a_mean, xs, k=k, keepdim=True, dim=dim).sum( reducedim, keepdim=True ) b_dist = _dist(b_mean, xs, k=k, keepdim=True, dim=dim).sum( reducedim, keepdim=True ) better = k.gt(0) & (b_dist < a_dist) a_mean = torch.where(better, b_mean, a_mean) if lincomb: if weights.numel() == 1: alpha = weights.clone() for d in reducedim: alpha *= xs.size(d) else: weights, _ = torch.broadcast_tensors(weights, gamma) alpha = weights.sum(reducedim, keepdim=True) a_mean = _mobius_scalar_mul(alpha, a_mean, k=k, dim=dim) if not keepdim: a_mean = drop_dims(a_mean, reducedim) return a_mean