cdactvm commited on
Commit
9568e5e
·
verified ·
1 Parent(s): ed97bcc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -61
app.py CHANGED
@@ -1,72 +1,39 @@
1
- # import torch
2
- # import torchaudio
3
- # import gradio as gr
4
- # from transformers import Wav2Vec2BertProcessor, Wav2Vec2BertForCTC
5
-
6
- # # Set device
7
- # device = "cuda" if torch.cuda.is_available() else "cpu"
8
-
9
- # # Load processor & model
10
- # model_name = "cdactvm/w2v-bert-punjabi" # Change if using a Punjabi ASR model
11
- # processor = Wav2Vec2BertProcessor.from_pretrained(model_name)
12
- # model = Wav2Vec2BertForCTC.from_pretrained(model_name, torch_dtype=torch.bfloat16).to(device)
13
-
14
- # def transcribe(audio_path):
15
- # # Load audio file
16
- # waveform, sample_rate = torchaudio.load(audio_path)
17
-
18
- # # Convert stereo to mono (if needed)
19
- # if waveform.shape[0] > 1:
20
- # waveform = torch.mean(waveform, dim=0, keepdim=True)
21
-
22
- # # Resample to 16kHz
23
- # if sample_rate != 16000:
24
- # waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
25
-
26
- # # Process audio
27
- # inputs = processor(waveform.squeeze(0), sampling_rate=16000, return_tensors="pt")
28
- # inputs = {key: val.to(device, dtype=torch.bfloat16) for key, val in inputs.items()}
29
-
30
- # # Get logits & transcribe
31
- # with torch.no_grad():
32
- # logits = model(**inputs).logits
33
- # predicted_ids = torch.argmax(logits, dim=-1)
34
- # transcription = processor.batch_decode(predicted_ids)[0]
35
 
36
- # return transcription
 
37
 
38
- # # Gradio Interface
39
- # app = gr.Interface(
40
- # fn=transcribe,
41
- # inputs=gr.Audio(sources="upload", type="filepath"),
42
- # outputs="text",
43
- # title="Punjabi Speech-to-Text",
44
- # description="Upload an audio file and get the transcription in Punjabi."
45
- # )
46
 
47
- # if __name__ == "__main__":
48
- # app.launch()
 
49
 
 
 
 
50
 
51
- import gradio as gr
52
- import torch
53
- from transformers import pipeline
54
 
55
- # Set device
56
- device = "cuda" if torch.cuda.is_available() else "cpu"
 
57
 
58
- # Load ASR pipeline
59
- asr_pipeline = pipeline(
60
- "automatic-speech-recognition",
61
- model="cdactvm/w2v-bert-punjabi", # Replace with a Punjabi ASR model if available
62
- torch_dtype=torch.bfloat16,
63
- device=0 if torch.cuda.is_available() else -1 # GPU (0) or CPU (-1)
64
- )
65
 
66
- def transcribe(audio_path):
67
- # Run inference
68
- result = asr_pipeline(audio_path)
69
- return result["text"]
70
 
71
  # Gradio Interface
72
  app = gr.Interface(
@@ -79,3 +46,36 @@ app = gr.Interface(
79
 
80
  if __name__ == "__main__":
81
  app.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchaudio
3
+ import gradio as gr
4
+ from transformers import Wav2Vec2BertProcessor, Wav2Vec2BertForCTC
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
+ # Set device
7
+ device = "cuda" if torch.cuda.is_available() else "cpu"
8
 
9
+ # Load processor & model
10
+ model_name = "cdactvm/w2v-bert-punjabi" # Change if using a Punjabi ASR model
11
+ processor = Wav2Vec2BertProcessor.from_pretrained(model_name)
12
+ model = Wav2Vec2BertForCTC.from_pretrained(model_name, torch_dtype=torch.float16).to(device)
 
 
 
 
13
 
14
+ def transcribe(audio_path):
15
+ # Load audio file
16
+ waveform, sample_rate = torchaudio.load(audio_path)
17
 
18
+ # Convert stereo to mono (if needed)
19
+ if waveform.shape[0] > 1:
20
+ waveform = torch.mean(waveform, dim=0, keepdim=True)
21
 
22
+ # Resample to 16kHz
23
+ if sample_rate != 16000:
24
+ waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
25
 
26
+ # Process audio
27
+ inputs = processor(waveform.squeeze(0), sampling_rate=16000, return_tensors="pt")
28
+ inputs = {key: val.to(device, dtype=torch.bfloat16) for key, val in inputs.items()}
29
 
30
+ # Get logits & transcribe
31
+ with torch.no_grad():
32
+ logits = model(**inputs).logits
33
+ predicted_ids = torch.argmax(logits, dim=-1)
34
+ transcription = processor.batch_decode(predicted_ids)[0]
 
 
35
 
36
+ return transcription
 
 
 
37
 
38
  # Gradio Interface
39
  app = gr.Interface(
 
46
 
47
  if __name__ == "__main__":
48
  app.launch()
49
+
50
+
51
+ # import gradio as gr
52
+ # import torch
53
+ # from transformers import pipeline
54
+
55
+ # # Set device
56
+ # device = "cuda" if torch.cuda.is_available() else "cpu"
57
+
58
+ # # Load ASR pipeline
59
+ # asr_pipeline = pipeline(
60
+ # "automatic-speech-recognition",
61
+ # model="cdactvm/w2v-bert-punjabi", # Replace with a Punjabi ASR model if available
62
+ # torch_dtype=torch.bfloat16,
63
+ # device=0 if torch.cuda.is_available() else -1 # GPU (0) or CPU (-1)
64
+ # )
65
+
66
+ # def transcribe(audio_path):
67
+ # # Run inference
68
+ # result = asr_pipeline(audio_path)
69
+ # return result["text"]
70
+
71
+ # # Gradio Interface
72
+ # app = gr.Interface(
73
+ # fn=transcribe,
74
+ # inputs=gr.Audio(sources="upload", type="filepath"),
75
+ # outputs="text",
76
+ # title="Punjabi Speech-to-Text",
77
+ # description="Upload an audio file and get the transcription in Punjabi."
78
+ # )
79
+
80
+ # if __name__ == "__main__":
81
+ # app.launch()