import torch.nn
from typing import Tuple, Union, Optional
import operator
import functools
import geoopt.utils
from ..utils import size2shape
from .base import Manifold, ScalingInfo
from .stereographic import Stereographic
__all__ = ["ProductManifold", "StereographicProductManifold"]
def _shape2size(shape: Tuple[int]):
return functools.reduce(operator.mul, shape, 1)
def _calculate_target_batch_dim(*dims: int):
return max(dims) - 1
[docs]class ProductManifold(Manifold):
"""
Product Manifold.
Examples
--------
A Torus
>>> import geoopt
>>> sphere = geoopt.Sphere()
>>> torus = ProductManifold((sphere, 2), (sphere, 2))
"""
ndim = 1
def __init__(
self,
*manifolds_with_shape: Tuple[Manifold, Union[Tuple[int, ...], int]],
):
if len(manifolds_with_shape) < 1:
raise ValueError(
"There should be at least one manifold in a product manifold"
)
super().__init__()
self.shapes = []
self.slices = []
name_parts = []
manifolds = []
dtype = None
device = None
pos0 = 0
for i, (manifold, shape) in enumerate(manifolds_with_shape):
# check shape consistency
shape = geoopt.utils.size2shape(shape)
ok, reason = manifold._check_shape(shape, str("{}'th shape".format(i)))
if not ok:
raise ValueError(reason)
# check device consistency
if manifold.device is not None and device is not None:
if device != manifold.device:
raise ValueError("Not all manifold share the same device")
elif device is None:
device = manifold.device
# check dtype consistency
if manifold.dtype is not None and dtype is not None:
if dtype != manifold.dtype:
raise ValueError("Not all manifold share the same dtype")
elif dtype is None:
dtype = manifold.dtype
name_parts.append(manifold.name)
manifolds.append(manifold)
self.shapes.append(shape)
pos1 = pos0 + _shape2size(shape)
self.slices.append(slice(pos0, pos1))
pos0 = pos1
self.name = "x".join(["({})".format(name) for name in name_parts])
self.n_elements = pos0
self.n_manifolds = len(manifolds)
self.manifolds = torch.nn.ModuleList(manifolds)
@property
def reversible(self) -> bool:
return all(m.reversible for m in self.manifolds)
[docs] def take_submanifold_value(
self, x: torch.Tensor, i: int, reshape=True
) -> torch.Tensor:
"""
Take i'th slice of the ambient tensor and possibly reshape.
Parameters
----------
x : tensor
Ambient tensor
i : int
submanifold index
reshape : bool
reshape the slice?
Returns
-------
torch.Tensor
"""
slc = self.slices[i]
part = x.narrow(-1, slc.start, slc.stop - slc.start)
if reshape:
part = part.reshape((*part.shape[:-1], *self.shapes[i]))
return part
def _check_shape(self, shape: Tuple[int], name: str) -> Tuple[bool, Optional[str]]:
ok = shape[-1] == self.n_elements
if not ok:
return (
ok,
"The last dimension should be equal to {}, but got {}".format(
self.n_elements, shape[-1]
),
)
return ok, None
def _check_point_on_manifold(
self, x: torch.Tensor, *, atol=1e-5, rtol=1e-5
) -> Tuple[bool, Optional[str]]:
ok, reason = True, None
for i, manifold in enumerate(self.manifolds):
point = self.take_submanifold_value(x, i)
ok, reason = manifold.check_point_on_manifold(
point, atol=atol, rtol=rtol, explain=True
)
if not ok:
break
return ok, reason
def _check_vector_on_tangent(
self, x, u, *, atol=1e-5, rtol=1e-5
) -> Tuple[bool, Optional[str]]:
ok, reason = True, None
for i, manifold in enumerate(self.manifolds):
point = self.take_submanifold_value(x, i)
tangent = self.take_submanifold_value(u, i)
ok, reason = manifold.check_vector_on_tangent(
point, tangent, atol=atol, rtol=rtol, explain=True
)
if not ok:
break
return ok, reason
[docs] def inner(
self, x: torch.Tensor, u: torch.Tensor, v=None, *, keepdim=False
) -> torch.Tensor:
if v is not None:
target_batch_dim = _calculate_target_batch_dim(x.dim(), u.dim(), v.dim())
else:
target_batch_dim = _calculate_target_batch_dim(x.dim(), u.dim())
products = []
for i, manifold in enumerate(self.manifolds):
point = self.take_submanifold_value(x, i)
u_vec = self.take_submanifold_value(u, i)
if v is not None:
v_vec = self.take_submanifold_value(v, i)
else:
v_vec = None
inner = manifold.inner(point, u_vec, v_vec, keepdim=True)
inner = inner.view(*inner.shape[:target_batch_dim], -1).sum(-1)
products.append(inner)
result = sum(products)
if keepdim:
result = torch.unsqueeze(result, -1)
return result
[docs] def component_inner(self, x: torch.Tensor, u: torch.Tensor, v=None) -> torch.Tensor:
products = []
for i, manifold in enumerate(self.manifolds):
point = self.take_submanifold_value(x, i)
u_vec = self.take_submanifold_value(u, i)
target_shape = geoopt.utils.broadcast_shapes(point.shape, u_vec.shape)
if v is not None:
v_vec = self.take_submanifold_value(v, i)
else:
v_vec = None
inner = manifold.component_inner(point, u_vec, v_vec)
inner = inner.expand(target_shape)
products.append(inner)
result = self.pack_point(*products)
return result
[docs] def projx(self, x: torch.Tensor) -> torch.Tensor:
projected = []
for i, manifold in enumerate(self.manifolds):
point = self.take_submanifold_value(x, i)
proj = manifold.projx(point)
proj = proj.view(*x.shape[: len(x.shape) - 1], -1)
projected.append(proj)
return torch.cat(projected, -1)
[docs] def proju(self, x: torch.Tensor, u: torch.Tensor) -> torch.Tensor:
target_batch_dim = _calculate_target_batch_dim(x.dim(), u.dim())
projected = []
for i, manifold in enumerate(self.manifolds):
point = self.take_submanifold_value(x, i)
tangent = self.take_submanifold_value(u, i)
proj = manifold.proju(point, tangent)
proj = proj.reshape((*proj.shape[:target_batch_dim], -1))
projected.append(proj)
return torch.cat(projected, -1)
[docs] def expmap(self, x: torch.Tensor, u: torch.Tensor) -> torch.Tensor:
target_batch_dim = _calculate_target_batch_dim(x.dim(), u.dim())
mapped_tensors = []
for i, manifold in enumerate(self.manifolds):
point = self.take_submanifold_value(x, i)
tangent = self.take_submanifold_value(u, i)
mapped = manifold.expmap(point, tangent)
mapped = mapped.reshape((*mapped.shape[:target_batch_dim], -1))
mapped_tensors.append(mapped)
return torch.cat(mapped_tensors, -1)
[docs] def retr(self, x: torch.Tensor, u: torch.Tensor) -> torch.Tensor:
target_batch_dim = _calculate_target_batch_dim(x.dim(), u.dim())
mapped_tensors = []
for i, manifold in enumerate(self.manifolds):
point = self.take_submanifold_value(x, i)
tangent = self.take_submanifold_value(u, i)
mapped = manifold.retr(point, tangent)
mapped = mapped.reshape((*mapped.shape[:target_batch_dim], -1))
mapped_tensors.append(mapped)
return torch.cat(mapped_tensors, -1)
[docs] def transp(self, x: torch.Tensor, y: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
target_batch_dim = _calculate_target_batch_dim(x.dim(), y.dim(), v.dim())
transported_tensors = []
for i, manifold in enumerate(self.manifolds):
point = self.take_submanifold_value(x, i)
point1 = self.take_submanifold_value(y, i)
tangent = self.take_submanifold_value(v, i)
transported = manifold.transp(point, point1, tangent)
transported = transported.reshape(
(*transported.shape[:target_batch_dim], -1)
)
transported_tensors.append(transported)
return torch.cat(transported_tensors, -1)
[docs] def logmap(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
target_batch_dim = _calculate_target_batch_dim(x.dim(), y.dim())
logmapped_tensors = []
for i, manifold in enumerate(self.manifolds):
point = self.take_submanifold_value(x, i)
point1 = self.take_submanifold_value(y, i)
logmapped = manifold.logmap(point, point1)
logmapped = logmapped.reshape((*logmapped.shape[:target_batch_dim], -1))
logmapped_tensors.append(logmapped)
return torch.cat(logmapped_tensors, -1)
[docs] def transp_follow_retr(
self, x: torch.Tensor, u: torch.Tensor, v: torch.Tensor
) -> torch.Tensor:
target_batch_dim = _calculate_target_batch_dim(x.dim(), u.dim(), v.dim())
results = []
for i, manifold in enumerate(self.manifolds):
point = self.take_submanifold_value(x, i)
direction = self.take_submanifold_value(u, i)
vector = self.take_submanifold_value(v, i)
transported = manifold.transp_follow_retr(point, direction, vector)
transported = transported.reshape(
(*transported.shape[:target_batch_dim], -1)
)
results.append(transported)
return torch.cat(results, -1)
[docs] def transp_follow_expmap(
self, x: torch.Tensor, u: torch.Tensor, v: torch.Tensor
) -> torch.Tensor:
target_batch_dim = _calculate_target_batch_dim(x.dim(), u.dim(), v.dim())
results = []
for i, manifold in enumerate(self.manifolds):
point = self.take_submanifold_value(x, i)
direction = self.take_submanifold_value(u, i)
vector = self.take_submanifold_value(v, i)
transported = manifold.transp_follow_expmap(point, direction, vector)
transported = transported.reshape(
(*transported.shape[:target_batch_dim], -1)
)
results.append(transported)
return torch.cat(results, -1)
[docs] def expmap_transp(
self, x: torch.Tensor, u: torch.Tensor, v: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
target_batch_dim = _calculate_target_batch_dim(x.dim(), u.dim(), v.dim())
results = []
for i, manifold in enumerate(self.manifolds):
point = self.take_submanifold_value(x, i)
direction = self.take_submanifold_value(u, i)
vector = self.take_submanifold_value(v, i)
new_point, transported = manifold.expmap_transp(point, direction, vector)
transported = transported.reshape(
(*transported.shape[:target_batch_dim], -1)
)
new_point = new_point.reshape((*new_point.shape[:target_batch_dim], -1))
results.append((new_point, transported))
points, vectors = zip(*results)
return torch.cat(points, -1), torch.cat(vectors, -1)
[docs] def retr_transp(self, x: torch.Tensor, u: torch.Tensor, v: torch.Tensor):
target_batch_dim = _calculate_target_batch_dim(x.dim(), u.dim(), v.dim())
results = []
for i, manifold in enumerate(self.manifolds):
point = self.take_submanifold_value(x, i)
direction = self.take_submanifold_value(u, i)
vector = self.take_submanifold_value(v, i)
new_point, transported = manifold.retr_transp(point, direction, vector)
transported = transported.reshape(
(*transported.shape[:target_batch_dim], -1)
)
new_point = new_point.reshape((*new_point.shape[:target_batch_dim], -1))
results.append((new_point, transported))
points, vectors = zip(*results)
return torch.cat(points, -1), torch.cat(vectors, -1)
[docs] def dist2(self, x: torch.Tensor, y: torch.Tensor, *, keepdim=False):
target_batch_dim = _calculate_target_batch_dim(x.dim(), y.dim())
mini_dists2 = []
for i, manifold in enumerate(self.manifolds):
point = self.take_submanifold_value(x, i)
point1 = self.take_submanifold_value(y, i)
mini_dist2 = manifold.dist2(point, point1, keepdim=True)
mini_dist2 = mini_dist2.reshape(
(*mini_dist2.shape[:target_batch_dim], -1)
).sum(-1)
mini_dists2.append(mini_dist2)
result = sum(mini_dists2)
if keepdim:
result = torch.unsqueeze(result, -1)
return result
[docs] def dist(self, x, y, *, keepdim=False):
return self.dist2(x, y, keepdim=keepdim).clamp_min_(1e-15) ** 0.5
[docs] def egrad2rgrad(self, x: torch.Tensor, u: torch.Tensor):
target_batch_dim = _calculate_target_batch_dim(x.dim(), u.dim())
transformed_tensors = []
for i, manifold in enumerate(self.manifolds):
point = self.take_submanifold_value(x, i)
grad = self.take_submanifold_value(u, i)
transformed = manifold.egrad2rgrad(point, grad)
transformed = transformed.reshape(
(*transformed.shape[:target_batch_dim], -1)
)
transformed_tensors.append(transformed)
return torch.cat(transformed_tensors, -1)
[docs] def unpack_tensor(self, tensor: torch.Tensor) -> Tuple[torch.Tensor]:
parts = []
for i in range(self.n_manifolds):
part = self.take_submanifold_value(tensor, i)
parts.append(part)
return tuple(parts)
[docs] def pack_point(self, *tensors: torch.Tensor) -> torch.Tensor:
if len(tensors) != len(self.manifolds):
raise ValueError(
"{} tensors expected, got {}".format(len(self.manifolds), len(tensors))
)
flattened = []
for i in range(self.n_manifolds):
part = tensors[i]
shape = self.shapes[i]
if len(shape) > 0:
if part.shape[-len(shape) :] != shape:
raise ValueError(
"last shape dimension does not seem to be valid. {} required, but got {}".format(
part.shape[-len(shape) :], shape
)
)
new_shape = (*part.shape[: -len(shape)], -1)
else:
new_shape = (*part.shape, -1)
flattened.append(part.reshape(new_shape))
return torch.cat(flattened, -1)
[docs] @classmethod
def from_point(cls, *parts: "geoopt.ManifoldTensor", batch_dims=0):
"""
Construct Product manifold from given points.
Parameters
----------
parts : tuple[geoopt.ManifoldTensor]
Manifold tensors to construct Product manifold from
batch_dims : int
number of first dims to treat as batch dims and not include in the Product manifold
Returns
-------
ProductManifold
"""
batch_shape = None
init = []
for tens in parts:
manifold = tens.manifold
if batch_shape is None:
batch_shape = tens.shape[:batch_dims]
elif not batch_shape == tens.shape[:batch_dims]:
raise ValueError("Not all parts have same batch shape")
init.append((manifold, tens.shape[batch_dims:]))
return cls(*init)
def random_combined(
self, *size, dtype=None, device=None
) -> "geoopt.ManifoldTensor":
shape = geoopt.utils.size2shape(*size)
self._assert_check_shape(shape, "x")
batch_shape = shape[:-1]
points = []
for manifold, shape in zip(self.manifolds, self.shapes):
points.append(
manifold.random(batch_shape + shape, dtype=dtype, device=device)
)
tensor = self.pack_point(*points)
return geoopt.ManifoldTensor(tensor, manifold=self)
random = random_combined
[docs] def origin(
self, *size, dtype=None, device=None, seed=42
) -> "geoopt.ManifoldTensor":
shape = geoopt.utils.size2shape(*size)
self._assert_check_shape(shape, "x")
batch_shape = shape[:-1]
points = []
for manifold, shape in zip(self.manifolds, self.shapes):
points.append(
manifold.origin(
batch_shape + shape, dtype=dtype, device=device, seed=seed
)
)
tensor = self.pack_point(*points)
return geoopt.ManifoldTensor(tensor, manifold=self)
class StereographicProductManifold(ProductManifold):
"""
Product Manifold for Stereographic manifolds.
Examples
--------
A Torus
>>> import geoopt
>>> sphere = geoopt.SphereProjection()
>>> torus = StereographicProductManifold((sphere, 2), (sphere, 2))
"""
__scaling__ = Stereographic.__scaling__.copy()
def __init__(
self,
*manifolds_with_shape: Tuple[Stereographic, Union[Tuple[int, ...], int]],
):
super().__init__(*manifolds_with_shape)
for man in self.manifolds:
if not geoopt.utils.ismanifold(man, Stereographic):
raise TypeError("Every submanifold has to be Stereographic manifold")
def dist2plane(
self,
x: torch.Tensor,
p: torch.Tensor,
a: torch.Tensor,
*,
keepdim=False,
signed=False,
scaled=False,
) -> torch.Tensor:
dists = []
for i, manifold in enumerate(self.manifolds):
dists.append(
manifold.dist2plane(
self.take_submanifold_value(x, i),
self.take_submanifold_value(p, i),
self.take_submanifold_value(a, i),
dim=-1,
keepdim=keepdim,
signed=signed,
scaled=scaled,
)
)
dists = torch.stack(dists, -1)
return (dists**2).sum(axis=-1).sqrt()
def mobius_add(
self,
x: torch.Tensor,
y: torch.Tensor,
*,
project=True,
) -> torch.Tensor:
return self._mobius_2_manifold_args(x, y, "mobius_add", project=project)
def mobius_coadd(
self,
x: torch.Tensor,
y: torch.Tensor,
*,
project=True,
) -> torch.Tensor:
return self._mobius_2_manifold_args(x, y, "mobius_coadd", project=project)
def mobius_sub(
self,
x: torch.Tensor,
y: torch.Tensor,
*,
project=True,
) -> torch.Tensor:
return self._mobius_2_manifold_args(x, y, "mobius_sub", project=project)
def mobius_cosub(
self,
x: torch.Tensor,
y: torch.Tensor,
*,
project=True,
) -> torch.Tensor:
return self._mobius_2_manifold_args(x, y, "mobius_cosub", project=project)
def _mobius_2_manifold_args(
self,
x: torch.Tensor,
y: torch.Tensor,
kind,
*,
project=True,
) -> torch.Tensor:
target_batch_dim = _calculate_target_batch_dim(x.dim(), y.dim())
mapped_tensors = []
for i, manifold in enumerate(self.manifolds):
x_ = self.take_submanifold_value(x, i)
y_ = self.take_submanifold_value(y, i)
mapped = getattr(manifold, kind)(x_, y_, dim=-1, project=project)
mapped = mapped.reshape((*mapped.shape[:target_batch_dim], -1))
mapped_tensors.append(mapped)
return self.pack_point(*mapped_tensors)
def mobius_scalar_mul(
self,
r: torch.Tensor,
x: torch.Tensor,
*,
project=True,
) -> torch.Tensor:
mapped_tensors = []
for i, manifold in enumerate(self.manifolds):
x_ = self.take_submanifold_value(x, i)
mapped_tensors.append(manifold.mobius_scalar_mul(r, x_, project=project))
return self.pack_point(*mapped_tensors)
def mobius_pointwise_mul(
self,
w: torch.Tensor,
x: torch.Tensor,
*,
project=True,
) -> torch.Tensor:
mapped_tensors = []
for i, manifold in enumerate(self.manifolds):
w_ = self.take_submanifold_value(w, i)
x_ = self.take_submanifold_value(x, i)
mapped_tensors.append(
manifold.mobius_pointwise_mul(w_, x_, project=project)
)
return self.pack_point(*mapped_tensors)
def take_submanifold_matrix(
self, x: torch.Tensor, i: int, reshape=True
) -> torch.Tensor:
"""
Take i'th slice of the ambient tensor and possibly reshape.
Parameters
----------
x : tensor
Ambient tensor
i : int
submanifold index
reshape : bool
reshape the slice?
Returns
-------
torch.Tensor
"""
slc = self.slices[i]
part = x[..., slc, slc]
if reshape:
part = part.reshape((*part.shape[:-2], *self.shapes[i], *self.shapes[i]))
return part
def mobius_matvec(
self,
m: torch.Tensor,
x: torch.Tensor,
*,
project=True,
) -> torch.Tensor:
mapped_tensors = []
for i, manifold in enumerate(self.manifolds):
m_ = self.take_submanifold_matrix(m, i)
x_ = self.take_submanifold_value(x, i)
mapped_tensors.append(manifold.mobius_matvec(m_, x_, project=project))
return self.pack_point(*mapped_tensors)
@__scaling__(ScalingInfo(std=-1))
def wrapped_normal(
self,
*size,
mean: torch.Tensor,
std: Union[torch.Tensor, int, float] = 1,
dtype=None,
device=None,
) -> "geoopt.ManifoldTensor":
shape = size2shape(*size)
self._assert_check_shape(shape, "x")
batch_shape = shape[:-1]
if type(std) == int or type(std) == float:
std = torch.zeros(mean.shape[-1]).type_as(mean) * std
points = []
for i, (manifold, shape) in enumerate(zip(self.manifolds, self.shapes)):
points.append(
manifold.wrapped_normal(
*(batch_shape + shape),
mean=self.take_submanifold_value(mean, i),
std=self.take_submanifold_value(std, i),
dtype=dtype,
device=device,
)
)
tensor = self.pack_point(*points)
return geoopt.ManifoldTensor(tensor, manifold=self)
def geodesic(
self,
t: torch.Tensor,
x: torch.Tensor,
y: torch.Tensor,
*,
dim=-1,
) -> torch.Tensor:
res_list = []
for i, manifold in enumerate(self.manifolds):
x_ = self.take_submanifold_value(x, i)
y_ = self.take_submanifold_value(y, i)
res = manifold.geodesic(t, x_, y_, dim=-1)
res_list.append(res)
return self.pack_point(*res_list)
def geodesic_unit(
self,
t: torch.Tensor,
x: torch.Tensor,
u: torch.Tensor,
*,
project=True,
) -> torch.Tensor:
res_list = []
for i, manifold in enumerate(self.manifolds):
x_ = self.take_submanifold_value(x, i)
u_ = self.take_submanifold_value(u, i)
res = manifold.geodesic_unit(t, x_, u_, dim=-1, project=project)
res_list.append(res)
return self.pack_point(*res_list)
def dist0(self, x: torch.Tensor, *, keepdim=False) -> torch.Tensor:
res = []
for i, manifold in enumerate(self.manifolds):
x_ = self.take_submanifold_value(x, i)
res.append(manifold.dist0(x_) ** 2)
res = sum(res) ** 0.5
if keepdim:
res = torch.unsqueeze(res, -1)
return res
def expmap0(self, u: torch.Tensor, *, project=True) -> torch.Tensor:
res = []
for i, manifold in enumerate(self.manifolds):
u_ = self.take_submanifold_value(u, i)
res.append(manifold.expmap0(u_, dim=-1, project=project))
return self.pack_point(*res)
def logmap0(self, x: torch.Tensor, *, project=True) -> torch.Tensor:
res = []
for i, manifold in enumerate(self.manifolds):
x_ = self.take_submanifold_value(x, i)
res.append(manifold.logmap0(x_, dim=-1))
return self.pack_point(*res)
def transp0(self, y: torch.Tensor, u: torch.Tensor) -> torch.Tensor:
res = []
for i, manifold in enumerate(self.manifolds):
y_ = self.take_submanifold_value(y, i)
u_ = self.take_submanifold_value(u, i)
res.append(manifold.transp0(y_, u_, dim=-1))
return self.pack_point(*res)
def transp0back(self, x: torch.Tensor, u: torch.Tensor) -> torch.Tensor:
res = []
for i, manifold in enumerate(self.manifolds):
x_ = self.take_submanifold_value(x, i)
u_ = self.take_submanifold_value(u, i)
res.append(manifold.transp0back(x_, u_, dim=-1))
return self.pack_point(*res)
def gyration(
self, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, *, project=True
) -> torch.Tensor:
res = []
for i, manifold in enumerate(self.manifolds):
x_ = self.take_submanifold_value(x, i)
y_ = self.take_submanifold_value(y, i)
z_ = self.take_submanifold_value(z, i)
res.append(manifold.gyration(x_, y_, z_, dim=-1))
return self.pack_point(*res)
def antipode(self, x: torch.Tensor, *, project=True) -> torch.Tensor:
res = []
for i, manifold in enumerate(self.manifolds):
x_ = self.take_submanifold_value(x, i)
res.append(manifold.antipode(x_, dim=-1))
return self.pack_point(*res)
def mobius_fn_apply(
self, fn: callable, x: torch.Tensor, *args, project=True, **kwargs
) -> torch.Tensor:
res = []
for i, manifold in enumerate(self.manifolds):
x_ = self.take_submanifold_value(x, i)
res.append(
manifold.mobius_fn_apply(
fn, x_, *args, dim=-1, project=project, **kwargs
)
)
return self.pack_point(*res)
def mobius_fn_apply_chain(
self,
x: torch.Tensor,
*fns: callable,
project=True,
) -> torch.Tensor:
res = []
for i, manifold in enumerate(self.manifolds):
x_ = self.take_submanifold_value(x, i)
res.append(
manifold.mobius_fn_apply_chain(
x_,
*fns,
dim=-1,
project=project,
)
)
return self.pack_point(*res)