lalalalalalalalalala commited on
Commit
a22ab2a
·
verified ·
1 Parent(s): c9e0b0f

Update run.py

Browse files
Files changed (1) hide show
  1. run.py +51 -45
run.py CHANGED
@@ -9,7 +9,7 @@ from huggingface_hub import hf_hub_download, snapshot_download
9
  import pyarrow.parquet as pq
10
  import hashlib
11
  import os
12
-
13
 
14
  def load_hf_dataset(dataset_path, auth_token):
15
  dataset = load_dataset(dataset_path, token=auth_token)
@@ -18,47 +18,53 @@ def load_hf_dataset(dataset_path, auth_token):
18
 
19
  def fast_caption(sys_prompt, usr_prompt, temp, top_p, max_tokens, model, key, endpoint, video_src, video_hf, video_hf_auth, parquet_index, video_od, video_od_auth, video_gd, video_gd_auth, frame_format, frame_limit):
20
  progress_info = []
21
- if video_src:
22
- video = video_src
23
- processor = VideoProcessor(frame_format=frame_format, frame_limit=frame_limit)
24
- frames = processor._decode(video)
25
- base64_list = processor.to_base64_list(frames)
26
- debug_image = processor.concatenate(frames)
27
- if not key or not endpoint:
28
- return "", f"API key or endpoint is missing. Processed {len(frames)} frames.", debug_image
29
- api = AzureAPI(key=key, endpoint=endpoint, model=model, temp=temp, top_p=top_p, max_tokens=max_tokens)
30
- caption = api.get_caption(sys_prompt, usr_prompt, base64_list)
31
- progress_info.append(f"Using model '{model}' with {len(frames)} frames extracted.")
32
- return f"{caption}", "\n".join(progress_info), debug_image
33
- elif video_hf and video_hf_auth:
34
- current_file_path = os.path.abspath(__file__)
35
- current_directory = os.path.dirname(current_file_path)
36
- progress_info.append('Begin processing Hugging Face dataset.')
37
- all_captions = []
38
- temp_parquet_file = hf_hub_download(
39
- repo_id=video_hf,
40
- filename='data/' + str(parquet_index).zfill(6) + '.parquet',
41
- repo_type="dataset",
42
- token=video_hf_auth,
43
- )
44
- parquet_file = pq.ParquetFile(temp_parquet_file)
45
- for batch in parquet_file.iter_batches(batch_size=1):
46
- df = batch.to_pandas()
47
- video = df['video'][0]
48
- md5 = hashlib.md5(video).hexdigest()
49
- with tempfile.NamedTemporaryFile(dir=current_directory) as temp_file:
50
- temp_file.write(video)
51
- video_path = temp_file.name
52
- processor = VideoProcessor(frame_format=frame_format, frame_limit=frame_limit)
53
- frames = processor._decode(video_path)
54
- base64_list = processor.to_base64_list(frames)
55
- api = AzureAPI(key=key, endpoint=endpoint, model=model, temp=temp, top_p=top_p, max_tokens=max_tokens)
56
- caption = api.get_caption(sys_prompt, usr_prompt, base64_list)
57
- all_captions.append(caption)
58
- progress_info.append(f"Processed video with MD5: {md5}")
59
- return "\n\n\n".join(all_captions), "\n".join(progress_info), None
60
- else:
61
- return "", "No video source selected.", None
 
 
 
 
 
 
62
 
63
  with gr.Blocks() as Core:
64
  with gr.Row(variant="panel"):
@@ -124,12 +130,12 @@ with gr.Blocks() as Core:
124
  video_gd = gr.Text()
125
  video_gd_auth = gr.Text(label="Google Drive Access Token")
126
  caption_button = gr.Button("Caption", variant="primary", size="lg")
 
127
  caption_button.click(
128
  fast_caption,
129
  inputs=[sys_prompt, usr_prompt, temp, top_p, max_tokens, model, key, endpoint, video_src, video_hf, video_hf_auth, parquet_index, video_od, video_od_auth, video_gd, video_gd_auth, frame_format, frame_limit],
130
- outputs=[result, info, frame]
131
  )
132
 
133
  if __name__ == "__main__":
134
- Core.launch()
135
-
 
9
  import pyarrow.parquet as pq
10
  import hashlib
11
  import os
12
+ import csv
13
 
14
  def load_hf_dataset(dataset_path, auth_token):
15
  dataset = load_dataset(dataset_path, token=auth_token)
 
18
 
19
  def fast_caption(sys_prompt, usr_prompt, temp, top_p, max_tokens, model, key, endpoint, video_src, video_hf, video_hf_auth, parquet_index, video_od, video_od_auth, video_gd, video_gd_auth, frame_format, frame_limit):
20
  progress_info = []
21
+ csv_filename = "captions.csv"
22
+ with open(csv_filename, mode='w', newline='') as csv_file:
23
+ fieldnames = ['md5', 'caption']
24
+ writer = csv.DictWriter(csv_file, fieldnames=fieldnames)
25
+ writer.writeheader()
26
+
27
+ if video_src:
28
+ video = video_src
29
+ processor = VideoProcessor(frame_format=frame_format, frame_limit=frame_limit)
30
+ frames = processor._decode(video)
31
+ base64_list = processor.to_base64_list(frames)
32
+ debug_image = processor.concatenate(frames)
33
+ if not key or not endpoint:
34
+ return "", f"API key or endpoint is missing. Processed {len(frames)} frames.", debug_image
35
+ api = AzureAPI(key=key, endpoint=endpoint, model=model, temp=temp, top_p=top_p, max_tokens=max_tokens)
36
+ caption = api.get_caption(sys_prompt, usr_prompt, base64_list)
37
+ progress_info.append(f"Using model '{model}' with {len(frames)} frames extracted.")
38
+ writer.writerow({'md5': 'single_video', 'caption': caption})
39
+ return f"{caption}", "\n".join(progress_info), debug_image
40
+ elif video_hf and video_hf_auth:
41
+ current_file_path = os.path.abspath(__file__)
42
+ current_directory = os.path.dirname(current_file_path)
43
+ progress_info.append('Begin processing Hugging Face dataset.')
44
+ temp_parquet_file = hf_hub_download(
45
+ repo_id=video_hf,
46
+ filename='data/' + str(parquet_index).zfill(6) + '.parquet',
47
+ repo_type="dataset",
48
+ token=video_hf_auth,
49
+ )
50
+ parquet_file = pq.ParquetFile(temp_parquet_file)
51
+ for batch in parquet_file.iter_batches(batch_size=1):
52
+ df = batch.to_pandas()
53
+ video = df['video'][0]
54
+ md5 = hashlib.md5(video).hexdigest()
55
+ with tempfile.NamedTemporaryFile(dir=current_directory) as temp_file:
56
+ temp_file.write(video)
57
+ video_path = temp_file.name
58
+ processor = VideoProcessor(frame_format=frame_format, frame_limit=frame_limit)
59
+ frames = processor._decode(video_path)
60
+ base64_list = processor.to_base64_list(frames)
61
+ api = AzureAPI(key=key, endpoint=endpoint, model=model, temp=temp, top_p=top_p, max_tokens=max_tokens)
62
+ caption = api.get_caption(sys_prompt, usr_prompt, base64_list)
63
+ writer.writerow({'md5': md5, 'caption': caption})
64
+ progress_info.append(f"Processed video with MD5: {md5}")
65
+ return csv_filename, "\n".join(progress_info), None
66
+ else:
67
+ return "", "No video source selected.", None
68
 
69
  with gr.Blocks() as Core:
70
  with gr.Row(variant="panel"):
 
130
  video_gd = gr.Text()
131
  video_gd_auth = gr.Text(label="Google Drive Access Token")
132
  caption_button = gr.Button("Caption", variant="primary", size="lg")
133
+ csv_link = gr.File(label="Download CSV", interactive=False)
134
  caption_button.click(
135
  fast_caption,
136
  inputs=[sys_prompt, usr_prompt, temp, top_p, max_tokens, model, key, endpoint, video_src, video_hf, video_hf_auth, parquet_index, video_od, video_od_auth, video_gd, video_gd_auth, frame_format, frame_limit],
137
+ outputs=[csv_link, info, frame]
138
  )
139
 
140
  if __name__ == "__main__":
141
+ Core.launch()