Spaces:
Runtime error
Runtime error
Change train progress bar, add test dataset to script
Browse files
train.py
CHANGED
@@ -16,6 +16,7 @@ EPOCHS = 100
|
|
16 |
LEARNING_RATE = 0.001
|
17 |
|
18 |
TRAIN_FILE="data/train"
|
|
|
19 |
SAMPLE_RATE=16000
|
20 |
|
21 |
def train(model, train_dataloader, loss_fn, optimizer, device, epochs, test_dataloader=None):
|
@@ -24,7 +25,7 @@ def train(model, train_dataloader, loss_fn, optimizer, device, epochs, test_data
|
|
24 |
testing_acc = []
|
25 |
testing_loss = []
|
26 |
|
27 |
-
for i in
|
28 |
print(f"Epoch {i + 1}")
|
29 |
|
30 |
# train model
|
@@ -59,7 +60,7 @@ def train_epoch(model, train_dataloader, loss_fn, optimizer, device):
|
|
59 |
|
60 |
model.train()
|
61 |
|
62 |
-
for wav, target in train_dataloader:
|
63 |
wav, target = wav.to(device), target.to(device)
|
64 |
|
65 |
# calculate loss
|
@@ -87,7 +88,7 @@ def validate_epoch(model, test_dataloader, loss_fn, device):
|
|
87 |
model.eval()
|
88 |
|
89 |
with torch.no_grad():
|
90 |
-
for wav, target in test_dataloader:
|
91 |
wav, target = wav.to(device), target.to(device)
|
92 |
|
93 |
output = model(wav)
|
@@ -116,7 +117,9 @@ if __name__ == "__main__":
|
|
116 |
)
|
117 |
|
118 |
train_dataset = VoiceDataset(TRAIN_FILE, mel_spectrogram, SAMPLE_RATE, device)
|
|
|
119 |
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
|
|
|
120 |
|
121 |
# construct model
|
122 |
model = CNNetwork().to(device)
|
@@ -128,7 +131,7 @@ if __name__ == "__main__":
|
|
128 |
|
129 |
|
130 |
# train model
|
131 |
-
train(model, train_dataloader, loss_fn, optimizer, device, EPOCHS)
|
132 |
|
133 |
# save model
|
134 |
now = datetime.now()
|
|
|
16 |
LEARNING_RATE = 0.001
|
17 |
|
18 |
TRAIN_FILE="data/train"
|
19 |
+
TEST_FILE="data/test"
|
20 |
SAMPLE_RATE=16000
|
21 |
|
22 |
def train(model, train_dataloader, loss_fn, optimizer, device, epochs, test_dataloader=None):
|
|
|
25 |
testing_acc = []
|
26 |
testing_loss = []
|
27 |
|
28 |
+
for i in range(epochs):
|
29 |
print(f"Epoch {i + 1}")
|
30 |
|
31 |
# train model
|
|
|
60 |
|
61 |
model.train()
|
62 |
|
63 |
+
for wav, target in tqdm(train_dataloader, "Training batch..."):
|
64 |
wav, target = wav.to(device), target.to(device)
|
65 |
|
66 |
# calculate loss
|
|
|
88 |
model.eval()
|
89 |
|
90 |
with torch.no_grad():
|
91 |
+
for wav, target in tqdm(test_dataloader, "Testing batch..."):
|
92 |
wav, target = wav.to(device), target.to(device)
|
93 |
|
94 |
output = model(wav)
|
|
|
117 |
)
|
118 |
|
119 |
train_dataset = VoiceDataset(TRAIN_FILE, mel_spectrogram, SAMPLE_RATE, device)
|
120 |
+
test_dataset = VoiceDataset(TEST_FILE, mel_spectrogram, SAMPLE_RATE, device)
|
121 |
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
|
122 |
+
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True)
|
123 |
|
124 |
# construct model
|
125 |
model = CNNetwork().to(device)
|
|
|
131 |
|
132 |
|
133 |
# train model
|
134 |
+
train(model, train_dataloader, loss_fn, optimizer, device, EPOCHS, test_dataloader=test_dataloader)
|
135 |
|
136 |
# save model
|
137 |
now = datetime.now()
|