|
import os |
|
import argparse |
|
from lang_list import LANGUAGE_NAME_TO_CODE, WHISPER_LANGUAGES |
|
from tqdm import tqdm |
|
import torch |
|
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline |
|
|
|
|
|
def get_language_dict(): |
|
language_dict = {} |
|
|
|
for language_name, language_code in LANGUAGE_NAME_TO_CODE.items(): |
|
|
|
lang_code = language_code.split('_')[0].lower() |
|
|
|
|
|
if lang_code in WHISPER_LANGUAGES: |
|
|
|
language_dict[language_name] = { |
|
"transcriber": lang_code, |
|
"translator": language_code |
|
} |
|
return language_dict |
|
|
|
def transcribe(audio_file, language, device, chunk_length_s=30, stride_length_s=5): |
|
""" |
|
Transcribe audio file using Whisper model. |
|
|
|
Args: |
|
audio_file (str): Path to audio file |
|
language (str): Language code for transcription |
|
device (str): Device to use for inference ('cuda' or 'cpu') |
|
chunk_length_s (int): Length of audio chunks in seconds |
|
stride_length_s (int): Stride length between chunks in seconds |
|
""" |
|
output_folder = "transcriptions" |
|
if not os.path.exists(output_folder): |
|
os.makedirs(output_folder) |
|
|
|
|
|
audio_filename = os.path.basename(audio_file) |
|
filename_without_ext = os.path.splitext(audio_filename)[0] |
|
output_file = os.path.join(output_folder, f"{filename_without_ext}.srt") |
|
|
|
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 |
|
|
|
|
|
model_id = "openai/whisper-large-v3-turbo" |
|
model = AutoModelForSpeechSeq2Seq.from_pretrained( |
|
model_id, |
|
torch_dtype=torch_dtype, |
|
low_cpu_mem_usage=True, |
|
use_safetensors=True |
|
) |
|
model.to(device) |
|
|
|
processor = AutoProcessor.from_pretrained(model_id) |
|
|
|
|
|
pipe = pipeline( |
|
"automatic-speech-recognition", |
|
model=model, |
|
tokenizer=processor.tokenizer, |
|
feature_extractor=processor.feature_extractor, |
|
torch_dtype=torch_dtype, |
|
device=device, |
|
chunk_length_s=chunk_length_s, |
|
stride_length_s=stride_length_s, |
|
return_timestamps=True |
|
) |
|
|
|
|
|
result = pipe( |
|
audio_file, |
|
return_timestamps=True, |
|
generate_kwargs={ |
|
"language": language, |
|
"task": "transcribe", |
|
"use_cache": True, |
|
"num_beams": 1 |
|
} |
|
) |
|
|
|
print(result) |
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser(description='Transcribe audio files') |
|
parser.add_argument('input_files', help='Input audio files') |
|
parser.add_argument('language', help='Language of the audio file') |
|
parser.add_argument('num_speakers', help='Number of speakers in the audio file') |
|
parser.add_argument('device', help='Device to use for PyTorch inference') |
|
args = parser.parse_args() |
|
|
|
chunks_folder = "chunks" |
|
|
|
with open(args.input_files, 'r') as f: |
|
inputs = f.read().splitlines() |
|
|
|
progress_bar = tqdm(total=len(inputs), desc="Transcribe audio files progress") |
|
for input in inputs: |
|
input_file, _ = input.split('.') |
|
_, input_name = input_file.split('/') |
|
extension = "mp3" |
|
file = f'{chunks_folder}/{input_name}.{extension}' |
|
language_dict = get_language_dict() |
|
transcribe(file, language_dict[args.language]["transcriber"], args.num_speakers, args.device) |
|
progress_bar.update(1) |