Source code for jax_rb.simulation.simulator

"""Simulator for global_manifold
"""
from collections import namedtuple

import jax.numpy as jnp
# import jax.numpy.linalg as jla
from jax import random, vmap


[docs] class RunParams(namedtuple('RunParams', ['x_0', 'key', 't_final', 'n_path', 'n_div', 'd_coeff', 'wiener_dim', 'm_size', 'normalize', 'run_type'])): """Parameters to save a run in simulator. :param x_0: starting point of the simulation :param key: key to generate the random numbers used in simulation. Created from jax.random.PRNGKey, then jax.random.split. :param t_final: The final time of simulation. Starting time is :math:`t=0`. :param n_path: number of paths used in simulation :param n_div: number of subdivision (interval will be t_final/n_div :param d_coeff: difusion coefficient, d_coeff = 0.5 for the Riemannian Brownian motion. :param wiener_dim: dimension of the Wienner process used in simulation. Usually it is the dimension of the ambient space :math:`\\mathcal{E}`. In some cases, we can simulate using the dimension of the manifold itself. :param m_size: a param indicating the size of the manifold, use to differentiate when simulating several manifolds, :param normalize: whether to normalize the move to a fixed lengh, :param run_type: string indicating one of the simulation moves. This is a tag to distinguish the output, does not affect the results. """
[docs] def simulate(x_0, integrator, path_pay_off, final_pay_off, params): """A simulation from :math:`t=0` up to time :math:`t=t_final`, with time increment :math:`t=\\frac{t_final}{n_div}`, run :math:`n_path` path. Return the full distribution of the simulation. We use the minimum cut-off with accuracy level 0.5 in this version. :param x_0: starting point of the simulation :param integrator: one of the integrators (geodesic, ito, stratonovich :param path_pay_off: the cost evaluated along the path :param final_pay_off: the contribution evaluated at the final time :param params: additional parameters for the simulations: sk, t_final, n_path, n_div, d_coeff, wiener_dim """ sk, t_final, n_path, n_div, d_coeff, wiener_dim = params p2 = 0.5 a_h = (2*p2*jnp.log(t_final/n_div))**.5 x_all = random.normal(sk, (wiener_dim, n_div, n_path)) x_all = x_all.at[jnp.where(x_all > a_h)].set(a_h).at[jnp.where(x_all < -a_h)].set(-a_h) def do_one_path(seq): path_sum = 0. x_i = x_0.copy() for j in range(n_div): x_i = integrator(x_i, seq[:, j], t_final/n_div*2*d_coeff) if path_pay_off: path_sum += path_pay_off(x_i, j*t_final/n_div)*t_final/n_div return path_sum + final_pay_off(x_i) # batch_do_one_path = jax.vmap(do_one_path, in_axes=2) pay_offs = vmap(do_one_path, in_axes=2)(x_all) return pay_offs, x_all
[docs] class Simulator(): """ Class to do simulation on a manifold with particular funtion or simulators. Run results is saved in self.runs. :param path_pay_off is the function value evaluated along the path :param final_pay_off is the function value evaluated at final time """ def __init__(self, path_pay_off, final_pay_off): self.path_pay_off = path_pay_off self.final_pay_off = final_pay_off self.runs = []
[docs] def run(self, integrator, params): """ run a simulation :param integrator the integrator used :param params is of class RunParams """ sim_params = (params.key, params.t_final, params.n_path, params.n_div, params.d_coeff, params.wiener_dim) pay_offs, _ = simulate(params.x_0, integrator, self.path_pay_off, self.final_pay_off, sim_params) self.runs.append([params, pay_offs])
[docs] def save_runs(self, save_path): """ save all the runs to save_path """ idx = 0 pay_offs = [] prms = [] while idx < len(self.runs): pay_offs.append(self.runs[idx][1]) fi = self.runs[idx][0]._fields pp = {} for i in range(len(fi)): if fi[i] != 'x_0': pp[fi[i]] = self.runs[idx][0][i] prms.append(pp) idx += 1 jnp.savez(save_path, pay_offs=pay_offs, params=prms, allow_pickle=True)