Source code for geoopt.manifolds.scaled

import inspect
import torch
import itertools
import types
from typing import Union, Tuple, Optional
import geoopt.utils
from geoopt.manifolds.base import Manifold, ScalingInfo
import functools

__all__ = ["Scaled"]


def rescale_value(value, scaling, power):
    return value * scaling**power if power != 0 else value


def rescale(function, scaling_info):
    if scaling_info is ScalingInfo.NotCompatible:

        @functools.wraps(functools)
        def stub(self, *args, **kwargs):
            raise NotImplementedError(
                "Scaled version of '{}' is not available".format(function.__name__)
            )

        return stub
    signature = inspect.signature(function)

    @functools.wraps(function)
    def rescaled_function(self, *args, **kwargs):
        params = signature.bind(self.base, *args, **kwargs)
        params.apply_defaults()
        arguments = params.arguments
        for k, power in scaling_info.kwargs.items():
            arguments[k] = rescale_value(arguments[k], self.scale, power)
        params = params.__class__(signature, arguments)
        results = function(*params.args, **params.kwargs)
        if not scaling_info.results:
            # do nothing
            return results
        wrapped_results = []
        is_tuple = isinstance(results, tuple)
        results = geoopt.utils.make_tuple(results)
        for _, (res, power) in enumerate(
            itertools.zip_longest(results, scaling_info.results, fillvalue=0)
        ):
            wrapped_results.append(rescale_value(res, self.scale, power))
        if not is_tuple:
            wrapped_results = wrapped_results[0]
        else:
            wrapped_results = results.__class__(wrapped_results)
        return wrapped_results

    return rescaled_function


[docs]class Scaled(Manifold): """ Scaled manifold. Scales all the distances on tha manifold by a constant factor. Scaling may be learnable since the underlying representation is canonical. Examples -------- Here is a simple example of radius 2 Sphere >>> import geoopt, torch, numpy as np >>> sphere = geoopt.Sphere() >>> radius_2_sphere = Scaled(sphere, 2) >>> p1 = torch.tensor([-1., 0.]) >>> p2 = torch.tensor([0., 1.]) >>> np.testing.assert_allclose(sphere.dist(p1, p2), np.pi / 2) >>> np.testing.assert_allclose(radius_2_sphere.dist(p1, p2), np.pi) """ def __init__(self, manifold: Manifold, scale=1.0, learnable=False): super().__init__() self.base = manifold scale = torch.as_tensor(scale, dtype=torch.get_default_dtype()) scale = scale.requires_grad_(False) if not learnable: self.register_buffer("_scale", scale) self.register_buffer("_log_scale", None) else: self.register_buffer("_scale", None) self.register_parameter("_log_scale", torch.nn.Parameter(scale.log())) # do not rebuild scaled functions very frequently, save them for method, scaling_info in self.base.__scaling__.items(): # register rescaled functions as bound methods of this particular instance unbound_method = getattr(self.base, method).__func__ # unbound method self.__setattr__( method, types.MethodType(rescale(unbound_method, scaling_info), self) ) @property def scale(self) -> torch.Tensor: if self._scale is None: return self._log_scale.exp() else: return self._scale @property def log_scale(self) -> torch.Tensor: if self._log_scale is None: return self._scale.log() else: return self._log_scale # propagate all important stuff reversible = property(lambda self: self.base.reversible) ndim = property(lambda self: self.base.ndim) name = "Scaled" __scaling__ = property(lambda self: self.base.__scaling__) # Make AbstractMeta happy, to be fixed in __init__ retr = NotImplemented expmap = NotImplemented def __getattr__(self, item): try: return super().__getattr__(item) except AttributeError as original: try: # propagate only public methods and attributes, ignore buffers, parameters, etc if isinstance(self.base, Scaled) and item in self._base_attributes: return self.base.__getattr__(item) else: return self.base.__getattribute__(item) except AttributeError as e: raise original from e @property def _base_attributes(self): if isinstance(self.base, Scaled): return self.base._base_attributes else: base_attributes = set(dir(self.base.__class__)) base_attributes |= set(self.base.__dict__.keys()) return base_attributes def __dir__(self): return list(set(super().__dir__()) | self._base_attributes) def __repr__(self): extra = self.base.extra_repr() if extra: return self.name + "({})({}) manifold".format(self.base.name, extra) else: return self.name + "({}) manifold".format(self.base.name) def _check_shape(self, shape: Tuple[int], name: str) -> Tuple[bool, Optional[str]]: return self.base._check_shape(shape, name) def _check_point_on_manifold( self, x: torch.Tensor, *, atol=1e-5, rtol=1e-5 ) -> Union[Tuple[bool, Optional[str]], bool]: return self.base._check_point_on_manifold(x, atol=atol, rtol=rtol) def _check_vector_on_tangent( self, x: torch.Tensor, u: torch.Tensor, *, atol=1e-5, rtol=1e-5 ) -> Union[Tuple[bool, Optional[str]], bool]: return self.base._check_vector_on_tangent(x, u, atol=atol, rtol=rtol) # stuff that should remain the same but we need to override it
[docs] def inner( self, x: torch.Tensor, u: torch.Tensor, v: torch.Tensor = None, *, keepdim=False, **kwargs, ) -> torch.Tensor: return self.base.inner(x, u, v, keepdim=keepdim, **kwargs)
[docs] def norm( self, x: torch.Tensor, u: torch.Tensor, *, keepdim=False, **kwargs ) -> torch.Tensor: return self.base.norm(x, u, keepdim=keepdim, **kwargs)
[docs] def proju(self, x: torch.Tensor, u: torch.Tensor, **kwargs) -> torch.Tensor: return self.base.proju(x, u, **kwargs)
[docs] def projx(self, x: torch.Tensor, **kwargs) -> torch.Tensor: return self.base.projx(x, **kwargs)
[docs] def egrad2rgrad(self, x: torch.Tensor, u: torch.Tensor, **kwargs) -> torch.Tensor: return self.base.egrad2rgrad(x, u, **kwargs)
[docs] def transp( self, x: torch.Tensor, y: torch.Tensor, v: torch.Tensor, **kwargs ) -> torch.Tensor: return self.base.transp(x, y, v, **kwargs)
[docs] def random(self, *size, dtype=None, device=None, **kwargs) -> torch.Tensor: return self.base.random(*size, dtype=dtype, device=device, **kwargs)