File size: 4,394 Bytes
6fb3e63
 
 
 
 
 
 
 
 
 
 
8e370dd
 
6fb3e63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d2db553
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6fb3e63
c0c2770
 
 
 
 
 
 
 
 
 
 
6fb3e63
 
 
 
 
 
 
 
 
d2db553
6fb3e63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c0c2770
6fb3e63
 
 
d3ed528
 
6fb3e63
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import spaces
import torch
import gradio as gr
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
import tempfile
import os

# Model configuration, this model contains synthetic data
MODEL_ID = "alakxender/whisper-small-dv-full"
BATCH_SIZE = 8
FILE_LIMIT_MB = 1000
CHUNK_LENGTH_S = 10
STRIDE_LENGTH_S = [3,2]

# Device and dtype setup
device = 0 if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

# Initialize model with memory optimizations
model = AutoModelForSpeechSeq2Seq.from_pretrained(
    MODEL_ID, 
    torch_dtype=torch_dtype,
    low_cpu_mem_usage=True,
    use_safetensors=True
)
model.to(device)

# Initialize processor
processor = AutoProcessor.from_pretrained(MODEL_ID)

# Single pipeline initialization with all components
pipe = pipeline(
    "automatic-speech-recognition",
    model=model,
    tokenizer=processor.tokenizer,
    feature_extractor=processor.feature_extractor,
    chunk_length_s=CHUNK_LENGTH_S,
    stride_length_s=STRIDE_LENGTH_S,
    batch_size=BATCH_SIZE,
    torch_dtype=torch_dtype,
    device=device,
)

# Define the generation arguments

# Define optimized generation arguments
def get_generate_kwargs(is_short_audio=False):
    """
    Get appropriate generation parameters based on audio length.
    Short audio transcription benefits from different parameters.
    """
    common_kwargs = {
        "max_new_tokens": model.config.max_target_positions-4,
        "num_beams": 4,
        "condition_on_prev_tokens": False,
    }
    
    if is_short_audio:
        # Parameters optimized for short audio:
        return {
            **common_kwargs,
            "compression_ratio_threshold": 1.5,     # Balanced setting to avoid repetition
            "no_speech_threshold": 0.4,             # Higher threshold to reduce hallucinations
            "repetition_penalty": 1.5,              # Add penalty for repeated tokens
            "return_timestamps": True,              # Get timestamps for better segmentation
        }
    else:
        # Parameters for longer audio:
        return {
            **common_kwargs,
            "compression_ratio_threshold": 1.35,    # Standard compression ratio for longer audio
            "repetition_penalty": 1.2,              # Light penalty for repeated tokens
        }

# IMPORTANT: Fix for forced_decoder_ids error
# Remove forced_decoder_ids from the model's generation config
if hasattr(model.generation_config, 'forced_decoder_ids'):
    print("Removing forced_decoder_ids from generation config")
    model.generation_config.forced_decoder_ids = None

# Also check if it's in the model config
if hasattr(model.config, 'forced_decoder_ids'):
    print("Removing forced_decoder_ids from model config")
    delattr(model.config, 'forced_decoder_ids')
    
@spaces.GPU
def transcribe(audio_input):
    if audio_input is None:
        raise gr.Error("No audio file submitted! Please upload or record an audio file before submitting your request.")
    
    try:
        # Use the defined generate_kwargs dictionary
        result = pipe(
            audio_input,
            generate_kwargs=get_generate_kwargs()
        )
        return result["text"]
    except Exception as e:
        # More detailed error logging might be helpful here if issues persist
        print(f"Detailed Error: {e}") 
        raise gr.Error(f"Transcription failed: {str(e)}")

# Custom CSS with modern Gradio styling
custom_css = """
.thaana-textbox textarea {
    font-size: 18px !important;
    font-family: 'MV_Faseyha', 'Faruma', 'A_Faruma', 'Noto Sans Thaana', 'MV Boli' !important;
    line-height: 1.8 !important;
    direction: rtl !important;
}
"""

demo = gr.Blocks(css=custom_css)

file_transcribe = gr.Interface(
    fn=transcribe,
    inputs=[
        gr.Audio(sources=["upload", "microphone"], type="filepath", label="Audio file"),
    ],
    outputs=gr.Textbox(
        label="",
        lines=2,
        elem_classes=["thaana-textbox"],
        rtl=True
    ),
    title="Transcribe Dhivehi Audio",
    description=(
        "Upload an audio file or record using your microphone to transcribe."
    ),
    flagging_mode="never",
    examples=[
        ["sample.mp3"]  
    ],
    api_name=False,
    cache_examples=False
)

with demo:
    gr.TabbedInterface([file_transcribe], ["Audio file"])

demo.queue().launch()