|
import sys |
|
from functools import partial |
|
|
|
from typing import Callable |
|
from typing import Dict |
|
from typing import Tuple |
|
from typing import Union |
|
from argparse import Namespace |
|
|
|
sys.path.append("vision/references/segmentation") |
|
|
|
import presets |
|
import torch |
|
import torch.utils.data |
|
import torchvision |
|
import utils |
|
from torch import nn |
|
from common import flops_calculation_function |
|
from common import NanSafeConfusionMatrix as ConfusionMatrix |
|
from common import get_coco |
|
|
|
|
|
def get_dataset(args: Namespace, is_train: bool, transform: Callable = None) -> Tuple[torch.utils.data.Dataset, int]: |
|
def sbd(*args, **kwargs): |
|
kwargs.pop("use_v2") |
|
return torchvision.datasets.SBDataset(*args, mode="segmentation", **kwargs) |
|
|
|
def voc(*args, **kwargs): |
|
kwargs.pop("use_v2") |
|
return torchvision.datasets.VOCSegmentation(*args, **kwargs) |
|
|
|
paths = { |
|
"voc": (args.data_path, voc, 21), |
|
"voc_aug": (args.data_path, sbd, 21), |
|
"coco": (args.data_path, get_coco, 21), |
|
"coco_orig": (args.data_path, partial(get_coco, use_orig=True), 81) |
|
} |
|
p, ds_fn, num_classes = paths["coco_orig"] |
|
|
|
if transform is None: |
|
transform = get_transform(is_train, args) |
|
image_set = "train" if is_train else "val" |
|
ds = ds_fn(p, image_set=image_set, transforms=transform, use_v2=args.use_v2) |
|
return ds, num_classes |
|
|
|
|
|
def get_transform(is_train: bool, args: Namespace) -> Callable: |
|
return presets.SegmentationPresetEval(base_size=520, backend=args.backend, use_v2=args.use_v2) |
|
|
|
|
|
def criterion(inputs: Dict[str, torch.Tensor], target: Dict[str, torch.Tensor]) -> torch.Tensor: |
|
losses = {} |
|
for name, x in inputs.items(): |
|
losses[name] = nn.functional.cross_entropy(x, target, ignore_index=255) |
|
|
|
if len(losses) == 1: |
|
return losses["out"] |
|
|
|
return losses["out"] + 0.5 * losses["aux"] |
|
|
|
|
|
def evaluate( |
|
model: torch.nn.Module, |
|
data_loader: torch.utils.data.DataLoader, |
|
device: Union[str, torch.device], |
|
num_classes: int, |
|
criterion: Callable, |
|
) -> Tuple[ConfusionMatrix, float]: |
|
model.eval() |
|
confmat = ConfusionMatrix(num_classes) |
|
metric_logger = utils.MetricLogger(delimiter=" ") |
|
header = "Test:" |
|
num_processed_samples = 0 |
|
with torch.inference_mode(): |
|
for batch_n, (image, target) in enumerate(metric_logger.log_every(data_loader, 100, header)): |
|
image, target = image.to(device), target.to(device) |
|
output = model(image) |
|
loss = criterion(output, target) |
|
output = output["out"] |
|
|
|
confmat.update(target.flatten(), output.argmax(1).flatten()) |
|
|
|
|
|
num_processed_samples += image.shape[0] |
|
|
|
metric_logger.update(loss=loss.item()) |
|
|
|
confmat.reduce_from_all_processes() |
|
|
|
return confmat, metric_logger.loss.global_avg |
|
|
|
|
|
def main(args): |
|
if args.backend.lower() != "pil" and not args.use_v2: |
|
|
|
raise ValueError("Use --use-v2 if you want to use the tv_tensor or tensor backend.") |
|
if args.use_v2: |
|
raise ValueError("v2 is only supported for coco dataset for now.") |
|
|
|
print(args) |
|
|
|
device = torch.device(args.device) |
|
|
|
if args.use_deterministic_algorithms: |
|
torch.backends.cudnn.benchmark = False |
|
torch.use_deterministic_algorithms(True) |
|
else: |
|
torch.backends.cudnn.benchmark = True |
|
|
|
dataset_test, num_classes = get_dataset(args, is_train=False) |
|
|
|
test_sampler = torch.utils.data.SequentialSampler(dataset_test) |
|
|
|
data_loader_test = torch.utils.data.DataLoader( |
|
dataset_test, batch_size=1, sampler=test_sampler, num_workers=args.workers, collate_fn=utils.collate_fn |
|
) |
|
|
|
checkpoint = torch.load(args.model_path) |
|
model = checkpoint["model"] |
|
model.to(device) |
|
model_flops = flops_calculation_function(model=model, input_sample=next(iter(data_loader_test))[0].to(device)) |
|
print(f"Model Flops: {model_flops}M") |
|
|
|
|
|
torch.backends.cudnn.benchmark = False |
|
torch.backends.cudnn.deterministic = True |
|
confmat, loss = evaluate( |
|
model=model, |
|
data_loader=data_loader_test, |
|
device=device, |
|
num_classes=num_classes, |
|
criterion=criterion, |
|
) |
|
print(confmat) |
|
return |
|
|
|
def get_args_parser(add_help=True): |
|
import argparse |
|
|
|
parser = argparse.ArgumentParser(description="PyTorch Segmentation Training", add_help=add_help) |
|
|
|
parser.add_argument("--data-path", default="/datasets01/COCO/022719/", type=str, help="dataset path") |
|
parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)") |
|
parser.add_argument( |
|
"-b", "--batch-size", default=8, type=int, help="images per gpu, the total batch size is $NGPU x batch_size" |
|
) |
|
parser.add_argument("--epochs", default=30, type=int, metavar="N", help="number of total epochs to run") |
|
|
|
parser.add_argument( |
|
"-j", "--workers", default=16, type=int, metavar="N", help="number of data loading workers (default: 16)" |
|
) |
|
parser.add_argument( |
|
"--use-deterministic-algorithms", action="store_true", help="Forces the use of deterministic algorithms only." |
|
) |
|
|
|
|
|
parser.add_argument("--backend", default="PIL", type=str.lower, help="PIL or tensor - case insensitive") |
|
parser.add_argument("--use-v2", action="store_true", help="Use V2 transforms") |
|
parser.add_argument("--model-path", default=None, help="Path to model checkpoint.") |
|
return parser |
|
|
|
|
|
if __name__ == "__main__": |
|
args = get_args_parser().parse_args() |
|
main(args) |
|
|