File size: 5,837 Bytes
d2d52b7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
import sys
from functools import partial
from typing import Callable
from typing import Dict
from typing import Tuple
from typing import Union
from argparse import Namespace
sys.path.append("vision/references/segmentation")
import presets
import torch
import torch.utils.data
import torchvision
import utils
from torch import nn
from common import flops_calculation_function
from common import NanSafeConfusionMatrix as ConfusionMatrix
from common import get_coco
def get_dataset(args: Namespace, is_train: bool, transform: Callable = None) -> Tuple[torch.utils.data.Dataset, int]:
def sbd(*args, **kwargs):
kwargs.pop("use_v2")
return torchvision.datasets.SBDataset(*args, mode="segmentation", **kwargs)
def voc(*args, **kwargs):
kwargs.pop("use_v2")
return torchvision.datasets.VOCSegmentation(*args, **kwargs)
paths = {
"voc": (args.data_path, voc, 21),
"voc_aug": (args.data_path, sbd, 21),
"coco": (args.data_path, get_coco, 21),
"coco_orig": (args.data_path, partial(get_coco, use_orig=True), 81)
}
p, ds_fn, num_classes = paths["coco_orig"]
if transform is None:
transform = get_transform(is_train, args)
image_set = "train" if is_train else "val"
ds = ds_fn(p, image_set=image_set, transforms=transform, use_v2=args.use_v2)
return ds, num_classes
def get_transform(is_train: bool, args: Namespace) -> Callable:
return presets.SegmentationPresetEval(base_size=520, backend=args.backend, use_v2=args.use_v2)
def criterion(inputs: Dict[str, torch.Tensor], target: Dict[str, torch.Tensor]) -> torch.Tensor:
losses = {}
for name, x in inputs.items():
losses[name] = nn.functional.cross_entropy(x, target, ignore_index=255)
if len(losses) == 1:
return losses["out"]
return losses["out"] + 0.5 * losses["aux"]
def evaluate(
model: torch.nn.Module,
data_loader: torch.utils.data.DataLoader,
device: Union[str, torch.device],
num_classes: int,
criterion: Callable,
) -> Tuple[ConfusionMatrix, float]:
model.eval()
confmat = ConfusionMatrix(num_classes)
metric_logger = utils.MetricLogger(delimiter=" ")
header = "Test:"
num_processed_samples = 0
with torch.inference_mode():
for batch_n, (image, target) in enumerate(metric_logger.log_every(data_loader, 100, header)):
image, target = image.to(device), target.to(device)
output = model(image)
loss = criterion(output, target)
output = output["out"]
confmat.update(target.flatten(), output.argmax(1).flatten())
# FIXME need to take into account that the datasets
# could have been padded in distributed setup
num_processed_samples += image.shape[0]
metric_logger.update(loss=loss.item())
confmat.reduce_from_all_processes()
return confmat, metric_logger.loss.global_avg
def main(args):
if args.backend.lower() != "pil" and not args.use_v2:
# TODO: Support tensor backend in V1?
raise ValueError("Use --use-v2 if you want to use the tv_tensor or tensor backend.")
if args.use_v2:
raise ValueError("v2 is only supported for coco dataset for now.")
print(args)
device = torch.device(args.device)
if args.use_deterministic_algorithms:
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)
else:
torch.backends.cudnn.benchmark = True
dataset_test, num_classes = get_dataset(args, is_train=False)
test_sampler = torch.utils.data.SequentialSampler(dataset_test)
data_loader_test = torch.utils.data.DataLoader(
dataset_test, batch_size=1, sampler=test_sampler, num_workers=args.workers, collate_fn=utils.collate_fn
)
checkpoint = torch.load(args.model_path)
model = checkpoint["model"]
model.to(device)
model_flops = flops_calculation_function(model=model, input_sample=next(iter(data_loader_test))[0].to(device))
print(f"Model Flops: {model_flops}M")
# We disable the cudnn benchmarking because it can noticeably affect the accuracy
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
confmat, loss = evaluate(
model=model,
data_loader=data_loader_test,
device=device,
num_classes=num_classes,
criterion=criterion,
)
print(confmat)
return
def get_args_parser(add_help=True):
import argparse
parser = argparse.ArgumentParser(description="PyTorch Segmentation Training", add_help=add_help)
parser.add_argument("--data-path", default="/datasets01/COCO/022719/", type=str, help="dataset path")
parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
parser.add_argument(
"-b", "--batch-size", default=8, type=int, help="images per gpu, the total batch size is $NGPU x batch_size"
)
parser.add_argument("--epochs", default=30, type=int, metavar="N", help="number of total epochs to run")
parser.add_argument(
"-j", "--workers", default=16, type=int, metavar="N", help="number of data loading workers (default: 16)"
)
parser.add_argument(
"--use-deterministic-algorithms", action="store_true", help="Forces the use of deterministic algorithms only."
)
# distributed training parameters
parser.add_argument("--backend", default="PIL", type=str.lower, help="PIL or tensor - case insensitive")
parser.add_argument("--use-v2", action="store_true", help="Use V2 transforms")
parser.add_argument("--model-path", default=None, help="Path to model checkpoint.")
return parser
if __name__ == "__main__":
args = get_args_parser().parse_args()
main(args)
|