lalalalalalalalalala's picture
Update run.py
e6b969d verified
# app.py
import gradio as gr
from utils import VideoProcessor, AzureAPI, GoogleAPI, AnthropicAPI, OpenAIAPI
from constraint import SYS_PROMPT, USER_PROMPT
from datasets import load_dataset
import tempfile
import requests
from huggingface_hub import hf_hub_download, snapshot_download
import pyarrow.parquet as pq
import hashlib
import os
import csv
import av
# pip install --no-cache-dir huggingface_hub[hf_transfer]
def single_download(repo, fname, token, endpoint):
os.environ["TOKIO_WORKER_THREADS"] = "32"
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
file = hf_hub_download(repo_id=repo, filename=fname, token=token, endpoint=endpoint, repo_type="dataset")
return file
def load_hf_dataset(dataset_path, auth_token):
dataset = load_dataset(dataset_path, token=auth_token)
video_paths = dataset
return video_paths
def fast_caption(sys_prompt, usr_prompt, temp, top_p, max_tokens, model, key, endpoint, video_hf, video_hf_auth, parquet_index, video_od, video_od_auth, video_gd, video_gd_auth, frame_format, frame_limit):
progress_info = []
processor = VideoProcessor(frame_format=frame_format, frame_limit=frame_limit)
api = AzureAPI(key=key, endpoint=endpoint, model=model, temp=temp, top_p=top_p, max_tokens=max_tokens)
ind = 0
with tempfile.TemporaryDirectory() as temp_dir:
# temp_dir = '/opt/run'
csv_filename = os.path.join('/dev/shm', str(parquet_index).zfill(6) + '_gpt4o_caption.csv')
# csv_filename = '/dev/shm/caption.csv'
with open(csv_filename, mode='w', newline='') as csv_file:
fieldnames = ['md5', 'caption']
writer = csv.DictWriter(csv_file, fieldnames=fieldnames)
writer.writeheader()
if video_hf and video_hf_auth:
progress_info.append('Begin processing Hugging Face dataset.')
os.environ["TOKIO_WORKER_THREADS"] = "8"
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
pqfile = hf_hub_download(
repo_id=video_hf,
filename='data/' + str(parquet_index).zfill(6) + '.parquet',
repo_type="dataset",
local_dir="/dev/shm",
token=video_hf_auth,
)
pf = pq.ParquetFile(pqfile)
for batch in pf.iter_batches(1):
_chunk = []
df = batch.to_pandas()
for binary in df["video"]:
ind += 1
if(binary):
_v = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
with open(_v.name, "wb") as f:
_ = f.write(binary)
_chunk.append(_v.name)
md5 = hashlib.md5(binary).hexdigest()
frames = processor._decode(_v.name)
base64_list = processor.to_base64_list(frames)
caption = api.get_caption(sys_prompt, usr_prompt, base64_list)
writer.writerow({'md5': md5, 'caption': caption})
# writer.writerow({'md5': md5, 'caption': 'caption'})
# progress_info.append(f"Processed video with MD5: {md5}")
if ind == 86:
return csv_filename, "\n".join(progress_info), None
# return csv_filename, "\n".join(progress_info), None
else:
return "", "No video source selected.", None
with gr.Blocks() as Core:
with gr.Row(variant="panel"):
with gr.Column(scale=6):
with gr.Accordion("Debug", open=False):
info = gr.Textbox(label="Info", interactive=False)
frame = gr.Image(label="Frame", interactive=False)
with gr.Accordion("Configuration", open=False):
with gr.Row():
temp = gr.Slider(0, 1, 0.3, step=0.1, label="Temperature")
top_p = gr.Slider(0, 1, 0.75, step=0.1, label="Top-P")
max_tokens = gr.Slider(512, 4096, 1024, step=1, label="Max Tokens")
with gr.Row():
frame_format = gr.Dropdown(label="Frame Format", value="JPEG", choices=["JPEG", "PNG"], interactive=False)
frame_limit = gr.Slider(1, 100, 10, step=1, label="Frame Limits")
with gr.Tabs():
with gr.Tab("User"):
usr_prompt = gr.Textbox(USER_PROMPT, label="User Prompt", lines=10, max_lines=100, show_copy_button=True)
with gr.Tab("System"):
sys_prompt = gr.Textbox(SYS_PROMPT, label="System Prompt", lines=10, max_lines=100, show_copy_button=True)
with gr.Tabs():
with gr.Tab("Azure"):
result = gr.Textbox(label="Result", lines=15, max_lines=100, show_copy_button=True, interactive=False)
with gr.Tab("Google"):
result_gg = gr.Textbox(label="Result", lines=15, max_lines=100, show_copy_button=True, interactive=False)
with gr.Tab("Anthropic"):
result_ac = gr.Textbox(label="Result", lines=15, max_lines=100, show_copy_button=True, interactive=False)
with gr.Tab("OpenAI"):
result_oai = gr.Textbox(label="Result", lines=15, max_lines=100, show_copy_button=True, interactive=False)
with gr.Column(scale=2):
with gr.Column():
with gr.Accordion("Model Provider", open=True):
with gr.Tabs():
with gr.Tab("Azure"):
model = gr.Dropdown(label="Model", value="GPT-4o", choices=["GPT-4o", "GPT-4v"], interactive=False)
key = gr.Textbox(label="Azure API Key")
endpoint = gr.Textbox(label="Azure Endpoint")
with gr.Tab("Google"):
model_gg = gr.Dropdown(label="Model", value="Gemini-1.5-Flash", choices=["Gemini-1.5-Flash", "Gemini-1.5-Pro"], interactive=False)
key_gg = gr.Textbox(label="Gemini API Key")
endpoint_gg = gr.Textbox(label="Gemini API Endpoint")
with gr.Tab("Anthropic"):
model_ac = gr.Dropdown(label="Model", value="Claude-3-Opus", choices=["Claude-3-Opus", "Claude-3-Sonnet"], interactive=False)
key_ac = gr.Textbox(label="Anthropic API Key")
endpoint_ac = gr.Textbox(label="Anthropic Endpoint")
with gr.Tab("OpenAI"):
model_oai = gr.Dropdown(label="Model", value="GPT-4o", choices=["GPT-4o", "GPT-4v"], interactive=False)
key_oai = gr.Textbox(label="OpenAI API Key")
endpoint_oai = gr.Textbox(label="OpenAI Endpoint")
with gr.Accordion("Data Source", open=True):
with gr.Tabs():
with gr.Tab("HF"):
video_hf = gr.Text(label="Huggingface File Path")
video_hf_auth = gr.Text(label="Huggingface Token")
parquet_index = gr.Text(label="Parquet Index")
with gr.Tab("Onedrive"):
video_od = gr.Text("Microsoft Onedrive")
video_od_auth = gr.Text(label="Microsoft Onedrive Token")
with gr.Tab("Google Drive"):
video_gd = gr.Text()
video_gd_auth = gr.Text(label="Google Drive Access Token")
caption_button = gr.Button("Caption", variant="primary", size="lg")
csv_link = gr.File(label="Download CSV", interactive=False)
caption_button.click(
fast_caption,
inputs=[sys_prompt, usr_prompt, temp, top_p, max_tokens, model, key, endpoint, video_hf, video_hf_auth, parquet_index, video_od, video_od_auth, video_gd, video_gd_auth, frame_format, frame_limit],
outputs=[csv_link, info, frame]
)
if __name__ == "__main__":
Core.launch()