CuriousDolphin commited on
Commit
28210ca
·
1 Parent(s): 41eedc7

add detr panoptic resnet101

Browse files
data/assets/download.png ADDED
detr/detr.py CHANGED
@@ -1,12 +1,16 @@
1
  from functools import cache
 
 
2
  import torch
3
  import torchvision.transforms as T
4
  import os
5
  import numpy as np
 
6
  from torch import nn
7
  from torchvision.models import resnet50
8
-
9
- from supervision import Detections, BoxAnnotator
 
10
 
11
  torch.set_grad_enabled(False)
12
 
@@ -199,25 +203,37 @@ class SimpleDetr:
199
  class PanopticDetrResenet101:
200
  @cache
201
  def __init__(self):
202
- model, postprocessor = torch.hub.load(
203
  "facebookresearch/detr",
204
  "detr_resnet101_panoptic",
205
  pretrained=True,
206
  return_postprocessor=True,
207
  num_classes=250,
208
  )
209
- model.eval()
210
 
211
  def detect(self, image, conf):
212
  # mean-std normalize the input image (batch-size: 1)
213
  img = normalize_img(image)
214
 
215
  outputs = self.model(img)
216
- # keep only predictions with 0.7+ confidence
217
- # compute the scores, excluding the "no-object" class (the last one)
218
- scores = outputs["pred_logits"].softmax(-1)[..., :-1].max(-1)[0]
219
- # threshold the confidence
220
- keep = scores > conf
 
 
 
 
 
 
 
 
 
 
 
 
221
 
222
 
223
  # COCO classes
 
1
  from functools import cache
2
+ import io
3
+ import itertools
4
  import torch
5
  import torchvision.transforms as T
6
  import os
7
  import numpy as np
8
+ import seaborn as sns
9
  from torch import nn
10
  from torchvision.models import resnet50
11
+ from panopticapi.utils import id2rgb, rgb2id
12
+ from supervision import Detections, BoxAnnotator, MaskAnnotator
13
+ from PIL import Image
14
 
15
  torch.set_grad_enabled(False)
16
 
 
203
  class PanopticDetrResenet101:
204
  @cache
205
  def __init__(self):
206
+ self.model, self.postprocessor = torch.hub.load(
207
  "facebookresearch/detr",
208
  "detr_resnet101_panoptic",
209
  pretrained=True,
210
  return_postprocessor=True,
211
  num_classes=250,
212
  )
213
+ self.model.eval()
214
 
215
  def detect(self, image, conf):
216
  # mean-std normalize the input image (batch-size: 1)
217
  img = normalize_img(image)
218
 
219
  outputs = self.model(img)
220
+ result = self.postprocessor(
221
+ outputs, torch.as_tensor(img.shape[-2:]).unsqueeze(0)
222
+ )[0]
223
+ print(result.keys())
224
+ palette = itertools.cycle(sns.color_palette())
225
+
226
+ # The segmentation is stored in a special-format png
227
+ panoptic_seg = Image.open(io.BytesIO(result["png_string"]))
228
+ panoptic_seg = np.array(panoptic_seg, dtype=np.uint8).copy()
229
+ # We retrieve the ids corresponding to each mask
230
+ panoptic_seg_id = rgb2id(panoptic_seg)
231
+
232
+ # Finally we color each mask individually
233
+ panoptic_seg[:, :, :] = 0
234
+ for id in range(panoptic_seg_id.max() + 1):
235
+ panoptic_seg[panoptic_seg_id == id] = np.asarray(next(palette)) * 255
236
+ return panoptic_seg
237
 
238
 
239
  # COCO classes
detr/main_gradio.py CHANGED
@@ -1,6 +1,8 @@
1
  import gradio as gr
2
  import supervision as sv
3
  import os
 
 
4
  from detr import SimpleDetr, PanopticDetrResenet101
5
 
6
  ASSETS_DIR = os.path.abspath(os.curdir) + "/data/assets"
@@ -10,15 +12,17 @@ print("Assets:", ASSETS_DIR)
10
 
11
  def run_inference(image, confidence, model_name, progress=gr.Progress(track_tqdm=True)):
12
  progress(0.1, "loading model..")
13
-
14
  if model_name == "detr_demo_boxes":
15
  model = SimpleDetr()
16
  else:
17
  model = PanopticDetrResenet101()
 
18
  progress(0.1, "Inference..")
19
 
20
  annotated_img = model.detect(image, confidence)
21
- return annotated_img, None, None
 
22
 
23
 
24
  with gr.Blocks() as inference_gradio:
@@ -47,7 +51,7 @@ with gr.Blocks() as inference_gradio:
47
  examples=[
48
  [path]
49
  for path in sv.list_files_with_extensions(
50
- directory=ASSETS_DIR, extensions=["jpeg", "jpg"]
51
  )
52
  ],
53
  inputs=[img_file],
 
1
  import gradio as gr
2
  import supervision as sv
3
  import os
4
+ from time import perf_counter
5
+
6
  from detr import SimpleDetr, PanopticDetrResenet101
7
 
8
  ASSETS_DIR = os.path.abspath(os.curdir) + "/data/assets"
 
12
 
13
  def run_inference(image, confidence, model_name, progress=gr.Progress(track_tqdm=True)):
14
  progress(0.1, "loading model..")
15
+ t0 = perf_counter()
16
  if model_name == "detr_demo_boxes":
17
  model = SimpleDetr()
18
  else:
19
  model = PanopticDetrResenet101()
20
+ t1 = perf_counter()
21
  progress(0.1, "Inference..")
22
 
23
  annotated_img = model.detect(image, confidence)
24
+ t2 = perf_counter()
25
+ return annotated_img, {"load_model": t1 - t0, "inference": t2 - t1}, None
26
 
27
 
28
  with gr.Blocks() as inference_gradio:
 
51
  examples=[
52
  [path]
53
  for path in sv.list_files_with_extensions(
54
+ directory=ASSETS_DIR, extensions=["jpeg", "jpg", "png"]
55
  )
56
  ],
57
  inputs=[img_file],
requirements.txt CHANGED
@@ -3,4 +3,6 @@ torch==2.1.1
3
  numpy
4
  matplotlib
5
  torchvision
6
- supervision==0.17.1
 
 
 
3
  numpy
4
  matplotlib
5
  torchvision
6
+ supervision==0.17.1
7
+ git+https://github.com/cocodataset/panopticapi.git
8
+ seaborn