techysanoj commited on
Commit
6c68290
·
1 Parent(s): f7b8ab6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -4
app.py CHANGED
@@ -3,11 +3,14 @@ import torch
3
  from PIL import Image
4
  from torchvision.transforms import functional as F
5
  from transformers import DetrImageProcessor, DetrForObjectDetection
 
 
6
 
7
  # Load the pretrained DETR model
8
  processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
9
  model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
10
 
 
11
  def detect_objects(frame):
12
  # Convert the frame to PIL image
13
  image = Image.fromarray(frame)
@@ -25,16 +28,17 @@ def detect_objects(frame):
25
  # Draw bounding boxes on the frame
26
  for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
27
  box = [round(i, 2) for i in box.tolist()]
28
- frame = gr.draw_box(frame, box, label=model.config.id2label[label.item()], color=(0, 255, 0))
 
 
29
 
30
- # Convert frame back to numpy array for Gradio
31
- return np.array(frame)
32
 
33
  # Define the Gradio interface
34
  iface = gr.Interface(
35
  fn=detect_objects,
36
  inputs=gr.Video(),
37
- outputs="video",
38
  live=True,
39
  )
40
 
 
3
  from PIL import Image
4
  from torchvision.transforms import functional as F
5
  from transformers import DetrImageProcessor, DetrForObjectDetection
6
+ import cv2
7
+ import numpy as np
8
 
9
  # Load the pretrained DETR model
10
  processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
11
  model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
12
 
13
+ # Define the object detection function
14
  def detect_objects(frame):
15
  # Convert the frame to PIL image
16
  image = Image.fromarray(frame)
 
28
  # Draw bounding boxes on the frame
29
  for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
30
  box = [round(i, 2) for i in box.tolist()]
31
+ frame = cv2.rectangle(frame, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (0, 255, 0), 2)
32
+ frame = cv2.putText(frame, f'{model.config.id2label[label.item()]}: {round(score.item(), 3)}',
33
+ (int(box[0]), int(box[1]) - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2, cv2.LINE_AA)
34
 
35
+ return frame
 
36
 
37
  # Define the Gradio interface
38
  iface = gr.Interface(
39
  fn=detect_objects,
40
  inputs=gr.Video(),
41
+ outputs="numpy_image",
42
  live=True,
43
  )
44