NLPV commited on
Commit
f0b2a66
·
verified ·
1 Parent(s): a81460b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -26
app.py CHANGED
@@ -64,43 +64,30 @@ def compare_hindi_sentences(expected, transcribed):
64
 
65
  def transcribe_audio(audio_path, original_text):
66
  try:
67
- # 1. Load and pre-process audio
68
  waveform, sample_rate = torchaudio.load(audio_path)
 
69
  if waveform.shape[0] > 1:
70
  waveform = waveform.mean(dim=0, keepdim=True)
71
- if sample_rate != 48000:
72
- transform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=48000)
 
73
  waveform = transform(waveform)
74
-
75
- # Amplify voice intensity
76
- GAIN = 1.5
77
- waveform = waveform * GAIN
78
- waveform = torch.clamp(waveform, -1.0, 1.0)
79
-
80
- input_values = processor(waveform.squeeze().numpy(), sampling_rate=48000, return_tensors="pt").input_values
81
-
82
- # 2. Transcribe with AI4Bharat model
83
  with torch.no_grad():
84
  logits = model(input_values).logits
85
  predicted_ids = torch.argmax(logits, dim=-1)
86
  transcription = processor.decode(predicted_ids[0])
87
 
88
- # 3. Error analysis (as table)
89
- errors = compare_hindi_sentences(original_text, transcription)
90
- df_errors = pd.DataFrame(errors, columns=["बिगड़ा हुआ शब्द", "संभावित सही शब्द", "गलती का प्रकार"])
91
 
92
- # Speaking speed
93
- transcribed_words = transcription.strip().split()
94
- duration = waveform.shape[1] / 48000
95
- speed = round(len(transcribed_words) / duration, 2) if duration > 0 else 0
96
-
97
- result = {
98
  "📝 Transcribed Text": transcription,
99
- "⏱️ Speaking Speed (words/sec)": speed,
100
- }
101
- # Return table as a separate output (Gradio Dataframe)
102
- return result, df_errors
103
-
104
  except Exception as e:
105
  return {"error": str(e)}, pd.DataFrame(columns=["बिगड़ा हुआ शब्द", "संभावित सही शब्द", "गलती का प्रकार"])
106
 
 
64
 
65
  def transcribe_audio(audio_path, original_text):
66
  try:
 
67
  waveform, sample_rate = torchaudio.load(audio_path)
68
+ # Convert to mono
69
  if waveform.shape[0] > 1:
70
  waveform = waveform.mean(dim=0, keepdim=True)
71
+ # Resample to 16000 Hz for model
72
+ if sample_rate != 16000:
73
+ transform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
74
  waveform = transform(waveform)
75
+ # Normalize to [-1, 1]
76
+ waveform = waveform / waveform.abs().max()
77
+
78
+ input_values = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_values
79
+
 
 
 
 
80
  with torch.no_grad():
81
  logits = model(input_values).logits
82
  predicted_ids = torch.argmax(logits, dim=-1)
83
  transcription = processor.decode(predicted_ids[0])
84
 
85
+ # ... rest of your error analysis
 
 
86
 
87
+ return {
 
 
 
 
 
88
  "📝 Transcribed Text": transcription,
89
+ # etc.
90
+ }, df_errors
 
 
 
91
  except Exception as e:
92
  return {"error": str(e)}, pd.DataFrame(columns=["बिगड़ा हुआ शब्द", "संभावित सही शब्द", "गलती का प्रकार"])
93