S2I / app.py
Pavankalyan's picture
Upload app.py with huggingface_hub
6afc25f
raw
history blame
5.57 kB
from model import Wav2VecModel
from dataset import S2IDataset, collate_fn
import requests
requests.packages.urllib3.disable_warnings()
import torch
import torch.nn as nn
import torchaudio
import torch.nn.functional as F
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
# SEED
SEED=100
pl.utilities.seed.seed_everything(SEED)
torch.manual_seed(SEED)
import os
os.environ['WANDB_MODE'] = 'online'
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="1"
class LightningModel(pl.LightningModule):
def __init__(self,):
super().__init__()
self.model = Wav2VecModel()
def forward(self, x):
return self.model(x)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-5)
return [optimizer]
def loss_fn(self, prediction, targets):
return nn.CrossEntropyLoss()(prediction, targets)
def training_step(self, batch, batch_idx):
x, y = batch
y = y.view(-1)
logits = self(x)
probs = F.softmax(logits, dim=1)
loss = self.loss_fn(logits, y)
winners = logits.argmax(dim=1)
corrects = (winners == y)
acc = corrects.sum().float()/float(logits.size(0))
self.log('train/loss', loss, on_step=False, on_epoch=True, prog_bar=True)
self.log('train/acc', acc, on_step=False, on_epoch=True, prog_bar=True)
torch.cuda.empty_cache()
return {
'loss':loss,
'acc':acc
}
def validation_step(self, batch, batch_idx):
x, y = batch
y = y.view(-1)
logits = self(x)
loss = self.loss_fn(logits, y)
winners = logits.argmax(dim=1)
corrects = (winners == y)
acc = corrects.sum().float() / float( logits.size(0))
self.log('val/loss' , loss, on_step=False, on_epoch=True, prog_bar=True)
self.log('val/acc',acc, on_step=False, on_epoch=True, prog_bar=True)
return {'val_loss':loss,
'val_acc':acc,
}
def test_step(self, batch, batch_idx):
x, y = batch
y = y.view(-1)
logits = self(x)
loss = self.loss_fn(logits, y)
winners = logits.argmax(dim=1)
corrects = (winners == y)
acc = corrects.sum().float() / float( logits.size(0))
self.log('val/loss' , loss, on_step=False, on_epoch=True, prog_bar=True)
self.log('val/acc',acc, on_step=False, on_epoch=True, prog_bar=True)
return {'val_loss':loss,
'val_acc':acc,
}
def predict(self, wav):
self.eval()
with torch.no_grad():
output = self.forward(wav)
predicted_class = torch.argmax(output, dim=1)
return predicted_class
if __name__ == "__main__":
dataset = S2IDataset(
csv_path="./speech-to-intent/train.csv",
wav_dir_path="/home/development/pavan/Telesoft/speech-to-intent-dataset/baselines/speech-to-intent"
)
test_dataset = S2IDataset(
csv_path="./speech-to-intent/test.csv",
wav_dir_path="/home/development/pavan/Telesoft/speech-to-intent-dataset/baselines/speech-to-intent"
)
train_len = int(len(dataset) * 0.90)
val_len = len(dataset) - train_len
print(train_len, val_len)
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_len, val_len], generator=torch.Generator().manual_seed(SEED))
print(len(test_dataset))
trainloader = torch.utils.data.DataLoader(
train_dataset,
batch_size=4,
shuffle=True,
num_workers=4,
collate_fn = collate_fn,
)
valloader = torch.utils.data.DataLoader(
val_dataset,
batch_size=4,
num_workers=4,
collate_fn = collate_fn,
)
testloader = torch.utils.data.DataLoader(
test_dataset,
#batch_size=4,
num_workers=4,
collate_fn = collate_fn,
)
print(torch.cuda.mem_get_info())
model = LightningModel()
run_name = "wav2vec"
logger = WandbLogger(
name=run_name,
project='S2I-baseline'
)
model_checkpoint_callback = ModelCheckpoint(
dirpath='checkpoints',
monitor='val/acc',
mode='max',
verbose=1,
filename=run_name + "-epoch={epoch}.ckpt")
trainer = Trainer(
fast_dev_run=False,
gpus=1,
max_epochs=5,
checkpoint_callback=True,
callbacks=[
model_checkpoint_callback,
],
logger=logger,
)
checkpoint_path = "./checkpoints/wav2vec-epoch=epoch=4.ckpt.ckpt"
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint['state_dict'])
trainer = Trainer(
gpus=1
)
#trainer.fit(model, train_dataloader=trainloader, val_dataloaders=valloader)
#trainer.test(model,dataloaders=testloader,verbose=True)
wav_path = "./speech-to-intent/wav_audios/92145547-3ab6-44e0-9245-085642fc4318.wav"
resmaple = torchaudio.transforms.Resample(8000, 16000)
wav_tensor,_ = torchaudio.load(wav_path)
wav_tensor = resmaple(wav_tensor)
model = model.to('cuda')
y_hat = model.predict(wav_tensor)
#with torch.no_grad():
# y_hat = model(wav_tensor)
print(y_hat)