Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from io import BytesIO | |
import os | |
import cv2 | |
import gradio as gr | |
import numpy as np | |
import requests | |
from PIL import Image | |
import gradio as gr | |
import cv2 | |
import tempfile | |
import numpy as np | |
import torch | |
from torchvision import transforms | |
from PIL import Image | |
import matplotlib.pyplot as plt | |
from io import BytesIO | |
# Load the YOLO model | |
from models.common import DetectMultiBackend | |
weights_path = "./last.pt" | |
device = torch.device("cpu") # Correctly define the device | |
model = DetectMultiBackend(weights_path, device=device) # Load YOLOv5 model correctly | |
model.eval() | |
# model_path = "./last.pt" | |
# model = torch.jit.load(model_path, map_location=torch.device("cpu")) | |
# model.eval() | |
# transform=transforms.Compose([ | |
# transforms.ToPILImage(), | |
# transforms.Resize((512,640)), | |
# transforms.ToTensor() | |
# ]) | |
transform = transforms.Compose([ | |
transforms.ToPILImage(), # Ensure input is a PIL image | |
transforms.Resize((512, 640)), | |
transforms.ToTensor() | |
]) | |
# transform = transforms.Compose([ | |
# transforms.Resize((640, 640)), | |
# transforms.ToTensor(), | |
# ]) | |
OBJECT_NAMES = ['enemies'] | |
def detect_objects_in_image(image): | |
""" | |
Detect objects in the given image. | |
""" | |
# Ensure image is a PIL Image | |
if isinstance(image, torch.Tensor): | |
image = transforms.ToPILImage()(image) # Convert tensor to PIL image | |
if isinstance(image, Image.Image): | |
orig_w, orig_h = image.size # PIL image size returns (width, height) | |
else: | |
raise TypeError(f"Expected a PIL Image but got {type(image)}") | |
# Apply transformation | |
img_tensor = transform(image).unsqueeze(0) | |
with torch.no_grad(): | |
pred = model(img_tensor)[0] | |
if isinstance(pred[0], torch.Tensor): | |
pred = [p.cpu().numpy() for p in pred] | |
pred = np.concatenate(pred, axis=0) | |
conf_thres = 0.25 | |
mask = pred[:, 4] > conf_thres | |
pred = pred[mask] | |
if len(pred) == 0: | |
return Image.fromarray(np.array(image)), None # Return only image and None for graph | |
boxes, scores, class_probs = pred[:, :4], pred[:, 4], pred[:, 5:] | |
class_ids = np.argmax(class_probs, axis=1) | |
boxes[:, 0] = boxes[:, 0] - (boxes[:, 2] / 2) | |
boxes[:, 1] = boxes[:, 1] - (boxes[:, 3] / 2) | |
boxes[:, 2] = boxes[:, 0] + boxes[:, 2] | |
boxes[:, 3] = boxes[:, 1] + boxes[:, 3] | |
boxes[:, [0, 2]] *= orig_w / 640 | |
boxes[:, [1, 3]] *= orig_h / 640 | |
boxes = np.clip(boxes, 0, [orig_w, orig_h, orig_w, orig_h]) | |
indices = cv2.dnn.NMSBoxes(boxes.tolist(), scores.tolist(), conf_thres, 0.5) | |
object_counts = {name: 0 for name in OBJECT_NAMES} | |
img_array = np.array(image) | |
if len(indices) > 0: | |
for i in indices.flatten(): | |
x1, y1, x2, y2 = map(int, boxes[i]) | |
cls = class_ids[i] | |
object_name = OBJECT_NAMES[cls] if cls < len(OBJECT_NAMES) else f"Unknown ({cls})" | |
if object_name in object_counts: | |
object_counts[object_name] += 1 | |
cv2.rectangle(img_array, (x1, y1), (x2, y2), (0, 255, 0), 2) | |
cv2.putText(img_array, f"{object_name}: {scores[i]:.2f}", (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2) | |
# Generate and return graph instead of dictionary | |
graph_image = generate_vehicle_count_graph(object_counts) | |
return Image.fromarray(img_array), graph_image # Now returning only 2 outputs | |
# def generate_vehicle_count_graph(object_counts): | |
# color_palette = ['#4C9ACD', '#88B8A3', '#7F9C9C', '#D1A3B5', '#A1C6EA', '#FFB6C1', '#F0E68C', '#D3B0D8', '#F8A5D1', '#B8B8D1'] | |
# fig, ax = plt.subplots(figsize=(8, 5)) | |
# labels = list(object_counts.keys()) | |
# values = list(object_counts.values()) | |
# ax.bar(labels, values, color=color_palette[:len(labels)]) | |
# ax.set_xlabel("Vehicle Categories", fontsize=12, fontweight='bold') | |
# ax.set_ylabel("Number of Vehicles", fontsize=12, fontweight='bold') | |
# ax.set_title("Detected Vehicles in Image", fontsize=14, fontweight='bold') | |
# plt.xticks(rotation=45, ha='right', fontsize=10) | |
# plt.yticks(fontsize=10) | |
# plt.tight_layout() | |
# buf = BytesIO() | |
# plt.savefig(buf, format='png') | |
# buf.seek(0) | |
# return Image.open(buf) | |
def generate_vehicle_count_graph(object_counts): | |
color_palette = ['#4C9ACD', '#88B8A3', '#7F9C9C', '#D1A3B5', '#A1C6EA', '#FFB6C1', '#F0E68C', '#D3B0D8', '#F8A5D1', '#B8B8D1'] | |
fig, ax = plt.subplots(figsize=(8, 5)) | |
labels = list(object_counts.keys()) | |
values = list(object_counts.values()) | |
ax.bar(labels, values, color=color_palette[:len(labels)]) | |
ax.set_xlabel("Vehicle Categories", fontsize=12, fontweight='bold') | |
ax.set_ylabel("Number of Vehicles", fontsize=12, fontweight='bold') | |
ax.set_title("Detected Vehicles in Image", fontsize=14, fontweight='bold') | |
plt.xticks(rotation=45, ha='right', fontsize=10) | |
plt.yticks(fontsize=10) | |
plt.tight_layout() | |
buf = BytesIO() | |
plt.savefig(buf, format='png') | |
buf.seek(0) | |
plt.close(fig) # ✅ CLOSE THE FIGURE TO FREE MEMORY | |
return Image.open(buf) | |
def detect_objects_in_video(video_input): | |
cap = cv2.VideoCapture(video_input) | |
if not cap.isOpened(): | |
return "Error: Cannot open video file.", None # Returning a second value (None) to match expected outputs | |
frame_width, frame_height, fps = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)), int(cap.get(cv2.CAP_PROP_FPS)) | |
temp_video_output = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name | |
out = cv2.VideoWriter(temp_video_output, cv2.VideoWriter_fourcc(*'mp4v'), fps, (frame_width, frame_height)) | |
# Initialize the counts for vehicle categories | |
total_counts = {name: 0 for name in ['car', 'truck', 'bus', 'motorcycle', 'bicycle']} | |
while cap.isOpened(): | |
ret, frame = cap.read() | |
if not ret: | |
break | |
image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) | |
# Get frame with detected objects and graph | |
frame_with_boxes, graph_image = detect_objects_in_image(image) | |
# Convert image back to OpenCV format for writing video | |
out.write(cv2.cvtColor(np.array(frame_with_boxes), cv2.COLOR_RGB2BGR)) | |
cap.release() | |
out.release() | |
return temp_video_output, graph_image # Return both expected outputs | |
def greet(name): | |
return "Hello " + name + "!!" | |
demo = gr.Interface(fn=greet, inputs="text", outputs="text") | |
from urllib.request import urlretrieve | |
# get image examples from github | |
urlretrieve("https://github.com/SamDaaLamb/ValorantTracker/blob/main/clip2_-1450-_jpg.jpg?raw=true", "clip2_-1450-_jpg.jpg") # make sure to use "copy image address when copying image from Github" | |
urlretrieve("https://github.com/SamDaaLamb/ValorantTracker/blob/main/clip2_-539-_jpg.jpg?raw=true", "clip2_-539-_jpg.jpg") | |
examples = [ # need to manually delete cache everytime new examples are added | |
["clip2_-1450-_jpg.jpg"], | |
["clip2_-539-_jpg.jpg"]] | |
# define app features and run | |
title = "SpecLab Demo" | |
description = "<p style='text-align: center'>Gradio demo for an ASPP model architecture trained on the SpecLab dataset. To use it, simply add your image, or click one of the examples to load them. Since this demo is run on CPU only, please allow additional time for processing. </p>" | |
article = "<p style='text-align: center'><a href='https://github.com/Nano1337/SpecLab'>Github Repo</a></p>" | |
css = "#0 {object-fit: contain;} #1 {object-fit: contain;}" | |
demo = gr.Interface(fn=detect_objects_in_image, | |
title=title, | |
description=description, | |
article=article, | |
inputs=gr.Image(elem_id=0, show_label=False), | |
outputs=gr.Image(elem_id=1, show_label=False), | |
css=css, | |
examples=examples, | |
cache_examples=True, | |
allow_flagging='never') | |
demo.launch() |