File size: 26,491 Bytes
938e515 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 |
# Copyright (c) Facebook, Inc. and its affiliates.
import itertools
import logging
import numpy as np
import operator
import pickle
from collections import OrderedDict, defaultdict
from typing import Any, Callable, Dict, List, Optional, Union
import torch
import torch.utils.data as torchdata
from tabulate import tabulate
from termcolor import colored
from detectron2.config import configurable
from detectron2.structures import BoxMode
from detectron2.utils.comm import get_world_size
from detectron2.utils.env import seed_all_rng
from detectron2.utils.file_io import PathManager
from detectron2.utils.logger import _log_api_usage, log_first_n
from .catalog import DatasetCatalog, MetadataCatalog
from .common import AspectRatioGroupedDataset, DatasetFromList, MapDataset, ToIterableDataset
from .dataset_mapper import DatasetMapper
from .detection_utils import check_metadata_consistency
from .samplers import (
InferenceSampler,
RandomSubsetTrainingSampler,
RepeatFactorTrainingSampler,
TrainingSampler,
)
"""
This file contains the default logic to build a dataloader for training or testing.
"""
__all__ = [
"build_batch_data_loader",
"build_detection_train_loader",
"build_detection_test_loader",
"get_detection_dataset_dicts",
"load_proposals_into_dataset",
"print_instances_class_histogram",
]
def filter_images_with_only_crowd_annotations(dataset_dicts):
"""
Filter out images with none annotations or only crowd annotations
(i.e., images without non-crowd annotations).
A common training-time preprocessing on COCO dataset.
Args:
dataset_dicts (list[dict]): annotations in Detectron2 Dataset format.
Returns:
list[dict]: the same format, but filtered.
"""
num_before = len(dataset_dicts)
def valid(anns):
for ann in anns:
if ann.get("iscrowd", 0) == 0:
return True
return False
dataset_dicts = [x for x in dataset_dicts if valid(x["annotations"])]
num_after = len(dataset_dicts)
logger = logging.getLogger(__name__)
logger.info(
"Removed {} images with no usable annotations. {} images left.".format(
num_before - num_after, num_after
)
)
return dataset_dicts
def filter_images_with_few_keypoints(dataset_dicts, min_keypoints_per_image):
"""
Filter out images with too few number of keypoints.
Args:
dataset_dicts (list[dict]): annotations in Detectron2 Dataset format.
Returns:
list[dict]: the same format as dataset_dicts, but filtered.
"""
num_before = len(dataset_dicts)
def visible_keypoints_in_image(dic):
# Each keypoints field has the format [x1, y1, v1, ...], where v is visibility
annotations = dic["annotations"]
return sum(
(np.array(ann["keypoints"][2::3]) > 0).sum()
for ann in annotations
if "keypoints" in ann
)
dataset_dicts = [
x for x in dataset_dicts if visible_keypoints_in_image(x) >= min_keypoints_per_image
]
num_after = len(dataset_dicts)
logger = logging.getLogger(__name__)
logger.info(
"Removed {} images with fewer than {} keypoints.".format(
num_before - num_after, min_keypoints_per_image
)
)
return dataset_dicts
def load_proposals_into_dataset(dataset_dicts, proposal_file):
"""
Load precomputed object proposals into the dataset.
The proposal file should be a pickled dict with the following keys:
- "ids": list[int] or list[str], the image ids
- "boxes": list[np.ndarray], each is an Nx4 array of boxes corresponding to the image id
- "objectness_logits": list[np.ndarray], each is an N sized array of objectness scores
corresponding to the boxes.
- "bbox_mode": the BoxMode of the boxes array. Defaults to ``BoxMode.XYXY_ABS``.
Args:
dataset_dicts (list[dict]): annotations in Detectron2 Dataset format.
proposal_file (str): file path of pre-computed proposals, in pkl format.
Returns:
list[dict]: the same format as dataset_dicts, but added proposal field.
"""
logger = logging.getLogger(__name__)
logger.info("Loading proposals from: {}".format(proposal_file))
with PathManager.open(proposal_file, "rb") as f:
proposals = pickle.load(f, encoding="latin1")
# Rename the key names in D1 proposal files
rename_keys = {"indexes": "ids", "scores": "objectness_logits"}
for key in rename_keys:
if key in proposals:
proposals[rename_keys[key]] = proposals.pop(key)
# Fetch the indexes of all proposals that are in the dataset
# Convert image_id to str since they could be int.
img_ids = set({str(record["image_id"]) for record in dataset_dicts})
id_to_index = {str(id): i for i, id in enumerate(proposals["ids"]) if str(id) in img_ids}
# Assuming default bbox_mode of precomputed proposals are 'XYXY_ABS'
bbox_mode = BoxMode(proposals["bbox_mode"]) if "bbox_mode" in proposals else BoxMode.XYXY_ABS
for record in dataset_dicts:
# Get the index of the proposal
i = id_to_index[str(record["image_id"])]
boxes = proposals["boxes"][i]
objectness_logits = proposals["objectness_logits"][i]
# Sort the proposals in descending order of the scores
inds = objectness_logits.argsort()[::-1]
record["proposal_boxes"] = boxes[inds]
record["proposal_objectness_logits"] = objectness_logits[inds]
record["proposal_bbox_mode"] = bbox_mode
return dataset_dicts
def print_instances_class_histogram(dataset_dicts, class_names):
"""
Args:
dataset_dicts (list[dict]): list of dataset dicts.
class_names (list[str]): list of class names (zero-indexed).
"""
num_classes = len(class_names)
hist_bins = np.arange(num_classes + 1)
histogram = np.zeros((num_classes,), dtype=int)
for entry in dataset_dicts:
annos = entry["annotations"]
classes = np.asarray(
[x["category_id"] for x in annos if not x.get("iscrowd", 0)], dtype=int
)
if len(classes):
assert classes.min() >= 0, f"Got an invalid category_id={classes.min()}"
assert (
classes.max() < num_classes
), f"Got an invalid category_id={classes.max()} for a dataset of {num_classes} classes"
histogram += np.histogram(classes, bins=hist_bins)[0]
N_COLS = min(6, len(class_names) * 2)
def short_name(x):
# make long class names shorter. useful for lvis
if len(x) > 13:
return x[:11] + ".."
return x
data = list(
itertools.chain(*[[short_name(class_names[i]), int(v)] for i, v in enumerate(histogram)])
)
total_num_instances = sum(data[1::2])
data.extend([None] * (N_COLS - (len(data) % N_COLS)))
if num_classes > 1:
data.extend(["total", total_num_instances])
data = itertools.zip_longest(*[data[i::N_COLS] for i in range(N_COLS)])
table = tabulate(
data,
headers=["category", "#instances"] * (N_COLS // 2),
tablefmt="pipe",
numalign="left",
stralign="center",
)
log_first_n(
logging.INFO,
"Distribution of instances among all {} categories:\n".format(num_classes)
+ colored(table, "cyan"),
key="message",
)
def get_detection_dataset_dicts(
names,
filter_empty=True,
min_keypoints=0,
proposal_files=None,
check_consistency=True,
):
"""
Load and prepare dataset dicts for instance detection/segmentation and semantic segmentation.
Args:
names (str or list[str]): a dataset name or a list of dataset names
filter_empty (bool): whether to filter out images without instance annotations
min_keypoints (int): filter out images with fewer keypoints than
`min_keypoints`. Set to 0 to do nothing.
proposal_files (list[str]): if given, a list of object proposal files
that match each dataset in `names`.
check_consistency (bool): whether to check if datasets have consistent metadata.
Returns:
list[dict]: a list of dicts following the standard dataset dict format.
"""
if isinstance(names, str):
names = [names]
assert len(names), names
available_datasets = DatasetCatalog.keys()
names_set = set(names)
if not names_set.issubset(available_datasets):
logger = logging.getLogger(__name__)
logger.warning(
"The following dataset names are not registered in the DatasetCatalog: "
f"{names_set - available_datasets}. "
f"Available datasets are {available_datasets}"
)
dataset_dicts = [DatasetCatalog.get(dataset_name) for dataset_name in names]
if isinstance(dataset_dicts[0], torchdata.Dataset):
if len(dataset_dicts) > 1:
# ConcatDataset does not work for iterable style dataset.
# We could support concat for iterable as well, but it's often
# not a good idea to concat iterables anyway.
return torchdata.ConcatDataset(dataset_dicts)
return dataset_dicts[0]
for dataset_name, dicts in zip(names, dataset_dicts):
assert len(dicts), "Dataset '{}' is empty!".format(dataset_name)
if proposal_files is not None:
assert len(names) == len(proposal_files)
# load precomputed proposals from proposal files
dataset_dicts = [
load_proposals_into_dataset(dataset_i_dicts, proposal_file)
for dataset_i_dicts, proposal_file in zip(dataset_dicts, proposal_files)
]
dataset_dicts = list(itertools.chain.from_iterable(dataset_dicts))
has_instances = "annotations" in dataset_dicts[0]
if filter_empty and has_instances:
dataset_dicts = filter_images_with_only_crowd_annotations(dataset_dicts)
if min_keypoints > 0 and has_instances:
dataset_dicts = filter_images_with_few_keypoints(dataset_dicts, min_keypoints)
if check_consistency and has_instances:
try:
class_names = MetadataCatalog.get(names[0]).thing_classes
check_metadata_consistency("thing_classes", names)
print_instances_class_histogram(dataset_dicts, class_names)
except AttributeError: # class names are not available for this dataset
pass
assert len(dataset_dicts), "No valid data found in {}.".format(",".join(names))
return dataset_dicts
def build_batch_data_loader(
dataset,
sampler,
total_batch_size,
*,
aspect_ratio_grouping=False,
num_workers=0,
collate_fn=None,
drop_last: bool = True,
single_gpu_batch_size=None,
seed=None,
**kwargs,
):
"""
Build a batched dataloader. The main differences from `torch.utils.data.DataLoader` are:
1. support aspect ratio grouping options
2. use no "batch collation", because this is common for detection training
Args:
dataset (torch.utils.data.Dataset): a pytorch map-style or iterable dataset.
sampler (torch.utils.data.sampler.Sampler or None): a sampler that produces indices.
Must be provided iff. ``dataset`` is a map-style dataset.
total_batch_size, aspect_ratio_grouping, num_workers, collate_fn: see
:func:`build_detection_train_loader`.
single_gpu_batch_size: You can specify either `single_gpu_batch_size` or `total_batch_size`.
`single_gpu_batch_size` specifies the batch size that will be used for each gpu/process.
`total_batch_size` allows you to specify the total aggregate batch size across gpus.
It is an error to supply a value for both.
drop_last (bool): if ``True``, the dataloader will drop incomplete batches.
Returns:
iterable[list]. Length of each list is the batch size of the current
GPU. Each element in the list comes from the dataset.
"""
if single_gpu_batch_size:
if total_batch_size:
raise ValueError(
"""total_batch_size and single_gpu_batch_size are mutually incompatible.
Please specify only one. """
)
batch_size = single_gpu_batch_size
else:
world_size = get_world_size()
assert (
total_batch_size > 0 and total_batch_size % world_size == 0
), "Total batch size ({}) must be divisible by the number of gpus ({}).".format(
total_batch_size, world_size
)
batch_size = total_batch_size // world_size
logger = logging.getLogger(__name__)
logger.info("Making batched data loader with batch_size=%d", batch_size)
if isinstance(dataset, torchdata.IterableDataset):
assert sampler is None, "sampler must be None if dataset is IterableDataset"
else:
dataset = ToIterableDataset(dataset, sampler, shard_chunk_size=batch_size)
generator = None
if seed is not None:
generator = torch.Generator()
generator.manual_seed(seed)
if aspect_ratio_grouping:
assert drop_last, "Aspect ratio grouping will drop incomplete batches."
data_loader = torchdata.DataLoader(
dataset,
num_workers=num_workers,
collate_fn=operator.itemgetter(0), # don't batch, but yield individual elements
worker_init_fn=worker_init_reset_seed,
generator=generator,
**kwargs
) # yield individual mapped dict
data_loader = AspectRatioGroupedDataset(data_loader, batch_size)
if collate_fn is None:
return data_loader
return MapDataset(data_loader, collate_fn)
else:
return torchdata.DataLoader(
dataset,
batch_size=batch_size,
drop_last=drop_last,
num_workers=num_workers,
collate_fn=trivial_batch_collator if collate_fn is None else collate_fn,
worker_init_fn=worker_init_reset_seed,
generator=generator,
**kwargs
)
def _get_train_datasets_repeat_factors(cfg) -> Dict[str, float]:
repeat_factors = cfg.DATASETS.TRAIN_REPEAT_FACTOR
assert all(len(tup) == 2 for tup in repeat_factors)
name_to_weight = defaultdict(lambda: 1, dict(repeat_factors))
# The sampling weights map should only contain datasets in train config
unrecognized = set(name_to_weight.keys()) - set(cfg.DATASETS.TRAIN)
assert not unrecognized, f"unrecognized datasets: {unrecognized}"
logger = logging.getLogger(__name__)
logger.info(f"Found repeat factors: {list(name_to_weight.items())}")
# pyre-fixme[7]: Expected `Dict[str, float]` but got `DefaultDict[typing.Any, int]`.
return name_to_weight
def _build_weighted_sampler(cfg, enable_category_balance=False):
dataset_repeat_factors = _get_train_datasets_repeat_factors(cfg)
# OrderedDict to guarantee order of values() consistent with repeat factors
dataset_name_to_dicts = OrderedDict(
{
name: get_detection_dataset_dicts(
[name],
filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS,
min_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE
if cfg.MODEL.KEYPOINT_ON
else 0,
proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN
if cfg.MODEL.LOAD_PROPOSALS
else None,
)
for name in cfg.DATASETS.TRAIN
}
)
# Repeat factor for every sample in the dataset
repeat_factors = [
[dataset_repeat_factors[dsname]] * len(dataset_name_to_dicts[dsname])
for dsname in cfg.DATASETS.TRAIN
]
repeat_factors = list(itertools.chain.from_iterable(repeat_factors))
repeat_factors = torch.tensor(repeat_factors)
logger = logging.getLogger(__name__)
if enable_category_balance:
"""
1. Calculate repeat factors using category frequency for each dataset and then merge them.
2. Element wise dot producting the dataset frequency repeat factors with
the category frequency repeat factors gives the final repeat factors.
"""
category_repeat_factors = [
RepeatFactorTrainingSampler.repeat_factors_from_category_frequency(
dataset_dict, cfg.DATALOADER.REPEAT_THRESHOLD
)
for dataset_dict in dataset_name_to_dicts.values()
]
# flatten the category repeat factors from all datasets
category_repeat_factors = list(itertools.chain.from_iterable(category_repeat_factors))
category_repeat_factors = torch.tensor(category_repeat_factors)
repeat_factors = torch.mul(category_repeat_factors, repeat_factors)
repeat_factors = repeat_factors / torch.min(repeat_factors)
logger.info(
"Using WeightedCategoryTrainingSampler with repeat_factors={}".format(
cfg.DATASETS.TRAIN_REPEAT_FACTOR
)
)
else:
logger.info(
"Using WeightedTrainingSampler with repeat_factors={}".format(
cfg.DATASETS.TRAIN_REPEAT_FACTOR
)
)
sampler = RepeatFactorTrainingSampler(repeat_factors)
return sampler
def _train_loader_from_config(cfg, mapper=None, *, dataset=None, sampler=None):
if dataset is None:
dataset = get_detection_dataset_dicts(
cfg.DATASETS.TRAIN,
filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS,
min_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE
if cfg.MODEL.KEYPOINT_ON
else 0,
proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN if cfg.MODEL.LOAD_PROPOSALS else None,
)
_log_api_usage("dataset." + cfg.DATASETS.TRAIN[0])
if mapper is None:
mapper = DatasetMapper(cfg, True)
if sampler is None:
sampler_name = cfg.DATALOADER.SAMPLER_TRAIN
logger = logging.getLogger(__name__)
if isinstance(dataset, torchdata.IterableDataset):
logger.info("Not using any sampler since the dataset is IterableDataset.")
sampler = None
else:
logger.info("Using training sampler {}".format(sampler_name))
if sampler_name == "TrainingSampler":
sampler = TrainingSampler(len(dataset))
elif sampler_name == "RepeatFactorTrainingSampler":
repeat_factors = RepeatFactorTrainingSampler.repeat_factors_from_category_frequency(
dataset, cfg.DATALOADER.REPEAT_THRESHOLD
)
sampler = RepeatFactorTrainingSampler(repeat_factors)
elif sampler_name == "RandomSubsetTrainingSampler":
sampler = RandomSubsetTrainingSampler(
len(dataset), cfg.DATALOADER.RANDOM_SUBSET_RATIO
)
elif sampler_name == "WeightedTrainingSampler":
sampler = _build_weighted_sampler(cfg)
elif sampler_name == "WeightedCategoryTrainingSampler":
sampler = _build_weighted_sampler(cfg, enable_category_balance=True)
else:
raise ValueError("Unknown training sampler: {}".format(sampler_name))
return {
"dataset": dataset,
"sampler": sampler,
"mapper": mapper,
"total_batch_size": cfg.SOLVER.IMS_PER_BATCH,
"aspect_ratio_grouping": cfg.DATALOADER.ASPECT_RATIO_GROUPING,
"num_workers": cfg.DATALOADER.NUM_WORKERS,
}
@configurable(from_config=_train_loader_from_config)
def build_detection_train_loader(
dataset,
*,
mapper,
sampler=None,
total_batch_size,
aspect_ratio_grouping=True,
num_workers=0,
collate_fn=None,
**kwargs
):
"""
Build a dataloader for object detection with some default features.
Args:
dataset (list or torch.utils.data.Dataset): a list of dataset dicts,
or a pytorch dataset (either map-style or iterable). It can be obtained
by using :func:`DatasetCatalog.get` or :func:`get_detection_dataset_dicts`.
mapper (callable): a callable which takes a sample (dict) from dataset and
returns the format to be consumed by the model.
When using cfg, the default choice is ``DatasetMapper(cfg, is_train=True)``.
sampler (torch.utils.data.sampler.Sampler or None): a sampler that produces
indices to be applied on ``dataset``.
If ``dataset`` is map-style, the default sampler is a :class:`TrainingSampler`,
which coordinates an infinite random shuffle sequence across all workers.
Sampler must be None if ``dataset`` is iterable.
total_batch_size (int): total batch size across all workers.
aspect_ratio_grouping (bool): whether to group images with similar
aspect ratio for efficiency. When enabled, it requires each
element in dataset be a dict with keys "width" and "height".
num_workers (int): number of parallel data loading workers
collate_fn: a function that determines how to do batching, same as the argument of
`torch.utils.data.DataLoader`. Defaults to do no collation and return a list of
data. No collation is OK for small batch size and simple data structures.
If your batch size is large and each sample contains too many small tensors,
it's more efficient to collate them in data loader.
Returns:
torch.utils.data.DataLoader:
a dataloader. Each output from it is a ``list[mapped_element]`` of length
``total_batch_size / num_workers``, where ``mapped_element`` is produced
by the ``mapper``.
"""
if isinstance(dataset, list):
dataset = DatasetFromList(dataset, copy=False)
if mapper is not None:
dataset = MapDataset(dataset, mapper)
if isinstance(dataset, torchdata.IterableDataset):
assert sampler is None, "sampler must be None if dataset is IterableDataset"
else:
if sampler is None:
sampler = TrainingSampler(len(dataset))
assert isinstance(sampler, torchdata.Sampler), f"Expect a Sampler but got {type(sampler)}"
return build_batch_data_loader(
dataset,
sampler,
total_batch_size,
aspect_ratio_grouping=aspect_ratio_grouping,
num_workers=num_workers,
collate_fn=collate_fn,
**kwargs
)
def _test_loader_from_config(cfg, dataset_name, mapper=None):
"""
Uses the given `dataset_name` argument (instead of the names in cfg), because the
standard practice is to evaluate each test set individually (not combining them).
"""
if isinstance(dataset_name, str):
dataset_name = [dataset_name]
dataset = get_detection_dataset_dicts(
dataset_name,
filter_empty=False,
proposal_files=[
cfg.DATASETS.PROPOSAL_FILES_TEST[list(cfg.DATASETS.TEST).index(x)] for x in dataset_name
]
if cfg.MODEL.LOAD_PROPOSALS
else None,
)
if mapper is None:
mapper = DatasetMapper(cfg, False)
return {
"dataset": dataset,
"mapper": mapper,
"num_workers": cfg.DATALOADER.NUM_WORKERS,
"sampler": InferenceSampler(len(dataset))
if not isinstance(dataset, torchdata.IterableDataset)
else None,
}
@configurable(from_config=_test_loader_from_config)
def build_detection_test_loader(
dataset: Union[List[Any], torchdata.Dataset],
*,
mapper: Callable[[Dict[str, Any]], Any],
sampler: Optional[torchdata.Sampler] = None,
batch_size: int = 1,
num_workers: int = 0,
collate_fn: Optional[Callable[[List[Any]], Any]] = None,
) -> torchdata.DataLoader:
"""
Similar to `build_detection_train_loader`, with default batch size = 1,
and sampler = :class:`InferenceSampler`. This sampler coordinates all workers
to produce the exact set of all samples.
Args:
dataset: a list of dataset dicts,
or a pytorch dataset (either map-style or iterable). They can be obtained
by using :func:`DatasetCatalog.get` or :func:`get_detection_dataset_dicts`.
mapper: a callable which takes a sample (dict) from dataset
and returns the format to be consumed by the model.
When using cfg, the default choice is ``DatasetMapper(cfg, is_train=False)``.
sampler: a sampler that produces
indices to be applied on ``dataset``. Default to :class:`InferenceSampler`,
which splits the dataset across all workers. Sampler must be None
if `dataset` is iterable.
batch_size: the batch size of the data loader to be created.
Default to 1 image per worker since this is the standard when reporting
inference time in papers.
num_workers: number of parallel data loading workers
collate_fn: same as the argument of `torch.utils.data.DataLoader`.
Defaults to do no collation and return a list of data.
Returns:
DataLoader: a torch DataLoader, that loads the given detection
dataset, with test-time transformation and batching.
Examples:
::
data_loader = build_detection_test_loader(
DatasetRegistry.get("my_test"),
mapper=DatasetMapper(...))
# or, instantiate with a CfgNode:
data_loader = build_detection_test_loader(cfg, "my_test")
"""
if isinstance(dataset, list):
dataset = DatasetFromList(dataset, copy=False)
if mapper is not None:
dataset = MapDataset(dataset, mapper)
if isinstance(dataset, torchdata.IterableDataset):
assert sampler is None, "sampler must be None if dataset is IterableDataset"
else:
if sampler is None:
sampler = InferenceSampler(len(dataset))
return torchdata.DataLoader(
dataset,
batch_size=batch_size,
sampler=sampler,
drop_last=False,
num_workers=num_workers,
collate_fn=trivial_batch_collator if collate_fn is None else collate_fn,
)
def trivial_batch_collator(batch):
"""
A batch collator that does nothing.
"""
return batch
def worker_init_reset_seed(worker_id):
initial_seed = torch.initial_seed() % 2**31
seed_all_rng(initial_seed + worker_id)
|