File size: 4,236 Bytes
3ef85e9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
# Copyright 2022-present NAVER Corp.
# CC BY-NC-SA 4.0
# Available only for non-commercial use
from pdb import set_trace as bb
import os
import torch
import torch.optim as optim
import torchvision.transforms as tvf
from tools import common, trainer
from datasets import *
from core.conv_mixer import ConvMixer
from core.losses import *
def parse_args():
import argparse
parser = argparse.ArgumentParser("Script to train PUMP")
parser.add_argument("--pretrained", type=str, default="", help='pretrained model path')
parser.add_argument("--save-path", type=str, required=True, help='directory to save model')
parser.add_argument("--epochs", type=int, default=50, help='number of training epochs')
parser.add_argument("--batch-size", "--bs", type=int, default=16, help="batch size")
parser.add_argument("--learning-rate", "--lr", type=str, default=1e-4)
parser.add_argument("--weight-decay", "--wd", type=float, default=5e-4)
parser.add_argument("--threads", type=int, default=8, help='number of worker threads')
parser.add_argument("--device", default='cuda')
args = parser.parse_args()
return args
def main( args ):
device = args.device
common.mkdir_for(args.save_path)
# Create data loader
db = BalancedCatImagePairs(
3125, SyntheticImagePairs(RandomWebImages(0,52),distort='RandomTilting(0.5)'),
4875, SyntheticImagePairs(SfM120k_Images(),distort='RandomTilting(0.5)'),
8000, SfM120k_Pairs())
db = FastPairLoader(db,
crop=256, transform='RandomRotation(20), RandomScale(256,1536,ar=1.3,can_upscale=True), PixelNoise(25)',
p_swap=0.5, p_flip=0.5, scale_jitter=0.5)
print("Training image database =", db)
data_loader = torch.utils.data.DataLoader(db, batch_size=args.batch_size, shuffle=True,
num_workers=args.threads, collate_fn=collate_ordered, pin_memory=False, drop_last=True,
worker_init_fn=WorkerWithRngInit())
# create network
net = ConvMixer(output_dim=128, hidden_dim=512, depth=7, patch_size=4, kernel_size=9)
print(f"\n>> Creating {type(net).__name__} net ( Model size: {common.model_size(net)/1e6:.1f}M parameters )")
# create losses
loss = MultiLoss(alpha=0.3,
loss_sup = PixelAPLoss(nq=20, inner_bw=True, sampler=NghSampler(ngh=7)),
loss_unsup = DeepMatchingLoss(eps=0.03))
# create optimizer
optimizer = optim.Adam( [p for p in net.parameters() if p.requires_grad],
lr=args.learning_rate, weight_decay=args.weight_decay)
train = MyTrainer(net, loss, optimizer).to(device)
# initialization
final_model_path = osp.join(args.save_path,'model.pt')
last_model_path = osp.join(args.save_path,'model.pt.last')
if osp.exists( final_model_path ):
print('Already trained, nothing to do!')
return
elif args.pretrained:
train.load( args.pretrained )
elif osp.exists( last_model_path ):
train.load( last_model_path )
train = train.to(args.device)
if ',' in os.environ.get('CUDA_VISIBLE_DEVICES',''):
train.distribute()
# Training loop #
while train.epoch < args.epochs:
# shuffle dataset (select new pairs)
data_loader.dataset.set_epoch(train.epoch)
train(data_loader)
train.save(last_model_path)
# save final model
torch.save(train.model.state_dict(), open(final_model_path,'wb'))
totensor = tvf.Compose([
common.ToTensor(),
tvf.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
class MyTrainer (trainer.Trainer):
""" This class implements the network training.
Below is the function I need to overload to explain how to do the backprop.
"""
def forward_backward(self, inputs):
assert torch.is_grad_enabled() and self.net.training
(img1, img2), labels = inputs
output1 = self.net(totensor(img1))
output2 = self.net(totensor(img2))
loss, details = trainer.get_loss(self.loss(output1, output2, img1=img1, img2=img2, **labels))
trainer.backward(loss)
return details
if __name__ == '__main__':
main(parse_args())
|