import torch |
import torch.quantization |
from tqdm import tqdm |
import cv2 |
import os |
import numpy as np |
import pandas as pd |
from datetime import datetime |
from typing import Tuple |
from PIL import Image |
from utils import plot_predictions, mp4_to_png, vid_stitcher |
from transformers import Owlv2Processor, Owlv2ForObjectDetection |
def preprocess_text(text_prompt: str, num_prompts: int = 1): |
""" |
Takes a string of text prompts and returns a list of lists of text prompts for each image. |
i.e. text_prompt = "a, b, c" -> [["a", "b", "c"], ["a", "b", "c"]] |
""" |
text_prompt = [s.strip() for s in text_prompt.split(",")] |
text_queries = [text_prompt] * num_prompts |
return text_queries |
def owl_batch_prediction( |
images: torch.Tensor, |
text_queries : list[str], |
threshold: float, |
processor, |
model, |
device: str = 'cuda' |
): |
inputs = processor(text=text_queries, images=images, return_tensors="pt").to(device) |
with torch.no_grad(): |
outputs = model(**inputs) |
target_sizes = torch.Tensor([img.size[::-1] for img in images]).to(device) |
results = processor.post_process_object_detection(outputs=outputs, target_sizes=target_sizes, threshold=threshold) |
return results |
def owl_full_video( |
vid_path: str, |
text_prompt: str, |
threshold: float, |
fps_processed: int = 1, |
scaling_factor: float = 0.5, |
device: str = 'cuda', |
batch_size: int = 6, |
): |
""" Same as owl_video, but processes the entire video regardless of detection bool. |
Saves results per frame to a df. |
""" |
processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16-ensemble") |
model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble").half().to('cuda') |
filename = os.path.splitext(os.path.basename(vid_path))[0] |
results_dir = f'temp/{filename}_{datetime.now().strftime("%H%M%S")}' |
frames_dir = os.path.join(results_dir, "frames") |
if not os.path.exists(results_dir): |
os.makedirs(results_dir, exist_ok=True) |
os.makedirs(frames_dir, exist_ok=True) |
fps = mp4_to_png(vid_path, frames_dir, scaling_factor) |
frame_filenames = os.listdir(frames_dir) |
frame_paths = [] |
annotation_guide = {} |
last_frame_run = frame_filenames[0] |
for i, frame in enumerate(frame_filenames): |
path = os.path.join(frames_dir, frame) |
if i % fps_processed == 0: |
last_frame_run = path |
frame_paths.append(path) |
annotation_guide[path] = [] |
else: |
annotation_guide[last_frame_run].append(path) |
df = pd.DataFrame(columns=["frame", "boxes", "scores", "labels"]) |
for i in tqdm(range(0, len(frame_paths), batch_size), desc="Running batches"): |
frame_nums = [i*fps_processed for i in range(batch_size)] |
batch_paths = frame_paths[i:i+batch_size] |
filenames = [os.path.basename(p) for p in batch_paths] |
images = [Image.open(image_path) for image_path in batch_paths] |
text_queries = preprocess_text(text_prompt, len(batch_paths)) |
results = owl_batch_prediction(images, text_queries, threshold, processor, model, device) |
label_ids = [] |
for entry in results: |
if entry['labels'].numel() > 0: |
label_ids.append(entry['labels'].tolist()) |
else: |
label_ids.append(None) |
text = text_queries[0] |
labels = [] |
for idx in label_ids: |
if idx is not None: |
idx = [text[id] for id in idx] |
labels.append(idx) |
else: |
labels.append(None) |
for j, image in enumerate(batch_paths): |
boxes = results[j]['boxes'].cpu().numpy() |
scores = results[j]['scores'].cpu().numpy() |
row = pd.DataFrame({"frame": [image], "boxes": [boxes], "scores": [scores], "labels": [labels[j]]}) |
df = pd.concat([df, row], ignore_index=True) |
if labels[j] is not None: |
annotated_frame = plot_predictions(image, labels[j], scores, boxes) |
cv2.imwrite(image, annotated_frame) |
for key in annotation_guide: |
labels = df[df["frame"] == key]["labels"].tolist()[0] |
boxes = df[df["frame"] == key]["boxes"].tolist()[0] |
scores = df[df["frame"] == key]["scores"].tolist()[0] |
print(labels) |
if not labels: |
continue |
for frame in annotation_guide[key]: |
annotated_frame = plot_predictions(frame, labels, scores, boxes, opacity=0.3) |
cv2.imwrite(frame, annotated_frame) |
csv_path = f"{results_dir}/{filename}_{threshold}.csv" |
df.to_csv(csv_path, index=False) |
save_path = vid_stitcher(frames_dir, output_path=os.path.join(results_dir, "output.mp4"), fps=fps) |
return csv_path, save_path |