File size: 3,809 Bytes
4d57eee
 
1e1be2d
149089c
e015c08
 
4d57eee
1e1be2d
e015c08
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1e1be2d
e015c08
 
 
 
 
 
 
 
 
 
 
4d57eee
e015c08
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4d57eee
e015c08
4d57eee
 
 
fdb83d2
4d57eee
c7ef4ea
1e1be2d
4d57eee
fdb83d2
c7ef4ea
fdb83d2
 
 
149089c
 
fdb83d2
1e1be2d
 
c7ef4ea
 
e015c08
c7ef4ea
e015c08
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import os
import argparse
from lang_list import LANGUAGE_NAME_TO_CODE, WHISPER_LANGUAGES
from tqdm import tqdm
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline


def get_language_dict():
    language_dict = {}
    # Iterate over the LANGUAGE_NAME_TO_CODE dictionary
    for language_name, language_code in LANGUAGE_NAME_TO_CODE.items():
        # Extract the language code (the first two characters before the underscore)
        lang_code = language_code.split('_')[0].lower()
        
        # Check if the language code is present in WHISPER_LANGUAGES
        if lang_code in WHISPER_LANGUAGES:
            # Construct the entry for the resulting dictionary
            language_dict[language_name] = {
                "transcriber": lang_code,
                "translator": language_code
            }
    return language_dict

def transcribe(audio_file, language, device, chunk_length_s=30, stride_length_s=5):
    """
    Transcribe audio file using Whisper model.
    
    Args:
        audio_file (str): Path to audio file
        language (str): Language code for transcription
        device (str): Device to use for inference ('cuda' or 'cpu')
        chunk_length_s (int): Length of audio chunks in seconds
        stride_length_s (int): Stride length between chunks in seconds
    """
    output_folder = "transcriptions"
    if not os.path.exists(output_folder):
        os.makedirs(output_folder)

    # Get output filename
    audio_filename = os.path.basename(audio_file)
    filename_without_ext = os.path.splitext(audio_filename)[0]
    output_file = os.path.join(output_folder, f"{filename_without_ext}.srt")

    torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

    # Load model and processor
    model_id = "openai/whisper-large-v3-turbo"
    model = AutoModelForSpeechSeq2Seq.from_pretrained(
        model_id, 
        torch_dtype=torch_dtype, 
        low_cpu_mem_usage=True, 
        use_safetensors=True
    )
    model.to(device)

    processor = AutoProcessor.from_pretrained(model_id)

    # Create pipeline with timestamp generation
    pipe = pipeline(
        "automatic-speech-recognition",
        model=model,
        tokenizer=processor.tokenizer,
        feature_extractor=processor.feature_extractor,
        torch_dtype=torch_dtype,
        device=device,
        chunk_length_s=chunk_length_s,
        stride_length_s=stride_length_s,
        return_timestamps=True
    )

    # Transcribe with timestamps and generate attention mask
    result = pipe(
        audio_file,
        return_timestamps=True,
        generate_kwargs={
            "language": language,
            "task": "transcribe",
            "use_cache": True,
            "num_beams": 1
        }
    )

    print(result)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Transcribe audio files')
    parser.add_argument('input_files', help='Input audio files')
    parser.add_argument('language', help='Language of the audio file')
    parser.add_argument('num_speakers', help='Number of speakers in the audio file')
    parser.add_argument('device', help='Device to use for PyTorch inference')
    args = parser.parse_args()

    chunks_folder = "chunks"

    with open(args.input_files, 'r') as f:
        inputs = f.read().splitlines()
    
    progress_bar = tqdm(total=len(inputs), desc="Transcribe audio files progress")
    for input in inputs:
        input_file, _ = input.split('.')
        _, input_name = input_file.split('/')
        extension = "mp3"
        file = f'{chunks_folder}/{input_name}.{extension}'
        language_dict = get_language_dict()
        transcribe(file, language_dict[args.language]["transcriber"], args.num_speakers, args.device)
        progress_bar.update(1)