Source code for jax_par_trans.expv.expv

"""define a class of linear operators with an exponential action method.
"""
from functools import partial

import jax.numpy as jnp
from jax import jit, lax


# The first 30 values are from table A.3 of Computing Matrix Functions.
# other are from various table
_theta_array = jnp.array([
    [1, 2.29e-16],
    [2, 2.58e-8],
    [3, 1.39e-5],
    [4, 3.40e-4],
    [5, 2.40e-3],
    [6, 9.07e-3],
    [7, 2.38e-2],
    [8, 5.00e-2],
    [9, 8.96e-2],
    [10, 1.44e-1],
    [11, 2.14e-1],
    [12, 3.00e-1],
    [13, 4.00e-1],
    [14, 5.14e-1],
    [15, 6.41e-1],
    [16, 7.81e-1],
    [17, 9.31e-1],
    [18, 1.09],
    [19, 1.26],
    [20, 1.44],
    [21, 1.62],
    [22, 1.82],
    [23, 2.01],
    [24, 2.22],
    [25, 2.43],
    [26, 2.64],
    [27, 2.86],
    [28, 3.08],
    [29, 3.31],
    [30, 3.54],
    # The rest are from table 3.1 of]
    # Computing the Action of the Matrix Exponential.]
    [35, 4.7],
    [40, 6.0],
    [45, 7.2],
    [50, 8.5],
    [55, 9.9]])



[docs] class LinearOperator(): """ A class of linear operators with method for exponential action. The operator operates on a vector space via dot. the exponential action expv(t*self, b) is provided through the method expv. To use the operator, define a derived class of linear operator supplying an estimate for 1-norm. """ def __init__(self, params=None): pass
[docs] def set_params(self, params): """Override the params supplied in constructor this is to avoid creating new object repeatedly. """ raise NotImplementedError
[docs] def dot(self, b): """Let the operator operates on b """ raise NotImplementedError
[docs] def one_norm_est(self): """Estimate the 1-norm """ raise NotImplementedError
[docs] @partial(jit, static_argnums=(0, 3)) def expv(self, b, t, tol=None): """Exponential action on b. Return is exp(t self)b """ norm_est = t*self.one_norm_est() agm = jnp.argmin(jnp.ceil(norm_est/_theta_array[:, 1])*_theta_array[:, 0]) m_star, s = _theta_array[agm, 0], jnp.ceil(norm_est/_theta_array[agm, 1]) if tol is None: u_d = 2 ** -53 tol = u_d b = b.copy() f = b.copy() def norm2(x): return jnp.sqrt(jnp.sum(x*x)) def body_fun(carry): # b is the power sequence, f is the sum of sequence j, b, f, pass_norm, current_norm = carry pass_norm = current_norm b = t / (s*(j+1)) * self.dot(b) f = f + b current_norm = norm2(b) return j+1, b, f, pass_norm, current_norm def cond_fun(carry): j, _, f, pass_norm, current_norm = carry return (j < m_star) & (pass_norm + current_norm > tol * norm2(f)) def func(j, carry): j, b, f, pass_norm, current_norm = carry j, b, f, pass_norm, current_norm = lax.while_loop(cond_fun, body_fun, (0, b, f, jnp.inf, norm2(b))) return j, f, f, pass_norm, current_norm ret = lax.fori_loop(lower=0, upper=s.astype(int), body_fun=func, init_val=(0, b, f, jnp.inf, jnp.inf)) return ret[2]