Spaces:
Runtime error
Runtime error
File size: 4,166 Bytes
bd0a813 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
import os
from torch.utils.tensorboard import SummaryWriter
import torch
from torch.nn import Sequential
from torch.utils.data import DataLoader
from datasets import Valentini
from datetime import datetime
from torchvision.transforms import RandomCrop
from utils import load_wav
from denoisers.demucs import Demucs
from pathlib import Path
os.environ['CUDA_VISIBLE_DEVICES'] = "1"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Demucs(H=64).to(device)
DATASET_PATH = Path('/media/public/datasets/denoising/DS_10283_2791/')
VALID_WAVS = {'hard': 'p257_171.wav',
'medium': 'p232_071.wav',
'easy': 'p232_284.wav'}
MAX_SECONDS = 3.2
SAMPLE_RATE = 16000
transform = Sequential(RandomCrop((1, int(MAX_SECONDS * SAMPLE_RATE)), pad_if_needed=True))
training_loader = DataLoader(Valentini(valid=False, transform=transform), batch_size=12, shuffle=True)
validation_loader = DataLoader(Valentini(valid=True, transform=transform), batch_size=12, shuffle=True)
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
loss_fn = torch.nn.MSELoss()
def train_one_epoch(epoch_index, tb_writer):
running_loss = 0.
last_loss = 0.
for i, data in enumerate(training_loader):
inputs, labels = data
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = loss_fn(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
if i % 1000 == 999:
last_loss = running_loss / 100 # loss per batch
print(' batch {} loss: {}'.format(i + 1, last_loss))
tb_x = epoch_index * len(training_loader) + i + 1
tb_writer.add_scalar('Loss/train', last_loss, tb_x)
running_loss = 0.
return last_loss
def train():
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter('runs/denoising_trainer_{}'.format(timestamp))
epoch_number = 0
EPOCHS = 5
best_vloss = 1_000_000.
for tag, wav_path in VALID_WAVS.items():
wav = load_wav(DATASET_PATH / 'noisy_testset_wav' / wav_path)
writer.add_audio(tag=tag, snd_tensor=wav, sample_rate=SAMPLE_RATE)
writer.flush()
for epoch in range(EPOCHS):
print('EPOCH {}:'.format(epoch_number + 1))
# Make sure gradient tracking is on, and do a pass over the data
model.train(True)
avg_loss = train_one_epoch(epoch_number, writer)
# We don't need gradients on to do reporting
model.train(False)
running_vloss = 0.0
with torch.no_grad():
for i, vdata in enumerate(validation_loader):
vinputs, vlabels = vdata
vinputs, vlabels = vinputs.to(device), vlabels.to(device)
voutputs = model(vinputs)
vloss = loss_fn(voutputs, vlabels)
running_vloss += vloss
avg_vloss = running_vloss / (i + 1)
print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))
writer.add_scalars('Training vs. Validation Loss',
{'Training': avg_loss, 'Validation': avg_vloss},
epoch_number + 1)
for tag, wav_path in VALID_WAVS.items():
wav = load_wav(DATASET_PATH / 'noisy_testset_wav' / wav_path)
wav = torch.reshape(wav, (1, 1, -1)).to(device)
prediction = model(wav)
writer.add_audio(tag=f"Model predicted {tag} on epoch {epoch}",
snd_tensor=prediction,
sample_rate=SAMPLE_RATE)
writer.flush()
if avg_vloss < best_vloss:
best_vloss = avg_vloss
model_path = 'checkpoints/model_{}_{}'.format(timestamp, epoch_number)
torch.save(model.state_dict(), model_path)
epoch_number += 1
if __name__ == '__main__':
train()
|