Source code for jax_rb.manifolds.matrix_left_invariant

"""Base class for matrix groups with left invariant metrics.
"""

from functools import partial

import jax
import jax.numpy as jnp
import jax.numpy.linalg as jla
from .global_manifold import GlobalManifold
from ..utils.utils import (grand, lie)


[docs] class MatrixLeftInvariant(GlobalManifold): """Matrix group with left invariant metric. :param p: the size of the matrix :param g_mat: The matrix defining the inner product at the identity. Typically g_mat is of size :math:`\\dim \\mathrm{G}` . """ def __init__(self, p, g_mat): """Constructor """ self.shape = (p, p) self.dim = p*p self._g_mat = g_mat ei, ev = jla.eigh(g_mat) self._i_sqrt_g_mat = ev@((1/jnp.sqrt(ei))[:, None]*ev.T) # Stratonovich drift at id self.v0 = self._make_v0() #: Stratonovich drift at the identity. self.id_drift = self._make_id_drift() #: Ito drift at the identity.
[docs] def name(self): raise NotImplementedError
def _lie_algebra_proj(self, omg): """ The projection :math:`p_{\\mathfrak{g}}` at the identity. """ raise NotImplementedError def _mat_apply(self, mat, omg): """ Implementing the operator :math:`\\mathcal{I}` applied on omg in :math:`\\mathcal{E}`. """ raise NotImplementedError def _id_opt(self, omg): """The metric applied at the identity. """ return self._mat_apply(self._g_mat, omg) def _inv_id_opt(self, omg): """Invert of _id_opt. """ return self._mat_apply(jla.inv(self._g_mat), omg)
[docs] def proj(self, x, omg): return x@self._lie_algebra_proj(jla.solve(x, omg))
def _d_proj(self, x, xi, eta): ivx = jla.inv(x) return xi@self._lie_algebra_proj(jla.solve(x, eta)) \ - x@self._lie_algebra_proj(ivx@xi@ivx@eta)
[docs] def rand_ambient(self, key): return grand(key, (self.shape))
[docs] def rand_vec(self, key, x): omg, key = grand(key, self.shape) return self.proj(x, omg), key
[docs] def rand_point(self, key): raise NotImplementedError
[docs] def inner(self, x, a, b): return jnp.sum(jla.solve(x, a)*self._id_opt(jla.solve(x, b)))
[docs] def g_metric(self, x, omg): return jla.solve(x.T, self._id_opt(jla.solve(x, omg)))
[docs] def inv_g_metric(self, x, omg): return x@self._inv_id_opt(x.T@omg)
@partial(jax.jit, static_argnums=(0,)) def gamma(self, x, xi, eta): # return - self.d_proj(x, xi, eta) # + self.proj(x, self.gamma_ambient(x, xi, eta)) ivxi = jla.solve(x, xi) iveta = jla.solve(x, eta) return -0.5*(xi@iveta + eta@ivxi) \ + 0.5*x@self._inv_id_opt( self._lie_algebra_proj( lie(self._id_opt(ivxi), iveta.T) \ + lie(self._id_opt(iveta), ivxi.T)))
[docs] @partial(jax.jit, static_argnums=(0,)) def gamma_ambient(self, x, xi, eta): """Christoffel function for ambient manifold. """ ivx = jla.inv(x) return 0.5*x@self._inv_id_opt( - self._id_opt(ivx@xi@ivx@eta + ivx@eta@ivx@xi) \ + lie(self._id_opt(ivx@xi), eta.T@ivx.T) \ + lie(self._id_opt(ivx@eta), xi.T@ivx.T))
[docs] def retract(self, x, v): raise NotImplementedError
def _make_id_drift_longer(self): """make the drift at identity. The longer way, sum gamma.x. """ p = self.shape[0] drft = jnp.zeros(self.shape) for i in range(p): for j in range(p): eij = jnp.zeros((p, p)).at[i, j].set(1.) drft -= self.gamma(jnp.eye(p), eij, self._lie_algebra_proj(self._inv_id_opt(eij))) return 0.5*drft def _make_id_drift(self): """make the drift at identity. Simplify so we dont need to evaluate gamma. """ p = self.shape[0] v = jnp.zeros((p, p)) zr = jnp.zeros((p, p)) def lp(a): return self._lie_algebra_proj(a) for i in range(p): for j in range(p): eij = zr.at[i, j].set(1.) v += - eij@self._inv_id_opt(lp(eij)) \ + self._inv_id_opt(lp(lie(lp(eij), eij.T))) return -0.5*v def _make_v0(self): """ make v0, the identity tangent vector corresponding to the Stratonovich drift. """ p = self.shape[0] v = jnp.zeros((p, p)) zr = jnp.zeros((p, p)) for i in range(p): for j in range(p): eij = zr.at[i, j].set(1.) v += self._inv_id_opt( self._lie_algebra_proj( lie(self._lie_algebra_proj(eij), eij.T))) return -0.5*v
[docs] def approx_nearest(self, q): """ find point on the manifold that is nearest to q, same order as the nearest point. """ raise NotImplementedError
@partial(jax.jit, static_argnums=(0,)) def ito_drift(self, x): return x@self.id_drift
[docs] @partial(jax.jit, static_argnums=(0,)) def stratonovich_drift(self, x): """ Stratonovich drift. """ return x@self.v0
# @partial(jax.jit, static_argnums=(0,))
[docs] def laplace_beltrami(self, x, egradx, ehessvp): p = self.shape[0] tup = jnp.zeros((p, p)) ret = 0 for i in range(p): for j in range(p): e_ij = tup.at[i, j].set(1.) ret += self.proj(x, self.inv_g_metric( x, ehessvp(x, e_ij)))[i, j] return ret + 2*jnp.sum(self.ito_drift(x)*egradx)
[docs] def left_invariant_vector_field(self, x, v): """ map from a unit vector in the trace metric to a vector field with unit length in the left invariant metric. """ return x@self._mat_apply(self._i_sqrt_g_mat, v)
[docs] @partial(jax.jit, static_argnums=(0,)) def pseudo_transport(self, x, y, v): """the easy one """ return y@jla.solve(x, v)
[docs] @partial(jax.jit, static_argnums=(0,)) def sigma_id(self, dw): """ sigma, to generate the Brownian motion at the identity. """ return self._lie_algebra_proj(self._mat_apply(self._i_sqrt_g_mat, dw))
[docs] @partial(jax.jit, static_argnums=(0,)) def sigma(self, x, dw): """ sigma, to generate the Brownian motion. """ return x@self.sigma_id(dw)