Tony4 commited on
Commit
902e3e0
·
verified ·
1 Parent(s): d03b85c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -24
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import gradio as gr
2
  from transformers import WhisperProcessor, WhisperForConditionalGeneration
3
  import torch
@@ -5,52 +6,53 @@ import os
5
  import soundfile as sf
6
  from scipy.signal import resample
7
 
8
- # Define the model ID
9
  MODEL_ID = "WMRNORDIC/whisper-swedish-telephonic"
10
-
11
- # Load the Hugging Face token from the environment
12
  HF_API_TOKEN = os.getenv("HF_API_TOKEN")
 
13
  if not HF_API_TOKEN:
14
- raise ValueError("HF_API_TOKEN not found in environment variables. Please set it in the Space settings.")
15
 
16
  # Sample file path
17
  SAMPLE_FILE_PATH = "trimmed_resampled_audio.wav" # Update this path if necessary
18
 
19
- # Function to initialize the model and processor lazily
 
20
  def initialize_model():
21
- print("Loading model and processor...")
 
22
  processor = WhisperProcessor.from_pretrained(MODEL_ID, token=HF_API_TOKEN)
23
  model = WhisperForConditionalGeneration.from_pretrained(MODEL_ID, token=HF_API_TOKEN)
24
- model = model.to("cuda" if torch.cuda.is_available() else "cpu") # Use GPU if available
25
- print("Model loaded successfully.")
 
26
  return processor, model
27
 
28
- # Function to resample audio to 16kHz
29
- def resample_audio(audio_data, original_rate, target_rate=16000):
30
- if original_rate != target_rate:
31
- num_samples = int(len(audio_data) * target_rate / original_rate)
32
- return resample(audio_data, num_samples)
33
- return audio_data
34
-
35
- # Transcription function
36
  def transcribe_audio(audio):
 
37
  try:
 
38
  global processor, model
39
  if 'processor' not in globals() or 'model' not in globals():
40
  processor, model = initialize_model()
41
-
42
- # Handle microphone input or uploaded file
43
  if isinstance(audio, tuple): # Microphone input
44
- audio_data = audio[1]
45
- sample_rate = audio[0]
46
- audio_data = resample_audio(audio_data, sample_rate)
47
  else: # Uploaded file
48
  audio_data, sample_rate = sf.read(audio)
49
- audio_data = resample_audio(audio_data, sample_rate)
50
 
51
- # Preprocess and perform inference
 
 
 
 
 
52
  device = "cuda" if torch.cuda.is_available() else "cpu"
53
  input_features = processor(audio_data, return_tensors="pt", sampling_rate=16000).input_features.to(device)
 
 
54
  with torch.no_grad():
55
  predicted_ids = model.generate(input_features)
56
 
@@ -61,7 +63,7 @@ def transcribe_audio(audio):
61
  except Exception as e:
62
  return f"Error during transcription: {str(e)}"
63
 
64
- # Gradio interface
65
  def create_demo():
66
  """Set up the Gradio app."""
67
  with gr.Blocks() as demo:
@@ -79,6 +81,7 @@ def create_demo():
79
  audio_input.change(transcribe_audio, inputs=audio_input, outputs=transcription_output)
80
  return demo
81
 
 
82
  # Initialize Gradio app
83
  demo = create_demo()
84
 
 
1
+ import spaces # Required for ZeroGPU compliance
2
  import gradio as gr
3
  from transformers import WhisperProcessor, WhisperForConditionalGeneration
4
  import torch
 
6
  import soundfile as sf
7
  from scipy.signal import resample
8
 
9
+ # Model ID and Hugging Face Token
10
  MODEL_ID = "WMRNORDIC/whisper-swedish-telephonic"
 
 
11
  HF_API_TOKEN = os.getenv("HF_API_TOKEN")
12
+
13
  if not HF_API_TOKEN:
14
+ raise ValueError("HF_API_TOKEN not found. Set it in the environment variables.")
15
 
16
  # Sample file path
17
  SAMPLE_FILE_PATH = "trimmed_resampled_audio.wav" # Update this path if necessary
18
 
19
+
20
+ @spaces.GPU
21
  def initialize_model():
22
+ """Lazy initialization of model and processor with GPU allocation."""
23
+ print("Initializing model and processor...")
24
  processor = WhisperProcessor.from_pretrained(MODEL_ID, token=HF_API_TOKEN)
25
  model = WhisperForConditionalGeneration.from_pretrained(MODEL_ID, token=HF_API_TOKEN)
26
+ device = "cuda" if torch.cuda.is_available() else "cpu"
27
+ model = model.to(device)
28
+ print(f"Model loaded on device: {device}")
29
  return processor, model
30
 
31
+ @spaces.GPU
 
 
 
 
 
 
 
32
  def transcribe_audio(audio):
33
+ """Transcription logic with ZeroGPU compliance."""
34
  try:
35
+ # Lazy-load model and processor
36
  global processor, model
37
  if 'processor' not in globals() or 'model' not in globals():
38
  processor, model = initialize_model()
39
+
40
+ # Handle audio input
41
  if isinstance(audio, tuple): # Microphone input
42
+ audio_data, sample_rate = audio[1], audio[0]
 
 
43
  else: # Uploaded file
44
  audio_data, sample_rate = sf.read(audio)
 
45
 
46
+ # Resample to 16kHz
47
+ if sample_rate != 16000:
48
+ num_samples = int(len(audio_data) * 16000 / sample_rate)
49
+ audio_data = resample(audio_data, num_samples)
50
+
51
+ # Prepare inputs for the model
52
  device = "cuda" if torch.cuda.is_available() else "cpu"
53
  input_features = processor(audio_data, return_tensors="pt", sampling_rate=16000).input_features.to(device)
54
+
55
+ # Generate transcription
56
  with torch.no_grad():
57
  predicted_ids = model.generate(input_features)
58
 
 
63
  except Exception as e:
64
  return f"Error during transcription: {str(e)}"
65
 
66
+ # Gradio Interface
67
  def create_demo():
68
  """Set up the Gradio app."""
69
  with gr.Blocks() as demo:
 
81
  audio_input.change(transcribe_audio, inputs=audio_input, outputs=transcription_output)
82
  return demo
83
 
84
+
85
  # Initialize Gradio app
86
  demo = create_demo()
87