import os import argparse from lang_list import LANGUAGE_NAME_TO_CODE, WHISPER_LANGUAGES from tqdm import tqdm import torch from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline from transformers.utils import is_flash_attn_2_available from time import time TRANSCRIPTOR_WHISPER = "openai/whisper-large-v3-turbo" # Time to transcribe: 296.53 seconds ==> minutes: 4.94 TRANSCRIPTOR_DISTIL_WHISPER = "distil-whisper/distil-large-v3" # Time to transcribe: 242.82 seconds ==> minutes: 4.05 TRANSCRIPTOR = TRANSCRIPTOR_DISTIL_WHISPER def get_language_dict(): language_dict = {} # Iterate over the LANGUAGE_NAME_TO_CODE dictionary for language_name, language_code in LANGUAGE_NAME_TO_CODE.items(): # Extract the language code (the first two characters before the underscore) lang_code = language_code.split('_')[0].lower() # Check if the language code is present in WHISPER_LANGUAGES if lang_code in WHISPER_LANGUAGES: # Construct the entry for the resulting dictionary language_dict[language_name] = { "transcriber": lang_code, "translator": language_code } return language_dict def transcription_to_dict(transcription): """ Convierte una transcripción en formato string a un diccionario estructurado. Args: transcription (str): String que contiene la transcripción con timestamps Returns: dict: Diccionario con el texto completo y los chunks con sus timestamps """ try: # Si la entrada es un string, convertirlo a diccionario if isinstance(transcription, str): # Evaluar el string como diccionario de Python transcription_dict = eval(transcription) else: transcription_dict = transcription # Validar la estructura del diccionario if not isinstance(transcription_dict, dict): raise ValueError("La transcripción no tiene el formato esperado") if 'text' not in transcription_dict or 'chunks' not in transcription_dict: raise ValueError("La transcripción no contiene los campos requeridos (text y chunks)") # Limpiar los chunks vacíos y validar timestamps cleaned_chunks = [] for chunk in transcription_dict['chunks']: # Verificar que el chunk tiene texto y timestamps válidos if (chunk.get('text') and isinstance(chunk.get('timestamp'), (list, tuple)) and len(chunk['timestamp']) == 2 and chunk['timestamp'][0] is not None and chunk['timestamp'][1] is not None): cleaned_chunks.append({ 'start': float(chunk['timestamp'][0]), # Convertir a float 'end': float(chunk['timestamp'][1]), # Convertir a float 'text': chunk['text'].strip() }) # Crear el diccionario final limpio result = { 'text': transcription_dict['text'], 'chunks': cleaned_chunks } return result except Exception as e: print(f"Error procesando la transcripción: {e}") return None def transcribe(audio_file, language, device, chunk_length_s=30, stride_length_s=5): """ Transcribe audio file using Whisper model. Args: audio_file (str): Path to audio file language (str): Language code for transcription device (str): Device to use for inference ('cuda' or 'cpu') chunk_length_s (int): Length of audio chunks in seconds stride_length_s (int): Stride length between chunks in seconds """ output_folder = "transcriptions" if not os.path.exists(output_folder): os.makedirs(output_folder) # Get output filename audio_filename = os.path.basename(audio_file) filename_without_ext = os.path.splitext(audio_filename)[0] output_file = os.path.join(output_folder, f"{filename_without_ext}.srt") device = torch.device(device) torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 # Load model and processor model_id = TRANSCRIPTOR t0 = time() # Configurar Flash Attention 2 si está disponible print(f"Using Flash Attention 2: {is_flash_attn_2_available()}") if TRANSCRIPTOR == TRANSCRIPTOR_WHISPER: model_kwargs = {"attn_implementation": "flash_attention_2"} if is_flash_attn_2_available() else {"attn_implementation": "sdpa"} model = AutoModelForSpeechSeq2Seq.from_pretrained( model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True, **model_kwargs ) else: model = AutoModelForSpeechSeq2Seq.from_pretrained( model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True, ) model.to(device) processor = AutoProcessor.from_pretrained(model_id) timestamp = True if TRANSCRIPTOR == TRANSCRIPTOR_DISTIL_WHISPER: timestamp = "word" else: timestamp = True # Create pipeline with timestamp generation if TRANSCRIPTOR == TRANSCRIPTOR_WHISPER: pipe = pipeline( "automatic-speech-recognition", model=model, tokenizer=processor.tokenizer, feature_extractor=processor.feature_extractor, torch_dtype=torch_dtype, device=device, chunk_length_s=chunk_length_s, stride_length_s=stride_length_s, return_timestamps=timestamp, max_new_tokens=128, batch_size=24, model_kwargs=model_kwargs ) else: pipe = pipeline( "automatic-speech-recognition", model=model, tokenizer=processor.tokenizer, feature_extractor=processor.feature_extractor, torch_dtype=torch_dtype, device=device, chunk_length_s=chunk_length_s, stride_length_s=stride_length_s, return_timestamps=timestamp, max_new_tokens=128, ) # Transcribe with timestamps and generate attention mask if TRANSCRIPTOR == TRANSCRIPTOR_WHISPER: result = pipe( audio_file, return_timestamps=timestamp, batch_size=24, generate_kwargs={ "language": language, "task": "transcribe", "use_cache": True, "num_beams": 1 } ) else: result = pipe( audio_file, return_timestamps=timestamp, generate_kwargs={ "language": language, "task": "transcribe", "use_cache": True, "num_beams": 1 } ) t = time() print(f"Time to transcribe: {t - t0:.2f} seconds") transcription_str = result transcription_dict = transcription_to_dict(transcription_str) return transcription_str, transcription_dict if __name__ == "__main__": parser = argparse.ArgumentParser(description='Transcribe audio files') parser.add_argument('input_files', help='Input audio files') parser.add_argument('language', help='Language of the audio file') parser.add_argument('num_speakers', help='Number of speakers in the audio file') parser.add_argument('device', help='Device to use for PyTorch inference') args = parser.parse_args() chunks_folder = "chunks" with open(args.input_files, 'r') as f: inputs = f.read().splitlines() progress_bar = tqdm(total=len(inputs), desc="Transcribe audio files progress") for input in inputs: input_file, _ = input.split('.') _, input_name = input_file.split('/') extension = "mp3" file = f'{chunks_folder}/{input_name}.{extension}' language_dict = get_language_dict() transcribe(file, language_dict[args.language]["transcriber"], args.num_speakers, args.device) progress_bar.update(1)