Omkar008 commited on
Commit
b666b7d
·
verified ·
1 Parent(s): b3f2b54

Update services/whisper_service.py

Browse files
Files changed (1) hide show
  1. services/whisper_service.py +8 -41
services/whisper_service.py CHANGED
@@ -1,46 +1,13 @@
1
  import whisper
2
- import tempfile
3
- import os
4
  import torch
5
- from config import settings
6
 
 
 
7
 
8
- class WhisperService:
9
- def __init__(self):
10
- if settings.TORCH_DEVICE == "cuda" and not torch.cuda.is_available():
11
- print("WARNING: CUDA requested but not available. Falling back to CPU.")
12
- self.device = "cpu"
13
- else:
14
- self.device = settings.TORCH_DEVICE
15
 
16
- self.model = whisper.load_model(settings.WHISPER_MODEL)
17
- if settings.FORCE_FP32 or self.device == "cpu":
18
- self.model = self.model.float()
19
-
20
- async def transcribe(self, audio_file: bytes, output_language: str = None) -> dict:
21
- try:
22
- # Create a temporary file to store the uploaded audio
23
- with tempfile.NamedTemporaryFile(delete=False, suffix='.mp3') as temp_audio:
24
- temp_audio.write(audio_file)
25
- temp_audio_path = temp_audio.name
26
-
27
- try:
28
- # Transcribe the audio
29
- transcription_options = {"fp16": not settings.FORCE_FP32 and self.device == "cuda"}
30
- if output_language:
31
- transcription_options["language"] = output_language
32
-
33
- result = self.model.transcribe(temp_audio_path, **transcription_options)
34
-
35
- return {
36
- "text": result["text"],
37
- "language": result.get("language"),
38
- "segments": result.get("segments")
39
- }
40
- finally:
41
- # Clean up the temporary file
42
- if os.path.exists(temp_audio_path):
43
- os.remove(temp_audio_path)
44
-
45
- except Exception as e:
46
- raise Exception(f"Transcription failed: {str(e)}")
 
1
  import whisper
 
 
2
  import torch
 
3
 
4
+ # Checking if NVIDIA GPU is available
5
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
6
 
7
+ # Load the Whisper model
8
+ model = whisper.load_model("base", device=DEVICE)
 
 
 
 
 
9
 
10
+ def transcribe_audio(file_path: str) -> str:
11
+ """Transcribes the given audio file and returns the text."""
12
+ result = model.transcribe(file_path)
13
+ return result['text']