piecurus commited on
Commit
da0005f
·
1 Parent(s): 540f7a6

added functionality for long text

Browse files
Files changed (1) hide show
  1. app.py +163 -42
app.py CHANGED
@@ -1,66 +1,187 @@
1
- #References: 1. https://www.kdnuggets.com/2021/03/speech-text-wav2vec.html
2
- #2. https://www.youtube.com/watch?v=4CoVcsxZphE
3
- #3. https://www.analyticsvidhya.com/blog/2021/02/hugging-face-introduces-the-first-automatic-speech-recognition-model-wav2vec2/
 
 
 
 
 
 
 
 
 
4
 
5
  #Importing all the necessary packages
6
  import nltk
7
  import librosa
 
8
  import torch
9
  import gradio as gr
10
  from transformers import Wav2Vec2Tokenizer, Wav2Vec2ForCTC
11
  nltk.download("punkt")
12
 
 
 
 
 
13
  #Loading the model and the tokenizer
14
  model_name = "facebook/wav2vec2-base-960h"
15
- tokenizer = Wav2Vec2Tokenizer.from_pretrained(model_name)
 
 
16
  model = Wav2Vec2ForCTC.from_pretrained(model_name)
17
 
18
 
 
 
 
19
  def load_data(input_file):
20
-
21
- """ Function for resampling to ensure that the speech input is sampled at 16KHz.
22
- """
23
- #read the file
24
- speech, sample_rate = librosa.load(input_file)
25
- #make it 1-D
26
- if len(speech.shape) > 1:
27
- speech = speech[:,0] + speech[:,1]
28
- #Resampling at 16KHz since wav2vec2-base-960h is pretrained and fine-tuned on speech audio sampled at 16 KHz.
29
- if sample_rate !=16000:
30
- speech = librosa.resample(speech, sample_rate,16000)
31
- return speech
32
-
33
-
 
 
34
 
35
  def correct_casing(input_sentence):
36
- """ This function is for correcting the casing of the generated transcribed text
37
- """
38
- sentences = nltk.sent_tokenize(input_sentence)
39
- return (' '.join([s.replace(s[0],s[0].capitalize(),1) for s in sentences]))
40
-
 
 
41
 
42
 
43
  def asr_transcript(input_file):
44
- """This function generates transcripts for the provided audio input
45
- """
46
- speech = load_data(input_file)
47
-
48
- #Tokenize
49
- input_values = tokenizer(speech, return_tensors="pt").input_values
50
- #Take logits
51
- logits = model(input_values).logits
52
- #Take argmax
53
- predicted_ids = torch.argmax(logits, dim=-1)
54
- #Get the words from predicted word ids
55
- transcription = tokenizer.decode(predicted_ids[0])
56
- #Output is all upper case
57
- transcription = correct_casing(transcription.lower())
58
- return transcription
59
-
60
-
61
- gr.Interface(asr_transcript,
62
- inputs = gr.inputs.Audio(source="microphone", type="filepath", optional=True, label="Please record your voice"),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  outputs = gr.outputs.Textbox(label="Output Text"),
64
  title="ASR using Wav2Vec 2.0",
65
  description = "This application displays transcribed text for given audio input",
66
  examples = [["Test_File1.wav"], ["Test_File2.wav"], ["Test_File3.wav"]], theme="grass").launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ # In[ ]:
5
+
6
+
7
+ # conver mp3 to wav
8
+ # ffmpeg -i test_5.mp3 -b:a 16000 test_5.wav
9
+
10
+
11
+ # In[1]:
12
+
13
 
14
  #Importing all the necessary packages
15
  import nltk
16
  import librosa
17
+ import IPython.display
18
  import torch
19
  import gradio as gr
20
  from transformers import Wav2Vec2Tokenizer, Wav2Vec2ForCTC
21
  nltk.download("punkt")
22
 
23
+
24
+ # In[2]:
25
+
26
+
27
  #Loading the model and the tokenizer
28
  model_name = "facebook/wav2vec2-base-960h"
29
+
30
+ #model_name = "facebook/wav2vec2-large-xlsr-53"
31
+ tokenizer = Wav2Vec2Tokenizer.from_pretrained(model_name)#omdel_name
32
  model = Wav2Vec2ForCTC.from_pretrained(model_name)
33
 
34
 
35
+ # In[3]:
36
+
37
+
38
  def load_data(input_file):
39
+ """ Function for resampling to ensure that the speech input is sampled at 16KHz.
40
+ """
41
+ #read the file
42
+ speech, sample_rate = librosa.load(input_file)
43
+ #make it 1-D
44
+ if len(speech.shape) > 1:
45
+ speech = speech[:,0] + speech[:,1]
46
+ #Resampling at 16KHz since wav2vec2-base-960h is pretrained and fine-tuned on speech audio sampled at 16 KHz.
47
+ if sample_rate !=16000:
48
+ speech = librosa.resample(speech, sample_rate,16000)
49
+ #speeches = librosa.effects.split(speech)
50
+ return speech
51
+
52
+
53
+ # In[4]:
54
+
55
 
56
  def correct_casing(input_sentence):
57
+ """ This function is for correcting the casing of the generated transcribed text
58
+ """
59
+ sentences = nltk.sent_tokenize(input_sentence)
60
+ return (' '.join([s.replace(s[0],s[0].capitalize(),1) for s in sentences]))
61
+
62
+
63
+ # In[5]:
64
 
65
 
66
  def asr_transcript(input_file):
67
+ """This function generates transcripts for the provided audio input
68
+ """
69
+ speech = load_data(input_file)
70
+ #Tokenize
71
+ input_values = tokenizer(speech, return_tensors="pt").input_values
72
+ #Take logits
73
+ logits = model(input_values).logits
74
+ #Take argmax
75
+ predicted_ids = torch.argmax(logits, dim=-1)
76
+ #Get the words from predicted word ids
77
+ transcription = tokenizer.decode(predicted_ids[0])
78
+ #Output is all upper case
79
+ transcription = correct_casing(transcription.lower())
80
+ return transcription
81
+
82
+
83
+ # In[6]:
84
+
85
+
86
+ def asr_transcript_long(input_file,tokenizer=tokenizer, model=model ):
87
+ transcript = ""
88
+ # Ensure that the sample rate is 16k
89
+ sample_rate = librosa.get_samplerate(input_file)
90
+
91
+ # Stream over 30 seconds chunks rather than load the full file
92
+ stream = librosa.stream(
93
+ input_file,
94
+ block_length=30,
95
+ frame_length=sample_rate, #16000,
96
+ hop_length=sample_rate, #16000
97
+ )
98
+
99
+ for speech in stream:
100
+ if len(speech.shape) > 1:
101
+ speech = speech[:, 0] + speech[:, 1]
102
+ if sample_rate !=16000:
103
+ speech = librosa.resample(speech, sample_rate,16000)
104
+ input_values = tokenizer(speech, return_tensors="pt").input_values
105
+ logits = model(input_values).logits
106
+
107
+ predicted_ids = torch.argmax(logits, dim=-1)
108
+ transcription = tokenizer.decode(predicted_ids[0])
109
+ #transcript += correct_sentence(transcription.lower())
110
+ transcript += correct_casing(transcription.lower())
111
+ transcript += " "
112
+
113
+ return transcript
114
+
115
+ from pydub import AudioSegment
116
+ from pydub.silence import split_on_silence
117
+ from pydub.playback import play
118
+
119
+ sound = AudioSegment.from_file("./test_2.wav", format="wav")
120
+ chunks = split_on_silence(
121
+ sound,
122
+
123
+ # split on silences longer than 1000ms (1 sec)
124
+ min_silence_len=5000,
125
+
126
+ # anything under -16 dBFS is considered silence
127
+ silence_thresh=-32,
128
+
129
+ # keep 200 ms of leading/trailing silence
130
+ keep_silence=500
131
+ )#read the file
132
+ speech, sample_rate = librosa.load('./test_2.wav')
133
+ #make it 1-D
134
+ if len(speech.shape) > 1:
135
+ speech = speech[:,0] + speech[:,1]
136
+ #Resampling at 16KHz since wav2vec2-base-960h is pretrained and fine-tuned on speech audio sampled at 16 KHz.
137
+ if sample_rate !=16000:
138
+ speech = librosa.resample(speech, sample_rate,16000)
139
+ part_of_speech = librosa.effects.split(speech)idx = -1
140
+ IPython.display.Audio(data=speech[part_of_speech[idx,0]:part_of_speech[idx,1]], rate=16000)
141
+ # In[ ]:
142
+
143
+
144
+ gr.Interface(asr_transcript_long,
145
+ #inputs = gr.inputs.Audio(source="microphone", type="filepath", optional=True, label="Please record your voice"),
146
+ inputs = gr.inputs.Audio(source="upload", type="filepath", optional=True, label="Upload your file here"),
147
  outputs = gr.outputs.Textbox(label="Output Text"),
148
  title="ASR using Wav2Vec 2.0",
149
  description = "This application displays transcribed text for given audio input",
150
  examples = [["Test_File1.wav"], ["Test_File2.wav"], ["Test_File3.wav"]], theme="grass").launch()
151
+
152
+
153
+ # In[ ]:
154
+
155
+
156
+
157
+
158
+
159
+ # In[ ]:
160
+
161
+
162
+
163
+
164
+
165
+ # In[ ]:
166
+
167
+
168
+
169
+
170
+
171
+ # In[7]:
172
+
173
+
174
+ #temp = asr_transcript_long('./test_2.wav')
175
+
176
+
177
+ # In[ ]:
178
+
179
+
180
+
181
+
182
+
183
+ # In[ ]:
184
+
185
+
186
+
187
+