henry000 commited on
Commit
849d290
Β·
1 Parent(s): 6c61696

🚚 [Move] the dataset helper func to tools/dataset

Browse files
Files changed (2) hide show
  1. tools/dataset_helper.py +103 -0
  2. utils/dataloader.py +10 -104
tools/dataset_helper.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from itertools import chain
4
+ from os import path
5
+ from typing import Any, Dict, List, Optional, Tuple
6
+
7
+ import numpy as np
8
+
9
+
10
+ def find_labels_path(dataset_path: str, phase_name: str):
11
+ """
12
+ Find the path to label files for a specified dataset and phase(e.g. training).
13
+
14
+ Args:
15
+ dataset_path (str): The path to the root directory of the dataset.
16
+ phase_name (str): The name of the phase for which labels are being searched (e.g., "train", "val", "test").
17
+
18
+ Returns:
19
+ Tuple[str, str]: A tuple containing the path to the labels file and the file format ("json" or "txt").
20
+ """
21
+ json_labels_path = path.join(dataset_path, "annotations", f"instances_{phase_name}.json")
22
+
23
+ txt_labels_path = path.join(dataset_path, "label", phase_name)
24
+
25
+ if path.isfile(json_labels_path):
26
+ return json_labels_path, "json"
27
+
28
+ elif path.isdir(txt_labels_path):
29
+ txt_files = [f for f in os.listdir(txt_labels_path) if f.endswith(".txt")]
30
+ if txt_files:
31
+ return txt_labels_path, "txt"
32
+
33
+ raise FileNotFoundError("No labels found in the specified dataset path and phase name.")
34
+
35
+
36
+ def create_image_info_dict(labels_path: str) -> Tuple[Dict[str, List], Dict[str, Dict]]:
37
+ """
38
+ Create a dictionary containing image information and annotations indexed by image ID.
39
+
40
+ Args:
41
+ labels_path (str): The path to the annotation json file.
42
+
43
+ Returns:
44
+ - annotations_index: A dictionary where keys are image IDs and values are lists of annotations.
45
+ - image_info_dict: A dictionary where keys are image file names without extension and values are image information dictionaries.
46
+ """
47
+ with open(labels_path, "r") as file:
48
+ labels_data = json.load(file)
49
+ annotations_index = index_annotations_by_image(labels_data) # check lookup is a good name?
50
+ image_info_dict = {path.splitext(img["file_name"])[0]: img for img in labels_data["images"]}
51
+ return annotations_index, image_info_dict
52
+
53
+
54
+ def index_annotations_by_image(data: Dict[str, Any]):
55
+ """
56
+ Use image index to lookup every annotations
57
+ Args:
58
+ data (Dict[str, Any]): A dictionary containing annotation data.
59
+
60
+ Returns:
61
+ Dict[int, List[Dict[str, Any]]]: A dictionary where keys are image IDs and values are lists of annotations.
62
+ Annotations with "iscrowd" set to True are excluded from the index.
63
+
64
+ """
65
+ annotation_lookup = {}
66
+ for anno in data["annotations"]:
67
+ if anno["iscrowd"]:
68
+ continue
69
+ image_id = anno["image_id"]
70
+ if image_id not in annotation_lookup:
71
+ annotation_lookup[image_id] = []
72
+ annotation_lookup[image_id].append(anno)
73
+ return annotation_lookup
74
+
75
+
76
+ def get_scaled_segmentation(
77
+ annotations: List[Dict[str, Any]], image_dimensions: Dict[str, int]
78
+ ) -> Optional[List[List[float]]]:
79
+ """
80
+ Scale the segmentation data based on image dimensions and return a list of scaled segmentation data.
81
+
82
+ Args:
83
+ annotations (List[Dict[str, Any]]): A list of annotation dictionaries.
84
+ image_dimensions (Dict[str, int]): A dictionary containing image dimensions (height and width).
85
+
86
+ Returns:
87
+ Optional[List[List[float]]]: A list of scaled segmentation data, where each sublist contains category_id followed by scaled (x, y) coordinates.
88
+ """
89
+ if annotations is None:
90
+ return None
91
+
92
+ seg_array_with_cat = []
93
+ h, w = image_dimensions["height"], image_dimensions["width"]
94
+ for anno in annotations:
95
+ category_id = anno["category_id"]
96
+ seg_list = [item for sublist in anno["segmentation"] for item in sublist]
97
+ scaled_seg_data = (
98
+ np.array(seg_list).reshape(-1, 2) / [w, h]
99
+ ).tolist() # make the list group in x, y pairs and scaled with image width, height
100
+ scaled_flat_seg_data = [category_id] + list(chain(*scaled_seg_data)) # flatten the scaled_seg_data list
101
+ seg_array_with_cat.append(scaled_flat_seg_data)
102
+
103
+ return seg_array_with_cat
utils/dataloader.py CHANGED
@@ -1,116 +1,24 @@
1
- import json
2
  import os
3
- from itertools import chain
4
- from os import listdir, path
5
- from typing import Any, Dict, List, Optional, Tuple, Union
6
 
7
  import diskcache as dc
8
  import hydra
9
  import numpy as np
10
  import torch
11
- from data_augment import Compose, HorizontalFlip, MixUp, Mosaic, VerticalFlip
12
- from drawer import draw_bboxes
13
  from loguru import logger
14
  from PIL import Image
15
  from torch.utils.data import DataLoader, Dataset
16
  from torchvision.transforms import functional as TF
17
  from tqdm.rich import tqdm
18
 
19
-
20
- def find_labels_path(dataset_path: str, phase_name: str):
21
- """
22
- Find the path to label files for a specified dataset and phase(e.g. training).
23
-
24
- Args:
25
- dataset_path (str): The path to the root directory of the dataset.
26
- phase_name (str): The name of the phase for which labels are being searched (e.g., "train", "val", "test").
27
-
28
- Returns:
29
- Tuple[str, str]: A tuple containing the path to the labels file and the file format ("json" or "txt").
30
- """
31
- json_labels_path = path.join(dataset_path, "annotations", f"instances_{phase_name}.json")
32
-
33
- txt_labels_path = path.join(dataset_path, "label", phase_name)
34
-
35
- if path.isfile(json_labels_path):
36
- return json_labels_path, "json"
37
-
38
- elif path.isdir(txt_labels_path):
39
- txt_files = [f for f in os.listdir(txt_labels_path) if f.endswith(".txt")]
40
- if txt_files:
41
- return txt_labels_path, "txt"
42
-
43
- raise FileNotFoundError("No labels found in the specified dataset path and phase name.")
44
-
45
-
46
- def create_image_info_dict(labels_path: str) -> Tuple[Dict[str, List], Dict[str, Dict]]:
47
- """
48
- Create a dictionary containing image information and annotations indexed by image ID.
49
-
50
- Args:
51
- labels_path (str): The path to the annotation json file.
52
-
53
- Returns:
54
- - annotations_index: A dictionary where keys are image IDs and values are lists of annotations.
55
- - image_info_dict: A dictionary where keys are image file names without extension and values are image information dictionaries.
56
- """
57
- with open(labels_path, "r") as file:
58
- labels_data = json.load(file)
59
- annotations_index = index_annotations_by_image(labels_data) # check lookup is a good name?
60
- image_info_dict = {path.splitext(img["file_name"])[0]: img for img in labels_data["images"]}
61
- return annotations_index, image_info_dict
62
-
63
-
64
- def index_annotations_by_image(data: Dict[str, Any]):
65
- """
66
- Use image index to lookup every annotations
67
- Args:
68
- data (Dict[str, Any]): A dictionary containing annotation data.
69
-
70
- Returns:
71
- Dict[int, List[Dict[str, Any]]]: A dictionary where keys are image IDs and values are lists of annotations.
72
- Annotations with "iscrowd" set to True are excluded from the index.
73
-
74
- """
75
- annotation_lookup = {}
76
- for anno in data["annotations"]:
77
- if anno["iscrowd"]:
78
- continue
79
- image_id = anno["image_id"]
80
- if image_id not in annotation_lookup:
81
- annotation_lookup[image_id] = []
82
- annotation_lookup[image_id].append(anno)
83
- return annotation_lookup
84
-
85
-
86
- def get_scaled_segmentation(
87
- annotations: List[Dict[str, Any]], image_dimensions: Dict[str, int]
88
- ) -> Optional[List[List[float]]]:
89
- """
90
- Scale the segmentation data based on image dimensions and return a list of scaled segmentation data.
91
-
92
- Args:
93
- annotations (List[Dict[str, Any]]): A list of annotation dictionaries.
94
- image_dimensions (Dict[str, int]): A dictionary containing image dimensions (height and width).
95
-
96
- Returns:
97
- Optional[List[List[float]]]: A list of scaled segmentation data, where each sublist contains category_id followed by scaled (x, y) coordinates.
98
- """
99
- if annotations is None:
100
- return None
101
-
102
- seg_array_with_cat = []
103
- h, w = image_dimensions["height"], image_dimensions["width"]
104
- for anno in annotations:
105
- category_id = anno["category_id"]
106
- seg_list = [item for sublist in anno["segmentation"] for item in sublist]
107
- scaled_seg_data = (
108
- np.array(seg_list).reshape(-1, 2) / [w, h]
109
- ).tolist() # make the list group in x, y pairs and scaled with image width, height
110
- scaled_flat_seg_data = [category_id] + list(chain(*scaled_seg_data)) # flatten the scaled_seg_data list
111
- seg_array_with_cat.append(scaled_flat_seg_data)
112
-
113
- return seg_array_with_cat
114
 
115
 
116
  class YoloDataset(Dataset):
@@ -188,9 +96,7 @@ class YoloDataset(Dataset):
188
  if not path.isfile(label_path):
189
  continue
190
  with open(label_path, "r") as file:
191
- image_seg_annotations = [
192
- list(map(float, line.strip().split())) for line in file
193
- ] # add a comment for this line, complicated, do you need "list", im not sure
194
 
195
  labels = self.load_valid_labels(image_id, image_seg_annotations)
196
  if labels is not None:
 
 
1
  import os
2
+ from os import path
3
+ from typing import List, Tuple, Union
 
4
 
5
  import diskcache as dc
6
  import hydra
7
  import numpy as np
8
  import torch
 
 
9
  from loguru import logger
10
  from PIL import Image
11
  from torch.utils.data import DataLoader, Dataset
12
  from torchvision.transforms import functional as TF
13
  from tqdm.rich import tqdm
14
 
15
+ from tools.dataset_helper import (
16
+ create_image_info_dict,
17
+ find_labels_path,
18
+ get_scaled_segmentation,
19
+ )
20
+ from utils.data_augment import Compose, HorizontalFlip, MixUp, Mosaic, VerticalFlip
21
+ from utils.drawer import draw_bboxes
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
 
24
  class YoloDataset(Dataset):
 
96
  if not path.isfile(label_path):
97
  continue
98
  with open(label_path, "r") as file:
99
+ image_seg_annotations = [list(map(float, line.strip().split())) for line in file]
 
 
100
 
101
  labels = self.load_valid_labels(image_id, image_seg_annotations)
102
  if labels is not None: