bubbliiiing
Create Code
19fe404
import argparse
import os
import re
from tqdm import tqdm
import pandas as pd
from vllm import LLM, SamplingParams
from utils.logger import logger
def parse_args():
parser = argparse.ArgumentParser(description="Recaption the video frame.")
parser.add_argument(
"--video_metadata_path", type=str, required=True, help="The path to the video dataset metadata (csv/jsonl)."
)
parser.add_argument(
"--video_path_column",
type=str,
default="video_path",
help="The column contains the video path (an absolute path or a relative path w.r.t the video_folder).",
)
parser.add_argument(
"--caption_column",
type=str,
default="sampled_frame_caption",
help="The column contains the sampled_frame_caption.",
)
parser.add_argument(
"--remove_quotes",
action="store_true",
help="Whether to remove quotes from caption.",
)
parser.add_argument(
"--batch_size",
type=int,
default=10,
required=False,
help="The batch size for the video caption.",
)
parser.add_argument(
"--summary_model_name",
type=str,
default="mistralai/Mistral-7B-Instruct-v0.2",
)
parser.add_argument(
"--summary_prompt",
type=str,
default=(
"You are a helpful video description generator. I'll give you a description of the middle frame of the video clip, "
"which you need to summarize it into a description of the video clip."
"Please provide your video description following these requirements: "
"1. Describe the basic and necessary information of the video in the third person, be as concise as possible. "
"2. Output the video description directly. Begin with 'In this video'. "
"3. Limit the video description within 100 words. "
"Here is the mid-frame description: "
),
)
parser.add_argument("--saved_path", type=str, required=True, help="The save path to the output results (csv/jsonl).")
parser.add_argument("--saved_freq", type=int, default=1000, help="The frequency to save the output results.")
args = parser.parse_args()
return args
def main():
args = parse_args()
if args.video_metadata_path.endswith(".csv"):
video_metadata_df = pd.read_csv(args.video_metadata_path)
elif args.video_metadata_path.endswith(".jsonl"):
video_metadata_df = pd.read_json(args.video_metadata_path, lines=True)
else:
raise ValueError("The video_metadata_path must end with .csv or .jsonl.")
video_path_list = video_metadata_df[args.video_path_column].tolist()
sampled_frame_caption_list = video_metadata_df[args.caption_column].tolist()
if not (args.saved_path.endswith(".csv") or args.saved_path.endswith(".jsonl")):
raise ValueError("The saved_path must end with .csv or .jsonl.")
if os.path.exists(args.saved_path):
if args.saved_path.endswith(".csv"):
saved_metadata_df = pd.read_csv(args.saved_path)
elif args.saved_path.endswith(".jsonl"):
saved_metadata_df = pd.read_json(args.saved_path, lines=True)
saved_video_path_list = saved_metadata_df[args.video_path_column].tolist()
video_path_list = list(set(video_path_list) - set(saved_video_path_list))
video_metadata_df.set_index(args.video_path_column, inplace=True)
video_metadata_df = video_metadata_df.loc[video_path_list]
sampled_frame_caption_list = video_metadata_df[args.caption_column].tolist()
logger.info(f"Resume from {args.saved_path}: {len(saved_video_path_list)} processed and {len(video_path_list)} to be processed.")
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=256)
summary_model = LLM(model=args.summary_model_name, trust_remote_code=True)
result_dict = {"video_path": [], "summary_model": [], "summary_caption": []}
for i in tqdm(range(0, len(sampled_frame_caption_list), args.batch_size)):
batch_video_path = video_path_list[i: i + args.batch_size]
batch_caption = sampled_frame_caption_list[i : i + args.batch_size]
batch_prompt = []
for caption in batch_caption:
if args.remove_quotes:
caption = re.sub(r'(["\']).*?\1', "", caption)
batch_prompt.append("user:" + args.summary_prompt + str(caption) + "\n assistant:")
batch_output = summary_model.generate(batch_prompt, sampling_params)
result_dict["video_path"].extend(batch_video_path)
result_dict["summary_model"].extend([args.summary_model_name] * len(batch_caption))
result_dict["summary_caption"].extend([output.outputs[0].text.rstrip() for output in batch_output])
# Save the metadata every args.saved_freq.
if i != 0 and ((i // args.batch_size) % args.saved_freq) == 0:
result_df = pd.DataFrame(result_dict)
if args.saved_path.endswith(".csv"):
header = True if not os.path.exists(args.saved_path) else False
result_df.to_csv(args.saved_path, header=header, index=False, mode="a")
elif args.saved_path.endswith(".jsonl"):
result_df.to_json(args.saved_path, orient="records", lines=True, mode="a")
logger.info(f"Save result to {args.saved_path}.")
result_dict = {"video_path": [], "summary_model": [], "summary_caption": []}
result_df = pd.DataFrame(result_dict)
if args.saved_path.endswith(".csv"):
header = True if not os.path.exists(args.saved_path) else False
result_df.to_csv(args.saved_path, header=header, index=False, mode="a")
elif args.saved_path.endswith(".jsonl"):
result_df.to_json(args.saved_path, orient="records", lines=True, mode="a")
logger.info(f"Save the final result to {args.saved_path}.")
if __name__ == "__main__":
main()