"""
module to handle loading model on cpu/meta device for FSDP
"""
import os
import time
from typing import List, Optional, Type, Union

import safetensors
import torch
from accelerate import init_empty_weights
from bitsandbytes.nn import Linear4bit, Params4bit
from fastcore.parallel import parallel
from torch import Tensor, nn
from tqdm import tqdm
from transformers import AutoModelForCausalLM
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, hub


def _replace_linear(
    model: nn.Module,
    linear_replacement: Type[nn.Module],
    quant_config: Union[dict, None] = None,
    skip_modules=None,
    **kwargs,
):
    """
    Replace linear modules with a new Linear module.
    Parameters:
        model (`torch.nn.Module`):
            Input model or `torch.nn.Module` as the function is run recursively.
        linear_replacement (`torch.nn.Module`):
            The linear module that replaces the old one. Only expects standard arguments.
            If other arguments need to be passed, use a lambda.
        skip_modules (`List[str]`, *optional*, defaults to `lm_head`):
            List of modules names not to convert. Defaults to `lm_head`.
    """
    if skip_modules is None:
        skip_modules = ["lm_head"]
    for name, module in model.named_children():
        if len(list(module.children())) > 0:
            _replace_linear(
                module, linear_replacement, quant_config, skip_modules, **kwargs
            )

        if isinstance(module, torch.nn.Linear) and name not in skip_modules:
            if issubclass(linear_replacement, Linear4bit):
                model._modules[  # pylint: disable=protected-access
                    name
                ] = linear_replacement(
                    module.in_features,
                    module.out_features,
                    module.bias is not None,
                    **kwargs,
                )
            else:
                raise ValueError(
                    f"Unsupported linear replacement: {type(linear_replacement)}"
                )
    return model


def load_and_quantize(
    module: nn.Module,
    name: str,
    value: Tensor,
    device: torch.device = None,
    dtype: torch.dtype = None,
    skip_names: Optional[List[str]] = None,
    to_cpu: bool = False,
    to_meta: bool = False,
    verbose: bool = False,
    quant_method: str = "bnb",
):
    """
    Loads `value` tensor into submodule of `module`, optionally skipping `skip_names` and converting to `dtype`.

    Quantizes `Params4bit` on `device` then places on "cpu" if to_cpu=True or "meta" if to_meta=True.
    """

    if not skip_names:
        skip_names = []

    def place_on_device(value):
        if to_meta:
            device = "meta"
        elif to_cpu:
            device = "cpu"
        return value.to(device=device, dtype=dtype)

    if any(skip_name in name for skip_name in skip_names):
        if verbose:
            print(f"Skipping {name} because it is in skip_names")
        return

    module_key, _, value_key = name.rpartition(".")
    try:
        submodule = module.get_submodule(module_key)
    except AttributeError as exc:
        print(f"Module {module_key} not found:\n{exc}")
        return

    try:
        if quant_method == "bnb":
            param = submodule.get_parameter(value_key)
            if isinstance(param, Params4bit):
                # With `sync_module_states=True`, a meta device Params4bit needs to be the same
                # shape as the quantized Params4bit with an initialized quant_state. However,
                # FSDP only syncs parameters and buffers, so the quant_state isn't copied. This
                # workaround quantizes Params4bit to initialize quant_state on all ranks, then
                # replaces Params4bit's data with a meta tensor to free memory on non-rank 0.
                value = type(param)(
                    value.to(device=device, dtype=dtype).data, **param.__dict__
                ).cuda(device)
                if to_meta:
                    value = type(param)(value.data.to("meta"), **value.__dict__)
                elif to_cpu:
                    value = type(param)(value.data.to("cpu"), **value.__dict__)
            else:
                value = type(param)(place_on_device(value).data)

    except AttributeError:
        # it's a buffer
        value = place_on_device(value)

    setattr(submodule, value_key, value)


def n_loading_workers(quant_method: str, param_count: float):
    devprops = torch.cuda.get_device_properties(torch.cuda.current_device())
    left = int(os.cpu_count() / torch.cuda.device_count())
    model_params_b = 70
    right = int(
        (4 if quant_method == "hqq" else 8)
        * (devprops.total_memory / 1e9 / 40)
        * (model_params_b / (param_count / 1e9))
    )
    return min(left, right)


def load_sharded_model(
    model_name,
    model_config,
    cfg,
    torch_dtype=torch.bfloat16,
    low_memory=True,
):
    if (low_memory and cfg.local_rank == 0) or not low_memory:
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            use_cache=False,
            torch_dtype=torch.float32,
            _attn_implementation=model_config._attn_implementation,  # pylint: disable=protected-access
            trust_remote_code=cfg.trust_remote_code,
        )
        dtype = torch_dtype if not cfg.float32 else None
        model.to(dtype=dtype, device="cpu" if low_memory else cfg.local_rank)
    else:
        with init_empty_weights():
            model = AutoModelForCausalLM.from_config(
                model_config,
                torch_dtype=torch_dtype,
                trust_remote_code=cfg.trust_remote_code,
            )
    return model


def load_sharded_model_quant(
    model_name,
    model_config,
    cfg,
    compute_dtype=torch.bfloat16,
    quant_storage=torch.float32,
    low_memory=True,
    verbose=False,
    loading_workers=2,
):
    with init_empty_weights():
        model = AutoModelForCausalLM.from_config(
            model_config,
            trust_remote_code=cfg.trust_remote_code,
        )
        if hasattr(model, "transformer"):
            model.transformer = _replace_linear(
                model.transformer,
                Linear4bit,
                compute_dtype=compute_dtype,
                quant_type="nf4",
                quant_storage=quant_storage,
            )
        else:
            # this is the more common case with HF transformers
            model.model = _replace_linear(
                model.model,
                Linear4bit,
                compute_dtype=compute_dtype,
                quant_type="nf4",
                quant_storage=quant_storage,
            )
    model.is_loaded_in_4bit = True

    # Grab the safetensors files that hold the weights
    try:
        idx = hub.cached_file(model_name, SAFE_WEIGHTS_INDEX_NAME)
        files, _ = hub.get_checkpoint_shard_files(model_name, idx)
    except OSError:
        try:
            # This means the model doesn't have a model.safetensors.index.json because it is not sharded
            files = []
            files.append(hub.cached_file(model_name, SAFE_WEIGHTS_NAME))
        except OSError as exc:
            # This means the model probably doesn't have a safetensors file
            raise exc

    # Load in the weights, using our custom load_and_quantize method which quantizes Params4bit on the fly
    # and then places each layer on CPU or meta if using low_memory to minimize GPU memory usage
    def load_and_quantize_parallel(name_param, model, **kwargs):
        name, param = name_param
        load_and_quantize(model, name, param, **kwargs)

    quant_method = "bnb"
    param_count = sum((p.numel() for n, p in model.named_parameters()))

    n_workers = (
        n_loading_workers(quant_method, param_count)
        if loading_workers == -1
        else loading_workers
    )
    if cfg.local_rank == 0 and verbose:
        print(f"Using n_workers: {n_workers} for loading")

    start = time.time()
    for filename in tqdm(
        files,
        desc="Loading & Quantizing Model Shards",
        disable=cfg.local_rank != 0,
        position=0,
    ):
        weights = safetensors.torch.load_file(filename)
        parallel(
            load_and_quantize_parallel,
            iter(weights.items()),
            n_workers=n_workers,
            threadpool=True,
            model=model,
            dtype=quant_storage,
            device=cfg.local_rank,
            skip_names=[],
            to_cpu=(low_memory and cfg.local_rank == 0),
            to_meta=(low_memory and cfg.local_rank != 0),
            verbose=verbose,
            quant_method=quant_method,
        )

    if cfg.local_rank == 0 and verbose:
        print(f"Loaded model weights in {time.time()-start:.3f} seconds")
    # cleanup any extra memory usage from parallel loading
    torch.cuda.empty_cache()

    return model