Commit
·
28210ca
1
Parent(s):
41eedc7
add detr panoptic resnet101
Browse files- data/assets/download.png +0 -0
- detr/detr.py +25 -9
- detr/main_gradio.py +7 -3
- requirements.txt +3 -1
data/assets/download.png
ADDED
![]() |
detr/detr.py
CHANGED
@@ -1,12 +1,16 @@
|
|
1 |
from functools import cache
|
|
|
|
|
2 |
import torch
|
3 |
import torchvision.transforms as T
|
4 |
import os
|
5 |
import numpy as np
|
|
|
6 |
from torch import nn
|
7 |
from torchvision.models import resnet50
|
8 |
-
|
9 |
-
from supervision import Detections, BoxAnnotator
|
|
|
10 |
|
11 |
torch.set_grad_enabled(False)
|
12 |
|
@@ -199,25 +203,37 @@ class SimpleDetr:
|
|
199 |
class PanopticDetrResenet101:
|
200 |
@cache
|
201 |
def __init__(self):
|
202 |
-
model, postprocessor = torch.hub.load(
|
203 |
"facebookresearch/detr",
|
204 |
"detr_resnet101_panoptic",
|
205 |
pretrained=True,
|
206 |
return_postprocessor=True,
|
207 |
num_classes=250,
|
208 |
)
|
209 |
-
model.eval()
|
210 |
|
211 |
def detect(self, image, conf):
|
212 |
# mean-std normalize the input image (batch-size: 1)
|
213 |
img = normalize_img(image)
|
214 |
|
215 |
outputs = self.model(img)
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
221 |
|
222 |
|
223 |
# COCO classes
|
|
|
1 |
from functools import cache
|
2 |
+
import io
|
3 |
+
import itertools
|
4 |
import torch
|
5 |
import torchvision.transforms as T
|
6 |
import os
|
7 |
import numpy as np
|
8 |
+
import seaborn as sns
|
9 |
from torch import nn
|
10 |
from torchvision.models import resnet50
|
11 |
+
from panopticapi.utils import id2rgb, rgb2id
|
12 |
+
from supervision import Detections, BoxAnnotator, MaskAnnotator
|
13 |
+
from PIL import Image
|
14 |
|
15 |
torch.set_grad_enabled(False)
|
16 |
|
|
|
203 |
class PanopticDetrResenet101:
|
204 |
@cache
|
205 |
def __init__(self):
|
206 |
+
self.model, self.postprocessor = torch.hub.load(
|
207 |
"facebookresearch/detr",
|
208 |
"detr_resnet101_panoptic",
|
209 |
pretrained=True,
|
210 |
return_postprocessor=True,
|
211 |
num_classes=250,
|
212 |
)
|
213 |
+
self.model.eval()
|
214 |
|
215 |
def detect(self, image, conf):
|
216 |
# mean-std normalize the input image (batch-size: 1)
|
217 |
img = normalize_img(image)
|
218 |
|
219 |
outputs = self.model(img)
|
220 |
+
result = self.postprocessor(
|
221 |
+
outputs, torch.as_tensor(img.shape[-2:]).unsqueeze(0)
|
222 |
+
)[0]
|
223 |
+
print(result.keys())
|
224 |
+
palette = itertools.cycle(sns.color_palette())
|
225 |
+
|
226 |
+
# The segmentation is stored in a special-format png
|
227 |
+
panoptic_seg = Image.open(io.BytesIO(result["png_string"]))
|
228 |
+
panoptic_seg = np.array(panoptic_seg, dtype=np.uint8).copy()
|
229 |
+
# We retrieve the ids corresponding to each mask
|
230 |
+
panoptic_seg_id = rgb2id(panoptic_seg)
|
231 |
+
|
232 |
+
# Finally we color each mask individually
|
233 |
+
panoptic_seg[:, :, :] = 0
|
234 |
+
for id in range(panoptic_seg_id.max() + 1):
|
235 |
+
panoptic_seg[panoptic_seg_id == id] = np.asarray(next(palette)) * 255
|
236 |
+
return panoptic_seg
|
237 |
|
238 |
|
239 |
# COCO classes
|
detr/main_gradio.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1 |
import gradio as gr
|
2 |
import supervision as sv
|
3 |
import os
|
|
|
|
|
4 |
from detr import SimpleDetr, PanopticDetrResenet101
|
5 |
|
6 |
ASSETS_DIR = os.path.abspath(os.curdir) + "/data/assets"
|
@@ -10,15 +12,17 @@ print("Assets:", ASSETS_DIR)
|
|
10 |
|
11 |
def run_inference(image, confidence, model_name, progress=gr.Progress(track_tqdm=True)):
|
12 |
progress(0.1, "loading model..")
|
13 |
-
|
14 |
if model_name == "detr_demo_boxes":
|
15 |
model = SimpleDetr()
|
16 |
else:
|
17 |
model = PanopticDetrResenet101()
|
|
|
18 |
progress(0.1, "Inference..")
|
19 |
|
20 |
annotated_img = model.detect(image, confidence)
|
21 |
-
|
|
|
22 |
|
23 |
|
24 |
with gr.Blocks() as inference_gradio:
|
@@ -47,7 +51,7 @@ with gr.Blocks() as inference_gradio:
|
|
47 |
examples=[
|
48 |
[path]
|
49 |
for path in sv.list_files_with_extensions(
|
50 |
-
directory=ASSETS_DIR, extensions=["jpeg", "jpg"]
|
51 |
)
|
52 |
],
|
53 |
inputs=[img_file],
|
|
|
1 |
import gradio as gr
|
2 |
import supervision as sv
|
3 |
import os
|
4 |
+
from time import perf_counter
|
5 |
+
|
6 |
from detr import SimpleDetr, PanopticDetrResenet101
|
7 |
|
8 |
ASSETS_DIR = os.path.abspath(os.curdir) + "/data/assets"
|
|
|
12 |
|
13 |
def run_inference(image, confidence, model_name, progress=gr.Progress(track_tqdm=True)):
|
14 |
progress(0.1, "loading model..")
|
15 |
+
t0 = perf_counter()
|
16 |
if model_name == "detr_demo_boxes":
|
17 |
model = SimpleDetr()
|
18 |
else:
|
19 |
model = PanopticDetrResenet101()
|
20 |
+
t1 = perf_counter()
|
21 |
progress(0.1, "Inference..")
|
22 |
|
23 |
annotated_img = model.detect(image, confidence)
|
24 |
+
t2 = perf_counter()
|
25 |
+
return annotated_img, {"load_model": t1 - t0, "inference": t2 - t1}, None
|
26 |
|
27 |
|
28 |
with gr.Blocks() as inference_gradio:
|
|
|
51 |
examples=[
|
52 |
[path]
|
53 |
for path in sv.list_files_with_extensions(
|
54 |
+
directory=ASSETS_DIR, extensions=["jpeg", "jpg", "png"]
|
55 |
)
|
56 |
],
|
57 |
inputs=[img_file],
|
requirements.txt
CHANGED
@@ -3,4 +3,6 @@ torch==2.1.1
|
|
3 |
numpy
|
4 |
matplotlib
|
5 |
torchvision
|
6 |
-
supervision==0.17.1
|
|
|
|
|
|
3 |
numpy
|
4 |
matplotlib
|
5 |
torchvision
|
6 |
+
supervision==0.17.1
|
7 |
+
git+https://github.com/cocodataset/panopticapi.git
|
8 |
+
seaborn
|