"""Stiefel manifold :math:`\\mathrm{St}((n, d), \\alpha)` with metric defined by a parameters :math:`\\alpha`.
"""
import numpy as np
import numpy.linalg as la
from numpy.random import randn
from scipy.linalg import expm
from scipy.sparse.linalg import LinearOperator, expm_multiply
from par_trans.utils.utils import (sym, asym, vcat)
[docs]
def par_bal(b, ar, salp):
""" balanced parallel operator
:math:`s_{salp}(P(ar, s_{\\frac{1}{salp}}(b)))`, where s is the operator scaling the top
:math:`d\\times d` block in a :math:`n\\times d` matrix by the subscript argument.
salp is typically :math:`\\alpha^{\\frac{1}{2}}`.
This operator is antisymmetric when restricted to the tangent space at :math:`I_{n,d}`.
"""
d = ar.shape[1]
b_a = b[:d, :]
b_r = b[d:, :]
a = ar[:d, :]
r = ar[d:, :]
return vcat(
((4*salp**2-1)*asym(b_a@a) + salp*asym(r.T@b_r)),
(salp**2*b_r@a-salp*r@b_a))
[docs]
def solve_w(b, ar, alp, t, tol=None):
"""The exponential action :math:`expv(tP_{ar}, b)` when the metric is
given by the parameter :math:`\\alpha`.
The calculation uses the 1-norm estimate in the local function one_norm_est.
"""
_theta = {
# The first 30 values are from table A.3 of Computing Matrix Functions.
1: 2.29e-16,
2: 2.58e-8,
3: 1.39e-5,
4: 3.40e-4,
5: 2.40e-3,
6: 9.07e-3,
7: 2.38e-2,
8: 5.00e-2,
9: 8.96e-2,
10: 1.44e-1,
# 11
11: 2.14e-1,
12: 3.00e-1,
13: 4.00e-1,
14: 5.14e-1,
15: 6.41e-1,
16: 7.81e-1,
17: 9.31e-1,
18: 1.09,
19: 1.26,
20: 1.44,
# 21
21: 1.62,
22: 1.82,
23: 2.01,
24: 2.22,
25: 2.43,
26: 2.64,
27: 2.86,
28: 3.08,
29: 3.31,
30: 3.54,
# The rest are from table 3.1 of
# Computing the Action of the Matrix Exponential.
35: 4.7,
40: 6.0,
45: 7.2,
50: 8.5,
55: 9.9,
}
salp = np.sqrt(alp)
_, d = ar.shape
def one_norm_est():
na = t*salp*la.norm(
np.concatenate([
ar[d:, :], np.abs(4*alp-1)/salp*la.norm(ar[:d, :],1)*np.ones((1, d))]), 1)
nr = t*salp*(salp*la.norm(np.concatenate(
[ar[:d, :], 1/salp*la.norm(ar[d:, :], np.inf)*np.ones((1, d))]),
1))
return max(na, nr)
norm_est = one_norm_est()
def calc_m_s(norm_est):
best_m = None
best_s = None
for m, theta in _theta.items():
s = int(np.ceil(norm_est / theta))
if best_m is None or m * s < best_m * best_s:
best_m = m
best_s = s
return best_m, best_s
m_star, s = calc_m_s(norm_est)
if tol is None:
u_d = 2 ** -53
tol = u_d
f = b.copy()
def norm2(x):
return np.sqrt(np.sum(x*x))
for _ in range(s):
c1 = norm2(b)
for j in range(m_star):
b = t / float(s*(j+1)) * par_bal(b, ar, salp)
c2 = norm2(b)
f = f + b
if c1 + c2 <= tol * norm2(f):
break
c1 = c2
b = f
return f
[docs]
class Stiefel():
""":math:`\\mathrm{St}_{n,d}` with an invariant metric defined by a parameter.
:param p: the size of the matrix
:param alpha: the metric is :math:`tr \\eta^{T}\\eta+(\\alpha-1)tr\\eta^TYY^T\\eta`.
"""
def __init__(self, n, d, alpha, null_cut_off=1e-12):
self.shape = (n, d)
self.alpha = alpha
self.d = d
self.null_cut_off = null_cut_off
[docs]
def name(self):
""" name of the object
"""
return f"Stiefel({self.shape}) alpha={self.alpha}"
[docs]
def inner(self, x, xi, eta):
""" Inner product
"""
alp = self.alpha
# ix_xi = x.T@xi
# ix_eta = x.T@eta
return np.sum(xi*eta) + (alp-1)*np.sum((x.T@xi)*(x.T@eta))
def proj(self, x, omg):
return omg - x@sym(x.T@omg)
[docs]
def rand_ambient(self):
"""random ambient vector
"""
return randn(*(self.shape))
[docs]
def rand_point(self):
""" A random point on the manifold
"""
return la.qr(self.rand_ambient())[0]
[docs]
def rand_vec(self, x):
""" A random tangent vector to the manifold at x
"""
return self.proj(x, self.rand_ambient())
[docs]
def retract(self, x, v):
""" second order retraction.
"""
return x + v - 0.5* self.proj(x, self.christoffel_gamma(x, v, v))
[docs]
def approx_nearest(self, q):
""" point on the manifold that is approximately nearest to q
"""
return la.qr(q)[0]
[docs]
def make_ar(self, a, r):
""" lift ar a tangent vector to the manifold at :math:`I_{n,d}`
to a square matrix, the lifted horizontal vector at :math:`I_n\\in SO(n)`.
"""
k = r.shape[0]
return np.concatenate([
np.concatenate([a, - r.T], axis=1),
np.concatenate([r, np.zeros((k, k))], axis=1)], axis=0)
[docs]
def exp(self, x, v):
""" geodesic, or riemannian exponential
"""
n, d = x.shape
u, _, _ = la.svd(v - x@(x.T@v), full_matrices=False)
k = min(n-d, d)
q = u[:, :k]
a = x.T@v
r = q.T@v
aar = self.make_ar(2*self.alpha*a, r)
return (np.concatenate([x, q], axis=1)@expm(aar)[:, :d])@expm((1-2*self.alpha)*a)
[docs]
def dexp(self, x, v, t, ddexp=False):
""" Higher derivative of Exponential function.
:param x: the initial point :math:`\\gamma(0)`
:param v: the initial velocity :math:`\\dot{\\gamma}(0)`
:param t: time.
If ddexp is False, we return :math:`\\gamma(t), \\dot{\\gamma}(t)`.
Otherwise, we return :math:`\\gamma(t), \\dot{\\gamma}(t), \\ddot{\\gamma}(t)`.
"""
n, d = x.shape
alp = self.alpha
u, _, _ = la.svd(v - x@(x.T@v), full_matrices=False)
k = min(n-d, d)
q = u[:, :k]
a = x.T@v
r = q.T@v
ar = self.make_ar(a, r)
aar = self.make_ar(2*alp*a, r)
prt0 = np.concatenate([x, q], axis=1)@expm(t*aar)
prt1 = expm(t*(1-2*self.alpha)*a)
if not ddexp:
return prt0[:, :d]@prt1, (prt0@ar)[:, :d]@prt1
lie_ar_a0 = np.zeros_like(ar)
lie_ar_a0[d:, :d] = ar[d:, :d]@a
lie_ar_a0[:d, d:] = - lie_ar_a0[d:, :d].T
return prt0[:, :d]@prt1, \
(prt0@ar)[:, :d]@prt1, \
(prt0@(ar@ar + (1-2*alp)*lie_ar_a0))[:, :d]@prt1
[docs]
def christoffel_gamma(self, x, xi, eta):
"""function representing the Christoffel symbols
"""
alp = self.alpha
xTxi = x.T@xi
xTeta = x.T@eta
def sym2(a):
return a + a.T
return 0.5*x@(xi.T@eta + eta.T@xi) - (1-alp)*(
xi@xTeta + eta@xTxi - x@sym2(xTxi@xTeta))
[docs]
def parallel_expm_multiply(self, x, xi, eta, t):
"""parallel transport. The exponential action is computed using expm_multiply from scipy.
:param x: a point on the manifold
:param xi: the initial velocity of the geodesic
:param eta: the vector to be transported
:param t: time.
"""
n, d = x.shape
alp = self.alpha
u, _, _ = la.svd(xi - x@(x.T@xi), full_matrices=False)
k = min(n-d, d)
q = u[:, :k]
a = x.T@xi
r = q.T@xi
# ar = self.make_ar(a, r)
xq = np.concatenate([x, q], axis=1)
aar = self.make_ar(2*alp*a, r)
prt0 = xq@expm(t*aar)
prt1 = expm(t*(1-2*self.alpha)*a)
def par(b):
b_a = b[:d, :]
b_r = b[d:, :]
return vcat(
(4*alp-1)*asym(b_a@a) + asym(r.T@b_r),
alp*(b_r@a-r@b_a))
# return 0.5*(lie(b, a) + (1-2*alp)*(lie(a_k_0, b) - lie(self._lie_proj_a(b), a)))
def par_T(b):
b_a = b[:d, :]
b_r = b[d:, :]
return vcat(
-(4*alp-1)*asym(b_a)@a - alp*r.T@b_r,
r@asym(b_a) - alp*b_r@a)
p_opt = LinearOperator(((d+k)*d, (d+k)*d),
matvec=lambda w: t*par(w.reshape(d+k, d)).reshape(-1),
rmatvec=lambda w: t*par_T(w.reshape(d+k, d)).reshape(-1))
return prt0@expm_multiply(p_opt, (xq.T@eta).reshape(-1), traceA=0).reshape(d+k, d)@prt1 \
+ (eta - x@x.T@eta - q@q.T@eta)@expm(t*(1-alp)*a)
[docs]
def parallel(self, x, xi, eta, t):
"""parallel transport. The exponential action is computed
using expv, with our customized estimate of 1_norm of the operator P
:param x: a point on the manifold
:param xi: the initial velocity of the geodesic
:param eta: the vector to be transported
:param t: time.
"""
n, d = x.shape
alp = self.alpha
u, _, _ = la.svd(xi - x@(x.T@xi), full_matrices=False)
k = min(n-d, d)
q = u[:, :k]
xq = np.concatenate([x, q], axis=1)
ar = xq.T@xi
a = ar[:d, :]
r = ar[d:, :]
# ar = self.make_ar(a, r)
xq = np.concatenate([x, q], axis=1)
aar = self.make_ar(2*alp*a, r)
prt0 = xq@expm(t*aar)
prt1 = expm(t*(1-2*self.alpha)*a)
def sc(ar, ft):
""" Scaling the a block of ar by a factor ft
"""
arn = ar.copy()
arn[:ar.shape[1], :] = ar[:ar.shape[1], :]*ft
return arn
# w = sc(sp_opt.expv(sc(xq.T@eta, salp), t), 1/salp)
salp = np.sqrt(alp)
w = sc(solve_w(sc(xq.T@eta, salp), ar, alp, t), 1/salp)
return prt0@w@prt1 + (eta - x@x.T@eta - q@q.T@eta)@expm(t*(1-alp)*a)