from models import IntuitionKillingMachine
from transforms import undo_box_transforms_batch, ToTensor, Normalize, SquarePad, Resize, NormalizeBoxCoords
from torchvision.transforms import Compose
from encoders import get_tokenizer
from PIL import Image, ImageDraw
from zipfile import ZipFile
from copy import copy
import gradio as gr
import pandas as pd
import torch

def parse_model_args(model_path):
    _, _, dataset, max_length, input_size, backbone, num_heads, num_layers, num_conv, _, _, mu, mask_pooling = model_path.split('_')[:13]
    return {
        'dataset': dataset,
        'max_length': int(max_length),
        'input_size': int(input_size),
        'backbone': backbone,
        'num_heads': int(num_heads),
        'num_layers': int(num_layers),
        'num_conv': int(num_conv),
        'mu': float(mu),
        'mask_pooling': bool(mask_pooling == '1')
    }


class Prober:
    def __init__(self,
                 df_path=None,
                 dataset_path=None,
                 model_checkpoint=None):
        params = parse_model_args(model_checkpoint)
        mean = [0.485, 0.456, 0.406]
        sdev = [0.229, 0.224, 0.225]
        self.tokenizer = get_tokenizer()
        self.df = pd.read_json(df_path)[['sample_idx', 'bbox', 'file_path', 'sent']]
        self.df.loc[:, "image_id"] = self.df.file_path.apply(lambda x: int(x.split('/')[-1][:-4]))
        self.df.file_path = self.df.file_path.apply(lambda x: x.replace('refer/data/images/', ''))
        self.model = IntuitionKillingMachine(
            backbone=params['backbone'],
            pretrained=True,
            num_heads=params['num_heads'],
            num_layers=params['num_layers'],
            num_conv=params['num_conv'],
            segmentation_head=bool(params['mu'] > 0.0),
            mask_pooling=params['mask_pooling']
        ) 
        self.load_model(model_checkpoint)
        self.transform = Compose([
            ToTensor(),
            Normalize(mean, sdev),
            SquarePad(),
            Resize(size=(params['input_size'], params['input_size'])),
            NormalizeBoxCoords(),
        ])
        self.max_length = 30
        self.zipfile = ZipFile(dataset_path, 'r')

    def load_model(self, model_checkpoint):
        checkpoint = torch.load(
            model_checkpoint, map_location=lambda storage, loc: storage
        )

        # strip 'model.' from pl checkpoint
        state_dict = {
            k[len('model.'):]: v
            for k, v in checkpoint['state_dict'].items()
        }

        missing, _ = self.model.load_state_dict(state_dict, strict=False)

        # ensure the only missing keys are those of the segmentation head only
        assert [k for k in missing if 'segm' not in k] == []

        self.model = self.model.eval()

    def preview_image(self, idx):
        img_path, target, = self.df.loc[idx][['file_path','bbox']].values
        img = Image.open(self.zipfile.open(img_path)).convert('RGB')
        return img

    @torch.no_grad()
    def probe(self, idx, re, search_by_sample_id: bool= True):
        if search_by_sample_id:
            img_path, target, = self.df.loc[idx][['file_path','bbox']].values
        else: 
            img_path, target = self.df[self.df.image_id == idx][['file_path','bbox']].values[0]
        img = Image.open(self.zipfile.open(img_path)).convert('RGB')
        if re != "":
            W0, H0 = img.size
            sample = {
                'image': img,
                'image_size': (H0, W0),  # image original size
                'bbox': torch.tensor([copy(target)]),
                'bbox_raw': torch.tensor([copy(target)]),
                'mask': torch.ones((1, H0, W0), dtype=torch.float32),  # visibiity mask
                'mask_bbox': None,  # target bbox mask
            } 
            sample = self.transform(sample)
            tok = self.tokenizer(re,
                                 max_length=30,
                                 return_tensors='pt',
                                 truncation=True)
            inn = {'image': torch.stack([sample['image']]),
                   'mask': torch.stack([sample['mask']]),
                   'tok': tok}
            output = undo_box_transforms_batch(self.model(inn)[0],
                                               [sample['tr_param']]).numpy().tolist()[0]
            img1 = ImageDraw.Draw(img)
            #img1.rectangle(target, outline ="#0000FF00", width=3)
            img1.rectangle(output, outline ="#00FF0000", width=3)
            return img
        else:
            return img
    
prober = Prober(
    df_path = 'data/val-sim_metric.json',
    dataset_path = "data/saiapr_tc-12.zip",
    model_checkpoint= "cache/20211220_191132_refclef_32_512_resnet50_8_6_8_0.1_0.1_0.1_0_0.0001_0.0_12_4_90_1_0_0_0/best.ckpt"
)

demo = gr.Interface(fn=prober.probe, inputs=["number", "text"], outputs="image")#, live=True)

demo.queue(concurrency_count=10)
demo.launch()