import os
import time
import pickle
import datetime
import itertools
import numpy as np
import torch
import torch.nn.functional as F

from onmt_modules.misc import sequence_mask
from model_autopst import Generator_2 as Predictor



class Solver(object):

    def __init__(self, data_loader, config, hparams):
        """Initialize configurations."""

        
        self.data_loader = data_loader
        self.hparams = hparams
        self.gate_threshold = hparams.gate_threshold
        
        self.use_cuda = torch.cuda.is_available()
        self.device = torch.device('cuda:{}'.format(config.device_id) if self.use_cuda else 'cpu')
        self.num_iters = config.num_iters
        self.log_step = config.log_step
        
        # Build the model
        self.build_model()
    
            
    def build_model(self):
        
        self.P = Predictor(self.hparams)
        self.freeze_layers(self.P.encoder_cd)
        
        self.optimizer = torch.optim.Adam(self.P.parameters(), 0.0001, [0.9, 0.999])
        
        self.P.to(self.device)
        
        self.BCELoss = torch.nn.BCEWithLogitsLoss().to(self.device)  
        
        
        checkpoint = torch.load(self.hparams.pretrained_path, 
                                map_location=lambda storage, loc: storage)
        
        self.P.load_state_dict(checkpoint['model'], strict=True)
        print('Loaded pretrained encoder .........................................')
        
        
    def freeze_layers(self, layer):
        print('Fixing layers!')
        for param in layer.parameters():
            param.requires_grad = False
    
                
    def train(self):
        # Set data loader
        data_loader = self.data_loader
        data_iter = iter(data_loader)
        
        
        # Print logs in specified order
        keys = ['P/loss_tx2sp', 'P/loss_stop_sp']
        
            
        # Start training.
        print('Start training...')
        start_time = time.time()
        for i in range(self.num_iters):

            try:
                sp_real, cep_real, cd_real, num_rep, _, len_real, len_short, _, spk_emb = next(data_iter)
            except:
                data_iter = iter(data_loader)
                sp_real, cep_real, cd_real, num_rep, _, len_real, len_short, _, spk_emb = next(data_iter)
                
            
            sp_real = sp_real.to(self.device)
            cep_real = cep_real.to(self.device)
            cd_real = cd_real.to(self.device)
            len_real = len_real.to(self.device)
            spk_emb = spk_emb.to(self.device)
            num_rep = num_rep.to(self.device)
            len_short = len_short.to(self.device)
            
            
            # real spect masks
            mask_sp_real = ~sequence_mask(len_real, sp_real.size(1))
            mask_long = (~mask_sp_real).float()
            
            len_real_mask = torch.min(len_real + 10, 
                                      torch.full_like(len_real, sp_real.size(1)))
            loss_tx2sp_mask = sequence_mask(len_real_mask, sp_real.size(1)).float().unsqueeze(-1)
            
            # text input masks
            codes_mask = sequence_mask(len_short, num_rep.size(1)).float()
            
            
            # =================================================================================== #
            #                                    2. Train                                         #
            # =================================================================================== #
            
            self.P = self.P.train()
            
            
            sp_real_sft = torch.zeros_like(sp_real)
            sp_real_sft[:, 1:, :] = sp_real[:, :-1, :]    
            
            
            spect_pred, stop_pred_sp = self.P(cep_real.transpose(2,1),
                                              mask_long,
                                              codes_mask,
                                              num_rep,
                                              len_short+1,
                                              sp_real_sft.transpose(1,0), 
                                              len_real+1,
                                              spk_emb)
                        
            
            loss_tx2sp = (F.mse_loss(spect_pred.permute(1,0,2), sp_real, reduction='none')
                          * loss_tx2sp_mask).sum() / loss_tx2sp_mask.sum()
                          
            loss_stop_sp = self.BCELoss(stop_pred_sp.squeeze(-1).t(), mask_sp_real.float())
            
            loss_total = loss_tx2sp + loss_stop_sp
            
            # Backward and optimize
            self.optimizer.zero_grad()
            loss_total.backward()
            self.optimizer.step()
            

            # Logging
            loss = {}
            loss['P/loss_tx2sp'] = loss_tx2sp.item()
            loss['P/loss_stop_sp'] = loss_stop_sp.item()
            

            # =================================================================================== #
            #                                 4. Miscellaneous                                    #
            # =================================================================================== #

            # Print out training information
            if (i+1) % self.log_step == 0:
                et = time.time() - start_time
                et = str(datetime.timedelta(seconds=et))[:-7]
                log = "Elapsed [{}], Iteration [{}/{}]".format(et, i+1, self.num_iters)
                for tag in keys:
                    log += ", {}: {:.8f}".format(tag, loss[tag])
                print(log)
                
                
            # Save model checkpoints.
            if (i+1) % 10000 == 0:
                torch.save({'model': self.P.state_dict(),
                            'optimizer': self.optimizer.state_dict()}, f'./assets/{i+1}-B.ckpt')
                print('Saved model checkpoints into assets ...')