Source code for jax_rb.manifolds.affine_left_invariant

""":math:`Aff^+`: Positive Component of the Affine group with left invariant metric.
"""

import jax.numpy.linalg as jla
from .matrix_left_invariant import MatrixLeftInvariant
from ..utils.utils import (grand)


[docs] class AffineLeftInvariant(MatrixLeftInvariant): """Group of affine tranformations of :math:`\\mathbb{R}^{n}`, represented by a pair :math:`(A, v)\\in GL^+(n)\\times \\mathbb{R}^{n}` with action :math:`(A, v).w = Aw + v` for :math:`w\\in\\mathbb{R}^{n}` . Alternatively, it is represented as a matrix :math:`\\begin{bmatrix} A & v \\\\ 0 & 1 \\end{bmatrix}\\in GL(n+1)`. :param n: size of A :param g_mat: a positive definite matrix in :math:`\\mathbb{R}^{n(n+1)\\times n(n+1)}` defining the metric at :math:`I_{n+1}` """ def __init__(self, n, g_mat): """ g_mat is a matrix of size (n(n+1))**2 used to define the metric """ super().__init__(n+1, g_mat) self.dim = n*(n+1) def name(self): return f"Aff({self.shape[0]-1})" def _lie_algebra_proj(self, omg): """ The projection at identity """ return omg.at[-1, :].set(0.) def _mat_apply(self, mat, omg): """ mat is a matrix of size (p(p-1))**2 """ p = omg.shape[0] return omg.at[:-1, :].set( (mat@omg[:-1, :].reshape(-1)).reshape(p-1, p)) def rand_point(self, key): """ A random point on the manifold """ mat, key = grand(key, self.shape) return mat.at[-1, :].set(0.).at[-1, -1].set(1.), key def retract(self, x, v): """ second order retraction, but simple """ return (x + v -0.5*self.gamma(x, v, v)).at[-1, :].set(0.) def approx_nearest(self, q): return q.at[-1, :].set(0.) def pseudo_transport(self, x, y, v): """the easy one """ return y@jla.solve(x, v) def sigma(self, x, dw): return x@self._lie_algebra_proj(self._mat_apply(self._i_sqrt_g_mat, dw))