The numpy version. While this version does not offer the AD feature of JAX, it may be faster on CPU, and we can use expm_multiply, which is not yet implemented in JAX
par_trans.manifolds
Stiefel
Stiefel manifold \(\mathrm{St}((n, d), \alpha)\) with metric defined by a parameters \(\alpha\).
- par_trans.manifolds.stiefel.par_bal(b, ar, salp)[source]
balanced parallel operator \(s_{salp}(P(ar, s_{\frac{1}{salp}}(b)))\), where s is the operator scaling the top \(d\times d\) block in a \(n\times d\) matrix by the subscript argument. salp is typically \(\alpha^{\frac{1}{2}}\). This operator is antisymmetric when restricted to the tangent space at \(I_{n,d}\).
- par_trans.manifolds.stiefel.solve_w(b, ar, alp, t, tol=None)[source]
The exponential action \(expv(tP_{ar}, b)\) when the metric is given by the parameter \(\alpha\). The calculation uses the 1-norm estimate in the local function one_norm_est.
- class par_trans.manifolds.stiefel.Stiefel(n, d, alpha, null_cut_off=1e-12)[source]
\(\mathrm{St}_{n,d}\) with an invariant metric defined by a parameter.
- Parameters:
p – the size of the matrix
alpha – the metric is \(tr \eta^{T}\eta+(\alpha-1)tr\eta^TYY^T\eta\).
- dexp(x, v, t, ddexp=False)[source]
Higher derivative of Exponential function.
- Parameters:
x – the initial point \(\gamma(0)\)
v – the initial velocity \(\dot{\gamma}(0)\)
t – time.
If ddexp is False, we return \(\gamma(t), \dot{\gamma}(t)\). Otherwise, we return \(\gamma(t), \dot{\gamma}(t), \ddot{\gamma}(t)\).
- make_ar(a, r)[source]
lift ar a tangent vector to the manifold at \(I_{n,d}\) to a square matrix, the lifted horizontal vector at \(I_n\in SO(n)\).
- parallel(x, xi, eta, t)[source]
parallel transport. The exponential action is computed using expv, with our customized estimate of 1_norm of the operator P
- Parameters:
x – a point on the manifold
xi – the initial velocity of the geodesic
eta – the vector to be transported
t – time.
Flag
\(Flag\): Flag manifold. Quotient of \(\mathrm{St}((n, d), \alpha)\) by a block diagonal group. For \(\alpha=\frac{1}{2}\), we have an efficient formula for parallel transport.
- par_trans.manifolds.flag.solve_w(b, ar, flg, t, tol=None)[source]
The exponential action \(expv(tP_{ar}, b)\) when the metric is given by the parameter \(\alpha\). The calculation uses the 1-norm estimate in the local function one_norm_est.
- class par_trans.manifolds.flag.Flag(dvec, alpha=0.5)[source]
\(Flag(\vec{d})\) with a homogeneous metric defined by a parameter. Realized as a quotient of a Stiefel manifold
- Parameters:
alpha – the metric is \(tr \eta^{T}\eta+(\alpha-1)tr\eta^TYY^T\eta\).
For ease of implementation, \(d_{p+1}\) is renamed d[0] and saved at top of dvec.
- dexp(x, v, t, ddexp=False)[source]
Higher derivative of Exponential function.
- Parameters:
x – the initial point \(\gamma(0)\)
v – the initial velocity \(\dot{\gamma}(0)\)
t – time.
If ddexp is False, we return \(\gamma(t), \dot{\gamma}(t)\). Otherwise, we return \(\gamma(t), \dot{\gamma}(t), \ddot{\gamma}(t)\).
- make_ar(a, r)[source]
lift ar a tangent vector to the manifold at \(I_{n,d}\) to a square matrix, the lifted horizontal vector at \(I_n\in SO(n)\).
- parallel_canonical_expm_multiply(x, xi, eta, t)[source]
only works for alpha = .5 parallel transport. Only works for alpha = .5 The exponential action is computed using expv, with our customized estimate of 1_norm of the operator P
- Parameters:
x – a point on the manifold
xi – the initial velocity of the geodesic
eta – the vector to be transported
t – time.
\(\mathrm{GL}^+(n)\)
\(GL^+\): Positive Component of the Generalized Linear group with a Cheeger deformation metric.
- class par_trans.manifolds.glp_beta.GLpBeta(n, beta)[source]
\(GL^+\) with left invariant metric defined by a parameter.
- Parameters:
p – the size of the matrix
beta – \((\beta_0, \beta_1)\): the metric at the identity is \(\beta_0tr \mathtt{g}^2 -\beta_1 tr\mathtt{g}_{\mathfrak{a}}^2\).
- dexp(x, v, t, ddexp=False)[source]
Higher derivative of Exponential function.
- Parameters:
x – the initial point \(\gamma(0)\)
v – the initial velocity \(\dot{\gamma}(0)\)
t – time.
If ddexp is False, we return \(\gamma(t), \dot{\gamma}(t)\). Otherwise, we return \(\gamma(t), \dot{\gamma}(t), \ddot{\gamma}(t)\).
\(\mathrm{SO}(n)\)
\(SO\): Special Orthogonal group with a Cheeger deformation metric.
- class par_trans.manifolds.so_alpha.SOAlpha(n, k, alpha)[source]
\(SO\) with left invariant metric defined by a parameter.
- Parameters:
n – the size of the matrix
alpha – the metric at the identity is \(-\frac{1}{2}tr \mathtt{g}^2-\frac{2\alpha-1}{2}tr\mathtt{g}_{\mathfrak{a}}^2\)
- christoffel_gamma_lie(x, xi, eta)[source]
function evaluating the christoffel symbols in Lie bracket form
- dexp(x, v, t, ddexp=False)[source]
Higher derivative of Exponential function.
- Parameters:
x – the initial point \(\gamma(0)\)
v – the initial velocity \(\dot{\gamma}(0)\)
t – time.
If ddexp is False, we return \(\gamma(t), \dot{\gamma}(t)\). Otherwise, we return \(\gamma(t), \dot{\gamma}(t), \ddot{\gamma}(t)\).
par_trans.expv
Expv
Compute the action of the matrix exponential. This module is taken from scipy, the one difference is we add the option use_frag_31 to see if the more difficult algorithm requiring the estimation of a higher norm makes much of a difference. This more difficult estimate is one of the reasons the algorithm is still not adapted for JAX. We use this module to show it is sufficient to use only the 1-norm in our case.
- par_trans.utils.expm_multiply_np.expm_multiply(A, B, start=None, stop=None, num=None, endpoint=None, traceA=None, use_frag_31=True)[source]
Compute the action of the matrix exponential of A on B, using the algorithm described in [1] and [2] .
Parameters
- Atransposable linear operator
The operator whose exponential is of interest.
- Bndarray
The matrix or vector to be multiplied by the matrix exponential of A.
- startscalar, optional
The starting time point of the sequence.
- stopscalar, optional
The end time point of the sequence, unless endpoint is set to False. In that case, the sequence consists of all but the last of
num + 1
evenly spaced time points, so that stop is excluded. Note that the step size changes when endpoint is False.- numint, optional
Number of time points to use.
- endpointbool, optional
If True, stop is the last time point. Otherwise, it is not included.
- traceAscalar, optional
Trace of A. If not given the trace is estimated for linear operators, or calculated exactly for sparse matrices. It is used to precondition A, thus an approximate trace is acceptable. For linear operators, traceA should be provided to ensure performance as the estimation is not guaranteed to be reliable for all cases.
- use_frag_31bool, optional
Indicates if we use high_p or not.
Returns
- expm_A_Bndarray
The result of the action \(e^{t_k A} B\).
Warns
- UserWarning
If A is a linear operator and
traceA=None
(default).
Notes
The optional arguments defining the sequence of evenly spaced time points are compatible with the arguments of numpy.linspace.
The output ndarray shape is somewhat complicated so I explain it here. The ndim of the output could be either 1, 2, or 3. It would be 1 if you are computing the expm action on a single vector at a single time point. It would be 2 if you are computing the expm action on a vector at multiple time points, or if you are computing the expm action on a matrix at a single time point. It would be 3 if you want the action on a matrix with multiple columns at multiple time points. If multiple time points are requested, expm_A_B[0] will always be the action of the expm at the first time point, regardless of whether the action is on a vector or a matrix.
References
Examples
>>> import numpy as np >>> from scipy.sparse import csc_matrix >>> from scipy.sparse.linalg import expm, expm_multiply >>> A = csc_matrix([[1, 0], [0, 1]]) >>> A.toarray() array([[1, 0], [0, 1]], dtype=int64) >>> B = np.array([np.exp(-1.), np.exp(-2.)]) >>> B array([ 0.36787944, 0.13533528]) >>> expm_multiply(A, B, start=1, stop=2, num=3, endpoint=True) array([[ 1. , 0.36787944], [ 1.64872127, 0.60653066], [ 2.71828183, 1. ]]) >>> expm(A).dot(B) # Verify 1st timestep array([ 1. , 0.36787944]) >>> expm(1.5*A).dot(B) # Verify 2nd timestep array([ 1.64872127, 0.60653066]) >>> expm(2*A).dot(B) # Verify 3rd timestep array([ 2.71828183, 1. ])
Utils
Common util functions