henry000 commited on
Commit
d8aafaa
·
1 Parent(s): 7c11918

✨ [Add] A instance of dataaugments

Browse files
config/config.yaml CHANGED
@@ -5,5 +5,6 @@ hydra:
5
  defaults:
6
  - data: coco
7
  - download: ../data/download
 
8
  - model: v7-base
9
  - _self_
 
5
  defaults:
6
  - data: coco
7
  - download: ../data/download
8
+ - augmentation: ../data/augmentation
9
  - model: v7-base
10
  - _self_
config/data/augmentation.yaml ADDED
@@ -0,0 +1 @@
 
 
1
+ RandomHorizontalFlip: 0.5
utils/dataargument.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision.transforms import functional as TF
3
+
4
+
5
+ class Compose:
6
+ """Composes several transforms together."""
7
+
8
+ def __init__(self, transforms):
9
+ self.transforms = transforms
10
+
11
+ def __call__(self, image, boxes):
12
+ for t in self.transforms:
13
+ image, boxes = t(image, boxes)
14
+ return image, boxes
15
+
16
+
17
+ class RandomHorizontalFlip:
18
+ """Randomly horizontally flips the image along with the bounding boxes."""
19
+
20
+ def __init__(self, p=0.5):
21
+ self.p = p
22
+
23
+ def __call__(self, image, boxes):
24
+ if torch.rand(1) < self.p:
25
+ image = TF.hflip(image)
26
+ # Assuming boxes are in the format [cls, xmin, ymin, xmax, ymax]
27
+ boxes[:, [1, 3]] = 1 - boxes[:, [3, 1]]
28
+ return image, boxes
utils/dataloader.py CHANGED
@@ -1,6 +1,6 @@
1
  from PIL import Image
2
- from os import path
3
- import os
4
  import hydra
5
  import numpy as np
6
  import torch
@@ -8,91 +8,119 @@ from torch.utils.data import Dataset
8
  from loguru import logger
9
  from tqdm.rich import tqdm
10
  import diskcache as dc
 
 
 
11
 
12
 
13
  class YoloDataset(Dataset):
14
- def __init__(self, dataset_cfg: dict, phase="train", transform=None, mixup=None):
15
  phase_name = dataset_cfg.get(phase, phase)
16
 
17
  self.transform = transform
18
- self.mixup = mixup
19
  self.data = self.load_data(dataset_cfg.path, phase_name)
20
 
21
  def load_data(self, dataset_path, phase_name):
22
- cache = dc.Cache(path.join(dataset_path, ".cache"))
23
-
24
- if phase_name not in cache:
25
- logger.info("Generate {} Cache", phase_name)
26
-
 
 
 
 
 
 
 
 
 
 
 
27
  images_path = path.join(dataset_path, phase_name, "images")
28
  labels_path = path.join(dataset_path, phase_name, "labels")
 
 
29
 
30
- cache[phase_name] = self.filter_data(images_path, labels_path)
31
-
32
- logger.info("Load {} Cache", phase_name)
33
- data = cache[phase_name]
34
  cache.close()
35
-
 
36
  return data
37
 
38
- def filter_data(self, images_path, labels_path):
 
 
 
 
 
 
 
 
 
 
39
  data = []
40
- valid_input = 0
41
- images_list = os.listdir(images_path)
42
- images_list.sort()
43
- for image_name in tqdm(images_list):
44
  if not image_name.lower().endswith((".jpg", ".jpeg", ".png")):
45
  continue
 
46
  img_path = path.join(images_path, image_name)
47
  base_name, _ = path.splitext(image_name)
48
- label_name = base_name + ".txt"
49
- label_path = path.join(labels_path, label_name)
50
 
51
- if not path.isfile(label_path):
52
- # logger.warning(f"Warning: No label file for {label_path}")
53
- continue
 
 
54
 
55
- labels = self.load_valid_labels(label_path)
56
- if labels is not None:
57
- data.append((img_path, labels))
58
- valid_input += 1
59
- logger.info("Finish Record {}/{}", valid_input, len(os.listdir(images_path)))
60
  return data
61
 
62
- def load_valid_labels(self, label_path):
 
 
 
 
 
 
 
 
 
63
  bboxes = []
64
  with open(label_path, "r") as file:
65
  for line in file:
66
- segment = list(map(float, line.strip().split()))
67
- cls = segment[0]
68
- # Ensure parts length is odd and more than two points
69
- if len(segment) % 2 != 1 or len(segment) < 5:
70
- logger.warning(f"Warning: Format error in {label_path}")
71
- continue
72
- points = np.array(segment[1:]).reshape(-1, 2) # change points to n x 2
73
- valid_idx = np.any((points <= 1) | (points >= 0), axis=1) # filter outlier points
74
- points = points[valid_idx] # only keep valid points
75
-
76
- bbox = torch.tensor([cls, *points.max(axis=0), *points.min(axis=0)])
77
- bboxes.append(bbox)
78
- if not bboxes:
79
- logger.warning(f"Warning: No valid BBox in {label_path}")
80
  return None
81
- return torch.stack(bboxes)
82
 
83
- def __getitem__(self, idx):
84
  img_path, bboxes = self.data[idx]
85
  img = Image.open(img_path).convert("RGB")
86
-
 
87
  return img, bboxes
88
 
89
- def __len__(self):
90
- return len(self.images)
91
 
92
 
93
- @hydra.main(config_path="../config/data", config_name="coco", version_base=None)
94
  def main(cfg):
95
- dataset = YoloDataset(cfg)
 
 
96
 
97
 
98
  if __name__ == "__main__":
 
1
  from PIL import Image
2
+ from os import path, listdir
3
+
4
  import hydra
5
  import numpy as np
6
  import torch
 
8
  from loguru import logger
9
  from tqdm.rich import tqdm
10
  import diskcache as dc
11
+ from typing import Union
12
+ from drawer import draw_bboxes
13
+ from dataargument import Compose, RandomHorizontalFlip
14
 
15
 
16
  class YoloDataset(Dataset):
17
+ def __init__(self, dataset_cfg: dict, phase: str = "train", transform=None):
18
  phase_name = dataset_cfg.get(phase, phase)
19
 
20
  self.transform = transform
 
21
  self.data = self.load_data(dataset_cfg.path, phase_name)
22
 
23
  def load_data(self, dataset_path, phase_name):
24
+ """
25
+ Loads data from a cache or generates a new cache for a specific dataset phase.
26
+
27
+ Parameters:
28
+ dataset_path (str): The root path to the dataset directory.
29
+ phase_name (str): The specific phase of the dataset (e.g., 'train', 'test') to load or generate data for.
30
+
31
+ Returns:
32
+ dict: The loaded data from the cache for the specified phase.
33
+ """
34
+ cache_path = path.join(dataset_path, ".cache")
35
+ cache = dc.Cache(cache_path)
36
+ data = cache.get(phase_name)
37
+
38
+ if data is None:
39
+ logger.info("Generating {} cache", phase_name)
40
  images_path = path.join(dataset_path, phase_name, "images")
41
  labels_path = path.join(dataset_path, phase_name, "labels")
42
+ data = self.filter_data(images_path, labels_path)
43
+ cache[phase_name] = data
44
 
 
 
 
 
45
  cache.close()
46
+ logger.info("Loaded {} cache", phase_name)
47
+ data = cache[phase_name]
48
  return data
49
 
50
+ def filter_data(self, images_path: str, labels_path: str) -> list:
51
+ """
52
+ Filters and collects dataset information by pairing images with their corresponding labels.
53
+
54
+ Parameters:
55
+ images_path (str): Path to the directory containing image files.
56
+ labels_path (str): Path to the directory containing label files.
57
+
58
+ Returns:
59
+ list: A list of tuples, each containing the path to an image file and its associated labels as a tensor.
60
+ """
61
  data = []
62
+ valid_inputs = 0
63
+ images_list = sorted(listdir(images_path))
64
+ for image_name in tqdm(images_list, desc="Filtering data"):
 
65
  if not image_name.lower().endswith((".jpg", ".jpeg", ".png")):
66
  continue
67
+
68
  img_path = path.join(images_path, image_name)
69
  base_name, _ = path.splitext(image_name)
70
+ label_path = path.join(labels_path, f"{base_name}.txt")
 
71
 
72
+ if path.isfile(label_path):
73
+ labels = self.load_valid_labels(label_path)
74
+ if labels is not None:
75
+ data.append((img_path, labels))
76
+ valid_inputs += 1
77
 
78
+ logger.info("Recorded {}/{} valid inputs", valid_inputs, len(images_list))
 
 
 
 
79
  return data
80
 
81
+ def load_valid_labels(self, label_path: str) -> Union[torch.Tensor, None]:
82
+ """
83
+ Loads and validates bounding box data is [0, 1] from a label file.
84
+
85
+ Parameters:
86
+ label_path (str): The filepath to the label file containing bounding box data.
87
+
88
+ Returns:
89
+ torch.Tensor or None: A tensor of all valid bounding boxes if any are found; otherwise, None.
90
+ """
91
  bboxes = []
92
  with open(label_path, "r") as file:
93
  for line in file:
94
+ parts = list(map(float, line.strip().split()))
95
+ cls = parts[0]
96
+ points = np.array(parts[1:]).reshape(-1, 2)
97
+ valid_points = points[(points >= 0) & (points <= 1)].reshape(-1, 2)
98
+ if valid_points.size > 1:
99
+ bbox = torch.tensor([cls, *valid_points.min(axis=0), *valid_points.max(axis=0)])
100
+ bboxes.append(bbox)
101
+
102
+ if bboxes:
103
+ return torch.stack(bboxes)
104
+ else:
105
+ logger.warning("No valid BBox in {}", label_path)
 
 
106
  return None
 
107
 
108
+ def __getitem__(self, idx) -> Union[Image.Image, torch.Tensor]:
109
  img_path, bboxes = self.data[idx]
110
  img = Image.open(img_path).convert("RGB")
111
+ if self.transform:
112
+ img, bboxes = self.transform(img, bboxes)
113
  return img, bboxes
114
 
115
+ def __len__(self) -> int:
116
+ return len(self.data)
117
 
118
 
119
+ @hydra.main(config_path="../config", config_name="config", version_base=None)
120
  def main(cfg):
121
+ transform = Compose([eval(aug)(prob) for aug, prob in cfg.augmentation.items()])
122
+ dataset = YoloDataset(cfg.data, transform=transform)
123
+ draw_bboxes(*dataset[0])
124
 
125
 
126
  if __name__ == "__main__":