mrdbourke commited on
Commit
675f436
·
verified ·
1 Parent(s): 2d1114b

Uploading Trashify box detection model app.py

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ examples/trashify_example_2.jpeg filter=lfs diff=lfs merge=lfs -text
.gradio/cached_examples/18/Image Output/66ec734ca428ae2384f6/image.webp ADDED
.gradio/cached_examples/18/Image Output/92cc1241b9494671fc05/image.webp ADDED
.gradio/cached_examples/18/log.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ Image Output,Text Output,timestamp
2
+ "{""path"": "".gradio/cached_examples/18/Image Output/92cc1241b9494671fc05/image.webp"", ""url"": ""/gradio_api/file=/tmp/gradio/a00bd5b7c75100f6f600a22625949c9350d2827637ab3e454535b4f44376dde0/image.webp"", ""size"": null, ""orig_name"": ""image.webp"", ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}","'+1! Found the following items: ['trash', 'bin', 'hand', 'not_trash', 'bin'], thank you for cleaning up the area!",2024-11-16 13:55:27.991471
3
+ "{""path"": "".gradio/cached_examples/18/Image Output/66ec734ca428ae2384f6/image.webp"", ""url"": ""/gradio_api/file=/tmp/gradio/b83f3584e66d5d7a6d26f3988d3b8c6cb39d94dd8433f94788676e9ec8c21327/image.webp"", ""size"": null, ""orig_name"": ""image.webp"", ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}","'+1! Found the following items: ['bin', 'trash', 'hand', 'not_trash', 'not_trash'], thank you for cleaning up the area!",2024-11-16 13:55:28.113367
app.py CHANGED
@@ -1,29 +1,44 @@
1
  import gradio as gr
2
  import torch
3
- from PIL import Image, ImageDraw
4
 
5
  from transformers import AutoImageProcessor
6
  from transformers import AutoModelForObjectDetection
7
 
8
- from PIL import Image
9
-
10
  model_save_path = "mrdbourke/detr_finetuned_trashify_box_detector"
11
 
 
12
  image_processor = AutoImageProcessor.from_pretrained(model_save_path)
13
  model = AutoModelForObjectDetection.from_pretrained(model_save_path)
14
 
 
 
 
 
15
  id2label = model.config.id2label
16
- color_dict = {
17
- "not_trash": "red",
 
18
  "bin": "green",
19
  "trash": "blue",
20
- "hand": "purple"
 
 
 
 
21
  }
22
 
23
- device = "cuda" if torch.cuda.is_available() else "cpu"
24
- model = model.to(device)
 
 
 
 
 
 
25
 
26
- def predict_on_image(image, conf_threshold=0.25):
27
  with torch.no_grad():
28
  inputs = image_processor(images=[image], return_tensors="pt")
29
  outputs = model(**inputs.to(device))
@@ -43,6 +58,12 @@ def predict_on_image(image, conf_threshold=0.25):
43
  # Can return results as plotted on a PIL image (then display the image)
44
  draw = ImageDraw.Draw(image)
45
 
 
 
 
 
 
 
46
  for box, score, label in zip(results["boxes"], results["scores"], results["labels"]):
47
  # Create coordinates
48
  x, y, x2, y2 = tuple(box.tolist())
@@ -50,6 +71,7 @@ def predict_on_image(image, conf_threshold=0.25):
50
  # Get label_name
51
  label_name = id2label[label.item()]
52
  targ_color = color_dict[label_name]
 
53
 
54
  # Draw the rectangle
55
  draw.rectangle(xy=(x, y, x2, y2),
@@ -62,23 +84,59 @@ def predict_on_image(image, conf_threshold=0.25):
62
  # Draw the text on the image
63
  draw.text(xy=(x, y),
64
  text=text_string_to_show,
65
- fill="white")
 
66
 
67
  # Remove the draw each time
68
  del draw
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
- return image
71
 
 
72
  demo = gr.Interface(
73
  fn=predict_on_image,
74
  inputs=[
75
- gr.Image(type="pil", label="Upload Target Image"),
76
  gr.Slider(minimum=0, maximum=1, value=0.25, label="Confidence Threshold")
77
  ],
78
- outputs=gr.Image(type="pil"),
 
 
 
79
  title="🚮 Trashify Object Detection Demo",
80
- description="Upload an image to detect whether there's a bin, a hand or trash in it."
 
 
 
 
 
 
81
  )
82
 
83
- if __name__ == "__main__":
84
- demo.launch()
 
1
  import gradio as gr
2
  import torch
3
+ from PIL import Image, ImageDraw, ImageFont
4
 
5
  from transformers import AutoImageProcessor
6
  from transformers import AutoModelForObjectDetection
7
 
8
+ # Note: Can load from Hugging Face or can load from local
 
9
  model_save_path = "mrdbourke/detr_finetuned_trashify_box_detector"
10
 
11
+ # Load the model and preprocessor
12
  image_processor = AutoImageProcessor.from_pretrained(model_save_path)
13
  model = AutoModelForObjectDetection.from_pretrained(model_save_path)
14
 
15
+ device = "cuda" if torch.cuda.is_available() else "cpu"
16
+ model = model.to(device)
17
+
18
+ # Get the id2label dictionary from the model
19
  id2label = model.config.id2label
20
+
21
+ # Set up a colour dictionary for plotting boxes with different colours
22
+ color_dict = {
23
  "bin": "green",
24
  "trash": "blue",
25
+ "hand": "purple",
26
+ "trash_arm": "yellow",
27
+ "not_trash": "red",
28
+ "not_bin": "red",
29
+ "not_hand": "red",
30
  }
31
 
32
+ # Create helper functions for seeing if items from one list are in another
33
+ def any_in_list(list_a, list_b):
34
+ "Returns True if any item from list_a is in list_b, otherwise False."
35
+ return any(item in list_b for item in list_a)
36
+
37
+ def all_in_list(list_a, list_b):
38
+ "Returns True if all items from list_a are in list_b, otherwise False."
39
+ return all(item in list_b for item in list_a)
40
 
41
+ def predict_on_image(image, conf_threshold):
42
  with torch.no_grad():
43
  inputs = image_processor(images=[image], return_tensors="pt")
44
  outputs = model(**inputs.to(device))
 
58
  # Can return results as plotted on a PIL image (then display the image)
59
  draw = ImageDraw.Draw(image)
60
 
61
+ # Get a font from ImageFont
62
+ font = ImageFont.load_default(size=20)
63
+
64
+ # Get class names as text for print out
65
+ class_name_text_labels = []
66
+
67
  for box, score, label in zip(results["boxes"], results["scores"], results["labels"]):
68
  # Create coordinates
69
  x, y, x2, y2 = tuple(box.tolist())
 
71
  # Get label_name
72
  label_name = id2label[label.item()]
73
  targ_color = color_dict[label_name]
74
+ class_name_text_labels.append(label_name)
75
 
76
  # Draw the rectangle
77
  draw.rectangle(xy=(x, y, x2, y2),
 
84
  # Draw the text on the image
85
  draw.text(xy=(x, y),
86
  text=text_string_to_show,
87
+ fill="white",
88
+ font=font)
89
 
90
  # Remove the draw each time
91
  del draw
92
+
93
+ # Setup blank string to print out
94
+ return_string = ""
95
+
96
+ # Setup list of target items to discover
97
+ target_items = ["trash", "bin", "hand"]
98
+
99
+ # If no items detected or trash, bin, hand not in list, return notification
100
+ if (len(class_name_text_labels) == 0) or not (any_in_list(list_a=target_items, list_b=class_name_text_labels)):
101
+ return_string = f"No trash, bin or hand detected at confidence threshold {conf_threshold}. Try another image or lowering the confidence threshold."
102
+ return image, return_string
103
+
104
+ # If there are some missing, print the ones which are missing
105
+ elif not all_in_list(list_a=target_items, list_b=class_name_text_labels):
106
+ missing_items = []
107
+ for item in target_items:
108
+ if item not in class_name_text_labels:
109
+ missing_items.append(item)
110
+ return_string = f"Detected the following items: {class_name_text_labels}. But missing the following in order to get +1: {missing_items}. If this is an error, try altering the confidence threshold."
111
+
112
+ # If all 3 trash, bin, hand occur = + 1
113
+ if all_in_list(list_a=target_items, list_b=class_name_text_labels):
114
+ return_string = f"+1! Found the following items: {class_name_text_labels}, thank you for cleaning up the area!"
115
+
116
+ print(return_string)
117
 
118
+ return image, return_string
119
 
120
+ # Create the interface
121
  demo = gr.Interface(
122
  fn=predict_on_image,
123
  inputs=[
124
+ gr.Image(type="pil", label="Target Image"),
125
  gr.Slider(minimum=0, maximum=1, value=0.25, label="Confidence Threshold")
126
  ],
127
+ outputs=[
128
+ gr.Image(type="pil", label="Image Output"),
129
+ gr.Text(label="Text Output")
130
+ ],
131
  title="🚮 Trashify Object Detection Demo",
132
+ description="Help clean up your local area! Upload an image and get +1 if there is all of the following items detected: trash, bin, hand.",
133
+ # Examples come in the form of a list of lists, where each inner list contains elements to prefill the `inputs` parameter with
134
+ examples=[
135
+ ["examples/trashify_example_1.jpeg", 0.25],
136
+ ["examples/trashify_example_2.jpeg", 0.25]
137
+ ],
138
+ cache_examples=True
139
  )
140
 
141
+ # Launch the demo
142
+ demo.launch()
examples/trashify_example_1.jpeg ADDED
examples/trashify_example_2.jpeg ADDED

Git LFS Details

  • SHA256: 89ed8acec03b7890e5d2e6fa509c7e842e70a6dd9f6ad4e37d5d1431a1081be7
  • Pointer size: 132 Bytes
  • Size of remote file: 1.07 MB
requirements.txt CHANGED
@@ -1,4 +1,4 @@
 
1
  gradio
2
  torch
3
  transformers
4
- timm
 
1
+ timm
2
  gradio
3
  torch
4
  transformers