cdactvm commited on
Commit
966d76f
·
verified ·
1 Parent(s): a1377ca

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -26
app.py CHANGED
@@ -1,46 +1,48 @@
1
  import torch
2
- import gradio as gr
3
  import torchaudio
4
- from transformers import Wav2Vec2ForCTC, AutoProcessor
5
- from quanto import qint8, quantize, freeze
 
 
 
6
 
7
- # Load and quantize the model
8
- model_name = "cdactvm/w2v-bert-punjabi"
9
- model = Wav2Vec2ForCTC.from_pretrained(model_name) # Ensure it's a CTC model
10
- processor = AutoProcessor.from_pretrained(model_name)
11
 
12
- # Quantization
13
- quantize(model, weights=qint8, activations=None)
14
- freeze(model)
15
 
16
- # Audio transcription function
17
- def transcribe(audio):
18
- waveform, sample_rate = torchaudio.load(audio)
19
 
20
- # Ensure 16kHz sample rate
21
  if sample_rate != 16000:
22
- waveform = torchaudio.transforms.Resample(sample_rate, 16000)(waveform)
23
 
24
  # Process audio
25
  inputs = processor(waveform.squeeze(0), sampling_rate=16000, return_tensors="pt")
 
26
 
27
- # Run inference
28
  with torch.no_grad():
29
- logits = model(**inputs).logits # Ensure model has 'logits'
30
-
31
- # Decode transcription
32
  predicted_ids = torch.argmax(logits, dim=-1)
33
  transcription = processor.batch_decode(predicted_ids)[0]
34
-
35
  return transcription
36
 
37
- # Gradio UI
38
- iface = gr.Interface(
39
  fn=transcribe,
40
- inputs=gr.Audio(sources="upload", type="filepath"),
41
  outputs="text",
42
- title="Punjabi Speech Recognition",
43
- description="Upload an audio file and get a Punjabi transcription using a quantized model.",
44
  )
45
 
46
- iface.launch()
 
 
1
  import torch
 
2
  import torchaudio
3
+ import gradio as gr
4
+ from transformers import Wav2Vec2BertProcessor, Wav2Vec2BertForCTC
5
+
6
+ # Set device
7
+ device = "cuda" if torch.cuda.is_available() else "cpu"
8
 
9
+ # Load processor & model
10
+ model_name = "cdactvm/w2v-bert-punjabi" # Change if using a Punjabi ASR model
11
+ processor = Wav2Vec2BertProcessor.from_pretrained(model_name)
12
+ model = Wav2Vec2BertForCTC.from_pretrained(model_name, torch_dtype=torch.bfloat16).to(device)
13
 
14
+ def transcribe(audio_path):
15
+ # Load audio file
16
+ waveform, sample_rate = torchaudio.load(audio_path)
17
 
18
+ # Convert stereo to mono (if needed)
19
+ if waveform.shape[0] > 1:
20
+ waveform = torch.mean(waveform, dim=0, keepdim=True)
21
 
22
+ # Resample to 16kHz
23
  if sample_rate != 16000:
24
+ waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
25
 
26
  # Process audio
27
  inputs = processor(waveform.squeeze(0), sampling_rate=16000, return_tensors="pt")
28
+ inputs = {key: val.to(device, dtype=torch.bfloat16) for key, val in inputs.items()}
29
 
30
+ # Get logits & transcribe
31
  with torch.no_grad():
32
+ logits = model(**inputs).logits
 
 
33
  predicted_ids = torch.argmax(logits, dim=-1)
34
  transcription = processor.batch_decode(predicted_ids)[0]
35
+
36
  return transcription
37
 
38
+ # Gradio Interface
39
+ app = gr.Interface(
40
  fn=transcribe,
41
+ inputs=gr.Audio(source="upload", type="filepath"),
42
  outputs="text",
43
+ title="Punjabi Speech-to-Text",
44
+ description="Upload an audio file and get the transcription in Punjabi."
45
  )
46
 
47
+ if __name__ == "__main__":
48
+ app.launch()