BasToTheMax commited on
Commit
3048b3b
·
verified ·
1 Parent(s): 9d47cfa

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -0
app.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import DetrImageProcessor, DetrForObjectDetection
2
+ import torch
3
+ from PIL import Image, ImageDraw
4
+ import gradio as gr
5
+ import requests
6
+ import random
7
+
8
+ def detect_objects(image):
9
+ # Load the pre-trained DETR model
10
+ processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
11
+ model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
12
+
13
+ inputs = processor(images=image, return_tensors="pt")
14
+ outputs = model(**inputs)
15
+
16
+ # convert outputs (bounding boxes and class logits) to COCO API
17
+ # let's only keep detections with score > 0.9
18
+ target_sizes = torch.tensor([image.size[::-1]])
19
+ results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
20
+
21
+ # Draw bounding boxes and labels on the image
22
+ #draw = ImageDraw.Draw(image)
23
+ #for i, (score, label, box) in enumerate(zip(results["scores"], results["labels"], results["boxes"])):
24
+ # box = [round(i, 2) for i in box.tolist()]
25
+ # color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
26
+ # draw.rectangle(box, outline=color, width=3)
27
+ # draw.text((box[0], box[1]), model.config.id2label[label.item()], fill=color)
28
+ res = []
29
+ for label in results["labels"]:
30
+ res.append(model.config.id2label[label.item()])
31
+
32
+ return ','.join(res)
33
+
34
+ def upload_image(file):
35
+ image = Image.open(file.name)
36
+ image_with_boxes = detect_objects(image)
37
+ return image_with_boxes
38
+
39
+ iface = gr.Interface(
40
+ fn=upload_image,
41
+ inputs="file",
42
+ outputs="text",
43
+ title="Object Detection",
44
+ description="Upload an image and detect objects using DETR model.",
45
+ allow_flagging=False
46
+ )
47
+
48
+ iface.launch()