File size: 5,979 Bytes
19fe404
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import argparse
import os
import re
from tqdm import tqdm

import pandas as pd
from vllm import LLM, SamplingParams

from utils.logger import logger


def parse_args():
    parser = argparse.ArgumentParser(description="Recaption the video frame.")
    parser.add_argument(
        "--video_metadata_path", type=str, required=True, help="The path to the video dataset metadata (csv/jsonl)."
    )
    parser.add_argument(
        "--video_path_column",
        type=str,
        default="video_path",
        help="The column contains the video path (an absolute path or a relative path w.r.t the video_folder).",
    )
    parser.add_argument(
        "--caption_column",
        type=str,
        default="sampled_frame_caption",
        help="The column contains the sampled_frame_caption.",
    )
    parser.add_argument(
        "--remove_quotes",
        action="store_true",
        help="Whether to remove quotes from caption.",
    )
    parser.add_argument(
        "--batch_size",
        type=int,
        default=10,
        required=False,
        help="The batch size for the video caption.",
    )
    parser.add_argument(
        "--summary_model_name",
        type=str,
        default="mistralai/Mistral-7B-Instruct-v0.2",
    )
    parser.add_argument(
        "--summary_prompt",
        type=str,
        default=(
            "You are a helpful video description generator. I'll give you a description of the middle frame of the video clip, "
            "which you need to summarize it into a description of the video clip."
            "Please provide your video description following these requirements: "
            "1. Describe the basic and necessary information of the video in the third person, be as concise as possible. "
            "2. Output the video description directly. Begin with 'In this video'. "
            "3. Limit the video description within 100 words. "
            "Here is the mid-frame description: "
        ),
    )
    parser.add_argument("--saved_path", type=str, required=True, help="The save path to the output results (csv/jsonl).")
    parser.add_argument("--saved_freq", type=int, default=1000, help="The frequency to save the output results.")

    args = parser.parse_args()
    return args


def main():
    args = parse_args()

    if args.video_metadata_path.endswith(".csv"):
        video_metadata_df = pd.read_csv(args.video_metadata_path)
    elif args.video_metadata_path.endswith(".jsonl"):
        video_metadata_df = pd.read_json(args.video_metadata_path, lines=True)
    else:
        raise ValueError("The video_metadata_path must end with .csv or .jsonl.")
    video_path_list = video_metadata_df[args.video_path_column].tolist()
    sampled_frame_caption_list = video_metadata_df[args.caption_column].tolist()
    
    if not (args.saved_path.endswith(".csv") or args.saved_path.endswith(".jsonl")):
        raise ValueError("The saved_path must end with .csv or .jsonl.")
    
    if os.path.exists(args.saved_path):
        if args.saved_path.endswith(".csv"):
            saved_metadata_df = pd.read_csv(args.saved_path)
        elif args.saved_path.endswith(".jsonl"):
            saved_metadata_df = pd.read_json(args.saved_path, lines=True)
        saved_video_path_list = saved_metadata_df[args.video_path_column].tolist()
        video_path_list = list(set(video_path_list) - set(saved_video_path_list))
        video_metadata_df.set_index(args.video_path_column, inplace=True)
        video_metadata_df = video_metadata_df.loc[video_path_list]
        sampled_frame_caption_list = video_metadata_df[args.caption_column].tolist()
        logger.info(f"Resume from {args.saved_path}: {len(saved_video_path_list)} processed and {len(video_path_list)} to be processed.")

    sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=256)
    summary_model = LLM(model=args.summary_model_name, trust_remote_code=True)

    result_dict = {"video_path": [], "summary_model": [], "summary_caption": []}

    for i in tqdm(range(0, len(sampled_frame_caption_list), args.batch_size)):
        batch_video_path = video_path_list[i: i + args.batch_size]
        batch_caption = sampled_frame_caption_list[i : i + args.batch_size]
        batch_prompt = []
        for caption in batch_caption:
            if args.remove_quotes:
                caption = re.sub(r'(["\']).*?\1', "", caption)
            batch_prompt.append("user:" + args.summary_prompt + str(caption) + "\n assistant:")
        batch_output = summary_model.generate(batch_prompt, sampling_params)
        
        result_dict["video_path"].extend(batch_video_path)
        result_dict["summary_model"].extend([args.summary_model_name] * len(batch_caption))
        result_dict["summary_caption"].extend([output.outputs[0].text.rstrip() for output in batch_output])

        # Save the metadata every args.saved_freq.
        if i != 0 and ((i // args.batch_size) % args.saved_freq) == 0:
            result_df = pd.DataFrame(result_dict)
            if args.saved_path.endswith(".csv"):
                header = True if not os.path.exists(args.saved_path) else False
                result_df.to_csv(args.saved_path, header=header, index=False, mode="a")
            elif args.saved_path.endswith(".jsonl"):
                result_df.to_json(args.saved_path, orient="records", lines=True, mode="a")
            logger.info(f"Save result to {args.saved_path}.")

            result_dict = {"video_path": [], "summary_model": [], "summary_caption": []}

    result_df = pd.DataFrame(result_dict)
    if args.saved_path.endswith(".csv"):
        header = True if not os.path.exists(args.saved_path) else False
        result_df.to_csv(args.saved_path, header=header, index=False, mode="a")
    elif args.saved_path.endswith(".jsonl"):
        result_df.to_json(args.saved_path, orient="records", lines=True, mode="a")
    logger.info(f"Save the final result to {args.saved_path}.")


if __name__ == "__main__":
    main()