Tamerstito commited on
Commit
df7a732
·
verified ·
1 Parent(s): 283d3e6

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +7 -24
  2. requirements.txt +2 -2
app.py CHANGED
@@ -20,60 +20,44 @@ def translate_audio(filepath):
20
  if filepath is None or not os.path.exists(filepath):
21
  return "No audio file received or file does not exist."
22
 
23
- # Lazy-load model and processor
24
  if model is None:
25
  print("Loading Whisper model...")
26
  model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")
27
  processor = WhisperProcessor.from_pretrained("openai/whisper-small")
28
  forced_decoder_ids = processor.get_decoder_prompt_ids(
29
- task="translate", language="en"
30
  )
31
- print("Model loaded and decoder ids set.")
32
 
33
  audio = AudioSegment.from_file(filepath).set_channels(1)
34
- print("Audio loaded. Duration (ms):", len(audio))
35
-
36
  chunk_length_ms = 30 * 1000
37
  chunks = [audio[i:i + chunk_length_ms] for i in range(0, len(audio), chunk_length_ms)]
38
- print(f"Audio split into {len(chunks)} chunks.")
39
 
40
  full_translation = ""
41
 
42
  for i, chunk in enumerate(chunks):
43
  chunk_path = f"chunk_{i}.wav"
44
  chunk.export(chunk_path, format="wav")
45
- print(f"Exported chunk {i} to {chunk_path}")
46
-
47
  waveform, sample_rate = torchaudio.load(chunk_path)
48
 
49
- # Resample if necessary
50
  if sample_rate != 16000:
51
- print(f"Resampling from {sample_rate} Hz to 16000 Hz")
52
- resampler = T.Resample(orig_freq=sample_rate, new_freq=16000)
53
- waveform = resampler(waveform)
54
 
55
- # Convert to mono
56
  waveform = waveform.mean(dim=0)
57
-
58
- inputs = processor(
59
- waveform,
60
- sampling_rate=16000,
61
- return_tensors="pt"
62
- )
63
 
64
  with torch.no_grad():
65
  generated_ids = model.generate(
66
  inputs["input_features"],
67
- forced_decoder_ids=forced_decoder_ids
 
68
  )
69
 
70
  translation = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
71
- print(f"Chunk {i} translation:", translation)
72
-
73
  full_translation += translation + " "
74
  os.remove(chunk_path)
75
 
76
- print("Full translation done.")
77
  return full_translation.strip()
78
 
79
  except Exception as e:
@@ -81,7 +65,6 @@ def translate_audio(filepath):
81
  traceback.print_exc()
82
  return f"An error occurred: {str(e)}"
83
 
84
- # Gradio UI
85
  mic_transcribe = gr.Interface(
86
  fn=translate_audio,
87
  inputs=gr.Audio(sources="microphone", type="filepath"),
 
20
  if filepath is None or not os.path.exists(filepath):
21
  return "No audio file received or file does not exist."
22
 
23
+ # Load Whisper
24
  if model is None:
25
  print("Loading Whisper model...")
26
  model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")
27
  processor = WhisperProcessor.from_pretrained("openai/whisper-small")
28
  forced_decoder_ids = processor.get_decoder_prompt_ids(
29
+ task="translate", language="es"
30
  )
31
+ print("Model and processor ready.")
32
 
33
  audio = AudioSegment.from_file(filepath).set_channels(1)
 
 
34
  chunk_length_ms = 30 * 1000
35
  chunks = [audio[i:i + chunk_length_ms] for i in range(0, len(audio), chunk_length_ms)]
 
36
 
37
  full_translation = ""
38
 
39
  for i, chunk in enumerate(chunks):
40
  chunk_path = f"chunk_{i}.wav"
41
  chunk.export(chunk_path, format="wav")
 
 
42
  waveform, sample_rate = torchaudio.load(chunk_path)
43
 
 
44
  if sample_rate != 16000:
45
+ waveform = T.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
 
 
46
 
 
47
  waveform = waveform.mean(dim=0)
48
+ inputs = processor(waveform, sampling_rate=16000, return_tensors="pt")
 
 
 
 
 
49
 
50
  with torch.no_grad():
51
  generated_ids = model.generate(
52
  inputs["input_features"],
53
+ forced_decoder_ids=forced_decoder_ids,
54
+ suppress_tokens=[]
55
  )
56
 
57
  translation = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
 
 
58
  full_translation += translation + " "
59
  os.remove(chunk_path)
60
 
 
61
  return full_translation.strip()
62
 
63
  except Exception as e:
 
65
  traceback.print_exc()
66
  return f"An error occurred: {str(e)}"
67
 
 
68
  mic_transcribe = gr.Interface(
69
  fn=translate_audio,
70
  inputs=gr.Audio(sources="microphone", type="filepath"),
requirements.txt CHANGED
@@ -1,6 +1,6 @@
1
  --extra-index-url https://download.pytorch.org/whl/cpu
2
- transformers==4.36.2
3
  torch
4
  torchaudio
5
- pydub
6
  gradio
 
 
1
  --extra-index-url https://download.pytorch.org/whl/cpu
 
2
  torch
3
  torchaudio
4
+ transformers==4.36.2
5
  gradio
6
+ pydub