Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,31 +1,47 @@
|
|
|
|
|
|
1 |
import re
|
2 |
import gradio as gr
|
3 |
from transformers import WhisperProcessor, WhisperForConditionalGeneration, pipeline
|
4 |
import torch
|
5 |
import numpy as np
|
|
|
|
|
6 |
|
7 |
# Load Whisper model for transcription
|
8 |
whisper_model_name = "openai/whisper-large"
|
9 |
processor = WhisperProcessor.from_pretrained(whisper_model_name)
|
10 |
model = WhisperForConditionalGeneration.from_pretrained(whisper_model_name)
|
11 |
|
12 |
-
# Initialize the language detection model
|
13 |
lang_detect_model = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
|
14 |
|
|
|
|
|
|
|
15 |
# Function to transcribe audio to text using Whisper model
|
16 |
def transcribe_audio(audio_file):
|
|
|
|
|
|
|
|
|
17 |
# Check if audio_file is a list (Gradio returns a list when multiple clips are recorded)
|
18 |
if isinstance(audio_file, list):
|
19 |
-
|
|
|
20 |
else:
|
21 |
audio = np.array(audio_file) # Ensure it's a 1D array
|
22 |
-
|
23 |
-
#
|
24 |
-
if
|
25 |
-
audio = audio.
|
|
|
|
|
|
|
|
|
26 |
|
27 |
# Prepare input features for Whisper (sampling rate should be 16000 for Whisper)
|
28 |
-
input_features = processor(audio, return_tensors="pt", sampling_rate=
|
29 |
|
30 |
# Generate transcription
|
31 |
generated_ids = model.generate(input_features["input_features"])
|
@@ -33,35 +49,58 @@ def transcribe_audio(audio_file):
|
|
33 |
|
34 |
return transcription
|
35 |
|
36 |
-
# Function to detect
|
37 |
-
def
|
38 |
-
|
39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
|
41 |
# Cleanup function to remove filler words and clean the transcription
|
42 |
def cleanup_text(text):
|
43 |
-
|
|
|
|
|
|
|
|
|
44 |
text = re.sub(r'\b(uh|um|like|you know|so|actually|basically)\b', '', text, flags=re.IGNORECASE)
|
45 |
-
|
46 |
-
|
47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
text = text.strip()
|
49 |
-
|
|
|
50 |
text = text.capitalize()
|
|
|
51 |
return text
|
52 |
|
53 |
-
# Main function to process the audio and detect language
|
54 |
def process_audio(audio_file):
|
55 |
try:
|
56 |
transcription = transcribe_audio(audio_file) # Transcribe audio to text
|
|
|
57 |
if not transcription.strip(): # If transcription is empty or just whitespace
|
58 |
raise ValueError("Transcription is empty.")
|
59 |
-
|
60 |
-
|
|
|
|
|
61 |
cleaned_text = cleanup_text(transcription) # Clean up the transcription
|
62 |
-
|
63 |
-
return cleaned_text,
|
64 |
-
|
65 |
except Exception as e:
|
66 |
# If any error occurs, return the error message
|
67 |
return f"Error: {str(e)}", "", ""
|
|
|
1 |
+
!pip install git+https://github.com/speechbrain/speechbrain.git@develop
|
2 |
+
|
3 |
import re
|
4 |
import gradio as gr
|
5 |
from transformers import WhisperProcessor, WhisperForConditionalGeneration, pipeline
|
6 |
import torch
|
7 |
import numpy as np
|
8 |
+
import torchaudio
|
9 |
+
from speechbrain.inference.classifiers import EncoderClassifier
|
10 |
|
11 |
# Load Whisper model for transcription
|
12 |
whisper_model_name = "openai/whisper-large"
|
13 |
processor = WhisperProcessor.from_pretrained(whisper_model_name)
|
14 |
model = WhisperForConditionalGeneration.from_pretrained(whisper_model_name)
|
15 |
|
16 |
+
# Initialize the language detection model (using zero-shot classification for language detection)
|
17 |
lang_detect_model = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
|
18 |
|
19 |
+
# Load the SpeechBrain language ID model
|
20 |
+
language_id = EncoderClassifier.from_hparams(source="speechbrain/lang-id-voxlingua107-ecapa", savedir="tmp")
|
21 |
+
|
22 |
# Function to transcribe audio to text using Whisper model
|
23 |
def transcribe_audio(audio_file):
|
24 |
+
"""
|
25 |
+
Function to transcribe audio to text using Whisper model.
|
26 |
+
Handles both file input and live audio input.
|
27 |
+
"""
|
28 |
# Check if audio_file is a list (Gradio returns a list when multiple clips are recorded)
|
29 |
if isinstance(audio_file, list):
|
30 |
+
# Ensure all elements in the list are of the same length before concatenating
|
31 |
+
audio = np.concatenate([np.array(a) for a in audio_file if a is not None])
|
32 |
else:
|
33 |
audio = np.array(audio_file) # Ensure it's a 1D array
|
34 |
+
|
35 |
+
# If audio is stereo (2D array with shape (2, N)), mix the channels by averaging them
|
36 |
+
if audio.ndim > 1:
|
37 |
+
audio = audio.mean(axis=0) # Mix the stereo channels into a mono signal
|
38 |
+
|
39 |
+
# Ensure the audio is a 1D array (e.g., [N])
|
40 |
+
if audio.ndim != 1:
|
41 |
+
raise ValueError("The audio input must be a 1D array (mono).")
|
42 |
|
43 |
# Prepare input features for Whisper (sampling rate should be 16000 for Whisper)
|
44 |
+
input_features = processor(audio, return_tensors="pt", sampling_rate=48000)
|
45 |
|
46 |
# Generate transcription
|
47 |
generated_ids = model.generate(input_features["input_features"])
|
|
|
49 |
|
50 |
return transcription
|
51 |
|
52 |
+
# Function to detect language using SpeechBrain's language ID model
|
53 |
+
def detect_language_speechbrain(audio_file):
|
54 |
+
# Load the audio using torchaudio
|
55 |
+
signal, sample_rate = torchaudio.load(audio_file)
|
56 |
+
|
57 |
+
# Use SpeechBrain to classify the language of the audio
|
58 |
+
prediction = language_id.classify_batch(signal)
|
59 |
+
|
60 |
+
# Extract the language ISO code and its confidence
|
61 |
+
language = prediction[3][0] # Extracted language
|
62 |
+
confidence = prediction[1].exp() # Linear scale of confidence
|
63 |
+
return language, confidence.item()
|
64 |
|
65 |
# Cleanup function to remove filler words and clean the transcription
|
66 |
def cleanup_text(text):
|
67 |
+
"""
|
68 |
+
Function to clean the transcription text by removing filler words, unnecessary spaces,
|
69 |
+
non-alphabetic characters, and ensuring proper capitalization.
|
70 |
+
"""
|
71 |
+
# Step 1: Remove filler words like "uh", "um", etc.
|
72 |
text = re.sub(r'\b(uh|um|like|you know|so|actually|basically)\b', '', text, flags=re.IGNORECASE)
|
73 |
+
|
74 |
+
# Step 2: Remove unwanted characters (e.g., non-alphabetical characters except punctuation)
|
75 |
+
text = re.sub(r'[^a-zA-Z0-9\s,.\'?!]', '', text)
|
76 |
+
|
77 |
+
# Step 3: Remove extra spaces and ensure proper spacing around punctuation
|
78 |
+
text = re.sub(r'\s+', ' ', text) # Replace multiple spaces with a single space
|
79 |
+
text = re.sub(r'\s([?.!.,])', r'\1', text) # Remove space before punctuation
|
80 |
+
|
81 |
+
# Step 4: Normalize the whitespace (remove leading/trailing spaces)
|
82 |
text = text.strip()
|
83 |
+
|
84 |
+
# Step 5: Capitalize the first letter of the transcription
|
85 |
text = text.capitalize()
|
86 |
+
|
87 |
return text
|
88 |
|
89 |
+
# Main function to process the audio, transcribe it, and detect the language
|
90 |
def process_audio(audio_file):
|
91 |
try:
|
92 |
transcription = transcribe_audio(audio_file) # Transcribe audio to text
|
93 |
+
|
94 |
if not transcription.strip(): # If transcription is empty or just whitespace
|
95 |
raise ValueError("Transcription is empty.")
|
96 |
+
|
97 |
+
# Detect language using SpeechBrain's model
|
98 |
+
language, confidence = detect_language_speechbrain(audio_file)
|
99 |
+
|
100 |
cleaned_text = cleanup_text(transcription) # Clean up the transcription
|
101 |
+
|
102 |
+
return cleaned_text, language, confidence # Return cleaned transcription, language, and confidence score
|
103 |
+
|
104 |
except Exception as e:
|
105 |
# If any error occurs, return the error message
|
106 |
return f"Error: {str(e)}", "", ""
|