|
import os |
|
import cv2 |
|
import time |
|
import torch |
|
import gradio as gr |
|
import numpy as np |
|
|
|
|
|
from model import create_model |
|
from config import NUM_CLASSES, DEVICE, CLASSES |
|
|
|
|
|
|
|
|
|
|
|
model = create_model(num_classes=NUM_CLASSES) |
|
checkpoint = torch.load("outputs/best_model_79.pth", map_location=DEVICE) |
|
model.load_state_dict(checkpoint["model_state_dict"]) |
|
model.to(DEVICE).eval() |
|
|
|
|
|
|
|
COLORS = np.random.uniform(0, 255, size=(len(CLASSES), 3)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def inference_on_image(orig_image: np.ndarray, resize_dim=None, threshold=0.25): |
|
""" |
|
Runs inference on a single image (OpenCV BGR or NumPy array). |
|
- resize_dim: if not None, we resize to (resize_dim, resize_dim) |
|
- threshold: detection confidence threshold |
|
Returns: processed image with bounding boxes drawn. |
|
""" |
|
image = orig_image.copy() |
|
|
|
if resize_dim is not None: |
|
image = cv2.resize(image, (resize_dim, resize_dim)) |
|
|
|
|
|
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0 |
|
|
|
image_tensor = torch.tensor(image_rgb.transpose(2, 0, 1), dtype=torch.float).unsqueeze(0).to(DEVICE) |
|
start_time = time.time() |
|
|
|
with torch.no_grad(): |
|
outputs = model(image_tensor) |
|
end_time = time.time() |
|
|
|
fps = 1 / (end_time - start_time) |
|
fps_text = f"FPS: {fps:.2f}" |
|
|
|
outputs = [{k: v.cpu() for k, v in t.items()} for t in outputs] |
|
boxes = outputs[0]["boxes"].numpy() |
|
scores = outputs[0]["scores"].numpy() |
|
labels = outputs[0]["labels"].numpy().astype(int) |
|
|
|
|
|
valid_idx = np.where(scores >= threshold)[0] |
|
boxes = boxes[valid_idx].astype(int) |
|
labels = labels[valid_idx] |
|
|
|
|
|
if resize_dim is not None: |
|
h_orig, w_orig = orig_image.shape[:2] |
|
h_new, w_new = resize_dim, resize_dim |
|
|
|
boxes[:, [0, 2]] = (boxes[:, [0, 2]] / w_new) * w_orig |
|
boxes[:, [1, 3]] = (boxes[:, [1, 3]] / h_new) * h_orig |
|
|
|
|
|
for box, label_idx in zip(boxes, labels): |
|
class_name = CLASSES[label_idx] if 0 <= label_idx < len(CLASSES) else str(label_idx) |
|
color = COLORS[label_idx % len(COLORS)][::-1] |
|
cv2.rectangle(orig_image, (box[0], box[1]), (box[2], box[3]), color, 5) |
|
cv2.putText(orig_image, class_name, (box[0], box[1] - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0,0,0), 3) |
|
cv2.putText( |
|
orig_image, |
|
fps_text, |
|
(int((w_orig / 2) - 50), 30), |
|
cv2.FONT_HERSHEY_SIMPLEX, |
|
0.8, |
|
(0, 255, 0), |
|
2, |
|
cv2.LINE_AA, |
|
) |
|
return orig_image, fps |
|
|
|
|
|
def inference_on_frame(frame: np.ndarray, resize_dim=None, threshold=0.25): |
|
""" |
|
Same as inference_on_image but for a single video frame. |
|
Returns the processed frame with bounding boxes. |
|
""" |
|
return inference_on_image(frame, resize_dim, threshold) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def img_inf(image_path, resize_dim, threshold): |
|
""" |
|
Gradio function for image inference. |
|
:param image_path: File path from Gradio (uploaded image). |
|
:param model_name: Selected model from Radio (not used if only one model). |
|
Returns: A NumPy image array with bounding boxes. |
|
""" |
|
if image_path is None: |
|
return None |
|
orig_image = cv2.imread(image_path) |
|
if orig_image is None: |
|
return None |
|
|
|
result_image, _ = inference_on_image(orig_image, resize_dim=resize_dim, threshold=threshold) |
|
|
|
result_image_rgb = cv2.cvtColor(result_image, cv2.COLOR_BGR2RGB) |
|
return result_image_rgb |
|
|
|
|
|
def vid_inf(video_path, resize_dim, threshold): |
|
""" |
|
Gradio function for video inference. |
|
Processes each frame, draws bounding boxes, and writes to an output video. |
|
Returns: (last_processed_frame, output_video_file_path) |
|
""" |
|
if video_path is None: |
|
return None, None |
|
|
|
|
|
cap = cv2.VideoCapture(video_path) |
|
if not cap.isOpened(): |
|
return None, None |
|
|
|
|
|
os.makedirs("inference_outputs/videos", exist_ok=True) |
|
out_video_path = os.path.join("inference_outputs/videos", "video_output.mp4") |
|
|
|
|
|
|
|
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
|
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
|
fps = cap.get(cv2.CAP_PROP_FPS) |
|
fourcc = cv2.VideoWriter_fourcc(*"mp4v") |
|
|
|
|
|
if fps <= 0: |
|
fps = 20.0 |
|
|
|
out_writer = cv2.VideoWriter(out_video_path, fourcc, fps, (width, height)) |
|
|
|
frame_count = 0 |
|
total_fps = 0 |
|
|
|
while True: |
|
ret, frame = cap.read() |
|
if not ret: |
|
break |
|
|
|
processed_frame, frame_fps = inference_on_frame(frame, resize_dim=resize_dim, threshold=threshold) |
|
total_fps += frame_fps |
|
frame_count += 1 |
|
|
|
|
|
out_writer.write(processed_frame) |
|
yield cv2.cvtColor(processed_frame, cv2.COLOR_BGR2RGB), None |
|
|
|
avg_fps = total_fps / frame_count |
|
|
|
cap.release() |
|
out_writer.release() |
|
print(f"Average FPS: {avg_fps:.3f}") |
|
yield None, out_video_path |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
resize_dim = gr.Slider(100, 1024, value=640, label="Resize Dimension", info="Resize image to this dimension") |
|
threshold = gr.Slider(0, 1, value=0.5, label="Threshold", info="Confidence threshold for detection") |
|
inputs_image = gr.Image(type="filepath", label="Input Image") |
|
outputs_image = gr.Image(type="numpy", label="Output Image") |
|
|
|
interface_image = gr.Interface( |
|
fn=img_inf, |
|
inputs=[inputs_image, resize_dim, threshold], |
|
outputs=outputs_image, |
|
title="Image Inference", |
|
description="Upload your photo, select a model, and see the results!", |
|
examples=[["examples/buffalo.jpg"], ["examples/zebra.jpg"]], |
|
cache_examples=False, |
|
) |
|
|
|
resize_dim = gr.Slider(100, 1024, value=640, label="Resize Dimension", info="Resize image to this dimension") |
|
threshold = gr.Slider(0, 1, value=0.5, label="Threshold", info="Confidence threshold for detection") |
|
input_video = gr.Video(label="Input Video") |
|
|
|
|
|
output_frame = gr.Image(type="numpy", label="Output (Last Processed Frame)") |
|
output_video_file = gr.Video(format="mp4", label="Output Video") |
|
|
|
interface_video = gr.Interface( |
|
fn=vid_inf, |
|
inputs=[input_video, resize_dim, threshold], |
|
outputs=[output_frame, output_video_file], |
|
title="Video Inference", |
|
description="Upload your video and see the processed output!", |
|
examples=[["examples/elephants.mp4"], ["examples/rhino.mp4"]], |
|
cache_examples=False, |
|
) |
|
|
|
|
|
demo = ( |
|
gr.TabbedInterface( |
|
[interface_image, interface_video], |
|
tab_names=["Image", "Video"], |
|
title="FineTuning RetinaNet for Wildlife Animal Detection", |
|
theme="gstaff/xkcd", |
|
) |
|
.queue() |
|
.launch() |
|
) |
|
|