Tamerstito commited on
Commit
d07df4a
·
verified ·
1 Parent(s): df7a732

Upload 2 files

Browse files
Files changed (1) hide show
  1. app.py +56 -76
app.py CHANGED
@@ -1,90 +1,70 @@
1
- from transformers import WhisperProcessor, WhisperForConditionalGeneration
2
- import torchaudio
3
- import torchaudio.transforms as T
4
  import torch
5
- import os
 
6
  import gradio as gr
7
  from pydub import AudioSegment
8
- import traceback
9
 
10
- # Lazy-load components
11
- model = None
12
- processor = None
13
- forced_decoder_ids = None
 
 
 
 
 
 
 
14
 
15
  def translate_audio(filepath):
16
- global model, processor, forced_decoder_ids
17
- try:
18
- print("Received filepath:", filepath)
19
-
20
- if filepath is None or not os.path.exists(filepath):
21
- return "No audio file received or file does not exist."
22
-
23
- # Load Whisper
24
- if model is None:
25
- print("Loading Whisper model...")
26
- model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")
27
- processor = WhisperProcessor.from_pretrained("openai/whisper-small")
28
- forced_decoder_ids = processor.get_decoder_prompt_ids(
29
- task="translate", language="es"
30
- )
31
- print("Model and processor ready.")
32
-
33
- audio = AudioSegment.from_file(filepath).set_channels(1)
34
- chunk_length_ms = 30 * 1000
35
- chunks = [audio[i:i + chunk_length_ms] for i in range(0, len(audio), chunk_length_ms)]
36
-
37
- full_translation = ""
38
-
39
- for i, chunk in enumerate(chunks):
40
- chunk_path = f"chunk_{i}.wav"
41
- chunk.export(chunk_path, format="wav")
42
- waveform, sample_rate = torchaudio.load(chunk_path)
43
-
44
- if sample_rate != 16000:
45
- waveform = T.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
46
-
47
- waveform = waveform.mean(dim=0)
48
- inputs = processor(waveform, sampling_rate=16000, return_tensors="pt")
49
-
50
- with torch.no_grad():
51
- generated_ids = model.generate(
52
- inputs["input_features"],
53
- forced_decoder_ids=forced_decoder_ids,
54
- suppress_tokens=[]
55
- )
56
-
57
- translation = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
58
- full_translation += translation + " "
59
- os.remove(chunk_path)
60
-
61
- return full_translation.strip()
62
-
63
- except Exception as e:
64
- print("ERROR:", str(e))
65
- traceback.print_exc()
66
- return f"An error occurred: {str(e)}"
67
-
68
- mic_transcribe = gr.Interface(
69
  fn=translate_audio,
70
  inputs=gr.Audio(sources="microphone", type="filepath"),
71
- outputs=gr.Textbox(label="Translation (English to Spanish)", lines=3),
72
- allow_flagging="never"
73
  )
74
 
75
- file_transcribe = gr.Interface(
76
  fn=translate_audio,
77
  inputs=gr.Audio(sources="upload", type="filepath"),
78
- outputs=gr.Textbox(label="Translation (English to Spanish)", lines=3),
79
- allow_flagging="never"
80
  )
81
 
82
- demo = gr.Blocks()
83
- with demo:
84
- gr.TabbedInterface(
85
- [mic_transcribe, file_transcribe],
86
- ["Translate Microphone", "Translate Audio File"]
87
- )
88
-
89
- server_port = int(os.environ.get("PORT", 7860))
90
- demo.launch(share=True, server_port=server_port)
 
1
+
 
 
2
  import torch
3
+ import torchaudio
4
+ from transformers import WhisperForConditionalGeneration, WhisperProcessor
5
  import gradio as gr
6
  from pydub import AudioSegment
7
+ import os
8
 
9
+ # Load model and processor
10
+ model_id = "openai/whisper-small"
11
+ model = WhisperForConditionalGeneration.from_pretrained(model_id)
12
+ processor = WhisperProcessor.from_pretrained(model_id)
13
+
14
+ # Set to eval mode and avoid grad
15
+ model.eval()
16
+ torch.set_grad_enabled(False)
17
+
18
+ # Get decoder prompts for English to Spanish translation
19
+ forced_decoder_ids = processor.get_decoder_prompt_ids(task="translate", language="es")
20
 
21
  def translate_audio(filepath):
22
+ if filepath is None or not os.path.exists(filepath):
23
+ return "No audio file received."
24
+
25
+ audio = AudioSegment.from_file(filepath).set_channels(1)
26
+ chunk_length_ms = 30 * 1000
27
+ chunks = [audio[i:i+chunk_length_ms] for i in range(0, len(audio), chunk_length_ms)]
28
+
29
+ final_output = ""
30
+
31
+ for idx, chunk in enumerate(chunks):
32
+ chunk_path = f"chunk_{idx}.wav"
33
+ chunk.export(chunk_path, format="wav")
34
+
35
+ waveform, sr = torchaudio.load(chunk_path)
36
+ os.remove(chunk_path)
37
+
38
+ # Resample if needed
39
+ if sr != 16000:
40
+ resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000)
41
+ waveform = resampler(waveform)
42
+
43
+ waveform = waveform.mean(dim=0) # convert to mono
44
+ inputs = processor(waveform, sampling_rate=16000, return_tensors="pt")
45
+
46
+ predicted_ids = model.generate(
47
+ inputs["input_features"],
48
+ forced_decoder_ids=forced_decoder_ids,
49
+ max_new_tokens=448
50
+ )
51
+
52
+ result = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
53
+ final_output += result.strip() + " "
54
+
55
+ return final_output.strip()
56
+
57
+ mic_ui = gr.Interface(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  fn=translate_audio,
59
  inputs=gr.Audio(sources="microphone", type="filepath"),
60
+ outputs=gr.Textbox(label="Translated Text (English to Spanish)"),
 
61
  )
62
 
63
+ file_ui = gr.Interface(
64
  fn=translate_audio,
65
  inputs=gr.Audio(sources="upload", type="filepath"),
66
+ outputs=gr.Textbox(label="Translated Text (English to Spanish)"),
 
67
  )
68
 
69
+ app = gr.TabbedInterface([mic_ui, file_ui], ["Microphone Input", "Upload File"])
70
+ app.launch()