Update app.py
Browse files
app.py
CHANGED
@@ -74,6 +74,10 @@ class Prober:
|
|
74 |
|
75 |
self.model = self.model.eval()
|
76 |
|
|
|
|
|
|
|
|
|
77 |
|
78 |
@torch.no_grad()
|
79 |
def probe(self, idx, re, search_by_sample_id: bool= True):
|
@@ -82,38 +86,40 @@ class Prober:
|
|
82 |
else:
|
83 |
img_path, target = self.df[self.df.image_id == idx][['file_path','bbox']].values[0]
|
84 |
img = Image.open(self.zipfile.open(img_path)).convert('RGB')
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
|
|
|
|
110 |
prober = Prober(
|
111 |
df_path = 'data/val-sim_metric.json',
|
112 |
dataset_path = "data/saiapr_tc-12.zip",
|
113 |
model_checkpoint= "cache/20211220_191132_refclef_32_512_resnet50_8_6_8_0.1_0.1_0.1_0_0.0001_0.0_12_4_90_1_0_0_0/best.ckpt"
|
114 |
)
|
115 |
|
116 |
-
demo = gr.Interface(fn=prober.probe, inputs=["number", "text"
|
117 |
|
118 |
demo.queue(concurrency_count=10)
|
119 |
demo.launch(debug=True)
|
|
|
74 |
|
75 |
self.model = self.model.eval()
|
76 |
|
77 |
+
def preview_image(self, idx):
|
78 |
+
img_path, target, = self.df.loc[idx][['file_path','bbox']].values
|
79 |
+
img = Image.open(self.zipfile.open(img_path)).convert('RGB')
|
80 |
+
return img
|
81 |
|
82 |
@torch.no_grad()
|
83 |
def probe(self, idx, re, search_by_sample_id: bool= True):
|
|
|
86 |
else:
|
87 |
img_path, target = self.df[self.df.image_id == idx][['file_path','bbox']].values[0]
|
88 |
img = Image.open(self.zipfile.open(img_path)).convert('RGB')
|
89 |
+
if re != "":
|
90 |
+
W0, H0 = img.size
|
91 |
+
sample = {
|
92 |
+
'image': img,
|
93 |
+
'image_size': (H0, W0), # image original size
|
94 |
+
'bbox': torch.tensor([copy(target)]),
|
95 |
+
'bbox_raw': torch.tensor([copy(target)]),
|
96 |
+
'mask': torch.ones((1, H0, W0), dtype=torch.float32), # visibiity mask
|
97 |
+
'mask_bbox': None, # target bbox mask
|
98 |
+
}
|
99 |
+
sample = self.transform(sample)
|
100 |
+
tok = self.tokenizer(re,
|
101 |
+
max_length=30,
|
102 |
+
return_tensors='pt',
|
103 |
+
truncation=True)
|
104 |
+
inn = {'image': torch.stack([sample['image']]),
|
105 |
+
'mask': torch.stack([sample['mask']]),
|
106 |
+
'tok': tok}
|
107 |
+
output = undo_box_transforms_batch(self.model(inn)[0],
|
108 |
+
[sample['tr_param']]).numpy().tolist()[0]
|
109 |
+
img1 = ImageDraw.Draw(img)
|
110 |
+
#img1.rectangle(target, outline ="#0000FF00", width=3)
|
111 |
+
img1.rectangle(output, outline ="#00FF0000", width=3)
|
112 |
+
return img
|
113 |
+
else:
|
114 |
+
return img
|
115 |
+
|
116 |
prober = Prober(
|
117 |
df_path = 'data/val-sim_metric.json',
|
118 |
dataset_path = "data/saiapr_tc-12.zip",
|
119 |
model_checkpoint= "cache/20211220_191132_refclef_32_512_resnet50_8_6_8_0.1_0.1_0.1_0_0.0001_0.0_12_4_90_1_0_0_0/best.ckpt"
|
120 |
)
|
121 |
|
122 |
+
demo = gr.Interface(fn=prober.probe, inputs=["number", "text"], outputs="image", live=True)
|
123 |
|
124 |
demo.queue(concurrency_count=10)
|
125 |
demo.launch(debug=True)
|