added batch abilities for local install
Browse files- .gitignore +6 -3
- app_batch.py +135 -0
- owl_batch.py +211 -0
.gitignore
CHANGED
@@ -1,3 +1,6 @@
|
|
1 |
-
*.pyc
|
2 |
-
*/__pycache__/**
|
3 |
-
*.mp4
|
|
|
|
|
|
|
|
1 |
+
*.pyc
|
2 |
+
*/__pycache__/**
|
3 |
+
*.mp4
|
4 |
+
*.png
|
5 |
+
*.csv
|
6 |
+
/.gradio/
|
app_batch.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
BATCH_SIZE = 8 # Change this to your desired batch size
|
2 |
+
CUDA_PATH = "/usr/local/cuda-12.3/" # Change this to your CUDA path
|
3 |
+
|
4 |
+
|
5 |
+
from datetime import datetime
|
6 |
+
import os
|
7 |
+
import sys
|
8 |
+
# set CUDA_HOME
|
9 |
+
os.environ["CUDA_HOME"] = CUDA_PATH
|
10 |
+
|
11 |
+
import gradio as gr
|
12 |
+
from tqdm import tqdm
|
13 |
+
import cv2
|
14 |
+
import os
|
15 |
+
import time
|
16 |
+
|
17 |
+
from owl_batch import owl_batch_video
|
18 |
+
|
19 |
+
# global CSV_PATH # csv that contains video names and detection results
|
20 |
+
# global POS_ZIP # zip of positive videos and individual results
|
21 |
+
# global NEG_ZIP # zip of negative videos and individual results
|
22 |
+
|
23 |
+
def run_owl_batch(
|
24 |
+
input_vids : list[str] | str,
|
25 |
+
target_prompt: str,
|
26 |
+
species_prompt: str,
|
27 |
+
conf_threshold: float,
|
28 |
+
fps_processed: int,
|
29 |
+
scaling_factor: float
|
30 |
+
) -> tuple[str, str, str]:
|
31 |
+
"""
|
32 |
+
args:
|
33 |
+
input_vids: list of video paths
|
34 |
+
target_prompt: prompt to search for
|
35 |
+
species_prompt: prompt to query
|
36 |
+
threshold: threshold for detection
|
37 |
+
fps_processed: number of frames per second to process
|
38 |
+
scaling_factor: factor to scale the frames by
|
39 |
+
returns:
|
40 |
+
csv_path: path to csv file
|
41 |
+
pos_zip: path to zip file of positive videos
|
42 |
+
neg_zip: path to zip file of negative videos
|
43 |
+
"""
|
44 |
+
start_time = time.time()
|
45 |
+
if type(input_vids) == str:
|
46 |
+
input_vids = [input_vids]
|
47 |
+
for vid in input_vids:
|
48 |
+
new_input_vid = vid.replace(" ", "_") # make sure there are no spaces in the name
|
49 |
+
os.rename(vid, new_input_vid)
|
50 |
+
|
51 |
+
# species prompt has to contain target prompt, otherwise add it
|
52 |
+
if target_prompt not in species_prompt:
|
53 |
+
species_prompt = f"{species_prompt}, {target_prompt}"
|
54 |
+
|
55 |
+
# turn target prompt into a list
|
56 |
+
target_prompt = target_prompt.split(", ")
|
57 |
+
|
58 |
+
now = datetime.now()
|
59 |
+
timestamp = now.strftime("%Y-%m-%d_%H-%M")
|
60 |
+
|
61 |
+
zip_path = owl_batch_video(
|
62 |
+
input_vids,
|
63 |
+
target_prompt,
|
64 |
+
species_prompt,
|
65 |
+
conf_threshold,
|
66 |
+
fps_processed=fps_processed,
|
67 |
+
scaling_factor=1/scaling_factor,
|
68 |
+
batch_size=BATCH_SIZE,
|
69 |
+
save_dir=f"temp_{timestamp}")
|
70 |
+
|
71 |
+
end_time = time.time()
|
72 |
+
print(f'Processing time: {end_time - start_time} seconds')
|
73 |
+
return zip_path
|
74 |
+
|
75 |
+
|
76 |
+
with gr.Blocks() as demo:
|
77 |
+
gr.HTML(
|
78 |
+
"""
|
79 |
+
<h1 align="center" style="font-size:xxx-large">🦍 Primate Detection</h1>
|
80 |
+
"""
|
81 |
+
)
|
82 |
+
|
83 |
+
with gr.Row():
|
84 |
+
with gr.Column():
|
85 |
+
input = gr.File(label="Upload Videos", file_types=['.mp4', '.mov'], file_count="multiple")
|
86 |
+
target_prompt = gr.Textbox(label="What do you want to detect? (Multiple species should be separated by commas)")
|
87 |
+
species_prompt = gr.Textbox(label="Which species are in your dataset? (Multiple species should be separated by commas)")
|
88 |
+
with gr.Accordion("Advanced Options", open=False):
|
89 |
+
conf_threshold = gr.Slider(
|
90 |
+
label="Confidence Threshold",
|
91 |
+
info="Adjust the threshold to change the sensitivity of the model, lower thresholds being more sensitive.",
|
92 |
+
minimum=0.0,
|
93 |
+
maximum=1.0,
|
94 |
+
value=0.3,
|
95 |
+
step=0.05
|
96 |
+
)
|
97 |
+
fps_processed = gr.Slider(
|
98 |
+
label="Frame Detection Rate",
|
99 |
+
info="Adjust the frame detection rate. I.e. a value of 120 will run detection every 120 frames, a value of 1 will run detection on every frame. Note: the lower the number the slower the processing time.",
|
100 |
+
minimum=1,
|
101 |
+
maximum=120,
|
102 |
+
value=10,
|
103 |
+
step=1)
|
104 |
+
scaling_factor = gr.Slider(
|
105 |
+
label="Downsample Factor",
|
106 |
+
info="Adjust the downsample factor. Note: the higher the number the faster the processing time but lower the accuracy.",
|
107 |
+
minimum=1,
|
108 |
+
maximum=10,
|
109 |
+
value=4,
|
110 |
+
step=1
|
111 |
+
)
|
112 |
+
with gr.Row():
|
113 |
+
clear_btn = gr.ClearButton(components=[input, target_prompt, species_prompt])
|
114 |
+
run_btn = gr.Button(value="Run Detection", variant='primary')
|
115 |
+
with gr.Column():
|
116 |
+
download_file = gr.Files(label="CSV, Video Output", interactive=False)
|
117 |
+
|
118 |
+
run_btn.click(fn=run_owl_batch, inputs=[input, target_prompt, species_prompt, conf_threshold, fps_processed, scaling_factor], outputs=[download_file])
|
119 |
+
|
120 |
+
gr.DuplicateButton()
|
121 |
+
|
122 |
+
gr.Markdown(
|
123 |
+
"""
|
124 |
+
## Frequently Asked Questions
|
125 |
+
|
126 |
+
##### How can I run the interface on my own computer?
|
127 |
+
By clicking on the three dots on the top right corner of the interface, you will be able to clone the repository or run it with a Docker image on your local machine. \
|
128 |
+
For local machine setup instructions please check the README file.
|
129 |
+
##### The video is very slow to process, how can I speed it up?
|
130 |
+
You can speed up the processing by adjusting the frame detection rate in the advanced options. The lower the number the slower the processing time. Choosing only\
|
131 |
+
bounding boxes will make the processing faster. You can also duplicate the space using the Duplicate Button and choose a different GPU which will make the processing faster.
|
132 |
+
"""
|
133 |
+
)
|
134 |
+
|
135 |
+
demo.launch(share=True)
|
owl_batch.py
ADDED
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import shutil
|
3 |
+
from tqdm import tqdm
|
4 |
+
import cv2
|
5 |
+
import pandas as pd
|
6 |
+
import torch
|
7 |
+
from PIL import Image
|
8 |
+
from transformers import Owlv2Processor, Owlv2ForObjectDetection
|
9 |
+
import math
|
10 |
+
import zipfile
|
11 |
+
from utils import plot_predictions, mp4_to_png, vid_stitcher
|
12 |
+
|
13 |
+
def owl_batch_video(
|
14 |
+
input_vids: list[str],
|
15 |
+
target_prompt: list[str],
|
16 |
+
species_prompt: str,
|
17 |
+
threshold: float,
|
18 |
+
fps_processed: int = 1,
|
19 |
+
scaling_factor: float = 0.5,
|
20 |
+
batch_size: int = 8,
|
21 |
+
save_dir: str = "temp/"
|
22 |
+
):
|
23 |
+
pos_preds = []
|
24 |
+
neg_preds = []
|
25 |
+
|
26 |
+
df = pd.DataFrame(columns=["video path", "detection?"])
|
27 |
+
|
28 |
+
for vid in input_vids:
|
29 |
+
detection = owl_video_detection(vid,
|
30 |
+
target_prompt,
|
31 |
+
species_prompt,
|
32 |
+
threshold,
|
33 |
+
fps_processed=fps_processed,
|
34 |
+
scaling_factor=scaling_factor,
|
35 |
+
batch_size=batch_size,
|
36 |
+
save_dir=save_dir)
|
37 |
+
|
38 |
+
if detection == True:
|
39 |
+
pos_preds.append(vid)
|
40 |
+
row = pd.DataFrame({"video path": [vid], "detection?": ["True"]})
|
41 |
+
df = pd.concat([df, row], ignore_index=True)
|
42 |
+
else:
|
43 |
+
neg_preds.append(vid)
|
44 |
+
row = pd.DataFrame({"video path": [vid], "detection?": ["False"]})
|
45 |
+
df = pd.concat([df, row], ignore_index=True)
|
46 |
+
|
47 |
+
# save the df
|
48 |
+
df.to_csv(save_dir + "detection_results.csv")
|
49 |
+
|
50 |
+
# zip the save_dir
|
51 |
+
zip_file = f"{save_dir}/results.zip"
|
52 |
+
zip_directory(save_dir, zip_file)
|
53 |
+
|
54 |
+
return zip_file
|
55 |
+
|
56 |
+
|
57 |
+
|
58 |
+
def zip_directory(folder_path, output_zip_path):
|
59 |
+
with zipfile.ZipFile(output_zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
|
60 |
+
for root, dirs, files in os.walk(folder_path):
|
61 |
+
for file in files:
|
62 |
+
file_path = os.path.join(root, file)
|
63 |
+
# Write the file with a relative path to preserve folder structure
|
64 |
+
arcname = os.path.relpath(file_path, start=folder_path)
|
65 |
+
zipf.write(file_path, arcname)
|
66 |
+
|
67 |
+
|
68 |
+
def preprocess_text(text_prompt: str, num_prompts: int = 1):
|
69 |
+
"""
|
70 |
+
Takes a string of text prompts and returns a list of lists of text prompts for each image.
|
71 |
+
i.e. text_prompt = "a, b, c" -> [["a", "b", "c"], ["a", "b", "c"]]
|
72 |
+
"""
|
73 |
+
text_prompt = [s.strip() for s in text_prompt.split(",")]
|
74 |
+
text_queries = [text_prompt] * num_prompts
|
75 |
+
# print("text_queries:", text_queries)
|
76 |
+
return text_queries
|
77 |
+
|
78 |
+
def owl_batch_prediction(
|
79 |
+
images: torch.Tensor,
|
80 |
+
text_queries : list[str], # assuming that every image is queried with the same text prompt
|
81 |
+
threshold: float,
|
82 |
+
processor,
|
83 |
+
model,
|
84 |
+
device: str = 'cuda'
|
85 |
+
):
|
86 |
+
|
87 |
+
inputs = processor(text=text_queries, images=images, return_tensors="pt").to(device)
|
88 |
+
with torch.no_grad():
|
89 |
+
outputs = model(**inputs)
|
90 |
+
|
91 |
+
# Target image sizes (height, width) to rescale box predictions [batch_size, 2]
|
92 |
+
target_sizes = torch.Tensor([img.size[::-1] for img in images]).to(device)
|
93 |
+
# Convert outputs (bounding boxes and class logits) to COCO API, resizes to original image size and filter by threshold
|
94 |
+
results = processor.post_process_object_detection(outputs=outputs, target_sizes=target_sizes, threshold=threshold)
|
95 |
+
|
96 |
+
return results
|
97 |
+
|
98 |
+
|
99 |
+
def count_pos(phrases: list[str], text_targets: list[str]) -> int:
|
100 |
+
"""
|
101 |
+
Counts how many phrases in the list match any of the target phrases.
|
102 |
+
|
103 |
+
Args:
|
104 |
+
phrases: A list of strings to evaluate.
|
105 |
+
text_targets: A list of target strings to match against.
|
106 |
+
|
107 |
+
Returns:
|
108 |
+
The number of phrases that match any of the targets.
|
109 |
+
"""
|
110 |
+
if len(phrases) == 0 or len(text_targets) == 0:
|
111 |
+
return 0
|
112 |
+
target_set = set(text_targets)
|
113 |
+
return sum(1 for phrase in phrases if phrase in target_set)
|
114 |
+
|
115 |
+
|
116 |
+
def owl_video_detection(
|
117 |
+
vid_path: str,
|
118 |
+
text_target: list[str],
|
119 |
+
text_prompt: str,
|
120 |
+
threshold: float,
|
121 |
+
fps_processed: int = 1,
|
122 |
+
scaling_factor: float = 0.5,
|
123 |
+
processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16-ensemble"),
|
124 |
+
model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble").to('cuda'),
|
125 |
+
device: str = 'cuda',
|
126 |
+
batch_size: int = 8,
|
127 |
+
save_dir: str = "temp/",
|
128 |
+
):
|
129 |
+
"""
|
130 |
+
Runs owl on a video and saves the results to a dataframe.
|
131 |
+
Returns True if text_target is detected in the video, False otherwise.
|
132 |
+
Stops running owl when a text_target is detected.
|
133 |
+
"""
|
134 |
+
os.makedirs(save_dir, exist_ok=True)
|
135 |
+
os.makedirs(f"{save_dir}/positives", exist_ok=True)
|
136 |
+
os.makedirs(f"{save_dir}/negatives", exist_ok=True)
|
137 |
+
|
138 |
+
# set up df for results
|
139 |
+
df = pd.DataFrame(columns=["frame", "boxes", "scores", "labels", "count"])
|
140 |
+
|
141 |
+
# create new dirs and paths for results
|
142 |
+
filename = os.path.splitext(os.path.basename(vid_path))[0]
|
143 |
+
frames_dir = f"{save_dir}/{filename}_frames"
|
144 |
+
os.makedirs(frames_dir, exist_ok=True)
|
145 |
+
|
146 |
+
# process video and create a directory of video frames
|
147 |
+
fps = mp4_to_png(vid_path, frames_dir, scaling_factor)
|
148 |
+
|
149 |
+
# get all frame paths
|
150 |
+
frame_filenames = os.listdir(frames_dir)
|
151 |
+
|
152 |
+
frame_paths = [] # list of frame paths to process based on fps_processed
|
153 |
+
# for every frame processed, add to frame_paths
|
154 |
+
for i, frame in enumerate(frame_filenames):
|
155 |
+
if i % fps_processed == 0:
|
156 |
+
frame_paths.append(os.path.join(frames_dir, frame))
|
157 |
+
|
158 |
+
# run owl in batches
|
159 |
+
for i in tqdm(range(0, len(frame_paths), batch_size), desc="Running batches"):
|
160 |
+
frame_nums = [i*fps_processed for i in range(batch_size)]
|
161 |
+
batch_paths = frame_paths[i:i+batch_size] # paths for this batch
|
162 |
+
images = [Image.open(image_path) for image_path in batch_paths]
|
163 |
+
|
164 |
+
# run owl on this batch of frames
|
165 |
+
text_queries = preprocess_text(text_prompt, len(batch_paths))
|
166 |
+
results = owl_batch_prediction(images, text_queries, threshold, processor, model, device)
|
167 |
+
|
168 |
+
# get the boxes, logits, and phrases for this batch
|
169 |
+
label_ids = []
|
170 |
+
for entry in results:
|
171 |
+
if entry['labels'].numel() > 0:
|
172 |
+
label_ids.append(entry['labels'].tolist())
|
173 |
+
else:
|
174 |
+
label_ids.append(None)
|
175 |
+
|
176 |
+
text = text_queries[0] # assuming that all texts in query are the same for each image
|
177 |
+
labels = []
|
178 |
+
# convert label_ids to phrases, if no phrases, append None
|
179 |
+
for idx in label_ids:
|
180 |
+
if idx is not None:
|
181 |
+
idx = [text[id] for id in idx]
|
182 |
+
labels.append(idx)
|
183 |
+
else:
|
184 |
+
labels.append([])
|
185 |
+
|
186 |
+
batch_pos = 0
|
187 |
+
for j, image in enumerate(batch_paths):
|
188 |
+
boxes = results[j]['boxes'].cpu().numpy()
|
189 |
+
scores = results[j]['scores'].cpu().numpy()
|
190 |
+
print(labels[j], text_target, count_pos(labels[j], text_target))
|
191 |
+
count = count_pos(labels[j], text_target)
|
192 |
+
row = pd.DataFrame({"frame": [image], "boxes": [boxes], "scores": [scores], "labels": [labels[j]], "count": count})
|
193 |
+
df = pd.concat([df, row], ignore_index=True)
|
194 |
+
|
195 |
+
# if there are detections, save the frame replacing the original frame
|
196 |
+
if count > 0:
|
197 |
+
annotated_frame = plot_predictions(image, labels[j], scores, boxes)
|
198 |
+
cv2.imwrite(image, annotated_frame)
|
199 |
+
batch_pos += 1
|
200 |
+
|
201 |
+
# if more than 2/3 batch frames are positive, return True
|
202 |
+
if batch_pos > math.ceil(2/3*batch_size):
|
203 |
+
vid_stitcher(frames_dir, f"{save_dir}/positives/{filename}_{threshold}.mp4", fps)
|
204 |
+
shutil.rmtree(frames_dir) # delete the frames to save space
|
205 |
+
df.to_csv(f"{save_dir}/positives/{filename}_{threshold}.csv", index=False)
|
206 |
+
return True
|
207 |
+
|
208 |
+
shutil.rmtree(frames_dir) # delete the frames to save space
|
209 |
+
df.to_csv(f"{save_dir}/negatives/{filename}_{threshold}.csv", index=False)
|
210 |
+
return False
|
211 |
+
|