camcounter / app.py
andrewgleave's picture
Working prototype
19a011f
raw
history blame
2.11 kB
import io
import torch
import gradio as gr
import matplotlib
import matplotlib.pyplot as plt
from PIL import Image
from transformers import AutoFeatureExtractor, AutoModelForObjectDetection
extractor = AutoFeatureExtractor.from_pretrained("hustvl/yolos-tiny")
model = AutoModelForObjectDetection.from_pretrained("hustvl/yolos-tiny")
matplotlib.pyplot.switch_backend('Agg')
COLORS = [
[0.000, 0.447, 0.741],
[0.850, 0.325, 0.098],
[0.929, 0.694, 0.125],
[0.494, 0.184, 0.556],
[0.466, 0.674, 0.188],
[0.301, 0.745, 0.933]
]
PRED_THRESHOLD = 0.90
def fig2img(fig):
buf = io.BytesIO()
fig.savefig(buf)
buf.seek(0)
img = Image.open(buf)
return img
def composite_predictions(img, processed_predictions):
keep = processed_predictions["labels"] == 1 # only interested in people
boxes = processed_predictions["boxes"][keep].tolist()
scores = processed_predictions["scores"][keep].tolist()
labels = processed_predictions["labels"][keep].tolist()
labels = [model.config.id2label[x] for x in labels]
plt.figure(figsize=(16, 10))
plt.imshow(img)
axis = plt.gca()
colors = COLORS * 100
for score, (xmin, ymin, xmax, ymax), label, color in zip(scores, boxes, labels, colors):
axis.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, color=color, linewidth=3))
axis.text(xmin, ymin, f"{label}: {score:0.2f}", fontsize=15, bbox=dict(facecolor="yellow", alpha=0.5))
plt.axis("off")
img = fig2img(plt.gcf())
matplotlib.pyplot.close()
return img
def process(img):
inputs = extractor(images=img, return_tensors="pt")
outputs = model(**inputs)
img_size = torch.tensor([tuple(reversed(img.size))])
processed = extractor.post_process_object_detection(outputs, PRED_THRESHOLD, img_size)
# Composite image and prediction bounding boxes + labels prediction
return composite_predictions(img, processed[0])
demo = gr.Interface(fn=process, inputs=[gr.Image(source="webcam", streaming=True, type='pil')], outputs=["image"], live=True)
demo.launch()