henry000 commited on
Commit
95ca62f
·
1 Parent(s): 31cab2b

🔨 [Add] Dataset to train.py, todo: Dataloader

Browse files
Files changed (3) hide show
  1. train.py +2 -0
  2. utils/dataargument.py +2 -4
  3. utils/dataloader.py +9 -7
train.py CHANGED
@@ -4,11 +4,13 @@ from loguru import logger
4
  from config.config import Config
5
  from model.yolo import get_model
6
  from tools.log_helper import custom_logger
 
7
  from utils.get_dataset import prepare_dataset
8
 
9
 
10
  @hydra.main(config_path="config", config_name="config", version_base=None)
11
  def main(cfg: Config):
 
12
  if cfg.download.auto:
13
  prepare_dataset(cfg.download)
14
 
 
4
  from config.config import Config
5
  from model.yolo import get_model
6
  from tools.log_helper import custom_logger
7
+ from utils.dataloader import YoloDataset
8
  from utils.get_dataset import prepare_dataset
9
 
10
 
11
  @hydra.main(config_path="config", config_name="config", version_base=None)
12
  def main(cfg: Config):
13
+ dataset = YoloDataset(cfg)
14
  if cfg.download.auto:
15
  prepare_dataset(cfg.download)
16
 
utils/dataargument.py CHANGED
@@ -7,8 +7,9 @@ from torchvision.transforms import functional as TF
7
  class Compose:
8
  """Composes several transforms together."""
9
 
10
- def __init__(self, transforms):
11
  self.transforms = transforms
 
12
 
13
  for transform in self.transforms:
14
  if hasattr(transform, "set_parent"):
@@ -19,9 +20,6 @@ class Compose:
19
  image, boxes = transform(image, boxes)
20
  return image, boxes
21
 
22
- def get_more_data(self):
23
- raise NotImplementedError("This method should be overridden by subclass instances!")
24
-
25
 
26
  class RandomHorizontalFlip:
27
  """Randomly horizontally flips the image along with the bounding boxes."""
 
7
  class Compose:
8
  """Composes several transforms together."""
9
 
10
+ def __init__(self, transforms, image_size: int = 640):
11
  self.transforms = transforms
12
+ self.image_size = image_size
13
 
14
  for transform in self.transforms:
15
  if hasattr(transform, "set_parent"):
 
20
  image, boxes = transform(image, boxes)
21
  return image, boxes
22
 
 
 
 
23
 
24
  class RandomHorizontalFlip:
25
  """Randomly horizontally flips the image along with the bounding boxes."""
utils/dataloader.py CHANGED
@@ -5,22 +5,25 @@ import diskcache as dc
5
  import hydra
6
  import numpy as np
7
  import torch
8
- from dataargument import Compose, Mosaic, RandomHorizontalFlip
9
- from drawer import draw_bboxes
10
  from loguru import logger
11
  from PIL import Image
12
  from torch.utils.data import Dataset
13
  from tqdm.rich import tqdm
14
 
 
 
 
15
 
16
  class YoloDataset(Dataset):
17
- def __init__(self, dataset_cfg: dict, phase: str = "train", image_size: int = 640, transform=None):
 
 
18
  phase_name = dataset_cfg.get(phase, phase)
19
  self.image_size = image_size
20
 
21
- self.transform = transform
 
22
  self.transform.get_more_data = self.get_more_data
23
- self.transform.image_size = self.image_size
24
  self.data = self.load_data(dataset_cfg.path, phase_name)
25
 
26
  def load_data(self, dataset_path, phase_name):
@@ -129,8 +132,7 @@ class YoloDataset(Dataset):
129
 
130
  @hydra.main(config_path="../config", config_name="config", version_base=None)
131
  def main(cfg):
132
- transform = Compose([eval(aug)(prob) for aug, prob in cfg.augmentation.items()])
133
- dataset = YoloDataset(cfg.data, transform=transform)
134
  draw_bboxes(*dataset[0])
135
 
136
 
 
5
  import hydra
6
  import numpy as np
7
  import torch
 
 
8
  from loguru import logger
9
  from PIL import Image
10
  from torch.utils.data import Dataset
11
  from tqdm.rich import tqdm
12
 
13
+ from utils.dataargument import Compose, Mosaic, RandomHorizontalFlip
14
+ from utils.drawer import draw_bboxes
15
+
16
 
17
  class YoloDataset(Dataset):
18
+ def __init__(self, config: dict, phase: str = "train", image_size: int = 640):
19
+ dataset_cfg = config.data
20
+ augment_cfg = config.augmentation
21
  phase_name = dataset_cfg.get(phase, phase)
22
  self.image_size = image_size
23
 
24
+ transforms = [eval(aug)(prob) for aug, prob in augment_cfg.items()]
25
+ self.transform = Compose(transforms, self.image_size)
26
  self.transform.get_more_data = self.get_more_data
 
27
  self.data = self.load_data(dataset_cfg.path, phase_name)
28
 
29
  def load_data(self, dataset_path, phase_name):
 
132
 
133
  @hydra.main(config_path="../config", config_name="config", version_base=None)
134
  def main(cfg):
135
+ dataset = YoloDataset(cfg)
 
136
  draw_bboxes(*dataset[0])
137
 
138