techysanoj commited on
Commit
bebfcb3
·
1 Parent(s): 4cbee43

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -27
app.py CHANGED
@@ -1,45 +1,55 @@
1
  import gradio as gr
2
  import torch
 
 
3
  from PIL import Image
4
- from transformers import DetrImageProcessor, DetrForObjectDetection
5
  import cv2
6
  import numpy as np
7
 
8
- # Load the pre-trained DETR model
9
- processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
10
- model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
11
- model.eval()
12
 
13
- # Function for image object detection
14
- def image_object_detection(image_pil):
15
- # Process the image with the DETR model
16
- inputs = processor(images=image_pil, return_tensors="pt")
17
- outputs = model(**inputs)
18
 
19
- # Convert the image to numpy array for drawing bounding boxes
20
- image_np = cv2.cvtColor(cv2.cvtColor(cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR), cv2.COLOR_BGR2RGB), cv2.COLOR_RGB2BGR)
 
 
21
 
22
- # convert outputs (bounding boxes and class logits) to COCO API
23
- # let's only keep detections with score > 0.9
24
- target_sizes = torch.tensor([image_pil.size[::-1]])
25
- results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
26
 
27
- # Draw bounding boxes on the image
28
- for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
29
- box = [int(round(i)) for i in box.tolist()]
30
- cv2.rectangle(image_np, (box[0], box[1]), (box[2], box[3]), (0, 255, 0), 2)
31
- label_text = f"{model.config.id2label[label.item()]}: {round(score.item(), 3)}"
32
- cv2.putText(image_np, label_text, (box[0], box[1] - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
33
 
34
- return image_np
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  # Define the Gradio interface
37
  iface = gr.Interface(
38
- fn=image_object_detection,
39
- inputs=gr.Image(type="pil", label="Upload an image for object detection"),
40
- outputs="image",
41
  live=True,
 
42
  )
43
 
44
- # Launch the Gradio interface
45
  iface.launch()
 
1
  import gradio as gr
2
  import torch
3
+ import torchvision.transforms as transforms
4
+ from torchvision.models.detection import detr
5
  from PIL import Image
 
6
  import cv2
7
  import numpy as np
8
 
9
+ # Load the pretrained DETR model
10
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+ model = detr.DETR(resnet50=True)
12
+ model = model.to(device).eval()
13
 
14
+ # Define the transformation for the input image
15
+ transform = transforms.Compose([
16
+ transforms.ToTensor(),
17
+ transforms.Resize((800, 800)),
18
+ ])
19
 
20
+ # Define the object detection function
21
+ def detect_objects(frame):
22
+ # Convert the frame to PIL image
23
+ image = Image.fromarray(frame)
24
 
25
+ # Apply the transformation
26
+ image = transform(image).unsqueeze(0).to(device)
 
 
27
 
28
+ # Perform object detection
29
+ with torch.no_grad():
30
+ outputs = model(image)
 
 
 
31
 
32
+ # Get the bounding boxes and labels
33
+ boxes = outputs['pred_boxes'][0].cpu().numpy()
34
+ labels = outputs['pred_classes'][0].cpu().numpy()
35
+
36
+ # Draw bounding boxes on the frame
37
+ for box, label in zip(boxes, labels):
38
+ box = [int(coord) for coord in box]
39
+ frame = cv2.rectangle(frame, (box[0], box[1]), (box[2], box[3]), (0, 255, 0), 2)
40
+ frame = cv2.putText(frame, f'Class: {label}', (box[0], box[1] - 10),
41
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2, cv2.LINE_AA)
42
+
43
+ return frame
44
 
45
  # Define the Gradio interface
46
  iface = gr.Interface(
47
+ fn=detect_objects,
48
+ inputs=gr.Video(),
49
+ outputs="video",
50
  live=True,
51
+ capture_session=True,
52
  )
53
 
54
+ # Launch the Gradio app
55
  iface.launch()