File size: 8,373 Bytes
b7c7aa0
 
dc465b0
b7c7aa0
ee6d0d7
b0844d7
 
c3ccbbe
 
 
7ac4d3a
a22ab2a
a156b86
b7c7aa0
6bf53fb
 
 
 
 
 
 
ee6d0d7
31c66a7
ee6d0d7
 
b7c7aa0
12c9224
c9e0b0f
352a2c1
 
ce10c28
d05a27e
44005dd
5de2505
 
d05a27e
 
 
 
 
ef9fe17
d05a27e
6bf53fb
 
39ff54e
 
 
 
a684d1e
39ff54e
 
6bf53fb
 
 
 
d05a27e
6bf53fb
ce10c28
6bf53fb
 
 
 
 
7b64e33
6bf53fb
1eff00c
 
 
 
 
6a77982
e6b969d
ce10c28
c91a2e2
d05a27e
 
b7c7aa0
 
 
 
 
 
 
 
 
 
 
dc465b0
b7c7aa0
 
7bee43e
dc465b0
 
 
 
 
50e7d4a
 
 
 
 
 
 
 
 
 
b7c7aa0
 
 
 
 
 
 
 
50e7d4a
 
1416e63
 
50e7d4a
 
 
 
 
 
 
 
b7c7aa0
 
50e7d4a
1416e63
 
c3ccbbe
50e7d4a
1416e63
 
50e7d4a
 
1416e63
b7c7aa0
a22ab2a
b7c7aa0
dc465b0
62c053c
a22ab2a
b7c7aa0
 
 
a22ab2a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
# app.py
import gradio as gr
from utils import VideoProcessor, AzureAPI, GoogleAPI, AnthropicAPI, OpenAIAPI
from constraint import SYS_PROMPT, USER_PROMPT
from datasets import load_dataset
import tempfile
import requests
from huggingface_hub import hf_hub_download, snapshot_download
import pyarrow.parquet as pq
import hashlib
import os
import csv
import av

# pip install --no-cache-dir huggingface_hub[hf_transfer]
def single_download(repo, fname, token, endpoint):
    os.environ["TOKIO_WORKER_THREADS"] = "32"
    os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
    file = hf_hub_download(repo_id=repo, filename=fname, token=token, endpoint=endpoint, repo_type="dataset")
    return file

def load_hf_dataset(dataset_path, auth_token):
    dataset = load_dataset(dataset_path, token=auth_token)
    video_paths = dataset
    return video_paths

def fast_caption(sys_prompt, usr_prompt, temp, top_p, max_tokens, model, key, endpoint, video_hf, video_hf_auth, parquet_index, video_od, video_od_auth, video_gd, video_gd_auth, frame_format, frame_limit):
    progress_info = []
    processor = VideoProcessor(frame_format=frame_format, frame_limit=frame_limit)
    api = AzureAPI(key=key, endpoint=endpoint, model=model, temp=temp, top_p=top_p, max_tokens=max_tokens)
    ind = 0
    with tempfile.TemporaryDirectory() as temp_dir:
        # temp_dir = '/opt/run'
        csv_filename = os.path.join('/dev/shm', str(parquet_index).zfill(6) + '_gpt4o_caption.csv')
        # csv_filename = '/dev/shm/caption.csv'
        with open(csv_filename, mode='w', newline='') as csv_file:
            fieldnames = ['md5', 'caption']
            writer = csv.DictWriter(csv_file, fieldnames=fieldnames)
            writer.writeheader()
    
            if video_hf and video_hf_auth:
                progress_info.append('Begin processing Hugging Face dataset.')
                os.environ["TOKIO_WORKER_THREADS"] = "8"
                os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
                pqfile = hf_hub_download(
                    repo_id=video_hf,
                    filename='data/' + str(parquet_index).zfill(6) + '.parquet',
                    repo_type="dataset",
                    local_dir="/dev/shm",
                    token=video_hf_auth,
                )

                pf = pq.ParquetFile(pqfile)
                for batch in pf.iter_batches(1):
                    _chunk = []
                    df = batch.to_pandas()
                    for binary in df["video"]:
                        ind += 1
                        if(binary):
                            _v = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
                            with open(_v.name, "wb") as f:
                                _ = f.write(binary)
                            _chunk.append(_v.name)
                            md5 = hashlib.md5(binary).hexdigest()
                            
                            frames = processor._decode(_v.name)
                            base64_list = processor.to_base64_list(frames)
                            caption = api.get_caption(sys_prompt, usr_prompt, base64_list)
                            writer.writerow({'md5': md5, 'caption': caption})
                            # writer.writerow({'md5': md5, 'caption': 'caption'})
                            # progress_info.append(f"Processed video with MD5: {md5}")
                        if ind == 86:
                            return csv_filename, "\n".join(progress_info), None
                # return csv_filename, "\n".join(progress_info), None
            else:
                return "", "No video source selected.", None

with gr.Blocks() as Core:
    with gr.Row(variant="panel"):
        with gr.Column(scale=6):
            with gr.Accordion("Debug", open=False):
                info = gr.Textbox(label="Info", interactive=False)
                frame = gr.Image(label="Frame", interactive=False)
            with gr.Accordion("Configuration", open=False):
                with gr.Row():
                    temp = gr.Slider(0, 1, 0.3, step=0.1, label="Temperature")
                    top_p = gr.Slider(0, 1, 0.75, step=0.1, label="Top-P")
                    max_tokens = gr.Slider(512, 4096, 1024, step=1, label="Max Tokens")
                with gr.Row():
                    frame_format = gr.Dropdown(label="Frame Format", value="JPEG", choices=["JPEG", "PNG"], interactive=False)
                    frame_limit = gr.Slider(1, 100, 10, step=1, label="Frame Limits")
            with gr.Tabs():
                with gr.Tab("User"):
                    usr_prompt = gr.Textbox(USER_PROMPT, label="User Prompt", lines=10, max_lines=100, show_copy_button=True)
                with gr.Tab("System"):
                    sys_prompt = gr.Textbox(SYS_PROMPT, label="System Prompt", lines=10, max_lines=100, show_copy_button=True)
            with gr.Tabs():
                with gr.Tab("Azure"):
                    result = gr.Textbox(label="Result", lines=15, max_lines=100, show_copy_button=True, interactive=False)
                with gr.Tab("Google"):
                    result_gg = gr.Textbox(label="Result", lines=15, max_lines=100, show_copy_button=True, interactive=False)
                with gr.Tab("Anthropic"):
                    result_ac = gr.Textbox(label="Result", lines=15, max_lines=100, show_copy_button=True, interactive=False)
                with gr.Tab("OpenAI"):
                    result_oai = gr.Textbox(label="Result", lines=15, max_lines=100, show_copy_button=True, interactive=False)

        with gr.Column(scale=2):
            with gr.Column():
                with gr.Accordion("Model Provider", open=True):
                    with gr.Tabs():
                        with gr.Tab("Azure"):
                            model = gr.Dropdown(label="Model", value="GPT-4o", choices=["GPT-4o", "GPT-4v"], interactive=False)
                            key = gr.Textbox(label="Azure API Key")
                            endpoint = gr.Textbox(label="Azure Endpoint")
                        with gr.Tab("Google"):
                            model_gg = gr.Dropdown(label="Model", value="Gemini-1.5-Flash", choices=["Gemini-1.5-Flash", "Gemini-1.5-Pro"], interactive=False)
                            key_gg = gr.Textbox(label="Gemini API Key")
                            endpoint_gg = gr.Textbox(label="Gemini API Endpoint")
                        with gr.Tab("Anthropic"):
                            model_ac = gr.Dropdown(label="Model", value="Claude-3-Opus", choices=["Claude-3-Opus", "Claude-3-Sonnet"], interactive=False)
                            key_ac = gr.Textbox(label="Anthropic API Key")
                            endpoint_ac = gr.Textbox(label="Anthropic Endpoint")
                        with gr.Tab("OpenAI"):
                            model_oai = gr.Dropdown(label="Model", value="GPT-4o", choices=["GPT-4o", "GPT-4v"], interactive=False)
                            key_oai = gr.Textbox(label="OpenAI API Key")
                            endpoint_oai = gr.Textbox(label="OpenAI Endpoint")
                with gr.Accordion("Data Source", open=True):
                    with gr.Tabs():
                        with gr.Tab("HF"):
                            video_hf = gr.Text(label="Huggingface File Path")
                            video_hf_auth = gr.Text(label="Huggingface Token")
                            parquet_index = gr.Text(label="Parquet Index")
                        with gr.Tab("Onedrive"):
                            video_od = gr.Text("Microsoft Onedrive")
                            video_od_auth = gr.Text(label="Microsoft Onedrive Token")
                        with gr.Tab("Google Drive"):
                            video_gd = gr.Text()
                            video_gd_auth = gr.Text(label="Google Drive Access Token")
                caption_button = gr.Button("Caption", variant="primary", size="lg")
                csv_link = gr.File(label="Download CSV", interactive=False)
        caption_button.click(
            fast_caption, 
            inputs=[sys_prompt, usr_prompt, temp, top_p, max_tokens, model, key, endpoint, video_hf, video_hf_auth, parquet_index, video_od, video_od_auth, video_gd, video_gd_auth, frame_format, frame_limit], 
            outputs=[csv_link, info, frame]
        )

if __name__ == "__main__":
    Core.launch()