# Source code for geoopt.manifolds.euclidean

from typing import Union, Tuple, Optional
import torch
from .base import Manifold, ScalingInfo
import geoopt

__all__ = ["Euclidean"]

[docs]class Euclidean(Manifold):
"""
Simple Euclidean manifold, every coordinate is treated as an independent element.

Parameters
----------
ndim : int
number of trailing dimensions treated as manifold dimensions. All the operations acting on cuch
as inner products, etc will respect the :attr:ndim.
"""

__scaling__ = Manifold.__scaling__.copy()
name = "Euclidean"
ndim = 0
reversible = True

def __init__(self, ndim=0):
super().__init__()
self.ndim = ndim

def _check_point_on_manifold(
self, x: torch.Tensor, *, atol=1e-5, rtol=1e-5
) -> Union[Tuple[bool, Optional[str]], bool]:
return True, None

def _check_vector_on_tangent(
self, x: torch.Tensor, u: torch.Tensor, *, atol=1e-5, rtol=1e-5
) -> Union[Tuple[bool, Optional[str]], bool]:
return True, None

[docs]    def retr(self, x: torch.Tensor, u: torch.Tensor) -> torch.Tensor:
return x + u

[docs]    def inner(
self, x: torch.Tensor, u: torch.Tensor, v: torch.Tensor = None, *, keepdim=False
) -> torch.Tensor:
if v is None:
inner = u.pow(2)
else:
inner = u * v
if self.ndim > 0:
inner = inner.sum(dim=tuple(range(-self.ndim, 0)), keepdim=keepdim)
x_shape = x.shape[: -self.ndim] + (1,) * self.ndim * keepdim
else:
x_shape = x.shape
i_shape = inner.shape
return inner.expand(target_shape)

[docs]    def component_inner(
self, x: torch.Tensor, u: torch.Tensor, v: torch.Tensor = None
) -> torch.Tensor:
# it is possible to factorize the manifold
if v is None:
inner = u.pow(2)
else:
inner = u * v
return inner.expand(target_shape)

[docs]    def norm(self, x: torch.Tensor, u: torch.Tensor, *, keepdim=False):
if self.ndim > 0:
return u.norm(dim=tuple(range(-self.ndim, 0)), keepdim=keepdim)
else:
return u.abs()

[docs]    def proju(self, x: torch.Tensor, u: torch.Tensor) -> torch.Tensor:
return u.expand(target_shape)

[docs]    def projx(self, x: torch.Tensor) -> torch.Tensor:
return x

[docs]    def logmap(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return y - x

[docs]    def dist(self, x: torch.Tensor, y: torch.Tensor, *, keepdim=False) -> torch.Tensor:
if self.ndim > 0:
return (x - y).norm(dim=tuple(range(-self.ndim, 0)), keepdim=keepdim)
else:
return (x - y).abs()

[docs]    def dist2(self, x: torch.Tensor, y: torch.Tensor, *, keepdim=False) -> torch.Tensor:
if self.ndim > 0:
return (x - y).pow(2).sum(dim=tuple(range(-self.ndim, 0)), keepdim=keepdim)
else:
return (x - y).pow(2)

return u.expand(target_shape)

[docs]    def expmap(self, x: torch.Tensor, u: torch.Tensor) -> torch.Tensor:
return x + u

[docs]    def transp(self, x: torch.Tensor, y: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
return v.expand(target_shape)

[docs]    @__scaling__(ScalingInfo(std=-1), "random")
def random_normal(
self, *size, mean=0.0, std=1.0, device=None, dtype=None
) -> "geoopt.ManifoldTensor":
"""
Create a point on the manifold, measure is induced by Normal distribution.

Parameters
----------
size : shape
the desired shape
mean : float|tensor
mean value for the Normal distribution
std : float|tensor
std value for the Normal distribution
device : torch.device
the desired device
dtype : torch.dtype
the desired dtype

Returns
-------
ManifoldTensor
random point on the manifold
"""
self._assert_check_shape(size2shape(*size), "x")
mean = torch.as_tensor(mean, device=device, dtype=dtype)
std = torch.as_tensor(std, device=device, dtype=dtype)
tens = std.new_empty(*size).normal_() * std + mean
return geoopt.ManifoldTensor(tens, manifold=self)

random = random_normal

[docs]    def origin(
self, *size, dtype=None, device=None, seed=42
) -> "geoopt.ManifoldTensor":
"""
Zero point origin.

Parameters
----------
size : shape
the desired shape
device : torch.device
the desired device
dtype : torch.dtype
the desired dtype
seed : int
ignored

Returns
-------
ManifoldTensor
"""
self._assert_check_shape(size2shape(*size), "x")
return geoopt.ManifoldTensor(
torch.zeros(*size, dtype=dtype, device=device), manifold=self
)

[docs]    def extra_repr(self):
return "ndim={}".format(self.ndim)