File size: 8,466 Bytes
b959f6e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 |
import os
import cv2
import time
import torch
import gradio as gr
import numpy as np
# Make sure these are your local imports from your project.
from model import create_model
from config import NUM_CLASSES, DEVICE, CLASSES
# ----------------------------------------------------------------
# GLOBAL SETUP
# ----------------------------------------------------------------
# Create the model and load the best weights.
model = create_model(num_classes=NUM_CLASSES)
checkpoint = torch.load("outputs/best_model_79.pth", map_location=DEVICE)
model.load_state_dict(checkpoint["model_state_dict"])
model.to(DEVICE).eval()
# Create a colors array for each class index.
# (length matches len(CLASSES), including background if you wish).
COLORS = np.random.uniform(0, 255, size=(len(CLASSES), 3))
# COLORS = [
# (255, 255, 0), # Cyan - background
# (50, 0, 255), # Red - buffalo
# (147, 20, 255), # Pink - elephant
# (0, 255, 0), # Green - rhino
# (238, 130, 238), # Violet - zebra
# ]
# ----------------------------------------------------------------
# HELPER FUNCTIONS
# ----------------------------------------------------------------
def inference_on_image(orig_image: np.ndarray, resize_dim=None, threshold=0.25):
"""
Runs inference on a single image (OpenCV BGR or NumPy array).
- resize_dim: if not None, we resize to (resize_dim, resize_dim)
- threshold: detection confidence threshold
Returns: processed image with bounding boxes drawn.
"""
image = orig_image.copy()
# Optionally resize for inference.
if resize_dim is not None:
image = cv2.resize(image, (resize_dim, resize_dim))
# Convert BGR to RGB, normalize [0..1]
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
# Move channels to front (C,H,W)
image_tensor = torch.tensor(image_rgb.transpose(2, 0, 1), dtype=torch.float).unsqueeze(0).to(DEVICE)
start_time = time.time()
# Inference
with torch.no_grad():
outputs = model(image_tensor)
end_time = time.time()
# Get the current fps.
fps = 1 / (end_time - start_time)
fps_text = f"FPS: {fps:.2f}"
# Move outputs to CPU numpy
outputs = [{k: v.cpu() for k, v in t.items()} for t in outputs]
boxes = outputs[0]["boxes"].numpy()
scores = outputs[0]["scores"].numpy()
labels = outputs[0]["labels"].numpy().astype(int)
# Filter out boxes with low confidence
valid_idx = np.where(scores >= threshold)[0]
boxes = boxes[valid_idx].astype(int)
labels = labels[valid_idx]
# If we resized for inference, rescale boxes back to orig_image size
if resize_dim is not None:
h_orig, w_orig = orig_image.shape[:2]
h_new, w_new = resize_dim, resize_dim
# scale boxes
boxes[:, [0, 2]] = (boxes[:, [0, 2]] / w_new) * w_orig
boxes[:, [1, 3]] = (boxes[:, [1, 3]] / h_new) * h_orig
# Draw bounding boxes
for box, label_idx in zip(boxes, labels):
class_name = CLASSES[label_idx] if 0 <= label_idx < len(CLASSES) else str(label_idx)
color = COLORS[label_idx % len(COLORS)][::-1] # BGR color
cv2.rectangle(orig_image, (box[0], box[1]), (box[2], box[3]), color, 5)
cv2.putText(orig_image, class_name, (box[0], box[1] - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0,0,0), 3)
cv2.putText(
orig_image,
fps_text,
(int((w_orig / 2) - 50), 30),
cv2.FONT_HERSHEY_SIMPLEX,
0.8,
(0, 255, 0),
2,
cv2.LINE_AA,
)
return orig_image, fps
def inference_on_frame(frame: np.ndarray, resize_dim=None, threshold=0.25):
"""
Same as inference_on_image but for a single video frame.
Returns the processed frame with bounding boxes.
"""
return inference_on_image(frame, resize_dim, threshold)
# ----------------------------------------------------------------
# GRADIO FUNCTIONS
# ----------------------------------------------------------------
def img_inf(image_path, resize_dim, threshold):
"""
Gradio function for image inference.
:param image_path: File path from Gradio (uploaded image).
:param model_name: Selected model from Radio (not used if only one model).
Returns: A NumPy image array with bounding boxes.
"""
if image_path is None:
return None # No image provided
orig_image = cv2.imread(image_path) # BGR
if orig_image is None:
return None # Error reading image
result_image, _ = inference_on_image(orig_image, resize_dim=resize_dim, threshold=threshold)
# Return the image in RGB for Gradio's display
result_image_rgb = cv2.cvtColor(result_image, cv2.COLOR_BGR2RGB)
return result_image_rgb
def vid_inf(video_path, resize_dim, threshold):
"""
Gradio function for video inference.
Processes each frame, draws bounding boxes, and writes to an output video.
Returns: (last_processed_frame, output_video_file_path)
"""
if video_path is None:
return None, None # No video provided
# Prepare input capture
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
return None, None
# Create an output file path
os.makedirs("inference_outputs/videos", exist_ok=True)
out_video_path = os.path.join("inference_outputs/videos", "video_output.mp4")
# out_video_path = "video_output.mp4"
# Get video properties
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = cap.get(cv2.CAP_PROP_FPS)
fourcc = cv2.VideoWriter_fourcc(*"mp4v") # or 'XVID'
# If FPS is 0 (some weird container), default to something
if fps <= 0:
fps = 20.0
out_writer = cv2.VideoWriter(out_video_path, fourcc, fps, (width, height))
frame_count = 0
total_fps = 0
while True:
ret, frame = cap.read()
if not ret:
break
# Inference on frame
processed_frame, frame_fps = inference_on_frame(frame, resize_dim=resize_dim, threshold=threshold)
total_fps += frame_fps
frame_count += 1
# Write the processed frame
out_writer.write(processed_frame)
yield cv2.cvtColor(processed_frame, cv2.COLOR_BGR2RGB), None
avg_fps = total_fps / frame_count
cap.release()
out_writer.release()
print(f"Average FPS: {avg_fps:.3f}")
yield None, out_video_path
# ----------------------------------------------------------------
# BUILD THE GRADIO INTERFACES
# ----------------------------------------------------------------
# For demonstration, we define two possible model radio choices.
# You can ignore or expand this if you only use RetinaNet.
resize_dim = gr.Slider(100, 1024, value=640, label="Resize Dimension", info="Resize image to this dimension")
threshold = gr.Slider(0, 1, value=0.5, label="Threshold", info="Confidence threshold for detection")
inputs_image = gr.Image(type="filepath", label="Input Image")
outputs_image = gr.Image(type="numpy", label="Output Image")
interface_image = gr.Interface(
fn=img_inf,
inputs=[inputs_image, resize_dim, threshold],
outputs=outputs_image,
title="Image Inference",
description="Upload your photo, select a model, and see the results!",
examples=[["examples/buffalo.jpg"], ["examples/zebra.jpg"]],
cache_examples=False,
)
resize_dim = gr.Slider(100, 1024, value=640, label="Resize Dimension", info="Resize image to this dimension")
threshold = gr.Slider(0, 1, value=0.5, label="Threshold", info="Confidence threshold for detection")
input_video = gr.Video(label="Input Video")
# Output is a pair: (last_processed_frame, output_video_path)
output_frame = gr.Image(type="numpy", label="Output (Last Processed Frame)")
output_video_file = gr.Video(format="mp4", label="Output Video")
interface_video = gr.Interface(
fn=vid_inf,
inputs=[input_video, resize_dim, threshold],
outputs=[output_frame, output_video_file],
title="Video Inference",
description="Upload your video and see the processed output!",
examples=[["examples/elephants.mp4"], ["examples/rhino.mp4"]],
cache_examples=False,
)
# Combine them in a Tabbed Interface
demo = (
gr.TabbedInterface(
[interface_image, interface_video],
tab_names=["Image", "Video"],
title="FineTuning RetinaNet for Wildlife Animal Detection",
theme="gstaff/xkcd",
)
.queue()
.launch()
)
|