AnyaSchen commited on
Commit
eca4b03
·
1 Parent(s): 65ac0a4

fix language detection

Browse files
Files changed (2) hide show
  1. main.py +18 -16
  2. whisper_streaming_custom/backends.py +78 -1
main.py CHANGED
@@ -17,21 +17,19 @@ import tempfile
17
 
18
  from core import WhisperLiveKit
19
  from audio_processor import AudioProcessor
20
- from language_detector import LanguageDetector
21
 
22
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
23
  logging.getLogger().setLevel(logging.WARNING)
24
  logger = logging.getLogger(__name__)
25
  logger.setLevel(logging.DEBUG)
26
 
27
- kit = None
28
- language_detector = None
29
 
30
  @asynccontextmanager
31
  async def lifespan(app: FastAPI):
32
- global kit, language_detector
33
- kit = WhisperLiveKit()
34
- language_detector = LanguageDetector(model_name="large")
35
  yield
36
 
37
  app = FastAPI(lifespan=lifespan)
@@ -47,8 +45,6 @@ app.add_middleware(
47
  # Mount static files
48
  app.mount("/static", StaticFiles(directory="static"), name="static")
49
 
50
-
51
-
52
  @app.get("/")
53
  async def read_root():
54
  return FileResponse("static/index.html")
@@ -66,9 +62,16 @@ async def detect_language(file: UploadFile = File(...)):
66
  contents = await file.read()
67
  temp_file.write(contents)
68
 
69
- # Use the language detector with the saved file
70
- if language_detector:
71
- detected_lang, confidence, probs = language_detector.detect_language_from_file(file_path)
 
 
 
 
 
 
 
72
 
73
  # Clean up - remove the temporary file
74
  os.remove(file_path)
@@ -80,7 +83,7 @@ async def detect_language(file: UploadFile = File(...)):
80
  })
81
  else:
82
  return JSONResponse(
83
- {"error": "Language detector not initialized"},
84
  status_code=500
85
  )
86
 
@@ -127,14 +130,15 @@ async def handle_websocket_results(websocket, results_generator):
127
  @app.websocket("/asr")
128
  async def websocket_endpoint(websocket: WebSocket):
129
  logger.info("New WebSocket connection request")
130
- audio_processor = None
131
  websocket_task = None
132
 
133
  try:
134
  await websocket.accept()
135
  logger.info("WebSocket connection accepted")
136
 
137
- audio_processor = AudioProcessor()
 
 
138
  results_generator = await audio_processor.create_tasks()
139
  websocket_task = asyncio.create_task(handle_websocket_results(websocket, results_generator))
140
 
@@ -155,8 +159,6 @@ async def websocket_endpoint(websocket: WebSocket):
155
  logger.error(f"Error in WebSocket endpoint: {e}")
156
  logger.error(f"Traceback: {traceback.format_exc()}")
157
  finally:
158
- if audio_processor:
159
- await audio_processor.cleanup()
160
  if websocket_task:
161
  websocket_task.cancel()
162
  try:
 
17
 
18
  from core import WhisperLiveKit
19
  from audio_processor import AudioProcessor
 
20
 
21
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
22
  logging.getLogger().setLevel(logging.WARNING)
23
  logger = logging.getLogger(__name__)
24
  logger.setLevel(logging.DEBUG)
25
 
26
+ audio_processor = None
 
27
 
28
  @asynccontextmanager
29
  async def lifespan(app: FastAPI):
30
+ global audio_processor
31
+ kit = WhisperLiveKit(args=args)
32
+ audio_processor = AudioProcessor()
33
  yield
34
 
35
  app = FastAPI(lifespan=lifespan)
 
45
  # Mount static files
46
  app.mount("/static", StaticFiles(directory="static"), name="static")
47
 
 
 
48
  @app.get("/")
49
  async def read_root():
50
  return FileResponse("static/index.html")
 
62
  contents = await file.read()
63
  temp_file.write(contents)
64
 
65
+ # Use the audio processor for language detection
66
+ if audio_processor:
67
+ # Load audio using librosa
68
+ audio, sr = librosa.load(file_path, sr=16000)
69
+
70
+ # Convert to format expected by Whisper
71
+ audio = (audio * 32768).astype(np.int16)
72
+
73
+ # Detect language
74
+ detected_lang, confidence, probs = audio_processor.detect_language(audio)
75
 
76
  # Clean up - remove the temporary file
77
  os.remove(file_path)
 
83
  })
84
  else:
85
  return JSONResponse(
86
+ {"error": "Audio processor not initialized"},
87
  status_code=500
88
  )
89
 
 
130
  @app.websocket("/asr")
131
  async def websocket_endpoint(websocket: WebSocket):
132
  logger.info("New WebSocket connection request")
 
133
  websocket_task = None
134
 
135
  try:
136
  await websocket.accept()
137
  logger.info("WebSocket connection accepted")
138
 
139
+ if not audio_processor:
140
+ raise RuntimeError("Audio processor not initialized")
141
+
142
  results_generator = await audio_processor.create_tasks()
143
  websocket_task = asyncio.create_task(handle_websocket_results(websocket, results_generator))
144
 
 
159
  logger.error(f"Error in WebSocket endpoint: {e}")
160
  logger.error(f"Traceback: {traceback.format_exc()}")
161
  finally:
 
 
162
  if websocket_task:
163
  websocket_task.cancel()
164
  try:
whisper_streaming_custom/backends.py CHANGED
@@ -89,6 +89,42 @@ class WhisperTimestampedASR(ASRBase):
89
  def set_translate_task(self):
90
  self.transcribe_kargs["task"] = "translate"
91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
  class FasterWhisperASR(ASRBase):
94
  """Uses faster-whisper as the backend."""
@@ -147,6 +183,41 @@ class FasterWhisperASR(ASRBase):
147
  def set_translate_task(self):
148
  self.transcribe_kargs["task"] = "translate"
149
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
 
151
  class MLXWhisper(ASRBase):
152
  """
@@ -225,6 +296,9 @@ class MLXWhisper(ASRBase):
225
 
226
  def set_translate_task(self):
227
  self.transcribe_kargs["task"] = "translate"
 
 
 
228
 
229
 
230
  class OpenaiApiASR(ASRBase):
@@ -292,4 +366,7 @@ class OpenaiApiASR(ASRBase):
292
  self.use_vad_opt = True
293
 
294
  def set_translate_task(self):
295
- self.task = "translate"
 
 
 
 
89
  def set_translate_task(self):
90
  self.transcribe_kargs["task"] = "translate"
91
 
92
+ def detect_language(self, audio):
93
+ import whisper
94
+ """
95
+ Detect the language of the audio using Whisper's language detection.
96
+
97
+ Args:
98
+ audio (np.ndarray): Audio data as numpy array
99
+
100
+ Returns:
101
+ tuple: (detected_language, confidence, probabilities)
102
+ - detected_language (str): The detected language code
103
+ - confidence (float): Confidence score for the detected language
104
+ - probabilities (dict): Dictionary of language probabilities
105
+ """
106
+ try:
107
+ # Ensure audio is in the correct format
108
+ if not isinstance(audio, np.ndarray):
109
+ audio = np.array(audio)
110
+
111
+ # Pad or trim audio to the correct length
112
+ audio = whisper.pad_or_trim(audio)
113
+
114
+ # Create mel spectrogram with correct dimensions
115
+ mel = whisper.log_mel_spectrogram(audio, n_mels=128).to(self.model.device)
116
+
117
+ # Detect language
118
+ _, probs = self.model.detect_language(mel)
119
+ detected_lang = max(probs, key=probs.get)
120
+ confidence = probs[detected_lang]
121
+
122
+ return detected_lang, confidence, probs
123
+
124
+ except Exception as e:
125
+ logger.error(f"Error in language detection: {e}")
126
+ raise
127
+
128
 
129
  class FasterWhisperASR(ASRBase):
130
  """Uses faster-whisper as the backend."""
 
183
  def set_translate_task(self):
184
  self.transcribe_kargs["task"] = "translate"
185
 
186
+ def detect_language(self, audio):
187
+ """
188
+ Detect the language of the audio using faster-whisper's language detection.
189
+
190
+ Args:
191
+ audio (np.ndarray): Audio data as numpy array
192
+
193
+ Returns:
194
+ tuple: (detected_language, confidence, probabilities)
195
+ - detected_language (str): The detected language code
196
+ - confidence (float): Confidence score for the detected language
197
+ - probabilities (dict): Dictionary of language probabilities
198
+ """
199
+ try:
200
+ # Ensure audio is in the correct format
201
+ if not isinstance(audio, np.ndarray):
202
+ audio = np.array(audio)
203
+
204
+ # Use faster-whisper's detect_language method
205
+ language, language_probability, all_language_probs = self.model.detect_language(
206
+ audio=audio,
207
+ vad_filter=False, # Disable VAD for language detection
208
+ language_detection_segments=1, # Use single segment for detection
209
+ language_detection_threshold=0.5 # Default threshold
210
+ )
211
+
212
+ # Convert list of tuples to dictionary for consistent return format
213
+ probs = {lang: prob for lang, prob in all_language_probs}
214
+
215
+ return language, language_probability, probs
216
+
217
+ except Exception as e:
218
+ logger.error(f"Error in language detection: {e}")
219
+ raise
220
+
221
 
222
  class MLXWhisper(ASRBase):
223
  """
 
296
 
297
  def set_translate_task(self):
298
  self.transcribe_kargs["task"] = "translate"
299
+
300
+ def detect_language(self, audio):
301
+ raise NotImplementedError("MLX Whisper does not support language detection.")
302
 
303
 
304
  class OpenaiApiASR(ASRBase):
 
366
  self.use_vad_opt = True
367
 
368
  def set_translate_task(self):
369
+ self.task = "translate"
370
+
371
+ def detect_language(self, audio):
372
+ raise NotImplementedError("MLX Whisper does not support language detection.")