Source code for jax_rb.utils.utils

"""various utils for the project
"""
import jax.numpy as jnp
import jax.numpy.linalg as jla
from jax import random, jit


def pos(x):
    """ max(x, 0), jit friendly
    """
    return 0.5*(x+jnp.abs(x))


def asym(mat):
    return 0.5*(mat - mat.T)


def sym2(mat):
    return mat + mat.T


def sym(a):
    return 0.5*sym2(a)


[docs] def grand(key, dims): """ generate a random array of shape dim using key """ key, sk = random.split(key) return random.normal(sk, dims), key
def lyapunov(a, b): """solve aU + Ua = b A, B, U are symmetric """ yei, yv = jla.eigh(a) return yv@((yv.T@b@yv)/(yei[:, None] + yei[None, :]))@yv.T def vcat(x, y): """vertical concatenate """ return jnp.concatenate([x, y], axis=0) def sinc(x): """ better sinc than """ if jnp.abs(x) <= 1e-20: return 1 return jnp.sin(x)/x def sinc1(x): """sinc1 is dsinc/x """ if jnp.abs(x) < 1e-6: return -1/3 + x*x/2/3/5 return (x*jnp.cos(x)-jnp.sin(x))/x/x/x def sinc2(x): """ helper function, derivative of sinc1 = x*sinc2(x) """ if jnp.abs(x) < 1e-3: return 1/15 - x*x/210 + x*x*x*x / 7560 return -((x*x-3)*jnp.sin(x) + 3*x*jnp.cos(x))/x**5 def dsinc(x): """Derivative of sinc """ if jnp.abs(x) < 1e-6: return -1/3*x + x*x*x/30 return (x*jnp.cos(x)-jnp.sin(x))/x/x def dsinc1(x): """sinc1 is dsinc/x dsinc1 is its derivative """ return x*sinc2(x)
[docs] def complement_basis_for_vector(xraw): """ complement basis of xraw, a non zero vector. Assume x[0] !=0 """ x = xraw/jnp.sqrt(jnp.sum(xraw*xraw)) q = 2.*(x[0] > 0) - 1 p = x[0]/q return jnp.concatenate( [-q*x[1:].reshape(1, -1), jnp.eye(x.shape[0]-1)-1/(1+p)*x[1:][:, None]@x[1:][None, :]])
[docs] def esqrtm(x): """ sqrtm by eigenvalue """ ei, ev = jla.eigh(x) return ev@(jnp.sqrt(ei)[:, None]*ev.T)
[docs] def make_complement_basis(x): """ make complement basis of """ n, p = x.shape P = esqrtm(x[:p, :].T@x[:p, :]) Q = jla.solve(P, x[:p, :].T).T return jnp.concatenate([- Q@x[p:, :].T, jnp.eye(n-p) - x[p:, :]@jla.solve(P + jnp.eye(p), x[p:, :].T)], axis=0)
[docs] def generate_symmetric_tensor(key, k, m): """Generating symmetric tensor size k,m """ mat = jnp.full(tuple(m*[k]), jnp.nan) current_idx = m*[0] active_i = m - 1 tval, key = grand(key, (1,)) mat = mat.at[tuple(current_idx)].set(tval[0]) while True: if current_idx[active_i] < k - 1: current_idx[active_i] += 1 if jnp.isnan(mat[tuple(current_idx)]): i_s = tuple(sorted(current_idx)) if jnp.isnan(mat[i_s]): tval, key = grand(key, (1,)) mat = mat.at[i_s].set(tval[0]) # print('Doing %s' % str(i_s)) mat = mat.at[tuple(current_idx)].set(mat[i_s]) # print('Doing %s' % str(current_idx)) elif active_i == 0: break else: next_pos = jnp.where(jnp.array(current_idx)[:active_i] < k-1)[0] if next_pos.shape[0] == 0: break current_idx[next_pos[-1]] += 1 for jx in range(next_pos[-1]+1, m): current_idx[jx] = 0 active_i = m - 1 if jnp.isnan(mat[tuple(current_idx)]): i_s = tuple(sorted(current_idx)) if jnp.isnan(mat[i_s]): tval, key = grand(key, (1,)) mat = mat.at[i_s].set(tval) # print('Doing %s' % str(i_s)) mat = mat.at[tuple(current_idx)].set(mat[i_s]) # print('Doing %s' % str(current_idx)) return mat, key
def _fill_symmetric(p_raw, k): """Fill a k by k matrix with p_raw symmetrically """ p = jnp.zeros((k, k)) start = 0 for i in range(k-1): p = p.at[i, i+1:].set(p_raw[start:start+k-i-1]) p = p.at[i+1:, i].set(p_raw[start:start+k-i-1]) start += k-i-1 return jnp.fill_diagonal(p, p_raw[-k:], inplace=False) def tv_mode_product(tensor, x, modes): """ Evaluating tensor subsituting x for the last modes times indices """ v = tensor for _ in range(modes): v = jnp.tensordot(v, x, axes=1) return v def _gen_so_inertia_matrix(key, n): """ generate an n times n symmetric matrix with diagonal 1 """ i_mat, key = grand(key, (n , n)) i_mat = sym(jnp.abs(i_mat)) i_mat = i_mat.at[jnp.diag_indices(n)].set(1.) return i_mat, key def _old_rand_positive_definite(key, n): """ generate a positive definite matrix of size n """ # n2 = (n*(n-1)) // 2 mat, key = grand(key, (n, n)) mat = mat@mat.T return sym(mat), key
[docs] def rand_positive_definite(key, n, bounds=None): """ generate a positive definite matrix of size n """ # n2 = (n*(n-1)) // 2 mat, key = grand(key, (n, n)) if not bounds: return sym(mat@mat.T), key mat, _ = jla.qr(mat) key, sk = random.split(key) ei = random.uniform(sk, (n,), minval =bounds[0], maxval=bounds[1]) mat = mat@(ei[:, None]*mat.T) return sym(mat), key
def _so_metric_opt(lu_mat, a): """ bilinear form given by lu_mat operates on a. lu_mat is a (n(n-1)/2)*(n(n-1)/2) matrix, operates on vectorization of the upper and lower triangular matrices so the overall operation is self-adjoint. """ p = a.shape[0] rows, cols = jnp.triu_indices(p, 1) ret = jnp.empty((p, p)) ret = ret.at[rows, cols].set(lu_mat@a.take(rows*p+cols)) ret = ret.at[cols, rows].set(lu_mat@a.T.take(rows*p+cols)) ret = ret.at[jnp.diag_indices(p)].set(a[jnp.diag_indices(p)]) return ret def _inv_so_metric_opt(lu_mat, a): """ invert of lu_mat """ p = a.shape[0] rows, cols = jnp.triu_indices(p, 1) ret = jnp.empty((p, p)) ret = ret.at[rows, cols].set(jla.solve(lu_mat, a.take(rows*p+cols))) # rows, cols = jnp.tril_indices(p, -1) ret = ret.at[cols, rows].set(jla.solve(lu_mat, a.T.take(rows*p+cols))) ret = ret.at[jnp.diag_indices(p)].set(a[jnp.diag_indices(p)]) return ret
[docs] def unvec_skew(v): """ unravel a n(n-1)//2 vector to anti hermitian matrix """ sqrt2 = jnp.sqrt(2.) rows = .5 * (1 + jnp.sqrt(1 + 8 * v.shape[0])) rows = jnp.round(rows).astype(int) result = jnp.zeros((rows, rows)) result = result.at[jnp.triu_indices(rows, 1)].set(v) return (result.T - result)/sqrt2
def unvec_anti_hermitian(v): """ unravel a n(n-1)//2 vector to anti hermitian matrix """ sqrt2 = jnp.sqrt(2) rows = .5 * (1 + jnp.sqrt(1 + 8 * len(v))) rows = int(jnp.round(rows)) result = jnp.zeros((rows, rows)) result = result.at[jnp.triu_indices(rows, 1)].set(v) return (result.T.conjugate() - result)/sqrt2 def unvech(v): """ Unvvectorize a symmetric matrix to a real vector Undoing the vech operation. sqrt2*upper triangular part concatenate with diagonal This is compatible with the trace(a@b) metric Parameters ---------- v : A vector Returns ---------- the symmetric matrix undoing the vech operation """ sqrt2 = jnp.sqrt(2) # quadratic formula, correct fp error rows = .5 * (-1 + jnp.sqrt(1 + 8 * v.shape[0])) rows = jnp.round(rows).astype(int) result = jnp.zeros((rows, rows)) result = result.at[jnp.triu_indices(rows)].set(v/jnp.sqrt(2)) result = result.at[jnp.diag_indices(rows)].set( result[jnp.diag_indices(rows)]/sqrt2) # result = (result + result.T)/sqrt2 # divide diagonal elements by 2 return result + result.T def lie(a, b): """ Lie Bracket """ return a@b - b@a
[docs] @jit def jpolar(x): """ jax polar decomposition """ return jla.solve(esqrtm(x.T@x), x.T).T