Spaces:
Build error
Build error
import math | |
import sys | |
import time | |
import torch | |
import torchvision.models.detection.faster_rcnn | |
from . import utils | |
from . import coco_eval | |
from . import coco_utils | |
def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, scaler=None): | |
model.train() | |
metric_logger = utils.MetricLogger(delimiter=" ") | |
metric_logger.add_meter("lr", utils.SmoothedValue( | |
window_size=1, fmt="{value:.6f}")) | |
header = f"Epoch: [{epoch}]" | |
lr_scheduler = None | |
if epoch == 0: | |
warmup_factor = 1.0 / 1000 | |
warmup_iters = min(1000, len(data_loader) - 1) | |
lr_scheduler = torch.optim.lr_scheduler.LinearLR( | |
optimizer, start_factor=warmup_factor, total_iters=warmup_iters | |
) | |
losses_dict = { | |
"lr": [], | |
"loss": [], | |
# loss rpn | |
"loss_objectness": [], | |
"loss_rpn_box_reg": [], | |
# roi heads | |
"loss_classifier": [], | |
"loss_box_reg": [], | |
} | |
for images, targets in metric_logger.log_every(data_loader, print_freq, header): | |
try: | |
images = list(image.to(device) for image in images) | |
targets = [{k: v.to(device) for k, v in t.items()} | |
for t in targets] | |
with torch.cuda.amp.autocast(enabled=scaler is not None): | |
loss_dict = model(images, targets) | |
losses = sum(loss for loss in loss_dict.values()) | |
# reduce losses over all GPUs for logging purposes | |
loss_dict_reduced = utils.reduce_dict(loss_dict) | |
losses_reduced = sum(loss for loss in loss_dict_reduced.values()) | |
loss_value = losses_reduced.item() | |
# if problem with loss see below | |
if not math.isfinite(loss_value): | |
print(f"Loss is {loss_value}, stopping training") | |
print(loss_dict_reduced) | |
sys.exit(1) | |
except Exception as exp: | |
print("ERROR", str(exp)) | |
torch.save({'img': images, 'targets': targets}, | |
'error_causing_batch.pth') | |
raise RuntimeError | |
optimizer.zero_grad() | |
if scaler is not None: | |
scaler.scale(losses).backward() | |
scaler.step(optimizer) | |
scaler.update() | |
else: | |
losses.backward() | |
optimizer.step() | |
if lr_scheduler is not None: | |
lr_scheduler.step() | |
metric_logger.update(loss=losses_reduced, **loss_dict_reduced) | |
metric_logger.update(lr=optimizer.param_groups[0]["lr"]) | |
for name, meter in metric_logger.meters.items(): | |
losses_dict[name].append(meter.global_avg) | |
return metric_logger, losses_dict | |
def _get_iou_types(model): | |
model_without_ddp = model | |
if isinstance(model, torch.nn.parallel.DistributedDataParallel): | |
model_without_ddp = model.module | |
iou_types = ["bbox"] | |
if isinstance(model_without_ddp, torchvision.models.detection.MaskRCNN): | |
iou_types.append("segm") | |
if isinstance(model_without_ddp, torchvision.models.detection.KeypointRCNN): | |
iou_types.append("keypoints") | |
return iou_types | |
def evaluate(model, data_loader, device, iou_types=None): | |
n_threads = torch.get_num_threads() | |
# FIXME remove this and make paste_masks_in_image run on the GPU | |
torch.set_num_threads(1) | |
cpu_device = torch.device("cpu") | |
model.eval() | |
metric_logger = utils.MetricLogger(delimiter=" ") | |
header = "Test:" | |
coco = coco_utils.get_coco_api_from_dataset(data_loader.dataset) | |
if iou_types is None: | |
iou_types = _get_iou_types(model) | |
coco_evaluator = coco_eval.CocoEvaluator(coco, iou_types) | |
for images, targets in metric_logger.log_every(data_loader, 100, header): | |
images = list(img.to(device) for img in images) | |
if torch.cuda.is_available(): | |
torch.cuda.synchronize() | |
model_time = time.time() | |
outputs = model(images) | |
outputs = [{k: v.to(cpu_device) for k, v in t.items()} | |
for t in outputs] | |
model_time = time.time() - model_time | |
res = {target["image_id"].item(): output for target, | |
output in zip(targets, outputs)} | |
evaluator_time = time.time() | |
coco_evaluator.update(res) | |
evaluator_time = time.time() - evaluator_time | |
metric_logger.update(model_time=model_time, | |
evaluator_time=evaluator_time) | |
# gather the stats from all processes | |
metric_logger.synchronize_between_processes() | |
print("Averaged stats:", metric_logger) | |
coco_evaluator.synchronize_between_processes() | |
# accumulate predictions from all images and print table with results | |
coco_evaluator.accumulate() | |
coco_evaluator.summarize() | |
torch.set_num_threads(n_threads) | |
return coco_evaluator | |