"""Common util functions
"""
import jax.numpy as jnp
from jax import random
[docs]
def asym(a):
"""asymmetrize
"""
return 0.5*(a-a.T)
[docs]
def sym(a):
"""symmetrize
"""
return 0.5*(a+a.T)
[docs]
def lie(a, b):
"""Lie bracket
"""
return a@b - b@a
[docs]
def vcat(x, y):
"""vertical concatenate
"""
return jnp.concatenate([x, y], axis=0)
[docs]
def hcat(x, y):
"""horizontal concatenate
"""
return jnp.concatenate([x, y], axis=1)
[docs]
def grand(key, shape):
""" random with key
"""
key, sk = random.split(key)
return random.normal(sk, shape), key
[docs]
def cz(a):
""" check if zero
"""
return jnp.max(jnp.abs(a))