File size: 4,194 Bytes
f3972ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from utils import CustomDataset, transform, preproc, Convert_ONNX
from torch.utils.data import Dataset, DataLoader
import torch
import numpy as np
from resnet_model_mask import  ResidualBlock, ResNet
import torch
import torch.nn as nn
import torch.optim as optim
import tqdm 
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau
import pickle
import sys

ind = int(sys.argv[1])
seeds = [1,42,7109,2002,32]
seed = seeds[ind]
print("using seed: ",seed)
torch.manual_seed(seed)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
num_gpus = torch.cuda.device_count()
print(num_gpus)

# Create custom dataset instance
data_dir = '/mnt/buf1/pma/frbnn/train_ready'
dataset = CustomDataset(data_dir, transform=transform)
valid_data_dir = '/mnt/buf1/pma/frbnn/valid_ready'
valid_dataset = CustomDataset(valid_data_dir, transform=transform)


num_classes = 2
trainloader = DataLoader(dataset, batch_size=420, shuffle=True, num_workers=32)
validloader = DataLoader(valid_dataset, batch_size=512, shuffle=True, num_workers=32)

model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device)
model = nn.DataParallel(model)
model = model.to(device)
params = sum(p.numel() for p in model.parameters())
print("num params ",params)
torch.save(model.state_dict(), f'models_mask/test_{seed}.pt')
model.load_state_dict(torch.load(f'models_mask/test_{seed}.pt'))

preproc_model = preproc()
Convert_ONNX(model.module,f'models_mask/test_{seed}.onnx', input_data_mock=torch.randn(1, 24, 192, 256).to(device))
Convert_ONNX(preproc_model,f'models_mask/preproc_{seed}.onnx', input_data_mock=torch.randn(32, 192, 2048).to(device))

# Define optimizer and loss function

criterion = nn.CrossEntropyLoss(weight = torch.tensor([1,1]).to(device))
optimizer = optim.Adam(model.parameters(), lr=0.0001)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10)


from tqdm import tqdm
# Training loop
epochs = 10000
for epoch in range(epochs):
    running_loss = 0.0
    correct_train = 0
    total_train = 0
    with tqdm(trainloader, unit="batch") as tepoch:
        model.train()
        for i, (images, labels) in enumerate(tepoch):
            inputs, labels = images.to(device), labels.to(device).float()
            optimizer.zero_grad()
            outputs = model(inputs, return_mask=False).to(device)
            new_label = F.one_hot(labels.type(torch.int64),num_classes=2).type(torch.float32).to(device)
            loss = criterion(outputs, new_label)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            # Calculate training accuracy
            _, predicted = torch.max(outputs.data, 1)
            total_train += labels.size(0)
            correct_train += (predicted == labels).sum().item()        
        val_loss = 0.0
        correct_valid = 0
        total = 0
        model.eval()
        with torch.no_grad():
            for images, labels in validloader:
                inputs, labels = images.to(device), labels.to(device).float()
                optimizer.zero_grad()
                outputs = model(inputs, return_mask=False)
                new_label = F.one_hot(labels.type(torch.int64),num_classes=2).type(torch.float32)
                loss = criterion(outputs, new_label)
                val_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct_valid += (predicted == labels).sum().item()
    scheduler.step(val_loss)
    # Calculate training accuracy after each epoch
    train_accuracy = 100 * correct_train / total_train
    val_accuracy = correct_valid / total * 100.0
    torch.save(model.state_dict(), 'models_mask/model-'+str(epoch)+'-'+str(val_accuracy)+f'_{seed}.pt')
    Convert_ONNX(model.module,'models_mask/model-'+str(epoch)+'-'+str(val_accuracy)+f'_{seed}.onnx', input_data_mock=inputs)

    print("===========================")
    print('accuracy: ', epoch, train_accuracy,  val_accuracy)
    print('learning rate: ', scheduler.get_last_lr())
    print("===========================")
    if scheduler.get_last_lr()[0] <1e-6:
        break