Whisper-WebUI / modules /whisper /insanely_fast_whisper_inference.py
jhj0517
Enable fintuned models
9e5ed74
import os
import time
import numpy as np
from typing import BinaryIO, Union, Tuple, List
import torch
from transformers import pipeline
from transformers.utils import is_flash_attn_2_available
import gradio as gr
from huggingface_hub import hf_hub_download
import whisper
from rich.progress import Progress, TimeElapsedColumn, BarColumn, TextColumn
from argparse import Namespace
from modules.utils.paths import (INSANELY_FAST_WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, UVR_MODELS_DIR, OUTPUT_DIR)
from modules.whisper.data_classes import *
from modules.whisper.base_transcription_pipeline import BaseTranscriptionPipeline
class InsanelyFastWhisperInference(BaseTranscriptionPipeline):
def __init__(self,
model_dir: str = INSANELY_FAST_WHISPER_MODELS_DIR,
diarization_model_dir: str = DIARIZATION_MODELS_DIR,
uvr_model_dir: str = UVR_MODELS_DIR,
output_dir: str = OUTPUT_DIR,
):
super().__init__(
model_dir=model_dir,
output_dir=output_dir,
diarization_model_dir=diarization_model_dir,
uvr_model_dir=uvr_model_dir
)
self.model_dir = model_dir
os.makedirs(self.model_dir, exist_ok=True)
self.available_models = self.get_model_paths()
def transcribe(self,
audio: Union[str, np.ndarray, torch.Tensor],
progress: gr.Progress = gr.Progress(),
*whisper_params,
) -> Tuple[List[Segment], float]:
"""
transcribe method for faster-whisper.
Parameters
----------
audio: Union[str, BinaryIO, np.ndarray]
Audio path or file binary or Audio numpy array
progress: gr.Progress
Indicator to show progress directly in gradio.
*whisper_params: tuple
Parameters related with whisper. This will be dealt with "WhisperParameters" data class
Returns
----------
segments_result: List[Segment]
list of Segment that includes start, end timestamps and transcribed text
elapsed_time: float
elapsed time for transcription
"""
start_time = time.time()
params = WhisperParams.from_list(list(whisper_params))
if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
self.update_model(params.model_size, params.compute_type, progress)
progress(0, desc="Transcribing...Progress is not shown in insanely-fast-whisper.")
with Progress(
TextColumn("[progress.description]{task.description}"),
BarColumn(style="yellow1", pulse_style="white"),
TimeElapsedColumn(),
) as progress:
progress.add_task("[yellow]Transcribing...", total=None)
kwargs = {
"no_speech_threshold": params.no_speech_threshold,
"temperature": params.temperature,
"compression_ratio_threshold": params.compression_ratio_threshold,
"logprob_threshold": params.log_prob_threshold,
}
if self.current_model_size.endswith(".en"):
pass
else:
kwargs["language"] = params.lang
kwargs["task"] = "translate" if params.is_translate else "transcribe"
segments = self.model(
inputs=audio,
return_timestamps=True,
chunk_length_s=params.chunk_length,
batch_size=params.batch_size,
generate_kwargs=kwargs
)
segments_result = []
for item in segments["chunks"]:
start, end = item["timestamp"][0], item["timestamp"][1]
if end is None:
end = start
segments_result.append(Segment(
text=item["text"],
start=start,
end=end
))
elapsed_time = time.time() - start_time
return segments_result, elapsed_time
def update_model(self,
model_size: str,
compute_type: str,
progress: gr.Progress = gr.Progress(),
):
"""
Update current model setting
Parameters
----------
model_size: str
Size of whisper model
compute_type: str
Compute type for transcription.
see more info : https://opennmt.net/CTranslate2/quantization.html
progress: gr.Progress
Indicator to show progress directly in gradio.
"""
progress(0, desc="Initializing Model..")
model_path = os.path.join(self.model_dir, model_size)
if not os.path.isdir(model_path) or not os.listdir(model_path):
self.download_model(
model_size=model_size,
download_root=model_path,
progress=progress
)
self.current_compute_type = compute_type
self.current_model_size = model_size
self.model = pipeline(
"automatic-speech-recognition",
model=os.path.join(self.model_dir, model_size),
torch_dtype=self.current_compute_type,
device=self.device,
model_kwargs={"attn_implementation": "flash_attention_2"} if is_flash_attn_2_available() else {"attn_implementation": "sdpa"},
)
def get_model_paths(self):
"""
Get available models from models path including fine-tuned model.
Returns
----------
Name set of models
"""
openai_models = whisper.available_models()
distil_models = ["distil-large-v2", "distil-large-v3", "distil-medium.en", "distil-small.en"]
default_models = openai_models + distil_models
existing_models = os.listdir(self.model_dir)
wrong_dirs = [".locks"]
available_models = default_models + existing_models
available_models = [model for model in available_models if model not in wrong_dirs]
available_models = sorted(set(available_models), key=available_models.index)
return available_models
@staticmethod
def download_model(
model_size: str,
download_root: str,
progress: gr.Progress
):
progress(0, 'Initializing model..')
print(f'Downloading {model_size} to "{download_root}"....')
os.makedirs(download_root, exist_ok=True)
download_list = [
"model.safetensors",
"config.json",
"generation_config.json",
"preprocessor_config.json",
"tokenizer.json",
"tokenizer_config.json",
"added_tokens.json",
"special_tokens_map.json",
"vocab.json",
]
if model_size.startswith("distil"):
repo_id = f"distil-whisper/{model_size}"
else:
repo_id = f"openai/whisper-{model_size}"
for item in download_list:
hf_hub_download(repo_id=repo_id, filename=item, local_dir=download_root)