Source code for geoopt.manifolds.product

import torch.nn
from typing import Tuple, Union, Optional
import operator
import functools
import geoopt.utils
from ..utils import size2shape
from .base import Manifold, ScalingInfo
from .stereographic import Stereographic


__all__ = ["ProductManifold", "StereographicProductManifold"]


def _shape2size(shape: Tuple[int]):
    return functools.reduce(operator.mul, shape, 1)


def _calculate_target_batch_dim(*dims: int):
    return max(dims) - 1


[docs]class ProductManifold(Manifold): """ Product Manifold. Examples -------- A Torus >>> import geoopt >>> sphere = geoopt.Sphere() >>> torus = ProductManifold((sphere, 2), (sphere, 2)) """ ndim = 1 def __init__( self, *manifolds_with_shape: Tuple[Manifold, Union[Tuple[int, ...], int]], ): if len(manifolds_with_shape) < 1: raise ValueError( "There should be at least one manifold in a product manifold" ) super().__init__() self.shapes = [] self.slices = [] name_parts = [] manifolds = [] dtype = None device = None pos0 = 0 for i, (manifold, shape) in enumerate(manifolds_with_shape): # check shape consistency shape = geoopt.utils.size2shape(shape) ok, reason = manifold._check_shape(shape, str("{}'th shape".format(i))) if not ok: raise ValueError(reason) # check device consistency if manifold.device is not None and device is not None: if device != manifold.device: raise ValueError("Not all manifold share the same device") elif device is None: device = manifold.device # check dtype consistency if manifold.dtype is not None and dtype is not None: if dtype != manifold.dtype: raise ValueError("Not all manifold share the same dtype") elif dtype is None: dtype = manifold.dtype name_parts.append(manifold.name) manifolds.append(manifold) self.shapes.append(shape) pos1 = pos0 + _shape2size(shape) self.slices.append(slice(pos0, pos1)) pos0 = pos1 self.name = "x".join(["({})".format(name) for name in name_parts]) self.n_elements = pos0 self.n_manifolds = len(manifolds) self.manifolds = torch.nn.ModuleList(manifolds) @property def reversible(self) -> bool: return all(m.reversible for m in self.manifolds)
[docs] def take_submanifold_value( self, x: torch.Tensor, i: int, reshape=True ) -> torch.Tensor: """ Take i'th slice of the ambient tensor and possibly reshape. Parameters ---------- x : tensor Ambient tensor i : int submanifold index reshape : bool reshape the slice? Returns ------- torch.Tensor """ slc = self.slices[i] part = x.narrow(-1, slc.start, slc.stop - slc.start) if reshape: part = part.reshape((*part.shape[:-1], *self.shapes[i])) return part
def _check_shape(self, shape: Tuple[int], name: str) -> Tuple[bool, Optional[str]]: ok = shape[-1] == self.n_elements if not ok: return ( ok, "The last dimension should be equal to {}, but got {}".format( self.n_elements, shape[-1] ), ) return ok, None def _check_point_on_manifold( self, x: torch.Tensor, *, atol=1e-5, rtol=1e-5 ) -> Tuple[bool, Optional[str]]: ok, reason = True, None for i, manifold in enumerate(self.manifolds): point = self.take_submanifold_value(x, i) ok, reason = manifold.check_point_on_manifold( point, atol=atol, rtol=rtol, explain=True ) if not ok: break return ok, reason def _check_vector_on_tangent( self, x, u, *, atol=1e-5, rtol=1e-5 ) -> Tuple[bool, Optional[str]]: ok, reason = True, None for i, manifold in enumerate(self.manifolds): point = self.take_submanifold_value(x, i) tangent = self.take_submanifold_value(u, i) ok, reason = manifold.check_vector_on_tangent( point, tangent, atol=atol, rtol=rtol, explain=True ) if not ok: break return ok, reason
[docs] def inner( self, x: torch.Tensor, u: torch.Tensor, v=None, *, keepdim=False ) -> torch.Tensor: if v is not None: target_batch_dim = _calculate_target_batch_dim(x.dim(), u.dim(), v.dim()) else: target_batch_dim = _calculate_target_batch_dim(x.dim(), u.dim()) products = [] for i, manifold in enumerate(self.manifolds): point = self.take_submanifold_value(x, i) u_vec = self.take_submanifold_value(u, i) if v is not None: v_vec = self.take_submanifold_value(v, i) else: v_vec = None inner = manifold.inner(point, u_vec, v_vec, keepdim=True) inner = inner.view(*inner.shape[:target_batch_dim], -1).sum(-1) products.append(inner) result = sum(products) if keepdim: result = torch.unsqueeze(result, -1) return result
[docs] def component_inner(self, x: torch.Tensor, u: torch.Tensor, v=None) -> torch.Tensor: products = [] for i, manifold in enumerate(self.manifolds): point = self.take_submanifold_value(x, i) u_vec = self.take_submanifold_value(u, i) target_shape = geoopt.utils.broadcast_shapes(point.shape, u_vec.shape) if v is not None: v_vec = self.take_submanifold_value(v, i) else: v_vec = None inner = manifold.component_inner(point, u_vec, v_vec) inner = inner.expand(target_shape) products.append(inner) result = self.pack_point(*products) return result
[docs] def projx(self, x: torch.Tensor) -> torch.Tensor: projected = [] for i, manifold in enumerate(self.manifolds): point = self.take_submanifold_value(x, i) proj = manifold.projx(point) proj = proj.view(*x.shape[: len(x.shape) - 1], -1) projected.append(proj) return torch.cat(projected, -1)
[docs] def proju(self, x: torch.Tensor, u: torch.Tensor) -> torch.Tensor: target_batch_dim = _calculate_target_batch_dim(x.dim(), u.dim()) projected = [] for i, manifold in enumerate(self.manifolds): point = self.take_submanifold_value(x, i) tangent = self.take_submanifold_value(u, i) proj = manifold.proju(point, tangent) proj = proj.reshape((*proj.shape[:target_batch_dim], -1)) projected.append(proj) return torch.cat(projected, -1)
[docs] def expmap(self, x: torch.Tensor, u: torch.Tensor) -> torch.Tensor: target_batch_dim = _calculate_target_batch_dim(x.dim(), u.dim()) mapped_tensors = [] for i, manifold in enumerate(self.manifolds): point = self.take_submanifold_value(x, i) tangent = self.take_submanifold_value(u, i) mapped = manifold.expmap(point, tangent) mapped = mapped.reshape((*mapped.shape[:target_batch_dim], -1)) mapped_tensors.append(mapped) return torch.cat(mapped_tensors, -1)
[docs] def retr(self, x: torch.Tensor, u: torch.Tensor) -> torch.Tensor: target_batch_dim = _calculate_target_batch_dim(x.dim(), u.dim()) mapped_tensors = [] for i, manifold in enumerate(self.manifolds): point = self.take_submanifold_value(x, i) tangent = self.take_submanifold_value(u, i) mapped = manifold.retr(point, tangent) mapped = mapped.reshape((*mapped.shape[:target_batch_dim], -1)) mapped_tensors.append(mapped) return torch.cat(mapped_tensors, -1)
[docs] def transp(self, x: torch.Tensor, y: torch.Tensor, v: torch.Tensor) -> torch.Tensor: target_batch_dim = _calculate_target_batch_dim(x.dim(), y.dim(), v.dim()) transported_tensors = [] for i, manifold in enumerate(self.manifolds): point = self.take_submanifold_value(x, i) point1 = self.take_submanifold_value(y, i) tangent = self.take_submanifold_value(v, i) transported = manifold.transp(point, point1, tangent) transported = transported.reshape( (*transported.shape[:target_batch_dim], -1) ) transported_tensors.append(transported) return torch.cat(transported_tensors, -1)
[docs] def logmap(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: target_batch_dim = _calculate_target_batch_dim(x.dim(), y.dim()) logmapped_tensors = [] for i, manifold in enumerate(self.manifolds): point = self.take_submanifold_value(x, i) point1 = self.take_submanifold_value(y, i) logmapped = manifold.logmap(point, point1) logmapped = logmapped.reshape((*logmapped.shape[:target_batch_dim], -1)) logmapped_tensors.append(logmapped) return torch.cat(logmapped_tensors, -1)
[docs] def transp_follow_retr( self, x: torch.Tensor, u: torch.Tensor, v: torch.Tensor ) -> torch.Tensor: target_batch_dim = _calculate_target_batch_dim(x.dim(), u.dim(), v.dim()) results = [] for i, manifold in enumerate(self.manifolds): point = self.take_submanifold_value(x, i) direction = self.take_submanifold_value(u, i) vector = self.take_submanifold_value(v, i) transported = manifold.transp_follow_retr(point, direction, vector) transported = transported.reshape( (*transported.shape[:target_batch_dim], -1) ) results.append(transported) return torch.cat(results, -1)
[docs] def transp_follow_expmap( self, x: torch.Tensor, u: torch.Tensor, v: torch.Tensor ) -> torch.Tensor: target_batch_dim = _calculate_target_batch_dim(x.dim(), u.dim(), v.dim()) results = [] for i, manifold in enumerate(self.manifolds): point = self.take_submanifold_value(x, i) direction = self.take_submanifold_value(u, i) vector = self.take_submanifold_value(v, i) transported = manifold.transp_follow_expmap(point, direction, vector) transported = transported.reshape( (*transported.shape[:target_batch_dim], -1) ) results.append(transported) return torch.cat(results, -1)
[docs] def expmap_transp( self, x: torch.Tensor, u: torch.Tensor, v: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: target_batch_dim = _calculate_target_batch_dim(x.dim(), u.dim(), v.dim()) results = [] for i, manifold in enumerate(self.manifolds): point = self.take_submanifold_value(x, i) direction = self.take_submanifold_value(u, i) vector = self.take_submanifold_value(v, i) new_point, transported = manifold.expmap_transp(point, direction, vector) transported = transported.reshape( (*transported.shape[:target_batch_dim], -1) ) new_point = new_point.reshape((*new_point.shape[:target_batch_dim], -1)) results.append((new_point, transported)) points, vectors = zip(*results) return torch.cat(points, -1), torch.cat(vectors, -1)
[docs] def retr_transp(self, x: torch.Tensor, u: torch.Tensor, v: torch.Tensor): target_batch_dim = _calculate_target_batch_dim(x.dim(), u.dim(), v.dim()) results = [] for i, manifold in enumerate(self.manifolds): point = self.take_submanifold_value(x, i) direction = self.take_submanifold_value(u, i) vector = self.take_submanifold_value(v, i) new_point, transported = manifold.retr_transp(point, direction, vector) transported = transported.reshape( (*transported.shape[:target_batch_dim], -1) ) new_point = new_point.reshape((*new_point.shape[:target_batch_dim], -1)) results.append((new_point, transported)) points, vectors = zip(*results) return torch.cat(points, -1), torch.cat(vectors, -1)
[docs] def dist2(self, x: torch.Tensor, y: torch.Tensor, *, keepdim=False): target_batch_dim = _calculate_target_batch_dim(x.dim(), y.dim()) mini_dists2 = [] for i, manifold in enumerate(self.manifolds): point = self.take_submanifold_value(x, i) point1 = self.take_submanifold_value(y, i) mini_dist2 = manifold.dist2(point, point1, keepdim=True) mini_dist2 = mini_dist2.reshape( (*mini_dist2.shape[:target_batch_dim], -1) ).sum(-1) mini_dists2.append(mini_dist2) result = sum(mini_dists2) if keepdim: result = torch.unsqueeze(result, -1) return result
[docs] def dist(self, x, y, *, keepdim=False): return self.dist2(x, y, keepdim=keepdim).clamp_min_(1e-15) ** 0.5
[docs] def egrad2rgrad(self, x: torch.Tensor, u: torch.Tensor): target_batch_dim = _calculate_target_batch_dim(x.dim(), u.dim()) transformed_tensors = [] for i, manifold in enumerate(self.manifolds): point = self.take_submanifold_value(x, i) grad = self.take_submanifold_value(u, i) transformed = manifold.egrad2rgrad(point, grad) transformed = transformed.reshape( (*transformed.shape[:target_batch_dim], -1) ) transformed_tensors.append(transformed) return torch.cat(transformed_tensors, -1)
[docs] def unpack_tensor(self, tensor: torch.Tensor) -> Tuple[torch.Tensor]: parts = [] for i in range(self.n_manifolds): part = self.take_submanifold_value(tensor, i) parts.append(part) return tuple(parts)
[docs] def pack_point(self, *tensors: torch.Tensor) -> torch.Tensor: if len(tensors) != len(self.manifolds): raise ValueError( "{} tensors expected, got {}".format(len(self.manifolds), len(tensors)) ) flattened = [] for i in range(self.n_manifolds): part = tensors[i] shape = self.shapes[i] if len(shape) > 0: if part.shape[-len(shape) :] != shape: raise ValueError( "last shape dimension does not seem to be valid. {} required, but got {}".format( part.shape[-len(shape) :], shape ) ) new_shape = (*part.shape[: -len(shape)], -1) else: new_shape = (*part.shape, -1) flattened.append(part.reshape(new_shape)) return torch.cat(flattened, -1)
[docs] @classmethod def from_point(cls, *parts: "geoopt.ManifoldTensor", batch_dims=0): """ Construct Product manifold from given points. Parameters ---------- parts : tuple[geoopt.ManifoldTensor] Manifold tensors to construct Product manifold from batch_dims : int number of first dims to treat as batch dims and not include in the Product manifold Returns ------- ProductManifold """ batch_shape = None init = [] for tens in parts: manifold = tens.manifold if batch_shape is None: batch_shape = tens.shape[:batch_dims] elif not batch_shape == tens.shape[:batch_dims]: raise ValueError("Not all parts have same batch shape") init.append((manifold, tens.shape[batch_dims:])) return cls(*init)
def random_combined( self, *size, dtype=None, device=None ) -> "geoopt.ManifoldTensor": shape = geoopt.utils.size2shape(*size) self._assert_check_shape(shape, "x") batch_shape = shape[:-1] points = [] for manifold, shape in zip(self.manifolds, self.shapes): points.append( manifold.random(batch_shape + shape, dtype=dtype, device=device) ) tensor = self.pack_point(*points) return geoopt.ManifoldTensor(tensor, manifold=self) random = random_combined
[docs] def origin( self, *size, dtype=None, device=None, seed=42 ) -> "geoopt.ManifoldTensor": shape = geoopt.utils.size2shape(*size) self._assert_check_shape(shape, "x") batch_shape = shape[:-1] points = [] for manifold, shape in zip(self.manifolds, self.shapes): points.append( manifold.origin( batch_shape + shape, dtype=dtype, device=device, seed=seed ) ) tensor = self.pack_point(*points) return geoopt.ManifoldTensor(tensor, manifold=self)
class StereographicProductManifold(ProductManifold): """ Product Manifold for Stereographic manifolds. Examples -------- A Torus >>> import geoopt >>> sphere = geoopt.SphereProjection() >>> torus = StereographicProductManifold((sphere, 2), (sphere, 2)) """ __scaling__ = Stereographic.__scaling__.copy() def __init__( self, *manifolds_with_shape: Tuple[Stereographic, Union[Tuple[int, ...], int]], ): super().__init__(*manifolds_with_shape) for man in self.manifolds: if not geoopt.utils.ismanifold(man, Stereographic): raise TypeError("Every submanifold has to be Stereographic manifold") def dist2plane( self, x: torch.Tensor, p: torch.Tensor, a: torch.Tensor, *, keepdim=False, signed=False, scaled=False, ) -> torch.Tensor: dists = [] for i, manifold in enumerate(self.manifolds): dists.append( manifold.dist2plane( self.take_submanifold_value(x, i), self.take_submanifold_value(p, i), self.take_submanifold_value(a, i), dim=-1, keepdim=keepdim, signed=signed, scaled=scaled, ) ) dists = torch.stack(dists, -1) return (dists**2).sum(axis=-1).sqrt() def mobius_add( self, x: torch.Tensor, y: torch.Tensor, *, project=True, ) -> torch.Tensor: return self._mobius_2_manifold_args(x, y, "mobius_add", project=project) def mobius_coadd( self, x: torch.Tensor, y: torch.Tensor, *, project=True, ) -> torch.Tensor: return self._mobius_2_manifold_args(x, y, "mobius_coadd", project=project) def mobius_sub( self, x: torch.Tensor, y: torch.Tensor, *, project=True, ) -> torch.Tensor: return self._mobius_2_manifold_args(x, y, "mobius_sub", project=project) def mobius_cosub( self, x: torch.Tensor, y: torch.Tensor, *, project=True, ) -> torch.Tensor: return self._mobius_2_manifold_args(x, y, "mobius_cosub", project=project) def _mobius_2_manifold_args( self, x: torch.Tensor, y: torch.Tensor, kind, *, project=True, ) -> torch.Tensor: target_batch_dim = _calculate_target_batch_dim(x.dim(), y.dim()) mapped_tensors = [] for i, manifold in enumerate(self.manifolds): x_ = self.take_submanifold_value(x, i) y_ = self.take_submanifold_value(y, i) mapped = getattr(manifold, kind)(x_, y_, dim=-1, project=project) mapped = mapped.reshape((*mapped.shape[:target_batch_dim], -1)) mapped_tensors.append(mapped) return self.pack_point(*mapped_tensors) def mobius_scalar_mul( self, r: torch.Tensor, x: torch.Tensor, *, project=True, ) -> torch.Tensor: mapped_tensors = [] for i, manifold in enumerate(self.manifolds): x_ = self.take_submanifold_value(x, i) mapped_tensors.append(manifold.mobius_scalar_mul(r, x_, project=project)) return self.pack_point(*mapped_tensors) def mobius_pointwise_mul( self, w: torch.Tensor, x: torch.Tensor, *, project=True, ) -> torch.Tensor: mapped_tensors = [] for i, manifold in enumerate(self.manifolds): w_ = self.take_submanifold_value(w, i) x_ = self.take_submanifold_value(x, i) mapped_tensors.append( manifold.mobius_pointwise_mul(w_, x_, project=project) ) return self.pack_point(*mapped_tensors) def take_submanifold_matrix( self, x: torch.Tensor, i: int, reshape=True ) -> torch.Tensor: """ Take i'th slice of the ambient tensor and possibly reshape. Parameters ---------- x : tensor Ambient tensor i : int submanifold index reshape : bool reshape the slice? Returns ------- torch.Tensor """ slc = self.slices[i] part = x[..., slc, slc] if reshape: part = part.reshape((*part.shape[:-2], *self.shapes[i], *self.shapes[i])) return part def mobius_matvec( self, m: torch.Tensor, x: torch.Tensor, *, project=True, ) -> torch.Tensor: mapped_tensors = [] for i, manifold in enumerate(self.manifolds): m_ = self.take_submanifold_matrix(m, i) x_ = self.take_submanifold_value(x, i) mapped_tensors.append(manifold.mobius_matvec(m_, x_, project=project)) return self.pack_point(*mapped_tensors) @__scaling__(ScalingInfo(std=-1)) def wrapped_normal( self, *size, mean: torch.Tensor, std: Union[torch.Tensor, int, float] = 1, dtype=None, device=None, ) -> "geoopt.ManifoldTensor": shape = size2shape(*size) self._assert_check_shape(shape, "x") batch_shape = shape[:-1] if type(std) == int or type(std) == float: std = torch.zeros(mean.shape[-1]).type_as(mean) * std points = [] for i, (manifold, shape) in enumerate(zip(self.manifolds, self.shapes)): points.append( manifold.wrapped_normal( *(batch_shape + shape), mean=self.take_submanifold_value(mean, i), std=self.take_submanifold_value(std, i), dtype=dtype, device=device, ) ) tensor = self.pack_point(*points) return geoopt.ManifoldTensor(tensor, manifold=self) def geodesic( self, t: torch.Tensor, x: torch.Tensor, y: torch.Tensor, *, dim=-1, ) -> torch.Tensor: res_list = [] for i, manifold in enumerate(self.manifolds): x_ = self.take_submanifold_value(x, i) y_ = self.take_submanifold_value(y, i) res = manifold.geodesic(t, x_, y_, dim=-1) res_list.append(res) return self.pack_point(*res_list) def geodesic_unit( self, t: torch.Tensor, x: torch.Tensor, u: torch.Tensor, *, project=True, ) -> torch.Tensor: res_list = [] for i, manifold in enumerate(self.manifolds): x_ = self.take_submanifold_value(x, i) u_ = self.take_submanifold_value(u, i) res = manifold.geodesic_unit(t, x_, u_, dim=-1, project=project) res_list.append(res) return self.pack_point(*res_list) def dist0(self, x: torch.Tensor, *, keepdim=False) -> torch.Tensor: res = [] for i, manifold in enumerate(self.manifolds): x_ = self.take_submanifold_value(x, i) res.append(manifold.dist0(x_) ** 2) res = sum(res) ** 0.5 if keepdim: res = torch.unsqueeze(res, -1) return res def expmap0(self, u: torch.Tensor, *, project=True) -> torch.Tensor: res = [] for i, manifold in enumerate(self.manifolds): u_ = self.take_submanifold_value(u, i) res.append(manifold.expmap0(u_, dim=-1, project=project)) return self.pack_point(*res) def logmap0(self, x: torch.Tensor, *, project=True) -> torch.Tensor: res = [] for i, manifold in enumerate(self.manifolds): x_ = self.take_submanifold_value(x, i) res.append(manifold.logmap0(x_, dim=-1)) return self.pack_point(*res) def transp0(self, y: torch.Tensor, u: torch.Tensor) -> torch.Tensor: res = [] for i, manifold in enumerate(self.manifolds): y_ = self.take_submanifold_value(y, i) u_ = self.take_submanifold_value(u, i) res.append(manifold.transp0(y_, u_, dim=-1)) return self.pack_point(*res) def transp0back(self, x: torch.Tensor, u: torch.Tensor) -> torch.Tensor: res = [] for i, manifold in enumerate(self.manifolds): x_ = self.take_submanifold_value(x, i) u_ = self.take_submanifold_value(u, i) res.append(manifold.transp0back(x_, u_, dim=-1)) return self.pack_point(*res) def gyration( self, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, *, project=True ) -> torch.Tensor: res = [] for i, manifold in enumerate(self.manifolds): x_ = self.take_submanifold_value(x, i) y_ = self.take_submanifold_value(y, i) z_ = self.take_submanifold_value(z, i) res.append(manifold.gyration(x_, y_, z_, dim=-1)) return self.pack_point(*res) def antipode(self, x: torch.Tensor, *, project=True) -> torch.Tensor: res = [] for i, manifold in enumerate(self.manifolds): x_ = self.take_submanifold_value(x, i) res.append(manifold.antipode(x_, dim=-1)) return self.pack_point(*res) def mobius_fn_apply( self, fn: callable, x: torch.Tensor, *args, project=True, **kwargs ) -> torch.Tensor: res = [] for i, manifold in enumerate(self.manifolds): x_ = self.take_submanifold_value(x, i) res.append( manifold.mobius_fn_apply( fn, x_, *args, dim=-1, project=project, **kwargs ) ) return self.pack_point(*res) def mobius_fn_apply_chain( self, x: torch.Tensor, *fns: callable, project=True, ) -> torch.Tensor: res = [] for i, manifold in enumerate(self.manifolds): x_ = self.take_submanifold_value(x, i) res.append( manifold.mobius_fn_apply_chain( x_, *fns, dim=-1, project=project, ) ) return self.pack_point(*res)