bubbliiiing
Create Code
19fe404
import argparse
import copy
import os
import pandas as pd
from accelerate import PartialState
from accelerate.utils import gather_object
from natsort import natsorted
from tqdm import tqdm
from torch.utils.data import DataLoader
from utils.logger import logger
from utils.video_dataset import VideoDataset, collate_fn
from utils.video_utils import get_video_path_list, extract_frames
ACCELERATE_SUPPORTED_MODELS = ["Qwen-VL-Chat", "internlm-xcomposer2-vl-7b"]
SGLANG_SUPPORTED_MODELS = ["llava-v1.6-vicuna-7b"]
def parse_args():
parser = argparse.ArgumentParser(description="Recaption the video frame.")
parser.add_argument("--video_folder", type=str, default="", help="The video folder.")
parser.add_argument(
"--video_metadata_path", type=str, default=None, help="The path to the video dataset metadata (csv/jsonl/txt)."
)
parser.add_argument(
"--video_path_column",
type=str,
default="video_path",
help="The column contains the video path (an absolute path or a relative path w.r.t the video_folder).",
)
parser.add_argument(
"--batch_size",
type=int,
default=10,
required=False,
help="The batch size for the video dataset.",
)
parser.add_argument(
"--frame_sample_method",
type=str,
choices=["mid", "uniform"],
default="mid",
)
parser.add_argument(
"--num_sampled_frames",
type=int,
default=1,
help="num_sampled_frames",
)
parser.add_argument(
"--image_caption_model_name",
type=str,
choices=ACCELERATE_SUPPORTED_MODELS + SGLANG_SUPPORTED_MODELS,
default="internlm-xcomposer2-vl-7b",
)
parser.add_argument(
"--image_caption_model_quantized", type=bool, default=True, help="Whether to use the quantized image caption model."
)
parser.add_argument(
"--image_caption_prompt",
type=str,
default="Describe this image and its style in a very detailed manner.",
)
parser.add_argument(
"--output_dir",
type=str,
required=True,
help="The directory to create the subfolder (named with the video name) to indicate the video has been processed.",
)
parser.add_argument("--saved_path", type=str, required=True, help="The save path to the output results (csv/jsonl).")
parser.add_argument("--saved_freq", type=int, default=1000, help="The frequency to save the output results.")
args = parser.parse_args()
return args
def accelerate_inference(args, video_path_list):
from utils.image_captioner_awq import QwenVLChat, InternLMXComposer2
state = PartialState()
device = state.device
if state.num_processes == 1:
device = "cuda:0"
if args.image_caption_model_name == "internlm-xcomposer2-vl-7b":
image_caption_model = InternLMXComposer2(device=device, quantized=args.image_caption_model_quantized)
elif args.image_caption_model_name == "Qwen-VL-Chat":
image_caption_model = QwenVLChat(device=device, quantized=args.image_caption_model_quantized)
# The workaround can be removed after https://github.com/huggingface/accelerate/pull/2781 is released.
index = len(video_path_list) - len(video_path_list) % state.num_processes
logger.info(f"Drop {len(video_path_list) % state.num_processes} videos to avoid duplicates in state.split_between_processes.")
video_path_list = video_path_list[:index]
if state.is_main_process:
os.makedirs(args.output_dir, exist_ok=True)
result_list = []
with state.split_between_processes(video_path_list) as splitted_video_path_list:
for i, video_path in enumerate(tqdm(splitted_video_path_list, desc=f"{state.device}")):
video_id = os.path.splitext(os.path.basename(video_path))[0]
try:
if not os.path.exists(video_path):
print(f"Video {video_id} does not exist. Pass it.")
continue
sampled_frame_list, sampled_frame_idx_list = extract_frames(video_path, num_sample_frames=args.num_sample_frames)
except Exception as e:
print(f"Failed to extract frames from video {video_id}. Error is {e}.")
video_recaption_output_dir = os.path.join(args.output_dir, video_id)
if os.path.exists(video_recaption_output_dir):
print(f"Video {video_id} has been processed. Pass it.")
continue
else:
os.makedirs(video_recaption_output_dir)
caption_list = []
for frame, frame_idx in zip(sampled_frame_list, sampled_frame_idx_list):
frame_path = f"{args.output_dir}/{video_id}_{frame_idx}.png"
frame.save(frame_path)
try:
response, _ = image_caption_model(args.image_caption_prompt, frame_path)
except Exception as e:
print(f"Failed to caption video {video_id}. Error is {e}.")
finally:
os.remove(frame_path)
caption_list.append(response)
result_meta = {}
if args.video_folder == "":
result_meta[args.video_path_column] = video_path
else:
result_meta[args.video_path_column] = os.path.basename(video_path)
result_meta["image_caption_model"] = args.image_caption_model_name
result_meta["prompt"] = args.image_caption_prompt
result_meta["sampled_frame_idx"] = sampled_frame_idx_list
result_meta["sampled_frame_caption"] = caption_list
result_list.append(copy.deepcopy(result_meta))
# Save the metadata in the main process.
if i != 0 and i % args.saved_freq == 0:
state.wait_for_everyone()
gathered_result_list = gather_object(result_list)
if state.is_main_process:
result_df = pd.DataFrame(gathered_result_list)
if args.saved_path.endswith(".csv"):
result_df.to_csv(args.saved_path, index=False)
elif args.saved_path.endswith(".jsonl"):
result_df.to_json(args.saved_path, orient="records", lines=True)
print(f"Save result to {args.saved_path}.")
# Wait for all processes to finish and gather the final result.
state.wait_for_everyone()
gathered_result_list = gather_object(result_list)
# Save the metadata in the main process.
if state.is_main_process:
result_df = pd.DataFrame(gathered_result_list)
if args.saved_path.endswith(".csv"):
result_df.to_csv(args.saved_path, index=False)
elif args.saved_path.endswith(".jsonl"):
result_df.to_json(args.saved_path, orient="records", lines=True)
print(f"Save the final result to {args.saved_path}.")
def sglang_inference(args, video_path_list):
from utils.image_captioner_sglang import LLaVASRT
if args.image_caption_model_name == "llava-v1.6-vicuna-7b":
image_caption_model = LLaVASRT()
result_dict = {
"video_path": [],
"image_caption_model": [],
"prompt": [],
'sampled_frame_idx': [],
"sampled_frame_caption": []
}
video_dataset = VideoDataset(
video_path_list=video_path_list,
sample_method=args.frame_sample_method,
num_sampled_frames=args.num_sampled_frames
)
video_loader = DataLoader(video_dataset, batch_size=args.batch_size, num_workers=16, collate_fn=collate_fn)
for idx, batch in enumerate(tqdm(video_loader)):
if len(batch) == 0:
continue
batch_video_path, batch_frame_idx = batch["video_path"], batch["sampled_frame_idx"]
# [batch_size, num_sampled_frames, H, W, C] => [batch_size * num_sampled_frames, H, W, C].
batch_frame = []
for item_sampled_frame in batch["sampled_frame"]:
batch_frame.extend([frame for frame in item_sampled_frame])
try:
response_list, _ = image_caption_model([args.image_caption_prompt] * len(batch_frame), batch_frame)
response_list = [response_list[i:i + args.num_sampled_frames] for i in range(0, len(response_list), args.num_sampled_frames)]
except Exception as e:
logger.error(f"Failed to caption video {batch_video_path}. Error is {e}.")
result_dict["video_path"].extend(batch_video_path)
result_dict["image_caption_model"].extend([args.image_caption_model_name] * len(batch_video_path))
result_dict["prompt"].extend([args.image_caption_prompt] * len(batch_video_path))
result_dict["sampled_frame_idx"].extend(batch_frame_idx)
result_dict["sampled_frame_caption"].extend(response_list)
# Save the metadata in the main process.
if idx != 0 and idx % args.saved_freq == 0:
result_df = pd.DataFrame(result_dict)
if args.saved_path.endswith(".csv"):
header = True if not os.path.exists(args.saved_path) else False
result_df.to_csv(args.saved_path, header=header, index=False, mode="a")
elif args.saved_path.endswith(".jsonl"):
result_df.to_json(args.saved_path, orient="records", lines=True, mode="a")
logger.info(f"Save result to {args.saved_path}.")
result_dict = {
"video_path": [],
"image_caption_model": [],
"prompt": [],
'sampled_frame_idx': [],
"sampled_frame_caption": []
}
if len(result_dict["video_path"]) != 0:
result_df = pd.DataFrame(result_dict)
if args.saved_path.endswith(".csv"):
header = True if not os.path.exists(args.saved_path) else False
result_df.to_csv(args.saved_path, header=header, index=False, mode="a")
elif args.saved_path.endswith(".jsonl"):
result_df.to_json(args.saved_path, orient="records", lines=True, mode="a")
logger.info(f"Save the final result to {args.saved_path}.")
def main():
args = parse_args()
video_path_list = get_video_path_list(
video_folder=args.video_folder,
video_metadata_path=args.video_metadata_path,
video_path_column=args.video_path_column
)
if not (args.saved_path.endswith(".csv") or args.saved_path.endswith(".jsonl")):
raise ValueError("The saved_path must end with .csv or .jsonl.")
if os.path.exists(args.saved_path):
if args.saved_path.endswith(".csv"):
saved_metadata_df = pd.read_csv(args.saved_path)
elif args.saved_path.endswith(".jsonl"):
saved_metadata_df = pd.read_json(args.saved_path, lines=True)
saved_video_path_list = saved_metadata_df[args.video_path_column].tolist()
saved_video_path_list = [os.path.join(args.video_folder, path) for path in saved_video_path_list]
video_path_list = list(set(video_path_list) - set(saved_video_path_list))
# Sorting to guarantee the same result for each process.
video_path_list = natsorted(video_path_list)
logger.info(f"Resume from {args.saved_path}: {len(saved_video_path_list)} processed and {len(video_path_list)} to be processed.")
if args.image_caption_model_name in SGLANG_SUPPORTED_MODELS:
sglang_inference(args, video_path_list)
elif args.image_caption_model_name in ACCELERATE_SUPPORTED_MODELS:
accelerate_inference(args, video_path_list)
else:
raise ValueError(f"The {args.image_caption_model_name} is not supported.")
if __name__ == "__main__":
main()