ankanpy commited on
Commit
b959f6e
·
verified ·
1 Parent(s): 41fe7af

Upload 9 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ examples/buffalo.jpg filter=lfs diff=lfs merge=lfs -text
37
+ examples/elephants.mp4 filter=lfs diff=lfs merge=lfs -text
38
+ examples/rhino.mp4 filter=lfs diff=lfs merge=lfs -text
39
+ examples/zebra.jpg filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import time
4
+ import torch
5
+ import gradio as gr
6
+ import numpy as np
7
+
8
+ # Make sure these are your local imports from your project.
9
+ from model import create_model
10
+ from config import NUM_CLASSES, DEVICE, CLASSES
11
+
12
+ # ----------------------------------------------------------------
13
+ # GLOBAL SETUP
14
+ # ----------------------------------------------------------------
15
+ # Create the model and load the best weights.
16
+ model = create_model(num_classes=NUM_CLASSES)
17
+ checkpoint = torch.load("outputs/best_model_79.pth", map_location=DEVICE)
18
+ model.load_state_dict(checkpoint["model_state_dict"])
19
+ model.to(DEVICE).eval()
20
+
21
+ # Create a colors array for each class index.
22
+ # (length matches len(CLASSES), including background if you wish).
23
+ COLORS = np.random.uniform(0, 255, size=(len(CLASSES), 3))
24
+
25
+ # COLORS = [
26
+ # (255, 255, 0), # Cyan - background
27
+ # (50, 0, 255), # Red - buffalo
28
+ # (147, 20, 255), # Pink - elephant
29
+ # (0, 255, 0), # Green - rhino
30
+ # (238, 130, 238), # Violet - zebra
31
+ # ]
32
+
33
+
34
+ # ----------------------------------------------------------------
35
+ # HELPER FUNCTIONS
36
+ # ----------------------------------------------------------------
37
+ def inference_on_image(orig_image: np.ndarray, resize_dim=None, threshold=0.25):
38
+ """
39
+ Runs inference on a single image (OpenCV BGR or NumPy array).
40
+ - resize_dim: if not None, we resize to (resize_dim, resize_dim)
41
+ - threshold: detection confidence threshold
42
+ Returns: processed image with bounding boxes drawn.
43
+ """
44
+ image = orig_image.copy()
45
+ # Optionally resize for inference.
46
+ if resize_dim is not None:
47
+ image = cv2.resize(image, (resize_dim, resize_dim))
48
+
49
+ # Convert BGR to RGB, normalize [0..1]
50
+ image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
51
+ # Move channels to front (C,H,W)
52
+ image_tensor = torch.tensor(image_rgb.transpose(2, 0, 1), dtype=torch.float).unsqueeze(0).to(DEVICE)
53
+ start_time = time.time()
54
+ # Inference
55
+ with torch.no_grad():
56
+ outputs = model(image_tensor)
57
+ end_time = time.time()
58
+ # Get the current fps.
59
+ fps = 1 / (end_time - start_time)
60
+ fps_text = f"FPS: {fps:.2f}"
61
+ # Move outputs to CPU numpy
62
+ outputs = [{k: v.cpu() for k, v in t.items()} for t in outputs]
63
+ boxes = outputs[0]["boxes"].numpy()
64
+ scores = outputs[0]["scores"].numpy()
65
+ labels = outputs[0]["labels"].numpy().astype(int)
66
+
67
+ # Filter out boxes with low confidence
68
+ valid_idx = np.where(scores >= threshold)[0]
69
+ boxes = boxes[valid_idx].astype(int)
70
+ labels = labels[valid_idx]
71
+
72
+ # If we resized for inference, rescale boxes back to orig_image size
73
+ if resize_dim is not None:
74
+ h_orig, w_orig = orig_image.shape[:2]
75
+ h_new, w_new = resize_dim, resize_dim
76
+ # scale boxes
77
+ boxes[:, [0, 2]] = (boxes[:, [0, 2]] / w_new) * w_orig
78
+ boxes[:, [1, 3]] = (boxes[:, [1, 3]] / h_new) * h_orig
79
+
80
+ # Draw bounding boxes
81
+ for box, label_idx in zip(boxes, labels):
82
+ class_name = CLASSES[label_idx] if 0 <= label_idx < len(CLASSES) else str(label_idx)
83
+ color = COLORS[label_idx % len(COLORS)][::-1] # BGR color
84
+ cv2.rectangle(orig_image, (box[0], box[1]), (box[2], box[3]), color, 5)
85
+ cv2.putText(orig_image, class_name, (box[0], box[1] - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0,0,0), 3)
86
+ cv2.putText(
87
+ orig_image,
88
+ fps_text,
89
+ (int((w_orig / 2) - 50), 30),
90
+ cv2.FONT_HERSHEY_SIMPLEX,
91
+ 0.8,
92
+ (0, 255, 0),
93
+ 2,
94
+ cv2.LINE_AA,
95
+ )
96
+ return orig_image, fps
97
+
98
+
99
+ def inference_on_frame(frame: np.ndarray, resize_dim=None, threshold=0.25):
100
+ """
101
+ Same as inference_on_image but for a single video frame.
102
+ Returns the processed frame with bounding boxes.
103
+ """
104
+ return inference_on_image(frame, resize_dim, threshold)
105
+
106
+
107
+ # ----------------------------------------------------------------
108
+ # GRADIO FUNCTIONS
109
+ # ----------------------------------------------------------------
110
+
111
+
112
+ def img_inf(image_path, resize_dim, threshold):
113
+ """
114
+ Gradio function for image inference.
115
+ :param image_path: File path from Gradio (uploaded image).
116
+ :param model_name: Selected model from Radio (not used if only one model).
117
+ Returns: A NumPy image array with bounding boxes.
118
+ """
119
+ if image_path is None:
120
+ return None # No image provided
121
+ orig_image = cv2.imread(image_path) # BGR
122
+ if orig_image is None:
123
+ return None # Error reading image
124
+
125
+ result_image, _ = inference_on_image(orig_image, resize_dim=resize_dim, threshold=threshold)
126
+ # Return the image in RGB for Gradio's display
127
+ result_image_rgb = cv2.cvtColor(result_image, cv2.COLOR_BGR2RGB)
128
+ return result_image_rgb
129
+
130
+
131
+ def vid_inf(video_path, resize_dim, threshold):
132
+ """
133
+ Gradio function for video inference.
134
+ Processes each frame, draws bounding boxes, and writes to an output video.
135
+ Returns: (last_processed_frame, output_video_file_path)
136
+ """
137
+ if video_path is None:
138
+ return None, None # No video provided
139
+
140
+ # Prepare input capture
141
+ cap = cv2.VideoCapture(video_path)
142
+ if not cap.isOpened():
143
+ return None, None
144
+
145
+ # Create an output file path
146
+ os.makedirs("inference_outputs/videos", exist_ok=True)
147
+ out_video_path = os.path.join("inference_outputs/videos", "video_output.mp4")
148
+ # out_video_path = "video_output.mp4"
149
+
150
+ # Get video properties
151
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
152
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
153
+ fps = cap.get(cv2.CAP_PROP_FPS)
154
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v") # or 'XVID'
155
+
156
+ # If FPS is 0 (some weird container), default to something
157
+ if fps <= 0:
158
+ fps = 20.0
159
+
160
+ out_writer = cv2.VideoWriter(out_video_path, fourcc, fps, (width, height))
161
+
162
+ frame_count = 0
163
+ total_fps = 0
164
+
165
+ while True:
166
+ ret, frame = cap.read()
167
+ if not ret:
168
+ break
169
+ # Inference on frame
170
+ processed_frame, frame_fps = inference_on_frame(frame, resize_dim=resize_dim, threshold=threshold)
171
+ total_fps += frame_fps
172
+ frame_count += 1
173
+
174
+ # Write the processed frame
175
+ out_writer.write(processed_frame)
176
+ yield cv2.cvtColor(processed_frame, cv2.COLOR_BGR2RGB), None
177
+
178
+ avg_fps = total_fps / frame_count
179
+
180
+ cap.release()
181
+ out_writer.release()
182
+ print(f"Average FPS: {avg_fps:.3f}")
183
+ yield None, out_video_path
184
+
185
+
186
+ # ----------------------------------------------------------------
187
+ # BUILD THE GRADIO INTERFACES
188
+ # ----------------------------------------------------------------
189
+
190
+ # For demonstration, we define two possible model radio choices.
191
+ # You can ignore or expand this if you only use RetinaNet.
192
+ resize_dim = gr.Slider(100, 1024, value=640, label="Resize Dimension", info="Resize image to this dimension")
193
+ threshold = gr.Slider(0, 1, value=0.5, label="Threshold", info="Confidence threshold for detection")
194
+ inputs_image = gr.Image(type="filepath", label="Input Image")
195
+ outputs_image = gr.Image(type="numpy", label="Output Image")
196
+
197
+ interface_image = gr.Interface(
198
+ fn=img_inf,
199
+ inputs=[inputs_image, resize_dim, threshold],
200
+ outputs=outputs_image,
201
+ title="Image Inference",
202
+ description="Upload your photo, select a model, and see the results!",
203
+ examples=[["examples/buffalo.jpg"], ["examples/zebra.jpg"]],
204
+ cache_examples=False,
205
+ )
206
+
207
+ resize_dim = gr.Slider(100, 1024, value=640, label="Resize Dimension", info="Resize image to this dimension")
208
+ threshold = gr.Slider(0, 1, value=0.5, label="Threshold", info="Confidence threshold for detection")
209
+ input_video = gr.Video(label="Input Video")
210
+
211
+ # Output is a pair: (last_processed_frame, output_video_path)
212
+ output_frame = gr.Image(type="numpy", label="Output (Last Processed Frame)")
213
+ output_video_file = gr.Video(format="mp4", label="Output Video")
214
+
215
+ interface_video = gr.Interface(
216
+ fn=vid_inf,
217
+ inputs=[input_video, resize_dim, threshold],
218
+ outputs=[output_frame, output_video_file],
219
+ title="Video Inference",
220
+ description="Upload your video and see the processed output!",
221
+ examples=[["examples/elephants.mp4"], ["examples/rhino.mp4"]],
222
+ cache_examples=False,
223
+ )
224
+
225
+ # Combine them in a Tabbed Interface
226
+ demo = (
227
+ gr.TabbedInterface(
228
+ [interface_image, interface_video],
229
+ tab_names=["Image", "Video"],
230
+ title="FineTuning RetinaNet for Wildlife Animal Detection",
231
+ theme="gstaff/xkcd",
232
+ )
233
+ .queue()
234
+ .launch()
235
+ )
config.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ BATCH_SIZE = 8 # Increase / decrease according to GPU memeory.
4
+ RESIZE_TO = 640 # Resize the image for training and transforms.
5
+ NUM_EPOCHS = 60 # Number of epochs to train for.
6
+ NUM_WORKERS = 4 # Number of parallel workers for data loading.
7
+
8
+ DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
9
+
10
+ # Training images and labels files directory.
11
+ TRAIN_DIR = "data/train"
12
+ # Validation images and labels files directory.
13
+ VALID_DIR = "data/valid"
14
+
15
+ # Classes: 0 index is reserved for background.
16
+ CLASSES = ["__background__", "buffalo", "elephant", "rhino", "zebra"]
17
+
18
+
19
+ NUM_CLASSES = len(CLASSES)
20
+
21
+ # Whether to visualize images after crearing the data loaders.
22
+ VISUALIZE_TRANSFORMED_IMAGES = True
23
+
24
+ # Location to save model and plots.
25
+ OUT_DIR = "outputs"
examples/buffalo.jpg ADDED

Git LFS Details

  • SHA256: 183a9c85a77a44ff66ab80bdb3ccdb32d34b89e8e089865e67496dfb08e7443a
  • Pointer size: 131 Bytes
  • Size of remote file: 324 kB
examples/elephants.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f7cd66a941d8883505826e1b191c6c45f21a2f9cad05301f1ce62da676b431a3
3
+ size 3617117
examples/rhino.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c151ff03aa5a01c4604ccc7cca9dd1518eb8faf807a320001f7efb2598effcef
3
+ size 9404729
examples/zebra.jpg ADDED

Git LFS Details

  • SHA256: 5f2610182725e882f3033ef7c763b2eb04641df44e1e4554f587ff5faa09cb0c
  • Pointer size: 131 Bytes
  • Size of remote file: 121 kB
model.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchvision
2
+ import torch
3
+
4
+ from functools import partial
5
+ from torchvision.models.detection import RetinaNet_ResNet50_FPN_V2_Weights
6
+ from torchvision.models.detection.retinanet import RetinaNetClassificationHead
7
+ from config import NUM_CLASSES
8
+
9
+
10
+ def create_model(num_classes=91):
11
+ """
12
+ Creates a RetinaNet-ResNet50-FPN v2 model pre-trained on COCO.
13
+ Replaces the classification head for the required number of classes.
14
+ """
15
+ model = torchvision.models.detection.retinanet_resnet50_fpn_v2(weights=RetinaNet_ResNet50_FPN_V2_Weights.COCO_V1)
16
+ num_anchors = model.head.classification_head.num_anchors
17
+
18
+ # Replace the classification head
19
+ model.head.classification_head = RetinaNetClassificationHead(
20
+ in_channels=256, num_anchors=num_anchors, num_classes=num_classes, norm_layer=partial(torch.nn.GroupNorm, 32)
21
+ )
22
+ return model
23
+
24
+
25
+ if __name__ == "__main__":
26
+ model = create_model(num_classes=NUM_CLASSES)
27
+ print(model)
28
+ # Total parameters:
29
+ total_params = sum(p.numel() for p in model.parameters())
30
+ print(f"{total_params:,} total parameters.")
31
+ # Trainable parameters:
32
+ total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
33
+ print(f"{total_trainable_params:,} training parameters.")
outputs/best_model_79.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:82a9d5634acb4adeecdf8417cae2d26c6389ca7279472bc65fbc907a28300047
3
+ size 146001704
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ opencv-python==4.11.0.86
2
+ torch==2.6.0
3
+ torchvision==0.21.0
4
+ torchaudio==2.6.0
5
+ gradio==5.18.0