from typing import List

import argparse
import os

from atoms_detection.create_crop_dataset import create_contrastive_crops_dataset
from atoms_detection.dl_detection import DLDetection
from atoms_detection.dl_detection_with_gmm import DLGMMdetection
from atoms_detection.evaluation import Evaluation
from atoms_detection.training_model import train_model
from utils.paths import (
    CROPS_PATH,
    CROPS_DATASET,
    MODELS_PATH,
    LOGS_PATH,
    DETECTION_PATH,
    PREDS_PATH,
    PRED_GT_VIS_PATH,
)
from utils.constants import ModelArgs, Split
from matplotlib import pyplot as plt
import pandas as pd
from PIL import Image
import numpy as np

from visualizations.prediction_gt_images import get_gt_coords
from visualizations.utils import plot_gt_pred_on_img


def dl_full_pipeline(
    extension_name: str,
    architecture: ModelArgs,
    coords_csv: str,
    thresholds_list: List[float],
    force_create_dataset: bool = False,
    force_evaluation: bool = False,
    show_sampling_image: bool = False,
    train: bool = False,
    visualise: bool = False,
    upsample: bool = False,
    upsample_neg_amount: float = 0,
    clip_max: float = 1,
    negative_dist: float = 1.1,
):
    # Create crops data
    crops_folder = CROPS_PATH + f"_{extension_name}"
    crops_dataset = CROPS_DATASET.replace(".csv", f"_{extension_name}.csv")
    print(os.path.exists(crops_dataset))
    if force_create_dataset or not os.path.exists(crops_dataset):
        print("Creating crops dataset...")
        create_contrastive_crops_dataset(
            crops_folder,
            coords_csv,
            crops_dataset,
            show_sampling_result=show_sampling_image,
            pos_data_upsampling=upsample,
            neg_upsample_multiplier=upsample_neg_amount,
            contrastive_distance_multiplier=negative_dist,
        )  # , clip=clip_max
        # training DL model
    ckpt_filename = os.path.join(MODELS_PATH, f"model_{extension_name}.ckpt")
    if train or not os.path.exists(ckpt_filename):
        print("Training DL crops model...")
        train_model(architecture, crops_dataset, crops_folder, ckpt_filename)

    for threshold in thresholds_list:
        inference_cache_path = os.path.join(
            PREDS_PATH, f"dl_detection_{extension_name}"
        )
        detections_path = os.path.join(
            DETECTION_PATH,
            f"dl_detection_{extension_name}",
            f"dl_detection_{extension_name}_{threshold}",
        )
        if force_evaluation or visualise or not os.path.exists(detections_path):
            print(f"Detecting atoms on test data with threshold={threshold}...")
            if args.run_gmm_for_multimers:
                detection_pipeline = DLGMMdetection
            else:
                detection_pipeline = DLDetection

            detection = detection_pipeline(
                model_name=architecture,
                ckpt_filename=ckpt_filename,
                dataset_csv=coords_csv,
                threshold=threshold,
                detections_path=detections_path,
                inference_cache_path=inference_cache_path,
            )
            detection.run()

        logging_filename = os.path.join(
            LOGS_PATH,
            f"dl_evaluation_{extension_name}",
            f"dl_evaluation_{extension_name}_{threshold}.csv",
        )
        if force_evaluation or visualise or not os.path.exists(logging_filename):
            evaluation = Evaluation(
                coords_csv=coords_csv,
                predictions_path=detections_path,
                logging_filename=logging_filename,
            )
            evaluation.run()
        if visualise:
            vis_folder = os.path.join(
                PRED_GT_VIS_PATH, f"dl_detection_{extension_name}"
            )
            if not os.path.exists(vis_folder):
                os.makedirs(vis_folder)

            vis_folder = os.path.join(
                vis_folder, f"dl_detection_{extension_name}_{threshold}"
            )
            if not os.path.exists(vis_folder):
                os.makedirs(vis_folder)
            is_evaluation = True
            if is_evaluation:
                gt_coords_dict = get_gt_coords(evaluation.coordinates_dataset)

            for image_path in detection.image_dataset.iterate_data(Split.TEST):
                img_name = os.path.split(image_path)[-1]
                gt_coords = gt_coords_dict[img_name] if is_evaluation else None
                pred_df_path = os.path.join(
                    detections_path, os.path.splitext(img_name)[0] + ".csv"
                )
                df_predicted = pd.read_csv(pred_df_path)
                pred_coords = [
                    (row["x"], row["y"]) for _, row in df_predicted.iterrows()
                ]
                img = Image.open(image_path)
                img_arr = np.array(img).astype(np.float32)
                img_normed = (img_arr - img_arr.min()) / (img_arr.max() - img_arr.min())

                plot_gt_pred_on_img(img_normed, gt_coords, pred_coords)
                clean_image_name = os.path.splitext(img_name)[0]
                vis_path = os.path.join(vis_folder, f"{clean_image_name}.png")
                plt.savefig(
                    vis_path, bbox_inches="tight", pad_inches=0.0, transparent=True
                )
                plt.close()


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("extension_name", type=str, help="Experiment extension name")
    parser.add_argument(
        "architecture", type=ModelArgs, choices=ModelArgs, help="Architecture name"
    )
    parser.add_argument(
        "coords_csv", type=str, help="Coordinates CSV file to use as input"
    )
    parser.add_argument(
        "-t" "--thresholds", nargs="+", type=float, help="Threshold value"
    )
    parser.add_argument(
        "-c", type=float, default=1, help="Clipping quantile (0..1]. CURRENTLY USELESS!"
    )
    parser.add_argument(
        "-nd", type=float, default=1.1, help="Negative contrastive crop distance"
    )
    parser.add_argument("--force_create_dataset", action="store_true")
    parser.add_argument("--force_evaluation", action="store_true")
    parser.add_argument("--show_sampling_result", action="store_true")
    parser.add_argument("--train", action="store_true")
    parser.add_argument("--visualise", action="store_true")
    parser.add_argument("--upsample", action="store_true")
    parser.add_argument(
        "--run_gmm_for_multimers",
        action="store_true",
        help="If selected, a postprocessing will be run to split large atoms (possible multimers) with a GMM",
    )
    parser.add_argument(
        "--upsample_neg",
        type=float,
        default=0,
        help="Upsample amount for negative crops during training",
    )
    return parser.parse_args()


if __name__ == "__main__":
    args = get_args()
    print(args)
    dl_full_pipeline(
        args.extension_name,
        args.architecture,
        args.coords_csv,
        args.t__thresholds,
        args.force_create_dataset,
        args.force_evaluation,
        args.show_sampling_result,
        args.train,
        args.visualise,
        args.upsample,
        args.upsample_neg,
        args.c,
        args.nd,
    )