from utils import CustomDataset, transform, preproc, Convert_ONNX from torch.utils.data import Dataset, DataLoader import torch import numpy as np from resnet_model import ResidualBlock, ResNet import torch import torch.nn as nn import torch.optim as optim import tqdm import torch.nn.functional as F from torch.optim.lr_scheduler import ReduceLROnPlateau import pickle import sys ind = int(sys.argv[1]) seeds = [1,42,7109,2002,32] seed = seeds[ind] print("using seed: ",seed) torch.manual_seed(seed) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") num_gpus = torch.cuda.device_count() print(num_gpus) # Create custom dataset instance data_dir = '/mnt/buf1/pma/frbnn/train_ready' dataset = CustomDataset(data_dir, transform=transform) valid_data_dir = '/mnt/buf1/pma/frbnn/valid_ready' valid_dataset = CustomDataset(valid_data_dir, transform=transform) num_classes = 2 trainloader = DataLoader(dataset, batch_size=420, shuffle=True, num_workers=32) validloader = DataLoader(valid_dataset, batch_size=512, shuffle=True, num_workers=32) model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device) model = nn.DataParallel(model) model = model.to(device) params = sum(p.numel() for p in model.parameters()) print("num params ",params) torch.save(model.state_dict(), 'models/test.pt') model.load_state_dict(torch.load('models/test.pt')) preproc_model = preproc() Convert_ONNX(model.module,'models/test.onnx', input_data_mock=torch.randn(1, 24, 192, 256).to(device)) Convert_ONNX(preproc_model,'models/preproc.onnx', input_data_mock=torch.randn(32, 192, 2048).to(device)) # Define optimizer and loss function criterion = nn.CrossEntropyLoss(weight = torch.tensor([1,1]).to(device)) optimizer = optim.Adam(model.parameters(), lr=0.0001) scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10) from tqdm import tqdm # Training loop epochs = 10000 for epoch in range(epochs): running_loss = 0.0 correct_train = 0 total_train = 0 with tqdm(trainloader, unit="batch") as tepoch: model.train() for i, (images, labels) in enumerate(tepoch): inputs, labels = images.to(device), labels.to(device).float() optimizer.zero_grad() outputs = model(inputs, return_mask=False).to(device) new_label = F.one_hot(labels.type(torch.int64),num_classes=2).type(torch.float32).to(device) loss = criterion(outputs, new_label) loss.backward() optimizer.step() running_loss += loss.item() # Calculate training accuracy _, predicted = torch.max(outputs.data, 1) total_train += labels.size(0) correct_train += (predicted == labels).sum().item() val_loss = 0.0 correct_valid = 0 total = 0 model.eval() with torch.no_grad(): for images, labels in validloader: inputs, labels = images.to(device), labels.to(device).float() optimizer.zero_grad() outputs = model(inputs, return_mask=False) new_label = F.one_hot(labels.type(torch.int64),num_classes=2).type(torch.float32) loss = criterion(outputs, new_label) val_loss += loss.item() _, predicted = torch.max(outputs, 1) total += labels.size(0) correct_valid += (predicted == labels).sum().item() scheduler.step(val_loss) # Calculate training accuracy after each epoch train_accuracy = 100 * correct_train / total_train val_accuracy = correct_valid / total * 100.0 torch.save(model.state_dict(), 'models/model-'+str(epoch)+'-'+str(val_accuracy)+'.pt') Convert_ONNX(model.module,'models/model-'+str(epoch)+'-'+str(val_accuracy)+'.onnx', input_data_mock=inputs) print("===========================") print('accuracy: ', epoch, train_accuracy, val_accuracy) print('learning rate: ', scheduler.get_last_lr()) print("===========================") if scheduler.get_last_lr()[0] <1e-6: break