import cv2
import torch
import numpy as np
import time
from midas.model_loader import default_models, load_model
import os
import urllib.request

MODEL_FILE_URL = {
    "midas_v21_small_256" : "https://github.com/isl-org/MiDaS/releases/download/v2_1/midas_v21_small_256.pt",
    "dpt_hybrid_384" : "https://github.com/isl-org/MiDaS/releases/download/v3/dpt_hybrid_384.pt",
    "dpt_large_384" : "https://github.com/isl-org/MiDaS/releases/download/v3/dpt_large_384.pt",
    "dpt_swin2_large_384" : "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_large_384.pt",
    "dpt_beit_large_512" : "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_large_512.pt",  
}

class MonocularDepthEstimator:
    def __init__(self,
        model_type="midas_v21_small_256",
        model_weights_path="models/", 
        optimize=False, 
        side_by_side=False, 
        height=None, 
        square=False, 
        grayscale=False):

        # model type
        # MiDaS 3.1:
        # For highest quality: dpt_beit_large_512
        # For moderately less quality, but better speed-performance trade-off: dpt_swin2_large_384
        # For embedded devices: dpt_swin2_tiny_256, dpt_levit_224
        # For inference on Intel CPUs, OpenVINO may be used for the small legacy model: openvino_midas_v21_small .xml, .bin
        
        # MiDaS 3.0: 
        # Legacy transformer models dpt_large_384 and dpt_hybrid_384

        # MiDaS 2.1: 
        # Legacy convolutional models midas_v21_384 and midas_v21_small_256
        
        # params
        print("Initializing parameters and model...")
        self.is_optimize = optimize
        self.is_square = square
        self.is_grayscale = grayscale
        self.height = height
        self.side_by_side = side_by_side

        # select device
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print("Running inference on : %s" % self.device)

        # loading model
        if not os.path.exists(model_weights_path+model_type+".pt"):
            print("Model file not found. Downloading...")
            # Download the model file
            urllib.request.urlretrieve(MODEL_FILE_URL[model_type], model_weights_path+model_type+".pt")
            print("Model file downloaded successfully.")

        self.model, self.transform, self.net_w, self.net_h = load_model(self.device, model_weights_path+model_type+".pt", 
                                                                        model_type, optimize, height, square)    
        print("Net width and height: ", (self.net_w, self.net_h))
        

    def predict(self, image, model, target_size):
        

        # convert img to tensor and load to gpu
        img_tensor = torch.from_numpy(image).to(self.device).unsqueeze(0)

        if self.is_optimize and self.device == torch.device("cuda"):
            img_tensor = img_tensor.to(memory_format=torch.channels_last)
            img_tensor = img_tensor.half()
        
        prediction = model.forward(img_tensor)
        prediction = (
            torch.nn.functional.interpolate(
                prediction.unsqueeze(1),
                size=target_size[::-1],
                mode="bicubic",
                align_corners=False,
            )
            .squeeze()
            .cpu()
            .numpy()
        )

        return prediction

    def process_prediction(self, depth_map):
        """
        Take an RGB image and depth map and place them side by side. This includes a proper normalization of the depth map
        for better visibility.
        Args:
            original_img: the RGB image
            depth_img: the depth map
            is_grayscale: use a grayscale colormap?
        Returns:
            the image and depth map place side by side
        """

        # normalizing depth image
        depth_min = depth_map.min()
        depth_max = depth_map.max()
        normalized_depth = 255 * (depth_map - depth_min) / (depth_max - depth_min)
        
        # normalized_depth *= 3
        # grayscale_depthmap = np.repeat(np.expand_dims(normalized_depth, 2), 3, axis=2) / 3
        grayscale_depthmap = np.repeat(np.expand_dims(normalized_depth, 2), 3, axis=2)
        depth_colormap = cv2.applyColorMap(np.uint8(grayscale_depthmap), cv2.COLORMAP_INFERNO)  
            
        return normalized_depth/255, depth_colormap/255

    def make_prediction(self, image):
        image = image.copy()
        with torch.no_grad():
            original_image_rgb = np.flip(image, 2)  # in [0, 255] (flip required to get RGB)
            # resizing the image to feed to the model
            image_tranformed = self.transform({"image": original_image_rgb/255})["image"]

            # monocular depth prediction
            pred = self.predict(image_tranformed, self.model, target_size=original_image_rgb.shape[1::-1]) 

            # process the model predictions
            depthmap, depth_colormap = self.process_prediction(pred)
        return depthmap, depth_colormap

    def run(self, input_path):
        
        # input video
        cap = cv2.VideoCapture(input_path)

        # Check if camera opened successfully
        if not cap.isOpened():
            print("Error opening video file")

        with torch.no_grad():
             while cap.isOpened():

                # Capture frame-by-frame
                inference_start_time = time.time()
                ret, frame = cap.read()                

                if ret == True:
                    _, depth_colormap = self.make_prediction(frame)                    
                    inference_end_time = time.time()
                    fps = round(1/(inference_end_time - inference_start_time))
                    cv2.putText(depth_colormap, f'FPS: {fps}', (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (10, 255, 100), 2)
                    cv2.imshow('MiDaS Depth Estimation - Press Escape to close window ', depth_colormap)

                    # Press ESC on keyboard to exit
                    if cv2.waitKey(1) == 27:  # Escape key
                        break
                
                else:
                    break


        # When everything done, release
        # the video capture object
        cap.release()
        
        # Closes all the frames
        cv2.destroyAllWindows()



if __name__ == "__main__":
    # params
    INPUT_PATH = "assets/videos/testvideo2.mp4"

    os.environ['CUDA_VISIBLE_DEVICES'] = '0'

     # set torch options
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = True
    
    depth_estimator = MonocularDepthEstimator(model_type="dpt_hybrid_384")
    depth_estimator.run(INPUT_PATH)