graph2plan / Network /train.py
Zai
test
06db6e9
import os
import json
import sys
import shutil
import random
import pickle
import datetime
import argparse
import pathlib as path
import tqdm
import logging
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
from tensorboardX import SummaryWriter
from model.utils import int_tuple, str_tuple, bool_flag
from model.metrics import iou,MetricAverage,image_acc,image_acc_ignore,binary_image_acc
from model.model import Model
from model.floorplan import FloorPlanDataset,floorplan_collate_fn
from model.loss import *
from model.box_utils import *
from model.utils import *
from ignite.contrib.handlers.tensorboard_logger import *
from ignite.contrib.handlers import *
from ignite.contrib.metrics import *
from ignite.metrics.accuracy import _BaseClassification
from ignite.engine import *
from ignite.handlers import *
from ignite.metrics import *
def parse_args():
parser = argparse.ArgumentParser()
''' Dataset '''
parser.add_argument('--dataset_dir', default='./data', type=str)
parser.add_argument('--image_size', default='128,128', type=int_tuple)
parser.add_argument('--input_dim', default=3, type=int)
parser.add_argument('--with_house', default='0', type=bool_flag)
parser.add_argument('--pos_dim', default=25, type=int)
parser.add_argument('--area_dim', default=10, type=int)
''' Dataloader '''
parser.add_argument('--batch_size', default=20, type=int)
parser.add_argument('--workers', default=8, type=int)
parser.add_argument('--train_shuffle', default='1', type=bool_flag)
''' Model '''
# architecture
parser.add_argument('--gene_layout', default='1', type=bool_flag)
parser.add_argument('--box_refine', default='1', type=bool_flag)
# input
parser.add_argument('--embedding_dim', default=128,type=int)
# refine
parser.add_argument('--refinement_dims', default='1024, 512, 256, 128, 64',type=int_tuple)
# box refine
parser.add_argument('--box_refine_arch', default='I15,C3-64-2,C3-128-2,C3-256-2',type=str)
parser.add_argument('--roi_cat_feature',default='1',type=bool_flag)
# control
parser.add_argument('--gt_box', default=0, type=bool_flag)
parser.add_argument('--relative', default=1, type=bool_flag)
''' Loss '''
parser.add_argument('--mutex', default=1, type=bool_flag)
parser.add_argument('--inside', default=1, type=bool_flag)
parser.add_argument('--coverage', default=1, type=bool_flag)
parser.add_argument('--render', default=1, type=bool_flag)
parser.add_argument('--nsample', default=100,type=int)
parser.add_argument('--loss_refine', default=0, type=bool_flag)
parser.add_argument('--render_refine', default=0, type=bool_flag)
''' Optimizer '''
parser.add_argument('--optimizer',default='Adam',type=str)
parser.add_argument('--scheduler',default='plateau',type=str)
parser.add_argument('--learning_rate', default=1e-4, type=float)
parser.add_argument('--decay_rate', default=1e-4, type=float)
parser.add_argument('--step_size', default=10, type=float)
parser.add_argument('--step_rate', default=0.5, type=float)
''' Checkpoints '''
parser.add_argument('--save_interval', default=5, type=int)
parser.add_argument('--n_saved', default=20, type=int)
parser.add_argument('--pretrain', default=None, type=str)
parser.add_argument('--skip_train', default=0, type=bool_flag)
''' Trainer '''
parser.add_argument('--seed', default=74269,type=int)
parser.add_argument('--epoch', default=101,type=int)
parser.add_argument('--start_epoch',default=None,type=int)
''' Others '''
parser.add_argument('--gpu', default='0', type=str)
parser.add_argument('--multi_gpu', default=None, type=str)
parser.add_argument('--suffix',default=None,type=str)
parser.add_argument('--debug', default=0, type=bool_flag)
parser.add_argument('--test', default=0, type=bool_flag)
return parser.parse_args()
def check_manual_seed(args):
seed = args.seed or random.randint(1, 10000)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
def get_model(args):
return Model(embedding_dim=args.embedding_dim,
image_size=args.image_size,
input_dim = args.input_dim,
attribute_dim=args.pos_dim+args.area_dim,
refinement_dims=args.refinement_dims if args.gene_layout else None,
box_refine_arch=args.box_refine_arch if args.box_refine else None,
roi_cat_feature=args.roi_cat_feature)
def get_dataset(args,split='valid'):
return FloorPlanDataset(f'{args.dataset_dir}/data_{split}.mat')
def get_dataloader(args,dataset,split):
print(f"{split},shuffle:",split=='train' and args.train_shuffle and (not args.debug))
return DataLoader(
dataset,
batch_size=args.batch_size,
shuffle=True if split=='train' and args.train_shuffle and (not args.debug) else False,
num_workers=args.workers,
drop_last=True if split=='train' else False,
collate_fn=floorplan_collate_fn
)
def get_data_loaders(args):
train_dataset = get_dataset(args,'train' if not args.debug else 'valid') if not args.skip_train else None
valid_dataset = get_dataset(args,'valid')
test_dataset = get_dataset(args,'test')
train_loader = get_dataloader(args,train_dataset,'train') if not args.skip_train else None
valid_loader = get_dataloader(args,valid_dataset,'valid')
test_loader = get_dataloader(args,test_dataset,'test')
return train_loader,valid_loader,test_loader
def get_optimizer(model,args):
if args.optimizer == 'SGD':
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0)
elif args.optimizer == 'Adam':
optimizer = torch.optim.Adam(
model.parameters(),
lr=args.learning_rate,
betas=(0.9, 0.999),
eps=1e-08,
weight_decay=args.decay_rate
)
elif args.optimizer == 'AdamW':
optimizer = torch.optim.AdamW(
model.parameters(),
lr = args.learning_rate,
weight_decay=args.decay_rate
)
return optimizer
def get_scheduler(optimizer,args):
if args.scheduler == 'step':
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.step_rate)
else:
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,mode='max',factor=args.step_rate,patience=args.step_size,threshold=0.005,verbose=True)
return scheduler
def get_losses(args):
loss = {}
weight = torch.ones(15).cuda()
weight[13]=weight[14]=0 # ignore unused category
if args.gene_layout:
loss['gene_ce'] = torch.nn.CrossEntropyLoss(weight=weight)
loss['box_mse'] = torch.nn.SmoothL1Loss()
if args.box_refine:
loss['box_ref_mse'] = torch.nn.SmoothL1Loss()
if args.mutex:
loss['mutex'] = MutexLoss(nsample=args.nsample)
if args.inside:
loss['inside'] = InsideLoss(nsample=args.nsample)
if args.coverage:
loss['coverage'] = CoverageLoss(nsample=args.nsample)
if args.render:
loss['render'] = BoxRenderLoss(nsample=args.nsample)
return loss
def batch_cuda(batch):
batch = list(batch)
for i in range(len(batch)):
if isinstance(batch[i],torch.Tensor):
batch[i] = batch[i].cuda()
elif isinstance(batch[i],list) and isinstance(batch[i][0],torch.Tensor):
batch[i] = [e.cuda() for e in batch[i]]
return batch
def main(args):
args.epoch=args.epoch if not args.debug else 6
print("Create dir...")
start_date = str(datetime.datetime.now().strftime('%Y-%m-%d'))+("" if not args.debug else "_debug")+("" if not args.test else "_test")
if not os.path.exists(f'../experiment'):
os.mkdir(f'../experiment')
experiment_dir = path.Path(f'../experiment/{start_date}')
experiment_dir.mkdir(exist_ok=True)
start_time = str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')) + '' if args.suffix is None else args.suffix
file_dir = path.Path(f'{experiment_dir}/DeepLayout_{start_time}')
file_dir.mkdir(exist_ok=True)
checkpoints_dir = file_dir.joinpath('checkpoints/')
checkpoints_dir.mkdir(exist_ok=True)
log_dir = file_dir.joinpath('logs/')
log_dir.mkdir(exist_ok=True)
shutil.copy(__file__,log_dir/'train.py')
shutil.copytree('./model',log_dir/'model')
output_dir = file_dir.joinpath('output/')
output_dir.mkdir(exist_ok=True)
logger = logging.getLogger()
logger.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
file_handler = logging.FileHandler(str(log_dir)+'/log.txt')
file_handler.setLevel(logging.INFO)
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
if args.skip_train:
logger.info(f'python {args.argv}')
else:
logger.info(f'python {args.argv} --skip_train 1 --pretrain ')
logger.info(args)
logger.info('---------------------------------------------------TRANING---------------------------------------------------')
logger.info(f'Use seed: {args.seed}')
# check_manual_seed(args)
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu if args.multi_gpu is None else args.multi_gpu
print("Create dataloader...")
train_loader,valid_loader,test_loader = get_data_loaders(args)
print("Create model...")
model = get_model(args)
print("Gene:",model.refinement_net!=None and args.gene_layout)
print("Refine:",args.box_refine)
print("Cat feat:",args.roi_cat_feature)
print("GT BOX:",args.gt_box)
print("Iniside Loss:",args.inside)
print("Coverage Loss:",args.coverage)
print("Mutex Loss:",args.mutex)
print("Render Loss:",args.render)
logger.info(argparse.Namespace(embedding_dim=args.embedding_dim,
image_size=args.image_size,
input_dim = args.input_dim,
attribute_dim=args.pos_dim+args.area_dim,
refinement_dims=args.refinement_dims if args.gene_layout else None,
box_refine_arch=args.box_refine_arch if args.box_refine else None,
roi_cat_feature=args.roi_cat_feature))
logger.info(str(model))
optimizer = get_optimizer(model,args)
scheduler = get_scheduler(optimizer,args)
loss = get_losses(args)
if args.pretrain is not None:
model.load_state_dict(torch.load(args.pretrain))
print("Cuda...")
model.cuda()
def update(engine,batch):
model.train()
optimizer.zero_grad()
boundary,inside_box,objs,attrs,triples,layout,boxes,inside_coords,obj_to_img,triple_to_img,name = batch_cuda(batch)
if args.relative: boxes = box_rel2abs(boxes,inside_box,obj_to_img)
model_out = model(
objs,
triples,
boundary,
obj_to_img = obj_to_img,
attributes=attrs,
boxes_gt= boxes if args.gt_box else None,
generate = args.gene_layout and engine.state.epoch>1,
refine = args.box_refine and engine.state.epoch>2,
relative = args.relative,
inside_box=inside_box if args.relative else None,
)
boxes_pred, gene_layout, boxes_refine = model_out
total_loss = 0
loss_items = {}
epoch = engine.state.epoch
step_weight = [0.1,0.5,1.0]
for name in loss:
l = None
if name=='box_mse':
l = loss[name](boxes_pred,boxes)
else:
if epoch>1:
if name=='gene_ce':
l = step_weight[epoch-2 if epoch<=3 else -1]*loss[name](gene_layout,layout)
elif name=='mutex':
l = 0.1*loss[name](boxes_pred,obj_to_img,objs)
if args.box_refine and args.loss_refine and epoch>2: l+=loss[name](boxes_refine,obj_to_img,objs)
elif name=='inside':
l = 0.1*loss[name](boxes_pred,inside_box,obj_to_img)
if args.box_refine and args.loss_refine and epoch>2: l+=loss[name](boxes_refine,inside_box,obj_to_img)
elif name=='coverage':
l = 0.1*loss[name](boxes_pred,inside_coords,obj_to_img)
if args.box_refine and args.loss_refine and epoch>2: l+=loss[name](boxes_refine,inside_coords,obj_to_img)
elif name=='render':
l = loss[name](boxes_pred,boxes)
if args.box_refine and args.loss_refine and epoch>2: l+=loss[name](boxes_refine,boxes)
if epoch>2:
if name=='box_ref_mse':
l = step_weight[epoch-3 if epoch<=4 else -1]*loss[name](boxes_refine,boxes)
if l is not None:
total_loss+=l
loss_items[name]=l.item()
loss_items['total_loss'] = total_loss.item()
total_loss.backward()
optimizer.step()
return loss_items
def inference(engine,batch):
model.eval()
with torch.no_grad():
boundary,inside_box,objs,attrs,triples,layout,boxes,inside_coords,obj_to_img,triple_to_img,name = batch_cuda(batch)
if args.relative: boxes = box_rel2abs(boxes,inside_box,obj_to_img)
model_out = model(
objs,
triples,
boundary,
obj_to_img = obj_to_img,
attributes=attrs,
boxes_gt= boxes if args.gt_box else None,
generate = args.gene_layout,
refine = args.box_refine,
relative = args.relative,
inside_box=inside_box if args.relative else None,
)
boxes_pred, gene_layout, boxes_refine = model_out
total_loss = 0
loss_items = {}
for name in loss:
l = None
if name=='box_mse':
l = loss[name](boxes_pred,boxes)
if engine.state.epoch>1:
if name=='gene_ce':
l = loss[name](gene_layout,layout)
elif name=='mutex':
l = 0.1*loss[name](boxes_pred,obj_to_img,objs)
if args.box_refine and args.loss_refine: l+=0.1*loss[name](boxes_refine,obj_to_img,objs)
elif name=='inside':
l = 0.1*loss[name](boxes_pred,inside_box,obj_to_img)
if args.box_refine and args.loss_refine: l+=0.1*loss[name](boxes_refine,inside_box,obj_to_img)
elif name=='coverage':
l = 0.1*loss[name](boxes_pred,inside_coords,obj_to_img)
if args.box_refine and args.loss_refine: l+=0.1*loss[name](boxes_refine,inside_coords,obj_to_img)
elif name=='render':
l = loss[name](boxes_pred,boxes)
if args.box_refine and args.loss_refine: l+=loss[name](boxes_refine,boxes)
if engine.state.epoch>2:
if name=='box_ref_mse':
l = loss[name](boxes_refine,boxes)
if l is not None:
total_loss+=l
loss_items[name]=l.item()
loss_items['total_loss'] = total_loss.item()
# boxes pred
boxes_pred = boxes_pred.detach()
boxes_pred = centers_to_extents(boxes_pred)
if args.gene_layout:
gene_layout = gene_layout*boundary[:,:1]
# boxes refine
if args.box_refine:
boxes_refine = boxes_refine.detach()
boxes_refine = centers_to_extents(boxes_refine)
# gt
boxes = centers_to_extents(boxes)
return {
'loss':loss_items,
'pred':[
boxes_pred,
gene_layout.detach() if args.gene_layout else None,
boxes_refine if args.box_refine else None,
],
'gt':[layout,boxes]
}
print("Create trainer...")
optimizer.step()
scheduler.step(0)
trainer = Engine(update)
valid_evaluator = Engine(inference)
if args.start_epoch is not None:
@trainer.on(Events.STARTED)
def set_up_state(engine):
engine.state.epoch = args.start_epoch
total_func = lambda e:(e.state.metrics['box_iou']+(e.state.metrics['gene_acc'] if args.gene_layout else 0)+(e.state.metrics['box_refine_iou'] if args.box_refine else 0))
@valid_evaluator.on(Events.COMPLETED)
def schedual(engine):
optimizer.step()
if args.scheduler == 'step':
scheduler.step()
else:
scheduler.step(total_func(engine))
@trainer.on(Events.EPOCH_COMPLETED)
def evaluate(engine):
valid_evaluator.run(valid_loader)
# Metrics
MetricAverage(output_transform=lambda output:iou(output['pred'][0],output['gt'][1])).attach(valid_evaluator,'box_iou')
if args.gene_layout:
MetricAverage(output_transform=lambda output:image_acc_ignore(output['pred'][1],output['gt'][0],13)).attach(valid_evaluator,'gene_acc')
if args.box_refine:
MetricAverage(output_transform=lambda output:iou(output['pred'][2],output['gt'][1])).attach(valid_evaluator,'box_refine_iou')
metrics = ['img_acc','box_iou','mask_acc']
# TQDM
ProgressBar(persist=True).attach(trainer, output_transform=lambda o:{'loss':o['total_loss']}, metric_names='all')
ProgressBar(persist=False).attach(valid_evaluator, output_transform=lambda o:{'loss':o['loss']['total_loss']},metric_names='all')
# Tensorboard
tb_logger = TensorboardLogger(log_dir=log_dir)
tb_logger.attach(trainer,
log_handler=OutputHandler(tag="train",output_transform=lambda o: o,metric_names='all'),
event_name=Events.ITERATION_COMPLETED)
tb_logger.attach(trainer,
log_handler=OptimizerParamsHandler(optimizer),
event_name=Events.ITERATION_STARTED)
tb_logger.attach(valid_evaluator,
log_handler=OutputHandler(tag="valid",output_transform=lambda o:o['loss'],metric_names='all', global_step_transform=global_step_from_engine(trainer)),
event_name=Events.EPOCH_COMPLETED)
# Logging
@trainer.on(Events.EPOCH_COMPLETED)
def log_results(engine):
logging.info(f'Train, Epoch{engine.state.epoch}, Loss: {str(engine.state.output)}')
@valid_evaluator.on(Events.EPOCH_COMPLETED)
def log_results(engine):
loss = engine.state.output['loss']
metrics = engine.state.metrics
logging.info(f'Valid, Epoch{engine.state.epoch}, Loss: {str(loss)}')
logging.info(f'Valid, Epoch{engine.state.epoch}, Metrics: {str(metrics)}')
# Checkpoint
epoch_saver = ModelCheckpoint(checkpoints_dir, 'epoch',save_interval=args.save_interval,n_saved=args.n_saved, require_empty=False, create_dir=True)
latest_saver = ModelCheckpoint(checkpoints_dir, 'latest',score_function=lambda e:e.state.epoch,n_saved=1, require_empty=False, create_dir=True)
loss_saver = ModelCheckpoint(checkpoints_dir, 'loss',score_function=lambda e:-e.state.output['loss']['total_loss'],n_saved=1, require_empty=False, create_dir=True)
trainer.add_event_handler(Events.EPOCH_COMPLETED, latest_saver, {'model': model,'opt':optimizer})
trainer.add_event_handler(Events.EPOCH_COMPLETED, epoch_saver, {'model': model,'opt':optimizer})
valid_evaluator.add_event_handler(Events.COMPLETED, loss_saver, {'model': model})
if not args.skip_train:
trainer.run(train_loader,max_epochs=args.epoch)
tb_logger.close()
output = {}
def test(engine,batch):
model.eval()
with torch.no_grad():
boundary,inside_box,objs,attrs,triples,layout,boxes,inside_coords,obj_to_img,triple_to_img,name = batch_cuda(batch)
model_out = model(
objs,
triples,
boundary,
obj_to_img = obj_to_img,
attributes=attrs,
boxes_gt= boxes if args.gt_box else None,
generate = args.gene_layout,
refine = args.box_refine,
relative = args.relative,
inside_box=inside_box if args.relative else None,
)
boxes_pred, gene_layout, boxes_refine = model_out
''' box: x_c,y_c,w,h -> x0,y0,x1,y1 '''
# boxes pred
boxes_pred = boxes_pred.detach()
boxes_pred = centers_to_extents(boxes_pred)
# boxes refine
if args.box_refine:
boxes_refine = boxes_refine.detach()
boxes_refine = centers_to_extents(boxes_refine)
# gt
if args.relative: boxes = box_rel2abs(boxes,inside_box,obj_to_img)
boxes = centers_to_extents(boxes)
''' layout: B*C*H*W->B*H*W '''
if args.gene_layout:
gene_layout = gene_layout*boundary[:,:1]
gene_preds = torch.argmax(gene_layout.softmax(1).detach(),dim=1)
''' layout with outside'''
for i in range(len(layout)):
mask = boundary[i,0]==0
if args.gene_layout:
gene_preds[i][mask]=13
''' mertics '''
# box iou
box_ious = iou(boxes_pred,boxes)
box_refine_ious = None
if args.box_refine:
box_refine_ious = iou(boxes_refine,boxes)
gene_acc_all = None
gene_acc_fg = None
if args.gene_layout:
gene_acc_all = image_acc(gene_preds,layout)
gene_acc_fg = image_acc_ignore(gene_preds,layout,13)
''' save output '''
for i in range(len(layout)):
''' objs '''
obj = objs[obj_to_img==i].cpu().numpy()
''' box '''
box_pred = boxes_pred[obj_to_img==i]
box_pred = box_pred.cpu().numpy()
box_iou = box_ious[obj_to_img==i].view(-1).cpu().numpy()
box_refine = None
if args.box_refine:
box_refine = boxes_refine[obj_to_img==i].cpu().numpy()
box_refine_iou = box_refine_ious[obj_to_img==i].view(-1).cpu().numpy()
''' layout '''
if args.gene_layout:
gene_pred = gene_preds[i].cpu().numpy().astype('uint8')
output[name[i]] = {
'obj':obj,
'box_gt':boxes[obj_to_img==i].cpu().numpy(),
'box_pred':box_pred,
'box_iou':box_iou,
'box_refine':box_refine if args.box_refine else None,
'box_refine_iou':box_refine_iou if args.box_refine else None,
'gene_pred':gene_pred if args.gene_layout else None,
'gene_acc_all': gene_acc_all[i].item() if args.gene_layout else None,
'gene_acc_fg':gene_acc_fg[i].item() if args.gene_layout else None
}
return {
'pred':[
boxes_pred,#0
gene_preds if args.gene_layout else None,#1
boxes_refine if args.box_refine else None,#2
],
'gt':[layout,boxes]
}
test_evaluator = Engine(test)
MetricAverage(output_transform=lambda output:iou(output['pred'][0],output['gt'][1])).attach(test_evaluator,'box_iou')
if args.gene_layout:
MetricAverage(output_transform=lambda output:image_acc_ignore(output['pred'][1],output['gt'][0],13)).attach(test_evaluator,'gene_acc')
MetricAverage(output_transform=lambda output:image_acc(output['pred'][1],output['gt'][0])).attach(test_evaluator,'gene_acc_all')
if args.box_refine:
MetricAverage(output_transform=lambda output:iou(output['pred'][2],output['gt'][1])).attach(test_evaluator,'box_refine_iou')
ProgressBar(persist=False).attach(test_evaluator)
@test_evaluator.on(Events.COMPLETED)
def save_metrics(engine):
metrics = engine.state.metrics
with open(f'{output_dir}/output_{start_time}_metrics.json','w') as f:
f.write(str(metrics))
if not args.skip_train:
test_evaluator.run(valid_loader)
else:
test_evaluator.run(test_loader)
with open(f'{output_dir}/output_{start_time}.pkl','wb') as f:
pickle.dump(output,f,pickle.HIGHEST_PROTOCOL)
if __name__ == "__main__":
args = parse_args()
args.argv = ' '.join(sys.argv)
main(args)