Spaces:
Runtime error
Runtime error
from model import Wav2VecModel | |
from dataset import S2IDataset, collate_fn | |
import requests | |
requests.packages.urllib3.disable_warnings() | |
import torch | |
import torch.nn as nn | |
import torchaudio | |
import torch.nn.functional as F | |
import pytorch_lightning as pl | |
from pytorch_lightning import Trainer | |
from pytorch_lightning.callbacks import ModelCheckpoint | |
from pytorch_lightning.loggers import WandbLogger | |
# SEED | |
SEED=100 | |
pl.utilities.seed.seed_everything(SEED) | |
torch.manual_seed(SEED) | |
import os | |
os.environ['WANDB_MODE'] = 'online' | |
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" | |
os.environ["CUDA_VISIBLE_DEVICES"]="1" | |
class LightningModel(pl.LightningModule): | |
def __init__(self,): | |
super().__init__() | |
self.model = Wav2VecModel() | |
def forward(self, x): | |
return self.model(x) | |
def configure_optimizers(self): | |
optimizer = torch.optim.Adam(self.parameters(), lr=1e-5) | |
return [optimizer] | |
def loss_fn(self, prediction, targets): | |
return nn.CrossEntropyLoss()(prediction, targets) | |
def training_step(self, batch, batch_idx): | |
x, y = batch | |
y = y.view(-1) | |
logits = self(x) | |
probs = F.softmax(logits, dim=1) | |
loss = self.loss_fn(logits, y) | |
winners = logits.argmax(dim=1) | |
corrects = (winners == y) | |
acc = corrects.sum().float()/float(logits.size(0)) | |
self.log('train/loss', loss, on_step=False, on_epoch=True, prog_bar=True) | |
self.log('train/acc', acc, on_step=False, on_epoch=True, prog_bar=True) | |
torch.cuda.empty_cache() | |
return { | |
'loss':loss, | |
'acc':acc | |
} | |
def validation_step(self, batch, batch_idx): | |
x, y = batch | |
y = y.view(-1) | |
logits = self(x) | |
loss = self.loss_fn(logits, y) | |
winners = logits.argmax(dim=1) | |
corrects = (winners == y) | |
acc = corrects.sum().float() / float( logits.size(0)) | |
self.log('val/loss' , loss, on_step=False, on_epoch=True, prog_bar=True) | |
self.log('val/acc',acc, on_step=False, on_epoch=True, prog_bar=True) | |
return {'val_loss':loss, | |
'val_acc':acc, | |
} | |
def test_step(self, batch, batch_idx): | |
x, y = batch | |
y = y.view(-1) | |
logits = self(x) | |
loss = self.loss_fn(logits, y) | |
winners = logits.argmax(dim=1) | |
corrects = (winners == y) | |
acc = corrects.sum().float() / float( logits.size(0)) | |
self.log('val/loss' , loss, on_step=False, on_epoch=True, prog_bar=True) | |
self.log('val/acc',acc, on_step=False, on_epoch=True, prog_bar=True) | |
return {'val_loss':loss, | |
'val_acc':acc, | |
} | |
def predict(self, wav): | |
self.eval() | |
with torch.no_grad(): | |
output = self.forward(wav) | |
predicted_class = torch.argmax(output, dim=1) | |
return predicted_class | |
if __name__ == "__main__": | |
dataset = S2IDataset( | |
csv_path="./speech-to-intent/train.csv", | |
wav_dir_path="/home/development/pavan/Telesoft/speech-to-intent-dataset/baselines/speech-to-intent" | |
) | |
test_dataset = S2IDataset( | |
csv_path="./speech-to-intent/test.csv", | |
wav_dir_path="/home/development/pavan/Telesoft/speech-to-intent-dataset/baselines/speech-to-intent" | |
) | |
train_len = int(len(dataset) * 0.90) | |
val_len = len(dataset) - train_len | |
print(train_len, val_len) | |
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_len, val_len], generator=torch.Generator().manual_seed(SEED)) | |
print(len(test_dataset)) | |
trainloader = torch.utils.data.DataLoader( | |
train_dataset, | |
batch_size=4, | |
shuffle=True, | |
num_workers=4, | |
collate_fn = collate_fn, | |
) | |
valloader = torch.utils.data.DataLoader( | |
val_dataset, | |
batch_size=4, | |
num_workers=4, | |
collate_fn = collate_fn, | |
) | |
testloader = torch.utils.data.DataLoader( | |
test_dataset, | |
#batch_size=4, | |
num_workers=4, | |
collate_fn = collate_fn, | |
) | |
print(torch.cuda.mem_get_info()) | |
model = LightningModel() | |
run_name = "wav2vec" | |
logger = WandbLogger( | |
name=run_name, | |
project='S2I-baseline' | |
) | |
model_checkpoint_callback = ModelCheckpoint( | |
dirpath='checkpoints', | |
monitor='val/acc', | |
mode='max', | |
verbose=1, | |
filename=run_name + "-epoch={epoch}.ckpt") | |
trainer = Trainer( | |
fast_dev_run=False, | |
gpus=1, | |
max_epochs=5, | |
checkpoint_callback=True, | |
callbacks=[ | |
model_checkpoint_callback, | |
], | |
logger=logger, | |
) | |
checkpoint_path = "./checkpoints/wav2vec-epoch=epoch=4.ckpt.ckpt" | |
checkpoint = torch.load(checkpoint_path) | |
model.load_state_dict(checkpoint['state_dict']) | |
trainer = Trainer( | |
gpus=1 | |
) | |
#trainer.fit(model, train_dataloader=trainloader, val_dataloaders=valloader) | |
#trainer.test(model,dataloaders=testloader,verbose=True) | |
wav_path = "./speech-to-intent/wav_audios/92145547-3ab6-44e0-9245-085642fc4318.wav" | |
resmaple = torchaudio.transforms.Resample(8000, 16000) | |
wav_tensor,_ = torchaudio.load(wav_path) | |
wav_tensor = resmaple(wav_tensor) | |
model = model.to('cuda') | |
y_hat = model.predict(wav_tensor) | |
#with torch.no_grad(): | |
# y_hat = model(wav_tensor) | |
print(y_hat) | |