ObjectDetection / app.py
DexterSptizu's picture
Update app.py
5314499 verified
raw
history blame
2.48 kB
import io
from random import choice
from PIL import Image
import gradio as gr
from transformers import pipeline
import matplotlib.pyplot as plt
# Initialize the models
detector50 = pipeline(model="facebook/detr-resnet-50")
detector101 = pipeline(model="facebook/detr-resnet-101")
# Define colors and font dictionary for bounding boxes and labels
COLORS = ["#ff7f7f", "#ff7fbf", "#ff7fff", "#bf7fff",
"#7f7fff", "#7fbfff", "#7fffff", "#7fffbf",
"#7fff7f", "#bfff7f", "#ffff7f", "#ffbf7f"]
fdic = {
"family": "Impact",
"style": "italic",
"size": 15,
"color": "yellow",
"weight": "bold"
}
def get_figure(in_pil_img, in_results):
# Create a figure to display the image and annotations
plt.figure(figsize=(16, 10))
plt.imshow(in_pil_img)
ax = plt.gca()
# Add bounding boxes and labels to the image
for prediction in in_results:
selected_color = choice(COLORS)
x, y = prediction['box']['xmin'], prediction['box']['ymin']
w, h = prediction['box']['xmax'] - x, prediction['box']['ymax'] - y
ax.add_patch(plt.Rectangle((x, y), w, h, fill=False, color=selected_color, linewidth=3))
ax.text(x, y, f"{prediction['label']}: {round(prediction['score']*100, 1)}%", fontdict=fdic)
plt.axis("off")
plt.tight_layout()
# Convert the figure to a PIL Image and return
buf = io.BytesIO()
plt.savefig(buf, format='png', bbox_inches='tight')
buf.seek(0)
return Image.open(buf)
def infer(model, in_pil_img):
# Perform inference using the specified model and input image
results = detector101(in_pil_img) if model == "detr-resnet-101" else detector50(in_pil_img)
return get_figure(in_pil_img, results)
# Define Gradio interface with local image examples
with gr.Blocks() as demo:
gr.Markdown("## DETR Object Detection")
model = gr.Radio(["detr-resnet-50", "detr-resnet-101"], value="detr-resnet-50", label="Model name")
# Use local image files instead of URLs
examples = gr.Examples(
examples=[
["image1.jpg"],
["image2.jpg"]
],
inputs=[gr.Image(type="pil")],
label="Try these example images"
)
input_image = gr.Image(label="Input image", type="pil")
output_image = gr.Image(label="Output image")
send_btn = gr.Button("Infer")
# Trigger inference on button click
send_btn.click(fn=infer, inputs=[model, input_image], outputs=output_image)
demo.launch()