Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from io import BytesIO | |
import os | |
import cv2 | |
import gradio as gr | |
import numpy as np | |
import requests | |
from PIL import Image | |
import gradio as gr | |
import cv2 | |
import tempfile | |
import numpy as np | |
import torch | |
from torchvision import transforms | |
from PIL import Image | |
import matplotlib.pyplot as plt | |
from io import BytesIO | |
# Load the YOLO model | |
from models.common import DetectMultiBackend | |
weights_path = "./best.torchscript" | |
device = torch.device("cpu") # Correctly define the device | |
model = torch.jit.load(weights_path) | |
# model.eval() # Load YOLOv5 model correctly | |
model.eval() | |
# model_path = "./last.pt" | |
# model = torch.jit.load(model_path, map_location=torch.device("cpu")) | |
# model.eval() | |
# transform=transforms.Compose([ | |
# transforms.ToPILImage(), | |
# transforms.Resize((512,640)), | |
# transforms.ToTensor() | |
# ]) | |
transform = transforms.Compose([ # Ensure input is a PIL image | |
transforms.Resize((640, 640)), | |
transforms.ToTensor() | |
]) | |
# transform = transforms.Compose([ | |
# transforms.Resize((640, 640)), | |
# transforms.ToTensor(), | |
# ]) | |
OBJECT_NAMES = ['enemies'] | |
def detect_objects_in_image(image): | |
print(type(image)) | |
print(np.ndarray.view(image)) | |
print(image.size) | |
if isinstance(image, np.ndarray): | |
print("Converting NumPy array to PIL Image") | |
image = Image.fromarray(image) | |
print(image.size) | |
img_tensor = transform(image).unsqueeze(0) | |
orig_w, orig_h = image.size | |
print("passed1") | |
print(torch.no_grad()) | |
with torch.no_grad(): | |
pred = model(img_tensor)[0] | |
print("Passed2") | |
if isinstance(pred[0], torch.Tensor): | |
pred = [p.cpu().numpy() for p in pred] | |
print("Passed3") | |
pred = np.concatenate(pred, axis=0) | |
conf_thres = 0.25 | |
# Ensure `pred` is at least a 2D array before indexing | |
pred = np.atleast_2d(pred) # Converts 1D to 2D if necessary | |
print("passed4") | |
mask = pred[:, 4] > conf_thres | |
pred = pred[mask] | |
print("passed5") | |
print(len(pred)) | |
print(Image.fromarray(np.array(image))) | |
print(np.array(image)) | |
print(type(image)) | |
print(len(pred)) | |
if len(pred) == 0: | |
return Image.fromarray(np.array(image)) # Return only image and None for graph | |
print("passed6") | |
boxes, scores, class_probs = pred[:, :4], pred[:, 4], pred[:, 5:] | |
class_ids = np.argmax(class_probs, axis=1) | |
print("passed7") | |
boxes[:, 0] = boxes[:, 0] - (boxes[:, 2] / 2) | |
boxes[:, 1] = boxes[:, 1] - (boxes[:, 3] / 2) | |
boxes[:, 2] = boxes[:, 0] + boxes[:, 2] | |
boxes[:, 3] = boxes[:, 1] + boxes[:, 3] | |
print("passed8") | |
boxes[:, [0, 2]] *= orig_w / 640 | |
boxes[:, [1, 3]] *= orig_h / 640 | |
boxes = np.clip(boxes, 0, [orig_w, orig_h, orig_w, orig_h]) | |
print("passed9") | |
indices = cv2.dnn.NMSBoxes(boxes.tolist(), scores.tolist(), conf_thres, 0.5) | |
print("passed10") | |
object_counts = {name: 0 for name in OBJECT_NAMES} | |
img_array = np.array(image) | |
print("passed11") | |
if len(indices) > 0: | |
for i in indices.flatten(): | |
x1, y1, x2, y2 = map(int, boxes[i]) | |
cls = class_ids[i] | |
object_name = OBJECT_NAMES[cls] if cls < len(OBJECT_NAMES) else f"Unknown ({cls})" | |
if object_name in object_counts: | |
object_counts[object_name] += 1 | |
cv2.rectangle(img_array, (x1, y1), (x2, y2), (0, 255, 0), 2) | |
cv2.putText(img_array, f"{object_name}: {scores[i]:.2f}", (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, | |
(0, 255, 0), 2) | |
# Generate and return graph instead of dictionary | |
# graph_image = generate_vehicle_count_graph(object_counts) | |
print(Image.fromarray(img_array),"hey") | |
return Image.fromarray(img_array)#, graph_image # Now returning only 2 outputs | |
# def generate_vehicle_count_graph(object_counts): | |
# color_palette = ['#4C9ACD', '#88B8A3', '#7F9C9C', '#D1A3B5', '#A1C6EA', '#FFB6C1', '#F0E68C', '#D3B0D8', '#F8A5D1', '#B8B8D1'] | |
# fig, ax = plt.subplots(figsize=(8, 5)) | |
# labels = list(object_counts.keys()) | |
# values = list(object_counts.values()) | |
# ax.bar(labels, values, color=color_palette[:len(labels)]) | |
# ax.set_xlabel("Vehicle Categories", fontsize=12, fontweight='bold') | |
# ax.set_ylabel("Number of Vehicles", fontsize=12, fontweight='bold') | |
# ax.set_title("Detected Vehicles in Image", fontsize=14, fontweight='bold') | |
# plt.xticks(rotation=45, ha='right', fontsize=10) | |
# plt.yticks(fontsize=10) | |
# plt.tight_layout() | |
# buf = BytesIO() | |
# plt.savefig(buf, format='png') | |
# buf.seek(0) | |
# return Image.open(buf) | |
# demo = gr.Interface(fn=greet, inputs="text", outputs="text") | |
from urllib.request import urlretrieve | |
# get image examples from github | |
urlretrieve("https://github.com/SamDaaLamb/ValorantTracker/blob/main/clip2_-1450-_jpg.jpg?raw=true", | |
"clip2_-1450-_jpg.jpg") # make sure to use "copy image address when copying image from Github" | |
urlretrieve("https://github.com/SamDaaLamb/ValorantTracker/blob/main/clip2_-539-_jpg.jpg?raw=true", | |
"clip2_-539-_jpg.jpg") | |
examples = [ # need to manually delete cache everytime new examples are added | |
["clip2_-1450-_jpg.jpg"], | |
["clip2_-539-_jpg.jpg"]] | |
# define app features and run | |
title = "Valorant Tracker Demo" | |
description = "<p style='text-align: center'>Gradio demo for a YOLO model architecture trained on the custom made dataset. To use it, simply add your image, or click one of the examples to load them. Since this demo is run on CPU only, please allow additional time for processing. I would like it to be knwon that the results from this virtual space are much worse than the same model on my computer. For an unknown reason this model perfroms worse in this space. If anyone does know the reason feel free to contanct: [email protected] .</p>" | |
article = "<p style='text-align: center'><a href='https://github.com/Nano1337/SpecLab'>Github Repo</a></p>" | |
css = "#0 {object-fit: contain;} #1 {object-fit: contain;}" | |
print("chek3") | |
demo = gr.Interface(fn=detect_objects_in_image, | |
title=title, | |
description=description, | |
article=article, | |
inputs=gr.Image(elem_id=0, show_label=False), | |
outputs=gr.Image(elem_id=1, show_label=False), | |
css=css, | |
examples=examples, | |
cache_examples=True, | |
allow_flagging='never') | |
demo.launch() |