File size: 3,600 Bytes
544e017
 
3a18b3b
 
 
 
 
 
 
 
7a0f405
544e017
 
3a18b3b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
544e017
3a18b3b
544e017
 
 
7a0f405
544e017
 
 
 
 
7a0f405
 
 
 
 
 
 
 
339c131
7a0f405
 
 
 
 
3da96bb
544e017
3a18b3b
 
 
 
 
544e017
3da96bb
 
544e017
 
 
3a18b3b
544e017
 
 
 
 
 
 
 
3a18b3b
 
 
 
 
 
 
544e017
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import torchaudio
import torch
from transformers import (
    WhisperProcessor,
    AutoProcessor,
    AutoModelForSpeechSeq2Seq,
    AutoModelForCTC,
    Wav2Vec2Processor,
    Wav2Vec2ForCTC
)
import numpy as np

# Load processor and model
models_info = {
    "openai/whisper-small-uzbek": {
        "processor": WhisperProcessor.from_pretrained("openai/whisper-small", language="uzbek", task="transcribe"),
        "model": AutoModelForSpeechSeq2Seq.from_pretrained("openai/whisper-small"),
        "ctc_model": False
    },
    "ixxan/whisper-small-thugy20": {
        "processor": AutoProcessor.from_pretrained("ixxan/whisper-small-thugy20"),
        "model": AutoModelForSpeechSeq2Seq.from_pretrained("ixxan/whisper-small-thugy20"),
        "ctc_model": False
    },
    "ixxan/whisper-small-uyghur-common-voice": {
        "processor": AutoProcessor.from_pretrained("ixxan/whisper-small-uyghur-common-voice"),
        "model": AutoModelForSpeechSeq2Seq.from_pretrained("ixxan/whisper-small-uyghur-common-voice"),
        "ctc_model": False
    },
    "facebook/mms-1b-all": {
        "processor": AutoProcessor.from_pretrained("facebook/mms-1b-all", target_lang='uig-script_arabic'),
        "model": AutoModelForCTC.from_pretrained("facebook/mms-1b-all", target_lang='uig-script_arabic', ignore_mismatched_sizes=True),
        "ctc_model": True
    },
    # "ixxan/wav2vec2-large-mms-1b-uyghur-latin": {
    #     "processor": Wav2Vec2Processor.from_pretrained("ixxan/wav2vec2-large-mms-1b-uyghur-latin", target_lang='uig-script_latin'),
    #     "model": Wav2Vec2ForCTC.from_pretrained("ixxan/wav2vec2-large-mms-1b-uyghur-latin"),
    #     "ctc_model": True
    # },
}

def transcribe(audio_data, model_id) -> str:
    """
    Transcribes audio to text using the Whisper model for Uyghur.
    Args:
    - audio_data: Gradio audio input 
    Returns:
    - str: The transcription of the audio.
    """

    # Load audio file
    if not audio_data:
        return "<<ERROR: Empty Audio Input>>"

    if isinstance(audio_data, tuple):
        # microphone
        sampling_rate, audio_input = audio_data
        audio_input = (audio_input / 32768.0).astype(np.float32)
        
    elif isinstance(audio_data, str):
        # file upload
        audio_input, sampling_rate = torchaudio.load(audio_data)
        
    else:
        return "<<ERROR: Invalid Audio Input Instance: {}>>".format(type(audio_data))


    model = models_info[model_id]["model"]
    processor = models_info[model_id]["processor"]
    target_sr = processor.feature_extractor.sampling_rate
    ctc_model = models_info[model_id]["ctc_model"]

    # Resample if needed
    if sampling_rate != target_sr:
        resampler = torchaudio.transforms.Resample(sampling_rate, target_sr)
        audio_input = resampler(audio_input)

    # Preprocess the audio input
    inputs = processor(audio_input.squeeze(), sampling_rate=target_sr, return_tensors="pt", padding=True)

    # Move model to GPU if available
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.to(device)
    inputs = {key: val.to(device) for key, val in inputs.items()}

    # Generate transcription
    with torch.no_grad():
        if ctc_model:
            logits = model(**inputs).logits
            predicted_ids = torch.argmax(logits, dim=-1)
            transcription = processor.batch_decode(predicted_ids)[0]
        else:
            generated_ids = model.generate(inputs["input_features"], max_length=225)
            transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
    
    return transcription