|
import os
|
|
import shutil
|
|
from tqdm import tqdm
|
|
import cv2
|
|
import pandas as pd
|
|
import torch
|
|
from PIL import Image
|
|
from transformers import Owlv2Processor, Owlv2ForObjectDetection
|
|
import math
|
|
import zipfile
|
|
from utils import plot_predictions, mp4_to_png, vid_stitcher
|
|
|
|
def owl_batch_video(
|
|
input_vids: list[str],
|
|
target_prompt: list[str],
|
|
species_prompt: str,
|
|
threshold: float,
|
|
fps_processed: int = 1,
|
|
scaling_factor: float = 0.5,
|
|
batch_size: int = 8,
|
|
save_dir: str = "temp/"
|
|
):
|
|
pos_preds = []
|
|
neg_preds = []
|
|
|
|
df = pd.DataFrame(columns=["video path", "detection?"])
|
|
|
|
for vid in input_vids:
|
|
detection = owl_video_detection(vid,
|
|
target_prompt,
|
|
species_prompt,
|
|
threshold,
|
|
fps_processed=fps_processed,
|
|
scaling_factor=scaling_factor,
|
|
batch_size=batch_size,
|
|
save_dir=save_dir)
|
|
|
|
if detection == True:
|
|
pos_preds.append(vid)
|
|
row = pd.DataFrame({"video path": [vid], "detection?": ["True"]})
|
|
df = pd.concat([df, row], ignore_index=True)
|
|
else:
|
|
neg_preds.append(vid)
|
|
row = pd.DataFrame({"video path": [vid], "detection?": ["False"]})
|
|
df = pd.concat([df, row], ignore_index=True)
|
|
|
|
|
|
df.to_csv(save_dir + "detection_results.csv")
|
|
|
|
|
|
zip_file = f"{save_dir}/results.zip"
|
|
zip_directory(save_dir, zip_file)
|
|
|
|
return zip_file
|
|
|
|
|
|
|
|
def zip_directory(folder_path, output_zip_path):
|
|
with zipfile.ZipFile(output_zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
|
|
for root, dirs, files in os.walk(folder_path):
|
|
for file in files:
|
|
file_path = os.path.join(root, file)
|
|
|
|
arcname = os.path.relpath(file_path, start=folder_path)
|
|
zipf.write(file_path, arcname)
|
|
|
|
|
|
def preprocess_text(text_prompt: str, num_prompts: int = 1):
|
|
"""
|
|
Takes a string of text prompts and returns a list of lists of text prompts for each image.
|
|
i.e. text_prompt = "a, b, c" -> [["a", "b", "c"], ["a", "b", "c"]]
|
|
"""
|
|
text_prompt = [s.strip() for s in text_prompt.split(",")]
|
|
text_queries = [text_prompt] * num_prompts
|
|
|
|
return text_queries
|
|
|
|
def owl_batch_prediction(
|
|
images: torch.Tensor,
|
|
text_queries : list[str],
|
|
threshold: float,
|
|
processor,
|
|
model,
|
|
device: str = 'cuda'
|
|
):
|
|
|
|
inputs = processor(text=text_queries, images=images, return_tensors="pt").to(device)
|
|
with torch.no_grad():
|
|
outputs = model(**inputs)
|
|
|
|
|
|
target_sizes = torch.Tensor([img.size[::-1] for img in images]).to(device)
|
|
|
|
results = processor.post_process_object_detection(outputs=outputs, target_sizes=target_sizes, threshold=threshold)
|
|
|
|
return results
|
|
|
|
|
|
def count_pos(phrases: list[str], text_targets: list[str]) -> int:
|
|
"""
|
|
Counts how many phrases in the list match any of the target phrases.
|
|
|
|
Args:
|
|
phrases: A list of strings to evaluate.
|
|
text_targets: A list of target strings to match against.
|
|
|
|
Returns:
|
|
The number of phrases that match any of the targets.
|
|
"""
|
|
if len(phrases) == 0 or len(text_targets) == 0:
|
|
return 0
|
|
target_set = set(text_targets)
|
|
return sum(1 for phrase in phrases if phrase in target_set)
|
|
|
|
|
|
def owl_video_detection(
|
|
vid_path: str,
|
|
text_target: list[str],
|
|
text_prompt: str,
|
|
threshold: float,
|
|
fps_processed: int = 1,
|
|
scaling_factor: float = 0.5,
|
|
processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16-ensemble"),
|
|
model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble").to('cuda'),
|
|
device: str = 'cuda',
|
|
batch_size: int = 8,
|
|
save_dir: str = "temp/",
|
|
):
|
|
"""
|
|
Runs owl on a video and saves the results to a dataframe.
|
|
Returns True if text_target is detected in the video, False otherwise.
|
|
Stops running owl when a text_target is detected.
|
|
"""
|
|
os.makedirs(save_dir, exist_ok=True)
|
|
os.makedirs(f"{save_dir}/positives", exist_ok=True)
|
|
os.makedirs(f"{save_dir}/negatives", exist_ok=True)
|
|
|
|
|
|
df = pd.DataFrame(columns=["frame", "boxes", "scores", "labels", "count"])
|
|
|
|
|
|
filename = os.path.splitext(os.path.basename(vid_path))[0]
|
|
frames_dir = f"{save_dir}/{filename}_frames"
|
|
os.makedirs(frames_dir, exist_ok=True)
|
|
|
|
|
|
fps = mp4_to_png(vid_path, frames_dir, scaling_factor)
|
|
|
|
|
|
frame_filenames = os.listdir(frames_dir)
|
|
|
|
frame_paths = []
|
|
|
|
for i, frame in enumerate(frame_filenames):
|
|
if i % fps_processed == 0:
|
|
frame_paths.append(os.path.join(frames_dir, frame))
|
|
|
|
|
|
for i in tqdm(range(0, len(frame_paths), batch_size), desc="Running batches"):
|
|
frame_nums = [i*fps_processed for i in range(batch_size)]
|
|
batch_paths = frame_paths[i:i+batch_size]
|
|
images = [Image.open(image_path) for image_path in batch_paths]
|
|
|
|
|
|
text_queries = preprocess_text(text_prompt, len(batch_paths))
|
|
results = owl_batch_prediction(images, text_queries, threshold, processor, model, device)
|
|
|
|
|
|
label_ids = []
|
|
for entry in results:
|
|
if entry['labels'].numel() > 0:
|
|
label_ids.append(entry['labels'].tolist())
|
|
else:
|
|
label_ids.append(None)
|
|
|
|
text = text_queries[0]
|
|
labels = []
|
|
|
|
for idx in label_ids:
|
|
if idx is not None:
|
|
idx = [text[id] for id in idx]
|
|
labels.append(idx)
|
|
else:
|
|
labels.append([])
|
|
|
|
batch_pos = 0
|
|
for j, image in enumerate(batch_paths):
|
|
boxes = results[j]['boxes'].cpu().numpy()
|
|
scores = results[j]['scores'].cpu().numpy()
|
|
print(labels[j], text_target, count_pos(labels[j], text_target))
|
|
count = count_pos(labels[j], text_target)
|
|
row = pd.DataFrame({"frame": [image], "boxes": [boxes], "scores": [scores], "labels": [labels[j]], "count": count})
|
|
df = pd.concat([df, row], ignore_index=True)
|
|
|
|
|
|
if count > 0:
|
|
annotated_frame = plot_predictions(image, labels[j], scores, boxes)
|
|
cv2.imwrite(image, annotated_frame)
|
|
batch_pos += 1
|
|
|
|
|
|
if batch_pos > math.ceil(2/3*batch_size):
|
|
vid_stitcher(frames_dir, f"{save_dir}/positives/{filename}_{threshold}.mp4", fps)
|
|
shutil.rmtree(frames_dir)
|
|
df.to_csv(f"{save_dir}/positives/{filename}_{threshold}.csv", index=False)
|
|
return True
|
|
|
|
shutil.rmtree(frames_dir)
|
|
df.to_csv(f"{save_dir}/negatives/{filename}_{threshold}.csv", index=False)
|
|
return False
|
|
|
|
|