lucytuan commited on
Commit
ac7f3c1
·
1 Parent(s): c8b07ff

✨ [Add] json can be directly read by dataloader.py

Browse files

The default format is json in this version. I also done refactored it.

Files changed (1) hide show
  1. utils/dataloader.py +98 -57
utils/dataloader.py CHANGED
@@ -2,7 +2,7 @@ import json
2
  import os
3
  from itertools import chain
4
  from os import listdir, path
5
- from typing import List, Tuple, Union
6
 
7
  import diskcache as dc
8
  import hydra
@@ -17,7 +17,17 @@ from torchvision.transforms import functional as TF
17
  from tqdm.rich import tqdm
18
 
19
 
20
- def find_labels_path(dataset_path, phase_name):
 
 
 
 
 
 
 
 
 
 
21
  json_labels_path = path.join(dataset_path, "annotations", f"instances_{phase_name}.json")
22
 
23
  txt_labels_path = path.join(dataset_path, "label", phase_name)
@@ -33,34 +43,74 @@ def find_labels_path(dataset_path, phase_name):
33
  raise FileNotFoundError("No labels found in the specified dataset path and phase name.")
34
 
35
 
36
- def load_json_labels(json_labels_path):
37
- with open(json_labels_path, "r") as file:
38
- data = json.load(file)
39
- return data
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
 
 
 
 
 
41
 
42
- def create_annotation_lookup(data):
 
 
 
 
43
  annotation_lookup = {}
44
  for anno in data["annotations"]:
45
- if anno["iscrowd"] == 0: # Exclude crowd annotations
46
- image_id = anno["image_id"]
47
- if image_id not in annotation_lookup:
48
- annotation_lookup[image_id] = []
49
- annotation_lookup[image_id].append(anno)
 
50
  return annotation_lookup
51
 
52
 
53
- def process_annotations(annotations, image_id, image_dimensions):
54
- ret_array = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  h, w = image_dimensions["height"], image_dimensions["width"]
56
  for anno in annotations:
57
  category_id = anno["category_id"]
58
- flat_list = [item for sublist in anno["segmentation"] for item in sublist]
59
- normalized_data = (np.array(flat_list).reshape(-1, 2) / [w, h]).tolist()
60
- normalized_flat = list(chain(*normalized_data))
61
- normalized_flat.insert(0, category_id)
62
- ret_array.append(normalized_flat)
63
- return ret_array
 
 
64
 
65
 
66
  class YoloDataset(Dataset):
@@ -114,48 +164,39 @@ class YoloDataset(Dataset):
114
  images_path = path.join(dataset_path, "images", phase_name)
115
  labels_path, data_type = find_labels_path(dataset_path, phase_name)
116
  images_list = sorted(os.listdir(images_path))
117
- data = []
118
- valid_inputs = 0
119
-
120
  if data_type == "json":
121
- labels_data = load_json_labels(labels_path)
122
- annotations_lookup = create_annotation_lookup(labels_data)
123
- image_info_dict = {path.splitext(img["file_name"])[0]: img for img in labels_data["images"]}
124
 
125
- for image_name in tqdm(images_list, desc="Filtering data"):
126
- if not image_name.lower().endswith((".jpg", ".jpeg", ".png")):
 
 
 
 
 
 
 
 
 
 
 
 
127
  continue
128
- base_name, _ = path.splitext(image_name)
129
- if base_name in image_info_dict:
130
- image_info = image_info_dict[base_name]
131
- annotations = annotations_lookup.get(image_info["id"], [])
132
- if annotations:
133
- processed_data = process_annotations(annotations, image_info["id"], image_info)
134
- if processed_data:
135
- img_path = path.join(images_path, image_name)
136
- labels = self.load_valid_labels(img_path, processed_data)
137
- if labels is not None:
138
- data.append((img_path, labels))
139
- valid_inputs += 1
140
-
141
- elif data_type == "txt":
142
- for image_name in tqdm(images_list, desc="Filtering data"):
143
- if not image_name.lower().endswith((".jpg", ".jpeg", ".png")):
144
  continue
 
 
 
 
 
 
 
145
  img_path = path.join(images_path, image_name)
146
- base_name, _ = path.splitext(image_name)
147
- label_path = path.join(labels_path, f"{base_name}.txt")
148
-
149
- if path.isfile(label_path):
150
- seg_data_one_img = []
151
- with open(label_path, "r") as file:
152
- for line in file:
153
- parts = list(map(float, line.strip().split()))
154
- seg_data_one_img.append(parts)
155
- labels = self.load_valid_labels(label_path, seg_data_one_img)
156
- if labels is not None:
157
- data.append((img_path, labels))
158
- valid_inputs += 1
159
 
160
  logger.info("Recorded {}/{} valid inputs", valid_inputs, len(images_list))
161
  return data
 
2
  import os
3
  from itertools import chain
4
  from os import listdir, path
5
+ from typing import Any, Dict, List, Optional, Tuple, Union
6
 
7
  import diskcache as dc
8
  import hydra
 
17
  from tqdm.rich import tqdm
18
 
19
 
20
+ def find_labels_path(dataset_path: str, phase_name: str):
21
+ """
22
+ Find the path to label files for a specified dataset and phase(e.g. training).
23
+
24
+ Args:
25
+ dataset_path (str): The path to the root directory of the dataset.
26
+ phase_name (str): The name of the phase for which labels are being searched (e.g., "train", "val", "test").
27
+
28
+ Returns:
29
+ Tuple[str, str]: A tuple containing the path to the labels file and the file format ("json" or "txt").
30
+ """
31
  json_labels_path = path.join(dataset_path, "annotations", f"instances_{phase_name}.json")
32
 
33
  txt_labels_path = path.join(dataset_path, "label", phase_name)
 
43
  raise FileNotFoundError("No labels found in the specified dataset path and phase name.")
44
 
45
 
46
+ def create_image_info_dict(labels_path: str) -> Tuple[Dict[str, List], Dict[str, Dict]]:
47
+ """
48
+ Create a dictionary containing image information and annotations indexed by image ID.
49
+
50
+ Args:
51
+ labels_path (str): The path to the annotation json file.
52
+
53
+ Returns:
54
+ - annotations_index: A dictionary where keys are image IDs and values are lists of annotations.
55
+ - image_info_dict: A dictionary where keys are image file names without extension and values are image information dictionaries.
56
+ """
57
+ with open(labels_path, "r") as file:
58
+ labels_data = json.load(file)
59
+ annotations_index = index_annotations_by_image(labels_data) # check lookup is a good name?
60
+ image_info_dict = {path.splitext(img["file_name"])[0]: img for img in labels_data["images"]}
61
+ return annotations_index, image_info_dict
62
+
63
 
64
+ def index_annotations_by_image(data: Dict[str, Any]):
65
+ """
66
+ Use image index to lookup every annotations
67
+ Args:
68
+ data (Dict[str, Any]): A dictionary containing annotation data.
69
 
70
+ Returns:
71
+ Dict[int, List[Dict[str, Any]]]: A dictionary where keys are image IDs and values are lists of annotations.
72
+ Annotations with "iscrowd" set to True are excluded from the index.
73
+
74
+ """
75
  annotation_lookup = {}
76
  for anno in data["annotations"]:
77
+ if anno["iscrowd"]:
78
+ continue
79
+ image_id = anno["image_id"]
80
+ if image_id not in annotation_lookup:
81
+ annotation_lookup[image_id] = []
82
+ annotation_lookup[image_id].append(anno)
83
  return annotation_lookup
84
 
85
 
86
+ def get_scaled_segmentation(
87
+ annotations: List[Dict[str, Any]], image_dimensions: Dict[str, int]
88
+ ) -> Optional[List[List[float]]]:
89
+ """
90
+ Scale the segmentation data based on image dimensions and return a list of scaled segmentation data.
91
+
92
+ Args:
93
+ annotations (List[Dict[str, Any]]): A list of annotation dictionaries.
94
+ image_dimensions (Dict[str, int]): A dictionary containing image dimensions (height and width).
95
+
96
+ Returns:
97
+ Optional[List[List[float]]]: A list of scaled segmentation data, where each sublist contains category_id followed by scaled (x, y) coordinates.
98
+ """
99
+ if annotations is None:
100
+ return None
101
+
102
+ seg_array_with_cat = []
103
  h, w = image_dimensions["height"], image_dimensions["width"]
104
  for anno in annotations:
105
  category_id = anno["category_id"]
106
+ seg_list = [item for sublist in anno["segmentation"] for item in sublist]
107
+ scaled_seg_data = (
108
+ np.array(seg_list).reshape(-1, 2) / [w, h]
109
+ ).tolist() # make the list group in x, y pairs and scaled with image width, height
110
+ scaled_flat_seg_data = [category_id] + list(chain(*scaled_seg_data)) # flatten the scaled_seg_data list
111
+ seg_array_with_cat.append(scaled_flat_seg_data)
112
+
113
+ return seg_array_with_cat
114
 
115
 
116
  class YoloDataset(Dataset):
 
164
  images_path = path.join(dataset_path, "images", phase_name)
165
  labels_path, data_type = find_labels_path(dataset_path, phase_name)
166
  images_list = sorted(os.listdir(images_path))
 
 
 
167
  if data_type == "json":
168
+ annotations_index, image_info_dict = create_image_info_dict(labels_path)
 
 
169
 
170
+ data = []
171
+ valid_inputs = 0
172
+ for image_name in tqdm(images_list, desc="Filtering data"):
173
+ if not image_name.lower().endswith((".jpg", ".jpeg", ".png")):
174
+ continue
175
+ image_id, _ = path.splitext(image_name)
176
+
177
+ if data_type == "json":
178
+ image_info = image_info_dict.get(image_id, None)
179
+ if image_info is None:
180
+ continue
181
+ annotations = annotations_index.get(image_info["id"], [])
182
+ image_seg_annotations = get_scaled_segmentation(annotations, image_info)
183
+ if not image_seg_annotations:
184
  continue
185
+
186
+ elif data_type == "txt":
187
+ label_path = path.join(labels_path, f"{image_id}.txt")
188
+ if not path.isfile(label_path):
 
 
 
 
 
 
 
 
 
 
 
 
189
  continue
190
+ with open(label_path, "r") as file:
191
+ image_seg_annotations = [
192
+ list(map(float, line.strip().split())) for line in file
193
+ ] # add a comment for this line, complicated, do you need "list", im not sure
194
+
195
+ labels = self.load_valid_labels(image_id, image_seg_annotations)
196
+ if labels is not None:
197
  img_path = path.join(images_path, image_name)
198
+ data.append((img_path, labels))
199
+ valid_inputs += 1
 
 
 
 
 
 
 
 
 
 
 
200
 
201
  logger.info("Recorded {}/{} valid inputs", valid_inputs, len(images_list))
202
  return data