import time import argparse import numpy as np import torch import tqdm from torch import optim from torch.utils.data import DataLoader from data_proc.cross_entropy_dataset import FBanksCrossEntropyDataset from models.cross_entropy_model import FBankCrossEntropyNetV2 from utils.pt_util import restore_objects, save_model, save_objects, restore_model from trainer.cross_entropy_train import train, test def main(args): model_path = f"saved_models_cross_entropy/{args.num_layers}/" use_cuda = True device = "cuda" if torch.cuda.is_available() else "cpu" print('using device', device) import multiprocessing print('num cpus:', multiprocessing.cpu_count()) kwargs = {'num_workers': multiprocessing.cpu_count(), 'pin_memory': True} if use_cuda else {} train_dataset = FBanksCrossEntropyDataset(args.train_folder) train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, **kwargs) test_dataset = FBanksCrossEntropyDataset(args.test_folder) test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=True, **kwargs) model = FBankCrossEntropyNetV2(num_layers=args.num_layers, reduction='mean').to(device) model = restore_model(model, model_path) last_epoch, max_accuracy, train_losses, test_losses, train_accuracies, test_accuracies = restore_objects(model_path, (0, 0, [], [], [], [])) start = last_epoch + 1 if max_accuracy > 0 else 0 optimizer = optim.Adam(model.parameters(), lr=args.lr) for epoch in range(start, args.epochs): train_loss, train_accuracy = train(model, device, train_loader, optimizer, epoch, 500) test_loss, test_accuracy = test(model, device, test_loader) print('After epoch: {}, train_loss: {}, test loss is: {}, train_accuracy: {}, ' 'test_accuracy: {}'.format(epoch, train_loss, test_loss, train_accuracy, test_accuracy)) train_losses.append(train_loss) test_losses.append(test_loss) train_accuracies.append(train_accuracy) test_accuracies.append(test_accuracy) if test_accuracy > max_accuracy: max_accuracy = test_accuracy save_model(model, epoch, model_path) save_objects((epoch, max_accuracy, train_losses, test_losses, train_accuracies, test_accuracies), epoch, model_path) print('saved epoch: {} as checkpoint'.format(epoch)) if __name__ == '__main__': parser = argparse.ArgumentParser(description='FBank Cross Entropy Training Script') parser.add_argument('--num_layers', type=int, default=2, help='Number of layers in the model') parser.add_argument('--train_folder', type=str, default='fbanks_train', help='Training dataset folder') parser.add_argument('--test_folder', type=str, default='fbanks_test', help='Testing dataset folder') parser.add_argument('--epochs', type=int, default=20, help='Number of epochs to train') parser.add_argument('--batch_size', type=int, default=64, help='Batch size for training') parser.add_argument('--lr', type=float, default=0.0005, help='Learning rate for the optimizer') args = parser.parse_args() main(args)