versae commited on
Commit
311ebef
·
verified ·
1 Parent(s): 3da5e49

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -5
app.py CHANGED
@@ -7,6 +7,12 @@ import pytube as pt
7
  import spaces
8
  from transformers import pipeline
9
  from huggingface_hub import model_info
 
 
 
 
 
 
10
 
11
  MODEL_NAME = "NbAiLab/nb-whisper-large"
12
  lang = "no"
@@ -25,14 +31,14 @@ def pipe(file, return_timestamps=False):
25
  device=device,
26
  token=auth_token,
27
  torch_dtype=torch.float16,
28
- model_kwargs={"attn_implementation": "flash_attention_2"} if args.flash else {"attn_implementation": "sdpa"},
29
  )
30
  asr.model.config.forced_decoder_ids = asr.tokenizer.get_decoder_prompt_ids(
31
  language=lang,
32
  task="transcribe",
33
  no_timestamps=not return_timestamps,
34
  )
35
- asr.model.config.no_timestamps_token_id = asr.tokenizer.encode("<|notimestamps|>", add_special_tokens=False)[0]
36
  return asr(file, return_timestamps=return_timestamps, batch_size=24)
37
 
38
  def transcribe(file, return_timestamps=False):
@@ -106,6 +112,12 @@ yt_transcribe = gr.Interface(
106
  )
107
 
108
  with demo:
109
- gr.TabbedInterface([mf_transcribe, yt_transcribe], ["Transcribe Audio", "Transcribe YouTube"])
110
-
111
- demo.launch(share=True).queue()
 
 
 
 
 
 
 
7
  import spaces
8
  from transformers import pipeline
9
  from huggingface_hub import model_info
10
+ try:
11
+ import flash_attn
12
+ FLASH_ATTENTION = True
13
+ except ImportError:
14
+ FLASH_ATTENTION = False
15
+
16
 
17
  MODEL_NAME = "NbAiLab/nb-whisper-large"
18
  lang = "no"
 
31
  device=device,
32
  token=auth_token,
33
  torch_dtype=torch.float16,
34
+ model_kwargs={"attn_implementation": "flash_attention_2"} if FLASH_ATTENTION else {"attn_implementation": "sdpa"},
35
  )
36
  asr.model.config.forced_decoder_ids = asr.tokenizer.get_decoder_prompt_ids(
37
  language=lang,
38
  task="transcribe",
39
  no_timestamps=not return_timestamps,
40
  )
41
+ # asr.model.config.no_timestamps_token_id = asr.tokenizer.encode("<|notimestamps|>", add_special_tokens=False)[0]
42
  return asr(file, return_timestamps=return_timestamps, batch_size=24)
43
 
44
  def transcribe(file, return_timestamps=False):
 
112
  )
113
 
114
  with demo:
115
+ gr.TabbedInterface([
116
+ mf_transcribe,
117
+ # yt_transcribe
118
+ ], [
119
+ "Transcribe Audio",
120
+ # "Transcribe YouTube"
121
+ ])
122
+
123
+ demo.launch(share=share).queue()