ajchri5 commited on
Commit
89c1bb5
·
verified ·
1 Parent(s): e40d9da

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -22
app.py CHANGED
@@ -1,31 +1,47 @@
 
 
1
  import re
2
  import gradio as gr
3
  from transformers import WhisperProcessor, WhisperForConditionalGeneration, pipeline
4
  import torch
5
  import numpy as np
 
 
6
 
7
  # Load Whisper model for transcription
8
  whisper_model_name = "openai/whisper-large"
9
  processor = WhisperProcessor.from_pretrained(whisper_model_name)
10
  model = WhisperForConditionalGeneration.from_pretrained(whisper_model_name)
11
 
12
- # Initialize the language detection model
13
  lang_detect_model = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
14
 
 
 
 
15
  # Function to transcribe audio to text using Whisper model
16
  def transcribe_audio(audio_file):
 
 
 
 
17
  # Check if audio_file is a list (Gradio returns a list when multiple clips are recorded)
18
  if isinstance(audio_file, list):
19
- audio = np.concatenate(audio_file) # Concatenate the list of arrays into a single 1D array
 
20
  else:
21
  audio = np.array(audio_file) # Ensure it's a 1D array
22
-
23
- # Ensure the shape is 1D (if the shape is (2, N), we flatten it)
24
- if len(audio.shape) > 1:
25
- audio = audio.flatten()
 
 
 
 
26
 
27
  # Prepare input features for Whisper (sampling rate should be 16000 for Whisper)
28
- input_features = processor(audio, return_tensors="pt", sampling_rate=16000)
29
 
30
  # Generate transcription
31
  generated_ids = model.generate(input_features["input_features"])
@@ -33,35 +49,58 @@ def transcribe_audio(audio_file):
33
 
34
  return transcription
35
 
36
- # Function to detect the language of the transcription using zero-shot classification
37
- def detect_language(text):
38
- result = lang_detect_model(text, candidate_labels=["en", "fr", "es", "de", "it", "pt", "zh", "ja", "ar", "hi"])
39
- return result['labels'][0], result['scores'][0] # Return the detected language and score
 
 
 
 
 
 
 
 
40
 
41
  # Cleanup function to remove filler words and clean the transcription
42
  def cleanup_text(text):
43
- # Remove filler words like "uh", "um", etc.
 
 
 
 
44
  text = re.sub(r'\b(uh|um|like|you know|so|actually|basically)\b', '', text, flags=re.IGNORECASE)
45
- # Remove extra spaces
46
- text = re.sub(r'\s+', ' ', text)
47
- # Strip leading and trailing spaces
 
 
 
 
 
 
48
  text = text.strip()
49
- # Capitalize the first letter
 
50
  text = text.capitalize()
 
51
  return text
52
 
53
- # Main function to process the audio and detect language
54
  def process_audio(audio_file):
55
  try:
56
  transcription = transcribe_audio(audio_file) # Transcribe audio to text
 
57
  if not transcription.strip(): # If transcription is empty or just whitespace
58
  raise ValueError("Transcription is empty.")
59
-
60
- lang, score = detect_language(transcription) # Detect the language of the transcription
 
 
61
  cleaned_text = cleanup_text(transcription) # Clean up the transcription
62
-
63
- return cleaned_text, lang, score # Return cleaned transcription, language, and confidence score
64
-
65
  except Exception as e:
66
  # If any error occurs, return the error message
67
  return f"Error: {str(e)}", "", ""
 
1
+ !pip install git+https://github.com/speechbrain/speechbrain.git@develop
2
+
3
  import re
4
  import gradio as gr
5
  from transformers import WhisperProcessor, WhisperForConditionalGeneration, pipeline
6
  import torch
7
  import numpy as np
8
+ import torchaudio
9
+ from speechbrain.inference.classifiers import EncoderClassifier
10
 
11
  # Load Whisper model for transcription
12
  whisper_model_name = "openai/whisper-large"
13
  processor = WhisperProcessor.from_pretrained(whisper_model_name)
14
  model = WhisperForConditionalGeneration.from_pretrained(whisper_model_name)
15
 
16
+ # Initialize the language detection model (using zero-shot classification for language detection)
17
  lang_detect_model = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
18
 
19
+ # Load the SpeechBrain language ID model
20
+ language_id = EncoderClassifier.from_hparams(source="speechbrain/lang-id-voxlingua107-ecapa", savedir="tmp")
21
+
22
  # Function to transcribe audio to text using Whisper model
23
  def transcribe_audio(audio_file):
24
+ """
25
+ Function to transcribe audio to text using Whisper model.
26
+ Handles both file input and live audio input.
27
+ """
28
  # Check if audio_file is a list (Gradio returns a list when multiple clips are recorded)
29
  if isinstance(audio_file, list):
30
+ # Ensure all elements in the list are of the same length before concatenating
31
+ audio = np.concatenate([np.array(a) for a in audio_file if a is not None])
32
  else:
33
  audio = np.array(audio_file) # Ensure it's a 1D array
34
+
35
+ # If audio is stereo (2D array with shape (2, N)), mix the channels by averaging them
36
+ if audio.ndim > 1:
37
+ audio = audio.mean(axis=0) # Mix the stereo channels into a mono signal
38
+
39
+ # Ensure the audio is a 1D array (e.g., [N])
40
+ if audio.ndim != 1:
41
+ raise ValueError("The audio input must be a 1D array (mono).")
42
 
43
  # Prepare input features for Whisper (sampling rate should be 16000 for Whisper)
44
+ input_features = processor(audio, return_tensors="pt", sampling_rate=48000)
45
 
46
  # Generate transcription
47
  generated_ids = model.generate(input_features["input_features"])
 
49
 
50
  return transcription
51
 
52
+ # Function to detect language using SpeechBrain's language ID model
53
+ def detect_language_speechbrain(audio_file):
54
+ # Load the audio using torchaudio
55
+ signal, sample_rate = torchaudio.load(audio_file)
56
+
57
+ # Use SpeechBrain to classify the language of the audio
58
+ prediction = language_id.classify_batch(signal)
59
+
60
+ # Extract the language ISO code and its confidence
61
+ language = prediction[3][0] # Extracted language
62
+ confidence = prediction[1].exp() # Linear scale of confidence
63
+ return language, confidence.item()
64
 
65
  # Cleanup function to remove filler words and clean the transcription
66
  def cleanup_text(text):
67
+ """
68
+ Function to clean the transcription text by removing filler words, unnecessary spaces,
69
+ non-alphabetic characters, and ensuring proper capitalization.
70
+ """
71
+ # Step 1: Remove filler words like "uh", "um", etc.
72
  text = re.sub(r'\b(uh|um|like|you know|so|actually|basically)\b', '', text, flags=re.IGNORECASE)
73
+
74
+ # Step 2: Remove unwanted characters (e.g., non-alphabetical characters except punctuation)
75
+ text = re.sub(r'[^a-zA-Z0-9\s,.\'?!]', '', text)
76
+
77
+ # Step 3: Remove extra spaces and ensure proper spacing around punctuation
78
+ text = re.sub(r'\s+', ' ', text) # Replace multiple spaces with a single space
79
+ text = re.sub(r'\s([?.!.,])', r'\1', text) # Remove space before punctuation
80
+
81
+ # Step 4: Normalize the whitespace (remove leading/trailing spaces)
82
  text = text.strip()
83
+
84
+ # Step 5: Capitalize the first letter of the transcription
85
  text = text.capitalize()
86
+
87
  return text
88
 
89
+ # Main function to process the audio, transcribe it, and detect the language
90
  def process_audio(audio_file):
91
  try:
92
  transcription = transcribe_audio(audio_file) # Transcribe audio to text
93
+
94
  if not transcription.strip(): # If transcription is empty or just whitespace
95
  raise ValueError("Transcription is empty.")
96
+
97
+ # Detect language using SpeechBrain's model
98
+ language, confidence = detect_language_speechbrain(audio_file)
99
+
100
  cleaned_text = cleanup_text(transcription) # Clean up the transcription
101
+
102
+ return cleaned_text, language, confidence # Return cleaned transcription, language, and confidence score
103
+
104
  except Exception as e:
105
  # If any error occurs, return the error message
106
  return f"Error: {str(e)}", "", ""