import torch
import torch.nn as nn
import torch.nn.functional as F


class CAM(nn.Module):
    '''
    Convolutional Alignment Module
    '''

    # Current version only supports input whose size is a power of 2, such as 32, 64, 128 etc.
    # You can adapt it to any input size by changing the padding or stride.
    def __init__(self,
                 channels_list=[64, 128, 256, 512],
                 strides_list=[[2, 2], [1, 1], [1, 1]],
                 in_shape=[8, 32],
                 maxT=25,
                 depth=4,
                 num_channels=128):
        super(CAM, self).__init__()
        # cascade multiscale features
        fpn = []
        for i in range(1, len(channels_list)):
            fpn.append(
                nn.Sequential(
                    nn.Conv2d(channels_list[i - 1], channels_list[i], (3, 3),
                              (strides_list[i - 1][0], strides_list[i - 1][1]),
                              1), nn.BatchNorm2d(channels_list[i]),
                    nn.ReLU(True)))
        self.fpn = nn.Sequential(*fpn)
        # convolutional alignment
        # convs
        assert depth % 2 == 0, 'the depth of CAM must be a even number.'
        # in_shape = scales[-1]
        strides = []
        conv_ksizes = []
        deconv_ksizes = []
        h, w = in_shape[0], in_shape[1]
        for i in range(0, int(depth / 2)):
            stride = [2] if 2**(depth / 2 - i) <= h else [1]
            stride = stride + [2] if 2**(depth / 2 - i) <= w else stride + [1]
            strides.append(stride)
            conv_ksizes.append([3, 3])
            deconv_ksizes.append([_**2 for _ in stride])
        convs = [
            nn.Sequential(
                nn.Conv2d(channels_list[-1], num_channels,
                          tuple(conv_ksizes[0]), tuple(strides[0]),
                          (int((conv_ksizes[0][0] - 1) / 2),
                           int((conv_ksizes[0][1] - 1) / 2))),
                nn.BatchNorm2d(num_channels), nn.ReLU(True))
        ]
        for i in range(1, int(depth / 2)):
            convs.append(
                nn.Sequential(
                    nn.Conv2d(num_channels, num_channels,
                              tuple(conv_ksizes[i]), tuple(strides[i]),
                              (int((conv_ksizes[i][0] - 1) / 2),
                               int((conv_ksizes[i][1] - 1) / 2))),
                    nn.BatchNorm2d(num_channels), nn.ReLU(True)))
        self.convs = nn.Sequential(*convs)
        # deconvs
        deconvs = []
        for i in range(1, int(depth / 2)):
            deconvs.append(
                nn.Sequential(
                    nn.ConvTranspose2d(
                        num_channels, num_channels,
                        tuple(deconv_ksizes[int(depth / 2) - i]),
                        tuple(strides[int(depth / 2) - i]),
                        (int(deconv_ksizes[int(depth / 2) - i][0] / 4.),
                         int(deconv_ksizes[int(depth / 2) - i][1] / 4.))),
                    nn.BatchNorm2d(num_channels), nn.ReLU(True)))
        deconvs.append(
            nn.Sequential(
                nn.ConvTranspose2d(num_channels, maxT, tuple(deconv_ksizes[0]),
                                   tuple(strides[0]),
                                   (int(deconv_ksizes[0][0] / 4.),
                                    int(deconv_ksizes[0][1] / 4.))),
                nn.Sigmoid()))
        self.deconvs = nn.Sequential(*deconvs)

    def forward(self, input):
        x = input[0]
        for i in range(0, len(self.fpn)):
            # print(self.fpn[i](x).shape, input[i+1].shape)
            x = self.fpn[i](x) + input[i + 1]
        conv_feats = []
        for i in range(0, len(self.convs)):
            x = self.convs[i](x)
            conv_feats.append(x)
        for i in range(0, len(self.deconvs) - 1):
            x = self.deconvs[i](x)
            x = x + conv_feats[len(conv_feats) - 2 - i]
        x = self.deconvs[-1](x)
        return x


class CAMSimp(nn.Module):

    def __init__(self, maxT=25, num_channels=128):
        super(CAMSimp, self).__init__()
        self.conv = nn.Sequential(nn.Conv2d(num_channels, maxT, 1, 1, 0),
                                  nn.Sigmoid())

    def forward(self, x):
        x = self.conv(x)
        return x


class DANDecoder(nn.Module):
    '''
    Decoupled Text Decoder
    '''

    def __init__(self,
                 out_channels,
                 in_channels,
                 use_cam=True,
                 max_len=25,
                 channels_list=[64, 128, 256, 512],
                 strides_list=[[2, 2], [1, 1], [1, 1]],
                 in_shape=[8, 32],
                 depth=4,
                 dropout=0.3,
                 **kwargs):
        super(DANDecoder, self).__init__()
        self.eos = 0
        self.bos = out_channels - 2
        self.ignore_index = out_channels - 1
        nchannel = in_channels
        self.nchannel = in_channels
        self.use_cam = use_cam
        if use_cam:
            self.cam = CAM(channels_list=channels_list,
                           strides_list=strides_list,
                           in_shape=in_shape,
                           maxT=max_len + 1,
                           depth=depth,
                           num_channels=nchannel)
        else:
            self.cam = CAMSimp(maxT=max_len + 1, num_channels=nchannel)
        self.pre_lstm = nn.LSTM(nchannel,
                                int(nchannel / 2),
                                bidirectional=True)
        self.rnn = nn.GRUCell(nchannel * 2, nchannel)
        self.generator = nn.Sequential(nn.Dropout(p=dropout),
                                       nn.Linear(nchannel, out_channels - 2))
        self.char_embeddings = nn.Embedding(out_channels,
                                            embedding_dim=in_channels,
                                            padding_idx=out_channels - 1)

    def forward(self, inputs, data=None):
        A = self.cam(inputs)
        if isinstance(inputs, list):
            feature = inputs[-1]
        else:
            feature = inputs
        nB, nC, nH, nW = feature.shape
        nT = A.shape[1]
        # Normalize
        A = A / A.view(nB, nT, -1).sum(2).view(nB, nT, 1, 1)
        # weighted sum
        C = feature.view(nB, 1, nC, nH, nW) * A.view(nB, nT, 1, nH, nW)
        C = C.view(nB, nT, nC, -1).sum(3).transpose(1, 0)  # T, B, C
        C, _ = self.pre_lstm(C)  # T, B, C
        C = F.dropout(C, p=0.3, training=self.training)
        if self.training:
            text = data[0]
            text_length = data[-1]
            nsteps = int(text_length.max())
            gru_res = torch.zeros_like(C)
            hidden = torch.zeros(nB, self.nchannel).type_as(C.data)
            prev_emb = self.char_embeddings(text[:, 0])
            for i in range(0, nsteps + 1):
                hidden = self.rnn(torch.cat((C[i, :, :], prev_emb), dim=1),
                                  hidden)
                gru_res[i, :, :] = hidden
                prev_emb = self.char_embeddings(text[:, i + 1])
            gru_res = self.generator(gru_res)
            return gru_res[:nsteps + 1, :, :].transpose(1, 0)
        else:
            gru_res = torch.zeros_like(C)
            hidden = torch.zeros(nB, self.nchannel).type_as(C.data)
            prev_emb = self.char_embeddings(
                torch.zeros(nB, dtype=torch.int64, device=feature.device) +
                self.bos)
            dec_seq = torch.full((nB, nT),
                                self.ignore_index,
                                dtype=torch.int64,
                                device=feature.get_device())
            
            for i in range(0, nT):
                hidden = self.rnn(torch.cat((C[i, :, :], prev_emb), dim=1),
                                  hidden)
                gru_res[i, :, :] = hidden
                mid_res = self.generator(hidden).argmax(-1)
                dec_seq[:, i] = mid_res.squeeze(0)
                if (dec_seq == self.eos).any(dim=-1).all():
                    break
                prev_emb = self.char_embeddings(mid_res)
            gru_res = self.generator(gru_res)
            return F.softmax(gru_res.transpose(1, 0), -1)