amanmibra commited on
Commit
070e16b
·
1 Parent(s): 6c41348

Add validation

Browse files
Files changed (1) hide show
  1. pipelines/train.py +51 -3
pipelines/train.py CHANGED
@@ -21,7 +21,7 @@ from cnn import CNNetwork
21
  # script defaults
22
  BATCH_SIZE = 128
23
  EPOCHS = 10
24
- LEARNING_RATE = 0.001
25
 
26
  TRAIN_FILE="data/train"
27
  TEST_FILE="data/test"
@@ -48,6 +48,7 @@ def train(
48
  optimizer,
49
  origin_device="cuda",
50
  epochs=10,
 
51
  ):
52
  import os
53
 
@@ -68,6 +69,8 @@ def train(
68
  # metrics
69
  training_acc = []
70
  training_loss = []
 
 
71
 
72
  wandb.init(project="void-training")
73
 
@@ -86,6 +89,18 @@ def train(
86
  now = time.time()
87
  print("Training Loss: {:.2f}, Training Accuracy: {:.4f}, Time: {:.2f}s".format(training_loss[i], training_acc[i], now - then))
88
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  print ("-------------------------------------------------------- \n")
90
 
91
  end = time.time()
@@ -96,7 +111,7 @@ def train(
96
  return model.to(origin_device)
97
 
98
  @stub.function(
99
- gpu="any",
100
  mounts=[
101
  Mount.from_local_file(local_path='dataset.py'),
102
  Mount.from_local_file(local_path='cnn.py'),
@@ -132,6 +147,36 @@ def train_epoch(model, train_dataloader, loss_fn, optimizer, device):
132
 
133
  return model, train_loss, train_acc
134
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  def save_model(model):
136
  now = time.strftime("%Y%m%d_%H%M%S")
137
  model_filename = f"models/void_{now}.pth"
@@ -163,6 +208,9 @@ def main():
163
  train_dataset = VoiceDataset(TRAIN_FILE, mel_spectrogram, device, time_limit_in_secs=3)
164
  train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
165
 
 
 
 
166
  # construct model
167
  model = CNNetwork()
168
 
@@ -171,7 +219,7 @@ def main():
171
  optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
172
 
173
  # train model
174
- model = train.call(model, train_dataloader, loss_fn, optimizer, device, 3)
175
 
176
  # save model
177
  save_model(model)
 
21
  # script defaults
22
  BATCH_SIZE = 128
23
  EPOCHS = 10
24
+ LEARNING_RATE = 0.01
25
 
26
  TRAIN_FILE="data/train"
27
  TEST_FILE="data/test"
 
48
  optimizer,
49
  origin_device="cuda",
50
  epochs=10,
51
+ test_dataloader=None,
52
  ):
53
  import os
54
 
 
69
  # metrics
70
  training_acc = []
71
  training_loss = []
72
+ testing_acc = []
73
+ testing_loss = []
74
 
75
  wandb.init(project="void-training")
76
 
 
89
  now = time.time()
90
  print("Training Loss: {:.2f}, Training Accuracy: {:.4f}, Time: {:.2f}s".format(training_loss[i], training_acc[i], now - then))
91
 
92
+ if test_dataloader:
93
+ # test model
94
+ test_epoch_loss, test_epoch_acc = validate_epoch.call(model, test_dataloader, loss_fn, modal_device)
95
+
96
+ # testing metrics
97
+ testing_loss.append(test_epoch_loss/len(test_dataloader))
98
+ testing_acc.append(test_epoch_acc/len(test_dataloader))
99
+
100
+ print("Testing Loss: {:.2f}, Testing Accuracy {:.2f}".format(testing_loss[i], testing_acc[i]))
101
+
102
+ wandb.log({'testing_loss': testing_loss[i], 'testing_acc': testing_acc[i]})
103
+
104
  print ("-------------------------------------------------------- \n")
105
 
106
  end = time.time()
 
111
  return model.to(origin_device)
112
 
113
  @stub.function(
114
+ gpu=gpu.A100(memory=20),
115
  mounts=[
116
  Mount.from_local_file(local_path='dataset.py'),
117
  Mount.from_local_file(local_path='cnn.py'),
 
147
 
148
  return model, train_loss, train_acc
149
 
150
+ @stub.function(
151
+ gpu="any",
152
+ mounts=[
153
+ Mount.from_local_file(local_path='dataset.py'),
154
+ Mount.from_local_file(local_path='cnn.py'),
155
+ ],
156
+ )
157
+ def validate_epoch(model, test_dataloader, loss_fn, device):
158
+ from tqdm import tqdm
159
+
160
+ test_loss = 0.0
161
+ test_acc = 0.0
162
+ total = 0.0
163
+
164
+ model.eval()
165
+
166
+ with torch.no_grad():
167
+ for wav, target in tqdm(test_dataloader, "Testing batch..."):
168
+ wav, target = wav.to(device), target.to(device)
169
+
170
+ output = model(wav)
171
+ loss = loss_fn(output, target)
172
+
173
+ test_loss += loss.item()
174
+ prediciton = torch.argmax(output, 1)
175
+ test_acc += (prediciton == target).sum().item()/len(prediciton)
176
+ total += 1
177
+
178
+ return test_loss, test_acc
179
+
180
  def save_model(model):
181
  now = time.strftime("%Y%m%d_%H%M%S")
182
  model_filename = f"models/void_{now}.pth"
 
208
  train_dataset = VoiceDataset(TRAIN_FILE, mel_spectrogram, device, time_limit_in_secs=3)
209
  train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
210
 
211
+ test_dataset = VoiceDataset(TEST_FILE, mel_spectrogram, device, time_limit_in_secs=3)
212
+ test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True)
213
+
214
  # construct model
215
  model = CNNetwork()
216
 
 
219
  optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
220
 
221
  # train model
222
+ model = train.call(model, train_dataloader, loss_fn, optimizer, device, EPOCHS, test_dataloader)
223
 
224
  # save model
225
  save_model(model)