import gradio as gr

image_path = './image001.png'
sentence = 'spoon on the dish'
weights = './checkpoints/gradio.pth'
device = 'cpu'

# pre-process the input image
from PIL import Image
import torchvision.transforms as T
import numpy as np
import datetime
import os
import time

import torch
import torch.utils.data
from torch import nn

from bert.multimodal_bert import MultiModalBert
import torchvision

from lib import multimodal_segmentation_ppm
#import transforms as T
import utils

import numpy as np
from PIL import Image
import torch.nn.functional as F

from modeling.MaskFormerModel import MaskFormerHead
from addict import Dict
#from bert.modeling_bert import BertLMPredictionHead, BertEncoder
import cv2
import textwrap

class WrapperModel(nn.Module):
    def __init__(self, image_model, language_model, classifier) :
        super(WrapperModel, self).__init__()
        self.image_model = image_model
        self.language_model = language_model
        self.classifier = classifier

        config = Dict({
          "architectures": [
           "BertForMaskedLM"
          ],
          "attention_probs_dropout_prob": 0.1,
          "gradient_checkpointing": False,
          "hidden_act": "gelu",
          "hidden_dropout_prob": 0.1,
          "hidden_size": 512,
          "initializer_range": 0.02,
          "intermediate_size": 3072,
          "layer_norm_eps": 1e-12,
          #"max_position_embeddings": 16+20,
          "model_type": "bert",
          "num_attention_heads": 8,
          "num_hidden_layers": 8,
         "pad_token_id": 0,
          "position_embedding_type": "absolute",
          "transformers_version": "4.6.0.dev0",
          "type_vocab_size": 2,
          "use_cache": True,
          "vocab_size": 30522
        })



    def _get_binary_mask(self, target):
        # 返回每类的binary mask
        y, x = target.size()
        target_onehot = torch.zeros(self.num_classes + 1, y, x)
        target_onehot = target_onehot.scatter(dim=0, index=target.unsqueeze(0), value=1)
        return target_onehot[1:]

    def semantic_inference(self, mask_cls, mask_pred):       
        mask_cls = F.softmax(mask_cls, dim=1)[...,1:]
        mask_pred = mask_pred.sigmoid()      
        semseg = torch.einsum("bqc,bqhw->bchw", mask_cls, mask_pred)        
        return semseg

    def forward(self, image, sentences, attentions): 
        print(image.sum(), sentences.sum(), attentions.sum())
        input_shape = image.shape[-2:]
        l_mask = attentions.unsqueeze(dim=-1)

        i0, Wh, Ww = self.image_model.forward_stem(image)
        l0, extended_attention_mask = self.language_model.forward_stem(sentences, attentions)

        i1 = self.image_model.forward_stage1(i0, Wh, Ww)
        l1 = self.language_model.forward_stage1(l0, extended_attention_mask)
        i1_residual, H, W, i1_temp, Wh, Ww  = self.image_model.forward_pwam1(i1, Wh, Ww, l1, l_mask)
        l1_residual, l1 = self.language_model.forward_pwam1(i1, l1, extended_attention_mask) 
        i1 = i1_temp

        i2 = self.image_model.forward_stage2(i1, Wh, Ww)
        l2 = self.language_model.forward_stage2(l1, extended_attention_mask)
        i2_residual, H, W, i2_temp, Wh, Ww  = self.image_model.forward_pwam2(i2, Wh, Ww, l2, l_mask)
        l2_residual, l2 = self.language_model.forward_pwam2(i2, l2, extended_attention_mask) 
        i2 = i2_temp

        i3 = self.image_model.forward_stage3(i2, Wh, Ww)
        l3 = self.language_model.forward_stage3(l2, extended_attention_mask)
        i3_residual, H, W, i3_temp, Wh, Ww  = self.image_model.forward_pwam3(i3, Wh, Ww, l3, l_mask)
        l3_residual, l3 = self.language_model.forward_pwam3(i3, l3, extended_attention_mask) 
        i3 = i3_temp

        i4 = self.image_model.forward_stage4(i3, Wh, Ww)
        l4 = self.language_model.forward_stage4(l3, extended_attention_mask)
        i4_residual, H, W, i4_temp, Wh, Ww  = self.image_model.forward_pwam4(i4, Wh, Ww, l4, l_mask)
        l4_residual, l4 = self.language_model.forward_pwam4(i4, l4, extended_attention_mask) 
        i4 = i4_temp

        #i1_residual, i2_residual, i3_residual, i4_residual = features
        #x = self.classifier(i4_residual, i3_residual, i2_residual, i1_residual)
        #x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=True)
        outputs = {}
        outputs['s1'] = i1_residual
        outputs['s2'] = i2_residual
        outputs['s3'] = i3_residual
        outputs['s4'] = i4_residual

        predictions = self.classifier(outputs)
        return predictions

#img = Image.open(image_path).convert("RGB")

# pre-process the raw sentence
from bert.tokenization_bert import BertTokenizer
import torch

# initialize model and load weights
#from bert.modeling_bert import BertModel
#from lib import segmentation

# construct a mini args class; like from a config file


class args:
    swin_type = 'base'
    window12 = True
    mha = ''
    fusion_drop = 0.0


#single_model = segmentation.__dict__['lavt'](pretrained='', args=args)
single_model = multimodal_segmentation_ppm.__dict__['lavt'](pretrained='',args=args)
single_model.to(device)
model_class = MultiModalBert
single_bert_model = model_class.from_pretrained('bert-base-uncased', embed_dim=single_model.backbone.embed_dim)
single_bert_model.pooler = None

input_shape = dict()
input_shape['s1'] = Dict({'channel': 128,  'stride': 4})
input_shape['s2'] = Dict({'channel': 256,  'stride': 8})
input_shape['s3'] = Dict({'channel': 512,  'stride': 16})
input_shape['s4'] = Dict({'channel': 1024, 'stride': 32})



cfg = Dict()
cfg.MODEL.SEM_SEG_HEAD.COMMON_STRIDE = 4
cfg.MODEL.MASK_FORMER.DROPOUT = 0.0 
cfg.MODEL.MASK_FORMER.NHEADS = 8
cfg.MODEL.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS = 4
cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM = 256
cfg.MODEL.SEM_SEG_HEAD.MASK_DIM = 256
cfg.MODEL.SEM_SEG_HEAD.DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES = ["s1", "s2", "s3", "s4"]

cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES = 1
cfg.MODEL.MASK_FORMER.HIDDEN_DIM = 256
cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES = 1
cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD = 2048
cfg.MODEL.MASK_FORMER.DEC_LAYERS = 10
cfg.MODEL.MASK_FORMER.PRE_NORM = False


maskformer_head = MaskFormerHead(cfg, input_shape)


model = WrapperModel(single_model.backbone, single_bert_model, maskformer_head)



checkpoint = torch.load(weights, map_location='cpu')

model.load_state_dict(checkpoint, strict=False)
model.to(device)
model.eval()
#single_bert_model.load_state_dict(checkpoint['bert_model'])
#single_model.load_state_dict(checkpoint['model'])
#model = single_model.to(device)
#bert_model = single_bert_model.to(device)


# inference
#import torch.nn.functional as F
#last_hidden_states = bert_model(padded_sent_toks, attention_mask=attention_mask)[0]
#embedding = last_hidden_states.permute(0, 2, 1)
#output = model(img, embedding, l_mask=attention_mask.unsqueeze(-1))
#output = output.argmax(1, keepdim=True)  # (1, 1, 480, 480)
#output = F.interpolate(output.float(), (original_h, original_w))  # 'nearest'; resize to the original image size
#output = output.squeeze()  # (orig_h, orig_w)
#output = output.cpu().data.numpy()  # (orig_h, orig_w)

#output = pred_masks[0]

#output = output.cpu()



#print(output.shape)
#output_mask = output.argmax(1).data.numpy()
#output = (output > 0.5).data.cpu().numpy()


# show/save results
def overlay_davis(image, mask, colors=[[0, 0, 0], [255, 0, 0]], cscale=1, alpha=0.4):
    from scipy.ndimage.morphology import binary_dilation

    colors = np.reshape(colors, (-1, 3))
    colors = np.atleast_2d(colors) * cscale

    im_overlay = image.copy()
    object_ids = np.unique(mask)

    for object_id in object_ids[1:]:
        # Overlay color on  binary mask
        foreground = image*alpha + np.ones(image.shape)*(1-alpha) * np.array(colors[object_id])
        binary_mask = mask == object_id

        # Compose image
        im_overlay[binary_mask] = foreground[binary_mask]

        # countours = skimage.morphology.binary.binary_dilation(binary_mask) - binary_mask
        countours = binary_dilation(binary_mask) ^ binary_mask
        # countours = cv2.dilate(binary_mask, cv2.getStructuringElement(cv2.MORPH_CROSS,(3,3))) - binary_mask
        im_overlay[countours, :] = 0

    return im_overlay.astype(image.dtype)


def run_model(img, sentence):

#img = Image.open(image_path).convert("RGB")
    img = Image.fromarray(img)
    img = img.convert("RGB")
    #print(img.shape)
    img_ndarray = np.array(img)  # (orig_h, orig_w, 3); for visualization
    original_w, original_h = img.size  # PIL .size returns width first and height second

    image_transforms = T.Compose(
        [
         T.Resize((480, 480)),
         T.ToTensor(),
         T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ]
    )

    img = image_transforms(img).unsqueeze(0)  # (1, 3, 480, 480)
    img = img.to(device)  # for inference (input)

    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    sentence_tokenized = tokenizer.encode(text=sentence, add_special_tokens=True)
    sentence_tokenized = sentence_tokenized[:20]  # if the sentence is longer than 20, then this truncates it to 20 words
    # pad the tokenized sentence
    padded_sent_toks = [0] * 20
    padded_sent_toks[:len(sentence_tokenized)] = sentence_tokenized
    # create a sentence token mask: 1 for real words; 0 for padded tokens
    attention_mask = [0] * 20
    attention_mask[:len(sentence_tokenized)] = [1]*len(sentence_tokenized)
    # convert lists to tensors
    padded_sent_toks = torch.tensor(padded_sent_toks).unsqueeze(0)  # (1, 20)
    attention_mask = torch.tensor(attention_mask).unsqueeze(0)  # (1, 20)
    padded_sent_toks = padded_sent_toks.to(device)  # for inference (input)
    attention_mask = attention_mask.to(device)  # for inference (input)

    output = model(img, padded_sent_toks, attention_mask)[0]
    #print(output[0].keys())
    #print(output[1].shape)
    mask_cls_results = output["pred_logits"]
    mask_pred_results = output["pred_masks"]

    target_shape = img_ndarray.shape[:2]
    #print(target_shape, mask_pred_results.shape)
    mask_pred_results = F.interpolate(mask_pred_results, size=(480,480), mode='bilinear', align_corners=True)

    pred_masks = model.semantic_inference(mask_cls_results, mask_pred_results)                

    output = torch.nn.functional.interpolate(pred_masks, target_shape)
    output = (output > 0.5).data.cpu().numpy()

    output = output.astype(np.uint8)  # (orig_h, orig_w), np.uint8
    # Overlay the mask on the image
    print(img_ndarray.shape, output.shape)
    visualization = overlay_davis(img_ndarray, output[0][0])  # red
    visualization = Image.fromarray(visualization)
    # show the visualization
    #visualization.show()
    # Save the visualization
    #visualization.save('./demo/spoon_on_the_dish.jpg')
    return visualization




demo = gr.Interface(run_model, inputs=[gr.Image(), "text"], outputs=["image"])
demo.launch()