import torch
import torch.nn as nn
import torch.distributions as D
from torch.nn import functional as F
import numpy as np
from torch.autograd import Variable

class BaseFlow(nn.Module):
    def __init__(self):
        super().__init__()

    def sample(self, n=1, context=None, **kwargs):
        dim = self.dim
        if isinstance(self.dim, int):
            dim = [dim, ]

        spl = Variable(torch.FloatTensor(n, *dim).normal_())
        lgd = Variable(torch.from_numpy(
            np.zeros(n).astype('float32')))
        if context is None:
            context = Variable(torch.from_numpy(
                np.ones((n, self.context_dim)).astype('float32')))

        if hasattr(self, 'gpu'):
            if self.gpu:
                spl = spl.cuda()
                lgd = lgd.cuda()
                context = context.gpu()

        return self.forward((spl, lgd, context))

    def cuda(self):
        self.gpu = True
        return super(BaseFlow, self).cuda()

def varify(x):
    return torch.autograd.Variable(torch.from_numpy(x))

def oper(array,oper,axis=-1,keepdims=False):
    a_oper = oper(array)
    if keepdims:
        shape = []
        for j,s in enumerate(array.size()):
            shape.append(s)
        shape[axis] = -1
        a_oper = a_oper.view(*shape)
    return a_oper

def log_sum_exp(A, axis=-1, sum_op=torch.sum):
    maximum = lambda x: x.max(axis)[0]
    A_max = oper(A,maximum,axis,True)
    summation = lambda x: sum_op(torch.exp(x-A_max), axis)
    B = torch.log(oper(A,summation,axis,True)) + A_max
    return B

delta = 1e-6
logsigmoid = lambda x: -F.softplus(-x)
log = lambda x: torch.log(x*1e2)-np.log(1e2)
softplus_ = nn.Softplus()
softplus = lambda x: softplus_(x) + delta
def softmax(x, dim=-1):
    e_x = torch.exp(x - x.max(dim=dim, keepdim=True)[0])
    out = e_x / e_x.sum(dim=dim, keepdim=True)
    return out

class DenseSigmoidFlow(nn.Module):
    def __init__(self, hidden_dim, in_dim=1, out_dim=1):
        super().__init__()
        self.in_dim = in_dim
        self.hidden_dim = hidden_dim
        self.out_dim = out_dim

        self.act_a = lambda x: F.softplus(x)
        self.act_b = lambda x: x
        self.act_w = lambda x: torch.softmax(x, dim=3)
        self.act_u = lambda x: torch.softmax(x, dim=3)

        self.u_ = torch.nn.Parameter(torch.Tensor(hidden_dim, in_dim))
        self.w_ = torch.nn.Parameter(torch.Tensor(out_dim, hidden_dim))
        self.num_params = 3* hidden_dim + in_dim
        self.reset_parameters()

    def reset_parameters(self):
        self.u_.data.uniform_(-0.001, 0.001)
        self.w_.data.uniform_(-0.001, 0.001)

    def forward(self, x, dsparams):
        delta = 1e-7
        inv = np.log(np.exp(1 - delta) - 1)
        ndim = self.hidden_dim
        pre_u = self.u_[None, None, :, :] + dsparams[:, :, -self.in_dim:][:, :, None, :]
        pre_w = self.w_[None, None, :, :] + dsparams[:, :, 2 * ndim:3 * ndim][:, :, None, :]
        a = self.act_a(dsparams[:, :, 0 * ndim:1 * ndim] + inv)
        b = self.act_b(dsparams[:, :, 1 * ndim:2 * ndim])
        w = self.act_w(pre_w)
        u = self.act_u(pre_u)

        pre_sigm = torch.sum(u * a[:, :, :, None] * x[:, :, None, :], 3) + b
        sigm = torch.selu(pre_sigm)
        x_pre = torch.sum(w * sigm[:, :, None, :], dim=3)
        #x_ = torch.special.logit(x_pre, eps=1e-5)
        #xnew = x_
        xnew = x_pre
        return xnew


class DDSF(nn.Module):
    def __init__(self, n_blocks=1, hidden_dim=16):
        super().__init__()
        self.num_params = 0
        if n_blocks == 1:
            model = [DenseSigmoidFlow(hidden_dim, in_dim=1, out_dim=1)]
        else:
            model = [DenseSigmoidFlow(hidden_dim=hidden_dim, in_dim=1, out_dim=hidden_dim)]
            for _ in range(n_blocks-2):
                model += [DenseSigmoidFlow(hidden_dim=hidden_dim, in_dim=hidden_dim, out_dim=hidden_dim)]
            model += [DenseSigmoidFlow(hidden_dim=hidden_dim, in_dim=hidden_dim, out_dim=1)]
        self.model = nn.Sequential(*model)
        for block in self.model:
            self.num_params += block.num_params

    def forward(self, x, dsparams):
        x = x.unsqueeze(2)
        start = 0
        for block in self.model:
            block_dsparams = dsparams[:,:,start:start+block.num_params]
            x = block(x, block_dsparams)
            start += block.num_params
        return x.squeeze(2)

def compute_jacobian(inputs, outputs):
    batch_size = outputs.size(0)
    outVector = torch.sum(outputs,0).view(-1)
    outdim = outVector.size()[0]
    jac = torch.stack([torch.autograd.grad(outVector[i], inputs,
                                     retain_graph=True, create_graph=True)[0].view(batch_size, outdim) for i in range(outdim)], dim=1)
    jacs = [jac[i,:,:] for i in range(batch_size)]
    print(jacs[1])

if __name__ == '__main__':

    flow = DDSF(n_blocks=10, hidden_dim=50)
    x = torch.arange(20).view(10, 2)/10.-1.
    x = Variable(x, requires_grad=True)

    dsparams = torch.randn(1, 2, 2*flow.num_params).repeat(10,1,1)
    y = flow(x, dsparams)
    print(x, y)
    compute_jacobian(x, y)

    """
    flow = ConvDenseSigmoidFlow(1,256,1)
    dsparams = torch.randn(1, 2, 1000).repeat(10,1,1)
    x = torch.arange(20).view(10,2,1).repeat(1,1,4).view(10,2,2,2)/10.
    print(x.size(), dsparams.size())
    out = flow(x, dsparams)
    print(x, out.flatten(2), out.size())
    flow = ConvDDSF(n_blocks=3)
    dsparams = torch.randn(1, 2, flow.num_params).repeat(10,1,1)
    x = torch.arange(80).view(10,2,4).view(10,2,2,2)/10
    print(x.size(), dsparams.size())
    out = flow(x, dsparams)
    print(x, out.flatten(2), out.size())
    """