from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation
from PIL import Image, ImageDraw
import numpy as np
from torch import nn
import gradio as gr
import os
import torch
import time

feature_extractor = SegformerFeatureExtractor.from_pretrained("nvidia/segformer-b5-finetuned-cityscapes-1024-1024")
model = SegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b5-finetuned-cityscapes-1024-1024")

device = 'cuda' if torch.cuda.is_available() else 'cpu'

print(f"Is CUDA available: {torch.cuda.is_available()} --> {device=}")
if (torch.cuda.is_available()):
    print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
 
model.to(device)

# https://github.com/NielsRogge/Transformers-Tutorials/blob/master/SegFormer/Segformer_inference_notebook.ipynb

def cityscapes_palette():
    """Cityscapes palette for external use."""
    return [[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156],
            [190, 153, 153], [153, 153, 153], [250, 170, 30], [220, 220, 0],
            [107, 142, 35], [152, 251, 152], [70, 130, 180], [220, 20, 60],
            [255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100], [0, 80, 100],
            [0, 0, 230], [119, 11, 32]]

def cityscapes_classes():
    """Cityscapes class names for external use."""
    return [
        'road', 'sidewalk', 'building', 'wall', 'fence', 'pole',
        'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky',
        'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle',
        'bicycle'
    ]

def annotation(image:ImageDraw, color_seg:np.array):
    assert image.size == (1024, 1024)
    assert color_seg.shape == (1024, 1024, 3)
    blocks = 4 # 4x4 sub grid
    step_size = 256 # sub square edge size

    draw = ImageDraw.Draw(image)

    sub_square_xy = [(x,y) for x in range(0, blocks * step_size, step_size) for y in range(0, blocks * step_size, step_size)]
    # print(f"{sub_square_xy=}")
    for (x,y) in sub_square_xy:
        reduced_seg = color_seg.sum(axis=2) # collapsing all colors into 1024 x 1024
        # print(f"{reduced_seg.shape=}")
        
        sub_square_seg = reduced_seg[ y:y+step_size, x:x+step_size]
        # print(f"{sub_square_seg.shape=}, {sub_square_seg.sum()}")
        
        if (sub_square_seg.sum() > 100000): 
            print("light found at square ", x, y)
            draw.rectangle([(x, y), (x + step_size, y + step_size)], outline="white", width=3)

def call(image): #nparray
    start = time.time()

    resized = Image.fromarray(image).resize((1024,1024))
    resized_image = np.array(resized)
    print(f"{np.array(resized_image).shape=}") # 1024, 1024, 3

    print(f"*processing time: {(time.time() - start):.2f} s")
    # resized_image = Image.fromarray(resized_image_np)
    # print(f"{resized_image=}")

    inputs = feature_extractor(images=resized_image, return_tensors="pt").to(device)
    
    print(f"**processing time: {(time.time() - start):.2f} s")

    outputs  = model(**inputs)
    logits = outputs.logits.cpu()

    print(f"{logits.shape=}") # shape (batch_size, num_labels, height/4, width/4) -> 3, 19, 256 ,256
    # print(f"{logits}")

    print(f"***processing time: {(time.time() - start):.2f} s")
    # First, rescale logits to original image size
    interpolated_logits =  nn.functional.interpolate(
        logits,
        size=[1024, 1024], #resized_image.size[::-1], # (height, width)
        mode='bilinear',
        align_corners=False)
    print(f"{interpolated_logits.shape=}, {logits.shape=}") # 1, 19, 1024, 1024

    # Second, apply argmax on the class dimension
    seg = interpolated_logits.argmax(dim=1)[0]
    print(f"{seg.shape=}")
    color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) # height, width, 3
    print(f"{color_seg.shape=}")

    for label, color in enumerate(cityscapes_palette()):
        if (label == 6): color_seg[seg == label, :] = color

    # Convert to BGR
    color_seg = color_seg[..., ::-1]
    print(f"{color_seg.shape=}")

    print(f"****processing time: {(time.time() - start):.2f} s")

    # Show image + mask
    img = np.array(resized_image) * 0.5 + color_seg * 0.5
    img = img.astype(np.uint8)

    out_im_file = Image.fromarray(img)
    annotation(out_im_file, color_seg)

    print(f"--> processing time: {(time.time() - start):.2f} s")

    return out_im_file

# original_image = Image.open("./examples/1.jpg")
# print(f"{np.array(original_image).shape=}") # eg 729, 1000, 3

# out = call(original_image)
# out.save("out2.jpeg")

title = "Traffic Light Detector"
description = "Experiment traffic light detection to evaluate the value of captcha security controls"

iface = gr.Interface(fn=call, 
                     inputs="image", 
                     outputs="image", 
                     title=title, 
                     description=description, 
                     examples=[
                       os.path.join(os.path.dirname(__file__), "examples/1.jpg"),
                       os.path.join(os.path.dirname(__file__), "examples/2.jpg"),
                       os.path.join(os.path.dirname(__file__), "examples/3.jpg"),
                       os.path.join(os.path.dirname(__file__), "examples/4.jpg"),
                       os.path.join(os.path.dirname(__file__), "examples/5.jpg"),
                       os.path.join(os.path.dirname(__file__), "examples/6.jpg"),
                     ],
                     thumbnail="thumbnail.webp")
iface.launch()