SchoolInAiProjectWork / tools /train_detector.py
Matteo Sirri
feat: initial commit
169e11c
raw
history blame
15.3 kB
from typing import List
from configs.path_cfg import MOTSYNTH_ROOT, MOTCHA_ROOT, OUTPUT_DIR
import datetime
import os.path as osp
import os
import time
import coloredlogs
import logging
from torchinfo import summary
import torch
import torch.utils.data
from src.detection.vision.mot_data import MOTObjDetect
from src.detection.model_factory import ModelFactory
from src.detection.graph_utils import save_train_loss_plot
import src.detection.vision.presets as presets
import src.detection.vision.utils as utils
from src.detection.vision.engine import train_one_epoch, evaluate
from src.detection.vision.group_by_aspect_ratio import GroupedBatchSampler, create_aspect_ratio_groups
from src.detection.mot_dataset import get_mot_dataset
import torchvision
from torchvision.models.detection.faster_rcnn import fasterrcnn_resnet50_fpn_v2
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
coloredlogs.install(level='DEBUG')
logger = logging.getLogger(__name__)
def get_args_parser(add_help=True):
import argparse
parser = argparse.ArgumentParser(
description="PyTorch Detection Training", add_help=add_help)
# Output directory used to save model, plots and summary
parser.add_argument("--output-dir", default='fasterrcnn_training',
type=str, help="Path to save outputs (default: fasterrcnn_training)")
# Dataset params
parser.add_argument("--train-dataset", default="motsynth_split1",
type=str, help="Dataset name. Please select one of the following: motsynth_split1, motsynth_split2, motsynth_split3, motsynth_split4, MOT17 (default: motsynth_split1)")
parser.add_argument("--val-dataset", default="MOT17",
type=str, help="Dataset name. Please select one of the following: MOT17 (default: MOT17)")
# Transforms params
parser.add_argument(
"--data-augmentation", default="hflip", type=str, help="Data augmentation policy (default: hflip)"
)
# Data Loaders params
parser.add_argument(
"-b", "--batch-size", default=3, type=int, help="Images per gpu (default: 3)"
)
parser.add_argument(
"-j", "--workers", default=0, type=int, metavar="N", help="Number of data loading workers (default: 0)"
)
parser.add_argument("--aspect-ratio-group-factor", default=3,
type=int, help="Aspect ration group factor (default:3)")
# Model param
parser.add_argument(
"--model", default="fasterrcnn_resnet50_fpn", type=str, help="Model name (default: fasterrcnn_resnet50_fpn)")
parser.add_argument(
"--weights", default="DEFAULT", type=str, help="Model weights (default: DEFAULT)"
)
parser.add_argument(
"--backbone", default='resnet50', type=str, help="Type of backbone (default: resnet50)"
)
parser.add_argument(
"--trainable-backbone-layers", default=3, type=int, help="Number of trainable layers of backbone (default: 3)"
)
parser.add_argument(
"--backbone-weights", default="DEFAULT", type=str, help="Backbone weights (default: DEFAULT)"
)
# Device param
parser.add_argument("--device", default="cuda", type=str,
help="device (default: cuda)")
# Test mode param
parser.add_argument(
"--test-only",
dest="test_only",
help="Only test the model",
action="store_true",
)
parser.add_argument(
"--model-eval", type=str, help="model path for test only mode"
)
# Optimizer params
parser.add_argument(
"--lr",
default=0.0025,
type=float,
help="Learning rate (default: 0.0025)",
)
parser.add_argument("--momentum", default=0.9,
type=float, metavar="M", help="Momentum (default: 0.9")
parser.add_argument(
"--wd",
"--weight-decay",
default=1e-4,
type=float,
metavar="W",
help="Weight decay (default: 1e-4)",
dest="weight_decay",
)
# Lr Scheduler params
parser.add_argument(
"--lr-scheduler", default="multisteplr", type=str, help="Name of lr scheduler (default: multisteplr)"
)
parser.add_argument(
"--lr-steps",
default=[16, 22],
nargs="+",
type=int,
help="Decrease lr every step-size epochs (multisteplr scheduler only)",
)
parser.add_argument(
"--lr-gamma", default=0.1, type=float, help="Decrease lr by a factor of lr-gamma (multisteplr scheduler only)"
)
# Mixed precision training params
parser.add_argument("--amp", action="store_true",
help="Use torch.cuda.amp for mixed precision training")
# Resume training params
parser.add_argument("--resume", default="", type=str,
help="path of checkpoint")
# training param
parser.add_argument("--start_epoch", default=0,
type=int, help="start epoch")
parser.add_argument("--epochs", default=30, type=int,
metavar="N", help="number of total epochs to run")
parser.add_argument("--print-freq", default=20,
type=int, help="print frequency")
return parser
def get_transform(train, data_augmentation):
if train:
return presets.DetectionPresetTrain(data_augmentation)
else:
return presets.DetectionPresetEval()
def get_motsynth_dataset(ds_name: str, transforms):
data_path = osp.join(MOTSYNTH_ROOT, 'comb_annotations', f"{ds_name}.json")
dataset = get_mot_dataset(MOTSYNTH_ROOT, data_path, transforms=transforms)
return dataset
def get_MOT17_dataset(split: str, split_seqs: List, transforms):
data_path = osp.join(MOTCHA_ROOT, "MOT17", "train")
dataset = MOTObjDetect(
data_path, transforms=transforms, split_seqs=split_seqs)
return dataset
def create_dataset(ds_name: str, transforms, split=None):
if (ds_name.startswith("motsynth")):
return get_motsynth_dataset(ds_name, transforms)
elif (ds_name.startswith("MOT17")):
if split == "train":
split_seqs = ['MOT17-02-FRCNN', 'MOT17-04-FRCNN',
'MOT17-11-FRCNN', 'MOT17-13-FRCNN']
elif split == "test":
split_seqs = ['MOT17-09-FRCNN', 'MOT17-10-FRCNN', 'MOT17-05-FRCNN']
return get_MOT17_dataset(split, split_seqs, transforms)
else:
logger.error(
"Please, provide a valid dataset as argument. Select one of the following: motsynth_split1, motsynth_split2, motsynth_split3, motsynth_split4, MOT17.")
raise ValueError(ds_name)
def create_data_loader(dataset, split: str, batch_size, workers, aspect_ratio_group_factor=-1):
data_loader = None
if split == "train":
# random sampling on training dataset
train_sampler = torch.utils.data.RandomSampler(dataset)
if aspect_ratio_group_factor >= 0:
group_ids = create_aspect_ratio_groups(
dataset, k=aspect_ratio_group_factor)
train_batch_sampler = GroupedBatchSampler(
train_sampler, group_ids, batch_size)
else:
train_batch_sampler = torch.utils.data.BatchSampler(
train_sampler, batch_size, drop_last=True)
data_loader = torch.utils.data.DataLoader(
dataset, batch_sampler=train_batch_sampler, num_workers=workers, collate_fn=utils.collate_fn
)
elif split == "test":
# sequential sampling on eval dataset
test_sampler = torch.utils.data.SequentialSampler(dataset)
data_loader = torch.utils.data.DataLoader(
dataset, batch_size=1, sampler=test_sampler, num_workers=workers, collate_fn=utils.collate_fn
)
return data_loader
def create_optimizer(model, lr, momentum, weight_decay):
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(
params, lr=lr, momentum=momentum, weight_decay=weight_decay)
return optimizer
def create_lr_scheduler(optimizer, lr_scheduler_type, lr_steps, lr_gamma, epochs):
if lr_scheduler_type == "multisteplr":
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer, milestones=lr_steps, gamma=lr_gamma)
logger.debug(
f"lr_scheduler: {lr_scheduler_type}, milestones: {lr_steps}, gamma: {lr_gamma}")
elif lr_scheduler_type == "cosineannealinglr":
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=epochs)
logger.debug(
f"lr_scheduler: {lr_scheduler_type}, T_max: {epochs}")
else:
raise RuntimeError(
f"Invalid lr scheduler '{lr_scheduler_type}'. Only MultiStepLR and CosineAnnealingLR are supported."
)
return lr_scheduler
def resume_training(model, optimizer, lr_scheduler, scaler, args):
checkpoint = torch.load(args.resume, map_location="cpu")
model.load_state_dict(checkpoint["model"])
optimizer.load_state_dict(checkpoint["optimizer"])
lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
args.start_epoch = checkpoint["epoch"] + 1
if args.amp:
scaler.load_state_dict(checkpoint["scaler"])
def save_model_checkpoint(model, optimizer, lr_scheduler, epoch, scaler, output_dir, args):
if output_dir:
checkpoint = {
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
"lr_scheduler": lr_scheduler.state_dict(),
"args": args,
"epoch": epoch,
}
if args.amp:
checkpoint["scaler"] = scaler.state_dict()
utils.save_on_master(checkpoint, os.path.join(
output_dir, f"model_{epoch}.pth"))
utils.save_on_master(checkpoint, os.path.join(
output_dir, "checkpoint.pth"))
def save_plots(losses_dict, batch_loss_dict, output_dir):
if not losses_dict:
for name, metric in batch_loss_dict.items():
losses_dict[name] = []
for name, metric in batch_loss_dict.items():
losses_dict[name].extend(metric)
save_train_loss_plot(losses_dict, output_dir)
def save_model_summary(model, output_dir, batch_size):
with open(osp.join(output_dir, "summary.txt"), 'w', encoding="utf-8") as f:
print(summary(model,
# (batch_size, color_channels, height, width)
input_size=(batch_size, 3, 1080, 1920),
verbose=0,
col_names=["input_size", "output_size",
"num_params", "kernel_size", "trainable"],
col_width=20,
row_settings=["var_names"]), file=f)
def save_args(output_dir, args):
with open(osp.join(output_dir, "args.txt"), 'w', encoding="utf-8") as f:
print(args, file=f)
def save_evaluate_summary(stats, output_dir):
metrics = ["AP", "AP50", "AP75", "APs", "APm", "APl"]
# the standard metrics
results = {
metric: float(stats[idx] *
100 if stats[idx] >= 0 else "nan")
for idx, metric in enumerate(metrics)
}
with open(osp.join(output_dir, "evaluate.txt"), 'w', encoding="utf-8") as f:
print(results, file=f)
def main(args):
output_dir = None
if args.output_dir:
output_dir = osp.join(
OUTPUT_DIR, 'detection_logs', args.output_dir)
utils.mkdir(output_dir)
output_plots_dir = osp.join(output_dir, "plots")
utils.mkdir(output_plots_dir)
logger.debug("COMMAND LINE ARGUMENTS")
logger.debug(args)
save_args(output_dir, args)
device = torch.device(args.device)
logger.debug(f"DEVICE: {device}")
logger.debug("CREATE DATASETS")
ds_train_name = args.train_dataset
ds_val_name = args.val_dataset
data_augmentation = args.data_augmentation
dataset_train = create_dataset(
ds_train_name, get_transform(True, data_augmentation), "train")
dataset_test = create_dataset(
ds_val_name, get_transform(False, data_augmentation), "test")
logger.debug("CREATE DATA LOADERS")
batch_size = args.batch_size
workers = args.workers
aspect_ratio_group_factor = args.aspect_ratio_group_factor
data_loader_train = create_data_loader(
dataset_train, "train", batch_size, workers, aspect_ratio_group_factor)
data_loader_test = create_data_loader(
dataset_test, "test", batch_size, workers)
if args.test_only:
logger.debug("TEST ONLY")
model = fasterrcnn_resnet50_fpn_v2()
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, 2)
checkpoint = torch.load(args.model_eval, map_location="cuda")
model.load_state_dict(checkpoint["model"])
model.to(device)
coco_evaluator = evaluate(model, data_loader_test,
device=device, iou_types=['bbox'])
save_evaluate_summary(
coco_evaluator.coco_eval['bbox'].stats, output_dir)
return
logger.debug("CREATE MODEL")
model_name = args.model
weights = args.weights
backbone = args.backbone
backbone_weights = args.backbone_weights
trainable_backbone_layers = args.trainable_backbone_layers
model = ModelFactory.get_model(
model_name, weights, backbone, backbone_weights, trainable_backbone_layers)
save_model_summary(model, output_dir, batch_size)
logger.debug("CREATE OPTIMIZER")
lr = args.lr
momentum = args.momentum
weight_decay = args.weight_decay
optimizer = create_optimizer(
model, lr, momentum, weight_decay)
logger.debug("CREATE LR SCHEDULER")
epochs = args.epochs
lr_scheduler_type = args.lr_scheduler.lower()
lr_steps = args.lr_steps
lr_gamma = args.lr_gamma
lr_scheduler = create_lr_scheduler(
optimizer, lr_scheduler_type, lr_steps, lr_gamma, epochs)
logger.debug("CONFIGURE SCALER FOR amp")
scaler = torch.cuda.amp.GradScaler() if args.amp else None
if args.resume:
logger.debug("RESUME TRAINING")
resume_training(model, optimizer, lr_scheduler,
scaler, args)
logger.debug("START TRAINING")
print_freq = args.print_freq
start_epoch = args.start_epoch
losses_dict = {}
start_time = time.time()
for epoch in range(start_epoch, epochs):
_, batch_loss_dict = train_one_epoch(model, optimizer, data_loader_train, device,
epoch, print_freq, scaler)
lr_scheduler.step()
save_plots(losses_dict, batch_loss_dict,
output_dir=output_plots_dir)
coco_evaluator = evaluate(model, data_loader_test,
device=device, iou_types=['bbox'])
save_evaluate_summary(
coco_evaluator.coco_eval['bbox'].stats, output_dir)
save_model_checkpoint(
model, optimizer, lr_scheduler, epoch, scaler, output_dir, args)
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
logger.debug(f"TRAINING TIME: {total_time_str}")
if __name__ == "__main__":
args = get_args_parser().parse_args()
main(args)