Developer Guide¶
Base Manifold¶
The common base class for all manifolds is geoopt.manifolds.base.Manifold
.
-
class
geoopt.manifolds.base.
Manifold
(**kwargs)[source] -
_assert_check_shape
(shape: Tuple[int], name: str)[source] Util to check shape and raise an error if needed.
Exhaustive implementation for checking if a given point has valid dimension size, shape, etc. It will raise a ValueError if check is not passed
Parameters: Raises:
-
_check_point_on_manifold
(x: torch.Tensor, *, atol=1e-05, rtol=1e-05) → Union[Tuple[bool, Optional[str]], bool][source] Util to check point lies on the manifold.
Exhaustive implementation for checking if a given point lies on the manifold. It should return boolean and a reason of failure if check is not passed. You can assume assert_check_point is already passed beforehand
Parameters: - torch.Tensor (x) – point on the manifold
- atol (float) – absolute tolerance as in
numpy.allclose()
- rtol (float) – relative tolerance as in
numpy.allclose()
Returns: check result and the reason of fail if any
Return type:
-
_check_shape
(shape: Tuple[int], name: str) → Union[Tuple[bool, Optional[str]], bool][source] Util to check shape.
Exhaustive implementation for checking if a given point has valid dimension size, shape, etc. It should return boolean and a reason of failure if check is not passed
Parameters: Returns: check result and the reason of fail if any
Return type:
-
_check_vector_on_tangent
(x: torch.Tensor, u: torch.Tensor, *, atol=1e-05, rtol=1e-05) → Union[Tuple[bool, Optional[str]], bool][source] Util to check a vector belongs to the tangent space of a point.
Exhaustive implementation for checking if a given point lies in the tangent space at x of the manifold. It should return a boolean indicating whether the test was passed and a reason of failure if check is not passed. You can assume assert_check_point is already passed beforehand
Parameters: - torch.Tensor (u) –
- torch.Tensor –
- atol (float) – absolute tolerance
- rtol – relative tolerance
Returns: check result and the reason of fail if any
Return type:
-
assert_check_point
(x: torch.Tensor)[source] Check if point is valid to be used with the manifold and raise an error with informative message on failure.
Parameters: x (torch.Tensor) – point on the manifold Notes
This check is compatible to what optimizer expects, last dimensions are treated as manifold dimensions
-
assert_check_point_on_manifold
(x: torch.Tensor, *, atol=1e-05, rtol=1e-05)[source] Check if point :math`x` is lying on the manifold and raise an error with informative message on failure.
Parameters: - x (torch.Tensor) – point on the manifold
- atol (float) – absolute tolerance as in
numpy.allclose()
- rtol (float) – relative tolerance as in
numpy.allclose()
-
assert_check_vector
(u: torch.Tensor)[source] Check if vector is valid to be used with the manifold and raise an error with informative message on failure.
Parameters: u (torch.Tensor) – vector on the tangent plane Notes
This check is compatible to what optimizer expects, last dimensions are treated as manifold dimensions
-
assert_check_vector_on_tangent
(x: torch.Tensor, u: torch.Tensor, *, ok_point=False, atol=1e-05, rtol=1e-05)[source] Check if u \(u\) is lying on the tangent space to x and raise an error on fail.
Parameters: - x (torch.Tensor) – point on the manifold
- u (torch.Tensor) – vector on the tangent space to \(x\)
- atol (float) – absolute tolerance as in
numpy.allclose()
- rtol (float) – relative tolerance as in
numpy.allclose()
- ok_point (bool) – is a check for point required?
-
check_point
(x: torch.Tensor, *, explain=False) → Union[Tuple[bool, Optional[str]], bool][source] Check if point is valid to be used with the manifold.
Parameters: - x (torch.Tensor) – point on the manifold
- explain (bool) – return an additional information on check
Returns: boolean indicating if tensor is valid and reason of failure if False
Return type: Notes
This check is compatible to what optimizer expects, last dimensions are treated as manifold dimensions
-
check_point_on_manifold
(x: torch.Tensor, *, explain=False, atol=1e-05, rtol=1e-05) → Union[Tuple[bool, Optional[str]], bool][source] Check if point \(x\) is lying on the manifold.
Parameters: - x (torch.Tensor) – point on the manifold
- atol (float) – absolute tolerance as in
numpy.allclose()
- rtol (float) – relative tolerance as in
numpy.allclose()
- explain (bool) – return an additional information on check
Returns: boolean indicating if tensor is valid and reason of failure if False
Return type: Notes
This check is compatible to what optimizer expects, last dimensions are treated as manifold dimensions
-
check_vector
(u: torch.Tensor, *, explain=False)[source] Check if vector is valid to be used with the manifold.
Parameters: - u (torch.Tensor) – vector on the tangent plane
- explain (bool) – return an additional information on check
Returns: boolean indicating if tensor is valid and reason of failure if False
Return type: Notes
This check is compatible to what optimizer expects, last dimensions are treated as manifold dimensions
-
check_vector_on_tangent
(x: torch.Tensor, u: torch.Tensor, *, ok_point=False, explain=False, atol=1e-05, rtol=1e-05) → Union[Tuple[bool, Optional[str]], bool][source] Check if \(u\) is lying on the tangent space to x.
Parameters: - x (torch.Tensor) – point on the manifold
- u (torch.Tensor) – vector on the tangent space to \(x\)
- atol (float) – absolute tolerance as in
numpy.allclose()
- rtol (float) – relative tolerance as in
numpy.allclose()
- explain (bool) – return an additional information on check
- ok_point (bool) – is a check for point required?
Returns: boolean indicating if tensor is valid and reason of failure if False
Return type:
-
component_inner
(x: torch.Tensor, u: torch.Tensor, v: torch.Tensor = None) → torch.Tensor[source] Inner product for tangent vectors at point \(x\) according to components of the manifold.
The result of the function is same as
inner
withkeepdim=True
for all the manifolds except ProductManifold. For this manifold it acts different way computing inner product for each component and then building an output correctly tiling and reshaping the result.Parameters: - x (torch.Tensor) – point on the manifold
- u (torch.Tensor) – tangent vector at point \(x\)
- v (Optional[torch.Tensor]) – tangent vector at point \(x\)
Returns: inner product component wise (broadcasted)
Return type: Notes
The purpose of this method is better adaptive properties in optimization since ProductManifold will “hide” the structure in public API.
-
device
Manifold device.
Returns: Return type: Optional[torch.device]
-
dist
(x: torch.Tensor, y: torch.Tensor, *, keepdim=False) → torch.Tensor[source] Compute distance between 2 points on the manifold that is the shortest path along geodesics.
Parameters: - x (torch.Tensor) – point on the manifold
- y (torch.Tensor) – point on the manifold
- keepdim (bool) – keep the last dim?
Returns: distance between two points
Return type:
-
dist2
(x: torch.Tensor, y: torch.Tensor, *, keepdim=False) → torch.Tensor[source] Compute squared distance between 2 points on the manifold that is the shortest path along geodesics.
Parameters: - x (torch.Tensor) – point on the manifold
- y (torch.Tensor) – point on the manifold
- keepdim (bool) – keep the last dim?
Returns: squared distance between two points
Return type:
-
dtype
Manifold dtype.
Returns: Return type: Optional[torch.dtype]
-
egrad2rgrad
(x: torch.Tensor, u: torch.Tensor) → torch.Tensor[source] Transform gradient computed using autodiff to the correct Riemannian gradient for the point \(x\).
Parameters: - torch.Tensor (u) – point on the manifold
- torch.Tensor – gradient to be projected
Returns: grad vector in the Riemannian manifold
Return type:
-
expmap
(x: torch.Tensor, u: torch.Tensor) → torch.Tensor[source] Perform an exponential map \(\operatorname{Exp}_x(u)\).
Parameters: - x (torch.Tensor) – point on the manifold
- u (torch.Tensor) – tangent vector at point \(x\)
Returns: transported point
Return type:
-
expmap_transp
(x: torch.Tensor, u: torch.Tensor, v: torch.Tensor) → Tuple[torch.Tensor, torch.Tensor][source] Perform an exponential map and vector transport from point \(x\) with given direction \(u\).
Parameters: - x (torch.Tensor) – point on the manifold
- u (torch.Tensor) – tangent vector at point \(x\)
- v (torch.Tensor) – tangent vector at point \(x\) to be transported
Returns: transported point
Return type:
-
extra_repr
()[source] Set the extra representation of the module
To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.
-
inner
(x: torch.Tensor, u: torch.Tensor, v: torch.Tensor = None, *, keepdim=False) → torch.Tensor[source] Inner product for tangent vectors at point \(x\).
Parameters: - x (torch.Tensor) – point on the manifold
- u (torch.Tensor) – tangent vector at point \(x\)
- v (Optional[torch.Tensor]) – tangent vector at point \(x\)
- keepdim (bool) – keep the last dim?
Returns: inner product (broadcasted)
Return type:
-
logmap
(x: torch.Tensor, y: torch.Tensor) → torch.Tensor[source] Perform an logarithmic map \(\operatorname{Log}_{x}(y)\).
Parameters: - x (torch.Tensor) – point on the manifold
- y (torch.Tensor) – point on the manifold
Returns: tangent vector
Return type:
-
norm
(x: torch.Tensor, u: torch.Tensor, *, keepdim=False) → torch.Tensor[source] Norm of a tangent vector at point \(x\).
Parameters: - x (torch.Tensor) – point on the manifold
- u (torch.Tensor) – tangent vector at point \(x\)
- keepdim (bool) – keep the last dim?
Returns: inner product (broadcasted)
Return type:
-
origin
(*size, dtype=None, device=None, seed: Optional[int] = 42) → torch.Tensor[source] Create some reasonable point on the manifold in a deterministic way.
For some manifolds there may exist e.g. zero vector or some analogy. In case it is possible to define this special point, this point is returned with the desired size. In other case, the returned point is sampled on the manifold in a deterministic way.
Parameters: - size (Union[int, Tuple[int]]) – the desired shape
- device (torch.device) – the desired device
- dtype (torch.dtype) – the desired dtype
- seed (Optional[int]) – A parameter controlling deterministic randomness for manifolds that do not provide
.origin
, but provide.random
. (default: 42)
Returns: Return type:
-
pack_point
(*tensors) → torch.Tensor[source] Construct a tensor representation of a manifold point.
In case of regular manifolds this will return the same tensor. However, for e.g. Product manifold this function will pack all non-batch dimensions.
Parameters: tensors (Tuple[torch.Tensor]) – Returns: Return type: torch.Tensor
-
proju
(x: torch.Tensor, u: torch.Tensor) → torch.Tensor[source] Project vector \(u\) on a tangent space for \(x\), usually is the same as
egrad2rgrad()
.Parameters: - torch.Tensor (u) – point on the manifold
- torch.Tensor – vector to be projected
Returns: projected vector
Return type:
-
projx
(x: torch.Tensor) → torch.Tensor[source] Project point \(x\) on the manifold.
Parameters: torch.Tensor (x) – point to be projected Returns: projected point Return type: torch.Tensor
-
random
(*size, dtype=None, device=None, **kwargs) → torch.Tensor[source] Random sampling on the manifold.
The exact implementation depends on manifold and usually does not follow all assumptions about uniform measure, etc.
-
retr
(x: torch.Tensor, u: torch.Tensor) → torch.Tensor[source] Perform a retraction from point \(x\) with given direction \(u\).
Parameters: - x (torch.Tensor) – point on the manifold
- u (torch.Tensor) – tangent vector at point \(x\)
Returns: transported point
Return type:
-
retr_transp
(x: torch.Tensor, u: torch.Tensor, v: torch.Tensor) → Tuple[torch.Tensor, torch.Tensor][source] Perform a retraction + vector transport at once.
Parameters: - x (torch.Tensor) – point on the manifold
- u (torch.Tensor) – tangent vector at point \(x\)
- v (torch.Tensor) – tangent vector at point \(x\) to be transported
Returns: transported point and vectors
Return type: Tuple[torch.Tensor, torch.Tensor]
Notes
Sometimes this is a far more optimal way to preform retraction + vector transport
-
transp
(x: torch.Tensor, y: torch.Tensor, v: torch.Tensor) → torch.Tensor[source] Perform vector transport \(\mathfrak{T}_{x\to y}(v)\).
Parameters: - x (torch.Tensor) – start point on the manifold
- y (torch.Tensor) – target point on the manifold
- v (torch.Tensor) – tangent vector at point \(x\)
Returns: transported tensor
Return type:
-
transp_follow_expmap
(x: torch.Tensor, u: torch.Tensor, v: torch.Tensor) → torch.Tensor[source] Perform vector transport following \(u\): \(\mathfrak{T}_{x\to\operatorname{Exp}(x, u)}(v)\).
Here, \(\operatorname{Exp}\) is the best possible approximation of the true exponential map. There are cases when the exact variant is hard or impossible implement, therefore a fallback, non-exact, implementation is used.
Parameters: - x (torch.Tensor) – point on the manifold
- u (torch.Tensor) – tangent vector at point \(x\)
- v (torch.Tensor) – tangent vector at point \(x\) to be transported
Returns: transported tensor
Return type:
-
transp_follow_retr
(x: torch.Tensor, u: torch.Tensor, v: torch.Tensor) → torch.Tensor[source] Perform vector transport following \(u\): \(\mathfrak{T}_{x\to\operatorname{retr}(x, u)}(v)\).
This operation is sometimes is much more simpler and can be optimized.
Parameters: - x (torch.Tensor) – point on the manifold
- u (torch.Tensor) – tangent vector at point \(x\)
- v (torch.Tensor) – tangent vector at point \(x\) to be transported
Returns: transported tensor
Return type:
-
unpack_tensor
(tensor: torch.Tensor) → torch.Tensor[source] Construct a point on the manifold.
This method should help to work with product and compound manifolds. Internally all points on the manifold are stored in an intuitive format. However, there might be cases, when this representation is simpler or more efficient to store in a different way that is hard to use in practice.
Parameters: tensor (torch.Tensor) – Returns: Return type: torch.Tensor
-
-
class
geoopt.manifolds.base.
ScalingStorage
[source] Helper class to make implementation transparent.
This is just a dictionary with additional overriden
__call__
for more explicit and elegant API to declare members. A usage example may be found inManifold
.Methods that require rescaling when wrapped into
Scaled
should be defined as follows1. Regular methods like
dist
,dist2
,expmap
,retr
etc. that are already present in the base class do not require registration, it has already happened in the baseManifold
class.- New methods (like in
PoincareBall
) should be treated with care.
class PoincareBall(Manifold): # make a class copy of __scaling__ info. Default methods are already present there __scaling__ = Manifold.__scaling__.copy() ... # here come regular implementation of the required methods @__scaling__(ScalingInfo(1)) # rescale output according to rule `out * scaling ** 1` def dist0(self, x: torch.Tensor, *, dim=-1, keepdim=False): return math.dist0(x, c=self.c, dim=dim, keepdim=keepdim) @__scaling__(ScalingInfo(u=-1)) # rescale argument `u` according to the rule `out * scaling ** -1` def expmap0(self, u: torch.Tensor, *, dim=-1, project=True): res = math.expmap0(u, c=self.c, dim=dim) if project: return math.project(res, c=self.c, dim=dim) else: return res ... # other special methods implementation
- Some methods are not compliant with the above rescaling rules. We should mark them as NotCompatible
# continuation of the PoincareBall definition @__scaling__(ScalingInfo.NotCompatible) def mobius_fn_apply( self, fn: callable, x: torch.Tensor, *args, dim=-1, project=True, **kwargs ): res = math.mobius_fn_apply(fn, x, *args, c=self.c, dim=dim, **kwargs) if project: return math.project(res, c=self.c, dim=dim) else: return res
-
copy
() → a shallow copy of D[source]
- New methods (like in
-
class
geoopt.manifolds.base.
ScalingInfo
(*results, **kwargs)[source] Scaling info for each argument that requires rescaling.
scaled_value = value * scaling ** power if power != 0 else value
For results it is not always required to set powers of scaling, then it is no-op.
The convention for this info is the following. The output of a function is either a tuple or a single object. In any case, outputs are treated as positionals. Function inputs, in contrast, are treated by keywords. It is a common practice to maintain function signature when overriding, so this way may be considered as a sufficient in this particular scenario. The only required info for formula above is
power
.