mrdbourke commited on
Commit
439a441
·
verified ·
1 Parent(s): 4c663d9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -0
app.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from PIL import Image, ImageDraw
4
+
5
+ from transformers import AutoImageProcessor
6
+ from transformers import AutoModelForObjectDetection
7
+
8
+ from PIL import Image
9
+
10
+ model_save_path = "mrdbourke/detr_finetuned_trashify_box_detector"
11
+
12
+ image_processor = AutoImageProcessor.from_pretrained(model_save_path)
13
+ model = AutoModelForObjectDetection.from_pretrained(model_save_path)
14
+
15
+ id2label = model.config.id2label
16
+ color_dict = {
17
+ "not_trash": "red",
18
+ "bin": "green",
19
+ "trash": "blue",
20
+ "hand": "purple"
21
+ }
22
+
23
+ device = "cuda" if torch.cuda.is_available() else "cpu"
24
+ model = model.to(device)
25
+
26
+ def predict_on_image(image, conf_threshold=0.25):
27
+ with torch.no_grad():
28
+ inputs = image_processor(images=[image], return_tensors="pt")
29
+ outputs = model(**inputs.to(device))
30
+
31
+ target_sizes = torch.tensor([[image.size[1], image.size[0]]]) # height, width
32
+
33
+ results = image_processor.post_process_object_detection(outputs,
34
+ threshold=conf_threshold,
35
+ target_sizes=target_sizes)[0]
36
+ # Return all items in results to CPU
37
+ for key, value in results.items():
38
+ try:
39
+ results[key] = value.item().cpu() # can't get scalar as .item() so add try/except block
40
+ except:
41
+ results[key] = value.cpu()
42
+
43
+ # Can return results as plotted on a PIL image (then display the image)
44
+ draw = ImageDraw.Draw(image)
45
+
46
+ for box, score, label in zip(results["boxes"], results["scores"], results["labels"]):
47
+ # Create coordinates
48
+ x, y, x2, y2 = tuple(box.tolist())
49
+
50
+ # Get label_name
51
+ label_name = id2label[label.item()]
52
+ targ_color = color_dict[label_name]
53
+
54
+ # Draw the rectangle
55
+ draw.rectangle(xy=(x, y, x2, y2),
56
+ outline=targ_color,
57
+ width=3)
58
+
59
+ # Create a text string to display
60
+ text_string_to_show = f"{label_name} ({round(score.item(), 3)})"
61
+
62
+ # Draw the text on the image
63
+ draw.text(xy=(x, y),
64
+ text=text_string_to_show,
65
+ fill="white")
66
+
67
+ # Remove the draw each time
68
+ del draw
69
+
70
+ return image
71
+
72
+ demo = gr.Interface(
73
+ fn=predict_on_image,
74
+ inputs=[
75
+ gr.Image(type="pil", label="Upload Target Image"),
76
+ gr.Slider(minimum=0, maximum=1, value=0.25, label="Confidence Threshold")
77
+ ],
78
+ outputs=gr.Image(type="pil"),
79
+ title="🚮 Trashify Object Detection Demo",
80
+ description="Upload an image to detect whether there's a bin, a hand or trash in it."
81
+ )
82
+
83
+ if __name__ == "__main__":
84
+ demo.launch()