techysanoj commited on
Commit
bf51623
·
1 Parent(s): ac00638

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -11
app.py CHANGED
@@ -9,34 +9,57 @@ processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50", revisi
9
  model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
10
  model.eval()
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  # Function for live object detection from the camera
13
  def live_object_detection(image_pil):
14
- # Convert the frame to PIL Image
15
- frame_pil = Image.fromarray(cv2.cvtColor(image_pil, cv2.COLOR_BGR2RGB))
16
-
17
  # Process the frame with the DETR model
18
- inputs = processor(images=frame_pil, return_tensors="pt")
19
  outputs = model(**inputs)
20
 
21
  # convert outputs (bounding boxes and class logits) to COCO API
22
  # let's only keep detections with score > 0.9
23
- target_sizes = torch.tensor([frame_pil.size[::-1]])
24
  results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
25
 
26
- # Draw bounding boxes on the frame
27
  for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
28
  box = [int(round(i)) for i in box.tolist()]
29
  cv2.rectangle(image_pil, (box[0], box[1]), (box[2], box[3]), (0, 255, 0), 2)
30
- cv2.putText(image_pil, f"{model.config.id2label[label.item()]}: {round(score.item(), 3)}",
31
- (box[0], box[1] - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
32
 
33
  return image_pil
34
 
35
  # Define the Gradio interface
36
  iface = gr.Interface(
37
- fn=live_object_detection,
38
- inputs="image",
39
- outputs="image",
 
 
 
40
  live=True,
41
  )
42
 
 
9
  model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
10
  model.eval()
11
 
12
+ # Function for image object detection
13
+ def image_object_detection(image_pil):
14
+ # Process the image with the DETR model
15
+ inputs = processor(images=image_pil, return_tensors="pt")
16
+ outputs = model(**inputs)
17
+
18
+ # Convert the image to numpy array for drawing bounding boxes
19
+ image_np = cv2.cvtColor(cv2.cvtColor(cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR), cv2.COLOR_BGR2RGB), cv2.COLOR_RGB2BGR)
20
+
21
+ # convert outputs (bounding boxes and class logits) to COCO API
22
+ # let's only keep detections with score > 0.9
23
+ target_sizes = torch.tensor([image_pil.size[::-1]])
24
+ results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
25
+
26
+ # Draw bounding boxes on the image
27
+ for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
28
+ box = [int(round(i)) for i in box.tolist()]
29
+ cv2.rectangle(image_np, (box[0], box[1]), (box[2], box[3]), (0, 255, 0), 2)
30
+ label_text = f"{model.config.id2label[label.item()]}: {round(score.item(), 3)}"
31
+ cv2.putText(image_np, label_text, (box[0], box[1] - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
32
+
33
+ return image_np
34
+
35
  # Function for live object detection from the camera
36
  def live_object_detection(image_pil):
 
 
 
37
  # Process the frame with the DETR model
38
+ inputs = processor(images=image_pil, return_tensors="pt")
39
  outputs = model(**inputs)
40
 
41
  # convert outputs (bounding boxes and class logits) to COCO API
42
  # let's only keep detections with score > 0.9
43
+ target_sizes = torch.tensor([image_pil.size[::-1]])
44
  results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
45
 
46
+ # Draw bounding boxes on the image
47
  for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
48
  box = [int(round(i)) for i in box.tolist()]
49
  cv2.rectangle(image_pil, (box[0], box[1]), (box[2], box[3]), (0, 255, 0), 2)
50
+ label_text = f"{model.config.id2label[label.item()]}: {round(score.item(), 3)}"
51
+ cv2.putText(image_pil, label_text, (box[0], box[1] - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
52
 
53
  return image_pil
54
 
55
  # Define the Gradio interface
56
  iface = gr.Interface(
57
+ fn=[image_object_detection, live_object_detection],
58
+ inputs=[
59
+ gr.Image(type="pil", label="Upload an image for object detection", hover=True),
60
+ "webcam",
61
+ ],
62
+ outputs=["image", "image"],
63
  live=True,
64
  )
65