# coding:utf-8
# chenjun
# date:2020-04-18
import torch.nn as nn 
import torch
import torch.nn.functional as F 
import numpy as np


# def get_non_pad_mask(seq, PAD):
#     assert seq.dim() == 2
#     return seq.ne(PAD).type(torch.float).unsqueeze(-1)

def get_pad_mask(seq, pad_idx):
    return (seq == pad_idx).unsqueeze(-2)


def get_subsequent_mask(seq):
    ''' For masking out the subsequent info. '''

    sz_b, len_s = seq.size()
    subsequent_mask = torch.triu(
        torch.ones((len_s, len_s), device=seq.device, dtype=torch.uint8), diagonal=1)       # 返回上三角矩阵
    subsequent_mask = subsequent_mask.unsqueeze(0).expand(sz_b, -1, -1)  # b x ls x ls

    return subsequent_mask


def get_attn_key_pad_mask(seq_k, seq_q, PAD):
    ''' For masking out the padding part of key sequence. 
        seq_k:src_seq
        seq_q:tgt_seq
    '''

    # Expand to fit the shape of key query attention matrix.
    len_q = seq_q.size(1)                       # 目标序列
    padding_mask = seq_k.eq(PAD)      # 源序列
    padding_mask = padding_mask.unsqueeze(1).expand(-1, len_q, -1)  # b x lq x lk

    return padding_mask


class PositionalEncoding(nn.Module):

    def __init__(self, d_hid, n_position=200):
        super(PositionalEncoding, self).__init__()

        # Not a parameter
        self.register_buffer('pos_table', self._get_sinusoid_encoding_table(n_position, d_hid))

    def _get_sinusoid_encoding_table(self, n_position, d_hid):
        ''' Sinusoid position encoding table '''
        # TODO: make it with torch instead of numpy

        def get_position_angle_vec(position):
            return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]

        sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
        sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2])  # dim 2i
        sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2])  # dim 2i+1

        return torch.FloatTensor(sinusoid_table).unsqueeze(0)

    def forward(self, x):
        return x + self.pos_table[:, :x.size(1)].clone().detach()


class ScaledDotProductAttention(nn.Module):
    ''' Scaled Dot-Product Attention '''

    def __init__(self, temperature, attn_dropout=0.1):
        super(ScaledDotProductAttention, self).__init__()
        self.temperature = temperature
        self.dropout = nn.Dropout(attn_dropout)
        self.softmax = nn.Softmax(dim=2)

    def forward(self, q, k, v, mask=None):

        attn = torch.bmm(q, k.transpose(1, 2))
        attn = attn / self.temperature

        if mask is not None:
            # print(mask.shape, attn.shape, v.shape)
            attn = attn.masked_fill(mask, -1e9)

        attn = self.softmax(attn)       # 第3个维度为权重
        attn = self.dropout(attn)
        output = torch.bmm(attn, v)

        return output, attn


class MultiHeadAttention(nn.Module):
    ''' Multi-Head Attention module '''

    def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
        super(MultiHeadAttention, self).__init__()

        self.n_head = n_head
        self.d_k = d_k
        self.d_v = d_v

        self.w_qs = nn.Linear(d_model, n_head * d_k)
        self.w_ks = nn.Linear(d_model, n_head * d_k)
        self.w_vs = nn.Linear(d_model, n_head * d_v)
        nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
        nn.init.normal_(self.w_ks.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
        nn.init.normal_(self.w_vs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_v)))

        self.attention = ScaledDotProductAttention(temperature=np.power(d_k, 0.5))
        self.layer_norm = nn.LayerNorm(d_model)

        self.fc = nn.Linear(n_head * d_v, d_model)
        nn.init.xavier_normal_(self.fc.weight)

        self.dropout = nn.Dropout(dropout)


    def forward(self, q, k, v, mask=None):

        d_k, d_v, n_head = self.d_k, self.d_v, self.n_head

        sz_b, len_q, _ = q.size()
        sz_b, len_k, _ = k.size()
        sz_b, len_v, _ = v.size()

        residual = q

        q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)     # 4*21*512 ---- 4*21*8*64
        k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
        v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)

        q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k) # (n*b) x lq x dk
        k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk
        v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv

        mask = mask.repeat(n_head, 1, 1) if mask is not None else None # (n*b) x .. x ..
        output, attn = self.attention(q, k, v, mask=mask)

        output = output.view(n_head, sz_b, len_q, d_v)
        output = output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1) # b x lq x (n*dv)

        output = self.dropout(self.fc(output))
        output = self.layer_norm(output + residual)

        return output, attn

class PositionwiseFeedForward(nn.Module):
    ''' A two-feed-forward-layer module '''

    def __init__(self, d_in, d_hid, dropout=0.1):
        super(PositionwiseFeedForward, self).__init__()
        self.w_1 = nn.Conv1d(d_in, d_hid, 1) # position-wise
        self.w_2 = nn.Conv1d(d_hid, d_in, 1) # position-wise
        self.layer_norm = nn.LayerNorm(d_in)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        residual = x
        output = x.transpose(1, 2)
        output = self.w_2(F.relu(self.w_1(output)))
        output = output.transpose(1, 2)
        output = self.dropout(output)
        output = self.layer_norm(output + residual)
        return output


class EncoderLayer(nn.Module):
    ''' Compose with two layers '''

    def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1):
        super(EncoderLayer, self).__init__()
        self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
        self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout)

    def forward(self, enc_input, slf_attn_mask=None):
        enc_output, enc_slf_attn = self.slf_attn(
            enc_input, enc_input, enc_input, mask=slf_attn_mask)
        enc_output = self.pos_ffn(enc_output)
        return enc_output, enc_slf_attn


class Torch_transformer_encoder(nn.Module):
    '''
        use pytorch transformer for sequence learning

    '''
    def __init__(self, d_word_vec=512, n_layers=2, n_head=8, d_model=512, dim_feedforward=1024, n_position=256):
        super(Torch_transformer_encoder, self).__init__()

        self.position_enc = PositionalEncoding(d_word_vec, n_position=n_position)
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=n_head, dim_feedforward=dim_feedforward)
        self.layer_norm = nn.LayerNorm(d_model)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_layers, norm=self.layer_norm)
        self.dropout = nn.Dropout(p=0.1)

    def forward(self, cnn_feature, src_mask=None, return_attns=False):
        enc_slf_attn_list = []

        # -- Forward
        enc_output = self.dropout(self.position_enc(cnn_feature))  # position embeding

        enc_output = self.encoder(enc_output)

        enc_output = self.layer_norm(enc_output)

        if return_attns:
            return enc_output, enc_slf_attn_list
        return enc_output,



class Transforme_Encoder(nn.Module):
    ''' to capture the global spatial dependencies'''
    '''
    d_word_vec: 位置编码,特征空间维度
    n_layers: transformer的层数
    n_head:多头数量
    d_k: 64
    d_v: 64
    d_model: 512,
    d_inner: 1024
    n_position: 位置编码的最大值
    '''
    def __init__(
            self, d_word_vec=512, n_layers=2, n_head=8, d_k=64, d_v=64,
            d_model=512, d_inner=1024, dropout=0.1, n_position=256):

        super().__init__()

        self.position_enc = PositionalEncoding(d_word_vec, n_position=n_position)
        self.dropout = nn.Dropout(p=dropout)
        self.layer_stack = nn.ModuleList([
            EncoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
            for _ in range(n_layers)])
        self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)

    def forward(self, cnn_feature, src_mask, return_attns=False):

        enc_slf_attn_list = []

        # -- Forward
        enc_output = self.dropout(self.position_enc(cnn_feature))   # position embeding

        for enc_layer in self.layer_stack:
            enc_output, enc_slf_attn = enc_layer(enc_output, slf_attn_mask=src_mask)
            enc_slf_attn_list += [enc_slf_attn] if return_attns else []

        enc_output = self.layer_norm(enc_output)

        if return_attns:
            return enc_output, enc_slf_attn_list
        return enc_output,
    

class PVAM(nn.Module):
    ''' Parallel Visual attention module 平行解码'''
    '''
    n_dim:512,阅读顺序序列编码的空间维度
    N_max_character: 25,单张图片最多有多少个字符
    n_position: cnn出来之后特征的序列长度
    '''
    def __init__(self,  n_dim=512, N_max_character=25, n_position=256):

        super(PVAM, self).__init__()
        self.character_len = N_max_character

        self.f0_embedding = nn.Embedding(N_max_character, n_dim)
        
        self.w0 = nn.Linear(N_max_character, n_position)
        self.wv = nn.Linear(n_dim, n_dim)
        # first linear(512,25)
        self.we = nn.Linear(n_dim, N_max_character)

        self.active = nn.Tanh()
        self.softmax = nn.Softmax(dim=2)

    def forward(self, enc_output):
        reading_order = torch.arange(self.character_len, dtype=torch.long, device=enc_output.device)
        reading_order = reading_order.unsqueeze(0).expand(enc_output.size(0), -1)    # (S,) -> (B, S)
        reading_order_embed = self.f0_embedding(reading_order)      # b,25,512

        t = self.w0(reading_order_embed.permute(0,2,1))     # b,512,256
        t = self.active(t.permute(0,2,1) + self.wv(enc_output))     # b,256,512
        # first linear(512,25)
        attn = self.we(t)  # b,256,25

        attn = self.softmax(attn.permute(0,2,1))  # b,25,256

        g_output = torch.bmm(attn, enc_output)  # b,25,512
        return g_output


class GSRM(nn.Module):
    # global semantic reasoning module
    '''
    n_dim:embed编码的特征空间维度
    n_class:embedding需要用到
    PAD:计算mask用到
    '''
    def __init__(self, n_dim=512, n_class=37, PAD=37-1, n_layers=4, n_position=25):

        super(GSRM, self).__init__()

        self.PAD = PAD
        self.argmax_embed = nn.Embedding(n_class, n_dim)

        self.transformer_units = Transforme_Encoder(n_layers=n_layers, n_position=n_position)      # for global context information
        # self.transformer_units = Torch_transformer_encoder(n_layers=n_layers, n_position=n_position)

    def forward(self, e_out):  
        '''
        e_out: b,25,37 | the output from PVAM3
        '''    
        e_argmax = e_out.argmax(dim=-1)     # b, 25
        e = self.argmax_embed(e_argmax)  # b,25,512

        e_mask = get_pad_mask(e_argmax, self.PAD)   # b,25,1
        s = self.transformer_units(e, None)   # b,25,512

        return s


class SRN_Decoder(nn.Module):
    # the wrapper of decoder layers
    '''
    n_dim: 特征空间维度
    n_class:字符种类
    N_max_character: 单张图最多只25个字符
    n_position:cnn输出的特征序列长度
    整个有三个部分的输出
    '''
    def __init__(self, n_dim=512, n_class=37, N_max_character=25, n_position=256, GSRM_layer=4 ):

        super(SRN_Decoder, self).__init__()
        
        self.pvam = PVAM(N_max_character=N_max_character, n_position=n_position)
        self.w_e = nn.Linear(n_dim, n_class)    # output layer

        self.GSRM = GSRM(n_class=n_class, PAD=n_class-1, n_dim=n_dim, n_position=N_max_character, n_layers=GSRM_layer)
        self.w_s = nn.Linear(n_dim, n_class)    # output layer

        self.w_f = nn.Linear(n_dim, n_class)    # output layer

    def forward(self, cnn_feature ):
        '''cnn_feature: b,256,512 | the output from cnn'''

        g_output = self.pvam(cnn_feature)   # b,25,512
        e_out = self.w_e(g_output)     # b,25,37 ----> cross entropy loss  |  第一个输出

        s = self.GSRM(e_out)[0]      # b,25,512
        s_out = self.w_s(s)       # b,25,37f

        # TODO:change the add to gated unit
        f = g_output + s    # b,25,512
        f_out = self.w_f(f)

        return e_out, s_out, f_out


def cal_performance(preds, gold, mask=None, smoothing='1'):
    ''' Apply label smoothing if needed '''

    loss = 0.
    n_correct = 0
    weights = [1.0, 0.15, 2.0]
    for ori_pred, weight in zip(preds, weights):
        pred = ori_pred.view(-1, ori_pred.shape[-1])
        # debug show
        t_gold = gold.view(ori_pred.shape[0], -1)
        t_pred_index = ori_pred.max(2)[1]

        mask = mask.view(-1)
        non_pad_mask = mask.ne(0) if mask is not None else None
        tloss = cal_loss(pred, gold, non_pad_mask, smoothing)
        if torch.isnan(tloss):
            print('have nan loss')
            continue
        else:
            loss += tloss * weight

        pred = pred.max(1)[1]
        gold = gold.contiguous().view(-1)
        n_correct = pred.eq(gold)
        n_correct = n_correct.masked_select(non_pad_mask).sum().item() if mask is not None else None

    return loss, n_correct


def cal_loss(pred, gold, mask, smoothing):
    ''' Calculate cross entropy loss, apply label smoothing if needed. '''

    gold = gold.contiguous().view(-1)

    if smoothing=='0':
        eps = 0.1
        n_class = pred.size(1)

        one_hot = torch.zeros_like(pred).scatter(1, gold.view(-1, 1), 1)
        one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1)
        log_prb = F.log_softmax(pred, dim=1)

        non_pad_mask = gold.ne(0)
        loss = -(one_hot * log_prb).sum(dim=1)
        loss = loss.masked_select(non_pad_mask).sum()  # average later
    elif smoothing == '1':
        if mask is not None:
            loss = F.cross_entropy(pred, gold, reduction='none')
            loss = loss.masked_select(mask)
            loss = loss.sum() / mask.sum()
        else:
            loss = F.cross_entropy(pred, gold)
    else:
        # loss = F.cross_entropy(pred, gold, ignore_index=PAD)
        loss = F.cross_entropy(pred, gold)

    return loss


def cal_performance2(preds, gold, PAD, smoothing='1'):
    ''' Apply label smoothing if needed '''

    loss = 0.
    n_correct = 0
    weights = [1.0, 0.15, 2.0]
    for ori_pred, weight in zip(preds, weights):
        pred = ori_pred.view(-1, ori_pred.shape[-1])
        # debug show
        t_gold = gold.view(ori_pred.shape[0], -1)
        t_pred_index = ori_pred.max(2)[1]

        tloss = cal_loss2(pred, gold, PAD, smoothing=smoothing)
        if torch.isnan(tloss):
            print('have nan loss')
            continue
        else:
            loss += tloss * weight

        pred = pred.max(1)[1]
        gold = gold.contiguous().view(-1)
        n_correct = pred.eq(gold)
        non_pad_mask = gold.ne(PAD)
        n_correct = n_correct.masked_select(non_pad_mask).sum().item()

    return loss, n_correct


def cal_loss2(pred, gold, PAD, smoothing='1'):
    ''' Calculate cross entropy loss, apply label smoothing if needed. '''

    gold = gold.contiguous().view(-1)

    if smoothing=='0':
        eps = 0.1
        n_class = pred.size(1)

        one_hot = torch.zeros_like(pred).scatter(1, gold.view(-1, 1), 1)
        one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1)
        log_prb = F.log_softmax(pred, dim=1)

        non_pad_mask = gold.ne(0)
        loss = -(one_hot * log_prb).sum(dim=1)
        loss = loss.masked_select(non_pad_mask).sum()  # average later
    elif smoothing == '1':
        loss = F.cross_entropy(pred, gold, ignore_index=PAD)
    else:
        # loss = F.cross_entropy(pred, gold, ignore_index=PAD)
        loss = F.cross_entropy(pred, gold)

    return loss


if __name__=='__main__':
    cnn_feature = torch.rand((2,256,512))
    model1 = Transforme_Encoder()
    image = model1(cnn_feature,src_mask=None)[0]
    model = SRN_Decoder(N_max_character=30)

    outs = model(image)
    for out in outs:
        print(out.shape)

    # image = torch.rand((4,3,32,60))
    # tgt_seq = torch.tensor([[   2,   24, 2176,  882, 2480,  612, 1525,  480,  875,  147, 1700,  715,
    #      1465,    3],
    #     [   2,  369, 1781,  882,  703,  879, 2855, 2415,  502, 1154,  833, 1465,
    #         3,    0],
    #     [   2, 2943,  334,  328,  480,  330, 1644, 1449,  163,  147, 1823, 1184,
    #      1465,    3],
    #     [   2,   24,  396,  480,  703, 1646,  897, 1711, 1508,  703, 2321,  147,
    #       642, 1465]], device='cuda:0')
    # tgt_pos = torch.tensor([[ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14],
    #     [ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13,  0],
    #     [ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14],
    #     [ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14]],
    #    device='cuda:0')
    # src_seq = torch.tensor([[   2,  598, 2088,  822, 2802, 1156,  157, 1099, 1000,  598, 1707, 1345,
    #         3,    0,    0, 0],
    #     [   2,  598, 2348,  822,  598, 1222,  471,  948,  986,  423, 1345,    3,
    #         0,    0,    0, 0],
    #     [   2, 2437, 2470,  901, 2473,  598, 1735,   84,    1, 2277, 1979,  499,
    #       962, 1345,    3, 0],
    #     [   2,  598,  186, 1904,  598,  868, 1339, 1604,   84,  598,  608, 1728,
    #      1345,    3,    0, 0]], device='cuda:0')

    # device = torch.device('cuda')
    # image = image.cuda()
    # transformer = Transformer()
    # transformer = transformer.to(device)
    # transformer.train()
    # out = transformer(image, tgt_seq, tgt_pos, src_seq)
    
    # gold = tgt_seq[:, 1:]           # 从第二列开始

    # # backward
    # loss, n_correct = cal_performance(out, gold, smoothing=True)
    # print(loss, n_correct)
    # a = 1