ObjectDetection / app.py
DexterSptizu's picture
Update app.py
489c028 verified
raw
history blame
3.18 kB
import io
from random import choice
from PIL import Image
import gradio as gr
from transformers import pipeline
import matplotlib.pyplot as plt
import requests
# Initialize the models
detector50 = pipeline(model="facebook/detr-resnet-50")
detector101 = pipeline(model="facebook/detr-resnet-101")
# Define colors and font dictionary for bounding boxes and labels
COLORS = ["#ff7f7f", "#ff7fbf", "#ff7fff", "#bf7fff",
"#7f7fff", "#7fbfff", "#7fffff", "#7fffbf",
"#7fff7f", "#bfff7f", "#ffff7f", "#ffbf7f"]
fdic = {
"family": "Impact",
"style": "italic",
"size": 15,
"color": "yellow",
"weight": "bold"
}
def get_figure(in_pil_img, in_results):
# Create a figure to display the image and annotations
plt.figure(figsize=(16, 10))
plt.imshow(in_pil_img)
ax = plt.gca()
# Add bounding boxes and labels to the image
for prediction in in_results:
selected_color = choice(COLORS)
x, y = prediction['box']['xmin'], prediction['box']['ymin']
w, h = prediction['box']['xmax'] - x, prediction['box']['ymax'] - y
ax.add_patch(plt.Rectangle((x, y), w, h, fill=False, color=selected_color, linewidth=3))
ax.text(x, y, f"{prediction['label']}: {round(prediction['score']*100, 1)}%", fontdict=fdic)
plt.axis("off")
plt.tight_layout()
# Convert the figure to a PIL Image and return
buf = io.BytesIO()
plt.savefig(buf, format='png', bbox_inches='tight')
buf.seek(0)
return Image.open(buf)
def infer(model, in_pil_img):
# Perform inference using the specified model and input image
results = detector101(in_pil_img) if model == "detr-resnet-101" else detector50(in_pil_img)
return get_figure(in_pil_img, results)
def url_to_image(url):
response = requests.get(url)
img = Image.open(io.BytesIO(response.content))
return img
# Define Gradio interface
with gr.Blocks() as demo:
gr.Markdown("## DETR Object Detection")
model = gr.Radio(["detr-resnet-50", "detr-resnet-101"], value="detr-resnet-50", label="Model name")
# Example images from the internet
example_urls = [
"https://www.quickanddirtytips.com/wp-content/uploads/2022/05/ezgif.com-gif-maker-3.jpg",
"https://img.freepik.com/free-photo/people-posing-together-registration-day_23-2149096794.jpg",
"https://www.shutterstock.com/shutterstock/photos/2077329079/display_1500/stock-photo-pushkar-india-nov-people-walking-in-a-busy-road-in-the-street-market-in-holy-city-2077329079.jpg",
"https://static.autox.com/uploads/2023/01/mahindra-xuv400.jpg",
"https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcQrmOeQJN9IzbR8ofJDpxcIHDSOSXg_5hzgfA&s"
]
examples = gr.Examples(
examples=example_urls,
inputs=[gr.Image(type="pil")],
fn=lambda x: url_to_image(x),
label="Try these example images"
)
input_image = gr.Image(label="Input image", type="pil")
output_image = gr.Image(label="Output image")
send_btn = gr.Button("Infer")
# Trigger inference on button click
send_btn.click(fn=infer, inputs=[model, input_image], outputs=output_image)
demo.launch()