File size: 5,332 Bytes
2d07fab 80bbb99 2d07fab 80bbb99 2d07fab 06c2209 2d07fab |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
from models import IntuitionKillingMachine
from transforms import undo_box_transforms_batch, ToTensor, Normalize, SquarePad, Resize, NormalizeBoxCoords
from torchvision.transforms import Compose
from encoders import get_tokenizer
from PIL import Image, ImageDraw
from zipfile import ZipFile
from copy import copy
import gradio as gr
import pandas as pd
import torch
def parse_model_args(model_path):
_, _, dataset, max_length, input_size, backbone, num_heads, num_layers, num_conv, _, _, mu, mask_pooling = model_path.split('_')[:13]
return {
'dataset': dataset,
'max_length': int(max_length),
'input_size': int(input_size),
'backbone': backbone,
'num_heads': int(num_heads),
'num_layers': int(num_layers),
'num_conv': int(num_conv),
'mu': float(mu),
'mask_pooling': bool(mask_pooling == '1')
}
class Prober:
def __init__(self,
df_path=None,
dataset_path=None,
model_checkpoint=None):
params = parse_model_args(model_checkpoint)
mean = [0.485, 0.456, 0.406]
sdev = [0.229, 0.224, 0.225]
self.tokenizer = get_tokenizer()
self.df = pd.read_json(df_path)[['sample_idx', 'bbox', 'file_path', 'sent']]
self.df.loc[:, "image_id"] = self.df.file_path.apply(lambda x: int(x.split('/')[-1][:-4]))
self.df.file_path = self.df.file_path.apply(lambda x: x.replace('refer/data/images/', ''))
self.model = IntuitionKillingMachine(
backbone=params['backbone'],
pretrained=True,
num_heads=params['num_heads'],
num_layers=params['num_layers'],
num_conv=params['num_conv'],
segmentation_head=bool(params['mu'] > 0.0),
mask_pooling=params['mask_pooling']
)
self.load_model(model_checkpoint)
self.transform = Compose([
ToTensor(),
Normalize(mean, sdev),
SquarePad(),
Resize(size=(params['input_size'], params['input_size'])),
NormalizeBoxCoords(),
])
self.max_length = 30
self.zipfile = ZipFile(dataset_path, 'r')
def load_model(self, model_checkpoint):
checkpoint = torch.load(
model_checkpoint, map_location=lambda storage, loc: storage
)
# strip 'model.' from pl checkpoint
state_dict = {
k[len('model.'):]: v
for k, v in checkpoint['state_dict'].items()
}
missing, _ = self.model.load_state_dict(state_dict, strict=False)
# ensure the only missing keys are those of the segmentation head only
assert [k for k in missing if 'segm' not in k] == []
self.model = self.model.eval()
def preview_image(self, idx):
img_path, target, = self.df.loc[idx][['file_path','bbox']].values
img = Image.open(self.zipfile.open(img_path)).convert('RGB')
return img
@torch.no_grad()
def probe(self, idx, re, search_by_sample_id: bool= True):
if search_by_sample_id:
img_path, target, = self.df.loc[idx][['file_path','bbox']].values
else:
img_path, target = self.df[self.df.image_id == idx][['file_path','bbox']].values[0]
img = Image.open(self.zipfile.open(img_path)).convert('RGB')
if re != "":
W0, H0 = img.size
sample = {
'image': img,
'image_size': (H0, W0), # image original size
'bbox': torch.tensor([copy(target)]),
'bbox_raw': torch.tensor([copy(target)]),
'mask': torch.ones((1, H0, W0), dtype=torch.float32), # visibiity mask
'mask_bbox': None, # target bbox mask
}
sample = self.transform(sample)
tok = self.tokenizer(re,
max_length=30,
return_tensors='pt',
truncation=True)
inn = {'image': torch.stack([sample['image']]),
'mask': torch.stack([sample['mask']]),
'tok': tok}
output = undo_box_transforms_batch(self.model(inn)[0],
[sample['tr_param']]).numpy().tolist()[0]
img1 = ImageDraw.Draw(img)
#img1.rectangle(target, outline ="#0000FF00", width=3)
img1.rectangle(output, outline ="#00FF0000", width=3)
return img
else:
return img
prober = Prober(
df_path = 'data/val-sim_metric.json',
dataset_path = "data/saiapr_tc-12.zip",
model_checkpoint= "cache/20211220_191132_refclef_32_512_resnet50_8_6_8_0.1_0.1_0.1_0_0.0001_0.0_12_4_90_1_0_0_0/best.ckpt"
)
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
a = gr.Number(label="sample_id")
output = gr.Image(prober.preview_image(a))
with gr.Row():
re = gr.Textbox(label="referring expression")
greet_btn = gr.Button("Greet")
output = gr.Image()
greet_btn.click(fn=prober.probe, inputs=[a, name], outputs=output)
# demo = gr.Interface(fn=prober.probe, inputs=["number", "text"], outputs="image", live=True)
demo.queue(concurrency_count=10)
demo.launch(debug=True) |