lalalalalalalalalala commited on
Commit
352a2c1
·
verified ·
1 Parent(s): 4003c7a

Update run.py

Browse files
Files changed (1) hide show
  1. run.py +2 -2
run.py CHANGED
@@ -26,6 +26,8 @@ def load_hf_dataset(dataset_path, auth_token):
26
 
27
  def fast_caption(sys_prompt, usr_prompt, temp, top_p, max_tokens, model, key, endpoint, video_hf, video_hf_auth, parquet_index, video_od, video_od_auth, video_gd, video_gd_auth, frame_format, frame_limit):
28
  progress_info = []
 
 
29
  with tempfile.TemporaryDirectory() as temp_dir:
30
  # temp_dir = '/opt/run'
31
  csv_filename = os.path.join('/dev/shm', str(parquet_index).zfill(6) + '_gpt4o_caption.csv')
@@ -59,10 +61,8 @@ def fast_caption(sys_prompt, usr_prompt, temp, top_p, max_tokens, model, key, en
59
  _chunk.append(_v.name)
60
  md5 = hashlib.md5(binary).hexdigest()
61
 
62
- processor = VideoProcessor(frame_format=frame_format, frame_limit=frame_limit)
63
  frames = processor._decode(_v.name)
64
  base64_list = processor.to_base64_list(frames)
65
- api = AzureAPI(key=key, endpoint=endpoint, model=model, temp=temp, top_p=top_p, max_tokens=max_tokens)
66
  caption = api.get_caption(sys_prompt, usr_prompt, base64_list)
67
  writer.writerow({'md5': md5, 'caption': caption})
68
  # writer.writerow({'md5': md5, 'caption': 'caption'})
 
26
 
27
  def fast_caption(sys_prompt, usr_prompt, temp, top_p, max_tokens, model, key, endpoint, video_hf, video_hf_auth, parquet_index, video_od, video_od_auth, video_gd, video_gd_auth, frame_format, frame_limit):
28
  progress_info = []
29
+ processor = VideoProcessor(frame_format=frame_format, frame_limit=frame_limit)
30
+ api = AzureAPI(key=key, endpoint=endpoint, model=model, temp=temp, top_p=top_p, max_tokens=max_tokens)
31
  with tempfile.TemporaryDirectory() as temp_dir:
32
  # temp_dir = '/opt/run'
33
  csv_filename = os.path.join('/dev/shm', str(parquet_index).zfill(6) + '_gpt4o_caption.csv')
 
61
  _chunk.append(_v.name)
62
  md5 = hashlib.md5(binary).hexdigest()
63
 
 
64
  frames = processor._decode(_v.name)
65
  base64_list = processor.to_base64_list(frames)
 
66
  caption = api.get_caption(sys_prompt, usr_prompt, base64_list)
67
  writer.writerow({'md5': md5, 'caption': caption})
68
  # writer.writerow({'md5': md5, 'caption': 'caption'})