jhj0517 commited on
Commit
501c404
·
1 Parent(s): 37be773

Update model usage

Browse files
modules/whisper/faster_whisper_inference.py CHANGED
@@ -62,7 +62,7 @@ class FasterWhisperInference(WhisperBase):
62
  """
63
  start_time = time.time()
64
 
65
- params = TranscriptionPipelineGradioComponents.as_value(*whisper_params)
66
 
67
  if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
68
  self.update_model(params.model_size, params.compute_type, progress)
 
62
  """
63
  start_time = time.time()
64
 
65
+ params = WhisperParams.from_list(list(whisper_params))
66
 
67
  if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
68
  self.update_model(params.model_size, params.compute_type, progress)
modules/whisper/insanely_fast_whisper_inference.py CHANGED
@@ -61,7 +61,7 @@ class InsanelyFastWhisperInference(WhisperBase):
61
  elapsed time for transcription
62
  """
63
  start_time = time.time()
64
- params = TranscriptionPipelineGradioComponents.as_value(*whisper_params)
65
 
66
  if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
67
  self.update_model(params.model_size, params.compute_type, progress)
 
61
  elapsed time for transcription
62
  """
63
  start_time = time.time()
64
+ params = WhisperParams.from_list(list(whisper_params))
65
 
66
  if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
67
  self.update_model(params.model_size, params.compute_type, progress)
modules/whisper/whisper_Inference.py CHANGED
@@ -51,7 +51,7 @@ class WhisperInference(WhisperBase):
51
  elapsed time for transcription
52
  """
53
  start_time = time.time()
54
- params = TranscriptionPipelineGradioComponents.as_value(*whisper_params)
55
 
56
  if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
57
  self.update_model(params.model_size, params.compute_type, progress)
 
51
  elapsed time for transcription
52
  """
53
  start_time = time.time()
54
+ params = WhisperParams.from_list(list(whisper_params))
55
 
56
  if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
57
  self.update_model(params.model_size, params.compute_type, progress)
modules/whisper/whisper_base.py CHANGED
@@ -74,7 +74,7 @@ class WhisperBase(ABC):
74
  audio: Union[str, BinaryIO, np.ndarray],
75
  progress: gr.Progress = gr.Progress(),
76
  add_timestamp: bool = True,
77
- *whisper_params,
78
  ) -> Tuple[List[dict], float]:
79
  """
80
  Run transcription with conditional pre-processing and post-processing.
@@ -89,8 +89,8 @@ class WhisperBase(ABC):
89
  Indicator to show progress directly in gradio.
90
  add_timestamp: bool
91
  Whether to add a timestamp at the end of the filename.
92
- *whisper_params: tuple
93
- Parameters related with whisper. This will be dealt with "WhisperParameters" data class
94
 
95
  Returns
96
  ----------
@@ -99,28 +99,29 @@ class WhisperBase(ABC):
99
  elapsed_time: float
100
  elapsed time for running
101
  """
102
- params = TranscriptionPipelineGradioComponents.as_value(*whisper_params)
 
103
 
104
  self.cache_parameters(
105
- whisper_params=params,
106
  add_timestamp=add_timestamp
107
  )
108
 
109
- if params.lang is None:
110
  pass
111
- elif params.lang == AUTOMATIC_DETECTION:
112
- params.lang = None
113
  else:
114
  language_code_dict = {value: key for key, value in whisper.tokenizer.LANGUAGES.items()}
115
- params.lang = language_code_dict[params.lang]
116
 
117
- if params.is_bgm_separate:
118
  music, audio, _ = self.music_separator.separate(
119
  audio=audio,
120
- model_name=params.uvr_model_size,
121
- device=params.uvr_device,
122
- segment_size=params.uvr_segment_size,
123
- save_file=params.uvr_save_file,
124
  progress=progress
125
  )
126
 
@@ -132,20 +133,20 @@ class WhisperBase(ABC):
132
  origin_sample_rate = self.music_separator.audio_info.sample_rate
133
  audio = self.resample_audio(audio=audio, original_sample_rate=origin_sample_rate)
134
 
135
- if params.uvr_enable_offload:
136
  self.music_separator.offload()
137
 
138
- if params.vad_filter:
139
  # Explicit value set for float('inf') from gr.Number()
140
- if params.max_speech_duration_s is None or params.max_speech_duration_s >= 9999:
141
- params.max_speech_duration_s = float('inf')
142
 
143
  vad_options = VadOptions(
144
- threshold=params.threshold,
145
- min_speech_duration_ms=params.min_speech_duration_ms,
146
- max_speech_duration_s=params.max_speech_duration_s,
147
- min_silence_duration_ms=params.min_silence_duration_ms,
148
- speech_pad_ms=params.speech_pad_ms
149
  )
150
 
151
  audio, speech_chunks = self.vad.run(
@@ -157,20 +158,21 @@ class WhisperBase(ABC):
157
  result, elapsed_time = self.transcribe(
158
  audio,
159
  progress,
160
- *astuple(params)
161
  )
162
 
163
- if params.vad_filter:
164
  result = self.vad.restore_speech_timestamps(
165
  segments=result,
166
- speech_chunks=speech_chunks,
167
  )
168
 
169
- if params.is_diarize:
170
  result, elapsed_time_diarization = self.diarizer.run(
171
  audio=audio,
172
- use_auth_token=params.hf_token,
173
  transcribed_result=result,
 
174
  )
175
  elapsed_time += elapsed_time_diarization
176
  return result, elapsed_time
@@ -181,7 +183,7 @@ class WhisperBase(ABC):
181
  file_format: str = "SRT",
182
  add_timestamp: bool = True,
183
  progress=gr.Progress(),
184
- *whisper_params,
185
  ) -> list:
186
  """
187
  Write subtitle file from Files
@@ -199,8 +201,8 @@ class WhisperBase(ABC):
199
  Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the subtitle filename.
200
  progress: gr.Progress
201
  Indicator to show progress directly in gradio.
202
- *whisper_params: tuple
203
- Parameters related with whisper. This will be dealt with "WhisperParameters" data class
204
 
205
  Returns
206
  ----------
@@ -223,7 +225,7 @@ class WhisperBase(ABC):
223
  file,
224
  progress,
225
  add_timestamp,
226
- *whisper_params,
227
  )
228
 
229
  file_name, file_ext = os.path.splitext(os.path.basename(file))
@@ -514,13 +516,14 @@ class WhisperBase(ABC):
514
 
515
  @staticmethod
516
  def cache_parameters(
517
- whisper_params: TranscriptionPipelineParams,
518
  add_timestamp: bool
519
  ):
520
  """cache parameters to the yaml file"""
521
  cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH)
522
- cached_whisper_param = whisper_params.to_yaml()
523
- cached_yaml = {**cached_params, **cached_whisper_param}
 
524
  cached_yaml["whisper"]["add_timestamp"] = add_timestamp
525
 
526
  save_yaml(cached_yaml, DEFAULT_PARAMETERS_CONFIG_PATH)
 
74
  audio: Union[str, BinaryIO, np.ndarray],
75
  progress: gr.Progress = gr.Progress(),
76
  add_timestamp: bool = True,
77
+ *pipeline_params,
78
  ) -> Tuple[List[dict], float]:
79
  """
80
  Run transcription with conditional pre-processing and post-processing.
 
89
  Indicator to show progress directly in gradio.
90
  add_timestamp: bool
91
  Whether to add a timestamp at the end of the filename.
92
+ *pipeline_params: tuple
93
+ Parameters for the transcription pipeline. This will be dealt with "TranscriptionPipelineParams" data class
94
 
95
  Returns
96
  ----------
 
99
  elapsed_time: float
100
  elapsed time for running
101
  """
102
+ params = TranscriptionPipelineParams.from_list(list(pipeline_params))
103
+ bgm_params, vad_params, whisper_params, diarization_params = params.bgm_separation, params.vad, params.whisper, params.diarization
104
 
105
  self.cache_parameters(
106
+ params=params,
107
  add_timestamp=add_timestamp
108
  )
109
 
110
+ if whisper_params.lang is None:
111
  pass
112
+ elif whisper_params.lang == AUTOMATIC_DETECTION:
113
+ whisper_params.lang = None
114
  else:
115
  language_code_dict = {value: key for key, value in whisper.tokenizer.LANGUAGES.items()}
116
+ whisper_params.lang = language_code_dict[params.lang]
117
 
118
+ if bgm_params.is_separate_bgm:
119
  music, audio, _ = self.music_separator.separate(
120
  audio=audio,
121
+ model_name=bgm_params.model_size,
122
+ device=bgm_params.device,
123
+ segment_size=bgm_params.segment_size,
124
+ save_file=bgm_params.save_file,
125
  progress=progress
126
  )
127
 
 
133
  origin_sample_rate = self.music_separator.audio_info.sample_rate
134
  audio = self.resample_audio(audio=audio, original_sample_rate=origin_sample_rate)
135
 
136
+ if bgm_params.enable_offload:
137
  self.music_separator.offload()
138
 
139
+ if vad_params.vad_filter:
140
  # Explicit value set for float('inf') from gr.Number()
141
+ if vad_params.max_speech_duration_s is None or vad_params.max_speech_duration_s >= 9999:
142
+ vad_params.max_speech_duration_s = float('inf')
143
 
144
  vad_options = VadOptions(
145
+ threshold=vad_params.threshold,
146
+ min_speech_duration_ms=vad_params.min_speech_duration_ms,
147
+ max_speech_duration_s=vad_params.max_speech_duration_s,
148
+ min_silence_duration_ms=vad_params.min_silence_duration_ms,
149
+ speech_pad_ms=vad_params.speech_pad_ms
150
  )
151
 
152
  audio, speech_chunks = self.vad.run(
 
158
  result, elapsed_time = self.transcribe(
159
  audio,
160
  progress,
161
+ *whisper_params.to_list()
162
  )
163
 
164
+ if vad_params.vad_filter:
165
  result = self.vad.restore_speech_timestamps(
166
  segments=result,
167
+ speech_chunks=vad_params.speech_chunks,
168
  )
169
 
170
+ if diarization_params.is_diarize:
171
  result, elapsed_time_diarization = self.diarizer.run(
172
  audio=audio,
173
+ use_auth_token=diarization_params.hf_token,
174
  transcribed_result=result,
175
+ device=diarization_params.device
176
  )
177
  elapsed_time += elapsed_time_diarization
178
  return result, elapsed_time
 
183
  file_format: str = "SRT",
184
  add_timestamp: bool = True,
185
  progress=gr.Progress(),
186
+ *params,
187
  ) -> list:
188
  """
189
  Write subtitle file from Files
 
201
  Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the subtitle filename.
202
  progress: gr.Progress
203
  Indicator to show progress directly in gradio.
204
+ *params: tuple
205
+ Parameters for the transcription pipeline. This will be dealt with "TranscriptionPipelineParams" data class
206
 
207
  Returns
208
  ----------
 
225
  file,
226
  progress,
227
  add_timestamp,
228
+ *params,
229
  )
230
 
231
  file_name, file_ext = os.path.splitext(os.path.basename(file))
 
516
 
517
  @staticmethod
518
  def cache_parameters(
519
+ params: TranscriptionPipelineParams,
520
  add_timestamp: bool
521
  ):
522
  """cache parameters to the yaml file"""
523
  cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH)
524
+ param_to_cache = params.to_dict()
525
+
526
+ cached_yaml = {**cached_params, **param_to_cache}
527
  cached_yaml["whisper"]["add_timestamp"] = add_timestamp
528
 
529
  save_yaml(cached_yaml, DEFAULT_PARAMETERS_CONFIG_PATH)
tests/test_transcription.py CHANGED
@@ -1,5 +1,5 @@
1
  from modules.whisper.whisper_factory import WhisperFactory
2
- from modules.whisper.data_classes import TranscriptionPipelineParams
3
  from modules.utils.paths import WEBUI_DIR
4
  from test_config import *
5
 
@@ -38,13 +38,21 @@ def test_transcribe(
38
  )
39
 
40
  hparams = TranscriptionPipelineParams(
41
- model_size=TEST_WHISPER_MODEL,
42
- vad_filter=vad_filter,
43
- is_bgm_separate=bgm_separation,
44
- compute_type=whisper_inferencer.current_compute_type,
45
- uvr_enable_offload=True,
46
- is_diarize=diarization,
47
- ).as_list()
 
 
 
 
 
 
 
 
48
 
49
  subtitle_str, file_path = whisper_inferencer.transcribe_file(
50
  [audio_path],
 
1
  from modules.whisper.whisper_factory import WhisperFactory
2
+ from modules.whisper.data_classes import *
3
  from modules.utils.paths import WEBUI_DIR
4
  from test_config import *
5
 
 
38
  )
39
 
40
  hparams = TranscriptionPipelineParams(
41
+ whisper=WhisperParams(
42
+ model_size=TEST_WHISPER_MODEL,
43
+ compute_type=whisper_inferencer.current_compute_type
44
+ ),
45
+ vad=VadParams(
46
+ vad_filter=vad_filter
47
+ ),
48
+ bgm_separation=BGMSeparationParams(
49
+ is_separate_bgm=bgm_separation,
50
+ enable_offload=True
51
+ ),
52
+ diarization=DiarizationParams(
53
+ is_diarize=diarization
54
+ ),
55
+ ).to_list()
56
 
57
  subtitle_str, file_path = whisper_inferencer.transcribe_file(
58
  [audio_path],