yuangpeng commited on
Commit
cba5002
·
1 Parent(s): a04b1de

feat(data): support custom dataset cache (#1584)

Browse files
README.md CHANGED
@@ -122,9 +122,9 @@ python -m yolox.tools.train -n yolox-s -d 8 -b 64 --fp16 -o [--cache]
122
  * -d: number of gpu devices
123
  * -b: total batch size, the recommended number for -b is num-gpu * 8
124
  * --fp16: mixed precision training
125
- * --cache: caching imgs into RAM to accelarate training, which need large system RAM.
 
126
 
127
-
128
 
129
  When using -f, the above commands are equivalent to:
130
  ```shell
@@ -140,7 +140,8 @@ We also support multi-nodes training. Just add the following args:
140
  * --num\_machines: num of your total training nodes
141
  * --machine\_rank: specify the rank of each node
142
 
143
- Suppose you want to train YOLOX on 2 machines, and your master machines's IP is 123.123.123.123, use port 12312 and TCP.
 
144
  On master machine, run
145
  ```shell
146
  python tools/train.py -n yolox-s -b 128 --dist-url tcp://123.123.123.123:12312 --num_machines 2 --machine_rank 0
@@ -163,7 +164,8 @@ python tools/train.py -n yolox-s -d 8 -b 64 --fp16 -o [--cache] --logger wandb w
163
 
164
  An example wandb dashboard is available [here](https://wandb.ai/manan-goel/yolox-nano/runs/3pzfeom0)
165
 
166
- **Others**
 
167
  See more information with the following command:
168
  ```shell
169
  python -m yolox.tools.train --help
@@ -202,6 +204,7 @@ python -m yolox.tools.eval -n yolox-s -c yolox_s.pth -b 1 -d 1 --conf 0.001 --f
202
  <summary>Tutorials</summary>
203
 
204
  * [Training on custom data](docs/train_custom_data.md)
 
205
  * [Manipulating training image size](docs/manipulate_training_image_size.md)
206
  * [Freezing model](docs/freeze_module.md)
207
 
 
122
  * -d: number of gpu devices
123
  * -b: total batch size, the recommended number for -b is num-gpu * 8
124
  * --fp16: mixed precision training
125
+ * --cache: caching imgs into RAM to accelarate training, which need large system RAM.
126
+
127
 
 
128
 
129
  When using -f, the above commands are equivalent to:
130
  ```shell
 
140
  * --num\_machines: num of your total training nodes
141
  * --machine\_rank: specify the rank of each node
142
 
143
+ Suppose you want to train YOLOX on 2 machines, and your master machines's IP is 123.123.123.123, use port 12312 and TCP.
144
+
145
  On master machine, run
146
  ```shell
147
  python tools/train.py -n yolox-s -b 128 --dist-url tcp://123.123.123.123:12312 --num_machines 2 --machine_rank 0
 
164
 
165
  An example wandb dashboard is available [here](https://wandb.ai/manan-goel/yolox-nano/runs/3pzfeom0)
166
 
167
+ **Others**
168
+
169
  See more information with the following command:
170
  ```shell
171
  python -m yolox.tools.train --help
 
204
  <summary>Tutorials</summary>
205
 
206
  * [Training on custom data](docs/train_custom_data.md)
207
+ * [Caching for custom data](docs/cache.md)
208
  * [Manipulating training image size](docs/manipulate_training_image_size.md)
209
  * [Freezing model](docs/freeze_module.md)
210
 
docs/cache.md ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Cache Custom Data
2
+
3
+ The caching feature is specifically tailored for users with ample memory resources. However, we still offer the option to cache data to disk, but disk performance can vary and may not guarantee optimal user experience. Implementing custom dataset RAM caching is also more straightforward and user-friendly compared to disk caching. With a few simple modifications, users can expect to see a significant increase in training speed, with speeds nearly double that of non-cached datasets.
4
+
5
+ This page explains how to cache your own custom data with YOLOX.
6
+
7
+ ## 0. Before you start
8
+
9
+ **Step1** Clone this repo and follow the [README](../README.md) to install YOLOX.
10
+
11
+ **Stpe2** Read the [Training on custom data](./train_custom_data.md) tutorial to understand how to prepare your custom data.
12
+
13
+ ## 1. Inheirit from `CacheDataset`
14
+
15
+
16
+ **Step1** Create a custom dataset that inherits from the `CacheDataset` class. Note that whether inheriting from `Dataset` or `CacheDataset `, the `__init__()` method of your custom dataset should take the following keyword arguments: `input_dimension`, `cache`, and `cache_type`. Also, call `super().__init__()` and pass in `input_dimension`, `num_imgs`, `cache`, and `cache_type` as input, where `num_imgs` is the size of the dataset.
17
+
18
+ **Step2** Implement the abstract function `read_img(self, index, use_cache=True)` of parent class and decorate it with `@cache_read_img`. This function takes an `index` as input and returns an `image`, and the returned image will be used for caching. It is recommended to put all repetitive and fixed post-processing operations on the image in this function to reduce the post-processing time of the image during training.
19
+
20
+ ```python
21
+ # CustomDataset.py
22
+ from yolox.data.datasets import CacheDataset, cache_read_img
23
+
24
+ class CustomDataset(CacheDataset):
25
+ def __init__(self, input_dimension, cache, cache_type, *args, **kwargs):
26
+ # Get the required keyword arguments of super().__init__()
27
+ super().__init__(
28
+ input_dimension=input_dimension,
29
+ num_imgs=num_imgs,
30
+ cache=cache,
31
+ cache_type=cache_type
32
+ )
33
+ # ...
34
+
35
+ @cache_read_img
36
+ def read_img(self, index, use_cache=True):
37
+ # get image ...
38
+ # (optional) repetitive and fixed post-processing operations for image
39
+ return image
40
+ ```
41
+
42
+ ## 2. Create your Exp file and return your custom dataset
43
+
44
+ **Step1** Create a new class that inherits from the `Exp` class provided by the `yolox_base.py`. Override the `get_dataset()` and `get_eval_dataset()` method to return an instance of your custom dataset.
45
+
46
+ **Step2** Implement your own `get_evaluator` method to return an instance of your custom evaluator.
47
+
48
+ ```python
49
+ # CustomeExp.py
50
+ from yolox.exp import Exp as MyExp
51
+
52
+ class Exp(MyExp):
53
+ def get_dataset(self, cache, cache_type: str = "ram"):
54
+ return CustomDataset(
55
+ input_dimension=self.input_size,
56
+ cache=cache,
57
+ cache_type=cache_type
58
+ )
59
+
60
+ def get_eval_dataset(self):
61
+ return CustomDataset(
62
+ input_dimension=self.input_size,
63
+ )
64
+
65
+ def get_evaluator(self, batch_size, is_distributed, testdev=False, legacy=False):
66
+ return CustomEvaluator(
67
+ dataloader=self.get_eval_loader(batch_size, is_distributed, testdev=testdev, legacy=legacy),
68
+ img_size=self.test_size,
69
+ confthre=self.test_conf,
70
+ nmsthre=self.nmsthre,
71
+ num_classes=self.num_classes,
72
+ testdev=testdev,
73
+ )
74
+ ```
75
+
76
+ **(Optional)** `get_data_loader` and `get_eval_loader` are now a default behavior in `yolox_base.py` and generally do not need to be changed. If you have to change `get_data_loader`, you need to add the following code at the beginning.
77
+
78
+ ```python
79
+ # CustomeExp.py
80
+ from yolox.exp import Exp as MyExp
81
+
82
+ class Exp(MyExp):
83
+ def get_data_loader(self, batch_size, is_distributed, no_aug=False, cache_img: str = None):
84
+ if self.dataset is None:
85
+ with wait_for_the_master():
86
+ assert cache_img is None
87
+ self.dataset = self.get_dataset(cache=False, cache_type=cache_img)
88
+ # ...
89
+
90
+ ```
91
+
92
+ ## 3. Cache to Disk
93
+ It's important to note that the `cache_type` can be `"ram"` or `"disk"`, depending on where you want to cache your dataset. If you choose `"disk"`, you need to pass in additional parameters to `super().__init__()` of `CustomDataset`: `data_dir`, `cache_dir_name`, `path_filename`.
94
+
95
+ - `data_dir`: the root directory of the dataset, e.g. `/path/to/COCO`.
96
+ - `cache_dir_name`: the name of the directory to cache to disk, for example `"custom_cache"`, then the files cached to disk will be saved under `/path/to/COCO/custom_cache`.
97
+ - `path_filename`: a list of paths to the data relative to the `data_dir`, e.g. if you have data `/path/to/COCO/train/1.jpg`, `/path/to/COCO/train/2.jpg`, then `path_filename = ['train/1.jpg', ' train/2.jpg']`.
exps/example/yolox_voc/yolox_voc_s.py CHANGED
@@ -1,9 +1,6 @@
1
  # encoding: utf-8
2
  import os
3
 
4
- import torch
5
- import torch.distributed as dist
6
-
7
  from yolox.data import get_yolox_datadir
8
  from yolox.exp import Exp as MyExp
9
 
@@ -24,115 +21,40 @@ class Exp(MyExp):
24
 
25
  self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0]
26
 
27
- def get_data_loader(self, batch_size, is_distributed, no_aug=False, cache_img=False):
28
- from yolox.data import (
29
- VOCDetection,
30
- TrainTransform,
31
- YoloBatchSampler,
32
- DataLoader,
33
- InfiniteSampler,
34
- MosaicDetection,
35
- worker_init_reset_seed,
36
- )
37
- from yolox.utils import (
38
- wait_for_the_master,
39
- get_local_rank,
40
- )
41
- local_rank = get_local_rank()
42
 
43
- with wait_for_the_master(local_rank):
44
- dataset = VOCDetection(
45
- data_dir=os.path.join(get_yolox_datadir(), "VOCdevkit"),
46
- image_sets=[('2007', 'trainval'), ('2012', 'trainval')],
47
- img_size=self.input_size,
48
- preproc=TrainTransform(
49
- max_labels=50,
50
- flip_prob=self.flip_prob,
51
- hsv_prob=self.hsv_prob),
52
- cache=cache_img,
53
- )
54
-
55
- dataset = MosaicDetection(
56
- dataset,
57
- mosaic=not no_aug,
58
  img_size=self.input_size,
59
  preproc=TrainTransform(
60
- max_labels=120,
61
  flip_prob=self.flip_prob,
62
  hsv_prob=self.hsv_prob),
63
- degrees=self.degrees,
64
- translate=self.translate,
65
- mosaic_scale=self.mosaic_scale,
66
- mixup_scale=self.mixup_scale,
67
- shear=self.shear,
68
- enable_mixup=self.enable_mixup,
69
- mosaic_prob=self.mosaic_prob,
70
- mixup_prob=self.mixup_prob,
71
- )
72
-
73
- self.dataset = dataset
74
-
75
- if is_distributed:
76
- batch_size = batch_size // dist.get_world_size()
77
-
78
- sampler = InfiniteSampler(
79
- len(self.dataset), seed=self.seed if self.seed else 0
80
  )
81
 
82
- batch_sampler = YoloBatchSampler(
83
- sampler=sampler,
84
- batch_size=batch_size,
85
- drop_last=False,
86
- mosaic=not no_aug,
87
- )
88
-
89
- dataloader_kwargs = {"num_workers": self.data_num_workers, "pin_memory": True}
90
- dataloader_kwargs["batch_sampler"] = batch_sampler
91
-
92
- # Make sure each process has different random seed, especially for 'fork' method
93
- dataloader_kwargs["worker_init_fn"] = worker_init_reset_seed
94
-
95
- train_loader = DataLoader(self.dataset, **dataloader_kwargs)
96
-
97
- return train_loader
98
-
99
- def get_eval_loader(self, batch_size, is_distributed, testdev=False, legacy=False):
100
  from yolox.data import VOCDetection, ValTransform
 
101
 
102
- valdataset = VOCDetection(
103
  data_dir=os.path.join(get_yolox_datadir(), "VOCdevkit"),
104
  image_sets=[('2007', 'test')],
105
  img_size=self.test_size,
106
  preproc=ValTransform(legacy=legacy),
107
  )
108
 
109
- if is_distributed:
110
- batch_size = batch_size // dist.get_world_size()
111
- sampler = torch.utils.data.distributed.DistributedSampler(
112
- valdataset, shuffle=False
113
- )
114
- else:
115
- sampler = torch.utils.data.SequentialSampler(valdataset)
116
-
117
- dataloader_kwargs = {
118
- "num_workers": self.data_num_workers,
119
- "pin_memory": True,
120
- "sampler": sampler,
121
- }
122
- dataloader_kwargs["batch_size"] = batch_size
123
- val_loader = torch.utils.data.DataLoader(valdataset, **dataloader_kwargs)
124
-
125
- return val_loader
126
-
127
  def get_evaluator(self, batch_size, is_distributed, testdev=False, legacy=False):
128
  from yolox.evaluators import VOCEvaluator
129
 
130
- val_loader = self.get_eval_loader(batch_size, is_distributed, testdev, legacy)
131
- evaluator = VOCEvaluator(
132
- dataloader=val_loader,
133
  img_size=self.test_size,
134
  confthre=self.test_conf,
135
  nmsthre=self.nmsthre,
136
  num_classes=self.num_classes,
137
  )
138
- return evaluator
 
1
  # encoding: utf-8
2
  import os
3
 
 
 
 
4
  from yolox.data import get_yolox_datadir
5
  from yolox.exp import Exp as MyExp
6
 
 
21
 
22
  self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0]
23
 
24
+ def get_dataset(self, cache: bool, cache_type: str = "ram"):
25
+ from yolox.data import VOCDetection, TrainTransform
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
+ return VOCDetection(
28
+ data_dir=os.path.join(get_yolox_datadir(), "VOCdevkit"),
29
+ image_sets=[('2007', 'trainval'), ('2012', 'trainval')],
 
 
 
 
 
 
 
 
 
 
 
 
30
  img_size=self.input_size,
31
  preproc=TrainTransform(
32
+ max_labels=50,
33
  flip_prob=self.flip_prob,
34
  hsv_prob=self.hsv_prob),
35
+ cache=cache,
36
+ cache_type=cache_type,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  )
38
 
39
+ def get_eval_dataset(self, **kwargs):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  from yolox.data import VOCDetection, ValTransform
41
+ legacy = kwargs.get("legacy", False)
42
 
43
+ return VOCDetection(
44
  data_dir=os.path.join(get_yolox_datadir(), "VOCdevkit"),
45
  image_sets=[('2007', 'test')],
46
  img_size=self.test_size,
47
  preproc=ValTransform(legacy=legacy),
48
  )
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  def get_evaluator(self, batch_size, is_distributed, testdev=False, legacy=False):
51
  from yolox.evaluators import VOCEvaluator
52
 
53
+ return VOCEvaluator(
54
+ dataloader=self.get_eval_loader(batch_size, is_distributed,
55
+ testdev=testdev, legacy=legacy),
56
  img_size=self.test_size,
57
  confthre=self.test_conf,
58
  nmsthre=self.nmsthre,
59
  num_classes=self.num_classes,
60
  )
 
tools/train.py CHANGED
@@ -131,7 +131,7 @@ if __name__ == "__main__":
131
  assert num_gpu <= get_num_devices()
132
 
133
  if args.cache is not None:
134
- exp.create_cache_dataset(args.cache)
135
 
136
  dist_url = "auto" if args.dist_url is None else args.dist_url
137
  launch(
 
131
  assert num_gpu <= get_num_devices()
132
 
133
  if args.cache is not None:
134
+ exp.dataset = exp.get_dataset(cache=True, cache_type=args.cache)
135
 
136
  dist_url = "auto" if args.dist_url is None else args.dist_url
137
  launch(
yolox/data/datasets/__init__.py CHANGED
@@ -4,6 +4,6 @@
4
 
5
  from .coco import COCODataset
6
  from .coco_classes import COCO_CLASSES
7
- from .datasets_wrapper import ConcatDataset, Dataset, MixConcatDataset
8
  from .mosaicdetection import MosaicDetection
9
  from .voc import VOCDetection
 
4
 
5
  from .coco import COCODataset
6
  from .coco_classes import COCO_CLASSES
7
+ from .datasets_wrapper import CacheDataset, ConcatDataset, Dataset, MixConcatDataset
8
  from .mosaicdetection import MosaicDetection
9
  from .voc import VOCDetection
yolox/data/datasets/coco.py CHANGED
@@ -3,18 +3,13 @@
3
  # Copyright (c) Megvii, Inc. and its affiliates.
4
  import copy
5
  import os
6
- import random
7
- from multiprocessing.pool import ThreadPool
8
- import psutil
9
- from loguru import logger
10
- from tqdm import tqdm
11
 
12
  import cv2
13
  import numpy as np
14
  from pycocotools.coco import COCO
15
 
16
  from ..dataloading import get_yolox_datadir
17
- from .datasets_wrapper import Dataset
18
 
19
 
20
  def remove_useless_info(coco):
@@ -36,7 +31,7 @@ def remove_useless_info(coco):
36
  anno.pop("segmentation", None)
37
 
38
 
39
- class COCODataset(Dataset):
40
  """
41
  COCO dataset class.
42
  """
@@ -60,7 +55,6 @@ class COCODataset(Dataset):
60
  img_size (int): target image size after pre-processing
61
  preproc: data augmentation strategy
62
  """
63
- super().__init__(img_size)
64
  if data_dir is None:
65
  data_dir = os.path.join(get_yolox_datadir(), "COCO")
66
  self.data_dir = data_dir
@@ -77,85 +71,21 @@ class COCODataset(Dataset):
77
  self.img_size = img_size
78
  self.preproc = preproc
79
  self.annotations = self._load_coco_annotations()
80
- self.imgs = None
81
- self.cache = cache
82
- self.cache_type = cache_type
83
 
84
- if self.cache:
85
- self._cache_images()
86
-
87
- def _cache_images(self):
88
- mem = psutil.virtual_memory()
89
- mem_required = self.cal_cache_ram()
90
- gb = 1 << 30
91
-
92
- if self.cache_type == "ram" and mem_required > mem.available:
93
- self.cache = False
94
- else:
95
- logger.info(
96
- f"{mem_required / gb:.1f}GB RAM required, "
97
- f"{mem.available / gb:.1f}/{mem.total / gb:.1f}GB RAM available, "
98
- f"Since the first thing we do is cache, "
99
- f"there is no guarantee that the remaining memory space is sufficient"
100
- )
101
-
102
- if self.imgs is None:
103
- if self.cache_type == 'ram':
104
- self.imgs = [None] * self.num_imgs
105
- logger.info("You are using cached images in RAM to accelerate training!")
106
- else: # 'disk'
107
- self.cache_dir = os.path.join(
108
- self.data_dir,
109
- f"{self.name}_cache{self.img_size[0]}x{self.img_size[1]}"
110
- )
111
- if not os.path.exists(self.cache_dir):
112
- os.mkdir(self.cache_dir)
113
- logger.warning(
114
- f"\n*******************************************************************\n"
115
- f"You are using cached images in DISK to accelerate training.\n"
116
- f"This requires large DISK space.\n"
117
- f"Make sure you have {mem_required / gb:.1f} "
118
- f"available DISK space for training COCO.\n"
119
- f"*******************************************************************\\n"
120
- )
121
- else:
122
- logger.info("Found disk cache!")
123
- return
124
-
125
- logger.info(
126
- "Caching images for the first time. "
127
- "This might take about 15 minutes for COCO"
128
- )
129
-
130
- num_threads = min(8, max(1, os.cpu_count() - 1))
131
- b = 0
132
- load_imgs = ThreadPool(num_threads).imap(self.load_resized_img, range(self.num_imgs))
133
- pbar = tqdm(enumerate(load_imgs), total=self.num_imgs)
134
- for i, x in pbar: # x = self.load_resized_img(self, i)
135
- if self.cache_type == 'ram':
136
- self.imgs[i] = x
137
- else: # 'disk'
138
- cache_filename = f'{self.annotations[i]["filename"].split(".")[0]}.npy'
139
- np.save(os.path.join(self.cache_dir, cache_filename), x)
140
- b += x.nbytes
141
- pbar.desc = f'Caching images ({b / gb:.1f}/{mem_required / gb:.1f}GB {self.cache})'
142
- pbar.close()
143
-
144
- def cal_cache_ram(self):
145
- cache_bytes = 0
146
- num_samples = min(self.num_imgs, 32)
147
- for _ in range(num_samples):
148
- img = self.load_resized_img(random.randint(0, self.num_imgs - 1))
149
- cache_bytes += img.nbytes
150
- mem_required = cache_bytes * self.num_imgs / num_samples
151
- return mem_required
152
 
153
  def __len__(self):
154
  return self.num_imgs
155
 
156
- def __del__(self):
157
- del self.imgs
158
-
159
  def _load_coco_annotations(self):
160
  return [self.load_anno_from_ids(_ids) for _ids in self.ids]
161
 
@@ -220,20 +150,18 @@ class COCODataset(Dataset):
220
 
221
  return img
222
 
 
 
 
 
223
  def pull_item(self, index):
224
  id_ = self.ids[index]
225
- label, origin_image_size, _, filename = self.annotations[index]
226
-
227
- if self.cache and self.cache_type == 'ram':
228
- img = self.imgs[index]
229
- elif self.cache and self.cache_type == 'disk':
230
- img = np.load(os.path.join(self.cache_dir, f"{filename.split('.')[0]}.npy"))
231
- else:
232
- img = self.load_resized_img(index)
233
 
234
- return copy.deepcopy(img), copy.deepcopy(label), origin_image_size, np.array([id_])
235
 
236
- @Dataset.mosaic_getitem
237
  def __getitem__(self, index):
238
  """
239
  One image / label pair for the given index is picked up and pre-processed.
 
3
  # Copyright (c) Megvii, Inc. and its affiliates.
4
  import copy
5
  import os
 
 
 
 
 
6
 
7
  import cv2
8
  import numpy as np
9
  from pycocotools.coco import COCO
10
 
11
  from ..dataloading import get_yolox_datadir
12
+ from .datasets_wrapper import CacheDataset, cache_read_img
13
 
14
 
15
  def remove_useless_info(coco):
 
31
  anno.pop("segmentation", None)
32
 
33
 
34
+ class COCODataset(CacheDataset):
35
  """
36
  COCO dataset class.
37
  """
 
55
  img_size (int): target image size after pre-processing
56
  preproc: data augmentation strategy
57
  """
 
58
  if data_dir is None:
59
  data_dir = os.path.join(get_yolox_datadir(), "COCO")
60
  self.data_dir = data_dir
 
71
  self.img_size = img_size
72
  self.preproc = preproc
73
  self.annotations = self._load_coco_annotations()
 
 
 
74
 
75
+ path_filename = [os.path.join(name, anno[3]) for anno in self.annotations]
76
+ super().__init__(
77
+ input_dimension=img_size,
78
+ num_imgs=self.num_imgs,
79
+ data_dir=data_dir,
80
+ cache_dir_name=f"cache_{name}",
81
+ path_filename=path_filename,
82
+ cache=cache,
83
+ cache_type=cache_type
84
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
  def __len__(self):
87
  return self.num_imgs
88
 
 
 
 
89
  def _load_coco_annotations(self):
90
  return [self.load_anno_from_ids(_ids) for _ids in self.ids]
91
 
 
150
 
151
  return img
152
 
153
+ @cache_read_img(use_cache=True)
154
+ def read_img(self, index):
155
+ return self.load_resized_img(index)
156
+
157
  def pull_item(self, index):
158
  id_ = self.ids[index]
159
+ label, origin_image_size, _, _ = self.annotations[index]
160
+ img = self.read_img(index)
 
 
 
 
 
 
161
 
162
+ return img, copy.deepcopy(label), origin_image_size, np.array([id_])
163
 
164
+ @CacheDataset.mosaic_getitem
165
  def __getitem__(self, index):
166
  """
167
  One image / label pair for the given index is picked up and pre-processed.
yolox/data/datasets/datasets_wrapper.py CHANGED
@@ -3,7 +3,17 @@
3
  # Copyright (c) Megvii, Inc. and its affiliates.
4
 
5
  import bisect
6
- from functools import wraps
 
 
 
 
 
 
 
 
 
 
7
 
8
  from torch.utils.data.dataset import ConcatDataset as torchConcatDataset
9
  from torch.utils.data.dataset import Dataset as torchDataset
@@ -112,3 +122,179 @@ class Dataset(torchDataset):
112
  return ret_val
113
 
114
  return wrapper
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  # Copyright (c) Megvii, Inc. and its affiliates.
4
 
5
  import bisect
6
+ import copy
7
+ import os
8
+ import random
9
+ from abc import ABCMeta, abstractmethod
10
+ from functools import partial, wraps
11
+ from multiprocessing.pool import ThreadPool
12
+ import psutil
13
+ from loguru import logger
14
+ from tqdm import tqdm
15
+
16
+ import numpy as np
17
 
18
  from torch.utils.data.dataset import ConcatDataset as torchConcatDataset
19
  from torch.utils.data.dataset import Dataset as torchDataset
 
122
  return ret_val
123
 
124
  return wrapper
125
+
126
+
127
+ class CacheDataset(Dataset, metaclass=ABCMeta):
128
+ """ This class is a subclass of the base :class:`yolox.data.datasets.Dataset`,
129
+ that enables cache images to ram or disk.
130
+
131
+ Args:
132
+ input_dimension (tuple): (width,height) tuple with default dimensions of the network
133
+ num_imgs (int): datset size
134
+ data_dir (str): the root directory of the dataset, e.g. `/path/to/COCO`.
135
+ cache_dir_name (str): the name of the directory to cache to disk,
136
+ e.g. `"custom_cache"`. The files cached to disk will be saved
137
+ under `/path/to/COCO/custom_cache`.
138
+ path_filename (str): a list of paths to the data relative to the `data_dir`,
139
+ e.g. if you have data `/path/to/COCO/train/1.jpg`, `/path/to/COCO/train/2.jpg`,
140
+ then `path_filename = ['train/1.jpg', ' train/2.jpg']`.
141
+ cache (bool): whether to cache the images to ram or disk.
142
+ cache_type (str): the type of cache,
143
+ "ram" : Caching imgs to ram for fast training.
144
+ "disk": Caching imgs to disk for fast training.
145
+ """
146
+
147
+ def __init__(
148
+ self,
149
+ input_dimension,
150
+ num_imgs=None,
151
+ data_dir=None,
152
+ cache_dir_name=None,
153
+ path_filename=None,
154
+ cache=False,
155
+ cache_type="ram",
156
+ ):
157
+ super().__init__(input_dimension)
158
+ self.cache = cache
159
+ self.cache_type = cache_type
160
+
161
+ if self.cache and self.cache_type == "disk":
162
+ self.cache_dir = os.path.join(data_dir, cache_dir_name)
163
+ self.path_filename = path_filename
164
+
165
+ if self.cache and self.cache_type == "ram":
166
+ self.imgs = None
167
+
168
+ if self.cache:
169
+ self.cache_images(
170
+ num_imgs=num_imgs,
171
+ data_dir=data_dir,
172
+ cache_dir_name=cache_dir_name,
173
+ path_filename=path_filename,
174
+ )
175
+
176
+ def __del__(self):
177
+ if self.cache and self.cache_type == "ram":
178
+ del self.imgs
179
+
180
+ @abstractmethod
181
+ def read_img(self, index):
182
+ """
183
+ Given index, return the corresponding image
184
+
185
+ Args:
186
+ index (int): image index
187
+ """
188
+ raise NotImplementedError
189
+
190
+ def cache_images(
191
+ self,
192
+ num_imgs=None,
193
+ data_dir=None,
194
+ cache_dir_name=None,
195
+ path_filename=None,
196
+ ):
197
+ assert num_imgs is not None, "num_imgs must be specified as the size of the dataset"
198
+ if self.cache_type == "disk":
199
+ assert (data_dir and cache_dir_name and path_filename) is not None, \
200
+ "data_dir, cache_name and path_filename must be specified if cache_type is disk"
201
+ self.path_filename = path_filename
202
+
203
+ mem = psutil.virtual_memory()
204
+ mem_required = self.cal_cache_occupy(num_imgs)
205
+ gb = 1 << 30
206
+
207
+ if self.cache_type == "ram":
208
+ if mem_required > mem.available:
209
+ self.cache = False
210
+ else:
211
+ logger.info(
212
+ f"{mem_required / gb:.1f}GB RAM required, "
213
+ f"{mem.available / gb:.1f}/{mem.total / gb:.1f}GB RAM available, "
214
+ f"Since the first thing we do is cache, "
215
+ f"there is no guarantee that the remaining memory space is sufficient"
216
+ )
217
+
218
+ if self.cache and self.imgs is None:
219
+ if self.cache_type == 'ram':
220
+ self.imgs = [None] * num_imgs
221
+ logger.info("You are using cached images in RAM to accelerate training!")
222
+ else: # 'disk'
223
+ if not os.path.exists(self.cache_dir):
224
+ os.mkdir(self.cache_dir)
225
+ logger.warning(
226
+ f"\n*******************************************************************\n"
227
+ f"You are using cached images in DISK to accelerate training.\n"
228
+ f"This requires large DISK space.\n"
229
+ f"Make sure you have {mem_required / gb:.1f} "
230
+ f"available DISK space for training your dataset.\n"
231
+ f"*******************************************************************\\n"
232
+ )
233
+ else:
234
+ logger.info(f"Found disk cache at {self.cache_dir}")
235
+ return
236
+
237
+ logger.info(
238
+ "Caching images...\n"
239
+ "This might take some time for your dataset"
240
+ )
241
+
242
+ num_threads = min(8, max(1, os.cpu_count() - 1))
243
+ b = 0
244
+ load_imgs = ThreadPool(num_threads).imap(
245
+ partial(self.read_img, use_cache=False),
246
+ range(num_imgs)
247
+ )
248
+ pbar = tqdm(enumerate(load_imgs), total=num_imgs)
249
+ for i, x in pbar: # x = self.read_img(self, i, use_cache=False)
250
+ if self.cache_type == 'ram':
251
+ self.imgs[i] = x
252
+ else: # 'disk'
253
+ cache_filename = f'{self.path_filename[i].split(".")[0]}.npy'
254
+ cache_path_filename = os.path.join(self.cache_dir, cache_filename)
255
+ os.makedirs(os.path.dirname(cache_path_filename), exist_ok=True)
256
+ np.save(cache_path_filename, x)
257
+ b += x.nbytes
258
+ pbar.desc = \
259
+ f'Caching images ({b / gb:.1f}/{mem_required / gb:.1f}GB {self.cache_type})'
260
+ pbar.close()
261
+
262
+ def cal_cache_occupy(self, num_imgs):
263
+ cache_bytes = 0
264
+ num_samples = min(num_imgs, 32)
265
+ for _ in range(num_samples):
266
+ img = self.read_img(index=random.randint(0, num_imgs - 1), use_cache=False)
267
+ cache_bytes += img.nbytes
268
+ mem_required = cache_bytes * num_imgs / num_samples
269
+ return mem_required
270
+
271
+
272
+ def cache_read_img(use_cache=True):
273
+ def decorator(read_img_fn):
274
+ """
275
+ Decorate the read_img function to cache the image
276
+
277
+ Args:
278
+ read_img_fn: read_img function
279
+ use_cache (bool, optional): For the decorated read_img function,
280
+ whether to read the image from cache.
281
+ Defaults to True.
282
+ """
283
+ @wraps(read_img_fn)
284
+ def wrapper(self, index, use_cache=use_cache):
285
+ cache = self.cache and use_cache
286
+ if cache:
287
+ if self.cache_type == "ram":
288
+ img = self.imgs[index]
289
+ img = copy.deepcopy(img)
290
+ elif self.cache_type == "disk":
291
+ img = np.load(
292
+ os.path.join(
293
+ self.cache_dir, f"{self.path_filename[index].split('.')[0]}.npy"))
294
+ else:
295
+ raise ValueError(f"Unknown cache type: {self.cache_type}")
296
+ else:
297
+ img = read_img_fn(self, index)
298
+ return img
299
+ return wrapper
300
+ return decorator
yolox/data/datasets/voc.py CHANGED
@@ -10,14 +10,13 @@ import os
10
  import os.path
11
  import pickle
12
  import xml.etree.ElementTree as ET
13
- from loguru import logger
14
 
15
  import cv2
16
  import numpy as np
17
 
18
  from yolox.evaluators.voc_eval import voc_eval
19
 
20
- from .datasets_wrapper import Dataset
21
  from .voc_classes import VOC_CLASSES
22
 
23
 
@@ -80,7 +79,7 @@ class AnnotationTransform(object):
80
  return res, img_info
81
 
82
 
83
- class VOCDetection(Dataset):
84
 
85
  """
86
  VOC Detection Dataset Object
@@ -108,8 +107,8 @@ class VOCDetection(Dataset):
108
  target_transform=AnnotationTransform(),
109
  dataset_name="VOC0712",
110
  cache=False,
 
111
  ):
112
- super().__init__(img_size)
113
  self.root = data_dir
114
  self.image_set = image_sets
115
  self.img_size = img_size
@@ -131,66 +130,29 @@ class VOCDetection(Dataset):
131
  os.path.join(rootpath, "ImageSets", "Main", name + ".txt")
132
  ):
133
  self.ids.append((rootpath, line.strip()))
 
134
 
135
  self.annotations = self._load_coco_annotations()
136
- self.imgs = None
137
- if cache:
138
- self._cache_images()
139
 
140
- def __len__(self):
141
- return len(self.ids)
142
-
143
- def _load_coco_annotations(self):
144
- return [self.load_anno_from_ids(_ids) for _ids in range(len(self.ids))]
145
-
146
- def _cache_images(self):
147
- logger.warning(
148
- "\n********************************************************************************\n"
149
- "You are using cached images in RAM to accelerate training.\n"
150
- "This requires large system RAM.\n"
151
- "Make sure you have 60G+ RAM and 19G available disk space for training VOC.\n"
152
- "********************************************************************************\n"
153
  )
154
- max_h = self.img_size[0]
155
- max_w = self.img_size[1]
156
- cache_file = os.path.join(self.root, f"img_resized_cache_{self.name}.array")
157
- if not os.path.exists(cache_file):
158
- logger.info(
159
- "Caching images for the first time. This might take about 3 minutes for VOC"
160
- )
161
- self.imgs = np.memmap(
162
- cache_file,
163
- shape=(len(self.ids), max_h, max_w, 3),
164
- dtype=np.uint8,
165
- mode="w+",
166
- )
167
- from tqdm import tqdm
168
- from multiprocessing.pool import ThreadPool
169
 
170
- NUM_THREADs = min(8, os.cpu_count())
171
- loaded_images = ThreadPool(NUM_THREADs).imap(
172
- lambda x: self.load_resized_img(x),
173
- range(len(self.annotations)),
174
- )
175
- pbar = tqdm(enumerate(loaded_images), total=len(self.annotations))
176
- for k, out in pbar:
177
- self.imgs[k][: out.shape[0], : out.shape[1], :] = out.copy()
178
- self.imgs.flush()
179
- pbar.close()
180
- else:
181
- logger.warning(
182
- "You are using cached imgs! Make sure your dataset is not changed!!\n"
183
- "Everytime the self.input_size is changed in your exp file, you need to delete\n"
184
- "the cached data and re-generate them.\n"
185
- )
186
 
187
- logger.info("Loading cached imgs...")
188
- self.imgs = np.memmap(
189
- cache_file,
190
- shape=(len(self.ids), max_h, max_w, 3),
191
- dtype=np.uint8,
192
- mode="r+",
193
- )
194
 
195
  def load_anno_from_ids(self, index):
196
  img_id = self.ids[index]
@@ -227,6 +189,10 @@ class VOCDetection(Dataset):
227
 
228
  return img
229
 
 
 
 
 
230
  def pull_item(self, index):
231
  """Returns the original image and target at an index for mixup
232
 
@@ -238,17 +204,12 @@ class VOCDetection(Dataset):
238
  Return:
239
  img, target
240
  """
241
- if self.imgs is not None:
242
- target, img_info, resized_info = self.annotations[index]
243
- pad_img = self.imgs[index]
244
- img = pad_img[: resized_info[0], : resized_info[1], :].copy()
245
- else:
246
- img = self.load_resized_img(index)
247
- target, img_info, _ = self.annotations[index]
248
 
249
  return img, target, img_info, index
250
 
251
- @Dataset.mosaic_getitem
252
  def __getitem__(self, index):
253
  img, target, img_info, img_id = self.pull_item(index)
254
 
 
10
  import os.path
11
  import pickle
12
  import xml.etree.ElementTree as ET
 
13
 
14
  import cv2
15
  import numpy as np
16
 
17
  from yolox.evaluators.voc_eval import voc_eval
18
 
19
+ from .datasets_wrapper import CacheDataset, cache_read_img
20
  from .voc_classes import VOC_CLASSES
21
 
22
 
 
79
  return res, img_info
80
 
81
 
82
+ class VOCDetection(CacheDataset):
83
 
84
  """
85
  VOC Detection Dataset Object
 
107
  target_transform=AnnotationTransform(),
108
  dataset_name="VOC0712",
109
  cache=False,
110
+ cache_type="ram",
111
  ):
 
112
  self.root = data_dir
113
  self.image_set = image_sets
114
  self.img_size = img_size
 
130
  os.path.join(rootpath, "ImageSets", "Main", name + ".txt")
131
  ):
132
  self.ids.append((rootpath, line.strip()))
133
+ self.num_imgs = len(self.ids)
134
 
135
  self.annotations = self._load_coco_annotations()
 
 
 
136
 
137
+ path_filename = [
138
+ (self._imgpath % self.ids[i]).split(self.root + "/")[1]
139
+ for i in range(self.num_imgs)
140
+ ]
141
+ super().__init__(
142
+ input_dimension=img_size,
143
+ num_imgs=self.num_imgs,
144
+ data_dir=self.root,
145
+ cache_dir_name=f"cache_{self.name}",
146
+ path_filename=path_filename,
147
+ cache=cache,
148
+ cache_type=cache_type
 
149
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
 
151
+ def __len__(self):
152
+ return self.num_imgs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
154
+ def _load_coco_annotations(self):
155
+ return [self.load_anno_from_ids(_ids) for _ids in range(self.num_imgs)]
 
 
 
 
 
156
 
157
  def load_anno_from_ids(self, index):
158
  img_id = self.ids[index]
 
189
 
190
  return img
191
 
192
+ @cache_read_img
193
+ def read_img(self, index, use_cache=True):
194
+ return self.load_resized_img(index)
195
+
196
  def pull_item(self, index):
197
  """Returns the original image and target at an index for mixup
198
 
 
204
  Return:
205
  img, target
206
  """
207
+ target, img_info, _ = self.annotations[index]
208
+ img = self.read_img(index)
 
 
 
 
 
209
 
210
  return img, target, img_info, index
211
 
212
+ @CacheDataset.mosaic_getitem
213
  def __getitem__(self, index):
214
  img, target, img_info, img_id = self.pull_item(index)
215
 
yolox/evaluators/coco_evaluator.py CHANGED
@@ -90,8 +90,8 @@ class COCOEvaluator:
90
  nmsthre: float,
91
  num_classes: int,
92
  testdev: bool = False,
93
- per_class_AP: bool = False,
94
- per_class_AR: bool = False,
95
  ):
96
  """
97
  Args:
@@ -101,8 +101,8 @@ class COCOEvaluator:
101
  confthre: confidence threshold ranging from 0 to 1, which
102
  is defined in the config file.
103
  nmsthre: IoU threshold of non-max supression ranging from 0 to 1.
104
- per_class_AP: Show per class AP during evalution or not. Default to False.
105
- per_class_AR: Show per class AR during evalution or not. Default to False.
106
  """
107
  self.dataloader = dataloader
108
  self.img_size = img_size
@@ -188,6 +188,9 @@ class COCOEvaluator:
188
 
189
  statistics = torch.cuda.FloatTensor([inference_time, nms_time, n_samples])
190
  if distributed:
 
 
 
191
  data_list = gather(data_list, dst=0)
192
  output_data = gather(output_data, dst=0)
193
  data_list = list(itertools.chain(*data_list))
 
90
  nmsthre: float,
91
  num_classes: int,
92
  testdev: bool = False,
93
+ per_class_AP: bool = True,
94
+ per_class_AR: bool = True,
95
  ):
96
  """
97
  Args:
 
101
  confthre: confidence threshold ranging from 0 to 1, which
102
  is defined in the config file.
103
  nmsthre: IoU threshold of non-max supression ranging from 0 to 1.
104
+ per_class_AP: Show per class AP during evalution or not. Default to True.
105
+ per_class_AR: Show per class AR during evalution or not. Default to True.
106
  """
107
  self.dataloader = dataloader
108
  self.img_size = img_size
 
188
 
189
  statistics = torch.cuda.FloatTensor([inference_time, nms_time, n_samples])
190
  if distributed:
191
+ # different process/device might have different speed,
192
+ # to make sure the process will not be stucked, sync func is used here.
193
+ synchronize()
194
  data_list = gather(data_list, dst=0)
195
  output_data = gather(output_data, dst=0)
196
  data_list = list(itertools.chain(*data_list))
yolox/exp/base_exp.py CHANGED
@@ -22,11 +22,16 @@ class BaseExp(metaclass=ABCMeta):
22
  self.output_dir = "./YOLOX_outputs"
23
  self.print_interval = 100
24
  self.eval_interval = 10
 
25
 
26
  @abstractmethod
27
  def get_model(self) -> Module:
28
  pass
29
 
 
 
 
 
30
  @abstractmethod
31
  def get_data_loader(
32
  self, batch_size: int, is_distributed: bool
 
22
  self.output_dir = "./YOLOX_outputs"
23
  self.print_interval = 100
24
  self.eval_interval = 10
25
+ self.dataset = None
26
 
27
  @abstractmethod
28
  def get_model(self) -> Module:
29
  pass
30
 
31
+ @abstractmethod
32
+ def get_dataset(self, cache: bool = False, cache_type: str = "ram"):
33
+ pass
34
+
35
  @abstractmethod
36
  def get_data_loader(
37
  self, batch_size: int, is_distributed: bool
yolox/exp/yolox_base.py CHANGED
@@ -106,23 +106,6 @@ class Exp(BaseExp):
106
  self.test_conf = 0.01
107
  # nms threshold
108
  self.nmsthre = 0.65
109
- self.cache_dataset = None
110
- self.dataset = None
111
-
112
- def create_cache_dataset(self, cache_type: str = "ram"):
113
- from yolox.data import COCODataset, TrainTransform
114
- self.cache_dataset = COCODataset(
115
- data_dir=self.data_dir,
116
- json_file=self.train_ann,
117
- img_size=self.input_size,
118
- preproc=TrainTransform(
119
- max_labels=50,
120
- flip_prob=self.flip_prob,
121
- hsv_prob=self.hsv_prob
122
- ),
123
- cache=True,
124
- cache_type=cache_type,
125
- )
126
 
127
  def get_model(self):
128
  from yolox.models import YOLOX, YOLOPAFPN, YOLOXHead
@@ -144,6 +127,30 @@ class Exp(BaseExp):
144
  self.model.train()
145
  return self.model
146
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  def get_data_loader(self, batch_size, is_distributed, no_aug=False, cache_img: str = None):
148
  """
149
  Get dataloader according to cache_img parameter.
@@ -155,7 +162,6 @@ class Exp(BaseExp):
155
  None: Do not use cache, in this case cache_data is also None.
156
  """
157
  from yolox.data import (
158
- COCODataset,
159
  TrainTransform,
160
  YoloBatchSampler,
161
  DataLoader,
@@ -165,25 +171,16 @@ class Exp(BaseExp):
165
  )
166
  from yolox.utils import wait_for_the_master
167
 
168
- with wait_for_the_master():
169
- if self.cache_dataset is None:
170
- assert cache_img is None, "cache is True, but cache_dataset is None"
171
- dataset = COCODataset(
172
- data_dir=self.data_dir,
173
- json_file=self.train_ann,
174
- img_size=self.input_size,
175
- preproc=TrainTransform(
176
- max_labels=50,
177
- flip_prob=self.flip_prob,
178
- hsv_prob=self.hsv_prob),
179
- cache=False,
180
- cache_type=cache_img,
181
- )
182
- else:
183
- dataset = self.cache_dataset
184
 
185
  self.dataset = MosaicDetection(
186
- dataset,
187
  mosaic=not no_aug,
188
  img_size=self.input_size,
189
  preproc=TrainTransform(
@@ -298,10 +295,12 @@ class Exp(BaseExp):
298
  )
299
  return scheduler
300
 
301
- def get_eval_loader(self, batch_size, is_distributed, testdev=False, legacy=False):
302
  from yolox.data import COCODataset, ValTransform
 
 
303
 
304
- valdataset = COCODataset(
305
  data_dir=self.data_dir,
306
  json_file=self.val_ann if not testdev else self.test_ann,
307
  name="val2017" if not testdev else "test2017",
@@ -309,6 +308,9 @@ class Exp(BaseExp):
309
  preproc=ValTransform(legacy=legacy),
310
  )
311
 
 
 
 
312
  if is_distributed:
313
  batch_size = batch_size // dist.get_world_size()
314
  sampler = torch.utils.data.distributed.DistributedSampler(
@@ -330,16 +332,15 @@ class Exp(BaseExp):
330
  def get_evaluator(self, batch_size, is_distributed, testdev=False, legacy=False):
331
  from yolox.evaluators import COCOEvaluator
332
 
333
- val_loader = self.get_eval_loader(batch_size, is_distributed, testdev, legacy)
334
- evaluator = COCOEvaluator(
335
- dataloader=val_loader,
336
  img_size=self.test_size,
337
  confthre=self.test_conf,
338
  nmsthre=self.nmsthre,
339
  num_classes=self.num_classes,
340
  testdev=testdev,
341
  )
342
- return evaluator
343
 
344
  def get_trainer(self, args):
345
  from yolox.core import Trainer
 
106
  self.test_conf = 0.01
107
  # nms threshold
108
  self.nmsthre = 0.65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
  def get_model(self):
111
  from yolox.models import YOLOX, YOLOPAFPN, YOLOXHead
 
127
  self.model.train()
128
  return self.model
129
 
130
+ def get_dataset(self, cache: bool = False, cache_type: str = "ram"):
131
+ """
132
+ Get dataset according to cache and cache_type parameters.
133
+ Args:
134
+ cache (bool): Whether to cache imgs to ram or disk.
135
+ cache_type (str, optional): Defaults to "ram".
136
+ "ram" : Caching imgs to ram for fast training.
137
+ "disk": Caching imgs to disk for fast training.
138
+ """
139
+ from yolox.data import COCODataset, TrainTransform
140
+
141
+ return COCODataset(
142
+ data_dir=self.data_dir,
143
+ json_file=self.train_ann,
144
+ img_size=self.input_size,
145
+ preproc=TrainTransform(
146
+ max_labels=50,
147
+ flip_prob=self.flip_prob,
148
+ hsv_prob=self.hsv_prob
149
+ ),
150
+ cache=cache,
151
+ cache_type=cache_type,
152
+ )
153
+
154
  def get_data_loader(self, batch_size, is_distributed, no_aug=False, cache_img: str = None):
155
  """
156
  Get dataloader according to cache_img parameter.
 
162
  None: Do not use cache, in this case cache_data is also None.
163
  """
164
  from yolox.data import (
 
165
  TrainTransform,
166
  YoloBatchSampler,
167
  DataLoader,
 
171
  )
172
  from yolox.utils import wait_for_the_master
173
 
174
+ # if cache is True, we will create self.dataset before launch
175
+ # else we will create self.dataset after launch
176
+ if self.dataset is None:
177
+ with wait_for_the_master():
178
+ assert cache_img is None, \
179
+ "cache_img must be None if you didn't create self.dataset before launch"
180
+ self.dataset = self.get_dataset(cache=False, cache_type=cache_img)
 
 
 
 
 
 
 
 
 
181
 
182
  self.dataset = MosaicDetection(
183
+ dataset=self.dataset,
184
  mosaic=not no_aug,
185
  img_size=self.input_size,
186
  preproc=TrainTransform(
 
295
  )
296
  return scheduler
297
 
298
+ def get_eval_dataset(self, **kwargs):
299
  from yolox.data import COCODataset, ValTransform
300
+ testdev = kwargs.get("testdev", False)
301
+ legacy = kwargs.get("legacy", False)
302
 
303
+ return COCODataset(
304
  data_dir=self.data_dir,
305
  json_file=self.val_ann if not testdev else self.test_ann,
306
  name="val2017" if not testdev else "test2017",
 
308
  preproc=ValTransform(legacy=legacy),
309
  )
310
 
311
+ def get_eval_loader(self, batch_size, is_distributed, **kwargs):
312
+ valdataset = self.get_eval_dataset(**kwargs)
313
+
314
  if is_distributed:
315
  batch_size = batch_size // dist.get_world_size()
316
  sampler = torch.utils.data.distributed.DistributedSampler(
 
332
  def get_evaluator(self, batch_size, is_distributed, testdev=False, legacy=False):
333
  from yolox.evaluators import COCOEvaluator
334
 
335
+ return COCOEvaluator(
336
+ dataloader=self.get_eval_loader(batch_size, is_distributed,
337
+ testdev=testdev, legacy=legacy),
338
  img_size=self.test_size,
339
  confthre=self.test_conf,
340
  nmsthre=self.nmsthre,
341
  num_classes=self.num_classes,
342
  testdev=testdev,
343
  )
 
344
 
345
  def get_trainer(self, args):
346
  from yolox.core import Trainer