versae commited on
Commit
35af703
·
verified ·
1 Parent(s): d2a22e6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -1
app.py CHANGED
@@ -24,6 +24,8 @@ def pipe(file, return_timestamps=False):
24
  chunk_length_s=30,
25
  device=device,
26
  token=auth_token,
 
 
27
  )
28
  asr.model.config.forced_decoder_ids = asr.tokenizer.get_decoder_prompt_ids(
29
  language=lang,
@@ -31,7 +33,7 @@ def pipe(file, return_timestamps=False):
31
  no_timestamps=not return_timestamps,
32
  )
33
  asr.model.config.no_timestamps_token_id = asr.tokenizer.encode("<|notimestamps|>", add_special_tokens=False)[0]
34
- return asr(file, return_timestamps=return_timestamps)
35
 
36
  def transcribe(file, return_timestamps=False):
37
  if not return_timestamps:
 
24
  chunk_length_s=30,
25
  device=device,
26
  token=auth_token,
27
+ torch_dtype=torch.float16,
28
+ model_kwargs={"attn_implementation": "flash_attention_2"} if args.flash else {"attn_implementation": "sdpa"},
29
  )
30
  asr.model.config.forced_decoder_ids = asr.tokenizer.get_decoder_prompt_ids(
31
  language=lang,
 
33
  no_timestamps=not return_timestamps,
34
  )
35
  asr.model.config.no_timestamps_token_id = asr.tokenizer.encode("<|notimestamps|>", add_special_tokens=False)[0]
36
+ return asr(file, return_timestamps=return_timestamps, batch_size=24)
37
 
38
  def transcribe(file, return_timestamps=False):
39
  if not return_timestamps: