danielaruizl1 commited on
Commit
98c0b3a
·
verified ·
1 Parent(s): 048f4e5

Upload gradio_demo.py

Browse files

megadetectorv6 updates

Files changed (1) hide show
  1. gradio_demo.py +346 -0
gradio_demo.py ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """ Gradio Demo for image detection"""
5
+
6
+ # Importing necessary basic libraries and modules
7
+ import os
8
+
9
+ # PyTorch imports
10
+ import torch
11
+ from torch.utils.data import DataLoader
12
+
13
+ # Importing the model, dataset, transformations and utility functions from PytorchWildlife
14
+ from PytorchWildlife.models import detection as pw_detection
15
+ from PytorchWildlife import utils as pw_utils
16
+
17
+ # Importing basic libraries
18
+ import shutil
19
+ import time
20
+ from PIL import Image
21
+ import supervision as sv
22
+ import gradio as gr
23
+ from zipfile import ZipFile
24
+ import numpy as np
25
+ import ast
26
+
27
+ # Importing the models, dataset, transformations, and utility functions from PytorchWildlife
28
+ from PytorchWildlife.models import classification as pw_classification
29
+ from PytorchWildlife.data import transforms as pw_trans
30
+ from PytorchWildlife.data import datasets as pw_data
31
+
32
+ # Setting the device to use for computations ('cuda' indicates GPU)
33
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
34
+ # Initializing a supervision box annotator for visualizing detections
35
+ dot_annotator = sv.DotAnnotator(radius=6)
36
+ box_annotator = sv.BoxAnnotator(thickness=4)
37
+ lab_annotator = sv.LabelAnnotator(text_color=sv.Color.BLACK, text_thickness=4, text_scale=2)
38
+ # Create a temp folder
39
+ os.makedirs(os.path.join("..","temp"), exist_ok=True) # ASK: Why do we need this?
40
+
41
+ # Initializing the detection and classification models
42
+ detection_model = None
43
+ classification_model = None
44
+
45
+ # Defining functions for different detection scenarios
46
+ def load_models(det, version, clf, wpath=None, wclass=None):
47
+
48
+ global detection_model, classification_model
49
+ if det != "None":
50
+ if det == "HerdNet General":
51
+ detection_model = pw_detection.HerdNet(device=DEVICE)
52
+ elif det == "HerdNet Ennedi":
53
+ detection_model = pw_detection.HerdNet(device=DEVICE, version="ennedi")
54
+ else:
55
+ detection_model = pw_detection.__dict__[det](device=DEVICE, pretrained=True, version=version)
56
+ else:
57
+ detection_model = None
58
+ return "NO MODEL LOADED!!"
59
+
60
+ if clf != "None":
61
+ # Create an exception for custom weights
62
+ if clf == "CustomWeights":
63
+ if (wpath is not None) and (wclass is not None):
64
+ wclass = ast.literal_eval(wclass)
65
+ classification_model = pw_classification.__dict__[clf](weights=wpath, class_names=wclass, device=DEVICE)
66
+ else:
67
+ classification_model = pw_classification.__dict__[clf](device=DEVICE, pretrained=True)
68
+ else:
69
+ classification_model = None
70
+
71
+ return "Loaded Detector: {}. Version: {}. Loaded Classifier: {}".format(det, version, clf)
72
+
73
+
74
+ def single_image_detection(input_img, det_conf_thres, clf_conf_thres, img_index=None):
75
+ """Performs detection on a single image and returns an annotated image.
76
+
77
+ Args:
78
+ input_img (PIL.Image): Input image in PIL.Image format defaulted by Gradio.
79
+ det_conf_thres (float): Confidence threshold for detection.
80
+ clf_conf_thres (float): Confidence threshold for classification.
81
+ img_index: Image index identifier.
82
+ Returns:
83
+ annotated_img (PIL.Image.Image): Annotated image with bounding box instances.
84
+ """
85
+
86
+ input_img = np.array(input_img)
87
+ # If the detection model is HerdNet, use dot annotator, else use box annotator
88
+ if detection_model.__class__.__name__.__contains__("HerdNet"):
89
+ annotator = dot_annotator
90
+ # Herdnet receives both clf and det confidence thresholds
91
+ results_det = detection_model.single_image_detection(input_img,
92
+ img_path=img_index,
93
+ det_conf_thres=det_conf_thres,
94
+ clf_conf_thres=clf_conf_thres)
95
+ else:
96
+ annotator = box_annotator
97
+ results_det = detection_model.single_image_detection(input_img,
98
+ img_path=img_index,
99
+ det_conf_thres = det_conf_thres)
100
+
101
+ if classification_model is not None:
102
+ labels = []
103
+ for i, (xyxy, det_id) in enumerate(zip(results_det["detections"].xyxy, results_det["detections"].class_id)):
104
+ # Only run classifier when detection class is animal
105
+ if det_id == 0:
106
+ cropped_image = sv.crop_image(image=input_img, xyxy=xyxy)
107
+ results_clf = classification_model.single_image_classification(cropped_image)
108
+ labels.append("{} {:.2f}".format(results_clf["prediction"] if results_clf["confidence"] > clf_conf_thres else "Unknown",
109
+ results_clf["confidence"]))
110
+ else:
111
+ labels.append(results_det["labels"][i])
112
+ else:
113
+ labels = results_det["labels"]
114
+
115
+ annotated_img = lab_annotator.annotate(
116
+ scene=annotator.annotate(
117
+ scene=input_img,
118
+ detections=results_det["detections"],
119
+ ),
120
+ detections=results_det["detections"],
121
+ labels=labels,
122
+ )
123
+ return annotated_img
124
+
125
+ def batch_detection(zip_file, timelapse, det_conf_thres):
126
+ """Perform detection on a batch of images from a zip file and return path to results JSON.
127
+
128
+ Args:
129
+ zip_file (File): Zip file containing images.
130
+ det_conf_thres (float): Confidence threshold for detection.
131
+ timelapse (boolean): Flag to output JSON for timelapse.
132
+ clf_conf_thres (float): Confidence threshold for classification.
133
+
134
+ Returns:
135
+ json_save_path (str): Path to the JSON file containing detection results.
136
+ """
137
+ # Clean the temp folder if it contains files
138
+ extract_path = os.path.join("..","temp","zip_upload")
139
+ if os.path.exists(extract_path):
140
+ shutil.rmtree(extract_path)
141
+ os.makedirs(extract_path)
142
+
143
+ json_save_path = os.path.join(extract_path, "results.json")
144
+ with ZipFile(zip_file.name) as zfile:
145
+ zfile.extractall(extract_path)
146
+ # Check the contents of the extracted folder
147
+ extracted_files = os.listdir(extract_path)
148
+
149
+ if len(extracted_files) == 1 and os.path.isdir(os.path.join(extract_path, extracted_files[0])):
150
+ tgt_folder_path = os.path.join(extract_path, extracted_files[0])
151
+ else:
152
+ tgt_folder_path = extract_path
153
+ # If the detection model is HerdNet set batch_size to 1
154
+ if detection_model.__class__.__name__.__contains__("HerdNet"):
155
+ det_results = detection_model.batch_image_detection(tgt_folder_path, batch_size=1, det_conf_thres=det_conf_thres, id_strip=tgt_folder_path)
156
+ else:
157
+ det_results = detection_model.batch_image_detection(tgt_folder_path, batch_size=16, det_conf_thres=det_conf_thres, id_strip=tgt_folder_path)
158
+
159
+ if classification_model is not None:
160
+ clf_dataset = pw_data.DetectionCrops(
161
+ det_results,
162
+ transform=pw_trans.Classification_Inference_Transform(target_size=224),
163
+ path_head=tgt_folder_path
164
+ )
165
+ clf_loader = DataLoader(clf_dataset, batch_size=32, shuffle=False,
166
+ pin_memory=True, num_workers=4, drop_last=False)
167
+ clf_results = classification_model.batch_image_classification(clf_loader, id_strip=tgt_folder_path)
168
+ if timelapse:
169
+ json_save_path = json_save_path.replace(".json", "_timelapse.json")
170
+ pw_utils.save_detection_classification_timelapse_json(det_results=det_results,
171
+ clf_results=clf_results,
172
+ det_categories=detection_model.CLASS_NAMES,
173
+ clf_categories=classification_model.CLASS_NAMES,
174
+ output_path=json_save_path)
175
+ else:
176
+ pw_utils.save_detection_classification_json(det_results=det_results,
177
+ clf_results=clf_results,
178
+ det_categories=detection_model.CLASS_NAMES,
179
+ clf_categories=classification_model.CLASS_NAMES,
180
+ output_path=json_save_path)
181
+ else:
182
+ if timelapse:
183
+ json_save_path = json_save_path.replace(".json", "_timelapse.json")
184
+ pw_utils.save_detection_timelapse_json(det_results, json_save_path, categories=detection_model.CLASS_NAMES)
185
+ elif detection_model.__class__.__name__.__contains__("HerdNet"):
186
+ pw_utils.save_detection_json_as_dots(det_results, json_save_path, categories=detection_model.CLASS_NAMES)
187
+ else:
188
+ pw_utils.save_detection_json(det_results, json_save_path, categories=detection_model.CLASS_NAMES)
189
+
190
+ return json_save_path
191
+
192
+ def batch_path_detection(tgt_folder_path, det_conf_thres):
193
+ """Perform detection on a batch of images from a zip file and return path to results JSON.
194
+
195
+ Args:
196
+ tgt_folder_path (str): path to the folder containing the images.
197
+ det_conf_thres (float): Confidence threshold for detection.
198
+ Returns:
199
+ json_save_path (str): Path to the JSON file containing detection results.
200
+ """
201
+
202
+ json_save_path = os.path.join(tgt_folder_path, "results.json")
203
+ det_results = detection_model.batch_image_detection(tgt_folder_path, det_conf_thres=det_conf_thres, id_strip=tgt_folder_path)
204
+ if detection_model.__class__.__name__.__contains__("HerdNet"):
205
+ pw_utils.save_detection_json_as_dots(det_results, json_save_path, categories=detection_model.CLASS_NAMES)
206
+ else:
207
+ pw_utils.save_detection_json(det_results, json_save_path, categories=detection_model.CLASS_NAMES)
208
+
209
+ return json_save_path
210
+
211
+
212
+ def video_detection(video, det_conf_thres, clf_conf_thres, target_fps, codec):
213
+ """Perform detection on a video and return path to processed video.
214
+
215
+ Args:
216
+ video (str): Video source path.
217
+ det_conf_thres (float): Confidence threshold for detection.
218
+ clf_conf_thres (float): Confidence threshold for classification.
219
+
220
+ """
221
+ def callback(frame, index):
222
+ annotated_frame = single_image_detection(frame,
223
+ img_index=index,
224
+ det_conf_thres=det_conf_thres,
225
+ clf_conf_thres=clf_conf_thres)
226
+ return annotated_frame
227
+
228
+ target_path = os.path.join("..","temp","video_detection.mp4")
229
+ pw_utils.process_video(source_path=video, target_path=target_path,
230
+ callback=callback, target_fps=int(target_fps), codec=codec)
231
+ return target_path
232
+
233
+ # Building Gradio UI
234
+
235
+ with gr.Blocks() as demo:
236
+ gr.Markdown("# Pytorch-Wildlife Demo.")
237
+ with gr.Row():
238
+ det_drop = gr.Dropdown(
239
+ ["None", "MegaDetectorV5", "MegaDetectorV6", "HerdNet General", "HerdNet Ennedi"],
240
+ label="Detection model",
241
+ info="Will add more detection models!",
242
+ value="None" # Default
243
+ )
244
+ det_version = gr.Dropdown(
245
+ ["None"],
246
+ label="Model version",
247
+ info="Select the version of the model",
248
+ value="None",
249
+ )
250
+
251
+ with gr.Column():
252
+ clf_drop = gr.Dropdown(
253
+ ["None", "AI4GOpossum", "AI4GAmazonRainforest", "AI4GSnapshotSerengeti", "CustomWeights"],
254
+ interactive=True,
255
+ label="Classification model",
256
+ info="Will add more classification models!",
257
+ visible=False,
258
+ value="None"
259
+ )
260
+ custom_weights_path = gr.Textbox(label="Custom Weights Path", visible=False, interactive=True, placeholder="./weights/my_weight.pt")
261
+ custom_weights_class = gr.Textbox(label="Custom Weights Class", visible=False, interactive=True, placeholder="{1:'ocelot', 2:'cow', 3:'bear'}")
262
+ load_but = gr.Button("Load Models!")
263
+ load_out = gr.Text("NO MODEL LOADED!!", label="Loaded models:")
264
+
265
+ def update_ui_elements(det_model):
266
+ if det_model == "MegaDetectorV6":
267
+ return gr.Dropdown(choices=["MDV6-yolov9-c", "MDV6-yolov9-e", "MDV6-yolov10-c", "MDV6-yolov10-e", "MDV6-rtdetr-c"], interactive=True, label="Model version", value="MDV6-yolov9e"), gr.update(visible=True)
268
+ elif det_model == "MegaDetectorV5":
269
+ return gr.Dropdown(choices=["a", "b"], interactive=True, label="Model version", value="a"), gr.update(visible=True)
270
+ else:
271
+ return gr.Dropdown(choices=["None"], interactive=True, label="Model version", value="None"), gr.update(value="None", visible=False)
272
+
273
+ det_drop.change(update_ui_elements, det_drop, [det_version, clf_drop])
274
+
275
+ def toggle_textboxes(model):
276
+ if model == "CustomWeights":
277
+ return gr.update(visible=True), gr.update(visible=True)
278
+ else:
279
+ return gr.update(visible=False), gr.update(visible=False)
280
+
281
+ clf_drop.change(
282
+ toggle_textboxes,
283
+ clf_drop,
284
+ [custom_weights_path, custom_weights_class]
285
+ )
286
+
287
+ with gr.Tab("Single Image Process"):
288
+ with gr.Row():
289
+ with gr.Column():
290
+ sgl_in = gr.Image(type="pil")
291
+ sgl_conf_sl_det = gr.Slider(0, 1, label="Detection Confidence Threshold", value=0.2)
292
+ sgl_conf_sl_clf = gr.Slider(0, 1, label="Classification Confidence Threshold", value=0.7, visible=True)
293
+ sgl_out = gr.Image()
294
+ sgl_but = gr.Button("Detect Animals!")
295
+ with gr.Tab("Folder Separation"):
296
+ with gr.Row():
297
+ with gr.Column():
298
+ inp_path = gr.Textbox(label="Input path", placeholder="./data/")
299
+ out_path = gr.Textbox(label="Output path", placeholder="./output/")
300
+ bth_conf_fs = gr.Slider(0, 1, label="Detection Confidence Threshold", value=0.2)
301
+ process_btn = gr.Button("Process Files")
302
+ bth_out2 = gr.File(label="Detection Results JSON.", height=200)
303
+ with gr.Column():
304
+ process_files_button = gr.Button("Separate files")
305
+ process_result = gr.Text("Click on 'Separate files' once you see the JSON file", label="Separated files:")
306
+ process_btn.click(batch_path_detection, inputs=[inp_path, bth_conf_fs], outputs=bth_out2)
307
+ process_files_button.click(pw_utils.detection_folder_separation, inputs=[bth_out2, inp_path, out_path, bth_conf_fs], outputs=process_result)
308
+ with gr.Tab("Batch Image Process"):
309
+ with gr.Row():
310
+ with gr.Column():
311
+ bth_in = gr.File(label="Upload zip file.")
312
+ # The timelapse checkbox is only visible when the detection model is not HerdNet
313
+ chck_timelapse = gr.Checkbox(label="Generate timelapse JSON", visible=False)
314
+ bth_conf_sl = gr.Slider(0, 1, label="Detection Confidence Threshold", value=0.2)
315
+ bth_out = gr.File(label="Detection Results JSON.", height=200)
316
+ bth_but = gr.Button("Detect Animals!")
317
+ with gr.Tab("Single Video Process"):
318
+ with gr.Row():
319
+ with gr.Column():
320
+ vid_in = gr.Video(label="Upload a video.")
321
+ vid_conf_sl_det = gr.Slider(0, 1, label="Detection Confidence Threshold", value=0.2)
322
+ vid_conf_sl_clf = gr.Slider(0, 1, label="Classification Confidence Threshold", value=0.7)
323
+ vid_fr = gr.Dropdown([5, 10, 30], label="Output video framerate", value=30)
324
+ vid_enc = gr.Dropdown(
325
+ ["mp4v", "avc1"],
326
+ label="Video encoder",
327
+ info="mp4v is default, av1c is faster (needs conda install opencv)",
328
+ value="mp4v"
329
+ )
330
+ vid_out = gr.Video()
331
+ vid_but = gr.Button("Detect Animals!")
332
+
333
+ # Show timelapsed checkbox only when detection model is not HerdNet
334
+ det_drop.change(
335
+ lambda model: gr.update(visible=True) if "HerdNet" not in model else gr.update(visible=False),
336
+ det_drop,
337
+ [chck_timelapse]
338
+ )
339
+
340
+ load_but.click(load_models, inputs=[det_drop, det_version, clf_drop, custom_weights_path, custom_weights_class], outputs=load_out)
341
+ sgl_but.click(single_image_detection, inputs=[sgl_in, sgl_conf_sl_det, sgl_conf_sl_clf], outputs=sgl_out)
342
+ bth_but.click(batch_detection, inputs=[bth_in, chck_timelapse, bth_conf_sl], outputs=bth_out)
343
+ vid_but.click(video_detection, inputs=[vid_in, vid_conf_sl_det, vid_conf_sl_clf, vid_fr, vid_enc], outputs=vid_out)
344
+
345
+ if __name__ == "__main__":
346
+ demo.launch(share=True)