File size: 8,265 Bytes
4d57eee
 
1e1be2d
149089c
e015c08
 
149ed58
 
 
 
 
 
4d57eee
1e1be2d
e015c08
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1e1be2d
149ed58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e015c08
 
 
 
 
 
 
 
 
 
 
4d57eee
e015c08
 
 
 
 
 
 
 
149ed58
e015c08
 
 
149ed58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e015c08
 
 
 
149ed58
 
 
 
 
 
e015c08
149ed58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e015c08
 
149ed58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4d57eee
149ed58
4d57eee
 
 
fdb83d2
4d57eee
c7ef4ea
1e1be2d
4d57eee
fdb83d2
c7ef4ea
fdb83d2
 
 
149089c
 
fdb83d2
1e1be2d
 
c7ef4ea
 
e015c08
c7ef4ea
e015c08
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
import os
import argparse
from lang_list import LANGUAGE_NAME_TO_CODE, WHISPER_LANGUAGES
from tqdm import tqdm
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from transformers.utils import is_flash_attn_2_available
from time import time

TRANSCRIPTOR_WHISPER = "openai/whisper-large-v3-turbo" # Time to transcribe: 296.53 seconds ==> minutes: 4.94
TRANSCRIPTOR_DISTIL_WHISPER = "distil-whisper/distil-large-v3" # Time to transcribe: 242.82 seconds ==> minutes: 4.05
TRANSCRIPTOR = TRANSCRIPTOR_DISTIL_WHISPER


def get_language_dict():
    language_dict = {}
    # Iterate over the LANGUAGE_NAME_TO_CODE dictionary
    for language_name, language_code in LANGUAGE_NAME_TO_CODE.items():
        # Extract the language code (the first two characters before the underscore)
        lang_code = language_code.split('_')[0].lower()
        
        # Check if the language code is present in WHISPER_LANGUAGES
        if lang_code in WHISPER_LANGUAGES:
            # Construct the entry for the resulting dictionary
            language_dict[language_name] = {
                "transcriber": lang_code,
                "translator": language_code
            }
    return language_dict

def transcription_to_dict(transcription):
    """
    Convierte una transcripci贸n en formato string a un diccionario estructurado.
    
    Args:
        transcription (str): String que contiene la transcripci贸n con timestamps
        
    Returns:
        dict: Diccionario con el texto completo y los chunks con sus timestamps
    """
    try:
        # Si la entrada es un string, convertirlo a diccionario
        if isinstance(transcription, str):
            # Evaluar el string como diccionario de Python
            transcription_dict = eval(transcription)
        else:
            transcription_dict = transcription
            
        # Validar la estructura del diccionario
        if not isinstance(transcription_dict, dict):
            raise ValueError("La transcripci贸n no tiene el formato esperado")
            
        if 'text' not in transcription_dict or 'chunks' not in transcription_dict:
            raise ValueError("La transcripci贸n no contiene los campos requeridos (text y chunks)")
            
        # Limpiar los chunks vac铆os y validar timestamps
        cleaned_chunks = []
        for chunk in transcription_dict['chunks']:
            # Verificar que el chunk tiene texto y timestamps v谩lidos
            if (chunk.get('text') and 
                isinstance(chunk.get('timestamp'), (list, tuple)) and 
                len(chunk['timestamp']) == 2 and 
                chunk['timestamp'][0] is not None and 
                chunk['timestamp'][1] is not None):
                
                cleaned_chunks.append({
                    'start': float(chunk['timestamp'][0]),  # Convertir a float
                    'end': float(chunk['timestamp'][1]),    # Convertir a float
                    'text': chunk['text'].strip()
                })
                
        # Crear el diccionario final limpio
        result = {
            'text': transcription_dict['text'],
            'chunks': cleaned_chunks
        }
        
        return result
        
    except Exception as e:
        print(f"Error procesando la transcripci贸n: {e}")
        return None

def transcribe(audio_file, language, device, chunk_length_s=30, stride_length_s=5):
    """
    Transcribe audio file using Whisper model.
    
    Args:
        audio_file (str): Path to audio file
        language (str): Language code for transcription
        device (str): Device to use for inference ('cuda' or 'cpu')
        chunk_length_s (int): Length of audio chunks in seconds
        stride_length_s (int): Stride length between chunks in seconds
    """
    output_folder = "transcriptions"
    if not os.path.exists(output_folder):
        os.makedirs(output_folder)

    # Get output filename
    audio_filename = os.path.basename(audio_file)
    filename_without_ext = os.path.splitext(audio_filename)[0]
    output_file = os.path.join(output_folder, f"{filename_without_ext}.srt")

    device = torch.device(device)
    torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

    # Load model and processor
    model_id = TRANSCRIPTOR
    t0 = time()
    
    # Configurar Flash Attention 2 si est谩 disponible
    print(f"Using Flash Attention 2: {is_flash_attn_2_available()}")
    if TRANSCRIPTOR == TRANSCRIPTOR_WHISPER:
        model_kwargs = {"attn_implementation": "flash_attention_2"} if is_flash_attn_2_available() else {"attn_implementation": "sdpa"}
        model = AutoModelForSpeechSeq2Seq.from_pretrained(
            model_id, 
            torch_dtype=torch_dtype, 
            low_cpu_mem_usage=True, 
            use_safetensors=True,
            **model_kwargs
        )
    else:
        model = AutoModelForSpeechSeq2Seq.from_pretrained(
            model_id, 
            torch_dtype=torch_dtype, 
            low_cpu_mem_usage=True, 
            use_safetensors=True,
        )    
    model.to(device)

    processor = AutoProcessor.from_pretrained(model_id)

    timestamp = True
    if TRANSCRIPTOR == TRANSCRIPTOR_DISTIL_WHISPER:
        timestamp = "word"
    else:
        timestamp = True

    # Create pipeline with timestamp generation
    if TRANSCRIPTOR == TRANSCRIPTOR_WHISPER:
        pipe = pipeline(
            "automatic-speech-recognition",
            model=model,
            tokenizer=processor.tokenizer,
            feature_extractor=processor.feature_extractor,
            torch_dtype=torch_dtype,
            device=device,
            chunk_length_s=chunk_length_s,
            stride_length_s=stride_length_s,
            return_timestamps=timestamp,
            max_new_tokens=128,
            batch_size=24,
            model_kwargs=model_kwargs 
        )
    else:
        pipe = pipeline(
            "automatic-speech-recognition",
            model=model,
            tokenizer=processor.tokenizer,
            feature_extractor=processor.feature_extractor,
            torch_dtype=torch_dtype,
            device=device,
            chunk_length_s=chunk_length_s,
            stride_length_s=stride_length_s,
            return_timestamps=timestamp,
            max_new_tokens=128,
        )

    # Transcribe with timestamps and generate attention mask
    if TRANSCRIPTOR == TRANSCRIPTOR_WHISPER:
        result = pipe(
            audio_file,
            return_timestamps=timestamp,
            batch_size=24,
            generate_kwargs={
                "language": language,
                "task": "transcribe",
                "use_cache": True,
                "num_beams": 1
            }
        )
    else:
        result = pipe(
            audio_file,
            return_timestamps=timestamp,
            generate_kwargs={
                "language": language,
                "task": "transcribe",
                "use_cache": True,
                "num_beams": 1
            }
        )
    
    t = time()
    print(f"Time to transcribe: {t - t0:.2f} seconds")

    transcription_str = result
    transcription_dict = transcription_to_dict(transcription_str)

    return transcription_str, transcription_dict

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Transcribe audio files')
    parser.add_argument('input_files', help='Input audio files')
    parser.add_argument('language', help='Language of the audio file')
    parser.add_argument('num_speakers', help='Number of speakers in the audio file')
    parser.add_argument('device', help='Device to use for PyTorch inference')
    args = parser.parse_args()

    chunks_folder = "chunks"

    with open(args.input_files, 'r') as f:
        inputs = f.read().splitlines()
    
    progress_bar = tqdm(total=len(inputs), desc="Transcribe audio files progress")
    for input in inputs:
        input_file, _ = input.split('.')
        _, input_name = input_file.split('/')
        extension = "mp3"
        file = f'{chunks_folder}/{input_name}.{extension}'
        language_dict = get_language_dict()
        transcribe(file, language_dict[args.language]["transcriber"], args.num_speakers, args.device)
        progress_bar.update(1)