import os
from typing import Tuple, List

import torch
import numpy as np
import torch.nn
import torch.nn.functional

from atoms_detection.detection import Detection
from atoms_detection.training_model import model_pipeline
from atoms_detection.image_preprocessing import dl_prepro_image
from utils.constants import ModelArgs
from utils.paths import PREDS_PATH


class DLDetection(Detection):
    def __init__(self,
                 model_name: ModelArgs,
                 ckpt_filename: str,
                 dataset_csv: str,
                 threshold: float,
                 detections_path: str,
                 inference_cache_path: str,
                 batch_size: int = 64,
                 ):
        self.model_name = model_name
        self.ckpt_filename = ckpt_filename
        self.device = self.get_torch_device()
        self.batch_size = batch_size

        self.stride = 1
        self.padding = 10
        self.window_size = (21, 21)

        super().__init__(dataset_csv, threshold, detections_path, inference_cache_path)

    @staticmethod
    def get_torch_device():
        if torch.backends.mps.is_available():
            device = torch.device("mps")
        elif torch.cuda.is_available():
            device = torch.device("cuda")
        else:
            device = torch.device("cpu")
        return device

    def sliding_window(self, image: np.ndarray, padding: int = 0) -> Tuple[int, int, np.ndarray]:
        # slide a window across the image
        x_to_center = self.window_size[0] // 2 - 1 if self.window_size[0] % 2 == 0 else self.window_size[0] // 2
        y_to_center = self.window_size[1] // 2 - 1 if self.window_size[1] % 2 == 0 else self.window_size[1] // 2

        for y in range(0, image.shape[0] - self.window_size[1]+1, self.stride):
            for x in range(0, image.shape[1] - self.window_size[0]+1, self.stride):
                # yield the current window
                center_x = x + x_to_center
                center_y = y + y_to_center
                yield center_x-padding, center_y-padding, image[y:y + self.window_size[1], x:x + self.window_size[0]]

    def batch_sliding_window(self, image: np.ndarray, padding: int = 0) -> Tuple[List[int], List[int], List[np.ndarray]]:
        x_idx_list = []
        y_idx_list = []
        images_list = []
        count = 0
        for _x, _y, _img in self.sliding_window(image, padding=padding):
            x_idx_list.append(_x)
            y_idx_list.append(_y)
            images_list.append(_img)
            count += 1
            if count == self.batch_size:
                yield x_idx_list, y_idx_list, images_list
                x_idx_list = []
                y_idx_list = []
                images_list = []
                count = 0
        if count != 0:
            yield x_idx_list, y_idx_list, images_list

    def padding_image(self, img: np.ndarray) -> np.ndarray:
        image_padded = np.zeros((img.shape[0] + self.padding*2, img.shape[1] + self.padding*2))
        image_padded[self.padding:-self.padding, self.padding:-self.padding] = img
        return image_padded

    def load_model(self) -> torch.nn.Module:
        checkpoint = torch.load(self.ckpt_filename, map_location=self.device)

        model = model_pipeline[self.model_name](num_classes=2).to(self.device)
        model.load_state_dict(checkpoint['state_dict'])
        model.eval()
        return model

    def images_to_torch_input(self, images_list: List[np.ndarray]) -> torch.Tensor:
        expanded_img = np.expand_dims(images_list, axis=1)
        input_tensor = torch.from_numpy(expanded_img).float()
        input_tensor = input_tensor.to(self.device)
        return input_tensor

    def get_prediction_map(self, padded_image: np.ndarray) -> np.ndarray:
        _shape = padded_image.shape
        pred_map = np.zeros((_shape[0] - self.padding*2, _shape[1] - self.padding*2))
        model = self.load_model()
        for x_idxs, y_idxs, image_crops in self.batch_sliding_window(padded_image, padding=self.padding):
            torch_input = self.images_to_torch_input(image_crops)
            output = model(torch_input)
            pred_prob = torch.nn.functional.softmax(output, 1)
            pred_prob = pred_prob.detach().cpu().numpy()[:, 1]
            pred_map[np.array(y_idxs), np.array(x_idxs)] = pred_prob
        return pred_map

    def image_to_pred_map(self, img: np.ndarray, return_intermediate: bool = False) -> np.ndarray:
        preprocessed_img = dl_prepro_image(img)
        print(f"preprocessed_img.shape: {preprocessed_img.shape}, μ: {np.mean(preprocessed_img)}, σ: {np.std(preprocessed_img)}")
        padded_image = self.padding_image(preprocessed_img)
        print(f"padded_image.shape: {padded_image.shape}, μ: {np.mean(padded_image)}, σ: {np.std(padded_image)}")
        pred_map = self.get_prediction_map(padded_image)
        print(f"pred_map.shape: {pred_map.shape}, μ: {np.mean(pred_map)}, σ: {np.std(pred_map)}")
        if return_intermediate:
            return preprocessed_img, padded_image, pred_map
        return pred_map