Source code for jax_rb.manifolds.se_left_invariant

""":math:`SE`: Special Euclidean group.
"""

from functools import partial

import jax
import jax.numpy as jnp
import jax.numpy.linalg as jla
from .matrix_left_invariant import MatrixLeftInvariant
from ..utils.utils import (sym, asym)


[docs] class SELeftInvariant(MatrixLeftInvariant): """Specian Euclidean group of Euclidean transformation of :math:`\\mathbb{R}^{n}`, represented by a pair :math:`(A, v)\\in SO^(n)\\times \\mathbb{R}^{n}` with action :math:`(A, v).w = Aw + v` for :math:`w\\in\\mathbb{R}^{n}` . Alternatively, it is represented as a matrix :math:`\\begin{bmatrix} A & v \\\\ 0 & 1 \\end{bmatrix}\\in GL(n+1)` where :math:`A\\in SO(n)`. :param n: size of A :param g_mat: a positive definite matrix in :math:`\\mathbb{R}^{\\frac{n(n+1)}{2}\\times\\frac{n(n+1)}{2}}` defining the metric at :math:`I_{n+1}` """ def __init__(self, n, g_mat): """ g_mat is a matrix of size (n(n+1)//2)**2 used to define the metric """ super().__init__(n+1, g_mat) self.dim = (n*(n+1))//2 def name(self): return f"SE({self.shape[0]-1})" @partial(jax.jit, static_argnums=(0,)) def _lie_algebra_proj(self, omg): return omg.at[:-1, :-1].set(asym(omg[:-1, :-1])).at[-1, :].set(0.) def _mat_apply(self, mat, omg): """ mat is a matrix of size (p(p+1)/)**2 Trick is to vectorize both the symmetric and anti symmetric away from the diagonal """ p = omg.shape[0] sqrt2 = jnp.sqrt(2) reta = asym(omg[:-1, :-1]) rets = sym(omg[:-1, :-1]) rows, cols = jnp.triu_indices(p-1, 1) veca = mat@jnp.concatenate([reta.take(rows*(p-1)+cols)*sqrt2, omg[:-1, -1]]) vecs = mat@jnp.concatenate([rets.take(rows*(p-1)+cols)*sqrt2, omg[-1, :-1]]) ret = jnp.empty((p, p)) ret = ret.at[:-1, -1].set(veca[1-p:]) ret = ret.at[-1, :-1].set(vecs[1-p:]) tota = jnp.zeros((p-1, p-1)) tota = tota.at[rows, cols].set(veca[:1-p]) tota = tota.at[cols, rows].set(-veca[:1-p]) tots = jnp.zeros((p-1, p-1)) tots = tots.at[rows, cols].set(vecs[:1-p]) tots = tots.at[cols, rows].set(vecs[:1-p]) ret = ret.at[:-1, :-1].set((tota+tots)/sqrt2) return ret.at[jnp.diag_indices(p)].set(omg[jnp.diag_indices(p)]) def rand_point(self, key): """ A random point on the manifold """ xt, key = self.rand_ambient(key) xo = jla.qr(xt[:-1, :-1])[0] if jla.det(xo) < 0: xo = xo.at[0, :].set(-xo[0, :]) return xt.at[:-1, :-1].set(xo).at[-1, :-1].set(0.).at[-1, -1].set(1.), key # @partial(jax.jit, static_argnums=(0,)) def retract(self, x, v): """ second order retraction, but simple """ x1 = x + v - 0.5* self.proj(x, self.gamma(x, v, v)) u, _, v = jla.svd(x1[:-1, :-1]) return x1.at[:-1, :-1].set(u@v) @partial(jax.jit, static_argnums=(0,)) def approx_nearest(self, q): u, _, v = jla.svd(q[:-1, :-1]) return q.at[:-1, :-1].set(u@v) def pseudo_transport(self, x, y, v): """the easy one """ return y@x.T@v def sigma(self, x, dw): return x@self._lie_algebra_proj(self._mat_apply(self._i_sqrt_g_mat, dw))