File size: 4,624 Bytes
544e017
 
3a18b3b
 
 
 
 
 
 
 
7a0f405
1dfec92
544e017
 
3a18b3b
bef8623
3a18b3b
 
1dfec92
 
3a18b3b
7bc4048
 
 
 
 
 
3a18b3b
 
 
1dfec92
 
3a18b3b
 
 
 
1dfec92
 
3a18b3b
70da837
 
 
1dfec92
 
70da837
3a18b3b
544e017
1dfec92
 
 
 
 
ef107e3
1dfec92
 
 
 
 
544e017
1dfec92
6502e85
7a0f405
 
 
 
339c131
7a0f405
 
6502e85
71494c3
3da96bb
71494c3
 
 
3527591
71494c3
3a18b3b
 
6502e85
3a18b3b
 
544e017
3da96bb
 
544e017
6502e85
544e017
 
6502e85
544e017
 
 
 
 
 
 
 
3a18b3b
 
 
 
 
 
 
544e017
1dfec92
 
 
 
 
 
81e83c9
1dfec92
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import torchaudio
import torch
from transformers import (
    WhisperProcessor,
    AutoProcessor,
    AutoModelForSpeechSeq2Seq,
    AutoModelForCTC,
    Wav2Vec2Processor,
    Wav2Vec2ForCTC
)
import numpy as np
import util

# Load processor and model
models_info = {
    "OpenAI-Whisper-Uzbek": {
        "processor": WhisperProcessor.from_pretrained("openai/whisper-small", language="uzbek", task="transcribe"),
        "model": AutoModelForSpeechSeq2Seq.from_pretrained("openai/whisper-small"),
        "ctc_model": False,
        "arabic_script": False
    },
    "Meta-MMS-Uyghur": {
        "processor": AutoProcessor.from_pretrained("facebook/mms-1b-all", target_lang='uig-script_arabic'),
        "model": AutoModelForCTC.from_pretrained("facebook/mms-1b-all", target_lang='uig-script_arabic', ignore_mismatched_sizes=True),
        "ctc_model": True,
        "arabic_script": True
    },
    "ixxan/whisper-small-thugy20": {
        "processor": AutoProcessor.from_pretrained("ixxan/whisper-small-thugy20"),
        "model": AutoModelForSpeechSeq2Seq.from_pretrained("ixxan/whisper-small-thugy20"),
        "ctc_model": False,
        "arabic_script": False
    },
    "ixxan/whisper-small-uyghur-common-voice": {
        "processor": AutoProcessor.from_pretrained("ixxan/whisper-small-uyghur-common-voice"),
        "model": AutoModelForSpeechSeq2Seq.from_pretrained("ixxan/whisper-small-uyghur-common-voice"),
        "ctc_model": False,
        "arabic_script": False
    },
    "ixxan/wav2vec2-large-mms-1b-uyghur-latin": {
        "processor": Wav2Vec2Processor.from_pretrained("ixxan/wav2vec2-large-mms-1b-uyghur-latin", target_lang='uig-script_latin'),
        "model": Wav2Vec2ForCTC.from_pretrained("ixxan/wav2vec2-large-mms-1b-uyghur-latin", target_lang='uig-script_latin'),
        "ctc_model": True,
        "arabic_script": False
    },
}

# def transcribe(audio_data, model_id) -> str:
#     if model_id == "Compare All Models":
#         return transcribe_all_models(audio_data)
#     else:
#         return transcribe_with_model(audio_data, model_id)

# def transcribe_all_models(audio_data) -> dict:
#     transcriptions = {}
#     for model_id in models_info.keys():
#         transcriptions[model_id] = transcribe_with_model(audio_data, model_id)
#     return transcriptions

def transcribe(audio_data, model_id) -> str:
    # Load user audio
    if isinstance(audio_data, tuple):
        # microphone
        sampling_rate, audio_input = audio_data
        audio_input = (audio_input / 32768.0).astype(np.float32)
    elif isinstance(audio_data, str):
        # file upload
        audio_input, sampling_rate = torchaudio.load(audio_data)
    else: 
        return "<<ERROR: Invalid Audio Input Instance: {}>>".format(type(audio_data)), None

    # Check audio duration
    duration = audio_input.shape[1] / sampling_rate
    if duration > 10:
        return f"<<ERROR: Audio duration ({duration:.2f}s) exceeds 10 seconds. Please upload a shorter audio clip for faster processing.>>", None
    
    model = models_info[model_id]["model"]
    processor = models_info[model_id]["processor"]
    target_sr = processor.feature_extractor.sampling_rate
    ctc_model = models_info[model_id]["ctc_model"]

    # Resample if needed
    if sampling_rate != target_sr:
        resampler = torchaudio.transforms.Resample(sampling_rate, target_sr)
        audio_input = resampler(audio_input)
        sampling_rate = target_sr

    # Preprocess the audio input
    inputs = processor(audio_input.squeeze(), sampling_rate=sampling_rate, return_tensors="pt")

    # Move model to GPU if available
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.to(device)
    inputs = {key: val.to(device) for key, val in inputs.items()}

    # Generate transcription
    with torch.no_grad():
        if ctc_model:
            logits = model(**inputs).logits
            predicted_ids = torch.argmax(logits, dim=-1)
            transcription = processor.batch_decode(predicted_ids)[0]
        else:
            generated_ids = model.generate(inputs["input_features"], max_length=225)
            transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
    
    if models_info[model_id]["arabic_script"]:
        transcription_arabic = transcription
        transcription_latin = util.ug_arab_to_latn(transcription)
    else: # Latin script output
        transcription_arabic = util.ug_latn_to_arab(transcription)
        transcription_latin = transcription
    print(model_id, transcription_arabic, transcription_latin)
    return transcription_arabic, transcription_latin