|
from utils import CustomDataset, transform, preproc, Convert_ONNX |
|
from torch.utils.data import Dataset, DataLoader |
|
import torch |
|
import numpy as np |
|
from resnet_model import ResidualBlock, ResNet |
|
import torch |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
import tqdm |
|
import torch.nn.functional as F |
|
from torch.optim.lr_scheduler import ReduceLROnPlateau |
|
import pickle |
|
import sys |
|
|
|
ind = int(sys.argv[1]) |
|
seeds = [1,42,7109,2002,32] |
|
seed = seeds[ind] |
|
print("using seed: ",seed) |
|
torch.manual_seed(seed) |
|
|
|
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
num_gpus = torch.cuda.device_count() |
|
print(num_gpus) |
|
|
|
|
|
data_dir = '/mnt/buf1/pma/frbnn/train_ready' |
|
dataset = CustomDataset(data_dir, transform=transform) |
|
valid_data_dir = '/mnt/buf1/pma/frbnn/valid_ready' |
|
valid_dataset = CustomDataset(valid_data_dir, transform=transform) |
|
|
|
|
|
num_classes = 2 |
|
trainloader = DataLoader(dataset, batch_size=420, shuffle=True, num_workers=32) |
|
validloader = DataLoader(valid_dataset, batch_size=512, shuffle=True, num_workers=32) |
|
|
|
model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device) |
|
model = nn.DataParallel(model) |
|
model = model.to(device) |
|
params = sum(p.numel() for p in model.parameters()) |
|
print("num params ",params) |
|
torch.save(model.state_dict(), 'models/test.pt') |
|
model.load_state_dict(torch.load('models/test.pt')) |
|
|
|
preproc_model = preproc() |
|
Convert_ONNX(model.module,'models/test.onnx', input_data_mock=torch.randn(1, 24, 192, 256).to(device)) |
|
Convert_ONNX(preproc_model,'models/preproc.onnx', input_data_mock=torch.randn(32, 192, 2048).to(device)) |
|
|
|
|
|
|
|
criterion = nn.CrossEntropyLoss(weight = torch.tensor([1,1]).to(device)) |
|
optimizer = optim.Adam(model.parameters(), lr=0.0001) |
|
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10) |
|
|
|
|
|
from tqdm import tqdm |
|
|
|
epochs = 10000 |
|
for epoch in range(epochs): |
|
running_loss = 0.0 |
|
correct_train = 0 |
|
total_train = 0 |
|
with tqdm(trainloader, unit="batch") as tepoch: |
|
model.train() |
|
for i, (images, labels) in enumerate(tepoch): |
|
inputs, labels = images.to(device), labels.to(device).float() |
|
optimizer.zero_grad() |
|
outputs = model(inputs, return_mask=False).to(device) |
|
new_label = F.one_hot(labels.type(torch.int64),num_classes=2).type(torch.float32).to(device) |
|
loss = criterion(outputs, new_label) |
|
loss.backward() |
|
optimizer.step() |
|
running_loss += loss.item() |
|
|
|
_, predicted = torch.max(outputs.data, 1) |
|
total_train += labels.size(0) |
|
correct_train += (predicted == labels).sum().item() |
|
val_loss = 0.0 |
|
correct_valid = 0 |
|
total = 0 |
|
model.eval() |
|
with torch.no_grad(): |
|
for images, labels in validloader: |
|
inputs, labels = images.to(device), labels.to(device).float() |
|
optimizer.zero_grad() |
|
outputs = model(inputs, return_mask=False) |
|
new_label = F.one_hot(labels.type(torch.int64),num_classes=2).type(torch.float32) |
|
loss = criterion(outputs, new_label) |
|
val_loss += loss.item() |
|
_, predicted = torch.max(outputs, 1) |
|
total += labels.size(0) |
|
correct_valid += (predicted == labels).sum().item() |
|
scheduler.step(val_loss) |
|
|
|
train_accuracy = 100 * correct_train / total_train |
|
val_accuracy = correct_valid / total * 100.0 |
|
torch.save(model.state_dict(), 'models/model-'+str(epoch)+'-'+str(val_accuracy)+'.pt') |
|
Convert_ONNX(model.module,'models/model-'+str(epoch)+'-'+str(val_accuracy)+'.onnx', input_data_mock=inputs) |
|
|
|
print("===========================") |
|
print('accuracy: ', epoch, train_accuracy, val_accuracy) |
|
print('learning rate: ', scheduler.get_last_lr()) |
|
print("===========================") |
|
if scheduler.get_last_lr()[0] <1e-6: |
|
break |
|
|