import copy
import datetime
import os
import random
import time

import numpy as np
import torch
from tqdm import tqdm

from openrec.losses import build_loss
from openrec.metrics import build_metric
from openrec.modeling import build_model
from openrec.optimizer import build_optimizer
from openrec.postprocess import build_post_process
from tools.data import build_dataloader
from tools.utils.ckpt import load_ckpt, save_ckpt
from tools.utils.logging import get_logger
from tools.utils.stats import TrainingStats
from tools.utils.utility import AverageMeter

__all__ = ['Trainer']


def get_parameter_number(model):
    total_num = sum(p.numel() for p in model.parameters())
    trainable_num = sum(p.numel() for p in model.parameters()
                        if p.requires_grad)
    return {'Total': total_num, 'Trainable': trainable_num}


class Trainer(object):

    def __init__(self, cfg, mode='train'):
        self.cfg = cfg.cfg

        self.local_rank = (int(os.environ['LOCAL_RANK'])
                           if 'LOCAL_RANK' in os.environ else 0)
        self.set_device(self.cfg['Global']['device'])
        mode = mode.lower()
        assert mode in [
            'train_eval',
            'train',
            'eval',
            'test',
        ], 'mode should be train, eval and test'
        if torch.cuda.device_count() > 1 and 'train' in mode:
            torch.distributed.init_process_group(backend='nccl')
            torch.cuda.set_device(self.device)
            self.cfg['Global']['distributed'] = True
        else:
            self.cfg['Global']['distributed'] = False
            self.local_rank = 0

        self.cfg['Global']['output_dir'] = self.cfg['Global'].get(
            'output_dir', 'output')
        os.makedirs(self.cfg['Global']['output_dir'], exist_ok=True)

        self.writer = None
        if self.local_rank == 0 and self.cfg['Global'][
                'use_tensorboard'] and 'train' in mode:
            from torch.utils.tensorboard import SummaryWriter

            self.writer = SummaryWriter(self.cfg['Global']['output_dir'])

        self.logger = get_logger(
            'openrec',
            os.path.join(self.cfg['Global']['output_dir'], 'train.log')
            if 'train' in mode else None,
        )

        cfg.print_cfg(self.logger.info)

        if self.cfg['Global']['device'] == 'gpu' and self.device.type == 'cpu':
            self.logger.info('cuda is not available, auto switch to cpu')

        self.grad_clip_val = self.cfg['Global'].get('grad_clip_val', 0)
        self.all_ema = self.cfg['Global'].get('all_ema', True)
        self.use_ema = self.cfg['Global'].get('use_ema', True)

        self.set_random_seed(self.cfg['Global'].get('seed', 48))

        # build data loader
        self.train_dataloader = None
        if 'train' in mode:
            cfg.save(
                os.path.join(self.cfg['Global']['output_dir'], 'config.yml'),
                self.cfg)
            self.train_dataloader = build_dataloader(self.cfg, 'Train',
                                                     self.logger)
            self.logger.info(
                f'train dataloader has {len(self.train_dataloader)} iters')
        self.valid_dataloader = None
        if 'eval' in mode and self.cfg['Eval']:
            self.valid_dataloader = build_dataloader(self.cfg, 'Eval',
                                                     self.logger)
            self.logger.info(
                f'valid dataloader has {len(self.valid_dataloader)} iters')

        # build post process
        self.post_process_class = build_post_process(self.cfg['PostProcess'],
                                                     self.cfg['Global'])
        # build model
        # for rec algorithm
        char_num = self.post_process_class.get_character_num()
        self.cfg['Architecture']['Decoder']['out_channels'] = char_num

        self.model = build_model(self.cfg['Architecture'])
        self.logger.info(get_parameter_number(model=self.model))
        self.model = self.model.to(self.device)

        if self.local_rank == 0:
            ema_model = build_model(self.cfg['Architecture'])
            self.ema_model = ema_model.to(self.device)
            self.ema_model.eval()

        use_sync_bn = self.cfg['Global'].get('use_sync_bn', False)
        if use_sync_bn:
            self.model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
                self.model)
            self.logger.info('convert_sync_batchnorm')

        # build loss
        self.loss_class = build_loss(self.cfg['Loss'])

        self.optimizer, self.lr_scheduler = None, None
        if self.train_dataloader is not None:
            # build optim
            self.optimizer, self.lr_scheduler = build_optimizer(
                self.cfg['Optimizer'],
                self.cfg['LRScheduler'],
                epochs=self.cfg['Global']['epoch_num'],
                step_each_epoch=len(self.train_dataloader),
                model=self.model,
            )

        self.eval_class = build_metric(self.cfg['Metric'])

        self.status = load_ckpt(self.model, self.cfg, self.optimizer,
                                self.lr_scheduler)

        if self.cfg['Global']['distributed']:
            self.model = torch.nn.parallel.DistributedDataParallel(
                self.model, [self.local_rank], find_unused_parameters=False)

        # amp
        self.scaler = (torch.cuda.amp.GradScaler() if self.cfg['Global'].get(
            'use_amp', False) else None)

        self.logger.info(
            f'run with torch {torch.__version__} and device {self.device}')

    def load_params(self, params):
        self.model.load_state_dict(params)

    def set_random_seed(self, seed):
        torch.manual_seed(seed)  # 为CPU设置随机种子
        if self.device.type == 'cuda':
            torch.backends.cudnn.benchmark = True
            torch.cuda.manual_seed(seed)  # 为当前GPU设置随机种子
            torch.cuda.manual_seed_all(seed)  # 为所有GPU设置随机种子
        random.seed(seed)
        np.random.seed(seed)

    def set_device(self, device):
        if device == 'gpu' and torch.cuda.is_available():
            device = torch.device(f'cuda:{self.local_rank}')
        else:
            device = torch.device('cpu')
        self.device = device

    def train(self):
        cal_metric_during_train = self.cfg['Global'].get(
            'cal_metric_during_train', False)
        log_smooth_window = self.cfg['Global']['log_smooth_window']
        epoch_num = self.cfg['Global']['epoch_num']
        print_batch_step = self.cfg['Global']['print_batch_step']
        eval_epoch_step = self.cfg['Global'].get('eval_epoch_step', 1)

        start_eval_epoch = 0
        if self.valid_dataloader is not None:
            if type(eval_epoch_step) == list and len(eval_epoch_step) >= 2:
                start_eval_epoch = eval_epoch_step[0]
                eval_epoch_step = eval_epoch_step[1]
                if len(self.valid_dataloader) == 0:
                    start_eval_epoch = 1e111
                    self.logger.info(
                        'No Images in eval dataset, evaluation during training will be disabled'
                    )
                self.logger.info(
                    f'During the training process, after the {start_eval_epoch}th epoch, '
                    f'an evaluation is run every {eval_epoch_step} epoch')
        else:
            start_eval_epoch = 1e111

        eval_batch_step = self.cfg['Global']['eval_batch_step']

        global_step = self.status.get('global_step', 0)

        start_eval_step = 0
        if type(eval_batch_step) == list and len(eval_batch_step) >= 2:
            start_eval_step = eval_batch_step[0]
            eval_batch_step = eval_batch_step[1]
            if len(self.valid_dataloader) == 0:
                self.logger.info(
                    'No Images in eval dataset, evaluation during training '
                    'will be disabled')
                start_eval_step = 1e111
            self.logger.info(
                'During the training process, after the {}th iteration, '
                'an evaluation is run every {} iterations'.format(
                    start_eval_step, eval_batch_step))

        start_epoch = self.status.get('epoch', 1)
        best_metric = self.status.get('metrics', {})
        if self.eval_class.main_indicator not in best_metric:
            best_metric[self.eval_class.main_indicator] = 0
        ema_best_metric = self.status.get('metrics', {})
        ema_best_metric[self.eval_class.main_indicator] = 0
        train_stats = TrainingStats(log_smooth_window, ['lr'])
        self.model.train()

        total_samples = 0
        train_reader_cost = 0.0
        train_batch_cost = 0.0
        best_iter = 0
        ema_stpe = 1
        ema_eval_iter = 0
        loss_avg = 0.
        reader_start = time.time()
        eta_meter = AverageMeter()

        for epoch in range(start_epoch, epoch_num + 1):
            if self.train_dataloader.dataset.need_reset:
                self.train_dataloader = build_dataloader(
                    self.cfg,
                    'Train',
                    self.logger,
                    epoch=epoch % 20 if epoch % 20 != 0 else 20,
                )

            for idx, batch in enumerate(self.train_dataloader):
                batch = [t.to(self.device) for t in batch]
                self.optimizer.zero_grad()
                train_reader_cost += time.time() - reader_start
                # use amp
                if self.scaler:
                    with torch.cuda.amp.autocast():
                        preds = self.model(batch[0], data=batch[1:])
                        loss = self.loss_class(preds, batch)
                    self.scaler.scale(loss['loss']).backward()
                    if self.grad_clip_val > 0:
                        torch.nn.utils.clip_grad_norm_(
                            self.model.parameters(),
                            max_norm=self.grad_clip_val)
                    self.scaler.step(self.optimizer)
                    self.scaler.update()
                else:
                    preds = self.model(batch[0], data=batch[1:])
                    loss = self.loss_class(preds, batch)
                    avg_loss = loss['loss']
                    avg_loss.backward()
                    if self.grad_clip_val > 0:
                        torch.nn.utils.clip_grad_norm_(
                            self.model.parameters(),
                            max_norm=self.grad_clip_val)
                    self.optimizer.step()

                if cal_metric_during_train:  # only rec and cls need
                    post_result = self.post_process_class(preds,
                                                          batch,
                                                          training=True)
                    self.eval_class(post_result, batch, training=True)
                    metric = self.eval_class.get_metric()
                    train_stats.update(metric)

                train_batch_time = time.time() - reader_start
                train_batch_cost += train_batch_time
                eta_meter.update(train_batch_time)
                global_step += 1
                total_samples += len(batch[0])

                self.lr_scheduler.step()

                if self.local_rank == 0 and self.use_ema and epoch > (
                        epoch_num - epoch_num // 10):
                    with torch.no_grad():
                        loss_currn = loss['loss'].detach().cpu().numpy().mean()
                        loss_avg = ((loss_avg *
                                     (ema_stpe - 1)) + loss_currn) / (ema_stpe)
                        if ema_stpe == 1:

                            # current_weight  = copy.deepcopy(self.model.module.state_dict())
                            ema_state_dict = copy.deepcopy(
                                self.model.module.state_dict() if self.
                                cfg['Global']['distributed'] else self.model.
                                state_dict())
                            self.ema_model.load_state_dict(ema_state_dict)
                        # if global_step > (epoch_num - epoch_num//10)*max_iter:
                        elif loss_currn <= loss_avg or self.all_ema:
                            # eval_batch_step = 500
                            current_weight = copy.deepcopy(
                                self.model.module.state_dict() if self.
                                cfg['Global']['distributed'] else self.model.
                                state_dict())
                            k1 = 1 / (ema_stpe + 1)
                            k2 = 1 - k1
                            for k, v in ema_state_dict.items():
                                # v = (v * (ema_stpe - 1) + current_weight[k])/ema_stpe
                                v = v * k2 + current_weight[k] * k1
                                # v.req = True
                                ema_state_dict[k] = v
                            # ema_stpe += 1
                            self.ema_model.load_state_dict(ema_state_dict)
                    ema_stpe += 1
                    if global_step > start_eval_step and (
                            global_step -
                            start_eval_step) % eval_batch_step == 0:
                        ema_cur_metric = self.eval_ema()
                        ema_cur_metric_str = f"cur ema metric, {', '.join(['{}: {}'.format(k, v) for k, v in ema_cur_metric.items()])}"
                        self.logger.info(ema_cur_metric_str)
                        state = {
                            'epoch': epoch,
                            'global_step': global_step,
                            'state_dict': self.ema_model.state_dict(),
                            'optimizer': None,
                            'scheduler': None,
                            'config': self.cfg,
                            'metrics': ema_cur_metric,
                        }
                        save_path = os.path.join(
                            self.cfg['Global']['output_dir'],
                            'ema_' + str(ema_eval_iter) + '.pth')
                        torch.save(state, save_path)
                        self.logger.info(f'save ema ckpt to {save_path}')
                        ema_eval_iter += 1
                        if ema_cur_metric[self.eval_class.
                                          main_indicator] >= ema_best_metric[
                                              self.eval_class.main_indicator]:
                            ema_best_metric.update(ema_cur_metric)
                            ema_best_metric['best_epoch'] = epoch
                        best_ema_str = f"best metric, {', '.join(['{}: {}'.format(k, v) for k, v in ema_best_metric.items()])}"
                        self.logger.info(best_ema_str)

                # logger
                stats = {
                    k: float(v)
                    if v.shape == [] else v.detach().cpu().numpy().mean()
                    for k, v in loss.items()
                }
                stats['lr'] = self.lr_scheduler.get_last_lr()[0]
                train_stats.update(stats)

                if self.writer is not None:
                    for k, v in train_stats.get().items():
                        self.writer.add_scalar(f'TRAIN/{k}', v, global_step)

                if self.local_rank == 0 and (
                    (global_step > 0 and global_step % print_batch_step == 0)
                        or (idx >= len(self.train_dataloader) - 1)):
                    logs = train_stats.log()

                    eta_sec = (
                        (epoch_num + 1 - epoch) * len(self.train_dataloader) -
                        idx - 1) * eta_meter.avg
                    eta_sec_format = str(
                        datetime.timedelta(seconds=int(eta_sec)))
                    strs = (
                        f'epoch: [{epoch}/{epoch_num}], global_step: {global_step}, {logs}, '
                        f'avg_reader_cost: {train_reader_cost / print_batch_step:.5f} s, '
                        f'avg_batch_cost: {train_batch_cost / print_batch_step:.5f} s, '
                        f'avg_samples: {total_samples / print_batch_step}, '
                        f'ips: {total_samples / train_batch_cost:.5f} samples/s, '
                        f'eta: {eta_sec_format}')
                    self.logger.info(strs)
                    total_samples = 0
                    train_reader_cost = 0.0
                    train_batch_cost = 0.0
                reader_start = time.time()
                # eval
                if (global_step > start_eval_step and
                    (global_step - start_eval_step) % eval_batch_step
                        == 0) and self.local_rank == 0:
                    cur_metric = self.eval()
                    cur_metric_str = f"cur metric, {', '.join(['{}: {}'.format(k, v) for k, v in cur_metric.items()])}"
                    self.logger.info(cur_metric_str)

                    # logger metric
                    if self.writer is not None:
                        for k, v in cur_metric.items():
                            if isinstance(v, (float, int)):
                                self.writer.add_scalar(f'EVAL/{k}',
                                                       cur_metric[k],
                                                       global_step)

                    if (cur_metric[self.eval_class.main_indicator] >=
                            best_metric[self.eval_class.main_indicator]):
                        best_metric.update(cur_metric)
                        best_metric['best_epoch'] = epoch
                        if self.writer is not None:
                            self.writer.add_scalar(
                                f'EVAL/best_{self.eval_class.main_indicator}',
                                best_metric[self.eval_class.main_indicator],
                                global_step,
                            )
                        if epoch > (epoch_num - epoch_num // 10 - 2):
                            save_ckpt(self.model,
                                      self.cfg,
                                      self.optimizer,
                                      self.lr_scheduler,
                                      epoch,
                                      global_step,
                                      best_metric,
                                      is_best=True,
                                      prefix='best_' + str(best_iter))
                            best_iter += 1
                        # else:
                        save_ckpt(self.model,
                                  self.cfg,
                                  self.optimizer,
                                  self.lr_scheduler,
                                  epoch,
                                  global_step,
                                  best_metric,
                                  is_best=True,
                                  prefix=None)
                    best_str = f"best metric, {', '.join(['{}: {}'.format(k, v) for k, v in best_metric.items()])}"
                    self.logger.info(best_str)
            if self.local_rank == 0 and epoch > start_eval_epoch and (
                    epoch - start_eval_epoch) % eval_epoch_step == 0:
                cur_metric = self.eval()
                cur_metric_str = f"cur metric, {', '.join(['{}: {}'.format(k, v) for k, v in cur_metric.items()])}"
                self.logger.info(cur_metric_str)

                # logger metric
                if self.writer is not None:
                    for k, v in cur_metric.items():
                        if isinstance(v, (float, int)):
                            self.writer.add_scalar(f'EVAL/{k}', cur_metric[k],
                                                   global_step)

                if (cur_metric[self.eval_class.main_indicator] >=
                        best_metric[self.eval_class.main_indicator]):
                    best_metric.update(cur_metric)
                    best_metric['best_epoch'] = epoch
                    if self.writer is not None:
                        self.writer.add_scalar(
                            f'EVAL/best_{self.eval_class.main_indicator}',
                            best_metric[self.eval_class.main_indicator],
                            global_step,
                        )
                    if epoch > (epoch_num - epoch_num // 10 - 2):
                        save_ckpt(self.model,
                                  self.cfg,
                                  self.optimizer,
                                  self.lr_scheduler,
                                  epoch,
                                  global_step,
                                  best_metric,
                                  is_best=True,
                                  prefix='best_' + str(best_iter))
                        best_iter += 1
                    # else:
                    save_ckpt(self.model,
                              self.cfg,
                              self.optimizer,
                              self.lr_scheduler,
                              epoch,
                              global_step,
                              best_metric,
                              is_best=True,
                              prefix=None)
                best_str = f"best metric, {', '.join(['{}: {}'.format(k, v) for k, v in best_metric.items()])}"
                self.logger.info(best_str)

            if self.local_rank == 0:
                save_ckpt(self.model,
                          self.cfg,
                          self.optimizer,
                          self.lr_scheduler,
                          epoch,
                          global_step,
                          best_metric,
                          is_best=False,
                          prefix=None)
                if epoch > (epoch_num - epoch_num // 10 - 2):
                    save_ckpt(self.model,
                              self.cfg,
                              self.optimizer,
                              self.lr_scheduler,
                              epoch,
                              global_step,
                              best_metric,
                              is_best=False,
                              prefix='epoch_' + str(epoch))
                if self.use_ema and epoch > (epoch_num - epoch_num // 10):
                    # if global_step > start_eval_step and (global_step - start_eval_step) % eval_batch_step == 0:
                    ema_cur_metric = self.eval_ema()
                    ema_cur_metric_str = f"cur ema metric, {', '.join(['{}: {}'.format(k, v) for k, v in ema_cur_metric.items()])}"
                    self.logger.info(ema_cur_metric_str)
                    state = {
                        'epoch': epoch,
                        'global_step': global_step,
                        'state_dict': self.ema_model.state_dict(),
                        'optimizer': None,
                        'scheduler': None,
                        'config': self.cfg,
                        'metrics': ema_cur_metric,
                    }
                    save_path = os.path.join(
                        self.cfg['Global']['output_dir'],
                        'ema_' + str(ema_eval_iter) + '.pth')
                    torch.save(state, save_path)
                    self.logger.info(f'save ema ckpt to {save_path}')
                    ema_eval_iter += 1
                    if (ema_cur_metric[self.eval_class.main_indicator] >=
                            ema_best_metric[self.eval_class.main_indicator]):
                        ema_best_metric.update(ema_cur_metric)
                        ema_best_metric['best_epoch'] = epoch
                        # ema_cur_metric_str = f"best ema metric, {', '.join(['{}: {}'.format(k, v) for k, v in ema_best_metric.items()])}"
                    best_ema_str = f"best metric, {', '.join(['{}: {}'.format(k, v) for k, v in ema_best_metric.items()])}"
                    self.logger.info(best_ema_str)
        best_str = f"best metric, {', '.join(['{}: {}'.format(k, v) for k, v in best_metric.items()])}"
        self.logger.info(best_str)
        if self.writer is not None:
            self.writer.close()
        if torch.cuda.device_count() > 1:
            torch.distributed.destroy_process_group()

    def eval(self):
        self.model.eval()
        with torch.no_grad():
            total_frame = 0.0
            total_time = 0.0
            pbar = tqdm(
                total=len(self.valid_dataloader),
                desc='eval model:',
                position=0,
                leave=True,
            )
            sum_images = 0
            for idx, batch in enumerate(self.valid_dataloader):
                batch = [t.to(self.device) for t in batch]
                start = time.time()
                if self.scaler:
                    with torch.cuda.amp.autocast():
                        preds = self.model(batch[0], data=batch[1:])
                else:
                    preds = self.model(batch[0], data=batch[1:])

                total_time += time.time() - start
                # Obtain usable results from post-processing methods
                # Evaluate the results of the current batch
                post_result = self.post_process_class(preds, batch)
                self.eval_class(post_result, batch)

                pbar.update(1)
                total_frame += len(batch[0])
                sum_images += 1
            # Get final metric,eg. acc or hmean
            metric = self.eval_class.get_metric()

        pbar.close()
        self.model.train()
        metric['fps'] = total_frame / total_time
        return metric

    def eval_ema(self):
        # self.model.eval()
        with torch.no_grad():
            total_frame = 0.0
            total_time = 0.0
            pbar = tqdm(
                total=len(self.valid_dataloader),
                desc='eval ema_model:',
                position=0,
                leave=True,
            )
            sum_images = 0
            for idx, batch in enumerate(self.valid_dataloader):
                batch = [t.to(self.device) for t in batch]
                start = time.time()
                if self.scaler:
                    with torch.cuda.amp.autocast():
                        preds = self.ema_model(batch[0], data=batch[1:])
                else:
                    preds = self.ema_model(batch[0], data=batch[1:])

                total_time += time.time() - start
                # Obtain usable results from post-processing methods
                # Evaluate the results of the current batch
                post_result = self.post_process_class(preds, batch)
                self.eval_class(post_result, batch)

                pbar.update(1)
                total_frame += len(batch[0])
                sum_images += 1
            # Get final metric,eg. acc or hmean
            metric = self.eval_class.get_metric()

        pbar.close()
        # self.model.train()
        metric['fps'] = total_frame / total_time
        return metric

    def test_dataloader(self):
        starttime = time.time()
        count = 0
        try:
            for data in self.train_dataloader:
                count += 1
                if count % 1 == 0:
                    batch_time = time.time() - starttime
                    starttime = time.time()
                    self.logger.info(
                        f'reader: {count}, {data[0].shape}, {batch_time}')
        except:
            import traceback

            self.logger.info(traceback.format_exc())
        self.logger.info(f'finish reader: {count}, Success!')