The JAX version. This version offers the AD feature of JAX, and execute on GPU.
jax_par_trans.manifolds
Stiefel
\(St\): Stiefel manifold.
- class jax_par_trans.manifolds.stiefel.Stiefel(n, d, alpha)[source]
\(\mathrm{St}_{n,d}\) with an invariant metric defined by a parameter.
- Parameters:
p – the size of the matrix
alpha – the metric is \(tr \eta^{T}\eta+(\alpha-1)tr\eta^TYY^T\eta\).
- dexp(x, v, t, ddexp=False)[source]
Higher derivative of Exponential function.
- Parameters:
x – the initial point \(\gamma(0)\)
v – the initial velocity \(\dot{\gamma}(0)\)
t – time.
If ddexp is False, we return \(\gamma(t), \dot{\gamma}(t)\). Otherwise, we return \(\gamma(t), \dot{\gamma}(t), \ddot{\gamma}(t)\).
- make_ar(a, r)[source]
lift ar a tangent vector to the manifold at \(I_{n,d}\) to a square matrix, the lifted horizontal vector at \(I_n\in SO(n)\).
Flag
\(Flag\): Flag manifold.
- class jax_par_trans.manifolds.flag.Flag(dvec, alpha=0.5)[source]
\(Flag(\vec{d})\) with a homogeneous metric defined by a parameter. Realized as a quotient of a Stiefel manifold
- Parameters:
alpha – the metric is \(tr \eta^{T}\eta+(\alpha-1)tr\eta^TYY^T\eta\).
For ease of implementation, \(d_{p+1}\) is renamed d[0] and saved at top of dvec.
- dexp(x, v, t, ddexp=False)[source]
Higher derivative of Exponential function.
- Parameters:
x – the initial point \(\gamma(0)\)
v – the initial velocity \(\dot{\gamma}(0)\)
t – time.
If ddexp is False, we return \(\gamma(t), \dot{\gamma}(t)\). Otherwise, we return \(\gamma(t), \dot{\gamma}(t), \ddot{\gamma}(t)\).
- make_ar(a, r)[source]
lift ar a tangent vector to the manifold at \(I_{n,d}\) to a square matrix a horizontal vector at \(SO(n)\)
- parallel_canonical(x, xi, eta, t)[source]
only works for alpha = .5 parallel transport. Only works for alpha = .5 The exponential action is computed using expv, with our customized estimate of 1_norm of the operator P
- Parameters:
x – a point on the manifold
xi – the initial velocity of the geodesic
eta – the vector to be transported
t – time.
jax_par_trans.expv
Expv
define a class of linear operators with an exponential action method.
- class jax_par_trans.expv.expv.LinearOperator(params=None)[source]
A class of linear operators with method for exponential action. The operator operates on a vector space via dot. the exponential action expv(t*self, b) is provided through the method expv. To use the operator, define a derived class of linear operator supplying an estimate for 1-norm.
Utils
Common util functions