ankanpy's picture
Upload 9 files
b959f6e verified
import os
import cv2
import time
import torch
import gradio as gr
import numpy as np
# Make sure these are your local imports from your project.
from model import create_model
from config import NUM_CLASSES, DEVICE, CLASSES
# ----------------------------------------------------------------
# GLOBAL SETUP
# ----------------------------------------------------------------
# Create the model and load the best weights.
model = create_model(num_classes=NUM_CLASSES)
checkpoint = torch.load("outputs/best_model_79.pth", map_location=DEVICE)
model.load_state_dict(checkpoint["model_state_dict"])
model.to(DEVICE).eval()
# Create a colors array for each class index.
# (length matches len(CLASSES), including background if you wish).
COLORS = np.random.uniform(0, 255, size=(len(CLASSES), 3))
# COLORS = [
# (255, 255, 0), # Cyan - background
# (50, 0, 255), # Red - buffalo
# (147, 20, 255), # Pink - elephant
# (0, 255, 0), # Green - rhino
# (238, 130, 238), # Violet - zebra
# ]
# ----------------------------------------------------------------
# HELPER FUNCTIONS
# ----------------------------------------------------------------
def inference_on_image(orig_image: np.ndarray, resize_dim=None, threshold=0.25):
"""
Runs inference on a single image (OpenCV BGR or NumPy array).
- resize_dim: if not None, we resize to (resize_dim, resize_dim)
- threshold: detection confidence threshold
Returns: processed image with bounding boxes drawn.
"""
image = orig_image.copy()
# Optionally resize for inference.
if resize_dim is not None:
image = cv2.resize(image, (resize_dim, resize_dim))
# Convert BGR to RGB, normalize [0..1]
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
# Move channels to front (C,H,W)
image_tensor = torch.tensor(image_rgb.transpose(2, 0, 1), dtype=torch.float).unsqueeze(0).to(DEVICE)
start_time = time.time()
# Inference
with torch.no_grad():
outputs = model(image_tensor)
end_time = time.time()
# Get the current fps.
fps = 1 / (end_time - start_time)
fps_text = f"FPS: {fps:.2f}"
# Move outputs to CPU numpy
outputs = [{k: v.cpu() for k, v in t.items()} for t in outputs]
boxes = outputs[0]["boxes"].numpy()
scores = outputs[0]["scores"].numpy()
labels = outputs[0]["labels"].numpy().astype(int)
# Filter out boxes with low confidence
valid_idx = np.where(scores >= threshold)[0]
boxes = boxes[valid_idx].astype(int)
labels = labels[valid_idx]
# If we resized for inference, rescale boxes back to orig_image size
if resize_dim is not None:
h_orig, w_orig = orig_image.shape[:2]
h_new, w_new = resize_dim, resize_dim
# scale boxes
boxes[:, [0, 2]] = (boxes[:, [0, 2]] / w_new) * w_orig
boxes[:, [1, 3]] = (boxes[:, [1, 3]] / h_new) * h_orig
# Draw bounding boxes
for box, label_idx in zip(boxes, labels):
class_name = CLASSES[label_idx] if 0 <= label_idx < len(CLASSES) else str(label_idx)
color = COLORS[label_idx % len(COLORS)][::-1] # BGR color
cv2.rectangle(orig_image, (box[0], box[1]), (box[2], box[3]), color, 5)
cv2.putText(orig_image, class_name, (box[0], box[1] - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0,0,0), 3)
cv2.putText(
orig_image,
fps_text,
(int((w_orig / 2) - 50), 30),
cv2.FONT_HERSHEY_SIMPLEX,
0.8,
(0, 255, 0),
2,
cv2.LINE_AA,
)
return orig_image, fps
def inference_on_frame(frame: np.ndarray, resize_dim=None, threshold=0.25):
"""
Same as inference_on_image but for a single video frame.
Returns the processed frame with bounding boxes.
"""
return inference_on_image(frame, resize_dim, threshold)
# ----------------------------------------------------------------
# GRADIO FUNCTIONS
# ----------------------------------------------------------------
def img_inf(image_path, resize_dim, threshold):
"""
Gradio function for image inference.
:param image_path: File path from Gradio (uploaded image).
:param model_name: Selected model from Radio (not used if only one model).
Returns: A NumPy image array with bounding boxes.
"""
if image_path is None:
return None # No image provided
orig_image = cv2.imread(image_path) # BGR
if orig_image is None:
return None # Error reading image
result_image, _ = inference_on_image(orig_image, resize_dim=resize_dim, threshold=threshold)
# Return the image in RGB for Gradio's display
result_image_rgb = cv2.cvtColor(result_image, cv2.COLOR_BGR2RGB)
return result_image_rgb
def vid_inf(video_path, resize_dim, threshold):
"""
Gradio function for video inference.
Processes each frame, draws bounding boxes, and writes to an output video.
Returns: (last_processed_frame, output_video_file_path)
"""
if video_path is None:
return None, None # No video provided
# Prepare input capture
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
return None, None
# Create an output file path
os.makedirs("inference_outputs/videos", exist_ok=True)
out_video_path = os.path.join("inference_outputs/videos", "video_output.mp4")
# out_video_path = "video_output.mp4"
# Get video properties
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = cap.get(cv2.CAP_PROP_FPS)
fourcc = cv2.VideoWriter_fourcc(*"mp4v") # or 'XVID'
# If FPS is 0 (some weird container), default to something
if fps <= 0:
fps = 20.0
out_writer = cv2.VideoWriter(out_video_path, fourcc, fps, (width, height))
frame_count = 0
total_fps = 0
while True:
ret, frame = cap.read()
if not ret:
break
# Inference on frame
processed_frame, frame_fps = inference_on_frame(frame, resize_dim=resize_dim, threshold=threshold)
total_fps += frame_fps
frame_count += 1
# Write the processed frame
out_writer.write(processed_frame)
yield cv2.cvtColor(processed_frame, cv2.COLOR_BGR2RGB), None
avg_fps = total_fps / frame_count
cap.release()
out_writer.release()
print(f"Average FPS: {avg_fps:.3f}")
yield None, out_video_path
# ----------------------------------------------------------------
# BUILD THE GRADIO INTERFACES
# ----------------------------------------------------------------
# For demonstration, we define two possible model radio choices.
# You can ignore or expand this if you only use RetinaNet.
resize_dim = gr.Slider(100, 1024, value=640, label="Resize Dimension", info="Resize image to this dimension")
threshold = gr.Slider(0, 1, value=0.5, label="Threshold", info="Confidence threshold for detection")
inputs_image = gr.Image(type="filepath", label="Input Image")
outputs_image = gr.Image(type="numpy", label="Output Image")
interface_image = gr.Interface(
fn=img_inf,
inputs=[inputs_image, resize_dim, threshold],
outputs=outputs_image,
title="Image Inference",
description="Upload your photo, select a model, and see the results!",
examples=[["examples/buffalo.jpg"], ["examples/zebra.jpg"]],
cache_examples=False,
)
resize_dim = gr.Slider(100, 1024, value=640, label="Resize Dimension", info="Resize image to this dimension")
threshold = gr.Slider(0, 1, value=0.5, label="Threshold", info="Confidence threshold for detection")
input_video = gr.Video(label="Input Video")
# Output is a pair: (last_processed_frame, output_video_path)
output_frame = gr.Image(type="numpy", label="Output (Last Processed Frame)")
output_video_file = gr.Video(format="mp4", label="Output Video")
interface_video = gr.Interface(
fn=vid_inf,
inputs=[input_video, resize_dim, threshold],
outputs=[output_frame, output_video_file],
title="Video Inference",
description="Upload your video and see the processed output!",
examples=[["examples/elephants.mp4"], ["examples/rhino.mp4"]],
cache_examples=False,
)
# Combine them in a Tabbed Interface
demo = (
gr.TabbedInterface(
[interface_image, interface_video],
tab_names=["Image", "Video"],
title="FineTuning RetinaNet for Wildlife Animal Detection",
theme="gstaff/xkcd",
)
.queue()
.launch()
)