subtify / transcribe.py
Maximofn's picture
Refactor project structure and update dependencies
e015c08
raw
history blame
3.81 kB
import os
import argparse
from lang_list import LANGUAGE_NAME_TO_CODE, WHISPER_LANGUAGES
from tqdm import tqdm
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
def get_language_dict():
language_dict = {}
# Iterate over the LANGUAGE_NAME_TO_CODE dictionary
for language_name, language_code in LANGUAGE_NAME_TO_CODE.items():
# Extract the language code (the first two characters before the underscore)
lang_code = language_code.split('_')[0].lower()
# Check if the language code is present in WHISPER_LANGUAGES
if lang_code in WHISPER_LANGUAGES:
# Construct the entry for the resulting dictionary
language_dict[language_name] = {
"transcriber": lang_code,
"translator": language_code
}
return language_dict
def transcribe(audio_file, language, device, chunk_length_s=30, stride_length_s=5):
"""
Transcribe audio file using Whisper model.
Args:
audio_file (str): Path to audio file
language (str): Language code for transcription
device (str): Device to use for inference ('cuda' or 'cpu')
chunk_length_s (int): Length of audio chunks in seconds
stride_length_s (int): Stride length between chunks in seconds
"""
output_folder = "transcriptions"
if not os.path.exists(output_folder):
os.makedirs(output_folder)
# Get output filename
audio_filename = os.path.basename(audio_file)
filename_without_ext = os.path.splitext(audio_filename)[0]
output_file = os.path.join(output_folder, f"{filename_without_ext}.srt")
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
# Load model and processor
model_id = "openai/whisper-large-v3-turbo"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
use_safetensors=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
# Create pipeline with timestamp generation
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
torch_dtype=torch_dtype,
device=device,
chunk_length_s=chunk_length_s,
stride_length_s=stride_length_s,
return_timestamps=True
)
# Transcribe with timestamps and generate attention mask
result = pipe(
audio_file,
return_timestamps=True,
generate_kwargs={
"language": language,
"task": "transcribe",
"use_cache": True,
"num_beams": 1
}
)
print(result)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Transcribe audio files')
parser.add_argument('input_files', help='Input audio files')
parser.add_argument('language', help='Language of the audio file')
parser.add_argument('num_speakers', help='Number of speakers in the audio file')
parser.add_argument('device', help='Device to use for PyTorch inference')
args = parser.parse_args()
chunks_folder = "chunks"
with open(args.input_files, 'r') as f:
inputs = f.read().splitlines()
progress_bar = tqdm(total=len(inputs), desc="Transcribe audio files progress")
for input in inputs:
input_file, _ = input.split('.')
_, input_name = input_file.split('/')
extension = "mp3"
file = f'{chunks_folder}/{input_name}.{extension}'
language_dict = get_language_dict()
transcribe(file, language_dict[args.language]["transcriber"], args.num_speakers, args.device)
progress_bar.update(1)