from typing import Callable
import os
from typing import Optional, Tuple, Union
import warnings

from mmcv import Config
import torch
import wandb

from risk_biased.predictors.biased_predictor import (
    LitTrajectoryPredictor,
    LitTrajectoryPredictorParams,
)

from risk_biased.utils.config_argparse import config_argparse
from risk_biased.utils.cost import TTCCostParams
from risk_biased.utils.torch_utils import load_weights

from risk_biased.scene_dataset.loaders import SceneDataLoaders
from risk_biased.scene_dataset.scene import load_create_dataset

from risk_biased.utils.waymo_dataloader import WaymoDataloaders


def get_predictor(
    config: Config, unnormalizer: Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
):
    params = LitTrajectoryPredictorParams.from_config(config)
    model_class = LitTrajectoryPredictor
    ttc_params = TTCCostParams.from_config(config)
    return model_class(params=params, unnormalizer=unnormalizer, cost_params=ttc_params)


def load_from_wandb_id(
    log_id: str,
    log_path: str,
    entity: str,
    project: str,
    config: Optional[Config] = None,
    load_last=False,
) -> Tuple[Union[LitTrajectoryPredictor, LitTrajectoryPredictor], Config]:
    """
    Load a model using a wandb id code.
    Args:
        log_id: the wandb id code
        log_path: the wandb log directory path
        config: An optional configuration argument, use these settings if not None, use the settings from the log directory otherwise
        load_last: An optional argumument, set to True to load the last checkpoint instead of the best one
    Returns:
        Predictor model and config file either loaded from the checkpoint or the one passed as argument.
    """
    list_matching = list(filter(lambda path: log_id in path, os.listdir(log_path)))
    if len(list_matching) == 1:
        list_ckpt = list(
            filter(
                lambda path: "epoch" in path and ".ckpt" in path,
                os.listdir(os.path.join(log_path, list_matching[0], "files")),
            )
        )
        if not load_last and len(list_ckpt) == 1:
            print(f"Loading best model: {list_ckpt[0]}.")
            checkpoint_path = os.path.join(
                log_path, list_matching[0], "files", list_ckpt[0]
            )
        else:
            print(f"Loading last checkpoint.")
            checkpoint_path = os.path.join(
                log_path, list_matching[0], "files", "last.ckpt"
            )
        config_path = os.path.join(
            log_path, list_matching[0], "files", "learning_config.py"
        )

        if config is None:
            config = config_argparse(config_path)
            distant_model_type = None
        else:
            distant_config = config_argparse(config_path)
            distant_model_type = distant_config.model_type
        config["load_from"] = log_id

        if config.model_type == "interaction_biased":
            dataloaders = WaymoDataloaders(config)
        else:
            [data_train, data_val, data_test] = load_create_dataset(config)
            dataloaders = SceneDataLoaders(
                state_dim=config.state_dim,
                num_steps=config.num_steps,
                num_steps_future=config.num_steps_future,
                batch_size=config.batch_size,
                data_train=data_train,
                data_val=data_val,
                data_test=data_test,
                num_workers=config.num_workers,
            )

        try:
            if len(config.gpus):
                map_location = "cpu"
            else:
                map_location = "gpu"
            model = load_weights(
                get_predictor(config, dataloaders.unnormalize_trajectory),
                torch.load(checkpoint_path, map_location=map_location),
                strict=True,
            )
        except RuntimeError:
            raise RuntimeError(
                f"The source model is of type {distant_model_type}."
                + " It cannot be used to load the weights of the interaction biased model."
            )

        return model, dataloaders, config

    else:
        print("Trying to download logs from WandB...")
        api = wandb.Api()
        run = api.run(entity + "/" + project + "/" + log_id)
        if run is not None:
            checkpoint_path = os.path.join(
                log_path, "downloaded_run-" + log_id, "files"
            )
            os.makedirs(checkpoint_path)
            for file in run.files():
                if file.name.endswith("ckpt") or file.name.endswith("config.py"):
                    file.download(checkpoint_path)
            return load_from_wandb_id(
                log_id, log_path, entity, project, config, load_last
            )
        else:
            raise RuntimeError(
                f"Error while loading checkpoint: Found {len(list_matching)} occurences of the given id {log_id} in the logs at {log_path}."
            )


def load_from_config(cfg: Config):
    """
    This function loads the predictor model and the data depending on which one is selected in the config.
    If a "load_from" field is not empty, then tries to load the pre-trained model from the checkpoint.
    The matching config file is loaded

    Args:
        cfg : Configuration that defines the model to be loaded

    Returns:
        loaded model and a new version of the config that is compatible with the checkpoint model that it could be loaded from
    """

    log_path = os.path.join(cfg.log_path, "wandb")
    ignored_keys = [
        "project",
        "dataset_parameters",
        "load_from",
        "force_config",
        "load_last",
    ]

    if "load_from" in cfg.keys() and cfg.load_from != "" and cfg.load_from:
        if "load_last" in cfg.keys():
            load_last = cfg["load_last"]
        else:
            load_last = False
        if cfg.force_config:
            warnings.warn(
                f"Using local configuration but loading from run {cfg.load_from}. Will fail if local configuration is not compatible."
            )
            predictor, dataloaders, config = load_from_wandb_id(
                log_id=cfg.load_from,
                log_path=log_path,
                entity=cfg.entity,
                project=cfg.project,
                config=cfg,
                load_last=load_last,
            )
        else:
            predictor, dataloaders, config = load_from_wandb_id(
                log_id=cfg.load_from,
                log_path=log_path,
                entity=cfg.entity,
                project=cfg.project,
                load_last=load_last,
            )
            difference = False
            warning_message = ""
            for key, item in cfg.items():
                try:
                    if config[key] != item:
                        if not difference:
                            warning_message += "When loading the model, the configuration was changed to match the configuration of the pre-trained model to be loaded.\n"
                            difference = True
                        if key not in ignored_keys:
                            warning_message += f"    The value of '{key}' is now '{config[key]}' instead of '{item}'."
                except KeyError:
                    if not difference:
                        warning_message += "When loading the model, the configuration was changed to match the configuration of the pre-trained model to be loaded."
                        difference = True
                    warning_message += f"    The parameter '{key}' with value '{item}' does not exist for the model you are loading from, it is added."
                    config[key] = item
            if warning_message != "":
                warnings.warn(warning_message)
        return predictor, dataloaders, config

    else:
        if cfg.model_type == "interaction_biased":
            dataloaders = WaymoDataloaders(cfg)
        else:
            [data_train, data_val, data_test] = load_create_dataset(cfg)
            dataloaders = SceneDataLoaders(
                state_dim=cfg.state_dim,
                num_steps=cfg.num_steps,
                num_steps_future=cfg.num_steps_future,
                batch_size=cfg.batch_size,
                data_train=data_train,
                data_val=data_val,
                data_test=data_test,
                num_workers=cfg.num_workers,
            )

        predictor = get_predictor(cfg, dataloaders.unnormalize_trajectory)
        return predictor, dataloaders, cfg