Source code for jax_rb.manifolds.glp_left_invariant

""":math:`GL^+`: Positive Component of the Generalized Linear group with left-invariant metric.
"""

import jax.numpy.linalg as jla
from .matrix_left_invariant import MatrixLeftInvariant
from ..utils.utils import (grand)


[docs] class GLpLeftInvariant(MatrixLeftInvariant): """:math:`GL^+` with left invariant metric defined by g_mat. :param p: the size of the matrix :param g_mat: The matrix defining the inner product at the identity. g_mat is in :math:`\\mathbb{R}^{p^2\\times p^2}` . """ def name(self): return f"GL+({self.shape[0]})" def _mat_apply(self, mat, omg): return (mat@omg.reshape(-1)).reshape(self.shape) def _lie_algebra_proj(self, omg): return omg def rand_ambient(self, key): """random ambient vector """ return grand(key, (self.shape)) def rand_point(self, key): """ A random point on the manifold """ ret, key = self.rand_ambient(key) if jla.det(ret) < 0: return ret.at[0, :].set(-ret[0, :]), key return ret, key def retract(self, x, v): """ second order retraction, but simple """ return x + v - 0.5* self.proj(x, self.gamma(x, v, v)) def approx_nearest(self, q): return q def pseudo_transport(self, x, y, v): """the easy one """ return y@jla.solve(x, v) def sigma(self, x, dw): """ sigma is applied on a vector rather than a matrix """ return x@self._mat_apply(self._i_sqrt_g_mat, dw)