Spaces:
Sleeping
Sleeping
File size: 9,718 Bytes
df8b8a4 bb0d52f df8b8a4 bb0d52f df8b8a4 bb0d52f df8b8a4 bb0d52f df8b8a4 bb0d52f df8b8a4 bb0d52f df8b8a4 bb0d52f df8b8a4 bb0d52f df8b8a4 bb0d52f df8b8a4 bb0d52f df8b8a4 bb0d52f df8b8a4 bb0d52f df8b8a4 bb0d52f df8b8a4 bb0d52f df8b8a4 bb0d52f df8b8a4 bb0d52f d50f451 bb0d52f df8b8a4 bb0d52f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 |
import gradio as gr
import torch
from PIL import Image, ImageDraw, ImageFont
from transformers import AutoImageProcessor
from transformers import AutoModelForObjectDetection
# Note: Can load from Hugging Face or can load from local.
# You will have to replace {mrdbourke} for your own username if the model is on your Hugging Face account.
model_save_path = "mrdbourke/detr_finetuned_trashify_box_detector_with_data_aug"
# Load the model and preprocessor
image_processor = AutoImageProcessor.from_pretrained(model_save_path)
model = AutoModelForObjectDetection.from_pretrained(model_save_path)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
# Get the id2label dictionary from the model
id2label = model.config.id2label
# Set up a colour dictionary for plotting boxes with different colours
color_dict = {
"bin": "green",
"trash": "blue",
"hand": "purple",
"trash_arm": "yellow",
"not_trash": "red",
"not_bin": "red",
"not_hand": "red",
}
# Create helper functions for seeing if items from one list are in another
def any_in_list(list_a, list_b):
"Returns True if any item from list_a is in list_b, otherwise False."
return any(item in list_b for item in list_a)
def all_in_list(list_a, list_b):
"Returns True if all items from list_a are in list_b, otherwise False."
return all(item in list_b for item in list_a)
def filter_highest_scoring_box_per_class(boxes, labels, scores):
"""
Perform NMS (Non-max Supression) to only keep the top scoring box per class.
Args:
boxes: tensor of shape (N, 4)
labels: tensor of shape (N,)
scores: tensor of shape (N,)
Returns:
boxes: tensor of shape (N, 4) filtered for max scoring item per class
labels: tensor of shape (N,) filtered for max scoring item per class
scores: tensor of shape (N,) filtered for max scoring item per class
"""
# Start with a blank keep mask (e.g. all False and then update the boxes to keep with True)
keep_mask = torch.zeros(len(boxes), dtype=torch.bool)
# For each unique class
for class_id in labels.unique():
# Get the indicies for the target class
class_mask = labels == class_id
# If any of the labels match the current class_id
if class_mask.any():
# Find the index of highest scoring box for this specific class
class_scores = scores[class_mask]
highest_score_idx = class_scores.argmax()
# Convert back to the original index
original_idx = torch.where(class_mask)[0][highest_score_idx]
# Update the index in the keep mask to keep the highest scoring box
keep_mask[original_idx] = True
return boxes[keep_mask], labels[keep_mask], scores[keep_mask]
def create_return_string(list_of_predicted_labels, target_items=["trash", "bin", "hand"]):
# Setup blank string to print out
return_string = ""
# If no items detected or trash, bin, hand not in list, return notification
if (len(list_of_predicted_labels) == 0) or not (any_in_list(list_a=target_items, list_b=list_of_predicted_labels)):
return_string = f"No trash, bin or hand detected at confidence threshold {conf_threshold}. Try another image or lowering the confidence threshold."
return return_string
# If there are some missing, print the ones which are missing
elif not all_in_list(list_a=target_items, list_b=list_of_predicted_labels):
missing_items = []
for item in target_items:
if item not in list_of_predicted_labels:
missing_items.append(item)
return_string = f"Detected the following items: {list_of_predicted_labels} (total: {len(list_of_predicted_labels)}). But missing the following in order to get +1: {missing_items}. If this is an error, try another image or altering the confidence threshold. Otherwise, the model may need to be updated with better data."
# If all 3 trash, bin, hand occur = + 1
if all_in_list(list_a=target_items, list_b=list_of_predicted_labels):
return_string = f"+1! Found the following items: {list_of_predicted_labels} (total: {len(list_of_predicted_labels)}), thank you for cleaning up the area!"
print(return_string)
return return_string
def predict_on_image(image, conf_threshold):
with torch.no_grad():
inputs = image_processor(images=[image], return_tensors="pt")
outputs = model(**inputs.to(device))
target_sizes = torch.tensor([[image.size[1], image.size[0]]]) # height, width
results = image_processor.post_process_object_detection(outputs,
threshold=conf_threshold,
target_sizes=target_sizes)[0]
# Return all items in results to CPU
for key, value in results.items():
try:
results[key] = value.item().cpu() # can't get scalar as .item() so add try/except block
except:
results[key] = value.cpu()
# Can return results as plotted on a PIL image (then display the image)
draw = ImageDraw.Draw(image)
# Create a copy of the image to draw on it for NMS
image_nms = image.copy()
draw_nms = ImageDraw.Draw(image_nms)
# Get a font from ImageFont
font = ImageFont.load_default(size=20)
# Get class names as text for print out
class_name_text_labels = []
# TK - update this for NMS
class_name_text_labels_nms = []
# Get original boxes, scores, labels
original_boxes = results["boxes"]
original_labels = results["labels"]
original_scores = results["scores"]
# Filter boxes and only keep 1x of each label with highest score
filtered_boxes, filtered_labels, filtered_scores = filter_highest_scoring_box_per_class(boxes=original_boxes,
labels=original_labels,
scores=original_scores)
# TODO: turn this into a function so it's cleaner?
for box, label, score in zip(original_boxes, original_labels, original_scores):
# Create coordinates
x, y, x2, y2 = tuple(box.tolist())
# Get label_name
label_name = id2label[label.item()]
targ_color = color_dict[label_name]
class_name_text_labels.append(label_name)
# Draw the rectangle
draw.rectangle(xy=(x, y, x2, y2),
outline=targ_color,
width=3)
# Create a text string to display
text_string_to_show = f"{label_name} ({round(score.item(), 3)})"
# Draw the text on the image
draw.text(xy=(x, y),
text=text_string_to_show,
fill="white",
font=font)
# TODO: turn this into a function so it's cleaner?
for box, label, score in zip(filtered_boxes, filtered_labels, filtered_scores):
# Create coordinates
x, y, x2, y2 = tuple(box.tolist())
# Get label_name
label_name = id2label[label.item()]
targ_color = color_dict[label_name]
class_name_text_labels_nms.append(label_name)
# Draw the rectangle
draw_nms.rectangle(xy=(x, y, x2, y2),
outline=targ_color,
width=3)
# Create a text string to display
text_string_to_show = f"{label_name} ({round(score.item(), 3)})"
# Draw the text on the image
draw_nms.text(xy=(x, y),
text=text_string_to_show,
fill="white",
font=font)
# Remove the draw each time
del draw
del draw_nms
# Create the return string
return_string = create_return_string(list_of_predicted_labels=class_name_text_labels)
return_string_nms = create_return_string(list_of_predicted_labels=class_name_text_labels_nms)
return image, return_string, image_nms, return_string_nms
# Create the interface
demo = gr.Interface(
fn=predict_on_image,
inputs=[
gr.Image(type="pil", label="Target Image"),
gr.Slider(minimum=0, maximum=1, value=0.25, label="Confidence Threshold")
],
outputs=[
gr.Image(type="pil", label="Image Output (no filtering)"),
gr.Text(label="Text Output (no filtering)"),
gr.Image(type="pil", label="Image Output (with max score per class box filtering)"),
gr.Text(label="Text Output (with max score per class box filtering)")
],
title="🚮 Trashify Object Detection Demo V3",
description="""Help clean up your local area! Upload an image and get +1 if there is all of the following items detected: trash, bin, hand.
The model in V3 is [same model](https://huggingface.co/mrdbourke/detr_finetuned_trashify_box_detector_with_data_aug) as in [V2](https://huggingface.co/spaces/mrdbourke/trashify_demo_v2) (trained with data augmentation) but has an additional post-processing step (NMS or [Non Maximum Suppression](https://paperswithcode.com/method/non-maximum-suppression)) to filter classes for only the highest scoring box of each class.
""",
# Examples come in the form of a list of lists, where each inner list contains elements to prefill the `inputs` parameter with
examples=[
["examples/trashify_example_1.jpeg", 0.25],
["examples/trashify_example_2.jpeg", 0.25],
["examples/trashify_example_3.jpeg", 0.25]
],
cache_examples=True
)
# Launch the demo
demo.launch()
|