The JAX version. This version offers the AD feature of JAX, and execute on GPU.

jax_par_trans.manifolds

Stiefel

\(St\): Stiefel manifold.

class jax_par_trans.manifolds.stiefel.Stiefel(n, d, alpha)[source]

\(\mathrm{St}_{n,d}\) with an invariant metric defined by a parameter.

Parameters:
  • p – the size of the matrix

  • alpha – the metric is \(tr \eta^{T}\eta+(\alpha-1)tr\eta^TYY^T\eta\).

approx_nearest(q)[source]

point on the manifold that is approximately nearest to q

christoffel_gamma(x, xi, eta)[source]

Christoffel function of the manifold

dexp(x, v, t, ddexp=False)[source]

Higher derivative of Exponential function.

Parameters:
  • x – the initial point \(\gamma(0)\)

  • v – the initial velocity \(\dot{\gamma}(0)\)

  • t – time.

If ddexp is False, we return \(\gamma(t), \dot{\gamma}(t)\). Otherwise, we return \(\gamma(t), \dot{\gamma}(t), \ddot{\gamma}(t)\).

exp(x, v)[source]

geodesic, or riemannian exponential

inner(x, xi, eta)[source]

Inner product

make_ar(a, r)[source]

lift ar a tangent vector to the manifold at \(I_{n,d}\) to a square matrix, the lifted horizontal vector at \(I_n\in SO(n)\).

name()[source]

name of the object

parallel(x, xi, eta, t)[source]

parallel transport. The exponential action is computed using expv, with our customized estimate of 1_norm of the operator P

Parameters:
  • x – a point on the manifold

  • xi – the initial velocity of the geodesic

  • eta – the vector to be transported

  • t – time.

rand_ambient(key)[source]

random ambient vector

rand_point(key)[source]

A random point on the manifold

rand_vec(key, x)[source]

A random tangent vector to the manifold at x

retract(x, v)[source]

second order retraction.

class jax_par_trans.manifolds.stiefel.StiefelParallelOperator(params)[source]

Defining the operator P used in parallel transport on Stiefel manifolds

set_params(params)[source]

Override the params supplied in constructor this is to avoid creating new object repeatedly.

Flag

\(Flag\): Flag manifold.

class jax_par_trans.manifolds.flag.Flag(dvec, alpha=0.5)[source]

\(Flag(\vec{d})\) with a homogeneous metric defined by a parameter. Realized as a quotient of a Stiefel manifold

Parameters:

alpha – the metric is \(tr \eta^{T}\eta+(\alpha-1)tr\eta^TYY^T\eta\).

For ease of implementation, \(d_{p+1}\) is renamed d[0] and saved at top of dvec.

approx_nearest(q)[source]

point on the manifold that is approximately nearest to q

christoffel_gamma(x, xi, eta)[source]

function representing the Christoffel symbols

dexp(x, v, t, ddexp=False)[source]

Higher derivative of Exponential function.

Parameters:
  • x – the initial point \(\gamma(0)\)

  • v – the initial velocity \(\dot{\gamma}(0)\)

  • t – time.

If ddexp is False, we return \(\gamma(t), \dot{\gamma}(t)\). Otherwise, we return \(\gamma(t), \dot{\gamma}(t), \ddot{\gamma}(t)\).

exp(x, v)[source]

geodesic, or riemannian exponential

inner(x, xi, eta)[source]

Inner product

make_ar(a, r)[source]

lift ar a tangent vector to the manifold at \(I_{n,d}\) to a square matrix a horizontal vector at \(SO(n)\)

name()[source]

name of the object

parallel_canonical(x, xi, eta, t)[source]

only works for alpha = .5 parallel transport. Only works for alpha = .5 The exponential action is computed using expv, with our customized estimate of 1_norm of the operator P

Parameters:
  • x – a point on the manifold

  • xi – the initial velocity of the geodesic

  • eta – the vector to be transported

  • t – time.

proj(x, omg)[source]

projection to the tangent bundle

proj_m(omg)[source]

projection to horizontal space

rand_ambient(key)[source]

random ambient vector

rand_point(key)[source]

A random point on the manifold

rand_vec(key, x)[source]

A random vector at x

retract(x, v)[source]

second order retraction

symf(omg)[source]

symmetrize but keep diagonal blocks unchanged

class jax_par_trans.manifolds.flag.FlagCanonicalParallelOperator(params)[source]

To implment expv of Flag parallel operator alpha is .5

set_params(params)[source]

Override the params supplied in constructor this is to avoid creating new object repeatedly.

jax_par_trans.expv

Expv

define a class of linear operators with an exponential action method.

class jax_par_trans.expv.expv.LinearOperator(params=None)[source]

A class of linear operators with method for exponential action. The operator operates on a vector space via dot. the exponential action expv(t*self, b) is provided through the method expv. To use the operator, define a derived class of linear operator supplying an estimate for 1-norm.

dot(b)[source]

Let the operator operates on b

expv(b, t, tol=None)[source]

Exponential action on b. Return is exp(t self)b

one_norm_est()[source]

Estimate the 1-norm

set_params(params)[source]

Override the params supplied in constructor this is to avoid creating new object repeatedly.

Utils

Common util functions

jax_par_trans.expv.utils.asym(a)[source]

asymmetrize

jax_par_trans.expv.utils.cz(a)[source]

check if zero

jax_par_trans.expv.utils.grand(key, shape)[source]

random with key

jax_par_trans.expv.utils.hcat(x, y)[source]

horizontal concatenate

jax_par_trans.expv.utils.lie(a, b)[source]

Lie bracket

jax_par_trans.expv.utils.sym(a)[source]

symmetrize

jax_par_trans.expv.utils.vcat(x, y)[source]

vertical concatenate