File size: 3,671 Bytes
aa547ad
86a1f13
 
550cf61
 
b86a6f7
550cf61
aa547ad
9a9ac31
550cf61
 
aa547ad
 
550cf61
b86a6f7
aa547ad
550cf61
 
 
ee53092
550cf61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa547ad
550cf61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ee53092
550cf61
 
 
 
b86a6f7
550cf61
86a1f13
b86a6f7
 
550cf61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b86a6f7
aa547ad
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import spaces
import torch
import gradio as gr
import whisperx
from transformers.pipelines.audio_utils import ffmpeg_read
import tempfile
import gc
import os

# Constants
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 4  # reduce if low on GPU mem
COMPUTE_TYPE = "float32"  # change to "int8" if low on GPU mem
FILE_LIMIT_MB = 1000

@spaces.GPU
def transcribe_audio(inputs, task):
    if inputs is None:
        raise gr.Error("No audio file submitted! Please upload or record an audio file before submitting your request.")
    
    try:
        # Load audio
        if isinstance(inputs, str):
            # For file path input
            audio = whisperx.load_audio(inputs)
        else:
            # For microphone input (needs conversion)
            audio = whisperx.load_audio(inputs)
            
        # 1. Transcribe with base Whisper model
        model = whisperx.load_model("large-v3", device=DEVICE, compute_type=COMPUTE_TYPE)
        result = model.transcribe(audio, batch_size=BATCH_SIZE)
        
        # Clear GPU memory
        del model
        gc.collect()
        torch.cuda.empty_cache()
        
        # 2. Align whisper output
        model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=DEVICE)
        result = whisperx.align(result["segments"], model_a, metadata, audio, DEVICE, return_char_alignments=False)
        
        # Clear GPU memory again
        del model_a
        gc.collect()
        torch.cuda.empty_cache()
        
        # 3. Diarize audio
        diarize_model = whisperx.DiarizationPipeline(use_auth_token=os.environ["HF_TOKEN"], device=DEVICE)
        diarize_segments = diarize_model(audio)
        
        # 4. Assign speaker labels
        result = whisperx.assign_word_speakers(diarize_segments, result)
        
        # Format output
        output_text = ""
        for segment in result['segments']:
            speaker = segment.get('speaker', 'Unknown Speaker')
            text = segment['text']
            output_text += f"{speaker}: {text}\n"
        
        return output_text
        
    except Exception as e:
        raise gr.Error(f"Error processing audio: {str(e)}")
    
    finally:
        # Final cleanup
        gc.collect()
        torch.cuda.empty_cache()

# Create Gradio interface
demo = gr.Blocks(theme=gr.themes.Ocean())

with demo:
    gr.Markdown("# WhisperX: Advanced Speech Recognition with Speaker Diarization")
    
    with gr.Row():
        with gr.Column():
            audio_input = gr.Audio(
                sources=["microphone", "upload"],
                type="filepath",
                label="Audio Input (Microphone or File Upload)"
            )
            task = gr.Radio(
                ["transcribe", "translate"],
                label="Task",
                value="transcribe"
            )
            submit_button = gr.Button("Process Audio")
        
        with gr.Column():
            output_text = gr.Textbox(
                label="Transcription with Speaker Diarization",
                lines=10,
                placeholder="Transcribed text will appear here..."
            )
    
    gr.Markdown("""
    ### Features:
    - High-accuracy transcription using WhisperX
    - Automatic speaker diarization
    - Support for both microphone recording and file upload
    - GPU-accelerated processing
    
    ### Note:
    Processing may take a few moments depending on the audio length and system resources.
    """)
    
    submit_button.click(
        fn=transcribe_audio,
        inputs=[audio_input, task],
        outputs=output_text
    )

demo.queue().launch(ssr_mode=False)