Source code for jax_par_trans.expv.utils

"""Common util functions
"""
import jax.numpy as jnp
from jax import random


[docs] def asym(a): """asymmetrize """ return 0.5*(a-a.T)
[docs] def sym(a): """symmetrize """ return 0.5*(a+a.T)
[docs] def lie(a, b): """Lie bracket """ return a@b - b@a
[docs] def vcat(x, y): """vertical concatenate """ return jnp.concatenate([x, y], axis=0)
[docs] def hcat(x, y): """horizontal concatenate """ return jnp.concatenate([x, y], axis=1)
[docs] def grand(key, shape): """ random with key """ key, sk = random.split(key) return random.normal(sk, shape), key
[docs] def cz(a): """ check if zero """ return jnp.max(jnp.abs(a))