NLPV commited on
Commit
e3afeb6
·
verified ·
1 Parent(s): 89b33e7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -25
app.py CHANGED
@@ -1,24 +1,18 @@
1
  import gradio as gr
2
  from gtts import gTTS
3
  import tempfile
4
- import os
5
- import torch
6
- from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
7
- import torchaudio
8
  import difflib
9
  import pandas as pd
10
  from Levenshtein import distance as lev_distance
 
11
 
12
- # Load AI4Bharat Hindi model & processor
13
- MODEL_NAME = "ai4bharat/indicwav2vec-hindi"
14
- processor = Wav2Vec2Processor.from_pretrained(MODEL_NAME)
15
- model = Wav2Vec2ForCTC.from_pretrained(MODEL_NAME)
16
 
17
  def play_text(text):
18
  tts = gTTS(text=text, lang='hi', slow=False)
19
  temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp3')
20
  tts.save(temp_file.name)
21
- # Return file for Gradio audio output
22
  return temp_file.name
23
 
24
  def get_error_type(asr_word, correct_word):
@@ -55,37 +49,43 @@ def compare_hindi_sentences(expected, transcribed):
55
  errors.append((transcribed_words[k], "", "Extra word"))
56
  return errors
57
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  def transcribe_audio(audio_path, original_text):
59
  try:
60
- waveform, sample_rate = torchaudio.load(audio_path)
61
- if waveform.shape[0] > 1:
62
- waveform = waveform.mean(dim=0, keepdim=True)
63
- if sample_rate != 16000:
64
- transform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
65
- waveform = transform(waveform)
66
- waveform = waveform / waveform.abs().max()
67
- input_values = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_values
68
- with torch.no_grad():
69
- logits = model(input_values).logits
70
- predicted_ids = torch.argmax(logits, dim=-1)
71
- transcription = processor.decode(predicted_ids[0])
72
  # Error analysis
73
  errors = compare_hindi_sentences(original_text, transcription)
74
  df_errors = pd.DataFrame(errors, columns=["बिगड़ा हुआ शब्द", "संभावित सही शब्द", "गलती का प्रकार"])
75
  # Speaking speed
76
  transcribed_words = transcription.strip().split()
77
- duration = waveform.shape[1] / 16000
78
  speed = round(len(transcribed_words) / duration, 2) if duration > 0 else 0
79
- result = {
 
 
80
  "📝 Transcribed Text": transcription,
81
  "⏱️ Speaking Speed (words/sec)": speed,
 
82
  }
83
- return result, df_errors
84
  except Exception as e:
85
  return {"error": str(e)}, pd.DataFrame(columns=["बिगड़ा हुआ शब्द", "संभावित सही शब्द", "गलती का प्रकार"])
86
 
87
  with gr.Blocks() as app:
88
- gr.Markdown("## 🗣️ Hindi Reading & Pronunciation Practice App (AI4Bharat Model)")
89
  with gr.Row():
90
  input_text = gr.Textbox(label="Paste Hindi Text Here", placeholder="यहाँ हिंदी टेक्स्ट लिखें...")
91
  play_button = gr.Button("🔊 Listen to Text")
 
1
  import gradio as gr
2
  from gtts import gTTS
3
  import tempfile
 
 
 
 
4
  import difflib
5
  import pandas as pd
6
  from Levenshtein import distance as lev_distance
7
+ import whisper
8
 
9
+ # Load Whisper model once (choose "small" or "medium" for better results)
10
+ model = whisper.load_model("small")
 
 
11
 
12
  def play_text(text):
13
  tts = gTTS(text=text, lang='hi', slow=False)
14
  temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp3')
15
  tts.save(temp_file.name)
 
16
  return temp_file.name
17
 
18
  def get_error_type(asr_word, correct_word):
 
49
  errors.append((transcribed_words[k], "", "Extra word"))
50
  return errors
51
 
52
+ def calculate_accuracy(expected, transcribed):
53
+ expected_words = expected.strip().split()
54
+ transcribed_words = transcribed.strip().split()
55
+ matcher = difflib.SequenceMatcher(None, transcribed_words, expected_words)
56
+ correct = 0
57
+ total = len(expected_words)
58
+ for tag, i1, i2, j1, j2 in matcher.get_opcodes():
59
+ if tag == 'equal':
60
+ correct += (j2-j1)
61
+ accuracy = (correct / total) * 100 if total > 0 else 0
62
+ return round(accuracy, 2)
63
+
64
  def transcribe_audio(audio_path, original_text):
65
  try:
66
+ # Use Whisper for transcription
67
+ result = model.transcribe(audio_path, language='hi')
68
+ transcription = result['text'].strip()
 
 
 
 
 
 
 
 
 
69
  # Error analysis
70
  errors = compare_hindi_sentences(original_text, transcription)
71
  df_errors = pd.DataFrame(errors, columns=["बिगड़ा हुआ शब्द", "संभावित सही शब्द", "गलती का प्रकार"])
72
  # Speaking speed
73
  transcribed_words = transcription.strip().split()
74
+ duration = result['segments'][-1]['end'] if result.get('segments') else 1.0
75
  speed = round(len(transcribed_words) / duration, 2) if duration > 0 else 0
76
+ # Accuracy
77
+ accuracy = calculate_accuracy(original_text, transcription)
78
+ result_dict = {
79
  "📝 Transcribed Text": transcription,
80
  "⏱️ Speaking Speed (words/sec)": speed,
81
+ "✅ Reading Accuracy (%)": accuracy,
82
  }
83
+ return result_dict, df_errors
84
  except Exception as e:
85
  return {"error": str(e)}, pd.DataFrame(columns=["बिगड़ा हुआ शब्द", "संभावित सही शब्द", "गलती का प्रकार"])
86
 
87
  with gr.Blocks() as app:
88
+ gr.Markdown("## 🗣️ Hindi Reading & Pronunciation Practice App (OpenAI Whisper)")
89
  with gr.Row():
90
  input_text = gr.Textbox(label="Paste Hindi Text Here", placeholder="यहाँ हिंदी टेक्स्ट लिखें...")
91
  play_button = gr.Button("🔊 Listen to Text")