peterkros commited on
Commit
d2753e9
·
verified ·
1 Parent(s): 9e71ecb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -8
app.py CHANGED
@@ -1,14 +1,32 @@
1
  import gradio as gr
2
- from transformers import pipeline
 
3
 
4
- # Load Whisper model from Hugging Face
5
- # This uses the `transformers` library's pipeline to load the model
6
- transcriber = pipeline("automatic-speech-recognition", model="openai/whisper-large-v3")
 
7
 
 
 
 
 
 
8
  def transcribe(audio):
9
- # Transcribe the audio using the Whisper model
10
- result = transcriber(audio)["text"]
11
- return result
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  # Create a Gradio Interface
14
  interface = gr.Interface(
@@ -20,4 +38,4 @@ interface = gr.Interface(
20
  )
21
 
22
  # Launch the interface as an API
23
- interface.launch()
 
1
  import gradio as gr
2
+ from transformers import WhisperProcessor, WhisperForConditionalGeneration, pipeline
3
+ import torch
4
 
5
+ # Load Whisper model and processor from Hugging Face
6
+ model_name = "openai/whisper-large-v3"
7
+ processor = WhisperProcessor.from_pretrained(model_name)
8
+ model = WhisperForConditionalGeneration.from_pretrained(model_name)
9
 
10
+ # Ensure the model is using the correct device (GPU or CPU)
11
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
+ model.to(device)
13
+
14
+ # Function to handle transcription with language set to English by default
15
  def transcribe(audio):
16
+ # Load audio
17
+ input_features = processor(audio, sampling_rate=16000, return_tensors="pt").input_features.to(device)
18
+
19
+ # Generate transcription with attention_mask and correct input_features
20
+ attention_mask = torch.ones(input_features.shape, dtype=torch.long, device=device)
21
+ generated_ids = model.generate(
22
+ input_features=input_features,
23
+ attention_mask=attention_mask,
24
+ language="en" # Force translation to English
25
+ )
26
+
27
+ # Decode transcription
28
+ transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
29
+ return transcription
30
 
31
  # Create a Gradio Interface
32
  interface = gr.Interface(
 
38
  )
39
 
40
  # Launch the interface as an API
41
+ interface.launch()