""":math:`St`: Stiefel manifold.
"""
from functools import partial
import jax.numpy as jnp
import jax.numpy.linalg as jla
from jax import jit
from jax.scipy.linalg import expm
from ..expv.utils import (sym, asym, vcat, grand)
from ..expv.expv import LinearOperator
[docs]
class StiefelParallelOperator(LinearOperator):
""" Defining the operator P used in parallel transport
on Stiefel manifolds
"""
def __init__(self, params):
self.ar = params['ar']
self.salp = jnp.sqrt(params['alpha'])
[docs]
def set_params(self, params):
self.ar = params['ar']
if 'alpha' in params:
self.salp = jnp.sqrt(params['alpha'])
@partial(jit, static_argnums=(0,))
def dot(self, b):
ar, salp = self.ar, self.salp
d = ar.shape[1]
b_a = b[:d, :]
b_r = b[d:, :]
a = ar[:d, :]
r = ar[d:, :]
return vcat(
((4*salp**2-1)*asym(b_a@a) + salp*asym(r.T@b_r)),
(salp**2*b_r@a-salp*r@b_a))
@partial(jit, static_argnums=(0,))
def one_norm_est(self):
ar, salp = self.ar, self.salp
d = ar.shape[1]
na = salp*jnp.max(jnp.sum(jnp.abs(ar[d:, :]), axis=0)
+ jnp.abs(4*salp**2-1)/salp*jla.norm(ar[:d, :], 1))
nr = salp**2*jnp.max(jnp.sum(jnp.abs(ar[:d, :]), axis=0)
+ 1/salp*jla.norm(ar[d:, :], jnp.inf))
return jnp.max(jnp.array([na, nr]))
class StiefelOperator(LinearOperator):
""" Evaluating expm([[a, -r.T], [r, 0]])
Operate on a thin vector.
For testing.
"""
def __init__(self, params):
self.ar = params
def set_params(self, params):
self.ar = params
@partial(jit, static_argnums=(0,))
def dot(self, b):
d = self.ar.shape[1]
return vcat(self.ar[:d, :]@b[:d, :] - self.ar[d:, :].T@b[d:, :],
self.ar[d:, :]@b[:d, :])
def one_norm_est(self):
return jla.norm(self.ar, jnp.inf)
[docs]
class Stiefel():
""":math:`\\mathrm{St}_{n,d}` with an invariant metric defined by a parameter.
:param p: the size of the matrix
:param alpha: the metric is :math:`tr \\eta^{T}\\eta+(\\alpha-1)tr\\eta^TYY^T\\eta`.
"""
def __init__(self, n, d, alpha):
self.shape = (n, d)
self.alpha = alpha
self.d = d
[docs]
def name(self):
""" name of the object
"""
return f"Stiefel({self.shape}) alpha={self.alpha}"
[docs]
def inner(self, x, xi, eta):
""" Inner product
"""
alp = self.alpha
return jnp.sum(xi*eta) + (alp-1)*jnp.sum((x.T@xi)*(x.T@eta))
def proj(self, x, omg):
return omg - x@sym(x.T@omg)
[docs]
def rand_ambient(self, key):
"""random ambient vector
"""
return grand(key, self.shape)
[docs]
def rand_point(self, key):
""" A random point on the manifold
"""
tmp, key = self.rand_ambient(key)
return jla.qr(tmp)[0], key
[docs]
def rand_vec(self, key, x):
""" A random tangent vector to the manifold at x
"""
tmp, key = self.rand_ambient(key)
return self.proj(x, tmp), key
[docs]
def retract(self, x, v):
""" second order retraction.
"""
return x + v - 0.5* self.proj(x, self.christoffel_gamma(x, v, v))
[docs]
def approx_nearest(self, q):
""" point on the manifold that is approximately nearest to q
"""
return jla.qr(q)[0]
[docs]
def make_ar(self, a, r):
""" lift ar a tangent vector to the manifold at :math:`I_{n,d}`
to a square matrix, the lifted horizontal vector at :math:`I_n\\in SO(n)`.
"""
k = r.shape[0]
return jnp.concatenate([
jnp.concatenate([a, - r.T], axis=1),
jnp.concatenate([r, jnp.zeros((k, k))], axis=1)], axis=0)
[docs]
def exp(self, x, v):
""" geodesic, or riemannian exponential
"""
n, d = x.shape
u, _, _ = jla.svd(v - x@(x.T@v), full_matrices=False)
k = min(n-d, d)
q = u[:, :k]
a = x.T@v
r = q.T@v
aar = self.make_ar(2*self.alpha*a, r)
return (jnp.concatenate([x, q], axis=1)@expm(aar)[:, :d])@expm((1-2*self.alpha)*a)
[docs]
def dexp(self, x, v, t, ddexp=False):
""" Higher derivative of Exponential function.
:param x: the initial point :math:`\\gamma(0)`
:param v: the initial velocity :math:`\\dot{\\gamma}(0)`
:param t: time.
If ddexp is False, we return :math:`\\gamma(t), \\dot{\\gamma}(t)`.
Otherwise, we return :math:`\\gamma(t), \\dot{\\gamma}(t), \\ddot{\\gamma}(t)`.
"""
n, d = x.shape
alp = self.alpha
u, _, _ = jla.svd(v - x@(x.T@v), full_matrices=False)
k = jnp.min(jnp.array([n-d, d]))
q = u[:, :k]
a = x.T@v
r = q.T@v
ar = self.make_ar(a, r)
aar = self.make_ar(2*alp*a, r)
prt0 = jnp.concatenate([x, q], axis=1)@expm(t*aar)
prt1 = expm(t*(1-2*self.alpha)*a)
if not ddexp:
return prt0[:, :d]@prt1, (prt0@ar)[:, :d]@prt1
lie_ar_a0 = jnp.zeros_like(ar)
lie_ar_a0 = lie_ar_a0.at[d:, :d].set(ar[d:, :d]@a)
lie_ar_a0 = lie_ar_a0.at[:d, d:].set(- lie_ar_a0[d:, :d].T)
return prt0[:, :d]@prt1, \
(prt0@ar)[:, :d]@prt1, \
(prt0@(ar@ar + (1-2*alp)*lie_ar_a0))[:, :d]@prt1
[docs]
def christoffel_gamma(self, x, xi, eta):
"""Christoffel function of the manifold
"""
alp = self.alpha
xTxi = x.T@xi
xTeta = x.T@eta
def sym2(a):
return a + a.T
return 0.5*x@(xi.T@eta + eta.T@xi) - (1-alp)*(
xi@xTeta + eta@xTxi - x@sym2(xTxi@xTeta))
def _sc(self, ar, ft):
""" Scaling the a block of ar by a factor ft
"""
arn = ar.copy()
return arn.at[:ar.shape[1], :].set(ar[:ar.shape[1], :]*ft)
[docs]
def parallel(self, x, xi, eta, t):
"""parallel transport. The exponential action is computed
using expv, with our customized estimate of 1_norm of the operator P
:param x: a point on the manifold
:param xi: the initial velocity of the geodesic
:param eta: the vector to be transported
:param t: time.
"""
n, d = x.shape
alp = self.alpha
salp = jnp.sqrt(self.alpha)
u, _, _ = jla.svd(xi - x@(x.T@xi), full_matrices=False)
k = jnp.min(jnp.array([n-d, d]))
q = u[:, :k]
xq = jnp.concatenate([x, q], axis=1)
ar = xq.T@xi
a = ar[:d, :]
r = ar[d:, :]
aar = jnp.concatenate([vcat(2*alp*a, r),
vcat(-r.T, jnp.zeros((k, k)))], axis=1)
prt0 = xq@expm(t*aar)
prt1 = expm(t*(1-2*alp)*a)
sp_opt = StiefelParallelOperator({"ar": ar, "alpha": alp})
w = self._sc(sp_opt.expv(self._sc(xq.T@eta, salp), t), 1/salp)
return prt0@w@prt1 \
+ (eta - x@x.T@eta - q@q.T@eta)@expm(t*(1-alp)*a)