Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,394 Bytes
6fb3e63 8e370dd 6fb3e63 d2db553 6fb3e63 c0c2770 6fb3e63 d2db553 6fb3e63 c0c2770 6fb3e63 d3ed528 6fb3e63 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
import spaces
import torch
import gradio as gr
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
import tempfile
import os
# Model configuration, this model contains synthetic data
MODEL_ID = "alakxender/whisper-small-dv-full"
BATCH_SIZE = 8
FILE_LIMIT_MB = 1000
CHUNK_LENGTH_S = 10
STRIDE_LENGTH_S = [3,2]
# Device and dtype setup
device = 0 if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
# Initialize model with memory optimizations
model = AutoModelForSpeechSeq2Seq.from_pretrained(
MODEL_ID,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
use_safetensors=True
)
model.to(device)
# Initialize processor
processor = AutoProcessor.from_pretrained(MODEL_ID)
# Single pipeline initialization with all components
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
chunk_length_s=CHUNK_LENGTH_S,
stride_length_s=STRIDE_LENGTH_S,
batch_size=BATCH_SIZE,
torch_dtype=torch_dtype,
device=device,
)
# Define the generation arguments
# Define optimized generation arguments
def get_generate_kwargs(is_short_audio=False):
"""
Get appropriate generation parameters based on audio length.
Short audio transcription benefits from different parameters.
"""
common_kwargs = {
"max_new_tokens": model.config.max_target_positions-4,
"num_beams": 4,
"condition_on_prev_tokens": False,
}
if is_short_audio:
# Parameters optimized for short audio:
return {
**common_kwargs,
"compression_ratio_threshold": 1.5, # Balanced setting to avoid repetition
"no_speech_threshold": 0.4, # Higher threshold to reduce hallucinations
"repetition_penalty": 1.5, # Add penalty for repeated tokens
"return_timestamps": True, # Get timestamps for better segmentation
}
else:
# Parameters for longer audio:
return {
**common_kwargs,
"compression_ratio_threshold": 1.35, # Standard compression ratio for longer audio
"repetition_penalty": 1.2, # Light penalty for repeated tokens
}
# IMPORTANT: Fix for forced_decoder_ids error
# Remove forced_decoder_ids from the model's generation config
if hasattr(model.generation_config, 'forced_decoder_ids'):
print("Removing forced_decoder_ids from generation config")
model.generation_config.forced_decoder_ids = None
# Also check if it's in the model config
if hasattr(model.config, 'forced_decoder_ids'):
print("Removing forced_decoder_ids from model config")
delattr(model.config, 'forced_decoder_ids')
@spaces.GPU
def transcribe(audio_input):
if audio_input is None:
raise gr.Error("No audio file submitted! Please upload or record an audio file before submitting your request.")
try:
# Use the defined generate_kwargs dictionary
result = pipe(
audio_input,
generate_kwargs=get_generate_kwargs()
)
return result["text"]
except Exception as e:
# More detailed error logging might be helpful here if issues persist
print(f"Detailed Error: {e}")
raise gr.Error(f"Transcription failed: {str(e)}")
# Custom CSS with modern Gradio styling
custom_css = """
.thaana-textbox textarea {
font-size: 18px !important;
font-family: 'MV_Faseyha', 'Faruma', 'A_Faruma', 'Noto Sans Thaana', 'MV Boli' !important;
line-height: 1.8 !important;
direction: rtl !important;
}
"""
demo = gr.Blocks(css=custom_css)
file_transcribe = gr.Interface(
fn=transcribe,
inputs=[
gr.Audio(sources=["upload", "microphone"], type="filepath", label="Audio file"),
],
outputs=gr.Textbox(
label="",
lines=2,
elem_classes=["thaana-textbox"],
rtl=True
),
title="Transcribe Dhivehi Audio",
description=(
"Upload an audio file or record using your microphone to transcribe."
),
flagging_mode="never",
examples=[
["sample.mp3"]
],
api_name=False,
cache_examples=False
)
with demo:
gr.TabbedInterface([file_transcribe], ["Audio file"])
demo.queue().launch()
|