File size: 8,466 Bytes
b959f6e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
import os
import cv2
import time
import torch
import gradio as gr
import numpy as np

# Make sure these are your local imports from your project.
from model import create_model
from config import NUM_CLASSES, DEVICE, CLASSES

# ----------------------------------------------------------------
# GLOBAL SETUP
# ----------------------------------------------------------------
# Create the model and load the best weights.
model = create_model(num_classes=NUM_CLASSES)
checkpoint = torch.load("outputs/best_model_79.pth", map_location=DEVICE)
model.load_state_dict(checkpoint["model_state_dict"])
model.to(DEVICE).eval()

# Create a colors array for each class index.
# (length matches len(CLASSES), including background if you wish).
COLORS = np.random.uniform(0, 255, size=(len(CLASSES), 3))

# COLORS = [
#     (255, 255, 0),  # Cyan - background
#     (50, 0, 255),  # Red - buffalo
#     (147, 20, 255),  # Pink - elephant
#     (0, 255, 0),  # Green - rhino
#     (238, 130, 238),  # Violet - zebra
# ]


# ----------------------------------------------------------------
# HELPER FUNCTIONS
# ----------------------------------------------------------------
def inference_on_image(orig_image: np.ndarray, resize_dim=None, threshold=0.25):
    """
    Runs inference on a single image (OpenCV BGR or NumPy array).
    - resize_dim: if not None, we resize to (resize_dim, resize_dim)
    - threshold: detection confidence threshold
    Returns: processed image with bounding boxes drawn.
    """
    image = orig_image.copy()
    # Optionally resize for inference.
    if resize_dim is not None:
        image = cv2.resize(image, (resize_dim, resize_dim))

    # Convert BGR to RGB, normalize [0..1]
    image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
    # Move channels to front (C,H,W)
    image_tensor = torch.tensor(image_rgb.transpose(2, 0, 1), dtype=torch.float).unsqueeze(0).to(DEVICE)
    start_time = time.time()
    # Inference
    with torch.no_grad():
        outputs = model(image_tensor)
    end_time = time.time()
    # Get the current fps.
    fps = 1 / (end_time - start_time)
    fps_text = f"FPS: {fps:.2f}"
    # Move outputs to CPU numpy
    outputs = [{k: v.cpu() for k, v in t.items()} for t in outputs]
    boxes = outputs[0]["boxes"].numpy()
    scores = outputs[0]["scores"].numpy()
    labels = outputs[0]["labels"].numpy().astype(int)

    # Filter out boxes with low confidence
    valid_idx = np.where(scores >= threshold)[0]
    boxes = boxes[valid_idx].astype(int)
    labels = labels[valid_idx]

    # If we resized for inference, rescale boxes back to orig_image size
    if resize_dim is not None:
        h_orig, w_orig = orig_image.shape[:2]
        h_new, w_new = resize_dim, resize_dim
        # scale boxes
        boxes[:, [0, 2]] = (boxes[:, [0, 2]] / w_new) * w_orig
        boxes[:, [1, 3]] = (boxes[:, [1, 3]] / h_new) * h_orig

    # Draw bounding boxes
    for box, label_idx in zip(boxes, labels):
        class_name = CLASSES[label_idx] if 0 <= label_idx < len(CLASSES) else str(label_idx)
        color = COLORS[label_idx % len(COLORS)][::-1]  # BGR color
        cv2.rectangle(orig_image, (box[0], box[1]), (box[2], box[3]), color, 5)
        cv2.putText(orig_image, class_name, (box[0], box[1] - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0,0,0), 3)
        cv2.putText(
            orig_image,
            fps_text,
            (int((w_orig / 2) - 50), 30),
            cv2.FONT_HERSHEY_SIMPLEX,
            0.8,
            (0, 255, 0),
            2,
            cv2.LINE_AA,
        )
    return orig_image, fps


def inference_on_frame(frame: np.ndarray, resize_dim=None, threshold=0.25):
    """
    Same as inference_on_image but for a single video frame.
    Returns the processed frame with bounding boxes.
    """
    return inference_on_image(frame, resize_dim, threshold)


# ----------------------------------------------------------------
# GRADIO FUNCTIONS
# ----------------------------------------------------------------


def img_inf(image_path, resize_dim, threshold):
    """
    Gradio function for image inference.
    :param image_path: File path from Gradio (uploaded image).
    :param model_name: Selected model from Radio (not used if only one model).
    Returns: A NumPy image array with bounding boxes.
    """
    if image_path is None:
        return None  # No image provided
    orig_image = cv2.imread(image_path)  # BGR
    if orig_image is None:
        return None  # Error reading image

    result_image, _ = inference_on_image(orig_image, resize_dim=resize_dim, threshold=threshold)
    # Return the image in RGB for Gradio's display
    result_image_rgb = cv2.cvtColor(result_image, cv2.COLOR_BGR2RGB)
    return result_image_rgb


def vid_inf(video_path, resize_dim, threshold):
    """
    Gradio function for video inference.
    Processes each frame, draws bounding boxes, and writes to an output video.
    Returns: (last_processed_frame, output_video_file_path)
    """
    if video_path is None:
        return None, None  # No video provided

    # Prepare input capture
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        return None, None

    # Create an output file path
    os.makedirs("inference_outputs/videos", exist_ok=True)
    out_video_path = os.path.join("inference_outputs/videos", "video_output.mp4")
    # out_video_path = "video_output.mp4"

    # Get video properties
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = cap.get(cv2.CAP_PROP_FPS)
    fourcc = cv2.VideoWriter_fourcc(*"mp4v")  # or 'XVID'

    # If FPS is 0 (some weird container), default to something
    if fps <= 0:
        fps = 20.0

    out_writer = cv2.VideoWriter(out_video_path, fourcc, fps, (width, height))

    frame_count = 0
    total_fps = 0

    while True:
        ret, frame = cap.read()
        if not ret:
            break
        # Inference on frame
        processed_frame, frame_fps = inference_on_frame(frame, resize_dim=resize_dim, threshold=threshold)
        total_fps += frame_fps
        frame_count += 1

        # Write the processed frame
        out_writer.write(processed_frame)
        yield cv2.cvtColor(processed_frame, cv2.COLOR_BGR2RGB), None

    avg_fps = total_fps / frame_count

    cap.release()
    out_writer.release()
    print(f"Average FPS: {avg_fps:.3f}")
    yield None, out_video_path


# ----------------------------------------------------------------
# BUILD THE GRADIO INTERFACES
# ----------------------------------------------------------------

# For demonstration, we define two possible model radio choices.
# You can ignore or expand this if you only use RetinaNet.
resize_dim = gr.Slider(100, 1024, value=640, label="Resize Dimension", info="Resize image to this dimension")
threshold = gr.Slider(0, 1, value=0.5, label="Threshold", info="Confidence threshold for detection")
inputs_image = gr.Image(type="filepath", label="Input Image")
outputs_image = gr.Image(type="numpy", label="Output Image")

interface_image = gr.Interface(
    fn=img_inf,
    inputs=[inputs_image, resize_dim, threshold],
    outputs=outputs_image,
    title="Image Inference",
    description="Upload your photo, select a model, and see the results!",
    examples=[["examples/buffalo.jpg"], ["examples/zebra.jpg"]],
    cache_examples=False,
)

resize_dim = gr.Slider(100, 1024, value=640, label="Resize Dimension", info="Resize image to this dimension")
threshold = gr.Slider(0, 1, value=0.5, label="Threshold", info="Confidence threshold for detection")
input_video = gr.Video(label="Input Video")

# Output is a pair: (last_processed_frame, output_video_path)
output_frame = gr.Image(type="numpy", label="Output (Last Processed Frame)")
output_video_file = gr.Video(format="mp4", label="Output Video")

interface_video = gr.Interface(
    fn=vid_inf,
    inputs=[input_video, resize_dim, threshold],
    outputs=[output_frame, output_video_file],
    title="Video Inference",
    description="Upload your video and see the processed output!",
    examples=[["examples/elephants.mp4"], ["examples/rhino.mp4"]],
    cache_examples=False,
)

# Combine them in a Tabbed Interface
demo = (
    gr.TabbedInterface(
        [interface_image, interface_video],
        tab_names=["Image", "Video"],
        title="FineTuning RetinaNet for Wildlife Animal Detection",
        theme="gstaff/xkcd",
    )
    .queue()
    .launch()
)