# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import numpy as np
import torch
from scepter.modules.solver import LatentDiffusionSolver
from scepter.modules.solver.registry import SOLVERS
from scepter.modules.utils.data import transfer_data_to_cuda
from scepter.modules.utils.distribute import we
from scepter.modules.utils.probe import ProbeData
from tqdm import tqdm
@SOLVERS.register_class()
class FormalACEPlusSolver(LatentDiffusionSolver):
    def __init__(self, cfg, logger=None):
        super().__init__(cfg, logger=logger)
        self.probe_prompt = cfg.get("PROBE_PROMPT", None)
        self.probe_hw = cfg.get("PROBE_HW", [])

    @torch.no_grad()
    def run_eval(self):
        self.eval_mode()
        self.before_all_iter(self.hooks_dict[self._mode])
        all_results = []
        for batch_idx, batch_data in tqdm(
                enumerate(self.datas[self._mode].dataloader)):
            self.before_iter(self.hooks_dict[self._mode])
            if self.sample_args:
                batch_data.update(self.sample_args.get_lowercase_dict())
            with torch.autocast(device_type='cuda',
                                enabled=self.use_amp,
                                dtype=self.dtype):
                results = self.run_step_eval(transfer_data_to_cuda(batch_data),
                                             batch_idx,
                                             step=self.total_iter,
                                             rank=we.rank)
                all_results.extend(results)
            self.after_iter(self.hooks_dict[self._mode])
        log_data, log_label = self.save_results(all_results)
        self.register_probe({'eval_label': log_label})
        self.register_probe({
            'eval_image':
                ProbeData(log_data,
                          is_image=True,
                          build_html=True,
                          build_label=log_label)
        })
        self.after_all_iter(self.hooks_dict[self._mode])

    @torch.no_grad()
    def run_test(self):
        self.test_mode()
        self.before_all_iter(self.hooks_dict[self._mode])
        all_results = []
        for batch_idx, batch_data in tqdm(
                enumerate(self.datas[self._mode].dataloader)):
            self.before_iter(self.hooks_dict[self._mode])
            if self.sample_args:
                batch_data.update(self.sample_args.get_lowercase_dict())
            with torch.autocast(device_type='cuda',
                                enabled=self.use_amp,
                                dtype=self.dtype):
                results = self.run_step_eval(transfer_data_to_cuda(batch_data),
                                             batch_idx,
                                             step=self.total_iter,
                                             rank=we.rank)
                all_results.extend(results)
            self.after_iter(self.hooks_dict[self._mode])
        log_data, log_label = self.save_results(all_results)
        self.register_probe({'test_label': log_label})
        self.register_probe({
            'test_image':
                ProbeData(log_data,
                          is_image=True,
                          build_html=True,
                          build_label=log_label)
        })

        self.after_all_iter(self.hooks_dict[self._mode])

    def run_step_val(self, batch_data, batch_idx=0, step=None, rank=None):
        sample_id_list = batch_data['sample_id']
        loss_dict = {}
        with torch.autocast(device_type='cuda',
                            enabled=self.use_amp,
                            dtype=self.dtype):
            results = self.model.forward_train(**batch_data)
            loss = results['loss']
        for sample_id in sample_id_list:
            loss_dict[sample_id] = loss.detach().cpu().numpy()
        return loss_dict

    def save_results(self, results):
        log_data, log_label = [], []
        for result in results:
            ret_images, ret_labels = [], []
            edit_image = result.get('edit_image', None)
            modify_image = result.get('modify_image', None)
            edit_mask = result.get('edit_mask', None)
            if edit_image is not None:
                for i, edit_img in enumerate(result['edit_image']):
                    if edit_img is None:
                        continue
                    ret_images.append((edit_img.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8))
                    ret_labels.append(f'edit_image{i}; ')
                    ret_images.append((modify_image[i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8))
                    ret_labels.append(f'modify_image{i}; ')
                    if edit_mask is not None:
                        ret_images.append((edit_mask[i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8))
                        ret_labels.append(f'edit_mask{i}; ')

            target_image = result.get('target_image', None)
            target_mask = result.get('target_mask', None)
            if target_image is not None:
                ret_images.append((target_image.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8))
                ret_labels.append(f'target_image; ')
                if target_mask is not None:
                    ret_images.append((target_mask.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8))
                    ret_labels.append(f'target_mask; ')
            teacher_image = result.get('image', None)
            if teacher_image is not None:
                ret_images.append((teacher_image.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8))
                ret_labels.append(f"teacher_image")
            reconstruct_image = result.get('reconstruct_image', None)
            if reconstruct_image is not None:
                ret_images.append((reconstruct_image.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8))
                ret_labels.append(f"{result['instruction']}")
            log_data.append(ret_images)
            log_label.append(ret_labels)
        return log_data, log_label
    @property
    def probe_data(self):
        if not we.debug and self.mode == 'train':
            batch_data = transfer_data_to_cuda(self.current_batch_data[self.mode])
            self.eval_mode()
            with torch.autocast(device_type='cuda',
                                enabled=self.use_amp,
                                dtype=self.dtype):
                batch_data['log_num'] = self.log_train_num
                batch_data.update(self.sample_args.get_lowercase_dict())
                results = self.run_step_eval(batch_data)
            self.train_mode()
            log_data, log_label = self.save_results(results)
            self.register_probe({
                'train_image':
                    ProbeData(log_data,
                              is_image=True,
                              build_html=True,
                              build_label=log_label)
            })
            self.register_probe({'train_label': log_label})
            if self.probe_prompt:
                self.eval_mode()
                all_results = []
                for prompt in self.probe_prompt:
                    with torch.autocast(device_type='cuda',
                                        enabled=self.use_amp,
                                        dtype=self.dtype):
                        batch_data = {
                            "prompt": [[prompt]],
                            "image": [torch.zeros(3, self.probe_hw[0], self.probe_hw[1])],
                            "image_mask": [torch.ones(1, self.probe_hw[0], self.probe_hw[1])],
                            "src_image_list": [[]],
                            "modify_image_list": [[]],
                            "src_mask_list": [[]],
                            "edit_id": [[]],
                            "height": self.probe_hw[0],
                            "width": self.probe_hw[1]
                        }
                        batch_data.update(self.sample_args.get_lowercase_dict())
                        results = self.run_step_eval(batch_data)
                        all_results.extend(results)
                self.train_mode()
                log_data, log_label = self.save_results(all_results)
                self.register_probe({
                    'probe_image':
                        ProbeData(log_data,
                                  is_image=True,
                                  build_html=True,
                                  build_label=log_label)
                })

        return super(LatentDiffusionSolver, self).probe_data