techysanoj commited on
Commit
b496423
·
1 Parent(s): bebfcb3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -25
app.py CHANGED
@@ -1,44 +1,32 @@
1
  import gradio as gr
2
  import torch
3
- import torchvision.transforms as transforms
4
- from torchvision.models.detection import detr
5
  from PIL import Image
6
- import cv2
7
- import numpy as np
8
 
9
  # Load the pretrained DETR model
10
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
- model = detr.DETR(resnet50=True)
12
- model = model.to(device).eval()
13
-
14
- # Define the transformation for the input image
15
- transform = transforms.Compose([
16
- transforms.ToTensor(),
17
- transforms.Resize((800, 800)),
18
- ])
19
 
20
  # Define the object detection function
21
  def detect_objects(frame):
22
  # Convert the frame to PIL image
23
  image = Image.fromarray(frame)
24
 
25
- # Apply the transformation
26
- image = transform(image).unsqueeze(0).to(device)
27
 
28
  # Perform object detection
29
- with torch.no_grad():
30
- outputs = model(image)
31
 
32
- # Get the bounding boxes and labels
33
- boxes = outputs['pred_boxes'][0].cpu().numpy()
34
- labels = outputs['pred_classes'][0].cpu().numpy()
35
 
36
  # Draw bounding boxes on the frame
37
- for box, label in zip(boxes, labels):
38
- box = [int(coord) for coord in box]
39
- frame = cv2.rectangle(frame, (box[0], box[1]), (box[2], box[3]), (0, 255, 0), 2)
40
- frame = cv2.putText(frame, f'Class: {label}', (box[0], box[1] - 10),
41
- cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2, cv2.LINE_AA)
42
 
43
  return frame
44
 
 
1
  import gradio as gr
2
  import torch
 
 
3
  from PIL import Image
4
+ from torchvision.transforms import functional as F
5
+ from transformers import DetrImageProcessor, DetrForObjectDetection
6
 
7
  # Load the pretrained DETR model
8
+ processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
9
+ model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
 
 
 
 
 
 
 
10
 
11
  # Define the object detection function
12
  def detect_objects(frame):
13
  # Convert the frame to PIL image
14
  image = Image.fromarray(frame)
15
 
16
+ # Preprocess the image
17
+ inputs = processor(images=image, return_tensors="pt")
18
 
19
  # Perform object detection
20
+ outputs = model(**inputs)
 
21
 
22
+ # Convert outputs to COCO API format
23
+ target_sizes = torch.tensor([image.size[::-1]])
24
+ results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
25
 
26
  # Draw bounding boxes on the frame
27
+ for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
28
+ box = [round(i, 2) for i in box.tolist()]
29
+ frame = gr.draw_box(frame, box, label=model.config.id2label[label.item()], color=(0, 255, 0))
 
 
30
 
31
  return frame
32