Spaces:
Runtime error
Runtime error
Pavankalyan
commited on
Commit
·
8d5928a
1
Parent(s):
e1522cf
Update app.py
Browse files
app.py
CHANGED
@@ -3,7 +3,7 @@ from model import Wav2VecModel
|
|
3 |
from dataset import S2IDataset, collate_fn
|
4 |
import requests
|
5 |
requests.packages.urllib3.disable_warnings()
|
6 |
-
|
7 |
import torch
|
8 |
import torch.nn as nn
|
9 |
import torchaudio
|
@@ -103,72 +103,12 @@ class LightningModel(pl.LightningModule):
|
|
103 |
return predicted_class
|
104 |
|
105 |
if __name__ == "__main__":
|
106 |
-
|
107 |
-
dataset = S2IDataset(
|
108 |
-
csv_path="./speech-to-intent/train.csv",
|
109 |
-
wav_dir_path="/home/development/pavan/Telesoft/speech-to-intent-dataset/baselines/speech-to-intent"
|
110 |
-
)
|
111 |
-
|
112 |
-
test_dataset = S2IDataset(
|
113 |
-
csv_path="./speech-to-intent/test.csv",
|
114 |
-
wav_dir_path="/home/development/pavan/Telesoft/speech-to-intent-dataset/baselines/speech-to-intent"
|
115 |
-
)
|
116 |
-
|
117 |
-
train_len = int(len(dataset) * 0.90)
|
118 |
-
val_len = len(dataset) - train_len
|
119 |
-
print(train_len, val_len)
|
120 |
-
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_len, val_len], generator=torch.Generator().manual_seed(SEED))
|
121 |
-
print(len(test_dataset))
|
122 |
-
|
123 |
-
trainloader = torch.utils.data.DataLoader(
|
124 |
-
train_dataset,
|
125 |
-
batch_size=4,
|
126 |
-
shuffle=True,
|
127 |
-
num_workers=4,
|
128 |
-
collate_fn = collate_fn,
|
129 |
-
)
|
130 |
-
|
131 |
-
valloader = torch.utils.data.DataLoader(
|
132 |
-
val_dataset,
|
133 |
-
batch_size=4,
|
134 |
-
num_workers=4,
|
135 |
-
collate_fn = collate_fn,
|
136 |
-
)
|
137 |
-
|
138 |
-
testloader = torch.utils.data.DataLoader(
|
139 |
-
test_dataset,
|
140 |
-
#batch_size=4,
|
141 |
-
num_workers=4,
|
142 |
-
collate_fn = collate_fn,
|
143 |
-
)
|
144 |
-
|
145 |
print(torch.cuda.mem_get_info())
|
146 |
|
147 |
model = LightningModel()
|
148 |
|
149 |
run_name = "wav2vec"
|
150 |
-
logger = WandbLogger(
|
151 |
-
name=run_name,
|
152 |
-
project='S2I-baseline'
|
153 |
-
)
|
154 |
-
|
155 |
-
model_checkpoint_callback = ModelCheckpoint(
|
156 |
-
dirpath='checkpoints',
|
157 |
-
monitor='val/acc',
|
158 |
-
mode='max',
|
159 |
-
verbose=1,
|
160 |
-
filename=run_name + "-epoch={epoch}.ckpt")
|
161 |
|
162 |
-
trainer = Trainer(
|
163 |
-
fast_dev_run=False,
|
164 |
-
gpus=1,
|
165 |
-
max_epochs=5,
|
166 |
-
checkpoint_callback=True,
|
167 |
-
callbacks=[
|
168 |
-
model_checkpoint_callback,
|
169 |
-
],
|
170 |
-
logger=logger,
|
171 |
-
)
|
172 |
checkpoint_path = "./checkpoints/wav2vec-epoch=epoch=4.ckpt.ckpt"
|
173 |
checkpoint = torch.load(checkpoint_path)
|
174 |
model.load_state_dict(checkpoint['state_dict'])
|
@@ -187,6 +127,29 @@ if __name__ == "__main__":
|
|
187 |
#with torch.no_grad():
|
188 |
# y_hat = model(wav_tensor)
|
189 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
190 |
print(y_hat)
|
|
|
|
|
191 |
|
192 |
|
|
|
3 |
from dataset import S2IDataset, collate_fn
|
4 |
import requests
|
5 |
requests.packages.urllib3.disable_warnings()
|
6 |
+
import gradio as gr
|
7 |
import torch
|
8 |
import torch.nn as nn
|
9 |
import torchaudio
|
|
|
103 |
return predicted_class
|
104 |
|
105 |
if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
print(torch.cuda.mem_get_info())
|
107 |
|
108 |
model = LightningModel()
|
109 |
|
110 |
run_name = "wav2vec"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
checkpoint_path = "./checkpoints/wav2vec-epoch=epoch=4.ckpt.ckpt"
|
113 |
checkpoint = torch.load(checkpoint_path)
|
114 |
model.load_state_dict(checkpoint['state_dict'])
|
|
|
127 |
#with torch.no_grad():
|
128 |
# y_hat = model(wav_tensor)
|
129 |
|
130 |
+
def trabscribe(audio):
|
131 |
+
wav_tensor,_ = audio
|
132 |
+
wav_tensor = resmaple(wav_tensor)
|
133 |
+
#model = model.to('cuda')
|
134 |
+
y_hat = model.predict(wav_tensor)
|
135 |
+
labels = {0:"branch_address : enquiry about bank branch location",
|
136 |
+
1:"activate_card : enquiry about activating card products",
|
137 |
+
2:"past_transactions : enquiry about past transactions in a specific time period",
|
138 |
+
3:"dispatch_status : enquiry about the dispatch status of card products",
|
139 |
+
4:"outstanding_balance : enquiry about outstanding balance on card products",
|
140 |
+
5:"card_issue : report about an issue with using card products",
|
141 |
+
6:"ifsc_code : enquiry about IFSC code of bank branch",
|
142 |
+
7:"generate_pin : enquiry about changing or generating a new pin for their card product",
|
143 |
+
8:"unauthorised_transaction : report about an unauthorised or fraudulent transaction",
|
144 |
+
9:"loan_query : enquiry about different kinds of loans",
|
145 |
+
10:"balance_enquiry : enquiry about bank account balance",
|
146 |
+
11:"change_limit : enquiry about changing the limit for card products",
|
147 |
+
12:"block : enquiry about blocking card or banking product",
|
148 |
+
13:"lost : report about losing a card product}
|
149 |
+
return labels[y_hat]
|
150 |
+
|
151 |
print(y_hat)
|
152 |
+
get_intent = gr.Interface(fn = transcribe,
|
153 |
+
gr.Audio(source="microphone"), outputs="text").launch()
|
154 |
|
155 |
|