ObjectDetection / app.py
DexterSptizu's picture
Rename app.py.txt to app.py
7912672
raw
history blame
2.17 kB
import io
from random import choice
from PIL import Image
import gradio as gr
from transformers import pipeline
import matplotlib.pyplot as plt
# Initialize the models
detector50 = pipeline(model="facebook/detr-resnet-50")
detector101 = pipeline(model="facebook/detr-resnet-101")
# Define colors and font dictionary for bounding boxes and labels
COLORS = ["#ff7f7f", "#ff7fbf", "#ff7fff", "#bf7fff",
"#7f7fff", "#7fbfff", "#7fffff", "#7fffbf",
"#7fff7f", "#bfff7f", "#ffff7f", "#ffbf7f"]
fdic = {
"family": "Impact",
"style": "italic",
"size": 15,
"color": "yellow",
"weight": "bold"
}
def get_figure(in_pil_img, in_results):
# Create a figure to display the image and annotations
plt.figure(figsize=(16, 10))
plt.imshow(in_pil_img)
ax = plt.gca()
# Add bounding boxes and labels to the image
for prediction in in_results:
selected_color = choice(COLORS)
x, y = prediction['box']['xmin'], prediction['box']['ymin']
w, h = prediction['box']['xmax'] - x, prediction['box']['ymax'] - y
ax.add_patch(plt.Rectangle((x, y), w, h, fill=False, color=selected_color, linewidth=3))
ax.text(x, y, f"{prediction['label']}: {round(prediction['score']*100, 1)}%", fontdict=fdic)
plt.axis("off")
plt.tight_layout()
# Convert the figure to a PIL Image and return
buf = io.BytesIO()
plt.savefig(buf, format='png', bbox_inches='tight')
buf.seek(0)
return Image.open(buf)
def infer(model, in_pil_img):
# Perform inference using the specified model and input image
results = detector101(in_pil_img) if model == "detr-resnet-101" else detector50(in_pil_img)
return get_figure(in_pil_img, results)
# Define Gradio interface
with gr.Blocks() as demo:
gr.Markdown("## DETR Object Detection")
model = gr.Radio(["detr-resnet-50", "detr-resnet-101"], value="detr-resnet-50", label="Model name")
input_image = gr.Image(label="Input image", type="pil")
output_image = gr.Image(label="Output image")
send_btn = gr.Button("Infer")
send_btn.click(fn=infer, inputs=[model, input_image], outputs=output_image)
demo.launch()