Source code for jax_rb.manifolds.so_left_invariant

""":math:`SO`: Special Orthogonal group.
"""

from functools import partial

import jax
import jax.numpy as jnp
import jax.numpy.linalg as jla
from .matrix_left_invariant import MatrixLeftInvariant
from ..utils.utils import (asym, lie)


[docs] class SOLeftInvariant(MatrixLeftInvariant): """The group :math:`SO(n)` of orthogonal matrices :math:`U\\in R^{n\\times n}` of determinant 1 with metric :math:`Tr(\\omega^TU^T\\mathcal{I}(U^T\\omega))` where :math:`\\mathcal{I}` is the metric defined by g_mat, a matrix of size :math:`\\mathbb{R}^{\\frac{n(n-1)}{2}\\times \\frac{n(n-1)}{2}}`. """ def __init__(self, n, g_mat): """ g_mat is a matrix of size :math:`\\frac{n(n-1)}{2}` used to define the metric. """ super().__init__(n, g_mat) self.dim = (n*(n-1)) // 2 def name(self): return f"SO({self.shape[0]})" @partial(jax.jit, static_argnums=(0,)) def _lie_algebra_proj(self, omg): return asym(omg) @partial(jax.jit, static_argnums=(0,)) def _mat_apply(self, mat, omg): """ mat is a matrix of size (p(p-1))**2 """ p = omg.shape[0] rows, cols = jnp.triu_indices(p, 1) ret = jnp.empty((p, p)) ret = ret.at[rows, cols].set(mat@omg.take(rows*p+cols)) ret = ret.at[cols, rows].set(mat@omg.T.take(rows*p+cols)) ret = ret.at[jnp.diag_indices(p)].set(omg[jnp.diag_indices(p)]) return ret def rand_point(self, key): """ A random point on the manifold """ xt, key = self.rand_ambient(key) x, _ = jla.qr(xt) return x, key # @partial(jax.jit, static_argnums=(0,)) def retract(self, x, v): """ second order retraction, but simple """ x1 = x + v - 0.5* self.proj_gamma(x, v, v) u, _, v = jla.svd(x1) return u@v # @partial(jax.jit, static_argnums=(0,)) def approx_nearest(self, q): u, _, v = jla.svd(q) return u@v # @partial(jax.jit, static_argnums=(0,)) def pseudo_transport(self, x, y, v): """the easy one """ return y@x.T@v @partial(jax.jit, static_argnums=(0,)) def proj_gamma(self, x, xi, eta): """projection of christoffel function """ # return - self.d_proj(x, xi, eta) # + self.proj(x, self.gamma_ambient(x, xi, eta)) ivxi = jla.solve(x, xi) iveta = jla.solve(x, eta) return -0.5*x@self._lie_algebra_proj(ivxi@iveta + iveta@ivxi) \ + 0.5*x@self._inv_id_opt( self._lie_algebra_proj( lie(self._id_opt(ivxi), iveta.T) \ + lie(self._id_opt(iveta), ivxi.T))) # @partial(jax.jit, static_argnums=(0,)) def sigma_la(self, vec_dw): """ sigma is applied on the lie agebra identified with a vector """ p = self.shape[0] v = jnp.zeros(self.shape) rows, cols = jnp.triu_indices(p, 1) for idx in range(vec_dw.shape[0]): i, j = rows[idx], cols[idx] v += 1/jnp.sqrt(2)*vec_dw[idx]*self._mat_apply( self._i_sqrt_g_mat, jnp.zeros((p, p)).at[i, j].set(1.).at[j, i].set(-1.) ) return v