henry000 commited on
Commit
e9629ef
·
1 Parent(s): d5ba31a

✨ [Add] Dataset for loading image, labels

Browse files
Files changed (1) hide show
  1. utils/dataloader.py +105 -0
utils/dataloader.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from os import path
3
+ import os
4
+ import hydra
5
+ import numpy as np
6
+ import torch
7
+ from torch.utils.data import Dataset
8
+ from loguru import logger
9
+ from tqdm.rich import tqdm
10
+ import diskcache as dc
11
+
12
+
13
+ class YoloDataset(Dataset):
14
+ def __init__(self, dataset_cfg: dict, phase="train", transform=None, mixup=None):
15
+ phase_name = dataset_cfg.get(phase, phase)
16
+
17
+ self.transform = transform
18
+ self.mixup = mixup
19
+ self.data = self.load_data(dataset_cfg.path, phase_name)
20
+
21
+ def load_data(self, dataset_path, phase_name):
22
+ cache = dc.Cache(path.join(dataset_path, ".cache"))
23
+
24
+ if phase_name not in cache:
25
+ logger.info("Generate {} Cache", phase_name)
26
+
27
+ images_path = path.join(dataset_path, phase_name, "images")
28
+ labels_path = path.join(dataset_path, phase_name, "labels")
29
+
30
+ cache[phase_name] = self.filter_data(images_path, labels_path)
31
+
32
+ logger.info("Load {} Cache", phase_name)
33
+ data = cache[phase_name]
34
+ cache.close()
35
+
36
+ return data
37
+
38
+ def filter_data(self, images_path, labels_path):
39
+ data = []
40
+ valid_input = 0
41
+ images_list = os.listdir(images_path)
42
+ images_list.sort()
43
+ for image_name in tqdm(images_list):
44
+ if not image_name.lower().endswith((".jpg", ".jpeg", ".png")):
45
+ continue
46
+ img_path = path.join(images_path, image_name)
47
+ base_name, _ = path.splitext(image_name)
48
+ label_name = base_name + ".txt"
49
+ label_path = path.join(labels_path, label_name)
50
+
51
+ if not path.isfile(label_path):
52
+ # logger.warning(f"Warning: No label file for {label_path}")
53
+ continue
54
+
55
+ labels = self.load_valid_labels(label_path)
56
+ if labels is not None:
57
+ data.append((img_path, labels))
58
+ valid_input += 1
59
+ logger.info("Finish Record {}/{}", valid_input, len(os.listdir(images_path)))
60
+ return data
61
+
62
+ def load_valid_labels(self, label_path):
63
+ bboxes = []
64
+ with open(label_path, "r") as file:
65
+ for line in file:
66
+ segment = list(map(float, line.strip().split()))
67
+ cls = segment[0]
68
+ # Ensure parts length is odd and more than two points
69
+ if len(segment) % 2 != 1 or len(segment) < 5:
70
+ logger.warning(f"Warning: Format error in {label_path}")
71
+ continue
72
+ points = np.array(segment[1:]).reshape(-1, 2) # change points to n x 2
73
+ valid_idx = np.any((points <= 1) | (points >= 0), axis=1) # filter outlier points
74
+ points = points[valid_idx] # only keep valid points
75
+
76
+ bbox = torch.tensor([cls, *points.max(axis=0), *points.min(axis=0)])
77
+ bboxes.append(bbox)
78
+ if not bboxes:
79
+ logger.warning(f"Warning: No valid BBox in {label_path}")
80
+ return None
81
+ return torch.stack(bboxes)
82
+
83
+ def __getitem__(self, idx):
84
+ img_path, bboxes = self.data[idx]
85
+ img = Image.open(img_path).convert("RGB")
86
+
87
+ return img, bboxes
88
+
89
+ def __len__(self):
90
+ return len(self.images)
91
+
92
+
93
+ @hydra.main(config_path="../config/data", config_name="coco", version_base=None)
94
+ def main(cfg):
95
+ dataset = YoloDataset(cfg)
96
+
97
+
98
+ if __name__ == "__main__":
99
+ import sys
100
+
101
+ sys.path.append("./")
102
+ from tools.log_helper import custom_logger
103
+
104
+ custom_logger()
105
+ main()