Upload 9 files
Browse files- .gitattributes +4 -0
- app.py +235 -0
- config.py +25 -0
- examples/buffalo.jpg +3 -0
- examples/elephants.mp4 +3 -0
- examples/rhino.mp4 +3 -0
- examples/zebra.jpg +3 -0
- model.py +33 -0
- outputs/best_model_79.pth +3 -0
- requirements.txt +5 -0
.gitattributes
CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
examples/buffalo.jpg filter=lfs diff=lfs merge=lfs -text
|
37 |
+
examples/elephants.mp4 filter=lfs diff=lfs merge=lfs -text
|
38 |
+
examples/rhino.mp4 filter=lfs diff=lfs merge=lfs -text
|
39 |
+
examples/zebra.jpg filter=lfs diff=lfs merge=lfs -text
|
app.py
ADDED
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import time
|
4 |
+
import torch
|
5 |
+
import gradio as gr
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
# Make sure these are your local imports from your project.
|
9 |
+
from model import create_model
|
10 |
+
from config import NUM_CLASSES, DEVICE, CLASSES
|
11 |
+
|
12 |
+
# ----------------------------------------------------------------
|
13 |
+
# GLOBAL SETUP
|
14 |
+
# ----------------------------------------------------------------
|
15 |
+
# Create the model and load the best weights.
|
16 |
+
model = create_model(num_classes=NUM_CLASSES)
|
17 |
+
checkpoint = torch.load("outputs/best_model_79.pth", map_location=DEVICE)
|
18 |
+
model.load_state_dict(checkpoint["model_state_dict"])
|
19 |
+
model.to(DEVICE).eval()
|
20 |
+
|
21 |
+
# Create a colors array for each class index.
|
22 |
+
# (length matches len(CLASSES), including background if you wish).
|
23 |
+
COLORS = np.random.uniform(0, 255, size=(len(CLASSES), 3))
|
24 |
+
|
25 |
+
# COLORS = [
|
26 |
+
# (255, 255, 0), # Cyan - background
|
27 |
+
# (50, 0, 255), # Red - buffalo
|
28 |
+
# (147, 20, 255), # Pink - elephant
|
29 |
+
# (0, 255, 0), # Green - rhino
|
30 |
+
# (238, 130, 238), # Violet - zebra
|
31 |
+
# ]
|
32 |
+
|
33 |
+
|
34 |
+
# ----------------------------------------------------------------
|
35 |
+
# HELPER FUNCTIONS
|
36 |
+
# ----------------------------------------------------------------
|
37 |
+
def inference_on_image(orig_image: np.ndarray, resize_dim=None, threshold=0.25):
|
38 |
+
"""
|
39 |
+
Runs inference on a single image (OpenCV BGR or NumPy array).
|
40 |
+
- resize_dim: if not None, we resize to (resize_dim, resize_dim)
|
41 |
+
- threshold: detection confidence threshold
|
42 |
+
Returns: processed image with bounding boxes drawn.
|
43 |
+
"""
|
44 |
+
image = orig_image.copy()
|
45 |
+
# Optionally resize for inference.
|
46 |
+
if resize_dim is not None:
|
47 |
+
image = cv2.resize(image, (resize_dim, resize_dim))
|
48 |
+
|
49 |
+
# Convert BGR to RGB, normalize [0..1]
|
50 |
+
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
|
51 |
+
# Move channels to front (C,H,W)
|
52 |
+
image_tensor = torch.tensor(image_rgb.transpose(2, 0, 1), dtype=torch.float).unsqueeze(0).to(DEVICE)
|
53 |
+
start_time = time.time()
|
54 |
+
# Inference
|
55 |
+
with torch.no_grad():
|
56 |
+
outputs = model(image_tensor)
|
57 |
+
end_time = time.time()
|
58 |
+
# Get the current fps.
|
59 |
+
fps = 1 / (end_time - start_time)
|
60 |
+
fps_text = f"FPS: {fps:.2f}"
|
61 |
+
# Move outputs to CPU numpy
|
62 |
+
outputs = [{k: v.cpu() for k, v in t.items()} for t in outputs]
|
63 |
+
boxes = outputs[0]["boxes"].numpy()
|
64 |
+
scores = outputs[0]["scores"].numpy()
|
65 |
+
labels = outputs[0]["labels"].numpy().astype(int)
|
66 |
+
|
67 |
+
# Filter out boxes with low confidence
|
68 |
+
valid_idx = np.where(scores >= threshold)[0]
|
69 |
+
boxes = boxes[valid_idx].astype(int)
|
70 |
+
labels = labels[valid_idx]
|
71 |
+
|
72 |
+
# If we resized for inference, rescale boxes back to orig_image size
|
73 |
+
if resize_dim is not None:
|
74 |
+
h_orig, w_orig = orig_image.shape[:2]
|
75 |
+
h_new, w_new = resize_dim, resize_dim
|
76 |
+
# scale boxes
|
77 |
+
boxes[:, [0, 2]] = (boxes[:, [0, 2]] / w_new) * w_orig
|
78 |
+
boxes[:, [1, 3]] = (boxes[:, [1, 3]] / h_new) * h_orig
|
79 |
+
|
80 |
+
# Draw bounding boxes
|
81 |
+
for box, label_idx in zip(boxes, labels):
|
82 |
+
class_name = CLASSES[label_idx] if 0 <= label_idx < len(CLASSES) else str(label_idx)
|
83 |
+
color = COLORS[label_idx % len(COLORS)][::-1] # BGR color
|
84 |
+
cv2.rectangle(orig_image, (box[0], box[1]), (box[2], box[3]), color, 5)
|
85 |
+
cv2.putText(orig_image, class_name, (box[0], box[1] - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0,0,0), 3)
|
86 |
+
cv2.putText(
|
87 |
+
orig_image,
|
88 |
+
fps_text,
|
89 |
+
(int((w_orig / 2) - 50), 30),
|
90 |
+
cv2.FONT_HERSHEY_SIMPLEX,
|
91 |
+
0.8,
|
92 |
+
(0, 255, 0),
|
93 |
+
2,
|
94 |
+
cv2.LINE_AA,
|
95 |
+
)
|
96 |
+
return orig_image, fps
|
97 |
+
|
98 |
+
|
99 |
+
def inference_on_frame(frame: np.ndarray, resize_dim=None, threshold=0.25):
|
100 |
+
"""
|
101 |
+
Same as inference_on_image but for a single video frame.
|
102 |
+
Returns the processed frame with bounding boxes.
|
103 |
+
"""
|
104 |
+
return inference_on_image(frame, resize_dim, threshold)
|
105 |
+
|
106 |
+
|
107 |
+
# ----------------------------------------------------------------
|
108 |
+
# GRADIO FUNCTIONS
|
109 |
+
# ----------------------------------------------------------------
|
110 |
+
|
111 |
+
|
112 |
+
def img_inf(image_path, resize_dim, threshold):
|
113 |
+
"""
|
114 |
+
Gradio function for image inference.
|
115 |
+
:param image_path: File path from Gradio (uploaded image).
|
116 |
+
:param model_name: Selected model from Radio (not used if only one model).
|
117 |
+
Returns: A NumPy image array with bounding boxes.
|
118 |
+
"""
|
119 |
+
if image_path is None:
|
120 |
+
return None # No image provided
|
121 |
+
orig_image = cv2.imread(image_path) # BGR
|
122 |
+
if orig_image is None:
|
123 |
+
return None # Error reading image
|
124 |
+
|
125 |
+
result_image, _ = inference_on_image(orig_image, resize_dim=resize_dim, threshold=threshold)
|
126 |
+
# Return the image in RGB for Gradio's display
|
127 |
+
result_image_rgb = cv2.cvtColor(result_image, cv2.COLOR_BGR2RGB)
|
128 |
+
return result_image_rgb
|
129 |
+
|
130 |
+
|
131 |
+
def vid_inf(video_path, resize_dim, threshold):
|
132 |
+
"""
|
133 |
+
Gradio function for video inference.
|
134 |
+
Processes each frame, draws bounding boxes, and writes to an output video.
|
135 |
+
Returns: (last_processed_frame, output_video_file_path)
|
136 |
+
"""
|
137 |
+
if video_path is None:
|
138 |
+
return None, None # No video provided
|
139 |
+
|
140 |
+
# Prepare input capture
|
141 |
+
cap = cv2.VideoCapture(video_path)
|
142 |
+
if not cap.isOpened():
|
143 |
+
return None, None
|
144 |
+
|
145 |
+
# Create an output file path
|
146 |
+
os.makedirs("inference_outputs/videos", exist_ok=True)
|
147 |
+
out_video_path = os.path.join("inference_outputs/videos", "video_output.mp4")
|
148 |
+
# out_video_path = "video_output.mp4"
|
149 |
+
|
150 |
+
# Get video properties
|
151 |
+
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
152 |
+
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
153 |
+
fps = cap.get(cv2.CAP_PROP_FPS)
|
154 |
+
fourcc = cv2.VideoWriter_fourcc(*"mp4v") # or 'XVID'
|
155 |
+
|
156 |
+
# If FPS is 0 (some weird container), default to something
|
157 |
+
if fps <= 0:
|
158 |
+
fps = 20.0
|
159 |
+
|
160 |
+
out_writer = cv2.VideoWriter(out_video_path, fourcc, fps, (width, height))
|
161 |
+
|
162 |
+
frame_count = 0
|
163 |
+
total_fps = 0
|
164 |
+
|
165 |
+
while True:
|
166 |
+
ret, frame = cap.read()
|
167 |
+
if not ret:
|
168 |
+
break
|
169 |
+
# Inference on frame
|
170 |
+
processed_frame, frame_fps = inference_on_frame(frame, resize_dim=resize_dim, threshold=threshold)
|
171 |
+
total_fps += frame_fps
|
172 |
+
frame_count += 1
|
173 |
+
|
174 |
+
# Write the processed frame
|
175 |
+
out_writer.write(processed_frame)
|
176 |
+
yield cv2.cvtColor(processed_frame, cv2.COLOR_BGR2RGB), None
|
177 |
+
|
178 |
+
avg_fps = total_fps / frame_count
|
179 |
+
|
180 |
+
cap.release()
|
181 |
+
out_writer.release()
|
182 |
+
print(f"Average FPS: {avg_fps:.3f}")
|
183 |
+
yield None, out_video_path
|
184 |
+
|
185 |
+
|
186 |
+
# ----------------------------------------------------------------
|
187 |
+
# BUILD THE GRADIO INTERFACES
|
188 |
+
# ----------------------------------------------------------------
|
189 |
+
|
190 |
+
# For demonstration, we define two possible model radio choices.
|
191 |
+
# You can ignore or expand this if you only use RetinaNet.
|
192 |
+
resize_dim = gr.Slider(100, 1024, value=640, label="Resize Dimension", info="Resize image to this dimension")
|
193 |
+
threshold = gr.Slider(0, 1, value=0.5, label="Threshold", info="Confidence threshold for detection")
|
194 |
+
inputs_image = gr.Image(type="filepath", label="Input Image")
|
195 |
+
outputs_image = gr.Image(type="numpy", label="Output Image")
|
196 |
+
|
197 |
+
interface_image = gr.Interface(
|
198 |
+
fn=img_inf,
|
199 |
+
inputs=[inputs_image, resize_dim, threshold],
|
200 |
+
outputs=outputs_image,
|
201 |
+
title="Image Inference",
|
202 |
+
description="Upload your photo, select a model, and see the results!",
|
203 |
+
examples=[["examples/buffalo.jpg"], ["examples/zebra.jpg"]],
|
204 |
+
cache_examples=False,
|
205 |
+
)
|
206 |
+
|
207 |
+
resize_dim = gr.Slider(100, 1024, value=640, label="Resize Dimension", info="Resize image to this dimension")
|
208 |
+
threshold = gr.Slider(0, 1, value=0.5, label="Threshold", info="Confidence threshold for detection")
|
209 |
+
input_video = gr.Video(label="Input Video")
|
210 |
+
|
211 |
+
# Output is a pair: (last_processed_frame, output_video_path)
|
212 |
+
output_frame = gr.Image(type="numpy", label="Output (Last Processed Frame)")
|
213 |
+
output_video_file = gr.Video(format="mp4", label="Output Video")
|
214 |
+
|
215 |
+
interface_video = gr.Interface(
|
216 |
+
fn=vid_inf,
|
217 |
+
inputs=[input_video, resize_dim, threshold],
|
218 |
+
outputs=[output_frame, output_video_file],
|
219 |
+
title="Video Inference",
|
220 |
+
description="Upload your video and see the processed output!",
|
221 |
+
examples=[["examples/elephants.mp4"], ["examples/rhino.mp4"]],
|
222 |
+
cache_examples=False,
|
223 |
+
)
|
224 |
+
|
225 |
+
# Combine them in a Tabbed Interface
|
226 |
+
demo = (
|
227 |
+
gr.TabbedInterface(
|
228 |
+
[interface_image, interface_video],
|
229 |
+
tab_names=["Image", "Video"],
|
230 |
+
title="FineTuning RetinaNet for Wildlife Animal Detection",
|
231 |
+
theme="gstaff/xkcd",
|
232 |
+
)
|
233 |
+
.queue()
|
234 |
+
.launch()
|
235 |
+
)
|
config.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
BATCH_SIZE = 8 # Increase / decrease according to GPU memeory.
|
4 |
+
RESIZE_TO = 640 # Resize the image for training and transforms.
|
5 |
+
NUM_EPOCHS = 60 # Number of epochs to train for.
|
6 |
+
NUM_WORKERS = 4 # Number of parallel workers for data loading.
|
7 |
+
|
8 |
+
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
9 |
+
|
10 |
+
# Training images and labels files directory.
|
11 |
+
TRAIN_DIR = "data/train"
|
12 |
+
# Validation images and labels files directory.
|
13 |
+
VALID_DIR = "data/valid"
|
14 |
+
|
15 |
+
# Classes: 0 index is reserved for background.
|
16 |
+
CLASSES = ["__background__", "buffalo", "elephant", "rhino", "zebra"]
|
17 |
+
|
18 |
+
|
19 |
+
NUM_CLASSES = len(CLASSES)
|
20 |
+
|
21 |
+
# Whether to visualize images after crearing the data loaders.
|
22 |
+
VISUALIZE_TRANSFORMED_IMAGES = True
|
23 |
+
|
24 |
+
# Location to save model and plots.
|
25 |
+
OUT_DIR = "outputs"
|
examples/buffalo.jpg
ADDED
![]() |
Git LFS Details
|
examples/elephants.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f7cd66a941d8883505826e1b191c6c45f21a2f9cad05301f1ce62da676b431a3
|
3 |
+
size 3617117
|
examples/rhino.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c151ff03aa5a01c4604ccc7cca9dd1518eb8faf807a320001f7efb2598effcef
|
3 |
+
size 9404729
|
examples/zebra.jpg
ADDED
![]() |
Git LFS Details
|
model.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torchvision
|
2 |
+
import torch
|
3 |
+
|
4 |
+
from functools import partial
|
5 |
+
from torchvision.models.detection import RetinaNet_ResNet50_FPN_V2_Weights
|
6 |
+
from torchvision.models.detection.retinanet import RetinaNetClassificationHead
|
7 |
+
from config import NUM_CLASSES
|
8 |
+
|
9 |
+
|
10 |
+
def create_model(num_classes=91):
|
11 |
+
"""
|
12 |
+
Creates a RetinaNet-ResNet50-FPN v2 model pre-trained on COCO.
|
13 |
+
Replaces the classification head for the required number of classes.
|
14 |
+
"""
|
15 |
+
model = torchvision.models.detection.retinanet_resnet50_fpn_v2(weights=RetinaNet_ResNet50_FPN_V2_Weights.COCO_V1)
|
16 |
+
num_anchors = model.head.classification_head.num_anchors
|
17 |
+
|
18 |
+
# Replace the classification head
|
19 |
+
model.head.classification_head = RetinaNetClassificationHead(
|
20 |
+
in_channels=256, num_anchors=num_anchors, num_classes=num_classes, norm_layer=partial(torch.nn.GroupNorm, 32)
|
21 |
+
)
|
22 |
+
return model
|
23 |
+
|
24 |
+
|
25 |
+
if __name__ == "__main__":
|
26 |
+
model = create_model(num_classes=NUM_CLASSES)
|
27 |
+
print(model)
|
28 |
+
# Total parameters:
|
29 |
+
total_params = sum(p.numel() for p in model.parameters())
|
30 |
+
print(f"{total_params:,} total parameters.")
|
31 |
+
# Trainable parameters:
|
32 |
+
total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
33 |
+
print(f"{total_trainable_params:,} training parameters.")
|
outputs/best_model_79.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:82a9d5634acb4adeecdf8417cae2d26c6389ca7279472bc65fbc907a28300047
|
3 |
+
size 146001704
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
opencv-python==4.11.0.86
|
2 |
+
torch==2.6.0
|
3 |
+
torchvision==0.21.0
|
4 |
+
torchaudio==2.6.0
|
5 |
+
gradio==5.18.0
|