File size: 4,166 Bytes
bd0a813
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import os
from torch.utils.tensorboard import SummaryWriter
import torch
from torch.nn import Sequential
from torch.utils.data import DataLoader
from datasets import Valentini
from datetime import datetime
from torchvision.transforms import RandomCrop
from utils import load_wav
from denoisers.demucs import Demucs
from pathlib import Path

os.environ['CUDA_VISIBLE_DEVICES'] = "1"

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = Demucs(H=64).to(device)

DATASET_PATH = Path('/media/public/datasets/denoising/DS_10283_2791/')
VALID_WAVS = {'hard': 'p257_171.wav',
              'medium': 'p232_071.wav',
              'easy': 'p232_284.wav'}
MAX_SECONDS = 3.2
SAMPLE_RATE = 16000

transform = Sequential(RandomCrop((1, int(MAX_SECONDS * SAMPLE_RATE)), pad_if_needed=True))

training_loader = DataLoader(Valentini(valid=False, transform=transform), batch_size=12, shuffle=True)
validation_loader = DataLoader(Valentini(valid=True, transform=transform), batch_size=12, shuffle=True)

optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
loss_fn = torch.nn.MSELoss()


def train_one_epoch(epoch_index, tb_writer):
    running_loss = 0.
    last_loss = 0.

    for i, data in enumerate(training_loader):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()

        outputs = model(inputs)

        loss = loss_fn(outputs, labels)
        loss.backward()

        optimizer.step()

        running_loss += loss.item()
        if i % 1000 == 999:
            last_loss = running_loss / 100  # loss per batch
            print('  batch {} loss: {}'.format(i + 1, last_loss))
            tb_x = epoch_index * len(training_loader) + i + 1
            tb_writer.add_scalar('Loss/train', last_loss, tb_x)
            running_loss = 0.


    return last_loss


def train():
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    writer = SummaryWriter('runs/denoising_trainer_{}'.format(timestamp))
    epoch_number = 0

    EPOCHS = 5

    best_vloss = 1_000_000.

    for tag, wav_path in VALID_WAVS.items():
        wav = load_wav(DATASET_PATH / 'noisy_testset_wav' / wav_path)
        writer.add_audio(tag=tag, snd_tensor=wav, sample_rate=SAMPLE_RATE)
    writer.flush()

    for epoch in range(EPOCHS):
        print('EPOCH {}:'.format(epoch_number + 1))

        # Make sure gradient tracking is on, and do a pass over the data
        model.train(True)
        avg_loss = train_one_epoch(epoch_number, writer)

        # We don't need gradients on to do reporting
        model.train(False)

        running_vloss = 0.0
        with torch.no_grad():
            for i, vdata in enumerate(validation_loader):
                vinputs, vlabels = vdata
                vinputs, vlabels = vinputs.to(device), vlabels.to(device)
                voutputs = model(vinputs)
                vloss = loss_fn(voutputs, vlabels)
                running_vloss += vloss

            avg_vloss = running_vloss / (i + 1)
            print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))

            writer.add_scalars('Training vs. Validation Loss',
                               {'Training': avg_loss, 'Validation': avg_vloss},
                               epoch_number + 1)
            for tag, wav_path in VALID_WAVS.items():
                wav = load_wav(DATASET_PATH / 'noisy_testset_wav' / wav_path)
                wav = torch.reshape(wav, (1, 1, -1)).to(device)
                prediction = model(wav)
                writer.add_audio(tag=f"Model predicted {tag} on epoch {epoch}",
                                 snd_tensor=prediction,
                                 sample_rate=SAMPLE_RATE)
            writer.flush()

            if avg_vloss < best_vloss:
                best_vloss = avg_vloss
                model_path = 'checkpoints/model_{}_{}'.format(timestamp, epoch_number)
                torch.save(model.state_dict(), model_path)

            epoch_number += 1


if __name__ == '__main__':
    train()