Spaces:
Running
Running
import gradio as gr | |
import torch | |
from PIL import Image, ImageDraw, ImageFont | |
from transformers import AutoImageProcessor | |
from transformers import AutoModelForObjectDetection | |
# Note: Can load from Hugging Face or can load from local | |
model_save_path = "mrdbourke/detr_finetuned_trashify_box_detector" | |
# Load the model and preprocessor | |
image_processor = AutoImageProcessor.from_pretrained(model_save_path) | |
model = AutoModelForObjectDetection.from_pretrained(model_save_path) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model = model.to(device) | |
# Get the id2label dictionary from the model | |
id2label = model.config.id2label | |
# Set up a colour dictionary for plotting boxes with different colours | |
color_dict = { | |
"bin": "green", | |
"trash": "blue", | |
"hand": "purple", | |
"trash_arm": "yellow", | |
"not_trash": "red", | |
"not_bin": "red", | |
"not_hand": "red", | |
} | |
# Create helper functions for seeing if items from one list are in another | |
def any_in_list(list_a, list_b): | |
"Returns True if any item from list_a is in list_b, otherwise False." | |
return any(item in list_b for item in list_a) | |
def all_in_list(list_a, list_b): | |
"Returns True if all items from list_a are in list_b, otherwise False." | |
return all(item in list_b for item in list_a) | |
def predict_on_image(image, conf_threshold): | |
with torch.no_grad(): | |
inputs = image_processor(images=[image], return_tensors="pt") | |
outputs = model(**inputs.to(device)) | |
target_sizes = torch.tensor([[image.size[1], image.size[0]]]) # height, width | |
results = image_processor.post_process_object_detection(outputs, | |
threshold=conf_threshold, | |
target_sizes=target_sizes)[0] | |
# Return all items in results to CPU | |
for key, value in results.items(): | |
try: | |
results[key] = value.item().cpu() # can't get scalar as .item() so add try/except block | |
except: | |
results[key] = value.cpu() | |
# Can return results as plotted on a PIL image (then display the image) | |
draw = ImageDraw.Draw(image) | |
# Get a font from ImageFont | |
font = ImageFont.load_default(size=20) | |
# Get class names as text for print out | |
class_name_text_labels = [] | |
for box, score, label in zip(results["boxes"], results["scores"], results["labels"]): | |
# Create coordinates | |
x, y, x2, y2 = tuple(box.tolist()) | |
# Get label_name | |
label_name = id2label[label.item()] | |
targ_color = color_dict[label_name] | |
class_name_text_labels.append(label_name) | |
# Draw the rectangle | |
draw.rectangle(xy=(x, y, x2, y2), | |
outline=targ_color, | |
width=3) | |
# Create a text string to display | |
text_string_to_show = f"{label_name} ({round(score.item(), 3)})" | |
# Draw the text on the image | |
draw.text(xy=(x, y), | |
text=text_string_to_show, | |
fill="white", | |
font=font) | |
# Remove the draw each time | |
del draw | |
# Setup blank string to print out | |
return_string = "" | |
# Setup list of target items to discover | |
target_items = ["trash", "bin", "hand"] | |
# If no items detected or trash, bin, hand not in list, return notification | |
if (len(class_name_text_labels) == 0) or not (any_in_list(list_a=target_items, list_b=class_name_text_labels)): | |
return_string = f"No trash, bin or hand detected at confidence threshold {conf_threshold}. Try another image or lowering the confidence threshold." | |
return image, return_string | |
# If there are some missing, print the ones which are missing | |
elif not all_in_list(list_a=target_items, list_b=class_name_text_labels): | |
missing_items = [] | |
for item in target_items: | |
if item not in class_name_text_labels: | |
missing_items.append(item) | |
return_string = f"Detected the following items: {class_name_text_labels}. But missing the following in order to get +1: {missing_items}. If this is an error, try another image or altering the confidence threshold. Otherwise, the model may need to be updated with better data." | |
# If all 3 trash, bin, hand occur = + 1 | |
if all_in_list(list_a=target_items, list_b=class_name_text_labels): | |
return_string = f"+1! Found the following items: {class_name_text_labels}, thank you for cleaning up the area!" | |
print(return_string) | |
return image, return_string | |
# Create the interface | |
demo = gr.Interface( | |
fn=predict_on_image, | |
inputs=[ | |
gr.Image(type="pil", label="Target Image"), | |
gr.Slider(minimum=0, maximum=1, value=0.25, label="Confidence Threshold") | |
], | |
outputs=[ | |
gr.Image(type="pil", label="Image Output"), | |
gr.Text(label="Text Output") | |
], | |
title="๐ฎ Trashify Object Detection Demo V1", | |
description="Help clean up your local area! Upload an image and get +1 if there is all of the following items detected: trash, bin, hand.", | |
# Examples come in the form of a list of lists, where each inner list contains elements to prefill the `inputs` parameter with | |
examples=[ | |
["examples/trashify_example_1.jpeg", 0.25], | |
["examples/trashify_example_2.jpeg", 0.25], | |
["examples/trashify_example_3.jpeg", 0.25], | |
], | |
cache_examples=True | |
) | |
# Launch the demo | |
demo.launch() | |