amanmibra commited on
Commit
1f644df
·
1 Parent(s): 7c8662f

Change train progress bar, add test dataset to script

Browse files
Files changed (1) hide show
  1. train.py +7 -4
train.py CHANGED
@@ -16,6 +16,7 @@ EPOCHS = 100
16
  LEARNING_RATE = 0.001
17
 
18
  TRAIN_FILE="data/train"
 
19
  SAMPLE_RATE=16000
20
 
21
  def train(model, train_dataloader, loss_fn, optimizer, device, epochs, test_dataloader=None):
@@ -24,7 +25,7 @@ def train(model, train_dataloader, loss_fn, optimizer, device, epochs, test_data
24
  testing_acc = []
25
  testing_loss = []
26
 
27
- for i in tqdm(range(epochs), "Training model..."):
28
  print(f"Epoch {i + 1}")
29
 
30
  # train model
@@ -59,7 +60,7 @@ def train_epoch(model, train_dataloader, loss_fn, optimizer, device):
59
 
60
  model.train()
61
 
62
- for wav, target in train_dataloader:
63
  wav, target = wav.to(device), target.to(device)
64
 
65
  # calculate loss
@@ -87,7 +88,7 @@ def validate_epoch(model, test_dataloader, loss_fn, device):
87
  model.eval()
88
 
89
  with torch.no_grad():
90
- for wav, target in test_dataloader:
91
  wav, target = wav.to(device), target.to(device)
92
 
93
  output = model(wav)
@@ -116,7 +117,9 @@ if __name__ == "__main__":
116
  )
117
 
118
  train_dataset = VoiceDataset(TRAIN_FILE, mel_spectrogram, SAMPLE_RATE, device)
 
119
  train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
 
120
 
121
  # construct model
122
  model = CNNetwork().to(device)
@@ -128,7 +131,7 @@ if __name__ == "__main__":
128
 
129
 
130
  # train model
131
- train(model, train_dataloader, loss_fn, optimizer, device, EPOCHS)
132
 
133
  # save model
134
  now = datetime.now()
 
16
  LEARNING_RATE = 0.001
17
 
18
  TRAIN_FILE="data/train"
19
+ TEST_FILE="data/test"
20
  SAMPLE_RATE=16000
21
 
22
  def train(model, train_dataloader, loss_fn, optimizer, device, epochs, test_dataloader=None):
 
25
  testing_acc = []
26
  testing_loss = []
27
 
28
+ for i in range(epochs):
29
  print(f"Epoch {i + 1}")
30
 
31
  # train model
 
60
 
61
  model.train()
62
 
63
+ for wav, target in tqdm(train_dataloader, "Training batch..."):
64
  wav, target = wav.to(device), target.to(device)
65
 
66
  # calculate loss
 
88
  model.eval()
89
 
90
  with torch.no_grad():
91
+ for wav, target in tqdm(test_dataloader, "Testing batch..."):
92
  wav, target = wav.to(device), target.to(device)
93
 
94
  output = model(wav)
 
117
  )
118
 
119
  train_dataset = VoiceDataset(TRAIN_FILE, mel_spectrogram, SAMPLE_RATE, device)
120
+ test_dataset = VoiceDataset(TEST_FILE, mel_spectrogram, SAMPLE_RATE, device)
121
  train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
122
+ test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True)
123
 
124
  # construct model
125
  model = CNNetwork().to(device)
 
131
 
132
 
133
  # train model
134
+ train(model, train_dataloader, loss_fn, optimizer, device, EPOCHS, test_dataloader=test_dataloader)
135
 
136
  # save model
137
  now = datetime.now()