minjibi commited on
Commit
64c7322
·
1 Parent(s): 2ab0960
Files changed (1) hide show
  1. app.py +55 -0
app.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #Importing all the necessary packages
2
+ import nltk
3
+ import librosa
4
+ import torch
5
+ import gradio as gr
6
+ from transformers import Wav2Vec2Tokenizer, Wav2Vec2ForCTC
7
+ nltk.download("punkt")
8
+
9
+ #Loading the pre-trained model and the tokenizer
10
+ model_name = "shizukanabasho/north2"
11
+ tokenizer = Wav2Vec2Tokenizer.from_pretrained(model_name)
12
+ model = Wav2Vec2ForCTC.from_pretrained(model_name)
13
+
14
+ def load_data(input_file):
15
+
16
+ #reading the file
17
+ speech, sample_rate = librosa.load(input_file)
18
+ #make it 1-D
19
+ if len(speech.shape) > 1:
20
+ speech = speech[:,0] + speech[:,1]
21
+ #Resampling the audio at 16KHz
22
+ if sample_rate !=16000:
23
+ speech = librosa.resample(speech, sample_rate,16000)
24
+ return speech
25
+
26
+ def correct_casing(input_sentence):
27
+
28
+ sentences = nltk.sent_tokenize(input_sentence)
29
+ return (''.join([s.replace(s[0],s[0].capitalize(),1) for s in sentences]))
30
+
31
+ def asr_transcript(input_file):
32
+
33
+ speech = load_data(input_file)
34
+ #Tokenize
35
+ input_values = tokenizer(speech, return_tensors="pt").input_values
36
+ #Take logits
37
+ logits = model(input_values).logits
38
+ #Take argmax
39
+ predicted_ids = torch.argmax(logits, dim=-1)
40
+ #Get the words from predicted word ids
41
+ transcription = tokenizer.decode(predicted_ids[0])
42
+ #Correcting the letter casing
43
+ # transcription = correct_casing(transcription.lower())
44
+ return transcription
45
+
46
+ gr.Interface(
47
+ asr_transcript,
48
+ inputs=[
49
+ gr.Audio(source="microphone", type="filepath", optional=True),
50
+ gr.Audio(source="upload", type="filepath", optional=True),
51
+ title="ASR using Wav2Vec2.0",
52
+ description = "This application displays transcribed text for given audio input",
53
+ ],
54
+ outputs="text",
55
+ ).launch()