崔浩堂 commited on
Commit
6c61696
·
unverified ·
2 Parent(s): 61ddf44 ac7f3c1

🔀 [Merge] pull request #14 from LucyTuan/DATASET

Browse files

[Merge][Add] json annotation data readable in dataloader.py

Files changed (1) hide show
  1. utils/dataloader.py +143 -29
utils/dataloader.py CHANGED
@@ -1,18 +1,116 @@
 
 
 
1
  from os import listdir, path
2
- from typing import List, Tuple, Union
3
 
4
  import diskcache as dc
5
  import hydra
6
  import numpy as np
7
  import torch
 
 
8
  from loguru import logger
9
  from PIL import Image
10
  from torch.utils.data import DataLoader, Dataset
11
  from torchvision.transforms import functional as TF
12
  from tqdm.rich import tqdm
13
 
14
- from utils.data_augment import Compose, HorizontalFlip, MixUp, Mosaic, VerticalFlip
15
- from utils.drawer import draw_bboxes
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
 
18
  class YoloDataset(Dataset):
@@ -44,9 +142,7 @@ class YoloDataset(Dataset):
44
 
45
  if data is None:
46
  logger.info("Generating {} cache", phase_name)
47
- images_path = path.join(dataset_path, "images", phase_name)
48
- labels_path = path.join(dataset_path, "labels", phase_name)
49
- data = self.filter_data(images_path, labels_path)
50
  cache[phase_name] = data
51
 
52
  cache.close()
@@ -54,7 +150,7 @@ class YoloDataset(Dataset):
54
  data = cache[phase_name]
55
  return data
56
 
57
- def filter_data(self, images_path: str, labels_path: str) -> list:
58
  """
59
  Filters and collects dataset information by pairing images with their corresponding labels.
60
 
@@ -63,29 +159,49 @@ class YoloDataset(Dataset):
63
  labels_path (str): Path to the directory containing label files.
64
 
65
  Returns:
66
- list: A list of tuples, each containing the path to an image file and its associated labels as a tensor.
67
  """
 
 
 
 
 
 
68
  data = []
69
  valid_inputs = 0
70
- images_list = sorted(listdir(images_path))
71
  for image_name in tqdm(images_list, desc="Filtering data"):
72
  if not image_name.lower().endswith((".jpg", ".jpeg", ".png")):
73
  continue
74
-
75
- img_path = path.join(images_path, image_name)
76
- base_name, _ = path.splitext(image_name)
77
- label_path = path.join(labels_path, f"{base_name}.txt")
78
-
79
- if path.isfile(label_path):
80
- labels = self.load_valid_labels(label_path)
81
- if labels is not None:
82
- data.append((img_path, labels))
83
- valid_inputs += 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
  logger.info("Recorded {}/{} valid inputs", valid_inputs, len(images_list))
86
  return data
87
 
88
- def load_valid_labels(self, label_path: str) -> Union[torch.Tensor, None]:
89
  """
90
  Loads and validates bounding box data is [0, 1] from a label file.
91
 
@@ -96,15 +212,13 @@ class YoloDataset(Dataset):
96
  torch.Tensor or None: A tensor of all valid bounding boxes if any are found; otherwise, None.
97
  """
98
  bboxes = []
99
- with open(label_path, "r") as file:
100
- for line in file:
101
- parts = list(map(float, line.strip().split()))
102
- cls = parts[0]
103
- points = np.array(parts[1:]).reshape(-1, 2)
104
- valid_points = points[(points >= 0) & (points <= 1)].reshape(-1, 2)
105
- if valid_points.size > 1:
106
- bbox = torch.tensor([cls, *valid_points.min(axis=0), *valid_points.max(axis=0)])
107
- bboxes.append(bbox)
108
 
109
  if bboxes:
110
  return torch.stack(bboxes)
 
1
+ import json
2
+ import os
3
+ from itertools import chain
4
  from os import listdir, path
5
+ from typing import Any, Dict, List, Optional, Tuple, Union
6
 
7
  import diskcache as dc
8
  import hydra
9
  import numpy as np
10
  import torch
11
+ from data_augment import Compose, HorizontalFlip, MixUp, Mosaic, VerticalFlip
12
+ from drawer import draw_bboxes
13
  from loguru import logger
14
  from PIL import Image
15
  from torch.utils.data import DataLoader, Dataset
16
  from torchvision.transforms import functional as TF
17
  from tqdm.rich import tqdm
18
 
19
+
20
+ def find_labels_path(dataset_path: str, phase_name: str):
21
+ """
22
+ Find the path to label files for a specified dataset and phase(e.g. training).
23
+
24
+ Args:
25
+ dataset_path (str): The path to the root directory of the dataset.
26
+ phase_name (str): The name of the phase for which labels are being searched (e.g., "train", "val", "test").
27
+
28
+ Returns:
29
+ Tuple[str, str]: A tuple containing the path to the labels file and the file format ("json" or "txt").
30
+ """
31
+ json_labels_path = path.join(dataset_path, "annotations", f"instances_{phase_name}.json")
32
+
33
+ txt_labels_path = path.join(dataset_path, "label", phase_name)
34
+
35
+ if path.isfile(json_labels_path):
36
+ return json_labels_path, "json"
37
+
38
+ elif path.isdir(txt_labels_path):
39
+ txt_files = [f for f in os.listdir(txt_labels_path) if f.endswith(".txt")]
40
+ if txt_files:
41
+ return txt_labels_path, "txt"
42
+
43
+ raise FileNotFoundError("No labels found in the specified dataset path and phase name.")
44
+
45
+
46
+ def create_image_info_dict(labels_path: str) -> Tuple[Dict[str, List], Dict[str, Dict]]:
47
+ """
48
+ Create a dictionary containing image information and annotations indexed by image ID.
49
+
50
+ Args:
51
+ labels_path (str): The path to the annotation json file.
52
+
53
+ Returns:
54
+ - annotations_index: A dictionary where keys are image IDs and values are lists of annotations.
55
+ - image_info_dict: A dictionary where keys are image file names without extension and values are image information dictionaries.
56
+ """
57
+ with open(labels_path, "r") as file:
58
+ labels_data = json.load(file)
59
+ annotations_index = index_annotations_by_image(labels_data) # check lookup is a good name?
60
+ image_info_dict = {path.splitext(img["file_name"])[0]: img for img in labels_data["images"]}
61
+ return annotations_index, image_info_dict
62
+
63
+
64
+ def index_annotations_by_image(data: Dict[str, Any]):
65
+ """
66
+ Use image index to lookup every annotations
67
+ Args:
68
+ data (Dict[str, Any]): A dictionary containing annotation data.
69
+
70
+ Returns:
71
+ Dict[int, List[Dict[str, Any]]]: A dictionary where keys are image IDs and values are lists of annotations.
72
+ Annotations with "iscrowd" set to True are excluded from the index.
73
+
74
+ """
75
+ annotation_lookup = {}
76
+ for anno in data["annotations"]:
77
+ if anno["iscrowd"]:
78
+ continue
79
+ image_id = anno["image_id"]
80
+ if image_id not in annotation_lookup:
81
+ annotation_lookup[image_id] = []
82
+ annotation_lookup[image_id].append(anno)
83
+ return annotation_lookup
84
+
85
+
86
+ def get_scaled_segmentation(
87
+ annotations: List[Dict[str, Any]], image_dimensions: Dict[str, int]
88
+ ) -> Optional[List[List[float]]]:
89
+ """
90
+ Scale the segmentation data based on image dimensions and return a list of scaled segmentation data.
91
+
92
+ Args:
93
+ annotations (List[Dict[str, Any]]): A list of annotation dictionaries.
94
+ image_dimensions (Dict[str, int]): A dictionary containing image dimensions (height and width).
95
+
96
+ Returns:
97
+ Optional[List[List[float]]]: A list of scaled segmentation data, where each sublist contains category_id followed by scaled (x, y) coordinates.
98
+ """
99
+ if annotations is None:
100
+ return None
101
+
102
+ seg_array_with_cat = []
103
+ h, w = image_dimensions["height"], image_dimensions["width"]
104
+ for anno in annotations:
105
+ category_id = anno["category_id"]
106
+ seg_list = [item for sublist in anno["segmentation"] for item in sublist]
107
+ scaled_seg_data = (
108
+ np.array(seg_list).reshape(-1, 2) / [w, h]
109
+ ).tolist() # make the list group in x, y pairs and scaled with image width, height
110
+ scaled_flat_seg_data = [category_id] + list(chain(*scaled_seg_data)) # flatten the scaled_seg_data list
111
+ seg_array_with_cat.append(scaled_flat_seg_data)
112
+
113
+ return seg_array_with_cat
114
 
115
 
116
  class YoloDataset(Dataset):
 
142
 
143
  if data is None:
144
  logger.info("Generating {} cache", phase_name)
145
+ data = self.filter_data(dataset_path, phase_name)
 
 
146
  cache[phase_name] = data
147
 
148
  cache.close()
 
150
  data = cache[phase_name]
151
  return data
152
 
153
+ def filter_data(self, dataset_path: str, phase_name: str) -> list:
154
  """
155
  Filters and collects dataset information by pairing images with their corresponding labels.
156
 
 
159
  labels_path (str): Path to the directory containing label files.
160
 
161
  Returns:
162
+ list: A list of tuples, each containing the path to an image file and its associated segmentation as a tensor.
163
  """
164
+ images_path = path.join(dataset_path, "images", phase_name)
165
+ labels_path, data_type = find_labels_path(dataset_path, phase_name)
166
+ images_list = sorted(os.listdir(images_path))
167
+ if data_type == "json":
168
+ annotations_index, image_info_dict = create_image_info_dict(labels_path)
169
+
170
  data = []
171
  valid_inputs = 0
 
172
  for image_name in tqdm(images_list, desc="Filtering data"):
173
  if not image_name.lower().endswith((".jpg", ".jpeg", ".png")):
174
  continue
175
+ image_id, _ = path.splitext(image_name)
176
+
177
+ if data_type == "json":
178
+ image_info = image_info_dict.get(image_id, None)
179
+ if image_info is None:
180
+ continue
181
+ annotations = annotations_index.get(image_info["id"], [])
182
+ image_seg_annotations = get_scaled_segmentation(annotations, image_info)
183
+ if not image_seg_annotations:
184
+ continue
185
+
186
+ elif data_type == "txt":
187
+ label_path = path.join(labels_path, f"{image_id}.txt")
188
+ if not path.isfile(label_path):
189
+ continue
190
+ with open(label_path, "r") as file:
191
+ image_seg_annotations = [
192
+ list(map(float, line.strip().split())) for line in file
193
+ ] # add a comment for this line, complicated, do you need "list", im not sure
194
+
195
+ labels = self.load_valid_labels(image_id, image_seg_annotations)
196
+ if labels is not None:
197
+ img_path = path.join(images_path, image_name)
198
+ data.append((img_path, labels))
199
+ valid_inputs += 1
200
 
201
  logger.info("Recorded {}/{} valid inputs", valid_inputs, len(images_list))
202
  return data
203
 
204
+ def load_valid_labels(self, label_path, seg_data_one_img) -> Union[torch.Tensor, None]:
205
  """
206
  Loads and validates bounding box data is [0, 1] from a label file.
207
 
 
212
  torch.Tensor or None: A tensor of all valid bounding boxes if any are found; otherwise, None.
213
  """
214
  bboxes = []
215
+ for seg_data in seg_data_one_img:
216
+ cls = seg_data[0]
217
+ points = np.array(seg_data[1:]).reshape(-1, 2)
218
+ valid_points = points[(points >= 0) & (points <= 1)].reshape(-1, 2)
219
+ if valid_points.size > 1:
220
+ bbox = torch.tensor([cls, *valid_points.min(axis=0), *valid_points.max(axis=0)])
221
+ bboxes.append(bbox)
 
 
222
 
223
  if bboxes:
224
  return torch.stack(bboxes)