Spaces:
Build error
Build error
from typing import List | |
from configs.path_cfg import MOTSYNTH_ROOT, MOTCHA_ROOT, OUTPUT_DIR | |
import datetime | |
import os.path as osp | |
import os | |
import time | |
import coloredlogs | |
import logging | |
from torchinfo import summary | |
import torch | |
import torch.utils.data | |
from src.detection.vision.mot_data import MOTObjDetect | |
from src.detection.model_factory import ModelFactory | |
from src.detection.graph_utils import save_train_loss_plot | |
import src.detection.vision.presets as presets | |
import src.detection.vision.utils as utils | |
from src.detection.vision.engine import train_one_epoch, evaluate | |
from src.detection.vision.group_by_aspect_ratio import GroupedBatchSampler, create_aspect_ratio_groups | |
from src.detection.mot_dataset import get_mot_dataset | |
import torchvision | |
from torchvision.models.detection.faster_rcnn import fasterrcnn_resnet50_fpn_v2 | |
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor | |
coloredlogs.install(level='DEBUG') | |
logger = logging.getLogger(__name__) | |
def get_args_parser(add_help=True): | |
import argparse | |
parser = argparse.ArgumentParser( | |
description="PyTorch Detection Training", add_help=add_help) | |
# Output directory used to save model, plots and summary | |
parser.add_argument("--output-dir", default='fasterrcnn_training', | |
type=str, help="Path to save outputs (default: fasterrcnn_training)") | |
# Dataset params | |
parser.add_argument("--train-dataset", default="motsynth_split1", | |
type=str, help="Dataset name. Please select one of the following: motsynth_split1, motsynth_split2, motsynth_split3, motsynth_split4, MOT17 (default: motsynth_split1)") | |
parser.add_argument("--val-dataset", default="MOT17", | |
type=str, help="Dataset name. Please select one of the following: MOT17 (default: MOT17)") | |
# Transforms params | |
parser.add_argument( | |
"--data-augmentation", default="hflip", type=str, help="Data augmentation policy (default: hflip)" | |
) | |
# Data Loaders params | |
parser.add_argument( | |
"-b", "--batch-size", default=3, type=int, help="Images per gpu (default: 3)" | |
) | |
parser.add_argument( | |
"-j", "--workers", default=0, type=int, metavar="N", help="Number of data loading workers (default: 0)" | |
) | |
parser.add_argument("--aspect-ratio-group-factor", default=3, | |
type=int, help="Aspect ration group factor (default:3)") | |
# Model param | |
parser.add_argument( | |
"--model", default="fasterrcnn_resnet50_fpn", type=str, help="Model name (default: fasterrcnn_resnet50_fpn)") | |
parser.add_argument( | |
"--weights", default="DEFAULT", type=str, help="Model weights (default: DEFAULT)" | |
) | |
parser.add_argument( | |
"--backbone", default='resnet50', type=str, help="Type of backbone (default: resnet50)" | |
) | |
parser.add_argument( | |
"--trainable-backbone-layers", default=3, type=int, help="Number of trainable layers of backbone (default: 3)" | |
) | |
parser.add_argument( | |
"--backbone-weights", default="DEFAULT", type=str, help="Backbone weights (default: DEFAULT)" | |
) | |
# Device param | |
parser.add_argument("--device", default="cuda", type=str, | |
help="device (default: cuda)") | |
# Test mode param | |
parser.add_argument( | |
"--test-only", | |
dest="test_only", | |
help="Only test the model", | |
action="store_true", | |
) | |
parser.add_argument( | |
"--model-eval", type=str, help="model path for test only mode" | |
) | |
# Optimizer params | |
parser.add_argument( | |
"--lr", | |
default=0.0025, | |
type=float, | |
help="Learning rate (default: 0.0025)", | |
) | |
parser.add_argument("--momentum", default=0.9, | |
type=float, metavar="M", help="Momentum (default: 0.9") | |
parser.add_argument( | |
"--wd", | |
"--weight-decay", | |
default=1e-4, | |
type=float, | |
metavar="W", | |
help="Weight decay (default: 1e-4)", | |
dest="weight_decay", | |
) | |
# Lr Scheduler params | |
parser.add_argument( | |
"--lr-scheduler", default="multisteplr", type=str, help="Name of lr scheduler (default: multisteplr)" | |
) | |
parser.add_argument( | |
"--lr-steps", | |
default=[16, 22], | |
nargs="+", | |
type=int, | |
help="Decrease lr every step-size epochs (multisteplr scheduler only)", | |
) | |
parser.add_argument( | |
"--lr-gamma", default=0.1, type=float, help="Decrease lr by a factor of lr-gamma (multisteplr scheduler only)" | |
) | |
# Mixed precision training params | |
parser.add_argument("--amp", action="store_true", | |
help="Use torch.cuda.amp for mixed precision training") | |
# Resume training params | |
parser.add_argument("--resume", default="", type=str, | |
help="path of checkpoint") | |
# training param | |
parser.add_argument("--start_epoch", default=0, | |
type=int, help="start epoch") | |
parser.add_argument("--epochs", default=30, type=int, | |
metavar="N", help="number of total epochs to run") | |
parser.add_argument("--print-freq", default=20, | |
type=int, help="print frequency") | |
return parser | |
def get_transform(train, data_augmentation): | |
if train: | |
return presets.DetectionPresetTrain(data_augmentation) | |
else: | |
return presets.DetectionPresetEval() | |
def get_motsynth_dataset(ds_name: str, transforms): | |
data_path = osp.join(MOTSYNTH_ROOT, 'comb_annotations', f"{ds_name}.json") | |
dataset = get_mot_dataset(MOTSYNTH_ROOT, data_path, transforms=transforms) | |
return dataset | |
def get_MOT17_dataset(split: str, split_seqs: List, transforms): | |
data_path = osp.join(MOTCHA_ROOT, "MOT17", "train") | |
dataset = MOTObjDetect( | |
data_path, transforms=transforms, split_seqs=split_seqs) | |
return dataset | |
def create_dataset(ds_name: str, transforms, split=None): | |
if (ds_name.startswith("motsynth")): | |
return get_motsynth_dataset(ds_name, transforms) | |
elif (ds_name.startswith("MOT17")): | |
if split == "train": | |
split_seqs = ['MOT17-02-FRCNN', 'MOT17-04-FRCNN', | |
'MOT17-11-FRCNN', 'MOT17-13-FRCNN'] | |
elif split == "test": | |
split_seqs = ['MOT17-09-FRCNN', 'MOT17-10-FRCNN', 'MOT17-05-FRCNN'] | |
return get_MOT17_dataset(split, split_seqs, transforms) | |
else: | |
logger.error( | |
"Please, provide a valid dataset as argument. Select one of the following: motsynth_split1, motsynth_split2, motsynth_split3, motsynth_split4, MOT17.") | |
raise ValueError(ds_name) | |
def create_data_loader(dataset, split: str, batch_size, workers, aspect_ratio_group_factor=-1): | |
data_loader = None | |
if split == "train": | |
# random sampling on training dataset | |
train_sampler = torch.utils.data.RandomSampler(dataset) | |
if aspect_ratio_group_factor >= 0: | |
group_ids = create_aspect_ratio_groups( | |
dataset, k=aspect_ratio_group_factor) | |
train_batch_sampler = GroupedBatchSampler( | |
train_sampler, group_ids, batch_size) | |
else: | |
train_batch_sampler = torch.utils.data.BatchSampler( | |
train_sampler, batch_size, drop_last=True) | |
data_loader = torch.utils.data.DataLoader( | |
dataset, batch_sampler=train_batch_sampler, num_workers=workers, collate_fn=utils.collate_fn | |
) | |
elif split == "test": | |
# sequential sampling on eval dataset | |
test_sampler = torch.utils.data.SequentialSampler(dataset) | |
data_loader = torch.utils.data.DataLoader( | |
dataset, batch_size=1, sampler=test_sampler, num_workers=workers, collate_fn=utils.collate_fn | |
) | |
return data_loader | |
def create_optimizer(model, lr, momentum, weight_decay): | |
params = [p for p in model.parameters() if p.requires_grad] | |
optimizer = torch.optim.SGD( | |
params, lr=lr, momentum=momentum, weight_decay=weight_decay) | |
return optimizer | |
def create_lr_scheduler(optimizer, lr_scheduler_type, lr_steps, lr_gamma, epochs): | |
if lr_scheduler_type == "multisteplr": | |
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( | |
optimizer, milestones=lr_steps, gamma=lr_gamma) | |
logger.debug( | |
f"lr_scheduler: {lr_scheduler_type}, milestones: {lr_steps}, gamma: {lr_gamma}") | |
elif lr_scheduler_type == "cosineannealinglr": | |
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( | |
optimizer, T_max=epochs) | |
logger.debug( | |
f"lr_scheduler: {lr_scheduler_type}, T_max: {epochs}") | |
else: | |
raise RuntimeError( | |
f"Invalid lr scheduler '{lr_scheduler_type}'. Only MultiStepLR and CosineAnnealingLR are supported." | |
) | |
return lr_scheduler | |
def resume_training(model, optimizer, lr_scheduler, scaler, args): | |
checkpoint = torch.load(args.resume, map_location="cpu") | |
model.load_state_dict(checkpoint["model"]) | |
optimizer.load_state_dict(checkpoint["optimizer"]) | |
lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) | |
args.start_epoch = checkpoint["epoch"] + 1 | |
if args.amp: | |
scaler.load_state_dict(checkpoint["scaler"]) | |
def save_model_checkpoint(model, optimizer, lr_scheduler, epoch, scaler, output_dir, args): | |
if output_dir: | |
checkpoint = { | |
"model": model.state_dict(), | |
"optimizer": optimizer.state_dict(), | |
"lr_scheduler": lr_scheduler.state_dict(), | |
"args": args, | |
"epoch": epoch, | |
} | |
if args.amp: | |
checkpoint["scaler"] = scaler.state_dict() | |
utils.save_on_master(checkpoint, os.path.join( | |
output_dir, f"model_{epoch}.pth")) | |
utils.save_on_master(checkpoint, os.path.join( | |
output_dir, "checkpoint.pth")) | |
def save_plots(losses_dict, batch_loss_dict, output_dir): | |
if not losses_dict: | |
for name, metric in batch_loss_dict.items(): | |
losses_dict[name] = [] | |
for name, metric in batch_loss_dict.items(): | |
losses_dict[name].extend(metric) | |
save_train_loss_plot(losses_dict, output_dir) | |
def save_model_summary(model, output_dir, batch_size): | |
with open(osp.join(output_dir, "summary.txt"), 'w', encoding="utf-8") as f: | |
print(summary(model, | |
# (batch_size, color_channels, height, width) | |
input_size=(batch_size, 3, 1080, 1920), | |
verbose=0, | |
col_names=["input_size", "output_size", | |
"num_params", "kernel_size", "trainable"], | |
col_width=20, | |
row_settings=["var_names"]), file=f) | |
def save_args(output_dir, args): | |
with open(osp.join(output_dir, "args.txt"), 'w', encoding="utf-8") as f: | |
print(args, file=f) | |
def save_evaluate_summary(stats, output_dir): | |
metrics = ["AP", "AP50", "AP75", "APs", "APm", "APl"] | |
# the standard metrics | |
results = { | |
metric: float(stats[idx] * | |
100 if stats[idx] >= 0 else "nan") | |
for idx, metric in enumerate(metrics) | |
} | |
with open(osp.join(output_dir, "evaluate.txt"), 'w', encoding="utf-8") as f: | |
print(results, file=f) | |
def main(args): | |
output_dir = None | |
if args.output_dir: | |
output_dir = osp.join( | |
OUTPUT_DIR, 'detection_logs', args.output_dir) | |
utils.mkdir(output_dir) | |
output_plots_dir = osp.join(output_dir, "plots") | |
utils.mkdir(output_plots_dir) | |
logger.debug("COMMAND LINE ARGUMENTS") | |
logger.debug(args) | |
save_args(output_dir, args) | |
device = torch.device(args.device) | |
logger.debug(f"DEVICE: {device}") | |
logger.debug("CREATE DATASETS") | |
ds_train_name = args.train_dataset | |
ds_val_name = args.val_dataset | |
data_augmentation = args.data_augmentation | |
dataset_train = create_dataset( | |
ds_train_name, get_transform(True, data_augmentation), "train") | |
dataset_test = create_dataset( | |
ds_val_name, get_transform(False, data_augmentation), "test") | |
logger.debug("CREATE DATA LOADERS") | |
batch_size = args.batch_size | |
workers = args.workers | |
aspect_ratio_group_factor = args.aspect_ratio_group_factor | |
data_loader_train = create_data_loader( | |
dataset_train, "train", batch_size, workers, aspect_ratio_group_factor) | |
data_loader_test = create_data_loader( | |
dataset_test, "test", batch_size, workers) | |
if args.test_only: | |
logger.debug("TEST ONLY") | |
model = fasterrcnn_resnet50_fpn_v2() | |
in_features = model.roi_heads.box_predictor.cls_score.in_features | |
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, 2) | |
checkpoint = torch.load(args.model_eval, map_location="cuda") | |
model.load_state_dict(checkpoint["model"]) | |
model.to(device) | |
coco_evaluator = evaluate(model, data_loader_test, | |
device=device, iou_types=['bbox']) | |
save_evaluate_summary( | |
coco_evaluator.coco_eval['bbox'].stats, output_dir) | |
return | |
logger.debug("CREATE MODEL") | |
model_name = args.model | |
weights = args.weights | |
backbone = args.backbone | |
backbone_weights = args.backbone_weights | |
trainable_backbone_layers = args.trainable_backbone_layers | |
model = ModelFactory.get_model( | |
model_name, weights, backbone, backbone_weights, trainable_backbone_layers) | |
save_model_summary(model, output_dir, batch_size) | |
logger.debug("CREATE OPTIMIZER") | |
lr = args.lr | |
momentum = args.momentum | |
weight_decay = args.weight_decay | |
optimizer = create_optimizer( | |
model, lr, momentum, weight_decay) | |
logger.debug("CREATE LR SCHEDULER") | |
epochs = args.epochs | |
lr_scheduler_type = args.lr_scheduler.lower() | |
lr_steps = args.lr_steps | |
lr_gamma = args.lr_gamma | |
lr_scheduler = create_lr_scheduler( | |
optimizer, lr_scheduler_type, lr_steps, lr_gamma, epochs) | |
logger.debug("CONFIGURE SCALER FOR amp") | |
scaler = torch.cuda.amp.GradScaler() if args.amp else None | |
if args.resume: | |
logger.debug("RESUME TRAINING") | |
resume_training(model, optimizer, lr_scheduler, | |
scaler, args) | |
logger.debug("START TRAINING") | |
print_freq = args.print_freq | |
start_epoch = args.start_epoch | |
losses_dict = {} | |
start_time = time.time() | |
for epoch in range(start_epoch, epochs): | |
_, batch_loss_dict = train_one_epoch(model, optimizer, data_loader_train, device, | |
epoch, print_freq, scaler) | |
lr_scheduler.step() | |
save_plots(losses_dict, batch_loss_dict, | |
output_dir=output_plots_dir) | |
coco_evaluator = evaluate(model, data_loader_test, | |
device=device, iou_types=['bbox']) | |
save_evaluate_summary( | |
coco_evaluator.coco_eval['bbox'].stats, output_dir) | |
save_model_checkpoint( | |
model, optimizer, lr_scheduler, epoch, scaler, output_dir, args) | |
total_time = time.time() - start_time | |
total_time_str = str(datetime.timedelta(seconds=int(total_time))) | |
logger.debug(f"TRAINING TIME: {total_time_str}") | |
if __name__ == "__main__": | |
args = get_args_parser().parse_args() | |
main(args) | |