henry000 commited on
Commit
23db031
·
1 Parent(s): e802523

✨ [Finish] Dataloder and get_dataloader

Browse files
config/config.yaml CHANGED
@@ -7,4 +7,5 @@ defaults:
7
  - download: ../data/download
8
  - augmentation: ../data/augmentation
9
  - model: v7-base
 
10
  - _self_
 
7
  - download: ../data/download
8
  - augmentation: ../data/augmentation
9
  - model: v7-base
10
+ - hyper: default
11
  - _self_
config/hyper/default.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ data:
2
+ batch_size: 4
3
+ shuffle: True
4
+ num_workers: 4
5
+ pin_memory: True
utils/dataloader.py CHANGED
@@ -1,5 +1,5 @@
1
  from os import listdir, path
2
- from typing import Union
3
 
4
  import diskcache as dc
5
  import hydra
@@ -7,16 +7,11 @@ import numpy as np
7
  import torch
8
  from loguru import logger
9
  from PIL import Image
10
- from torch.utils.data import Dataset
 
11
  from tqdm.rich import tqdm
12
 
13
- from utils.data_augment import (
14
- Compose,
15
- MixUp,
16
- Mosaic,
17
- RandomHorizontalFlip,
18
- RandomVerticalFlip,
19
- )
20
  from utils.drawer import draw_bboxes
21
 
22
 
@@ -130,16 +125,55 @@ class YoloDataset(Dataset):
130
  img, bboxes = self.get_data(idx)
131
  if self.transform:
132
  img, bboxes = self.transform(img, bboxes)
 
133
  return img, bboxes
134
 
135
  def __len__(self) -> int:
136
  return len(self.data)
137
 
138
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  @hydra.main(config_path="../config", config_name="config", version_base=None)
140
  def main(cfg):
141
- dataset = YoloDataset(cfg)
142
- draw_bboxes(*dataset[0])
143
 
144
 
145
  if __name__ == "__main__":
 
1
  from os import listdir, path
2
+ from typing import List, Tuple, Union
3
 
4
  import diskcache as dc
5
  import hydra
 
7
  import torch
8
  from loguru import logger
9
  from PIL import Image
10
+ from torch.utils.data import DataLoader, Dataset
11
+ from torchvision.transforms import functional as TF
12
  from tqdm.rich import tqdm
13
 
14
+ from utils.data_augment import Compose, HorizontalFlip, MixUp, Mosaic, VerticalFlip
 
 
 
 
 
 
15
  from utils.drawer import draw_bboxes
16
 
17
 
 
125
  img, bboxes = self.get_data(idx)
126
  if self.transform:
127
  img, bboxes = self.transform(img, bboxes)
128
+ img = TF.to_tensor(img)
129
  return img, bboxes
130
 
131
  def __len__(self) -> int:
132
  return len(self.data)
133
 
134
 
135
+ class YoloDataLoader(DataLoader):
136
+ def __init__(self, config: dict):
137
+ """Initializes the YoloDataLoader with hydra-config files."""
138
+ hyper = config.hyper.data
139
+ dataset = YoloDataset(config)
140
+
141
+ super().__init__(
142
+ dataset,
143
+ batch_size=hyper.batch_size,
144
+ shuffle=hyper.shuffle,
145
+ num_workers=hyper.num_workers,
146
+ pin_memory=hyper.pin_memory,
147
+ collate_fn=self.collate_fn,
148
+ )
149
+
150
+ def collate_fn(self, batch: List[Tuple[torch.Tensor, torch.Tensor]]) -> Tuple[torch.Tensor, List[torch.Tensor]]:
151
+ """
152
+ A collate function to handle batching of images and their corresponding targets.
153
+
154
+ Args:
155
+ batch (list of tuples): Each tuple contains:
156
+ - image (torch.Tensor): The image tensor.
157
+ - labels (torch.Tensor): The tensor of labels for the image.
158
+
159
+ Returns:
160
+ Tuple[torch.Tensor, List[torch.Tensor]]: A tuple containing:
161
+ - A tensor of batched images.
162
+ - A list of tensors, each corresponding to bboxes for each image in the batch.
163
+ """
164
+ images = torch.stack([item[0] for item in batch])
165
+ targets = [item[1] for item in batch]
166
+ return images, targets
167
+
168
+
169
+ def get_dataloader(config):
170
+ return YoloDataLoader(config)
171
+
172
+
173
  @hydra.main(config_path="../config", config_name="config", version_base=None)
174
  def main(cfg):
175
+ dataloader = get_dataloader(cfg)
176
+ draw_bboxes(next(iter(dataloader)))
177
 
178
 
179
  if __name__ == "__main__":
utils/drawer.py CHANGED
@@ -1,23 +1,31 @@
 
 
 
 
1
  from PIL import Image, ImageDraw, ImageFont
 
2
 
3
 
4
- def draw_bboxes(img, bboxes):
5
  """
6
  Draw bounding boxes on an image.
7
 
8
  Args:
9
- - image_path (str): Path to the image file.
10
- - bboxes (list of lists/tuples): Bounding boxes with [x_min, y_min, x_max, y_max, class_id].
 
11
  """
12
- # Load an image
13
- draw = ImageDraw.Draw(img)
 
 
 
 
 
14
 
15
- # Font for class_id (optional)
16
- try:
17
- font = ImageFont.truetype("arial.ttf", 30)
18
- except IOError:
19
- font = ImageFont.load_default(30)
20
  width, height = img.size
 
21
 
22
  for bbox in bboxes:
23
  class_id, x_min, y_min, x_max, y_max = bbox
@@ -26,7 +34,8 @@ def draw_bboxes(img, bboxes):
26
  y_min = y_min * height
27
  y_max = y_max * height
28
  shape = [(x_min, y_min), (x_max, y_max)]
29
- draw.rectangle(shape, outline="red", width=2)
30
  draw.text((x_min, y_min), str(int(class_id)), font=font, fill="blue")
31
 
32
- img.save("output.jpg")
 
 
1
+ from typing import List, Union
2
+
3
+ import torch
4
+ from loguru import logger
5
  from PIL import Image, ImageDraw, ImageFont
6
+ from torchvision.transforms.functional import to_pil_image
7
 
8
 
9
+ def draw_bboxes(img: Union[Image.Image, torch.Tensor], bboxes: List[List[Union[int, float]]]):
10
  """
11
  Draw bounding boxes on an image.
12
 
13
  Args:
14
+ - img (PIL Image or torch.Tensor): Image on which to draw the bounding boxes.
15
+ - bboxes (List of Lists/Tensors): Bounding boxes with [class_id, x_min, y_min, x_max, y_max],
16
+ where coordinates are normalized [0, 1].
17
  """
18
+ # Convert tensor image to PIL Image if necessary
19
+ if isinstance(img, torch.Tensor):
20
+ if img.dim() > 3:
21
+ logger.info("Multi-frame tensor detected, using the first image.")
22
+ img = img[0]
23
+ bboxes = bboxes[0]
24
+ img = to_pil_image(img)
25
 
26
+ draw = ImageDraw.Draw(img)
 
 
 
 
27
  width, height = img.size
28
+ font = ImageFont.load_default(30)
29
 
30
  for bbox in bboxes:
31
  class_id, x_min, y_min, x_max, y_max = bbox
 
34
  y_min = y_min * height
35
  y_max = y_max * height
36
  shape = [(x_min, y_min), (x_max, y_max)]
37
+ draw.rectangle(shape, outline="red", width=3)
38
  draw.text((x_min, y_min), str(int(class_id)), font=font, fill="blue")
39
 
40
+ img.save("visualize.jpg") # Save the image with annotations
41
+ logger.info("Saved visualize image at visualize.png")