# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.

import torch
from PIL.Image import Image
from collections import OrderedDict
from scepter.modules.utils.distribute import we
from scepter.modules.utils.config import Config
from scepter.modules.utils.logger import get_logger
from scepter.studio.utils.env import get_available_memory
from scepter.modules.model.registry import MODELS, BACKBONES, EMBEDDERS
from scepter.modules.utils.registry import Registry, build_from_config
def get_model(model_tuple):
    assert 'model' in model_tuple
    return model_tuple['model']

class BaseInference():
    '''
        support to load the components dynamicly.
        create and load model when run this model at the first time.
    '''
    def __init__(self, cfg, logger=None):
        if logger is None:
            logger = get_logger(name='scepter')
        self.logger = logger
        self.name = cfg.NAME

    def init_from_modules(self, modules):
        for k, v in modules.items():
            self.__setattr__(k, v)

    def infer_model(self, cfg, module_paras=None):
        module = {
            'model': None,
            'cfg': cfg,
            'device': 'offline',
            'name': cfg.NAME,
            'function_info': {},
            'paras': {}
        }
        if module_paras is None:
            return module
        function_info = {}
        paras = {
            k.lower(): v
            for k, v in module_paras.get('PARAS', {}).items()
        }
        for function in module_paras.get('FUNCTION', []):
            input_dict = {}
            for inp in function.get('INPUT', []):
                if inp.lower() in self.input:
                    input_dict[inp.lower()] = self.input[inp.lower()]
            function_info[function.NAME] = {
                'dtype': function.get('DTYPE', 'float32'),
                'input': input_dict
            }
        module['paras'] = paras
        module['function_info'] = function_info
        return module

    def init_from_ckpt(self, path, model, ignore_keys=list()):
        if path.endswith('safetensors'):
            from safetensors.torch import load_file as load_safetensors
            sd = load_safetensors(path)
        else:
            sd = torch.load(path, map_location='cpu', weights_only=True)

        new_sd = OrderedDict()
        for k, v in sd.items():
            ignored = False
            for ik in ignore_keys:
                if ik in k:
                    if we.rank == 0:
                        self.logger.info(
                            'Ignore key {} from state_dict.'.format(k))
                    ignored = True
                    break
            if not ignored:
                new_sd[k] = v

        missing, unexpected = model.load_state_dict(new_sd, strict=False)
        if we.rank == 0:
            self.logger.info(
                f'Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys'
            )
            if len(missing) > 0:
                self.logger.info(f'Missing Keys:\n {missing}')
            if len(unexpected) > 0:
                self.logger.info(f'\nUnexpected Keys:\n {unexpected}')

    def load(self, module):
        if module['device'] == 'offline':
            from scepter.modules.utils.import_utils import LazyImportModule
            if (LazyImportModule.get_module_type(('MODELS', module['cfg'].NAME)) or
                    module['cfg'].NAME in MODELS.class_map):
                model = MODELS.build(module['cfg'], logger=self.logger).eval()
            elif (LazyImportModule.get_module_type(('BACKBONES', module['cfg'].NAME)) or
                    module['cfg'].NAME in BACKBONES.class_map):
                model = BACKBONES.build(module['cfg'],
                                        logger=self.logger).eval()
            elif (LazyImportModule.get_module_type(('EMBEDDERS', module['cfg'].NAME)) or
                    module['cfg'].NAME in EMBEDDERS.class_map):
                model = EMBEDDERS.build(module['cfg'],
                                        logger=self.logger).eval()
            else:
                raise NotImplementedError
            if 'DTYPE' in module['cfg'] and module['cfg']['DTYPE'] is not None:
                model = model.to(getattr(torch, module['cfg'].DTYPE))
            if module['cfg'].get('RELOAD_MODEL', None):
                self.init_from_ckpt(module['cfg'].RELOAD_MODEL, model)
            module['model'] = model
            module['device'] = 'cpu'
        if module['device'] == 'cpu':
            module['device'] = we.device_id
            module['model'] = module['model'].to(we.device_id)
        return module

    def unload(self, module):
        if module is None:
            return module
        mem = get_available_memory()
        free_mem = int(mem['available'] / (1024**2))
        total_mem = int(mem['total'] / (1024**2))
        if free_mem < 0.5 * total_mem:
            if module['model'] is not None:
                module['model'] = module['model'].to('cpu')
                del module['model']
            module['model'] = None
            module['device'] = 'offline'
            print('delete module')
        else:
            if module['model'] is not None:
                module['model'] = module['model'].to('cpu')
                module['device'] = 'cpu'
            else:
                module['device'] = 'offline'
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            torch.cuda.ipc_collect()
        return module

    def dynamic_load(self, module=None, name=''):
        self.logger.info('Loading {} model'.format(name))
        if name == 'all':
            for subname in self.loaded_model_name:
                self.loaded_model[subname] = self.dynamic_load(
                    getattr(self, subname), subname)
        elif name in self.loaded_model_name:
            if name in self.loaded_model:
                if module['cfg'] != self.loaded_model[name]['cfg']:
                    self.unload(self.loaded_model[name])
                    module = self.load(module)
                    self.loaded_model[name] = module
                    return module
                elif module['device'] == 'cpu' or module['device'] == 'offline':
                    module = self.load(module)
                    return module
                else:
                    return module
            else:
                module = self.load(module)
                self.loaded_model[name] = module
                return module
        else:
            return self.load(module)

    def dynamic_unload(self, module=None, name='', skip_loaded=False):
        self.logger.info('Unloading {} model'.format(name))
        if name == 'all':
            for name, module in self.loaded_model.items():
                module = self.unload(self.loaded_model[name])
                self.loaded_model[name] = module
        elif name in self.loaded_model_name:
            if name in self.loaded_model:
                if not skip_loaded:
                    module = self.unload(self.loaded_model[name])
                    self.loaded_model[name] = module
            else:
                self.unload(module)
        else:
            self.unload(module)

    def load_default(self, cfg):
        module_paras = {}
        if cfg is not None:
            self.paras = cfg.PARAS
            self.input_cfg = {k.lower(): v for k, v in cfg.INPUT.items()}
            self.input = {k.lower(): dict(v).get('DEFAULT', None) if isinstance(v, (dict, OrderedDict, Config)) else v for k, v in cfg.INPUT.items()}
            self.output = {k.lower(): v for k, v in cfg.OUTPUT.items()}
            module_paras = cfg.MODULES_PARAS
        return module_paras

    def load_image(self, image, num_samples=1):
        if isinstance(image, torch.Tensor):
            pass
        elif isinstance(image, Image):
            pass
        elif isinstance(image, Image):
            pass

    def get_function_info(self, module, function_name=None):
        all_function = module['function_info']
        if function_name in all_function:
            return function_name, all_function[function_name]['dtype']
        if function_name is None and len(all_function) == 1:
            for k, v in all_function.items():
                return k, v['dtype']

    @torch.no_grad()
    def __call__(self,
                 input,
                 **kwargs):
        return

def build_inference(cfg, registry, logger=None, *args, **kwargs):
    """ After build model, load pretrained model if exists key `pretrain`.

    pretrain (str, dict): Describes how to load pretrained model.
        str, treat pretrain as model path;
        dict: should contains key `path`, and other parameters token by function load_pretrained();
    """
    if not isinstance(cfg, Config):
        raise TypeError(f'Config must be type dict, got {type(cfg)}')
    model = build_from_config(cfg, registry, logger=logger, *args, **kwargs)
    return model

# reigister cls for diffusion.
INFERENCES = Registry('INFERENCE', build_func=build_inference)