denoising / train.py
BorisovMaksim's picture
fixes
bd0a813
raw
history blame
4.17 kB
import os
from torch.utils.tensorboard import SummaryWriter
import torch
from torch.nn import Sequential
from torch.utils.data import DataLoader
from datasets import Valentini
from datetime import datetime
from torchvision.transforms import RandomCrop
from utils import load_wav
from denoisers.demucs import Demucs
from pathlib import Path
os.environ['CUDA_VISIBLE_DEVICES'] = "1"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Demucs(H=64).to(device)
DATASET_PATH = Path('/media/public/datasets/denoising/DS_10283_2791/')
VALID_WAVS = {'hard': 'p257_171.wav',
'medium': 'p232_071.wav',
'easy': 'p232_284.wav'}
MAX_SECONDS = 3.2
SAMPLE_RATE = 16000
transform = Sequential(RandomCrop((1, int(MAX_SECONDS * SAMPLE_RATE)), pad_if_needed=True))
training_loader = DataLoader(Valentini(valid=False, transform=transform), batch_size=12, shuffle=True)
validation_loader = DataLoader(Valentini(valid=True, transform=transform), batch_size=12, shuffle=True)
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
loss_fn = torch.nn.MSELoss()
def train_one_epoch(epoch_index, tb_writer):
running_loss = 0.
last_loss = 0.
for i, data in enumerate(training_loader):
inputs, labels = data
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = loss_fn(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
if i % 1000 == 999:
last_loss = running_loss / 100 # loss per batch
print(' batch {} loss: {}'.format(i + 1, last_loss))
tb_x = epoch_index * len(training_loader) + i + 1
tb_writer.add_scalar('Loss/train', last_loss, tb_x)
running_loss = 0.
return last_loss
def train():
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter('runs/denoising_trainer_{}'.format(timestamp))
epoch_number = 0
EPOCHS = 5
best_vloss = 1_000_000.
for tag, wav_path in VALID_WAVS.items():
wav = load_wav(DATASET_PATH / 'noisy_testset_wav' / wav_path)
writer.add_audio(tag=tag, snd_tensor=wav, sample_rate=SAMPLE_RATE)
writer.flush()
for epoch in range(EPOCHS):
print('EPOCH {}:'.format(epoch_number + 1))
# Make sure gradient tracking is on, and do a pass over the data
model.train(True)
avg_loss = train_one_epoch(epoch_number, writer)
# We don't need gradients on to do reporting
model.train(False)
running_vloss = 0.0
with torch.no_grad():
for i, vdata in enumerate(validation_loader):
vinputs, vlabels = vdata
vinputs, vlabels = vinputs.to(device), vlabels.to(device)
voutputs = model(vinputs)
vloss = loss_fn(voutputs, vlabels)
running_vloss += vloss
avg_vloss = running_vloss / (i + 1)
print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))
writer.add_scalars('Training vs. Validation Loss',
{'Training': avg_loss, 'Validation': avg_vloss},
epoch_number + 1)
for tag, wav_path in VALID_WAVS.items():
wav = load_wav(DATASET_PATH / 'noisy_testset_wav' / wav_path)
wav = torch.reshape(wav, (1, 1, -1)).to(device)
prediction = model(wav)
writer.add_audio(tag=f"Model predicted {tag} on epoch {epoch}",
snd_tensor=prediction,
sample_rate=SAMPLE_RATE)
writer.flush()
if avg_vloss < best_vloss:
best_vloss = avg_vloss
model_path = 'checkpoints/model_{}_{}'.format(timestamp, epoch_number)
torch.save(model.state_dict(), model_path)
epoch_number += 1
if __name__ == '__main__':
train()