In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append('..')

In [35]:
import os
import matplotlib.pyplot as plt
from datetime import datetime
from IPython.display import Audio


# torch
import torch
import torchaudio
from torch.utils.data import DataLoader

# model training
from cnn import CNNetwork
from dataset import VoiceDataset
from train import train, validate_epoch

# api
from server.preprocess import process_from_url, librosa, wget

In [79]:
BATCH_SIZE = 128
EPOCHS = 10
LEARNING_RATE = 0.001

TRAIN_PATH = "../data/aisf/train"
TEST_PATH = "../data/aisf/test"
DEV_PATH = "../data/aisf/dev"
SAMPLE_RATE=48000

MEL_SPEC = torchaudio.transforms.MelSpectrogram(
    sample_rate=SAMPLE_RATE,
    n_fft=2048,
    hop_length=512,
    n_mels=128,
)

# Dataset

In [99]:
torch.set_printoptions(profile="default")

In [6]:
if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"
print(f"Using {device} device.")

Using cpu device.


In [8]:
len(train_dataset)

21

In [80]:
# Datasets
train_dataset = VoiceDataset(TRAIN_PATH, MEL_SPEC, device, SAMPLE_RATE, time_limit_in_secs=3)
test_dataset = VoiceDataset(TEST_PATH, MEL_SPEC, device, SAMPLE_RATE, time_limit_in_secs=3)
dev_dataset = VoiceDataset(DEV_PATH, MEL_SPEC, device, SAMPLE_RATE, time_limit_in_secs=3)

# Dataloaders
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True)
dev_dataloader = DataLoader(dev_dataset, batch_size=BATCH_SIZE, shuffle=True)

model = CNNetwork().to(device)

loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

history = train(
    model,
    test_dataloader,
    loss_fn,
    optimizer,
    device,
    10,
#     test_dataloader
)

Epoch 1/10


Training batch...: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 17/17 [00:37<00:00,  2.20s/it]


Training Loss: 1.18, Training Accuracy  0.37722854552780016
-------------------------------------------- 

Epoch 2/10


Training batch...: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 17/17 [00:33<00:00,  2.00s/it]


Training Loss: 1.17, Training Accuracy  0.3792871172441579
-------------------------------------------- 

Epoch 3/10


Training batch...: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 17/17 [00:36<00:00,  2.13s/it]


Training Loss: 1.17, Training Accuracy  0.3792871172441579
-------------------------------------------- 

Epoch 4/10


Training batch...: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 17/17 [00:34<00:00,  2.02s/it]


Training Loss: 1.17, Training Accuracy  0.3827495467365028
-------------------------------------------- 

Epoch 5/10


Training batch...: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 17/17 [00:35<00:00,  2.08s/it]


Training Loss: 1.17, Training Accuracy  0.3792871172441579
-------------------------------------------- 

Epoch 6/10


Training batch...: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 17/17 [00:32<00:00,  1.93s/it]


Training Loss: 1.17, Training Accuracy  0.3806720890410959
-------------------------------------------- 

Epoch 7/10


Training batch...: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 17/17 [00:32<00:00,  1.89s/it]


Training Loss: 1.17, Training Accuracy  0.3813645749395649
-------------------------------------------- 

Epoch 8/10


Training batch...: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 17/17 [00:34<00:00,  2.02s/it]


Training Loss: 1.17, Training Accuracy  0.3813645749395649
-------------------------------------------- 

Epoch 9/10


Training batch...: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 17/17 [00:35<00:00,  2.10s/it]


Training Loss: 1.17, Training Accuracy  0.3803258460918614
-------------------------------------------- 

Epoch 10/10


Training batch...: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 17/17 [00:33<00:00,  1.99s/it]

Training Loss: 1.17, Training Accuracy  0.38101833199033036
-------------------------------------------- 

---- Finished Training ----





In [69]:
validate_epoch(model, dev_dataloader, loss_fn, device)

Testing batch...: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  5.97it/s]


(0.5514446496963501, 1.0)

In [70]:
now = datetime.now()
now = now.strftime("%Y%m%d_%H%M%S")
model_filename = f"../models/aisf/void_{now}.pth"
torch.save(model.state_dict(), model_filename)
print(f"Trained void model saved at {model_filename}")

Trained void model saved at ../models/aisf/void_20230517_113634.pth


In [61]:
url = "https://cdn.filestackcontent.com/i4bxGOmPSkCp7nIYrBFw"
filename = wget.download(url)
audio, sr = librosa.load(filename)

  0% [                                                                            ]      0 / 116590  7% [.....                                                                       ]   8192 / 116590 14% [..........                                                                  ]  16384 / 116590 21% [................                                                            ]  24576 / 116590 28% [.....................                                                       ]  32768 / 116590 35% [..........................                                                  ]  40960 / 116590 42% [................................                                            ]  49152 / 116590 49% [.....................................                                       ]  57344 / 116590 56% [..........................................                                  ]  65536 / 116590 63% [................................................                            ]  73728 / 116590

In [None]:
display(Audio(audio, rate=sr))

In [65]:
s = process_from_url(url)

  0% [                                                                            ]      0 / 116590  7% [.....                                                                       ]   8192 / 116590 14% [..........                                                                  ]  16384 / 116590 21% [................                                                            ]  24576 / 116590 28% [.....................                                                       ]  32768 / 116590 35% [..........................                                                  ]  40960 / 116590 42% [................................                                            ]  49152 / 116590 49% [.....................................                                       ]  57344 / 116590 56% [..........................................                                  ]  65536 / 116590 63% [................................................                            ]  73728 / 116590

In [66]:
inp = s.unsqueeze(0)
model(inp)

tensor([[0., 0., 1.]], grad_fn=<SoftmaxBackward0>)

In [52]:
len(test_dataset)

2121

In [90]:
model = CNNetwork()
sd = torch.load('../models/aisf/void_20230517_115313.pth')
model.load_state_dict(sd)

i = 2100
ex = test_dataset[i][0]
expected = test_dataset[i][1]
ex = ex.unsqueeze(0)
print(model(ex))
print(expected)


tensor([[0., 0., 1.]], grad_fn=<SoftmaxBackward0>)
2


In [95]:
og = test_dataset[0][0]

In [96]:
from server.preprocess import process_from_url

In [108]:
wav = process_from_url("https://cdn.filestackcontent.com/A6gvnfdrQyWG3sX5G6Tz")

  0% [                                                                            ]      0 / 280670  2% [..                                                                          ]   8192 / 280670  5% [....                                                                        ]  16384 / 280670  8% [......                                                                      ]  24576 / 280670 11% [........                                                                    ]  32768 / 280670 14% [...........                                                                 ]  40960 / 280670 17% [.............                                                               ]  49152 / 280670 20% [...............                                                             ]  57344 / 280670 23% [.................                                                           ]  65536 / 280670 26% [...................                                                         ]  73728 / 280670

UnpicklingError: unpickling stack underflow

In [109]:
og == wav

tensor([[[False, False, False, False, False, False, False, False, False, False,
          False, False, False, False, False, False, False, False, False, False,
          False, False, False, False, False, False, False, False, False, False,
          False, False, False, False, False, False, False, False, False, False,
          False, False, False, False, False, False, False, False, False, False,
          False, False, False, False, False, False, False, False, False, False,
          False, False, False, False, False, False, False, False, False, False,
          False, False, False, False, False, False, False, False, False, False,
          False, False, False, False, False, False, False, False, False, False,
          False, False, False, False, False, False, False, False, False, False,
          False, False, False, False, False, False, False, False, False, False,
          False, False, False, False, False, False, False, False, False, False,
          False, False, False, False, Fa

In [100]:
torch.set_printoptions(profile="full")

In [110]:
og

tensor([[[0.0000e+00, 0.0000e+00, 0.0000e+00, 1.9071e-07, 5.4189e-05,
          6.9105e-04, 9.4122e-04, 4.7590e-04, 5.3411e-04, 2.1923e-03,
          1.0346e-02, 2.1014e-02, 8.5096e-03, 1.6077e-03, 7.7999e-02,
          1.4926e-01, 5.6634e-02, 6.2833e-02, 6.1681e-02, 1.8395e-02,
          2.0951e-03, 1.8622e-03, 2.3255e-02, 4.4418e-02, 8.3534e-02,
          4.3211e-02, 3.8408e-03, 5.1468e-03, 1.2562e-02, 2.0300e-02,
          1.7053e-02, 1.8022e-02, 1.9366e-02, 2.6519e-02, 4.5200e-02,
          4.5033e-02, 5.8789e-02, 1.0408e-01, 3.8676e-01, 2.5994e-01,
          2.5000e-01, 2.2123e-01, 1.5145e-01, 3.8525e-02, 1.4709e-02,
          7.5309e-04, 4.2539e-02, 1.0235e-01, 1.0290e-01, 6.2771e-02,
          7.1301e-02, 1.5755e-01, 1.6258e-01, 5.9196e-02, 4.5670e-02,
          5.0529e-02, 1.3353e-01, 5.6818e-02, 9.7458e-02, 3.2249e-02,
          2.1035e-02, 2.1370e-02, 3.3990e-02, 2.5314e-02, 2.8278e-02,
          6.0121e-02, 1.5050e-02, 2.2090e-02, 6.5799e-02, 3.2295e-02,
          4.8375e-03

In [111]:
wav

tensor([[[2.9502e-14, 1.1655e-06, 2.8426e-05, 6.6369e-05, 4.8865e-05,
          5.5627e-04, 1.6141e-03, 1.1643e-03, 4.9479e-04, 8.1252e-04,
          6.3261e-04, 7.8284e-04, 1.1501e-03, 1.5271e-03, 8.0774e-04,
          4.1401e-04, 1.3984e-03, 2.1693e-04, 1.9381e-03, 1.1178e-03,
          1.0962e-03, 4.1937e-04, 5.9588e-05, 5.9074e-04, 7.9317e-05,
          1.2123e-03, 1.3510e-03, 1.4369e-03, 1.4651e-04, 8.8809e-04,
          2.5105e-05, 2.4291e-03, 3.2386e-04, 2.1175e-03, 2.2902e-04,
          4.5842e-04, 3.8593e-03, 7.3141e-03, 8.9032e-04, 2.5558e-03,
          6.2424e-04, 1.2219e-03, 9.1645e-04, 4.6009e-03, 4.8670e-03,
          8.5113e-04, 1.0307e-03, 3.6066e-03, 9.0430e-03, 1.3521e-02,
          1.8816e-02, 2.3763e-03, 1.8221e-02, 3.9029e-02, 4.7335e-02,
          1.0348e-02, 1.7014e-02, 2.4003e-02, 1.7942e-02, 3.5550e-03,
          9.5289e-04, 1.3781e-03, 1.4569e-03, 7.9573e-04, 1.3278e-03,
          2.0630e-03, 9.6402e-04, 9.1164e-05, 4.6493e-03, 1.0668e-03,
          7.6707e-03

In [113]:
def foo():
    for i in range(len(dev_dataset)):
        if torch.all(dev_dataset[i][0] == wav):
            return True

    return False

foo()

False