Source code for jax_par_trans.manifolds.flag

""":math:`Flag`: Flag manifold.
"""
from functools import partial

import jax.numpy as jnp
import jax.numpy.linalg as jla
from jax import jit
from jax.scipy.linalg import expm

from jax_par_trans.expv.utils import (vcat, grand)
from jax_par_trans.expv.expv import LinearOperator


[docs] class FlagCanonicalParallelOperator(LinearOperator): """ To implment expv of Flag parallel operator alpha is .5 """ def __init__(self, params): self.ar = params['ar'] self.flag = params['flag']
[docs] def set_params(self, params): self.ar = params
@partial(jit, static_argnums=(0,)) def dot(self, b): ar, salp = self.ar, jnp.sqrt(.5) d = ar.shape[1] b_a = b[:d, :] b_r = b[d:, :] a = ar[:d, :] r = ar[d:, :] return vcat( self.flag.proj_m(b_a@a + salp*r.T@b_r), (0.5*b_r@a-salp*r@b_a)) @partial(jit, static_argnums=(0,)) def one_norm_est(self): ar, salp = self.ar, jnp.sqrt(.5) d = ar.shape[1] na = jnp.max(salp*jnp.sum(jnp.abs(ar[d:, :]), axis=0) + jla.norm(ar[:d, :], 1)) nr = 0.5*jnp.max(jnp.sum(jnp.abs(ar[:d, :]), axis=0) + 1/salp*jla.norm(ar[d:, :], jnp.inf)) return jnp.max(jnp.array([na, nr]))
[docs] class Flag(): """:math:`Flag(\\vec{d})` with a homogeneous metric defined by a parameter. Realized as a quotient of a Stiefel manifold :param alpha: the metric is :math:`tr \\eta^{T}\\eta+(\\alpha-1)tr\\eta^TYY^T\\eta`. For ease of implementation, :math:`d_{p+1}` is renamed d[0] and saved at top of dvec. """ def __init__(self, dvec, alpha=.5): self.n = jnp.sum(dvec) self.d = jnp.sum(dvec[:-1]) self.shape = (self.n, self.d) self.alpha = alpha self.dvec = jnp.concatenate([dvec[-1:], dvec[:-1]]) cs = self.dvec[:].cumsum() - self.dvec[0] self._g_idx = dict((i+1, (cs[i], cs[i+1])) for i in range(cs.shape[0]-1)) self.p = self.dvec.shape[0]-1
[docs] def name(self): """ name of the object """ return f"Flag({self.dvec}) alpha={self.alpha}"
[docs] @partial(jit, static_argnums=(0,)) def symf(self, omg): """ symmetrize but keep diagonal blocks unchanged """ p = self.p ret = 0.5*(omg+omg.T) for tt in range(1, p+1): bt, et = self._g_idx[tt] ret = ret.at[bt:et, bt:et].set(omg[bt:et, bt:et]) return ret
[docs] @partial(jit, static_argnums=(0,)) def proj_m(self, omg): """ projection to horizontal space """ p = self.p ret = 0.5*(omg-omg.T) for tt in range(1, p+1): bt, et = self._g_idx[tt] ret = ret.at[bt:et, bt:et].set(0.) return ret
[docs] def inner(self, x, xi, eta): """ Inner product """ alp = self.alpha # ix_xi = x.T@xi # ix_eta = x.T@eta return jnp.sum(xi*eta) + (alp-1)*jnp.sum((x.T@xi)*(x.T@eta))
[docs] def proj(self, x, omg): """ projection to the tangent bundle """ return omg - x@self.symf(x.T@omg)
[docs] def rand_ambient(self, key): """random ambient vector """ return grand(key, self.shape)
[docs] def rand_point(self, key): """ A random point on the manifold """ tmp, key = self.rand_ambient(key) return jla.qr(tmp)[0], key
[docs] def rand_vec(self, key, x): """ A random vector at x """ tmp, key = self.rand_ambient(key) return self.proj(x, tmp), key
[docs] def retract(self, x, v): """ second order retraction """ return x + v - 0.5* self.proj(x, self.christoffel_gamma(x, v, v))
[docs] def approx_nearest(self, q): """ point on the manifold that is approximately nearest to q """ return jla.qr(q)[0]
[docs] def make_ar(self, a, r): """ lift ar a tangent vector to the manifold at :math:`I_{n,d}` to a square matrix a horizontal vector at :math:`SO(n)` """ k = r.shape[0] return jnp.concatenate([ jnp.concatenate([a, - r.T], axis=1), jnp.concatenate([r, jnp.zeros((k, k))], axis=1)], axis=0)
[docs] def exp(self, x, v): """ geodesic, or riemannian exponential """ n, d = x.shape u, _, _ = jla.svd(v - x@(x.T@v), full_matrices=False) k = min(n-d, d) q = u[:, :k] a = x.T@v r = q.T@v aar = self.make_ar(2*self.alpha*a, r) return (jnp.concatenate([x, q], axis=1)@expm(aar)[:, :d])@expm((1-2*self.alpha)*a)
[docs] def dexp(self, x, v, t, ddexp=False): """ Higher derivative of Exponential function. :param x: the initial point :math:`\\gamma(0)` :param v: the initial velocity :math:`\\dot{\\gamma}(0)` :param t: time. If ddexp is False, we return :math:`\\gamma(t), \\dot{\\gamma}(t)`. Otherwise, we return :math:`\\gamma(t), \\dot{\\gamma}(t), \\ddot{\\gamma}(t)`. """ n, d = x.shape alp = self.alpha u, _, _ = jla.svd(v - x@(x.T@v), full_matrices=False) k = jnp.min(jnp.array([n-d, d])) q = u[:, :k] a = x.T@v r = q.T@v ar = self.make_ar(a, r) aar = self.make_ar(2*alp*a, r) prt0 = jnp.concatenate([x, q], axis=1)@expm(t*aar) prt1 = expm(t*(1-2*self.alpha)*a) if not ddexp: return prt0[:, :d]@prt1, (prt0@ar)[:, :d]@prt1 lie_ar_a0 = jnp.zeros_like(ar) lie_ar_a0 = lie_ar_a0.at[d:, :d].set(ar[d:, :d]@a) lie_ar_a0 = lie_ar_a0.at[:d, d:].set(- lie_ar_a0[d:, :d].T) return prt0[:, :d]@prt1, \ (prt0@ar)[:, :d]@prt1, \ (prt0@(ar@ar + (1-2*alp)*lie_ar_a0))[:, :d]@prt1
[docs] def christoffel_gamma(self, x, xi, eta): """function representing the Christoffel symbols """ alp = self.alpha xTxi = x.T@xi xTeta = x.T@eta def sym2(a): return a + a.T return x@self.symf(xi.T@eta) - (1-alp)*( xi@xTeta + eta@xTxi - x@sym2(xTxi@xTeta))
def _sc(self, ar, ft): """ Scaling the a block of ar by a factor ft """ arn = ar.copy() return arn.at[:ar.shape[1], :].set(ar[:ar.shape[1], :]*ft)
[docs] def parallel_canonical(self, x, xi, eta, t): """only works for alpha = .5 parallel transport. Only works for alpha = .5 The exponential action is computed using expv, with our customized estimate of 1_norm of the operator P :param x: a point on the manifold :param xi: the initial velocity of the geodesic :param eta: the vector to be transported :param t: time. """ n, d = x.shape # alp = 0.5 salp = jnp.sqrt(self.alpha) u, _, _ = jla.svd(xi - x@(x.T@xi), full_matrices=False) k = min(n-d, d) q = u[:, :k] xq = jnp.concatenate([x, q], axis=1) ar = xq.T@xi # a = ar[:d, :] prt0 = xq@expm(t*jnp.concatenate( [ar, vcat(-ar[d:, :].T, jnp.zeros((k, k)))], axis=1)) flag_opt = FlagCanonicalParallelOperator({"ar": ar, 'flag': self}) w = self._sc(flag_opt.expv(self._sc(xq.T@eta, salp), t), 1/salp) return prt0@w + (eta - x@x.T@eta - q@q.T@eta)@expm(0.5*t*ar[:d, :])