lucytuan commited on
Commit
de1ec48
·
1 Parent(s): 1dfe70c

✨ [Add] json can be directly read by dataloader.py

Browse files

The default data type will be json file in this commit, and this will be further modified in later commits.

Files changed (1) hide show
  1. utils/dataloader.py +103 -30
utils/dataloader.py CHANGED
@@ -1,3 +1,6 @@
 
 
 
1
  from os import listdir, path
2
  from typing import List, Tuple, Union
3
 
@@ -5,14 +8,59 @@ import diskcache as dc
5
  import hydra
6
  import numpy as np
7
  import torch
 
 
8
  from loguru import logger
9
  from PIL import Image
10
  from torch.utils.data import DataLoader, Dataset
11
  from torchvision.transforms import functional as TF
12
  from tqdm.rich import tqdm
13
 
14
- from utils.data_augment import Compose, HorizontalFlip, MixUp, Mosaic, VerticalFlip
15
- from utils.drawer import draw_bboxes
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
 
18
  class YoloDataset(Dataset):
@@ -44,9 +92,7 @@ class YoloDataset(Dataset):
44
 
45
  if data is None:
46
  logger.info("Generating {} cache", phase_name)
47
- images_path = path.join(dataset_path, "images", phase_name)
48
- labels_path = path.join(dataset_path, "label", phase_name)
49
- data = self.filter_data(images_path, labels_path)
50
  cache[phase_name] = data
51
 
52
  cache.close()
@@ -54,7 +100,7 @@ class YoloDataset(Dataset):
54
  data = cache[phase_name]
55
  return data
56
 
57
- def filter_data(self, images_path: str, labels_path: str) -> list:
58
  """
59
  Filters and collects dataset information by pairing images with their corresponding labels.
60
 
@@ -63,29 +109,58 @@ class YoloDataset(Dataset):
63
  labels_path (str): Path to the directory containing label files.
64
 
65
  Returns:
66
- list: A list of tuples, each containing the path to an image file and its associated labels as a tensor.
67
  """
 
 
 
68
  data = []
69
  valid_inputs = 0
70
- images_list = sorted(listdir(images_path))
71
- for image_name in tqdm(images_list, desc="Filtering data"):
72
- if not image_name.lower().endswith((".jpg", ".jpeg", ".png")):
73
- continue
74
-
75
- img_path = path.join(images_path, image_name)
76
- base_name, _ = path.splitext(image_name)
77
- label_path = path.join(labels_path, f"{base_name}.txt")
78
 
79
- if path.isfile(label_path):
80
- labels = self.load_valid_labels(label_path)
81
- if labels is not None:
82
- data.append((img_path, labels))
83
- valid_inputs += 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
  logger.info("Recorded {}/{} valid inputs", valid_inputs, len(images_list))
86
  return data
87
 
88
- def load_valid_labels(self, label_path: str) -> Union[torch.Tensor, None]:
89
  """
90
  Loads and validates bounding box data is [0, 1] from a label file.
91
 
@@ -96,15 +171,13 @@ class YoloDataset(Dataset):
96
  torch.Tensor or None: A tensor of all valid bounding boxes if any are found; otherwise, None.
97
  """
98
  bboxes = []
99
- with open(label_path, "r") as file:
100
- for line in file:
101
- parts = list(map(float, line.strip().split()))
102
- cls = parts[0]
103
- points = np.array(parts[1:]).reshape(-1, 2)
104
- valid_points = points[(points >= 0) & (points <= 1)].reshape(-1, 2)
105
- if valid_points.size > 1:
106
- bbox = torch.tensor([cls, *valid_points.min(axis=0), *valid_points.max(axis=0)])
107
- bboxes.append(bbox)
108
 
109
  if bboxes:
110
  return torch.stack(bboxes)
 
1
+ import json
2
+ import os
3
+ from itertools import chain
4
  from os import listdir, path
5
  from typing import List, Tuple, Union
6
 
 
8
  import hydra
9
  import numpy as np
10
  import torch
11
+ from data_augment import Compose, HorizontalFlip, MixUp, Mosaic, VerticalFlip
12
+ from drawer import draw_bboxes
13
  from loguru import logger
14
  from PIL import Image
15
  from torch.utils.data import DataLoader, Dataset
16
  from torchvision.transforms import functional as TF
17
  from tqdm.rich import tqdm
18
 
19
+
20
+ def find_labels_path(dataset_path, phase_name):
21
+ json_labels_path = path.join(dataset_path, "annotations", f"instances_{phase_name}.json")
22
+
23
+ txt_labels_path = path.join(dataset_path, "label", phase_name)
24
+
25
+ if path.isfile(json_labels_path):
26
+ return json_labels_path, "json"
27
+
28
+ elif path.isdir(txt_labels_path):
29
+ txt_files = [f for f in os.listdir(txt_labels_path) if f.endswith(".txt")]
30
+ if txt_files:
31
+ return txt_labels_path, "txt"
32
+
33
+ raise FileNotFoundError("No labels found in the specified dataset path and phase name.")
34
+
35
+
36
+ def load_json_labels(json_labels_path):
37
+ with open(json_labels_path, "r") as file:
38
+ data = json.load(file)
39
+ return data
40
+
41
+
42
+ def create_annotation_lookup(data):
43
+ annotation_lookup = {}
44
+ for anno in data["annotations"]:
45
+ if anno["iscrowd"] == 0: # Exclude crowd annotations
46
+ image_id = anno["image_id"]
47
+ if image_id not in annotation_lookup:
48
+ annotation_lookup[image_id] = []
49
+ annotation_lookup[image_id].append(anno)
50
+ return annotation_lookup
51
+
52
+
53
+ def process_annotations(annotations, image_id, image_dimensions):
54
+ ret_array = []
55
+ h, w = image_dimensions["height"], image_dimensions["width"]
56
+ for anno in annotations:
57
+ category_id = anno["category_id"]
58
+ flat_list = [item for sublist in anno["segmentation"] for item in sublist]
59
+ normalized_data = (np.array(flat_list).reshape(-1, 2) / [w, h]).tolist()
60
+ normalized_flat = list(chain(*normalized_data))
61
+ normalized_flat.insert(0, category_id)
62
+ ret_array.append(normalized_flat)
63
+ return ret_array
64
 
65
 
66
  class YoloDataset(Dataset):
 
92
 
93
  if data is None:
94
  logger.info("Generating {} cache", phase_name)
95
+ data = self.filter_data(dataset_path, phase_name)
 
 
96
  cache[phase_name] = data
97
 
98
  cache.close()
 
100
  data = cache[phase_name]
101
  return data
102
 
103
+ def filter_data(self, dataset_path: str, phase_name: str) -> list:
104
  """
105
  Filters and collects dataset information by pairing images with their corresponding labels.
106
 
 
109
  labels_path (str): Path to the directory containing label files.
110
 
111
  Returns:
112
+ list: A list of tuples, each containing the path to an image file and its associated segmentation as a tensor.
113
  """
114
+ images_path = path.join(dataset_path, "images", phase_name)
115
+ labels_path, data_type = find_labels_path(dataset_path, phase_name)
116
+ images_list = sorted(os.listdir(images_path))
117
  data = []
118
  valid_inputs = 0
 
 
 
 
 
 
 
 
119
 
120
+ if data_type == "json":
121
+ labels_data = load_json_labels(labels_path)
122
+ annotations_lookup = create_annotation_lookup(labels_data)
123
+ image_info_dict = {path.splitext(img["file_name"])[0]: img for img in labels_data["images"]}
124
+
125
+ for image_name in tqdm(images_list, desc="Filtering data"):
126
+ if not image_name.lower().endswith((".jpg", ".jpeg", ".png")):
127
+ continue
128
+ base_name, _ = path.splitext(image_name)
129
+ if base_name in image_info_dict:
130
+ image_info = image_info_dict[base_name]
131
+ annotations = annotations_lookup.get(image_info["id"], [])
132
+ if annotations:
133
+ processed_data = process_annotations(annotations, image_info["id"], image_info)
134
+ if processed_data:
135
+ img_path = path.join(images_path, image_name)
136
+ labels = self.load_valid_labels(img_path, processed_data)
137
+ if labels is not None:
138
+ data.append((img_path, labels))
139
+ valid_inputs += 1
140
+
141
+ elif data_type == "txt":
142
+ for image_name in tqdm(images_list, desc="Filtering data"):
143
+ if not image_name.lower().endswith((".jpg", ".jpeg", ".png")):
144
+ continue
145
+ img_path = path.join(images_path, image_name)
146
+ base_name, _ = path.splitext(image_name)
147
+ label_path = path.join(labels_path, f"{base_name}.txt")
148
+
149
+ if path.isfile(label_path):
150
+ seg_data_one_img = []
151
+ with open(label_path, "r") as file:
152
+ for line in file:
153
+ parts = list(map(float, line.strip().split()))
154
+ seg_data_one_img.append(parts)
155
+ labels = self.load_valid_labels(label_path, seg_data_one_img)
156
+ if labels is not None:
157
+ data.append((img_path, labels))
158
+ valid_inputs += 1
159
 
160
  logger.info("Recorded {}/{} valid inputs", valid_inputs, len(images_list))
161
  return data
162
 
163
+ def load_valid_labels(self, label_path, seg_data_one_img) -> Union[torch.Tensor, None]:
164
  """
165
  Loads and validates bounding box data is [0, 1] from a label file.
166
 
 
171
  torch.Tensor or None: A tensor of all valid bounding boxes if any are found; otherwise, None.
172
  """
173
  bboxes = []
174
+ for seg_data in seg_data_one_img:
175
+ cls = seg_data[0]
176
+ points = np.array(seg_data[1:]).reshape(-1, 2)
177
+ valid_points = points[(points >= 0) & (points <= 1)].reshape(-1, 2)
178
+ if valid_points.size > 1:
179
+ bbox = torch.tensor([cls, *valid_points.min(axis=0), *valid_points.max(axis=0)])
180
+ bboxes.append(bbox)
 
 
181
 
182
  if bboxes:
183
  return torch.stack(bboxes)