cdactvm commited on
Commit
ed97bcc
·
verified ·
1 Parent(s): 8472c6f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -28
app.py CHANGED
@@ -1,39 +1,72 @@
1
- import torch
2
- import torchaudio
3
- import gradio as gr
4
- from transformers import Wav2Vec2BertProcessor, Wav2Vec2BertForCTC
5
 
6
- # Set device
7
- device = "cuda" if torch.cuda.is_available() else "cpu"
8
 
9
- # Load processor & model
10
- model_name = "cdactvm/w2v-bert-punjabi" # Change if using a Punjabi ASR model
11
- processor = Wav2Vec2BertProcessor.from_pretrained(model_name)
12
- model = Wav2Vec2BertForCTC.from_pretrained(model_name, torch_dtype=torch.bfloat16).to(device)
13
 
14
- def transcribe(audio_path):
15
- # Load audio file
16
- waveform, sample_rate = torchaudio.load(audio_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
- # Convert stereo to mono (if needed)
19
- if waveform.shape[0] > 1:
20
- waveform = torch.mean(waveform, dim=0, keepdim=True)
21
 
22
- # Resample to 16kHz
23
- if sample_rate != 16000:
24
- waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
 
 
 
 
 
25
 
26
- # Process audio
27
- inputs = processor(waveform.squeeze(0), sampling_rate=16000, return_tensors="pt")
28
- inputs = {key: val.to(device, dtype=torch.bfloat16) for key, val in inputs.items()}
29
 
30
- # Get logits & transcribe
31
- with torch.no_grad():
32
- logits = model(**inputs).logits
33
- predicted_ids = torch.argmax(logits, dim=-1)
34
- transcription = processor.batch_decode(predicted_ids)[0]
35
 
36
- return transcription
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  # Gradio Interface
39
  app = gr.Interface(
 
1
+ # import torch
2
+ # import torchaudio
3
+ # import gradio as gr
4
+ # from transformers import Wav2Vec2BertProcessor, Wav2Vec2BertForCTC
5
 
6
+ # # Set device
7
+ # device = "cuda" if torch.cuda.is_available() else "cpu"
8
 
9
+ # # Load processor & model
10
+ # model_name = "cdactvm/w2v-bert-punjabi" # Change if using a Punjabi ASR model
11
+ # processor = Wav2Vec2BertProcessor.from_pretrained(model_name)
12
+ # model = Wav2Vec2BertForCTC.from_pretrained(model_name, torch_dtype=torch.bfloat16).to(device)
13
 
14
+ # def transcribe(audio_path):
15
+ # # Load audio file
16
+ # waveform, sample_rate = torchaudio.load(audio_path)
17
+
18
+ # # Convert stereo to mono (if needed)
19
+ # if waveform.shape[0] > 1:
20
+ # waveform = torch.mean(waveform, dim=0, keepdim=True)
21
+
22
+ # # Resample to 16kHz
23
+ # if sample_rate != 16000:
24
+ # waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
25
+
26
+ # # Process audio
27
+ # inputs = processor(waveform.squeeze(0), sampling_rate=16000, return_tensors="pt")
28
+ # inputs = {key: val.to(device, dtype=torch.bfloat16) for key, val in inputs.items()}
29
+
30
+ # # Get logits & transcribe
31
+ # with torch.no_grad():
32
+ # logits = model(**inputs).logits
33
+ # predicted_ids = torch.argmax(logits, dim=-1)
34
+ # transcription = processor.batch_decode(predicted_ids)[0]
35
 
36
+ # return transcription
 
 
37
 
38
+ # # Gradio Interface
39
+ # app = gr.Interface(
40
+ # fn=transcribe,
41
+ # inputs=gr.Audio(sources="upload", type="filepath"),
42
+ # outputs="text",
43
+ # title="Punjabi Speech-to-Text",
44
+ # description="Upload an audio file and get the transcription in Punjabi."
45
+ # )
46
 
47
+ # if __name__ == "__main__":
48
+ # app.launch()
 
49
 
 
 
 
 
 
50
 
51
+ import gradio as gr
52
+ import torch
53
+ from transformers import pipeline
54
+
55
+ # Set device
56
+ device = "cuda" if torch.cuda.is_available() else "cpu"
57
+
58
+ # Load ASR pipeline
59
+ asr_pipeline = pipeline(
60
+ "automatic-speech-recognition",
61
+ model="cdactvm/w2v-bert-punjabi", # Replace with a Punjabi ASR model if available
62
+ torch_dtype=torch.bfloat16,
63
+ device=0 if torch.cuda.is_available() else -1 # GPU (0) or CPU (-1)
64
+ )
65
+
66
+ def transcribe(audio_path):
67
+ # Run inference
68
+ result = asr_pipeline(audio_path)
69
+ return result["text"]
70
 
71
  # Gradio Interface
72
  app = gr.Interface(