Spaces:
Sleeping
Sleeping
# from dataclasses import dataclass, replace | |
# from functools import reduce | |
from io import BytesIO | |
import math | |
import os | |
from pprint import pprint | |
import tempfile | |
from PIL import Image, ImageDraw, ImageFont | |
import numpy as np | |
import cv2 | |
# import seaborn as sns | |
# import matplotlib.pyplot as plt | |
# %matplotlib inline | |
import torch | |
from torch.utils.data import Dataset | |
import torchvision | |
from torchvision import transforms | |
import roboflow | |
from roboflow import Roboflow | |
import supervision as sv | |
import albumentations as A | |
import gradio as gr | |
import requests | |
# from torchmetrics.detection.mean_ap import MeanAveragePrecision | |
# from torchmetrics.detection.iou import IntersectionOverUnion | |
# import evaluate | |
#from datasets import load_metric | |
from transformers import pipeline | |
from transformers import ( | |
AutoProcessor, | |
AutoImageProcessor, | |
AutoModel, | |
AutoModelForObjectDetection, | |
RTDetrForObjectDetection, | |
RTDetrImageProcessor, | |
TrainingArguments, | |
Trainer | |
) | |
from huggingface_hub import hf_hub_download | |
from safetensors.torch import load_file | |
#@title Utilities | |
PALETTE = {0: {"color": (255, 0, 0), | |
"name": "Ambulance"}, | |
1: {"color": (0, 191, 0), | |
"name": "Firetruck"}, | |
2: {"color": (0, 0, 255), | |
"name": "Police"}, | |
3: {"color": (255, 0, 255), | |
"name": "Non-EV"}} | |
label2id = {val["name"]: id for (id, val) in PALETTE.items()} | |
id2label = {id: name for (name, id) in label2id.items()} | |
print(label2id) | |
print(id2label) | |
def unnormalize_bbox(img_h, img_w, bbox): | |
x_min = bbox[0] - bbox[2]/2 | |
y_min = bbox[1] - bbox[3]/2 | |
x_max = bbox[0] + bbox[2]/2 # - x_min | |
y_max = bbox[1] + bbox[3]/2 # - y_min | |
x_min *= img_w | |
y_min *= img_h | |
x_max *= img_w | |
y_max *= img_h | |
x_min, y_min, x_max, y_max = list(map(int, [x_min, y_min, x_max, y_max])) | |
return (x_min, y_min, x_max, y_max) | |
def paint_bbox( | |
image, | |
annotations, | |
normalize_labels=True, | |
normalize_bbox=True, | |
): | |
bboxes = annotations["boxes"].tolist() | |
class_id = annotations["labels"].tolist() | |
confidences = annotations["scores"].tolist() | |
painted_img = image.copy() # Wutdehell | |
for (bbox, label, confidence) in zip(bboxes, class_id, confidences): | |
label = (label - 1) if normalize_labels else label | |
if normalize_bbox: | |
img_h, img_w = image.shape[0], image.shape[1] # H, W, C | |
x_min, y_min, x_max, y_max = unnormalize_bbox(img_h, img_w, bbox) | |
print([x_min, y_min, x_max, y_max]) | |
""" | |
x_min = #int(bbox[0] - bbox[2]/2) # Left | |
y_min = #int(bbox[1] - bbox[3]/2) # Top | |
x_max = #int(bbox[0] + bbox[2]/2) | |
y_max = #int(bbox[1] + bbox[3]/2) | |
""" | |
else: | |
x_min, y_min, x_max, y_max = list(map(int, bbox)) | |
box_color = PALETTE[label]["color"] | |
label_name = PALETTE[label]["name"] | |
if confidence != -1: | |
label_name = f"{label_name} ({confidence:.2f})" | |
cv2.rectangle(painted_img, | |
(x_min, y_min), | |
(x_max, y_max), | |
color=box_color, | |
thickness=2) | |
cv2.rectangle(painted_img, | |
(x_min, y_min), | |
(x_min + 5 + len(label_name)*10, y_min + 17), | |
color=box_color, | |
thickness=-1) | |
cv2.putText(painted_img, | |
label_name, | |
(x_min + 2, y_min + 12), | |
fontFace=cv2.FONT_HERSHEY_SIMPLEX, | |
fontScale=0.5, | |
color=(255, 255, 255), | |
thickness=1) | |
return painted_img | |
# Function to calculate Intersection over Union (IoU) | |
def calculate_iou(truth_bbx, pred_bbx): | |
# Coordinates of the boxes: [xmin, ymin, xmax, ymax] | |
x1, y1, x2, y2 = truth_bbx | |
x1_p, y1_p, x2_p, y2_p = pred_bbx | |
# Calculate intersection | |
ixmin = max(x1, x1_p) | |
iymin = max(y1, y1_p) | |
ixmax = min(x2, x2_p) | |
iymax = min(y2, y2_p) | |
iw = max(0, ixmax - ixmin) | |
ih = max(0, iymax - iymin) | |
intersection = iw * ih | |
area1 = (x2 - x1) * (y2 - y1) | |
area2 = (x2_p - x1_p) * (y2_p - y1_p) | |
union = area1 + area2 - intersection | |
iou = intersection / union if union != 0 else 0 | |
return iou | |
# Example: emotion_classifier = pipeline("image-classification", model="itsindrabudhik/emotion_classification") | |
# (Load only once) | |
DETECTOR = pipeline("object-detection", model="itsindrabudhik/finalProjectCV2425") #later on, change this with out trained modell yesssss (the trained model should be uploaded to hugging face) | |
tensor_file = hf_hub_download(repo_id="itsindrabudhik/finalProjectCV2425", | |
filename="model.safetensors") | |
# Assign classification head weights since that pipeline seems to not handling it | |
# weights = load_file(tensor_file) | |
# DETECTOR.model.class_labels_classifier.weight.data = weights["class_labels_classifier.weight"] | |
# DETECTOR.model.class_labels_classifier.bias.data = weights["class_labels_classifier.bias"] | |
# del weights | |
def detect_ev_nev(image, confidence_threshold=0.5, iou_threshold=0.5): | |
# Run the detector pipeline on the image | |
results = DETECTOR(image) | |
# Open the image | |
if isinstance(image, str): # If the image is a URL or file path | |
if image.startswith("http"): | |
response = requests.get(image) | |
img = Image.open(BytesIO(response.content)) | |
else: | |
img = Image.open(image) | |
else: | |
img = image | |
# Draw bounding boxes and labels on the image | |
font_path = os.path.join(cv2.__path__[0],'qt','fonts','DejaVuSans.ttf') | |
font = ImageFont.truetype(font_path, size=32) | |
draw = ImageDraw.Draw(img) | |
details = [] # Collect details for text output | |
for result in results: | |
score = result['score'] | |
label = result['label'] | |
box = result['box'] | |
# Apply confidence threshold | |
if score < confidence_threshold: | |
continue | |
# Filter out low IoU detections | |
keep = True | |
for previous_result in results: | |
if previous_result != result: | |
prev_box = previous_result['box'] | |
iou = calculate_iou([box['xmin'], box['ymin'], box['xmax'], box['ymax']], | |
[prev_box['xmin'], prev_box['ymin'], prev_box['xmax'], prev_box['ymax']]) | |
if iou > iou_threshold: | |
keep = False | |
break | |
label_color = PALETTE[label2id[label]]["color"] | |
if keep: | |
# Draw the bounding box and label | |
xmin, ymin, xmax, ymax = box['xmin'], box['ymin'], box['xmax'], box['ymax'] | |
draw.rectangle([xmin, ymin, xmax, ymax], outline="red", width=3) | |
# Use a larger font size for text | |
text = f"{label} ({score:.2f})" | |
# Calculate text bounding box | |
text_bbox = draw.textbbox((xmin, ymin - 10), text, font=font) # This gives (xmin, ymin, xmax, ymax) | |
text_width = text_bbox[2] - text_bbox[0] # width of the text box | |
text_height = text_bbox[3] - text_bbox[1] # height of the text box | |
# Draw the text on the image (position adjusted) | |
draw.text((xmin, ymin - text_height - 5), text, fill="red", font=font) | |
# Add details to the list | |
details.append({ | |
"Label": label, | |
"Confidence": f"{score:.2f}", | |
"Bounding Box": f"({xmin}, {ymin}, {xmax}, {ymax})" | |
}) | |
details_text = "\n".join([f"Label: {d['Label']}, Confidence: {d['Confidence']}, Box: {d['Bounding Box']}" for d in details]) | |
return img, details_text | |
def detect_video(video, confidence_threshold=0.5, iou_threshold=0.5): | |
video_capture = cv2.VideoCapture(video) | |
fps = video_capture.get(cv2.CAP_PROP_FPS) | |
frame_width = int(video_capture.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
frame_height = int(video_capture.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
temp_output = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") | |
fourcc = cv2.VideoWriter_fourcc(*"mp4v") | |
out = cv2.VideoWriter(temp_output.name, fourcc, fps, (frame_width, frame_height)) | |
details = [] | |
total_frames = 0 | |
detected_frames = 0 | |
while True: | |
ret, frame = video_capture.read() | |
if not ret: | |
break | |
total_frames += 1 | |
image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) | |
annotated_image, frame_details = detect_ev_nev(image, confidence_threshold, iou_threshold) | |
# Count frames with detections | |
if frame_details.strip(): # Non-empty details indicate detections | |
detected_frames += 1 | |
details.append(frame_details) | |
annotated_frame = cv2.cvtColor(np.array(annotated_image), cv2.COLOR_RGB2BGR) | |
out.write(annotated_frame) | |
video_capture.release() | |
out.release() | |
details_text = "\n".join(details) | |
summary = f"Total Frames: {total_frames}, Frames with Detections: {detected_frames}\n" + details_text | |
return temp_output.name, summary | |
def detect(file, confidence_threshold=0.5, iou_threshold=0.5): | |
# Determine if input is an image or video | |
file_ext = file.name.split(".")[-1].lower() | |
if file_ext in ["png", "jpg", "jpeg"]: | |
# Image processing | |
annotated_image, details = detect_ev_nev(file, confidence_threshold, iou_threshold) | |
return annotated_image, None, details | |
elif file_ext in ["mp4", "avi", "mov"]: | |
# Video processing | |
processed_video, details = detect_video(file, confidence_threshold, iou_threshold) | |
return None, processed_video, details | |
else: | |
raise ValueError("Unsupported file format. Please upload an image or video.") | |
interface = gr.Interface( | |
fn=detect, | |
inputs=[ | |
gr.File(label="Upload Image or Video", file_types=[".png", ".jpg", ".jpeg", ".mp4", ".avi", ".mov"]), | |
gr.Slider(0, 1, value=0.5, label="Confidence Threshold"), | |
gr.Slider(0, 1, value=0.5, label="IoU Threshold"), | |
], | |
outputs=[ | |
gr.Image(label="Processed Image"), | |
gr.Video(label="Generated Video"), | |
gr.Text(label="Detection Details") | |
], | |
title="RT-DETR Object Detection for Images and Videos", | |
description="Upload an image or video to detect objects using the fine-tuned RT-DETR model. Results include the annotated image/video and detection details." | |
) | |
interface.launch(debug=True) |