"""Hypersurface with a constraint of the form :math:`\\sum_i d_i x_i^p = 1`
"""
import jax.numpy as jnp
import jax.numpy.linalg as jla
from .global_manifold import GlobalManifold
from ..utils.utils import (grand)
[docs]
class DiagHypersurface(GlobalManifold):
    """Hypersurface of the form :math:`\\sum_i d_ix_i^p = 1`.
    
    :param dvec: vector :math:`d_i` of coefficients.
       Sort dvec so dvec[-1] is positive.
    :param p: :math:`p > 0` is an integer, degree of the constraint.
    Use embedded metric.
    """
    def __init__(self, dvec, p):
        self.dvec = dvec
        self.shape = dvec.shape
        self.p = p
        self.dim = dvec.shape[0]-1
    def name(self):
        return f"DH{self.shape[0]-1}, {self.p}"
    def g_metric(self, x, omg):
        return omg
    def inv_g_metric(self, x, omg):
        return omg
    def inner(self, x, a, b):
        return jnp.sum(a*b)
    def cfunc(self, x):
        """ constraint for the surface is cfunc(x) = 1
        """
        return jnp.sum(self.dvec*x**self.p)
    def grad_c(self, x):
        """ gradient of cfunc
        """
        return self.p*self.dvec*x**(self.p-1)
    def rand_point(self, key):
        """random point on  manifold
        """
        p = self.p
        dvec = self.dvec
        x, key = grand(key, self.shape)
        val = self.cfunc(x)
        if p % 2 == 1:
            return x/jnp.abs(val)**(1/p)*jnp.sign(val)
        if val < 0:
            ret = jnp.concatenate(
                [x[:-1],
                 jnp.array([1/dvec[-1]*(1-jnp.sum(dvec[:-1]*x[:-1]**p))**(1/p)])])
        else:
            ret = x/val**(1/p)
        return ret, key
    def rand_vec(self, key, x):
        """random tangent vector
        """
        omg, key = grand(key, self.shape)
        return self.proj_scale(x, omg), key
    def proj(self, x, omg):
        """ othogonal projection
        """
        gcx = self.grad_c(x)
        return omg - gcx*jnp.sum(gcx*omg)/jnp.sum(gcx*gcx)
    def approx_nearest(self, q):
        """ tubular retraction. Need some work
        to show this is actually approx_nearest
        """
        val = self.cfunc(q)
        return q/val**(1/self.p)
    def retract(self, x, v):
        return self.approx_nearest(x + v - 0.5*self.proj_scale(x, self.gamma(x, v, v)))
    def proj_scale(self, x, omg):
        """rescale projection
        """
        return omg - x*jnp.sum(self.dvec*x**(self.p-1)*omg)
    def gamma(self, x, xi, eta):
        """Christoffel function
        """
        p = self.p
        gcx = self.grad_c(x)
        return p*(p-1)*gcx*jnp.sum(self.dvec*x**(p-2)*xi*eta)/jnp.sum(gcx*gcx)
    def ito_drift(self, x):
        p = self.p
        gcx = self.grad_c(x)
        return -0.5*p*(p-1)*gcx*(
            jnp.sum(self.dvec*x**(p-2)) - jnp.sum(self.dvec*x**(p-2)*gcx*gcx)/jnp.sum(gcx*gcx)
        )/jnp.sum(gcx*gcx)
    def pseudo_transport(self, x, y, v):
        gcx = self.grad_c(x)
        gcy = self.grad_c(y)
        a = jnp.sum(gcy*v)*(jla.norm(gcx)*jla.norm(gcy) - jnp.sum(gcx*gcy)) \
             
/ (jnp.sum(gcx*gcx)*jnp.sum(gcy*gcy) - jnp.sum(gcy*gcx)**2)
        return v - a*gcx - (jnp.sum(gcy*v) - a*jnp.sum(gcy*gcx))/jnp.sum(gcy*gcy)*gcy
    def sigma(self, x, dw):
        return dw
    def rtr_tan_scale(self, yv, dyv):
        """retraction to the tangent bundle using rescale
        projection
        """
        y1 = self.retract(yv[:, 0], dyv[:, 0])
        v1 = self.proj_scale(y1, yv[:, 1] + dyv[:, 1])
        v1 = v1*jnp.sqrt(self.inner(yv[0], yv[:, 1], yv[:, 1])/self.inner(y1, v1, v1))
        return jnp.concatenate([
            y1[:, None],
            v1[:, None]], axis=1)
    def geodesic(self, x, v, t, nstep=100):
        """ approximate geodesic
        using the retraction to the tangent bundle rtr_tan
        """
        yv = jnp.concatenate([x[:, None], v[:, None]], axis=1)
        h = t/nstep
        def dyvdt(_, yv):
            return jnp.concatenate(
                [yv[:, [1]], -self.gamma(yv[:, 0], yv[:, 1], yv[:, 1])[:, None]],
                axis=1)
        t0 = 0
        for _ in range(1, nstep+1):
            # Apply Runge Kutta Formulas to find next value of y
            k1 = h * dyvdt(t0, yv)
            k2 = h * dyvdt(t0 + 0.5 * h, yv + 0.5 * k1)
            k3 = h * dyvdt(t0 + 0.5 * h, yv + 0.5 * k2)
            k4 = h * dyvdt(t0 + h, yv + k3)
            yv = self.rtr_tan_scale(yv,  (1.0 / 6.0)*(k1 + 2 * k2 + 2 * k3 + k4))
            t0 = t0 + h
        return yv[:, 0], yv[:, 1]
    def make_tangent_basis(self, x):
        """ tangent  basis at x
        """
        d = self.dim
        gcx = self.grad_c(x)
        proj_mat = jnp.eye(self.shape[0]) - gcx[:, None]@gcx[None, :]/jnp.sum(gcx*gcx)
        _, ev = jla.eigh(proj_mat)
        cmat = ev[:, 1:]
        mat = jnp.empty((d, d))
        for i in range(d):
            for j in range(d):
                mat = mat.at[i, j].set(self.inner(x, cmat[:, i], cmat[:, j]))
        ei, ev = jla.eigh(mat)
        return cmat@ev@(1/jnp.sqrt(jnp.abs(ei))[:, None]*ev.T)