annading commited on
Commit
692e2af
·
1 Parent(s): aab5330

added batch abilities for local install

Browse files
Files changed (3) hide show
  1. .gitignore +6 -3
  2. app_batch.py +135 -0
  3. owl_batch.py +211 -0
.gitignore CHANGED
@@ -1,3 +1,6 @@
1
- *.pyc
2
- */__pycache__/**
3
- *.mp4
 
 
 
 
1
+ *.pyc
2
+ */__pycache__/**
3
+ *.mp4
4
+ *.png
5
+ *.csv
6
+ /.gradio/
app_batch.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ BATCH_SIZE = 8 # Change this to your desired batch size
2
+ CUDA_PATH = "/usr/local/cuda-12.3/" # Change this to your CUDA path
3
+
4
+
5
+ from datetime import datetime
6
+ import os
7
+ import sys
8
+ # set CUDA_HOME
9
+ os.environ["CUDA_HOME"] = CUDA_PATH
10
+
11
+ import gradio as gr
12
+ from tqdm import tqdm
13
+ import cv2
14
+ import os
15
+ import time
16
+
17
+ from owl_batch import owl_batch_video
18
+
19
+ # global CSV_PATH # csv that contains video names and detection results
20
+ # global POS_ZIP # zip of positive videos and individual results
21
+ # global NEG_ZIP # zip of negative videos and individual results
22
+
23
+ def run_owl_batch(
24
+ input_vids : list[str] | str,
25
+ target_prompt: str,
26
+ species_prompt: str,
27
+ conf_threshold: float,
28
+ fps_processed: int,
29
+ scaling_factor: float
30
+ ) -> tuple[str, str, str]:
31
+ """
32
+ args:
33
+ input_vids: list of video paths
34
+ target_prompt: prompt to search for
35
+ species_prompt: prompt to query
36
+ threshold: threshold for detection
37
+ fps_processed: number of frames per second to process
38
+ scaling_factor: factor to scale the frames by
39
+ returns:
40
+ csv_path: path to csv file
41
+ pos_zip: path to zip file of positive videos
42
+ neg_zip: path to zip file of negative videos
43
+ """
44
+ start_time = time.time()
45
+ if type(input_vids) == str:
46
+ input_vids = [input_vids]
47
+ for vid in input_vids:
48
+ new_input_vid = vid.replace(" ", "_") # make sure there are no spaces in the name
49
+ os.rename(vid, new_input_vid)
50
+
51
+ # species prompt has to contain target prompt, otherwise add it
52
+ if target_prompt not in species_prompt:
53
+ species_prompt = f"{species_prompt}, {target_prompt}"
54
+
55
+ # turn target prompt into a list
56
+ target_prompt = target_prompt.split(", ")
57
+
58
+ now = datetime.now()
59
+ timestamp = now.strftime("%Y-%m-%d_%H-%M")
60
+
61
+ zip_path = owl_batch_video(
62
+ input_vids,
63
+ target_prompt,
64
+ species_prompt,
65
+ conf_threshold,
66
+ fps_processed=fps_processed,
67
+ scaling_factor=1/scaling_factor,
68
+ batch_size=BATCH_SIZE,
69
+ save_dir=f"temp_{timestamp}")
70
+
71
+ end_time = time.time()
72
+ print(f'Processing time: {end_time - start_time} seconds')
73
+ return zip_path
74
+
75
+
76
+ with gr.Blocks() as demo:
77
+ gr.HTML(
78
+ """
79
+ <h1 align="center" style="font-size:xxx-large">🦍 Primate Detection</h1>
80
+ """
81
+ )
82
+
83
+ with gr.Row():
84
+ with gr.Column():
85
+ input = gr.File(label="Upload Videos", file_types=['.mp4', '.mov'], file_count="multiple")
86
+ target_prompt = gr.Textbox(label="What do you want to detect? (Multiple species should be separated by commas)")
87
+ species_prompt = gr.Textbox(label="Which species are in your dataset? (Multiple species should be separated by commas)")
88
+ with gr.Accordion("Advanced Options", open=False):
89
+ conf_threshold = gr.Slider(
90
+ label="Confidence Threshold",
91
+ info="Adjust the threshold to change the sensitivity of the model, lower thresholds being more sensitive.",
92
+ minimum=0.0,
93
+ maximum=1.0,
94
+ value=0.3,
95
+ step=0.05
96
+ )
97
+ fps_processed = gr.Slider(
98
+ label="Frame Detection Rate",
99
+ info="Adjust the frame detection rate. I.e. a value of 120 will run detection every 120 frames, a value of 1 will run detection on every frame. Note: the lower the number the slower the processing time.",
100
+ minimum=1,
101
+ maximum=120,
102
+ value=10,
103
+ step=1)
104
+ scaling_factor = gr.Slider(
105
+ label="Downsample Factor",
106
+ info="Adjust the downsample factor. Note: the higher the number the faster the processing time but lower the accuracy.",
107
+ minimum=1,
108
+ maximum=10,
109
+ value=4,
110
+ step=1
111
+ )
112
+ with gr.Row():
113
+ clear_btn = gr.ClearButton(components=[input, target_prompt, species_prompt])
114
+ run_btn = gr.Button(value="Run Detection", variant='primary')
115
+ with gr.Column():
116
+ download_file = gr.Files(label="CSV, Video Output", interactive=False)
117
+
118
+ run_btn.click(fn=run_owl_batch, inputs=[input, target_prompt, species_prompt, conf_threshold, fps_processed, scaling_factor], outputs=[download_file])
119
+
120
+ gr.DuplicateButton()
121
+
122
+ gr.Markdown(
123
+ """
124
+ ## Frequently Asked Questions
125
+
126
+ ##### How can I run the interface on my own computer?
127
+ By clicking on the three dots on the top right corner of the interface, you will be able to clone the repository or run it with a Docker image on your local machine. \
128
+ For local machine setup instructions please check the README file.
129
+ ##### The video is very slow to process, how can I speed it up?
130
+ You can speed up the processing by adjusting the frame detection rate in the advanced options. The lower the number the slower the processing time. Choosing only\
131
+ bounding boxes will make the processing faster. You can also duplicate the space using the Duplicate Button and choose a different GPU which will make the processing faster.
132
+ """
133
+ )
134
+
135
+ demo.launch(share=True)
owl_batch.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ from tqdm import tqdm
4
+ import cv2
5
+ import pandas as pd
6
+ import torch
7
+ from PIL import Image
8
+ from transformers import Owlv2Processor, Owlv2ForObjectDetection
9
+ import math
10
+ import zipfile
11
+ from utils import plot_predictions, mp4_to_png, vid_stitcher
12
+
13
+ def owl_batch_video(
14
+ input_vids: list[str],
15
+ target_prompt: list[str],
16
+ species_prompt: str,
17
+ threshold: float,
18
+ fps_processed: int = 1,
19
+ scaling_factor: float = 0.5,
20
+ batch_size: int = 8,
21
+ save_dir: str = "temp/"
22
+ ):
23
+ pos_preds = []
24
+ neg_preds = []
25
+
26
+ df = pd.DataFrame(columns=["video path", "detection?"])
27
+
28
+ for vid in input_vids:
29
+ detection = owl_video_detection(vid,
30
+ target_prompt,
31
+ species_prompt,
32
+ threshold,
33
+ fps_processed=fps_processed,
34
+ scaling_factor=scaling_factor,
35
+ batch_size=batch_size,
36
+ save_dir=save_dir)
37
+
38
+ if detection == True:
39
+ pos_preds.append(vid)
40
+ row = pd.DataFrame({"video path": [vid], "detection?": ["True"]})
41
+ df = pd.concat([df, row], ignore_index=True)
42
+ else:
43
+ neg_preds.append(vid)
44
+ row = pd.DataFrame({"video path": [vid], "detection?": ["False"]})
45
+ df = pd.concat([df, row], ignore_index=True)
46
+
47
+ # save the df
48
+ df.to_csv(save_dir + "detection_results.csv")
49
+
50
+ # zip the save_dir
51
+ zip_file = f"{save_dir}/results.zip"
52
+ zip_directory(save_dir, zip_file)
53
+
54
+ return zip_file
55
+
56
+
57
+
58
+ def zip_directory(folder_path, output_zip_path):
59
+ with zipfile.ZipFile(output_zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
60
+ for root, dirs, files in os.walk(folder_path):
61
+ for file in files:
62
+ file_path = os.path.join(root, file)
63
+ # Write the file with a relative path to preserve folder structure
64
+ arcname = os.path.relpath(file_path, start=folder_path)
65
+ zipf.write(file_path, arcname)
66
+
67
+
68
+ def preprocess_text(text_prompt: str, num_prompts: int = 1):
69
+ """
70
+ Takes a string of text prompts and returns a list of lists of text prompts for each image.
71
+ i.e. text_prompt = "a, b, c" -> [["a", "b", "c"], ["a", "b", "c"]]
72
+ """
73
+ text_prompt = [s.strip() for s in text_prompt.split(",")]
74
+ text_queries = [text_prompt] * num_prompts
75
+ # print("text_queries:", text_queries)
76
+ return text_queries
77
+
78
+ def owl_batch_prediction(
79
+ images: torch.Tensor,
80
+ text_queries : list[str], # assuming that every image is queried with the same text prompt
81
+ threshold: float,
82
+ processor,
83
+ model,
84
+ device: str = 'cuda'
85
+ ):
86
+
87
+ inputs = processor(text=text_queries, images=images, return_tensors="pt").to(device)
88
+ with torch.no_grad():
89
+ outputs = model(**inputs)
90
+
91
+ # Target image sizes (height, width) to rescale box predictions [batch_size, 2]
92
+ target_sizes = torch.Tensor([img.size[::-1] for img in images]).to(device)
93
+ # Convert outputs (bounding boxes and class logits) to COCO API, resizes to original image size and filter by threshold
94
+ results = processor.post_process_object_detection(outputs=outputs, target_sizes=target_sizes, threshold=threshold)
95
+
96
+ return results
97
+
98
+
99
+ def count_pos(phrases: list[str], text_targets: list[str]) -> int:
100
+ """
101
+ Counts how many phrases in the list match any of the target phrases.
102
+
103
+ Args:
104
+ phrases: A list of strings to evaluate.
105
+ text_targets: A list of target strings to match against.
106
+
107
+ Returns:
108
+ The number of phrases that match any of the targets.
109
+ """
110
+ if len(phrases) == 0 or len(text_targets) == 0:
111
+ return 0
112
+ target_set = set(text_targets)
113
+ return sum(1 for phrase in phrases if phrase in target_set)
114
+
115
+
116
+ def owl_video_detection(
117
+ vid_path: str,
118
+ text_target: list[str],
119
+ text_prompt: str,
120
+ threshold: float,
121
+ fps_processed: int = 1,
122
+ scaling_factor: float = 0.5,
123
+ processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16-ensemble"),
124
+ model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble").to('cuda'),
125
+ device: str = 'cuda',
126
+ batch_size: int = 8,
127
+ save_dir: str = "temp/",
128
+ ):
129
+ """
130
+ Runs owl on a video and saves the results to a dataframe.
131
+ Returns True if text_target is detected in the video, False otherwise.
132
+ Stops running owl when a text_target is detected.
133
+ """
134
+ os.makedirs(save_dir, exist_ok=True)
135
+ os.makedirs(f"{save_dir}/positives", exist_ok=True)
136
+ os.makedirs(f"{save_dir}/negatives", exist_ok=True)
137
+
138
+ # set up df for results
139
+ df = pd.DataFrame(columns=["frame", "boxes", "scores", "labels", "count"])
140
+
141
+ # create new dirs and paths for results
142
+ filename = os.path.splitext(os.path.basename(vid_path))[0]
143
+ frames_dir = f"{save_dir}/{filename}_frames"
144
+ os.makedirs(frames_dir, exist_ok=True)
145
+
146
+ # process video and create a directory of video frames
147
+ fps = mp4_to_png(vid_path, frames_dir, scaling_factor)
148
+
149
+ # get all frame paths
150
+ frame_filenames = os.listdir(frames_dir)
151
+
152
+ frame_paths = [] # list of frame paths to process based on fps_processed
153
+ # for every frame processed, add to frame_paths
154
+ for i, frame in enumerate(frame_filenames):
155
+ if i % fps_processed == 0:
156
+ frame_paths.append(os.path.join(frames_dir, frame))
157
+
158
+ # run owl in batches
159
+ for i in tqdm(range(0, len(frame_paths), batch_size), desc="Running batches"):
160
+ frame_nums = [i*fps_processed for i in range(batch_size)]
161
+ batch_paths = frame_paths[i:i+batch_size] # paths for this batch
162
+ images = [Image.open(image_path) for image_path in batch_paths]
163
+
164
+ # run owl on this batch of frames
165
+ text_queries = preprocess_text(text_prompt, len(batch_paths))
166
+ results = owl_batch_prediction(images, text_queries, threshold, processor, model, device)
167
+
168
+ # get the boxes, logits, and phrases for this batch
169
+ label_ids = []
170
+ for entry in results:
171
+ if entry['labels'].numel() > 0:
172
+ label_ids.append(entry['labels'].tolist())
173
+ else:
174
+ label_ids.append(None)
175
+
176
+ text = text_queries[0] # assuming that all texts in query are the same for each image
177
+ labels = []
178
+ # convert label_ids to phrases, if no phrases, append None
179
+ for idx in label_ids:
180
+ if idx is not None:
181
+ idx = [text[id] for id in idx]
182
+ labels.append(idx)
183
+ else:
184
+ labels.append([])
185
+
186
+ batch_pos = 0
187
+ for j, image in enumerate(batch_paths):
188
+ boxes = results[j]['boxes'].cpu().numpy()
189
+ scores = results[j]['scores'].cpu().numpy()
190
+ print(labels[j], text_target, count_pos(labels[j], text_target))
191
+ count = count_pos(labels[j], text_target)
192
+ row = pd.DataFrame({"frame": [image], "boxes": [boxes], "scores": [scores], "labels": [labels[j]], "count": count})
193
+ df = pd.concat([df, row], ignore_index=True)
194
+
195
+ # if there are detections, save the frame replacing the original frame
196
+ if count > 0:
197
+ annotated_frame = plot_predictions(image, labels[j], scores, boxes)
198
+ cv2.imwrite(image, annotated_frame)
199
+ batch_pos += 1
200
+
201
+ # if more than 2/3 batch frames are positive, return True
202
+ if batch_pos > math.ceil(2/3*batch_size):
203
+ vid_stitcher(frames_dir, f"{save_dir}/positives/{filename}_{threshold}.mp4", fps)
204
+ shutil.rmtree(frames_dir) # delete the frames to save space
205
+ df.to_csv(f"{save_dir}/positives/{filename}_{threshold}.csv", index=False)
206
+ return True
207
+
208
+ shutil.rmtree(frames_dir) # delete the frames to save space
209
+ df.to_csv(f"{save_dir}/negatives/{filename}_{threshold}.csv", index=False)
210
+ return False
211
+