Spaces:
Build error
Build error
File size: 4,865 Bytes
169e11c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import math
import sys
import time
import torch
import torchvision.models.detection.faster_rcnn
from . import utils
from . import coco_eval
from . import coco_utils
def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, scaler=None):
model.train()
metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter("lr", utils.SmoothedValue(
window_size=1, fmt="{value:.6f}"))
header = f"Epoch: [{epoch}]"
lr_scheduler = None
if epoch == 0:
warmup_factor = 1.0 / 1000
warmup_iters = min(1000, len(data_loader) - 1)
lr_scheduler = torch.optim.lr_scheduler.LinearLR(
optimizer, start_factor=warmup_factor, total_iters=warmup_iters
)
losses_dict = {
"lr": [],
"loss": [],
# loss rpn
"loss_objectness": [],
"loss_rpn_box_reg": [],
# roi heads
"loss_classifier": [],
"loss_box_reg": [],
}
for images, targets in metric_logger.log_every(data_loader, print_freq, header):
try:
images = list(image.to(device) for image in images)
targets = [{k: v.to(device) for k, v in t.items()}
for t in targets]
with torch.cuda.amp.autocast(enabled=scaler is not None):
loss_dict = model(images, targets)
losses = sum(loss for loss in loss_dict.values())
# reduce losses over all GPUs for logging purposes
loss_dict_reduced = utils.reduce_dict(loss_dict)
losses_reduced = sum(loss for loss in loss_dict_reduced.values())
loss_value = losses_reduced.item()
# if problem with loss see below
if not math.isfinite(loss_value):
print(f"Loss is {loss_value}, stopping training")
print(loss_dict_reduced)
sys.exit(1)
except Exception as exp:
print("ERROR", str(exp))
torch.save({'img': images, 'targets': targets},
'error_causing_batch.pth')
raise RuntimeError
optimizer.zero_grad()
if scaler is not None:
scaler.scale(losses).backward()
scaler.step(optimizer)
scaler.update()
else:
losses.backward()
optimizer.step()
if lr_scheduler is not None:
lr_scheduler.step()
metric_logger.update(loss=losses_reduced, **loss_dict_reduced)
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
for name, meter in metric_logger.meters.items():
losses_dict[name].append(meter.global_avg)
return metric_logger, losses_dict
def _get_iou_types(model):
model_without_ddp = model
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
model_without_ddp = model.module
iou_types = ["bbox"]
if isinstance(model_without_ddp, torchvision.models.detection.MaskRCNN):
iou_types.append("segm")
if isinstance(model_without_ddp, torchvision.models.detection.KeypointRCNN):
iou_types.append("keypoints")
return iou_types
@ torch.inference_mode()
def evaluate(model, data_loader, device, iou_types=None):
n_threads = torch.get_num_threads()
# FIXME remove this and make paste_masks_in_image run on the GPU
torch.set_num_threads(1)
cpu_device = torch.device("cpu")
model.eval()
metric_logger = utils.MetricLogger(delimiter=" ")
header = "Test:"
coco = coco_utils.get_coco_api_from_dataset(data_loader.dataset)
if iou_types is None:
iou_types = _get_iou_types(model)
coco_evaluator = coco_eval.CocoEvaluator(coco, iou_types)
for images, targets in metric_logger.log_every(data_loader, 100, header):
images = list(img.to(device) for img in images)
if torch.cuda.is_available():
torch.cuda.synchronize()
model_time = time.time()
outputs = model(images)
outputs = [{k: v.to(cpu_device) for k, v in t.items()}
for t in outputs]
model_time = time.time() - model_time
res = {target["image_id"].item(): output for target,
output in zip(targets, outputs)}
evaluator_time = time.time()
coco_evaluator.update(res)
evaluator_time = time.time() - evaluator_time
metric_logger.update(model_time=model_time,
evaluator_time=evaluator_time)
# gather the stats from all processes
metric_logger.synchronize_between_processes()
print("Averaged stats:", metric_logger)
coco_evaluator.synchronize_between_processes()
# accumulate predictions from all images and print table with results
coco_evaluator.accumulate()
coco_evaluator.summarize()
torch.set_num_threads(n_threads)
return coco_evaluator
|