import torch.optim as optim
import numpy as np
import torch
from torch.autograd import Variable

import random
from torch.nn.utils import clip_grad_norm
import copy

import os
import pickle



def get_decoder_index_XY(batchY):
    '''

    :param batchY: like [0 0 1 0 0 0 0 1]
    :return:
    '''


    returnX =[]
    returnY =[]
    for i in range(len(batchY)):

        curY = batchY[i]
        index_1 = np.where(curY==1)

        decoderY = index_1[0]

        if len(index_1[0]) ==1:
            decoderX = np.array([0])
        else:
            decoderX = np.append([0],decoderY[0:-1]+1)
        returnX.append(decoderX)
        returnY.append(decoderY)

    returnX = np.array(returnX)
    returnY = np.array(returnY)

    return returnX,returnY

def align_variable_numpy(X,maxL,paddingNumber):

    aligned = []
    for cur in X:
        ext_cur = []
        ext_cur.extend(cur)
        ext_cur.extend([paddingNumber] * (maxL - len(cur)))
        aligned.append(ext_cur)
    aligned = np.array(aligned)

    return aligned


def sample_a_sorted_batch_from_numpy(numpyX,numpyY,batch_size,use_cuda):
    select_index = np.array(range(len(numpyY)))
    
    select_index = np.array(range(len(numpyX)))

    batch_x = [copy.deepcopy(numpyX[i]) for i in select_index]
    batch_y = [copy.deepcopy(numpyY[i]) for i in select_index]

    index_decoder_X,index_decoder_Y = get_decoder_index_XY(batch_y)
    all_lens = np.array([len(x) for x in batch_y])
    
    maxL = np.max(all_lens)
    
    idx = np.argsort(all_lens)
    idx = np.sort(idx)
    batch_x = [batch_x[i] for i in idx]
    batch_y = [batch_y[i] for i in idx]
    all_lens = all_lens[idx]
    
    index_decoder_X = np.array([index_decoder_X[i] for i in idx])
    index_decoder_Y = np.array([index_decoder_Y[i] for i in idx])

    numpy_batch_x = batch_x

    batch_x = align_variable_numpy(batch_x,maxL,2000001)
    batch_y = align_variable_numpy(batch_y,maxL,2)
    batch_x = Variable(torch.from_numpy(np.array(batch_x, dtype="int64")))

    if use_cuda:
        batch_x = batch_x.cuda()

    return  numpy_batch_x,batch_x,batch_y,index_decoder_X,index_decoder_Y,all_lens,maxL




class TrainSolver(object):
    def __init__(self, model,train_x,train_y,dev_x,dev_y,save_path,batch_size,eval_size,epoch, lr,lr_decay_epoch,weight_decay,use_cuda):

        self.lr = lr
        self.model = model
        self.epoch = epoch
        self.train_x = train_x
        self.train_y = train_y
        self.use_cuda = use_cuda
        self.batch_size = batch_size
        self.lr_decay_epoch = lr_decay_epoch
        self.eval_size  = eval_size

        self.dev_x, self.dev_y = dev_x, dev_y

        self.model = model
        self.save_path = save_path
        self.weight_decay =weight_decay


    def get_batch_micro_metric(self,pre_b, ground_b, x,index2word, fukugen, nloop):


        
        tokendic = {}
        for n,i in enumerate(index2word):
            tokendic[n] = i
        sents = []
        for i,cur_seq_y in enumerate(ground_b):
            fuku = fukugen[i]
            index_of_1 = np.where(cur_seq_y==1)[0]
            index_pre = pre_b[i]
            inp = x[i]

            index_pre = np.array(index_pre)
            END_B = index_of_1[-1]
            index_pre = index_pre[index_pre != END_B]
            index_of_1 = index_of_1[index_of_1 != END_B]


            index_of_1 = list(index_of_1)
            index_pre = list(index_pre)

            FP = []
            sent = []
            ex = ""
            sent = [tokendic[int(j.to('cpu').detach().numpy().copy())] for j in inp]
            for k in index_pre:
                if k not in index_of_1:
                    FP.append(k)
            #FP = [int(j.to('cpu').detach().numpy().copy()) for j in FP]

            for n,k in enumerate(zip(sent, fuku)):
                f = k[1]
                i = k[0]
                if k == "<pad>":
                    continue
                if n in FP:
                    ex += f
                    sents.append(ex)
                    ex = ""
                else:
                    ex += f
            sents.append(ex)
        return sents


    def check_accuracy(self,data2X,data2Y,index2word, fukugen2):
        for nloop in range(1):
            dataY = data2Y[nloop]
            dataX = data2X[nloop]
            fukugen = fukugen2[nloop]
            need_loop = int(np.ceil(len(dataY) / self.batch_size))
            
            for lp in range(need_loop):
                startN = lp*self.batch_size
                endN =  (lp+1)*self.batch_size
                if endN > len(dataY):
                    endN = len(dataY)
                fukuge = fukugen[startN:endN]
                numpy_batch_x, batch_x, batch_y, index_decoder_X, index_decoder_Y, all_lens, maxL = sample_a_sorted_batch_from_numpy(
                    dataX[startN:endN], dataY[startN:endN], None, self.use_cuda)

                batch_ave_loss, batch_boundary, batch_boundary_start, batch_align_matrix = self.model.predict(batch_x,index_decoder_Y,all_lens)    
                output_texts = self.get_batch_micro_metric(batch_boundary,batch_y,batch_x,index2word, fukuge, nloop)

        return output_texts