Source code for jax_rb.simulation.global_manifold_integrator

"""Module implementing simulation methods for embedded manifolds
"""
from functools import partial


import jax.numpy as jnp
from jax import jit


[docs] @partial(jit, static_argnums=(0,)) def geodesic_move(mnf, x, unit_move, scale): """ simulate using a second order retraction. The move is :math:`x_{new} = \\mathfrak{r}(x, \\Pi(x)\\sigma(x)(\\text{unit_move}(\\text{scale})^{\\frac{1}{2}}))` """ return mnf.retract(x, mnf.proj(x, mnf.sigma(x, unit_move.reshape(mnf.shape)*jnp.sqrt(scale))))
[docs] @partial(jit, static_argnums=(0,)) def geodesic_move_with_drift(mnf, x, unit_move, scale, additional_drift): """ This method is used to simulate a Riemanian Brownian motion with drift. The additional_drift is on top of the Brownian motion. Simulate using a second order retraction. The move is :math:`x_{new} = \\mathfrak{r}(x, \\Pi(x)\\sigma(x)(\\text{unit_move}(\\text{scale})^{\\frac{1}{2}}+\\text{scale}\\times\\text{additional_drift}))` """ return mnf.retract(x, mnf.proj(x, mnf.sigma(x, unit_move.reshape(mnf.shape)*jnp.sqrt(scale)) + scale*additional_drift))
[docs] @partial(jit, static_argnums=(0,)) def geodesic_move_normalized(mnf, x, unit_move, scale): """ similar to geodesic_move, but the move is normalized to have fixed length :math:`scale^{\\frac{1}{2}}` """ # stochastic dx v = mnf.proj(x, mnf.sigma(x, unit_move.reshape(mnf.shape))) v = v/jnp.sqrt(mnf.inner(x, v, v))*jnp.sqrt(scale) return mnf.retract(x, v)
[docs] @partial(jit, static_argnums=(0,)) def geodesic_move_exact(mnf, x, unit_move, scale): """ similar to geodesic_move, but use exact geodesic """ return mnf.exp(x, mnf.proj(x, mnf.sigma(x, unit_move.reshape(mnf.shape)*jnp.sqrt(scale))))
[docs] @partial(jit, static_argnums=(0,)) def geodesic_move_exact_normalized(mnf, x, unit_move, scale): """ similar to geodesic_move_exact, but use normalize the unit_move """ # stochastic dx v = mnf.proj(x, mnf.sigma(x, unit_move.reshape(mnf.shape))) v = v/jnp.sqrt(mnf.inner(x, v, v))*jnp.sqrt(scale) return mnf.exp(x, v)
[docs] @partial(jit, static_argnums=(0,)) def rbrownian_ito_move(mnf, x, unit_move, scale): """ Use Euler Maruyama and projection method to solve the Ito equation. """ return mnf.approx_nearest( x + mnf.proj(x, mnf.sigma(x, unit_move.reshape(mnf.shape)*jnp.sqrt(scale))) + mnf.ito_drift(x)*scale)
[docs] @partial(jit, static_argnums=(0,)) def ito_move_with_drift(mnf, x, unit_move, scale, additional_drift): """ This method is used to simulate a Riemanian Brownian motion with drift. The additional_drift is on top of the Brownian motion. Use Euler Maruyama and projection method to solve the Ito equation. """ return mnf.approx_nearest( x + mnf.proj(x, mnf.sigma(x, unit_move.reshape(mnf.shape)*jnp.sqrt(scale))) + (additional_drift+mnf.ito_drift(x))*scale)
[docs] @partial(jit, static_argnums=(0,)) def rbrownian_stratonovich_move(mnf, x, unit_move, scale): """ Use Euler Heun and projection method to solve the Stratonovich equation. """ # stochastic dx dxs = mnf.sigma(x, unit_move.reshape(mnf.shape)*jnp.sqrt(scale)) xbk = x + mnf.proj(x, dxs) return mnf.approx_nearest(x + mnf.proj(0.5*(x + xbk), dxs) + mnf.proj(x, mnf.ito_drift(x)*scale))
[docs] @partial(jit, static_argnums=(0,)) def stratonovich_move_with_drift(mnf, x, unit_move, scale, additional_drift): """ This method is used to simulate a Riemanian Brownian motion with drift. The additional_drift is on top of the Brownian motion. Use Euler Heun and projection method to solve the Stratonovich equation. """ # stochastic dx dxs = mnf.sigma(x, unit_move.reshape(mnf.shape)*jnp.sqrt(scale)) xbk = x + mnf.proj(x, dxs) return mnf.approx_nearest(x + mnf.proj(0.5*(x + xbk), dxs) + mnf.proj(x, mnf.ito_drift(x)+additional_drift)*scale)