#!/usr/bin/env python3
# Copyright (c) 2018-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch as th
from torch.autograd import Function
from .euclidean import EuclideanManifold
[docs]class PoincareManifold(EuclideanManifold):
"""
Poincaré Ball model of hyperbolic geometry. This is the manifold used
in `"Poincaré Embeddings for Learning Hierarchical Representations"
(Nickel et al., 2017) <https://arxiv.org/abs/1705.08039>`_
Args:
eps (float): :math:`\\epsilon` value to restrict the radius of the
ball. This is used to prevent numerical overflow/underflow
"""
def __init__(self, eps=1e-5, **kwargs):
super(PoincareManifold, self).__init__(**kwargs)
self.eps = eps
self.boundary = 1 - eps
self.max_norm = self.boundary
[docs] def distance(self, u, v):
"""
See :func:`~hype.manifold.Manifold.distance`
:math:`d(u, v) = \\text{arcosh}\\left(
1 + 2 \\frac{||u - v||^2}{(1 - ||u||^2)(1 - ||v||^2)} \\right)`
"""
return Distance.apply(u, v, self.eps)
[docs] def rgrad(self, p, d_p):
"""See :func:`~hype.manifold.Manifold.rgrad`"""
if d_p.is_sparse:
p_sqnorm = th.sum(
p[d_p._indices()[0].squeeze()] ** 2, dim=1,
keepdim=True
).expand_as(d_p._values())
n_vals = d_p._values() * ((1 - p_sqnorm) ** 2) / 4
n_vals.renorm_(2, 0, 5)
d_p = th.sparse.DoubleTensor(d_p._indices(), n_vals, d_p.size())
else:
p_sqnorm = th.sum(p ** 2, dim=-1, keepdim=True)
d_p = d_p * ((1 - p_sqnorm) ** 2 / 4).expand_as(d_p)
return d_p
class Distance(Function):
@staticmethod
def grad(x, v, sqnormx, sqnormv, sqdist, eps):
alpha = (1 - sqnormx)
beta = (1 - sqnormv)
z = 1 + 2 * sqdist / (alpha * beta)
a = ((sqnormv - 2 * th.sum(x * v, dim=-1) + 1) / th.pow(alpha, 2))\
.unsqueeze(-1).expand_as(x)
a = a * x - v / alpha.unsqueeze(-1).expand_as(v)
z = th.sqrt(th.pow(z, 2) - 1)
z = th.clamp(z * beta, min=eps).unsqueeze(-1)
return 4 * a / z.expand_as(x)
@staticmethod
def forward(ctx, u, v, eps):
squnorm = th.clamp(th.sum(u * u, dim=-1), 0, 1 - eps)
sqvnorm = th.clamp(th.sum(v * v, dim=-1), 0, 1 - eps)
sqdist = th.sum(th.pow(u - v, 2), dim=-1)
ctx.eps = eps
ctx.save_for_backward(u, v, squnorm, sqvnorm, sqdist)
x = sqdist / ((1 - squnorm) * (1 - sqvnorm)) * 2 + 1
# arcosh
z = th.sqrt(th.pow(x, 2) - 1)
return th.log(x + z)
@staticmethod
def backward(ctx, g):
u, v, squnorm, sqvnorm, sqdist = ctx.saved_tensors
g = g.unsqueeze(-1)
gu = Distance.grad(u, v, squnorm, sqvnorm, sqdist, ctx.eps)
gv = Distance.grad(v, u, sqvnorm, squnorm, sqdist, ctx.eps)
return g.expand_as(gu) * gu, g.expand_as(gv) * gv, None