# Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. """ Gradio Demo for image detection""" # Importing necessary basic libraries and modules import os # PyTorch imports import torch from torch.utils.data import DataLoader # Importing the model, dataset, transformations and utility functions from PytorchWildlife from PytorchWildlife.models import detection as pw_detection from PytorchWildlife import utils as pw_utils # Importing basic libraries import shutil import time from PIL import Image import supervision as sv import gradio as gr from zipfile import ZipFile import numpy as np import ast # Importing the models, dataset, transformations, and utility functions from PytorchWildlife from PytorchWildlife.models import classification as pw_classification from PytorchWildlife.data import transforms as pw_trans from PytorchWildlife.data import datasets as pw_data # Setting the device to use for computations ('cuda' indicates GPU) DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # Initializing a supervision box annotator for visualizing detections dot_annotator = sv.DotAnnotator(radius=6) box_annotator = sv.BoxAnnotator(thickness=4) lab_annotator = sv.LabelAnnotator(text_color=sv.Color.BLACK, text_thickness=4, text_scale=2) # Create a temp folder os.makedirs(os.path.join("..","temp"), exist_ok=True) # ASK: Why do we need this? # Initializing the detection and classification models detection_model = None classification_model = None # Defining functions for different detection scenarios def load_models(det, version, clf, wpath=None, wclass=None): global detection_model, classification_model if det != "None": if det == "HerdNet General": detection_model = pw_detection.HerdNet(device=DEVICE) elif det == "HerdNet Ennedi": detection_model = pw_detection.HerdNet(device=DEVICE, version="ennedi") else: detection_model = pw_detection.__dict__[det](device=DEVICE, pretrained=True, version=version) else: detection_model = None return "NO MODEL LOADED!!" if clf != "None": # Create an exception for custom weights if clf == "CustomWeights": if (wpath is not None) and (wclass is not None): wclass = ast.literal_eval(wclass) classification_model = pw_classification.__dict__[clf](weights=wpath, class_names=wclass, device=DEVICE) else: classification_model = pw_classification.__dict__[clf](device=DEVICE, pretrained=True) else: classification_model = None return "Loaded Detector: {}. Version: {}. Loaded Classifier: {}".format(det, version, clf) def single_image_detection(input_img, det_conf_thres, clf_conf_thres, img_index=None): """Performs detection on a single image and returns an annotated image. Args: input_img (PIL.Image): Input image in PIL.Image format defaulted by Gradio. det_conf_thres (float): Confidence threshold for detection. clf_conf_thres (float): Confidence threshold for classification. img_index: Image index identifier. Returns: annotated_img (PIL.Image.Image): Annotated image with bounding box instances. """ input_img = np.array(input_img) # If the detection model is HerdNet, use dot annotator, else use box annotator if detection_model.__class__.__name__.__contains__("HerdNet"): annotator = dot_annotator # Herdnet receives both clf and det confidence thresholds results_det = detection_model.single_image_detection(input_img, img_path=img_index, det_conf_thres=det_conf_thres, clf_conf_thres=clf_conf_thres) else: annotator = box_annotator results_det = detection_model.single_image_detection(input_img, img_path=img_index, det_conf_thres = det_conf_thres) if classification_model is not None: labels = [] for i, (xyxy, det_id) in enumerate(zip(results_det["detections"].xyxy, results_det["detections"].class_id)): # Only run classifier when detection class is animal if det_id == 0: cropped_image = sv.crop_image(image=input_img, xyxy=xyxy) results_clf = classification_model.single_image_classification(cropped_image) labels.append("{} {:.2f}".format(results_clf["prediction"] if results_clf["confidence"] > clf_conf_thres else "Unknown", results_clf["confidence"])) else: labels.append(results_det["labels"][i]) else: labels = results_det["labels"] annotated_img = lab_annotator.annotate( scene=annotator.annotate( scene=input_img, detections=results_det["detections"], ), detections=results_det["detections"], labels=labels, ) return annotated_img def batch_detection(zip_file, timelapse, det_conf_thres): """Perform detection on a batch of images from a zip file and return path to results JSON. Args: zip_file (File): Zip file containing images. det_conf_thres (float): Confidence threshold for detection. timelapse (boolean): Flag to output JSON for timelapse. clf_conf_thres (float): Confidence threshold for classification. Returns: json_save_path (str): Path to the JSON file containing detection results. """ # Clean the temp folder if it contains files extract_path = os.path.join("..","temp","zip_upload") if os.path.exists(extract_path): shutil.rmtree(extract_path) os.makedirs(extract_path) json_save_path = os.path.join(extract_path, "results.json") # with ZipFile(zip_file.name) as zfile: # zfile.extractall(extract_path) # Check the contents of the extracted folder extracted_files = os.listdir(extract_path) if len(extracted_files) == 1 and os.path.isdir(os.path.join(extract_path, extracted_files[0])): tgt_folder_path = os.path.join(extract_path, extracted_files[0]) else: tgt_folder_path = extract_path # If the detection model is HerdNet set batch_size to 1 if detection_model.__class__.__name__.__contains__("HerdNet"): det_results = detection_model.batch_image_detection(tgt_folder_path, batch_size=1, det_conf_thres=det_conf_thres, id_strip=tgt_folder_path) else: det_results = detection_model.batch_image_detection(tgt_folder_path, batch_size=16, det_conf_thres=det_conf_thres, id_strip=tgt_folder_path) if classification_model is not None: clf_dataset = pw_data.DetectionCrops( det_results, transform=pw_trans.Classification_Inference_Transform(target_size=224), path_head=tgt_folder_path ) clf_loader = DataLoader(clf_dataset, batch_size=32, shuffle=False, pin_memory=True, num_workers=4, drop_last=False) clf_results = classification_model.batch_image_classification(clf_loader, id_strip=tgt_folder_path) if timelapse: json_save_path = json_save_path.replace(".json", "_timelapse.json") pw_utils.save_detection_classification_timelapse_json(det_results=det_results, clf_results=clf_results, det_categories=detection_model.CLASS_NAMES, clf_categories=classification_model.CLASS_NAMES, output_path=json_save_path) else: pw_utils.save_detection_classification_json(det_results=det_results, clf_results=clf_results, det_categories=detection_model.CLASS_NAMES, clf_categories=classification_model.CLASS_NAMES, output_path=json_save_path) else: if timelapse: json_save_path = json_save_path.replace(".json", "_timelapse.json") pw_utils.save_detection_timelapse_json(det_results, json_save_path, categories=detection_model.CLASS_NAMES) elif detection_model.__class__.__name__.__contains__("HerdNet"): pw_utils.save_detection_json_as_dots(det_results, json_save_path, categories=detection_model.CLASS_NAMES) else: pw_utils.save_detection_json(det_results, json_save_path, categories=detection_model.CLASS_NAMES) return json_save_path def batch_path_detection(tgt_folder_path, det_conf_thres): """Perform detection on a batch of images from a zip file and return path to results JSON. Args: tgt_folder_path (str): path to the folder containing the images. det_conf_thres (float): Confidence threshold for detection. Returns: json_save_path (str): Path to the JSON file containing detection results. """ json_save_path = os.path.join(tgt_folder_path, "results.json") det_results = detection_model.batch_image_detection(tgt_folder_path, det_conf_thres=det_conf_thres, id_strip=tgt_folder_path) if detection_model.__class__.__name__.__contains__("HerdNet"): pw_utils.save_detection_json_as_dots(det_results, json_save_path, categories=detection_model.CLASS_NAMES) else: pw_utils.save_detection_json(det_results, json_save_path, categories=detection_model.CLASS_NAMES) return json_save_path def video_detection(video, det_conf_thres, clf_conf_thres, target_fps, codec): """Perform detection on a video and return path to processed video. Args: video (str): Video source path. det_conf_thres (float): Confidence threshold for detection. clf_conf_thres (float): Confidence threshold for classification. """ def callback(frame, index): annotated_frame = single_image_detection(frame, img_index=index, det_conf_thres=det_conf_thres, clf_conf_thres=clf_conf_thres) return annotated_frame target_path = os.path.join("..","temp","video_detection.mp4") pw_utils.process_video(source_path=video, target_path=target_path, callback=callback, target_fps=int(target_fps), codec=codec) return target_path def wrap_bool_output(fn): def wrapped(*args, **kwargs): result = fn(*args, **kwargs) if isinstance(result, bool): return {"success": result} return result return wrapped # Building Gradio UI with gr.Blocks() as demo: gr.Markdown("# Pytorch-Wildlife Demo.") with gr.Row(): det_drop = gr.Dropdown( ["None", "MegaDetectorV5", "MegaDetectorV6", "HerdNet General", "HerdNet Ennedi"], label="Detection model", info="Will add more detection models!", value="None" # Default ) det_version = gr.Dropdown( ["None"], label="Model version", info="Select the version of the model", value="None", ) with gr.Column(): clf_drop = gr.Dropdown( ["None", "AI4GOpossum", "AI4GAmazonRainforest", "AI4GSnapshotSerengeti", "CustomWeights"], interactive=True, label="Classification model", info="Will add more classification models!", visible=False, value="None" ) custom_weights_path = gr.Textbox(label="Custom Weights Path", visible=False, interactive=True, placeholder="./weights/my_weight.pt") custom_weights_class = gr.Textbox(label="Custom Weights Class", visible=False, interactive=True, placeholder="{1:'ocelot', 2:'cow', 3:'bear'}") load_but = gr.Button("Load Models!") load_out = gr.Text("NO MODEL LOADED!!", label="Loaded models:") def update_ui_elements(det_model): if det_model == "MegaDetectorV6": return gr.Dropdown(choices=["MDV6-yolov9-c", "MDV6-yolov9-e", "MDV6-yolov10-c", "MDV6-yolov10-e", "MDV6-rtdetr-c"], interactive=True, label="Model version", value="MDV6-yolov9e"), gr.update(visible=True) elif det_model == "MegaDetectorV5": return gr.Dropdown(choices=["a", "b"], interactive=True, label="Model version", value="a"), gr.update(visible=True) else: return gr.Dropdown(choices=["None"], interactive=True, label="Model version", value="None"), gr.update(value="None", visible=False) det_drop.change(update_ui_elements, det_drop, [det_version, clf_drop]) def toggle_textboxes(model): if model == "CustomWeights": return gr.update(visible=True), gr.update(visible=True) else: return gr.update(visible=False), gr.update(visible=False) clf_drop.change( toggle_textboxes, clf_drop, [custom_weights_path, custom_weights_class] ) with gr.Tab("Single Image Process"): with gr.Row(): with gr.Column(): sgl_in = gr.Image(type="pil") sgl_conf_sl_det = gr.Slider(0, 1, label="Detection Confidence Threshold", value=0.2) sgl_conf_sl_clf = gr.Slider(0, 1, label="Classification Confidence Threshold", value=0.7, visible=True) sgl_out = gr.Image() sgl_but = gr.Button("Detect Animals!") with gr.Tab("Folder Separation"): with gr.Row(): with gr.Column(): inp_path = gr.Textbox(label="Input path", placeholder="./data/") out_path = gr.Textbox(label="Output path", placeholder="./output/") bth_conf_fs = gr.Slider(0, 1, label="Detection Confidence Threshold", value=0.2) process_btn = gr.Button("Process Files") bth_out2 = gr.File(label="Detection Results JSON.", height=200) with gr.Column(): process_files_button = gr.Button("Separate files") process_result = gr.Text("Click on 'Separate files' once you see the JSON file", label="Separated files:") process_btn.click(batch_path_detection, inputs=[inp_path, bth_conf_fs], outputs=bth_out2) process_files_button.click(wrap_bool_output(pw_utils.detection_folder_separation), inputs=[bth_out2, inp_path, out_path, bth_conf_fs], outputs=process_result) with gr.Tab("Batch Image Process"): with gr.Row(): with gr.Column(): bth_in = gr.File(label="Upload zip file.") # The timelapse checkbox is only visible when the detection model is not HerdNet chck_timelapse = gr.Checkbox(label="Generate timelapse JSON", visible=False) bth_conf_sl = gr.Slider(0, 1, label="Detection Confidence Threshold", value=0.2) bth_out = gr.File(label="Detection Results JSON.", height=200) bth_but = gr.Button("Detect Animals!") with gr.Tab("Single Video Process"): with gr.Row(): with gr.Column(): vid_in = gr.Video(label="Upload a video.") vid_conf_sl_det = gr.Slider(0, 1, label="Detection Confidence Threshold", value=0.2) vid_conf_sl_clf = gr.Slider(0, 1, label="Classification Confidence Threshold", value=0.7) vid_fr = gr.Dropdown([5, 10, 30], label="Output video framerate", value=30) vid_enc = gr.Dropdown( ["mp4v", "avc1"], label="Video encoder", info="mp4v is default, av1c is faster (needs conda install opencv)", value="mp4v" ) vid_out = gr.Video() vid_but = gr.Button("Detect Animals!") # Show timelapsed checkbox only when detection model is not HerdNet det_drop.change( lambda model: gr.update(visible=True) if "HerdNet" not in model else gr.update(visible=False), det_drop, [chck_timelapse] ) load_but.click(load_models, inputs=[det_drop, det_version, clf_drop, custom_weights_path, custom_weights_class], outputs=load_out) sgl_but.click(single_image_detection, inputs=[sgl_in, sgl_conf_sl_det, sgl_conf_sl_clf], outputs=sgl_out) bth_but.click(batch_detection, inputs=[bth_in, chck_timelapse, bth_conf_sl], outputs=bth_out) vid_but.click(video_detection, inputs=[vid_in, vid_conf_sl_det, vid_conf_sl_clf, vid_fr, vid_enc], outputs=vid_out) if __name__ == "__main__": #demo.queue() demo.launch(share=True)