cdactvm commited on
Commit
82814b2
·
verified ·
1 Parent(s): 50710a8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -44
app.py CHANGED
@@ -9,9 +9,43 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
9
  # Load processor & model
10
  model_name = "cdactvm/w2v-bert-punjabi" # Change if using a Punjabi ASR model
11
  processor = Wav2Vec2BertProcessor.from_pretrained(model_name)
12
- model = Wav2Vec2BertForCTC.from_pretrained(model_name, torch_dtype=torch.bfloat16).to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
- def transcribe(audio_path):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  # Load audio file
16
  waveform, sample_rate = torchaudio.load(audio_path)
17
 
@@ -29,53 +63,37 @@ def transcribe(audio_path):
29
 
30
  # Get logits & transcribe
31
  with torch.no_grad():
32
- logits = model(**inputs).logits
33
  predicted_ids = torch.argmax(logits, dim=-1)
34
  transcription = processor.batch_decode(predicted_ids)[0]
35
 
36
  return transcription
37
 
38
- # Gradio Interface
39
- app = gr.Interface(
40
- fn=transcribe,
41
- inputs=gr.Audio(sources="upload", type="filepath"),
42
- outputs="text",
43
- title="Punjabi Speech-to-Text",
44
- description="Upload an audio file and get the transcription in Punjabi."
45
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  if __name__ == "__main__":
48
  app.launch()
49
-
50
-
51
- # import gradio as gr
52
- # import torch
53
- # from transformers import pipeline
54
-
55
- # # Set device
56
- # device = "cuda" if torch.cuda.is_available() else "cpu"
57
-
58
- # # Load ASR pipeline
59
- # asr_pipeline = pipeline(
60
- # "automatic-speech-recognition",
61
- # model="cdactvm/w2v-bert-punjabi", # Replace with a Punjabi ASR model if available
62
- # torch_dtype=torch.bfloat16,
63
- # device=0 if torch.cuda.is_available() else -1 # GPU (0) or CPU (-1)
64
- # )
65
-
66
- # def transcribe(audio_path):
67
- # # Run inference
68
- # result = asr_pipeline(audio_path)
69
- # return result["text"]
70
-
71
- # # Gradio Interface
72
- # app = gr.Interface(
73
- # fn=transcribe,
74
- # inputs=gr.Audio(sources="upload", type="filepath"),
75
- # outputs="text",
76
- # title="Punjabi Speech-to-Text",
77
- # description="Upload an audio file and get the transcription in Punjabi."
78
- # )
79
-
80
- # if __name__ == "__main__":
81
- # app.launch()
 
9
  # Load processor & model
10
  model_name = "cdactvm/w2v-bert-punjabi" # Change if using a Punjabi ASR model
11
  processor = Wav2Vec2BertProcessor.from_pretrained(model_name)
12
+ # Loading the original model.
13
+ original_model=Wav2Vec2BertForCTC.from_pretrained(model_name)
14
+ # Explicitly allow Wav2Vec2BertForCTC during unpickling3+
15
+ torch.serialization.add_safe_globals([Wav2Vec2BertForCTC])
16
+ # Load the full quantized model
17
+ quantized_model = torch.load("model_name", weights_only=False)
18
+ quantized_model.eval()
19
+
20
+ #####################################################
21
+ # recognize speech using original model
22
+ def transcribe_original_model(audio_path):
23
+ # Load audio file
24
+ waveform, sample_rate = torchaudio.load(audio_path)
25
+
26
+ # Convert stereo to mono (if needed)
27
+ if waveform.shape[0] > 1:
28
+ waveform = torch.mean(waveform, dim=0, keepdim=True)
29
 
30
+ # Resample to 16kHz
31
+ if sample_rate != 16000:
32
+ waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
33
+
34
+ # Process audio
35
+ inputs = processor(waveform.squeeze(0), sampling_rate=16000, return_tensors="pt")
36
+ inputs = {key: val.to(device, dtype=torch.bfloat16) for key, val in inputs.items()}
37
+
38
+ # Get logits & transcribe
39
+ with torch.no_grad():
40
+ logits = original_model(**inputs).logits
41
+ predicted_ids = torch.argmax(logits, dim=-1)
42
+ transcription = processor.batch_decode(predicted_ids)[0]
43
+
44
+ return transcription
45
+
46
+
47
+ # recognize speech using quantized model.
48
+ def transcribe_quantized_model(audio_path):
49
  # Load audio file
50
  waveform, sample_rate = torchaudio.load(audio_path)
51
 
 
63
 
64
  # Get logits & transcribe
65
  with torch.no_grad():
66
+ logits = quantized_model(**inputs).logits
67
  predicted_ids = torch.argmax(logits, dim=-1)
68
  transcription = processor.batch_decode(predicted_ids)[0]
69
 
70
  return transcription
71
 
72
+ def select_lng(lng, mic=None, file=None):
73
+ if mic is not None:
74
+ audio = mic
75
+ elif file is not None:
76
+ audio = file
77
+ else:
78
+ return "You must either provide a mic recording or a file"
79
+
80
+ if lng == "original_model":
81
+ return transcribe_original_model(audio)
82
+ elif lng == "quantized_model":
83
+ return transcribe_quantized_model(audio)
84
+
85
+
86
+ # Gradio Interface
87
+ demo=gr.Interface(
88
+ fn=select_lng,
89
+ inputs=[
90
+ gr.Dropdown(["original_model","quantized_model"],label="Select Model"),
91
+ gr.Audio(sources=["microphone","upload"], type="filepath"),
92
+ ],
93
+ outputs=["textbox"],
94
+ title="Automatic Speech Recognition",
95
+ description = "Upload an audio file and get the transcription in Punjabi.",
96
+ )
97
 
98
  if __name__ == "__main__":
99
  app.launch()