import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from skimage.segmentation._slic import _enforce_label_connectivity_cython

def initWave(nPeriodic):
    buf = []
    for i in range(nPeriodic // 4+1):
        v = 0.5 + i / float(nPeriodic//4+1e-10)
        buf += [0, v, v, 0]
        buf += [0, -v, v, 0]  #so from other quadrants as well..
    buf = buf[:2*nPeriodic]
    awave = np.array(buf, dtype=np.float32) * np.pi
    awave = torch.FloatTensor(awave).unsqueeze(-1).unsqueeze(-1).unsqueeze(0)
    return awave

class SPADEGenerator(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        nf = hidden_dim // 16

        self.head_0 = SPADEResnetBlock(16 * nf, 16 * nf)

        self.G_middle_0 = SPADEResnetBlock(16 * nf, 16 * nf)
        self.G_middle_1 = SPADEResnetBlock(16 * nf, 16 * nf)

        self.up_0 = SPADEResnetBlock(16 * nf, 8 * nf)
        self.up_1 = SPADEResnetBlock(8 * nf, 4 * nf)
        self.up_2 = SPADEResnetBlock(4 * nf, nf)
        #self.up_3 = SPADEResnetBlock(2 * nf, 1 * nf)

        final_nc = nf

        self.conv_img = nn.Conv2d(final_nc, 3, 3, padding=1)

        self.up = nn.Upsample(scale_factor=2)


    def forward(self, x, input):
        seg = input

        x = self.head_0(x, seg)

        x = self.up(x)
        x = self.G_middle_0(x, seg)
        x = self.G_middle_1(x, seg)

        x = self.up(x)
        x = self.up_0(x, seg)
        x = self.up(x)
        x = self.up_1(x, seg)
        x = self.up(x)
        x = self.up_2(x, seg)
        #x = self.up(x)
        #x = self.up_3(x, seg)

        x = self.conv_img(F.leaky_relu(x, 2e-1))
        return x

class SPADE(nn.Module):
    def __init__(self, norm_nc, label_nc):
        super().__init__()

        ks = 3

        self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)

        # The dimension of the intermediate embedding space. Yes, hardcoded.
        nhidden = 128

        pw = ks // 2
        self.mlp_shared = nn.Sequential(
            nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw),
            nn.ReLU()
        )
        self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)
        self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)

    def forward(self, x, segmap):

        # Part 1. generate parameter-free normalized activations
        normalized = self.param_free_norm(x)

        # Part 2. produce scaling and bias conditioned on semantic map
        #segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest')
        segmap = F.interpolate(segmap, size=x.size()[2:], mode='bilinear', align_corners = False)
        actv = self.mlp_shared(segmap)
        gamma = self.mlp_gamma(actv)
        beta = self.mlp_beta(actv)

        # apply scale and bias
        out = normalized * (1 + gamma) + beta

        return out

class SPADEResnetBlock(nn.Module):
    def __init__(self, fin, fout):
        super().__init__()
        # Attributes
        self.learned_shortcut = (fin != fout)
        fmiddle = min(fin, fout)

        # create conv layers
        self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1)
        self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1)
        if self.learned_shortcut:
            self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False)

        # define normalization layers
        self.norm_0 = SPADE(fin, 256)
        self.norm_1 = SPADE(fmiddle, 256)
        if self.learned_shortcut:
            self.norm_s = SPADE(fin, 256)

    # note the resnet block with SPADE also takes in |seg|,
    # the semantic segmentation map as input
    def forward(self, x, seg):
        x_s = self.shortcut(x, seg)

        dx = self.conv_0(self.actvn(self.norm_0(x, seg)))
        dx = self.conv_1(self.actvn(self.norm_1(dx, seg)))

        out = x_s + dx

        return out

    def shortcut(self, x, seg):
        if self.learned_shortcut:
            x_s = self.conv_s(self.norm_s(x, seg))
        else:
            x_s = x
        return x_s

    def actvn(self, x):
        return F.leaky_relu(x, 2e-1)

def get_edges(sp_label, sp_num):
    # This function returns a (hw) * (hw) matrix N.
    # If Nij = 1, then superpixel i and j are neighbors
    # Otherwise Nij = 0.
    top = sp_label[:, :, :-1, :] - sp_label[:, :, 1:, :]
    left = sp_label[:, :, :, :-1] - sp_label[:, :, :, 1:]
    top_left = sp_label[:, :, :-1, :-1] - sp_label[:, :, 1:, 1:]
    top_right = sp_label[:, :, :-1, 1:] - sp_label[:, :, 1:, :-1]
    n_affs = []
    edge_indices = []
    for i in range(sp_label.shape[0]):
        # change to torch.ones below to include self-loop in graph
        n_aff = torch.zeros(sp_num, sp_num).unsqueeze(0).to(sp_label.device)
        # top/bottom
        top_i = top[i].squeeze()
        x, y = torch.nonzero(top_i, as_tuple = True)
        sp1 = sp_label[i, :, x, y].squeeze().long()
        sp2 = sp_label[i, :, x+1, y].squeeze().long()
        n_aff[:, sp1, sp2] = 1
        n_aff[:, sp2, sp1] = 1

        # left/right
        left_i = left[i].squeeze()
        try:
            x, y = torch.nonzero(left_i, as_tuple = True)
        except:
            import pdb; pdb.set_trace()
        sp1 = sp_label[i, :, x, y].squeeze().long()
        sp2 = sp_label[i, :, x, y+1].squeeze().long()
        n_aff[:, sp1, sp2] = 1
        n_aff[:, sp2, sp1] = 1

        # top left
        top_left_i = top_left[i].squeeze()
        x, y = torch.nonzero(top_left_i, as_tuple = True)
        sp1 = sp_label[i, :, x, y].squeeze().long()
        sp2 = sp_label[i, :, x+1, y+1].squeeze().long()
        n_aff[:, sp1, sp2] = 1
        n_aff[:, sp2, sp1] = 1

        # top right
        top_right_i = top_right[i].squeeze()
        x, y = torch.nonzero(top_right_i, as_tuple = True)
        sp1 = sp_label[i, :, x, y+1].squeeze().long()
        sp2 = sp_label[i, :, x+1, y].squeeze().long()
        n_aff[:, sp1, sp2] = 1
        n_aff[:, sp2, sp1] = 1

        n_affs.append(n_aff)
        edge_index = torch.stack(torch.nonzero(n_aff.squeeze(), as_tuple=True))
        edge_indices.append(edge_index.to(sp_label.device))
    return edge_indices, torch.cat(n_affs)

def enforce_connectivity(segs, H, W, sp_num = 196, min_size = None, max_size = None):
    rets = []
    for i in range(segs.shape[0]):
        seg = segs[i]
        seg = seg.squeeze().cpu().numpy()

        segment_size = H * W / sp_num
        if min_size is None:
            min_size = int(0.1 * segment_size)
        if max_size is None:
            max_size = int(1000.0 * segment_size)
        seg = _enforce_label_connectivity_cython(seg[None], min_size, max_size)[0]
        seg = torch.from_numpy(seg).unsqueeze(0).unsqueeze(0)
        rets.append(seg)
    return torch.cat(rets)