LAP-DEV commited on
Commit
b7531d2
·
verified ·
1 Parent(s): a4d8f97

Update modules/whisper/faster_whisper_inference.py

Browse files
modules/whisper/faster_whisper_inference.py CHANGED
@@ -1,5 +1,6 @@
1
  import os
2
  import time
 
3
  import numpy as np
4
  import torch
5
  from typing import BinaryIO, Union, Tuple, List
@@ -12,11 +13,11 @@ import gradio as gr
12
  from argparse import Namespace
13
 
14
  from modules.utils.paths import (FASTER_WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, UVR_MODELS_DIR, OUTPUT_DIR)
15
- from modules.whisper.whisper_parameter import *
16
- from modules.whisper.whisper_base import WhisperBase
17
 
18
 
19
- class FasterWhisperInference(WhisperBase):
20
  def __init__(self,
21
  model_dir: str = FASTER_WHISPER_MODELS_DIR,
22
  diarization_model_dir: str = DIARIZATION_MODELS_DIR,
@@ -35,14 +36,12 @@ class FasterWhisperInference(WhisperBase):
35
  self.model_paths = self.get_model_paths()
36
  self.device = self.get_device()
37
  self.available_models = self.model_paths.keys()
38
- self.available_compute_types = ctranslate2.get_supported_compute_types(
39
- "cuda") if self.device == "cuda" else ctranslate2.get_supported_compute_types("cpu")
40
 
41
  def transcribe(self,
42
  audio: Union[str, BinaryIO, np.ndarray],
43
  progress: gr.Progress = gr.Progress(),
44
  *whisper_params,
45
- ) -> Tuple[List[dict], float]:
46
  """
47
  transcribe method for faster-whisper.
48
 
@@ -57,32 +56,22 @@ class FasterWhisperInference(WhisperBase):
57
 
58
  Returns
59
  ----------
60
- segments_result: List[dict]
61
- list of dicts that includes start, end timestamps and transcribed text
62
  elapsed_time: float
63
  elapsed time for transcription
64
  """
65
  start_time = time.time()
66
 
67
- params = WhisperParameters.as_value(*whisper_params)
68
 
69
  if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
70
  self.update_model(params.model_size, params.compute_type, progress)
71
 
72
- # None parameters with Textboxes: https://github.com/gradio-app/gradio/issues/8723
73
- if not params.initial_prompt:
74
- params.initial_prompt = None
75
- if not params.prefix:
76
- params.prefix = None
77
- if not params.hotwords:
78
- params.hotwords = None
79
-
80
- params.suppress_tokens = self.format_suppress_tokens_str(params.suppress_tokens)
81
-
82
  segments, info = self.model.transcribe(
83
  audio=audio,
84
  language=params.lang,
85
- task="translate" if params.is_translate and self.current_model_size in self.translatable_models else "transcribe",
86
  beam_size=params.beam_size,
87
  log_prob_threshold=params.log_prob_threshold,
88
  no_speech_threshold=params.no_speech_threshold,
@@ -109,16 +98,12 @@ class FasterWhisperInference(WhisperBase):
109
  language_detection_segments=params.language_detection_segments,
110
  prompt_reset_on_temperature=params.prompt_reset_on_temperature,
111
  )
112
- progress(0, desc="Loading audio...")
113
 
114
  segments_result = []
115
  for segment in segments:
116
- progress(segment.start / info.duration, desc="Transcribing...")
117
- segments_result.append({
118
- "start": segment.start,
119
- "end": segment.end,
120
- "text": segment.text
121
- })
122
 
123
  elapsed_time = time.time() - start_time
124
  return segments_result, elapsed_time
@@ -134,21 +119,43 @@ class FasterWhisperInference(WhisperBase):
134
  Parameters
135
  ----------
136
  model_size: str
137
- Size of whisper model
 
138
  compute_type: str
139
  Compute type for transcription.
140
  see more info : https://opennmt.net/CTranslate2/quantization.html
141
  progress: gr.Progress
142
  Indicator to show progress directly in gradio.
143
  """
144
- progress(0, desc="Initializing Model...")
145
- self.current_model_size = self.model_paths[model_size]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  self.current_compute_type = compute_type
147
  self.model = faster_whisper.WhisperModel(
148
  device=self.device,
149
  model_size_or_path=self.current_model_size,
150
  download_root=self.model_dir,
151
- compute_type=self.current_compute_type
 
152
  )
153
 
154
  def get_model_paths(self):
@@ -163,7 +170,7 @@ class FasterWhisperInference(WhisperBase):
163
  faster_whisper_prefix = "models--Systran--faster-whisper-"
164
 
165
  existing_models = os.listdir(self.model_dir)
166
- wrong_dirs = [".locks"]
167
  existing_models = list(set(existing_models) - set(wrong_dirs))
168
 
169
  for model_name in existing_models:
@@ -189,4 +196,4 @@ class FasterWhisperInference(WhisperBase):
189
  raise ValueError("Invalid Suppress Tokens. The value must be type of List[int]")
190
  return suppress_tokens
191
  except Exception as e:
192
- raise ValueError("Invalid Suppress Tokens. The value must be type of List[int]")
 
1
  import os
2
  import time
3
+ import huggingface_hub
4
  import numpy as np
5
  import torch
6
  from typing import BinaryIO, Union, Tuple, List
 
13
  from argparse import Namespace
14
 
15
  from modules.utils.paths import (FASTER_WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, UVR_MODELS_DIR, OUTPUT_DIR)
16
+ from modules.whisper.data_classes import *
17
+ from modules.whisper.base_transcription_pipeline import BaseTranscriptionPipeline
18
 
19
 
20
+ class FasterWhisperInference(BaseTranscriptionPipeline):
21
  def __init__(self,
22
  model_dir: str = FASTER_WHISPER_MODELS_DIR,
23
  diarization_model_dir: str = DIARIZATION_MODELS_DIR,
 
36
  self.model_paths = self.get_model_paths()
37
  self.device = self.get_device()
38
  self.available_models = self.model_paths.keys()
 
 
39
 
40
  def transcribe(self,
41
  audio: Union[str, BinaryIO, np.ndarray],
42
  progress: gr.Progress = gr.Progress(),
43
  *whisper_params,
44
+ ) -> Tuple[List[Segment], float]:
45
  """
46
  transcribe method for faster-whisper.
47
 
 
56
 
57
  Returns
58
  ----------
59
+ segments_result: List[Segment]
60
+ list of Segment that includes start, end timestamps and transcribed text
61
  elapsed_time: float
62
  elapsed time for transcription
63
  """
64
  start_time = time.time()
65
 
66
+ params = WhisperParams.from_list(list(whisper_params))
67
 
68
  if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
69
  self.update_model(params.model_size, params.compute_type, progress)
70
 
 
 
 
 
 
 
 
 
 
 
71
  segments, info = self.model.transcribe(
72
  audio=audio,
73
  language=params.lang,
74
+ task="translate" if params.is_translate else "transcribe",
75
  beam_size=params.beam_size,
76
  log_prob_threshold=params.log_prob_threshold,
77
  no_speech_threshold=params.no_speech_threshold,
 
98
  language_detection_segments=params.language_detection_segments,
99
  prompt_reset_on_temperature=params.prompt_reset_on_temperature,
100
  )
101
+ progress(0, desc="Loading audio..")
102
 
103
  segments_result = []
104
  for segment in segments:
105
+ progress(segment.start / info.duration, desc="Transcribing..")
106
+ segments_result.append(Segment.from_faster_whisper(segment))
 
 
 
 
107
 
108
  elapsed_time = time.time() - start_time
109
  return segments_result, elapsed_time
 
119
  Parameters
120
  ----------
121
  model_size: str
122
+ Size of whisper model. If you enter the huggingface repo id, it will try to download the model
123
+ automatically from huggingface.
124
  compute_type: str
125
  Compute type for transcription.
126
  see more info : https://opennmt.net/CTranslate2/quantization.html
127
  progress: gr.Progress
128
  Indicator to show progress directly in gradio.
129
  """
130
+ progress(0, desc="Initializing Model..")
131
+
132
+ model_size_dirname = model_size.replace("/", "--") if "/" in model_size else model_size
133
+ if model_size not in self.model_paths and model_size_dirname not in self.model_paths:
134
+ print(f"Model is not detected. Trying to download \"{model_size}\" from huggingface to "
135
+ f"\"{os.path.join(self.model_dir, model_size_dirname)} ...")
136
+ huggingface_hub.snapshot_download(
137
+ model_size,
138
+ local_dir=os.path.join(self.model_dir, model_size_dirname),
139
+ )
140
+ self.model_paths = self.get_model_paths()
141
+ gr.Info(f"Model is downloaded with the name \"{model_size_dirname}\"")
142
+
143
+ self.current_model_size = self.model_paths[model_size_dirname]
144
+
145
+ local_files_only = False
146
+ hf_prefix = "models--Systran--faster-whisper-"
147
+ official_model_path = os.path.join(self.model_dir, hf_prefix+model_size)
148
+ if ((os.path.isdir(self.current_model_size) and os.path.exists(self.current_model_size)) or
149
+ (model_size in faster_whisper.available_models() and os.path.exists(official_model_path))):
150
+ local_files_only = True
151
+
152
  self.current_compute_type = compute_type
153
  self.model = faster_whisper.WhisperModel(
154
  device=self.device,
155
  model_size_or_path=self.current_model_size,
156
  download_root=self.model_dir,
157
+ compute_type=self.current_compute_type,
158
+ local_files_only=local_files_only
159
  )
160
 
161
  def get_model_paths(self):
 
170
  faster_whisper_prefix = "models--Systran--faster-whisper-"
171
 
172
  existing_models = os.listdir(self.model_dir)
173
+ wrong_dirs = [".locks", "faster_whisper_models_will_be_saved_here"]
174
  existing_models = list(set(existing_models) - set(wrong_dirs))
175
 
176
  for model_name in existing_models:
 
196
  raise ValueError("Invalid Suppress Tokens. The value must be type of List[int]")
197
  return suppress_tokens
198
  except Exception as e:
199
+ raise ValueError("Invalid Suppress Tokens. The value must be type of List[int]")