lalalalalalalalalala commited on
Commit
c9e0b0f
·
verified ·
1 Parent(s): 8327237

Update run.py

Browse files
Files changed (1) hide show
  1. run.py +6 -16
run.py CHANGED
@@ -13,31 +13,27 @@ import os
13
 
14
  def load_hf_dataset(dataset_path, auth_token):
15
  dataset = load_dataset(dataset_path, token=auth_token)
16
-
17
  video_paths = dataset
18
-
19
  return video_paths
20
 
21
  def fast_caption(sys_prompt, usr_prompt, temp, top_p, max_tokens, model, key, endpoint, video_src, video_hf, video_hf_auth, parquet_index, video_od, video_od_auth, video_gd, video_gd_auth, frame_format, frame_limit):
 
22
  if video_src:
23
  video = video_src
24
  processor = VideoProcessor(frame_format=frame_format, frame_limit=frame_limit)
25
  frames = processor._decode(video)
26
-
27
  base64_list = processor.to_base64_list(frames)
28
  debug_image = processor.concatenate(frames)
29
-
30
  if not key or not endpoint:
31
  return "", f"API key or endpoint is missing. Processed {len(frames)} frames.", debug_image
32
-
33
  api = AzureAPI(key=key, endpoint=endpoint, model=model, temp=temp, top_p=top_p, max_tokens=max_tokens)
34
  caption = api.get_caption(sys_prompt, usr_prompt, base64_list)
35
- return f"{caption}", f"Using model '{model}' with {len(frames)} frames extracted.", debug_image
 
36
  elif video_hf and video_hf_auth:
37
  current_file_path = os.path.abspath(__file__)
38
  current_directory = os.path.dirname(current_file_path)
39
- print('begin video_hf')
40
- # Process all videos in the dataset
41
  all_captions = []
42
  temp_parquet_file = hf_hub_download(
43
  repo_id=video_hf,
@@ -45,28 +41,22 @@ def fast_caption(sys_prompt, usr_prompt, temp, top_p, max_tokens, model, key, en
45
  repo_type="dataset",
46
  token=video_hf_auth,
47
  )
48
- print(temp_parquet_file)
49
  parquet_file = pq.ParquetFile(temp_parquet_file)
50
-
51
  for batch in parquet_file.iter_batches(batch_size=1):
52
  df = batch.to_pandas()
53
  video = df['video'][0]
54
-
55
  md5 = hashlib.md5(video).hexdigest()
56
- print(md5)
57
  with tempfile.NamedTemporaryFile(dir=current_directory) as temp_file:
58
  temp_file.write(video)
59
  video_path = temp_file.name
60
-
61
  processor = VideoProcessor(frame_format=frame_format, frame_limit=frame_limit)
62
  frames = processor._decode(video_path)
63
  base64_list = processor.to_base64_list(frames)
64
  api = AzureAPI(key=key, endpoint=endpoint, model=model, temp=temp, top_p=top_p, max_tokens=max_tokens)
65
  caption = api.get_caption(sys_prompt, usr_prompt, base64_list)
66
  all_captions.append(caption)
67
-
68
- return "\n\n\n".join(all_captions), f"Processed {len(video_paths)} videos.", None
69
-
70
  else:
71
  return "", "No video source selected.", None
72
 
 
13
 
14
  def load_hf_dataset(dataset_path, auth_token):
15
  dataset = load_dataset(dataset_path, token=auth_token)
 
16
  video_paths = dataset
 
17
  return video_paths
18
 
19
  def fast_caption(sys_prompt, usr_prompt, temp, top_p, max_tokens, model, key, endpoint, video_src, video_hf, video_hf_auth, parquet_index, video_od, video_od_auth, video_gd, video_gd_auth, frame_format, frame_limit):
20
+ progress_info = []
21
  if video_src:
22
  video = video_src
23
  processor = VideoProcessor(frame_format=frame_format, frame_limit=frame_limit)
24
  frames = processor._decode(video)
 
25
  base64_list = processor.to_base64_list(frames)
26
  debug_image = processor.concatenate(frames)
 
27
  if not key or not endpoint:
28
  return "", f"API key or endpoint is missing. Processed {len(frames)} frames.", debug_image
 
29
  api = AzureAPI(key=key, endpoint=endpoint, model=model, temp=temp, top_p=top_p, max_tokens=max_tokens)
30
  caption = api.get_caption(sys_prompt, usr_prompt, base64_list)
31
+ progress_info.append(f"Using model '{model}' with {len(frames)} frames extracted.")
32
+ return f"{caption}", "\n".join(progress_info), debug_image
33
  elif video_hf and video_hf_auth:
34
  current_file_path = os.path.abspath(__file__)
35
  current_directory = os.path.dirname(current_file_path)
36
+ progress_info.append('Begin processing Hugging Face dataset.')
 
37
  all_captions = []
38
  temp_parquet_file = hf_hub_download(
39
  repo_id=video_hf,
 
41
  repo_type="dataset",
42
  token=video_hf_auth,
43
  )
 
44
  parquet_file = pq.ParquetFile(temp_parquet_file)
 
45
  for batch in parquet_file.iter_batches(batch_size=1):
46
  df = batch.to_pandas()
47
  video = df['video'][0]
 
48
  md5 = hashlib.md5(video).hexdigest()
 
49
  with tempfile.NamedTemporaryFile(dir=current_directory) as temp_file:
50
  temp_file.write(video)
51
  video_path = temp_file.name
 
52
  processor = VideoProcessor(frame_format=frame_format, frame_limit=frame_limit)
53
  frames = processor._decode(video_path)
54
  base64_list = processor.to_base64_list(frames)
55
  api = AzureAPI(key=key, endpoint=endpoint, model=model, temp=temp, top_p=top_p, max_tokens=max_tokens)
56
  caption = api.get_caption(sys_prompt, usr_prompt, base64_list)
57
  all_captions.append(caption)
58
+ progress_info.append(f"Processed video with MD5: {md5}")
59
+ return "\n\n\n".join(all_captions), "\n".join(progress_info), None
 
60
  else:
61
  return "", "No video source selected.", None
62