🔨 [Add] Dataset to train.py, todo: Dataloader
Browse files- train.py +2 -0
- utils/dataargument.py +2 -4
- utils/dataloader.py +9 -7
train.py
CHANGED
@@ -4,11 +4,13 @@ from loguru import logger
|
|
4 |
from config.config import Config
|
5 |
from model.yolo import get_model
|
6 |
from tools.log_helper import custom_logger
|
|
|
7 |
from utils.get_dataset import prepare_dataset
|
8 |
|
9 |
|
10 |
@hydra.main(config_path="config", config_name="config", version_base=None)
|
11 |
def main(cfg: Config):
|
|
|
12 |
if cfg.download.auto:
|
13 |
prepare_dataset(cfg.download)
|
14 |
|
|
|
4 |
from config.config import Config
|
5 |
from model.yolo import get_model
|
6 |
from tools.log_helper import custom_logger
|
7 |
+
from utils.dataloader import YoloDataset
|
8 |
from utils.get_dataset import prepare_dataset
|
9 |
|
10 |
|
11 |
@hydra.main(config_path="config", config_name="config", version_base=None)
|
12 |
def main(cfg: Config):
|
13 |
+
dataset = YoloDataset(cfg)
|
14 |
if cfg.download.auto:
|
15 |
prepare_dataset(cfg.download)
|
16 |
|
utils/dataargument.py
CHANGED
@@ -7,8 +7,9 @@ from torchvision.transforms import functional as TF
|
|
7 |
class Compose:
|
8 |
"""Composes several transforms together."""
|
9 |
|
10 |
-
def __init__(self, transforms):
|
11 |
self.transforms = transforms
|
|
|
12 |
|
13 |
for transform in self.transforms:
|
14 |
if hasattr(transform, "set_parent"):
|
@@ -19,9 +20,6 @@ class Compose:
|
|
19 |
image, boxes = transform(image, boxes)
|
20 |
return image, boxes
|
21 |
|
22 |
-
def get_more_data(self):
|
23 |
-
raise NotImplementedError("This method should be overridden by subclass instances!")
|
24 |
-
|
25 |
|
26 |
class RandomHorizontalFlip:
|
27 |
"""Randomly horizontally flips the image along with the bounding boxes."""
|
|
|
7 |
class Compose:
|
8 |
"""Composes several transforms together."""
|
9 |
|
10 |
+
def __init__(self, transforms, image_size: int = 640):
|
11 |
self.transforms = transforms
|
12 |
+
self.image_size = image_size
|
13 |
|
14 |
for transform in self.transforms:
|
15 |
if hasattr(transform, "set_parent"):
|
|
|
20 |
image, boxes = transform(image, boxes)
|
21 |
return image, boxes
|
22 |
|
|
|
|
|
|
|
23 |
|
24 |
class RandomHorizontalFlip:
|
25 |
"""Randomly horizontally flips the image along with the bounding boxes."""
|
utils/dataloader.py
CHANGED
@@ -5,22 +5,25 @@ import diskcache as dc
|
|
5 |
import hydra
|
6 |
import numpy as np
|
7 |
import torch
|
8 |
-
from dataargument import Compose, Mosaic, RandomHorizontalFlip
|
9 |
-
from drawer import draw_bboxes
|
10 |
from loguru import logger
|
11 |
from PIL import Image
|
12 |
from torch.utils.data import Dataset
|
13 |
from tqdm.rich import tqdm
|
14 |
|
|
|
|
|
|
|
15 |
|
16 |
class YoloDataset(Dataset):
|
17 |
-
def __init__(self,
|
|
|
|
|
18 |
phase_name = dataset_cfg.get(phase, phase)
|
19 |
self.image_size = image_size
|
20 |
|
21 |
-
|
|
|
22 |
self.transform.get_more_data = self.get_more_data
|
23 |
-
self.transform.image_size = self.image_size
|
24 |
self.data = self.load_data(dataset_cfg.path, phase_name)
|
25 |
|
26 |
def load_data(self, dataset_path, phase_name):
|
@@ -129,8 +132,7 @@ class YoloDataset(Dataset):
|
|
129 |
|
130 |
@hydra.main(config_path="../config", config_name="config", version_base=None)
|
131 |
def main(cfg):
|
132 |
-
|
133 |
-
dataset = YoloDataset(cfg.data, transform=transform)
|
134 |
draw_bboxes(*dataset[0])
|
135 |
|
136 |
|
|
|
5 |
import hydra
|
6 |
import numpy as np
|
7 |
import torch
|
|
|
|
|
8 |
from loguru import logger
|
9 |
from PIL import Image
|
10 |
from torch.utils.data import Dataset
|
11 |
from tqdm.rich import tqdm
|
12 |
|
13 |
+
from utils.dataargument import Compose, Mosaic, RandomHorizontalFlip
|
14 |
+
from utils.drawer import draw_bboxes
|
15 |
+
|
16 |
|
17 |
class YoloDataset(Dataset):
|
18 |
+
def __init__(self, config: dict, phase: str = "train", image_size: int = 640):
|
19 |
+
dataset_cfg = config.data
|
20 |
+
augment_cfg = config.augmentation
|
21 |
phase_name = dataset_cfg.get(phase, phase)
|
22 |
self.image_size = image_size
|
23 |
|
24 |
+
transforms = [eval(aug)(prob) for aug, prob in augment_cfg.items()]
|
25 |
+
self.transform = Compose(transforms, self.image_size)
|
26 |
self.transform.get_more_data = self.get_more_data
|
|
|
27 |
self.data = self.load_data(dataset_cfg.path, phase_name)
|
28 |
|
29 |
def load_data(self, dataset_path, phase_name):
|
|
|
132 |
|
133 |
@hydra.main(config_path="../config", config_name="config", version_base=None)
|
134 |
def main(cfg):
|
135 |
+
dataset = YoloDataset(cfg)
|
|
|
136 |
draw_bboxes(*dataset[0])
|
137 |
|
138 |
|