BLADE_FRBNN / models /train-mask.py
peterma02's picture
Upload folder using huggingface_hub
f3972ea verified
from utils import CustomDataset, transform, preproc, Convert_ONNX
from torch.utils.data import Dataset, DataLoader
import torch
import numpy as np
from resnet_model_mask import ResidualBlock, ResNet
import torch
import torch.nn as nn
import torch.optim as optim
import tqdm
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau
import pickle
import sys
ind = int(sys.argv[1])
seeds = [1,42,7109,2002,32]
seed = seeds[ind]
print("using seed: ",seed)
torch.manual_seed(seed)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
num_gpus = torch.cuda.device_count()
print(num_gpus)
# Create custom dataset instance
data_dir = '/mnt/buf1/pma/frbnn/train_ready'
dataset = CustomDataset(data_dir, transform=transform)
valid_data_dir = '/mnt/buf1/pma/frbnn/valid_ready'
valid_dataset = CustomDataset(valid_data_dir, transform=transform)
num_classes = 2
trainloader = DataLoader(dataset, batch_size=420, shuffle=True, num_workers=32)
validloader = DataLoader(valid_dataset, batch_size=512, shuffle=True, num_workers=32)
model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device)
model = nn.DataParallel(model)
model = model.to(device)
params = sum(p.numel() for p in model.parameters())
print("num params ",params)
torch.save(model.state_dict(), f'models_mask/test_{seed}.pt')
model.load_state_dict(torch.load(f'models_mask/test_{seed}.pt'))
preproc_model = preproc()
Convert_ONNX(model.module,f'models_mask/test_{seed}.onnx', input_data_mock=torch.randn(1, 24, 192, 256).to(device))
Convert_ONNX(preproc_model,f'models_mask/preproc_{seed}.onnx', input_data_mock=torch.randn(32, 192, 2048).to(device))
# Define optimizer and loss function
criterion = nn.CrossEntropyLoss(weight = torch.tensor([1,1]).to(device))
optimizer = optim.Adam(model.parameters(), lr=0.0001)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10)
from tqdm import tqdm
# Training loop
epochs = 10000
for epoch in range(epochs):
running_loss = 0.0
correct_train = 0
total_train = 0
with tqdm(trainloader, unit="batch") as tepoch:
model.train()
for i, (images, labels) in enumerate(tepoch):
inputs, labels = images.to(device), labels.to(device).float()
optimizer.zero_grad()
outputs = model(inputs, return_mask=False).to(device)
new_label = F.one_hot(labels.type(torch.int64),num_classes=2).type(torch.float32).to(device)
loss = criterion(outputs, new_label)
loss.backward()
optimizer.step()
running_loss += loss.item()
# Calculate training accuracy
_, predicted = torch.max(outputs.data, 1)
total_train += labels.size(0)
correct_train += (predicted == labels).sum().item()
val_loss = 0.0
correct_valid = 0
total = 0
model.eval()
with torch.no_grad():
for images, labels in validloader:
inputs, labels = images.to(device), labels.to(device).float()
optimizer.zero_grad()
outputs = model(inputs, return_mask=False)
new_label = F.one_hot(labels.type(torch.int64),num_classes=2).type(torch.float32)
loss = criterion(outputs, new_label)
val_loss += loss.item()
_, predicted = torch.max(outputs, 1)
total += labels.size(0)
correct_valid += (predicted == labels).sum().item()
scheduler.step(val_loss)
# Calculate training accuracy after each epoch
train_accuracy = 100 * correct_train / total_train
val_accuracy = correct_valid / total * 100.0
torch.save(model.state_dict(), 'models_mask/model-'+str(epoch)+'-'+str(val_accuracy)+f'_{seed}.pt')
Convert_ONNX(model.module,'models_mask/model-'+str(epoch)+'-'+str(val_accuracy)+f'_{seed}.onnx', input_data_mock=inputs)
print("===========================")
print('accuracy: ', epoch, train_accuracy, val_accuracy)
print('learning rate: ', scheduler.get_last_lr())
print("===========================")
if scheduler.get_last_lr()[0] <1e-6:
break