File size: 4,935 Bytes
2d07fab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80bbb99
 
 
 
2d07fab
 
 
 
 
 
 
 
80bbb99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2d07fab
 
 
 
 
 
f13173e
2d07fab
 
f13173e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
from models import IntuitionKillingMachine
from transforms import undo_box_transforms_batch, ToTensor, Normalize, SquarePad, Resize, NormalizeBoxCoords
from torchvision.transforms import Compose
from encoders import get_tokenizer
from PIL import Image, ImageDraw
from zipfile import ZipFile
from copy import copy
import gradio as gr
import pandas as pd
import torch

def parse_model_args(model_path):
    _, _, dataset, max_length, input_size, backbone, num_heads, num_layers, num_conv, _, _, mu, mask_pooling = model_path.split('_')[:13]
    return {
        'dataset': dataset,
        'max_length': int(max_length),
        'input_size': int(input_size),
        'backbone': backbone,
        'num_heads': int(num_heads),
        'num_layers': int(num_layers),
        'num_conv': int(num_conv),
        'mu': float(mu),
        'mask_pooling': bool(mask_pooling == '1')
    }


class Prober:
    def __init__(self,
                 df_path=None,
                 dataset_path=None,
                 model_checkpoint=None):
        params = parse_model_args(model_checkpoint)
        mean = [0.485, 0.456, 0.406]
        sdev = [0.229, 0.224, 0.225]
        self.tokenizer = get_tokenizer()
        self.df = pd.read_json(df_path)[['sample_idx', 'bbox', 'file_path', 'sent']]
        self.df.loc[:, "image_id"] = self.df.file_path.apply(lambda x: int(x.split('/')[-1][:-4]))
        self.df.file_path = self.df.file_path.apply(lambda x: x.replace('refer/data/images/', ''))
        self.model = IntuitionKillingMachine(
            backbone=params['backbone'],
            pretrained=True,
            num_heads=params['num_heads'],
            num_layers=params['num_layers'],
            num_conv=params['num_conv'],
            segmentation_head=bool(params['mu'] > 0.0),
            mask_pooling=params['mask_pooling']
        ) 
        self.load_model(model_checkpoint)
        self.transform = Compose([
            ToTensor(),
            Normalize(mean, sdev),
            SquarePad(),
            Resize(size=(params['input_size'], params['input_size'])),
            NormalizeBoxCoords(),
        ])
        self.max_length = 30
        self.zipfile = ZipFile(dataset_path, 'r')

    def load_model(self, model_checkpoint):
        checkpoint = torch.load(
            model_checkpoint, map_location=lambda storage, loc: storage
        )

        # strip 'model.' from pl checkpoint
        state_dict = {
            k[len('model.'):]: v
            for k, v in checkpoint['state_dict'].items()
        }

        missing, _ = self.model.load_state_dict(state_dict, strict=False)

        # ensure the only missing keys are those of the segmentation head only
        assert [k for k in missing if 'segm' not in k] == []

        self.model = self.model.eval()

    def preview_image(self, idx):
        img_path, target, = self.df.loc[idx][['file_path','bbox']].values
        img = Image.open(self.zipfile.open(img_path)).convert('RGB')
        return img

    @torch.no_grad()
    def probe(self, idx, re, search_by_sample_id: bool= True):
        if search_by_sample_id:
            img_path, target, = self.df.loc[idx][['file_path','bbox']].values
        else: 
            img_path, target = self.df[self.df.image_id == idx][['file_path','bbox']].values[0]
        img = Image.open(self.zipfile.open(img_path)).convert('RGB')
        if re != "":
            W0, H0 = img.size
            sample = {
                'image': img,
                'image_size': (H0, W0),  # image original size
                'bbox': torch.tensor([copy(target)]),
                'bbox_raw': torch.tensor([copy(target)]),
                'mask': torch.ones((1, H0, W0), dtype=torch.float32),  # visibiity mask
                'mask_bbox': None,  # target bbox mask
            } 
            sample = self.transform(sample)
            tok = self.tokenizer(re,
                                 max_length=30,
                                 return_tensors='pt',
                                 truncation=True)
            inn = {'image': torch.stack([sample['image']]),
                   'mask': torch.stack([sample['mask']]),
                   'tok': tok}
            output = undo_box_transforms_batch(self.model(inn)[0],
                                               [sample['tr_param']]).numpy().tolist()[0]
            img1 = ImageDraw.Draw(img)
            #img1.rectangle(target, outline ="#0000FF00", width=3)
            img1.rectangle(output, outline ="#00FF0000", width=3)
            return img
        else:
            return img
    
prober = Prober(
    df_path = 'data/val-sim_metric.json',
    dataset_path = "data/saiapr_tc-12.zip",
    model_checkpoint= "cache/20211220_191132_refclef_32_512_resnet50_8_6_8_0.1_0.1_0.1_0_0.0001_0.0_12_4_90_1_0_0_0/best.ckpt"
)

demo = gr.Interface(fn=prober.probe, inputs=["number", "text"], outputs="image")#, live=True)

demo.queue(concurrency_count=10)
demo.launch()