lalalalalalalalalala commited on
Commit
c3ccbbe
·
verified ·
1 Parent(s): e13fa07

Update run.py

Browse files
Files changed (1) hide show
  1. run.py +33 -26
run.py CHANGED
@@ -5,6 +5,10 @@ from constraint import SYS_PROMPT, USER_PROMPT
5
  from datasets import load_dataset
6
  import tempfile
7
  import requests
 
 
 
 
8
 
9
  def load_hf_dataset(dataset_path, auth_token):
10
  dataset = load_dataset(dataset_path, token=auth_token)
@@ -13,7 +17,7 @@ def load_hf_dataset(dataset_path, auth_token):
13
 
14
  return video_paths
15
 
16
- def fast_caption(sys_prompt, usr_prompt, temp, top_p, max_tokens, model, key, endpoint, video_src, video_hf, video_hf_auth, video_od, video_od_auth, video_gd, video_gd_auth, frame_format, frame_limit):
17
  if video_src:
18
  video = video_src
19
  processor = VideoProcessor(frame_format=frame_format, frame_limit=frame_limit)
@@ -29,31 +33,36 @@ def fast_caption(sys_prompt, usr_prompt, temp, top_p, max_tokens, model, key, en
29
  caption = api.get_caption(sys_prompt, usr_prompt, base64_list)
30
  return f"{caption}", f"Using model '{model}' with {len(frames)} frames extracted.", debug_image
31
  elif video_hf and video_hf_auth:
32
- # Handle Hugging Face dataset
33
- video_paths = load_hf_dataset(video_hf, video_hf_auth)
34
- video_paths = video_paths["train"]
35
  # Process all videos in the dataset
36
  all_captions = []
37
- for video_path_url in video_paths:
38
- video_path_url = video_path_url["id"]
39
- # 使用requests下载文件到临时文件
40
- response = requests.get(video_path_url, stream=True)
41
- if response.status_code == 200:
42
- with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as temp_video_file:
43
- temp_video_file.write(response.content)
 
 
 
 
 
 
 
 
 
 
44
  video_path = temp_video_file.name
45
- else:
46
- raise Exception(f"Failed to download video, status code: {response.status_code}")
47
-
48
- if video_path.endswith('.mp4'): # 假设我们只处理.mp4文件
49
- processor = VideoProcessor(frame_format=frame_format, frame_limit=frame_limit)
50
- frames = processor._decode(video_path)
51
- base64_list = processor.to_base64_list(frames)
52
- api = AzureAPI(key=key, endpoint=endpoint, model=model, temp=temp, top_p=top_p, max_tokens=max_tokens)
53
- caption = api.get_caption(sys_prompt, usr_prompt, base64_list)
54
- all_captions.append(caption)
55
  return "\n\n\n".join(all_captions), f"Processed {len(video_paths)} videos.", None
56
- # ... (Handle other sources)
57
  else:
58
  return "", "No video source selected.", None
59
 
@@ -113,9 +122,7 @@ with gr.Blocks() as Core:
113
  with gr.Tab("HF"):
114
  video_hf = gr.Text(label="Huggingface File Path")
115
  video_hf_auth = gr.Text(label="Huggingface Token")
116
- with gr.Tab("Parquet_index"):
117
- video_hf = gr.Text(label="Parquet_index")
118
- video_hf_auth = gr.Text(label="Huggingface Token")
119
  with gr.Tab("Onedrive"):
120
  video_od = gr.Text("Microsoft Onedrive")
121
  video_od_auth = gr.Text(label="Microsoft Onedrive Token")
@@ -125,7 +132,7 @@ with gr.Blocks() as Core:
125
  caption_button = gr.Button("Caption", variant="primary", size="lg")
126
  caption_button.click(
127
  fast_caption,
128
- inputs=[sys_prompt, usr_prompt, temp, top_p, max_tokens, model, key, endpoint, video_src, video_hf, video_hf_auth, video_od, video_od_auth, video_gd, video_gd_auth, frame_format, frame_limit],
129
  outputs=[result, info, frame]
130
  )
131
 
 
5
  from datasets import load_dataset
6
  import tempfile
7
  import requests
8
+ from huggingface_hub import hf_hub_download, snapshot_download
9
+ import pyarrow.parquet as pq
10
+ import hashlib
11
+
12
 
13
  def load_hf_dataset(dataset_path, auth_token):
14
  dataset = load_dataset(dataset_path, token=auth_token)
 
17
 
18
  return video_paths
19
 
20
+ def fast_caption(sys_prompt, usr_prompt, temp, top_p, max_tokens, model, key, endpoint, video_src, video_hf, video_hf_auth, parquet_index, video_od, video_od_auth, video_gd, video_gd_auth, frame_format, frame_limit):
21
  if video_src:
22
  video = video_src
23
  processor = VideoProcessor(frame_format=frame_format, frame_limit=frame_limit)
 
33
  caption = api.get_caption(sys_prompt, usr_prompt, base64_list)
34
  return f"{caption}", f"Using model '{model}' with {len(frames)} frames extracted.", debug_image
35
  elif video_hf and video_hf_auth:
 
 
 
36
  # Process all videos in the dataset
37
  all_captions = []
38
+ with tempfile.NamedTemporaryFile(mode='w+t', delete=True) as temp_parquet_file:
39
+ temp_parquet_file = hf_hub_download(
40
+ repo_id="OpenVideo/pexels-raw",
41
+ filename="data/“ + str(number).zfill(6) + “.parquet",
42
+ repo_type="dataset",
43
+ token=video_hf_auth,
44
+ )
45
+ parquet_path = temp_parquet_file.name
46
+ parquet_file = pq.ParquetFile(parquet_path)
47
+
48
+ for batch in parquet_file.iter_batches(batch_size=1):
49
+ df = batch.to_pandas()
50
+ video = df['video'][0]
51
+
52
+ md5 = hashlib.md5(video).hexdigest()
53
+ with tempfile.NamedTemporaryFile(mode='w+t', delete=True) as temp_video_file:
54
+ temp_video_file.write(video)
55
  video_path = temp_video_file.name
56
+
57
+ processor = VideoProcessor(frame_format=frame_format, frame_limit=frame_limit)
58
+ frames = processor._decode(video_path)
59
+ base64_list = processor.to_base64_list(frames)
60
+ api = AzureAPI(key=key, endpoint=endpoint, model=model, temp=temp, top_p=top_p, max_tokens=max_tokens)
61
+ caption = api.get_caption(sys_prompt, usr_prompt, base64_list)
62
+ all_captions.append(caption)
63
+
 
 
64
  return "\n\n\n".join(all_captions), f"Processed {len(video_paths)} videos.", None
65
+
66
  else:
67
  return "", "No video source selected.", None
68
 
 
122
  with gr.Tab("HF"):
123
  video_hf = gr.Text(label="Huggingface File Path")
124
  video_hf_auth = gr.Text(label="Huggingface Token")
125
+ parquet_index = gr.Text(label="Parquet Index")
 
 
126
  with gr.Tab("Onedrive"):
127
  video_od = gr.Text("Microsoft Onedrive")
128
  video_od_auth = gr.Text(label="Microsoft Onedrive Token")
 
132
  caption_button = gr.Button("Caption", variant="primary", size="lg")
133
  caption_button.click(
134
  fast_caption,
135
+ inputs=[sys_prompt, usr_prompt, temp, top_p, max_tokens, model, key, endpoint, video_src, video_hf, video_hf_auth, parquet_index, video_od, video_od_auth, video_gd, video_gd_auth, frame_format, frame_limit],
136
  outputs=[result, info, frame]
137
  )
138