mrtroydev's picture
Upload folder using huggingface_hub
3883c60 verified
raw
history blame
3.53 kB
import gc
import os.path
from tempfile import NamedTemporaryFile
import torch
import whisper
from transformers import WhisperProcessor, WhisperForConditionalGeneration, AutomaticSpeechRecognitionPipeline
from gradio_client.client import DEFAULT_TEMP_DIR
processor: WhisperProcessor = None
model: WhisperForConditionalGeneration | AutomaticSpeechRecognitionPipeline = None
device: str = None
loaded_model: str = None
def get_official_models():
# return [
# 'openai/whisper-tiny.en',
# 'openai/whisper-small.en',
# 'openai/whisper-base.en',
# 'openai/whisper-medium.en',
# 'openai/whisper-tiny',
# 'openai/whisper-small',
# 'openai/whisper-base',
# 'openai/whisper-medium',
# 'openai/whisper-large',
# 'openai/whisper-large-v2'
# ]
return [
'tiny.en',
'small.en',
'base.en',
'medium.en',
'tiny',
'small',
'base',
'medium',
'large',
'large-v2'
]
def unload():
global model, processor, device, loaded_model
model = None
processor = None
device = None
loaded_model = None
gc.collect()
torch.cuda.empty_cache()
return 'Unloaded'
def load(pretrained_model='openai/whisper-base', map_device='cuda' if torch.cuda.is_available() else 'cpu'):
global model, processor, device, loaded_model
try:
if loaded_model != pretrained_model:
unload()
# model = pipeline('automatic-speech-recognition', pretrained_model, device=map_device, model_kwargs={'cache_dir': 'models/automatic-speech-recognition'})
model = whisper.load_model(pretrained_model, map_device, 'data/models/automatic-speech-recognition/whisper')
loaded_model = pretrained_model
device = map_device
return f'Loaded {pretrained_model}'
except Exception as e:
unload()
return f'Failed to load, {e}'
def transcribe(wav, files) -> tuple[tuple[int, torch.Tensor], list[str]]:
return transcribe_wav(wav), transcribe_files(files)
def transcribe_wav(wav):
global model, processor, device, loaded_model
if loaded_model is not None:
if wav is None:
return None
sr, wav = wav
import traceback
try:
if sr != 16000:
import torchaudio.functional as F
wav = F.resample((torch.tensor(wav).to(device).float() / 32767.0).mean(-1).squeeze().unsqueeze(0), sr, 16000).flatten().cpu().detach().numpy()
sr = 16000
return whisper.transcribe(model, wav)['text'].strip()
except Exception as e:
traceback.print_exception(e)
return f'Exception: {e}'
else:
return 'No model loaded! Please load a model.'
def transcribe_files(files: list) -> list[str]:
if files is None or len(files) == 0:
return []
out_list = []
global model, processor, device, loaded_model
if loaded_model is not None:
for f in files:
filename = os.path.basename(f.name)
print('Processing ', filename)
filename_noext, fileext = os.path.splitext(filename)
out_file = NamedTemporaryFile(dir=DEFAULT_TEMP_DIR, mode='w', delete=False, suffix='.txt', prefix=filename_noext, encoding='utf8')
out_file.write(whisper.transcribe(model, f.name)['text'].strip())
out_list.append(out_file.name)
return out_list
else:
return []