🔥 [Remove] utils and config again, move to yolo/
Browse files- config/config.py +0 -72
- config/config.yaml +0 -11
- config/data/augmentation.yaml +0 -3
- config/data/download.yaml +0 -17
- config/hyper/default.yaml +0 -19
- utils/data_augment.py +0 -125
- utils/dataloader.py +0 -186
- utils/drawer.py +0 -41
- utils/get_dataset.py +0 -84
- utils/loss.py +0 -2
config/config.py
DELETED
@@ -1,72 +0,0 @@
|
|
1 |
-
from dataclasses import dataclass
|
2 |
-
from typing import Dict, List, Union
|
3 |
-
|
4 |
-
|
5 |
-
@dataclass
|
6 |
-
class Model:
|
7 |
-
anchor: List[List[int]]
|
8 |
-
model: Dict[str, List[Dict[str, Union[Dict, List, int]]]]
|
9 |
-
|
10 |
-
|
11 |
-
@dataclass
|
12 |
-
class Download:
|
13 |
-
auto: bool
|
14 |
-
path: str
|
15 |
-
|
16 |
-
|
17 |
-
@dataclass
|
18 |
-
class DataLoaderConfig:
|
19 |
-
batch_size: int
|
20 |
-
shuffle: bool
|
21 |
-
num_workers: int
|
22 |
-
pin_memory: bool
|
23 |
-
|
24 |
-
|
25 |
-
@dataclass
|
26 |
-
class OptimizerArgs:
|
27 |
-
lr: float
|
28 |
-
weight_decay: float
|
29 |
-
|
30 |
-
|
31 |
-
@dataclass
|
32 |
-
class OptimizerConfig:
|
33 |
-
type: str
|
34 |
-
args: OptimizerArgs
|
35 |
-
|
36 |
-
|
37 |
-
@dataclass
|
38 |
-
class SchedulerArgs:
|
39 |
-
step_size: int
|
40 |
-
gamma: float
|
41 |
-
|
42 |
-
|
43 |
-
@dataclass
|
44 |
-
class SchedulerConfig:
|
45 |
-
type: str
|
46 |
-
args: SchedulerArgs
|
47 |
-
|
48 |
-
|
49 |
-
@dataclass
|
50 |
-
class EMAConfig:
|
51 |
-
enabled: bool
|
52 |
-
decay: float
|
53 |
-
|
54 |
-
|
55 |
-
@dataclass
|
56 |
-
class TrainConfig:
|
57 |
-
optimizer: OptimizerConfig
|
58 |
-
scheduler: SchedulerConfig
|
59 |
-
ema: EMAConfig
|
60 |
-
|
61 |
-
|
62 |
-
@dataclass
|
63 |
-
class HyperConfig:
|
64 |
-
data: DataLoaderConfig
|
65 |
-
train: TrainConfig
|
66 |
-
|
67 |
-
|
68 |
-
@dataclass
|
69 |
-
class Config:
|
70 |
-
model: Model
|
71 |
-
download: Download
|
72 |
-
hyper: HyperConfig
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
config/config.yaml
DELETED
@@ -1,11 +0,0 @@
|
|
1 |
-
hydra:
|
2 |
-
run:
|
3 |
-
dir: ./runs
|
4 |
-
|
5 |
-
defaults:
|
6 |
-
- data: coco
|
7 |
-
- download: ../data/download
|
8 |
-
- augmentation: ../data/augmentation
|
9 |
-
- model: v7-base
|
10 |
-
- hyper: default
|
11 |
-
- _self_
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
config/data/augmentation.yaml
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
Mosaic: 1
|
2 |
-
# MixUp: 1
|
3 |
-
HorizontalFlip: 0.5
|
|
|
|
|
|
|
|
config/data/download.yaml
DELETED
@@ -1,17 +0,0 @@
|
|
1 |
-
auto: True
|
2 |
-
path: data/coco
|
3 |
-
images:
|
4 |
-
base_url: http://images.cocodataset.org/zips/
|
5 |
-
datasets:
|
6 |
-
train:
|
7 |
-
file_name: train2017.zip
|
8 |
-
file_num: 118287
|
9 |
-
val:
|
10 |
-
file_name: val2017.zip
|
11 |
-
num_files: 5000
|
12 |
-
test:
|
13 |
-
file_name: test2017.zip
|
14 |
-
num_files: 40670
|
15 |
-
hydra:
|
16 |
-
run:
|
17 |
-
dir: ./runs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
config/hyper/default.yaml
DELETED
@@ -1,19 +0,0 @@
|
|
1 |
-
data:
|
2 |
-
batch_size: 4
|
3 |
-
shuffle: True
|
4 |
-
num_workers: 4
|
5 |
-
pin_memory: True
|
6 |
-
train:
|
7 |
-
optimizer:
|
8 |
-
type: Adam
|
9 |
-
args:
|
10 |
-
lr: 0.001
|
11 |
-
weight_decay: 0.0001
|
12 |
-
scheduler:
|
13 |
-
type: StepLR
|
14 |
-
args:
|
15 |
-
step_size: 10
|
16 |
-
gamma: 0.1
|
17 |
-
ema:
|
18 |
-
enabled: true
|
19 |
-
decay: 0.995
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/data_augment.py
DELETED
@@ -1,125 +0,0 @@
|
|
1 |
-
import numpy as np
|
2 |
-
import torch
|
3 |
-
from PIL import Image
|
4 |
-
from torchvision.transforms import functional as TF
|
5 |
-
|
6 |
-
|
7 |
-
class Compose:
|
8 |
-
"""Composes several transforms together."""
|
9 |
-
|
10 |
-
def __init__(self, transforms, image_size: int = 640):
|
11 |
-
self.transforms = transforms
|
12 |
-
self.image_size = image_size
|
13 |
-
|
14 |
-
for transform in self.transforms:
|
15 |
-
if hasattr(transform, "set_parent"):
|
16 |
-
transform.set_parent(self)
|
17 |
-
|
18 |
-
def __call__(self, image, boxes):
|
19 |
-
for transform in self.transforms:
|
20 |
-
image, boxes = transform(image, boxes)
|
21 |
-
return image, boxes
|
22 |
-
|
23 |
-
|
24 |
-
class HorizontalFlip:
|
25 |
-
"""Randomly horizontally flips the image along with the bounding boxes."""
|
26 |
-
|
27 |
-
def __init__(self, prob=0.5):
|
28 |
-
self.prob = prob
|
29 |
-
|
30 |
-
def __call__(self, image, boxes):
|
31 |
-
if torch.rand(1) < self.prob:
|
32 |
-
image = TF.hflip(image)
|
33 |
-
boxes[:, [1, 3]] = 1 - boxes[:, [3, 1]]
|
34 |
-
return image, boxes
|
35 |
-
|
36 |
-
|
37 |
-
class VerticalFlip:
|
38 |
-
"""Randomly vertically flips the image along with the bounding boxes."""
|
39 |
-
|
40 |
-
def __init__(self, prob=0.5):
|
41 |
-
self.prob = prob
|
42 |
-
|
43 |
-
def __call__(self, image, boxes):
|
44 |
-
if torch.rand(1) < self.prob:
|
45 |
-
image = TF.vflip(image)
|
46 |
-
boxes[:, [2, 4]] = 1 - boxes[:, [4, 2]]
|
47 |
-
return image, boxes
|
48 |
-
|
49 |
-
|
50 |
-
class Mosaic:
|
51 |
-
"""Applies the Mosaic augmentation to a batch of images and their corresponding boxes."""
|
52 |
-
|
53 |
-
def __init__(self, prob=0.5):
|
54 |
-
self.prob = prob
|
55 |
-
self.parent = None
|
56 |
-
|
57 |
-
def set_parent(self, parent):
|
58 |
-
self.parent = parent
|
59 |
-
|
60 |
-
def __call__(self, image, boxes):
|
61 |
-
if torch.rand(1) >= self.prob:
|
62 |
-
return image, boxes
|
63 |
-
|
64 |
-
assert self.parent is not None, "Parent is not set. Mosaic cannot retrieve image size."
|
65 |
-
|
66 |
-
img_sz = self.parent.image_size # Assuming `image_size` is defined in parent
|
67 |
-
more_data = self.parent.get_more_data(3) # get 3 more images randomly
|
68 |
-
|
69 |
-
data = [(image, boxes)] + more_data
|
70 |
-
mosaic_image = Image.new("RGB", (2 * img_sz, 2 * img_sz))
|
71 |
-
vectors = np.array([(-1, -1), (0, -1), (-1, 0), (0, 0)])
|
72 |
-
center = np.array([img_sz, img_sz])
|
73 |
-
all_labels = []
|
74 |
-
|
75 |
-
for (image, boxes), vector in zip(data, vectors):
|
76 |
-
this_w, this_h = image.size
|
77 |
-
coord = tuple(center + vector * np.array([this_w, this_h]))
|
78 |
-
|
79 |
-
mosaic_image.paste(image, coord)
|
80 |
-
xmin, ymin, xmax, ymax = boxes[:, 1], boxes[:, 2], boxes[:, 3], boxes[:, 4]
|
81 |
-
xmin = (xmin * this_w + coord[0]) / (2 * img_sz)
|
82 |
-
xmax = (xmax * this_w + coord[0]) / (2 * img_sz)
|
83 |
-
ymin = (ymin * this_h + coord[1]) / (2 * img_sz)
|
84 |
-
ymax = (ymax * this_h + coord[1]) / (2 * img_sz)
|
85 |
-
|
86 |
-
adjusted_boxes = torch.stack([boxes[:, 0], xmin, ymin, xmax, ymax], dim=1)
|
87 |
-
all_labels.append(adjusted_boxes)
|
88 |
-
|
89 |
-
all_labels = torch.cat(all_labels, dim=0)
|
90 |
-
mosaic_image = mosaic_image.resize((img_sz, img_sz))
|
91 |
-
return mosaic_image, all_labels
|
92 |
-
|
93 |
-
|
94 |
-
class MixUp:
|
95 |
-
"""Applies the MixUp augmentation to a pair of images and their corresponding boxes."""
|
96 |
-
|
97 |
-
def __init__(self, prob=0.5, alpha=1.0):
|
98 |
-
self.alpha = alpha
|
99 |
-
self.prob = prob
|
100 |
-
self.parent = None
|
101 |
-
|
102 |
-
def set_parent(self, parent):
|
103 |
-
"""Set the parent dataset object for accessing dataset methods."""
|
104 |
-
self.parent = parent
|
105 |
-
|
106 |
-
def __call__(self, image, boxes):
|
107 |
-
if torch.rand(1) >= self.prob:
|
108 |
-
return image, boxes
|
109 |
-
|
110 |
-
assert self.parent is not None, "Parent is not set. MixUp cannot retrieve additional data."
|
111 |
-
|
112 |
-
# Retrieve another image and its boxes randomly from the dataset
|
113 |
-
image2, boxes2 = self.parent.get_more_data()[0]
|
114 |
-
|
115 |
-
# Calculate the mixup lambda parameter
|
116 |
-
lam = np.random.beta(self.alpha, self.alpha) if self.alpha > 0 else 0.5
|
117 |
-
|
118 |
-
# Mix images
|
119 |
-
image1, image2 = TF.to_tensor(image), TF.to_tensor(image2)
|
120 |
-
mixed_image = lam * image1 + (1 - lam) * image2
|
121 |
-
|
122 |
-
# Mix bounding boxes
|
123 |
-
mixed_boxes = torch.cat([lam * boxes, (1 - lam) * boxes2])
|
124 |
-
|
125 |
-
return TF.to_pil_image(mixed_image), mixed_boxes
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/dataloader.py
DELETED
@@ -1,186 +0,0 @@
|
|
1 |
-
from os import listdir, path
|
2 |
-
from typing import List, Tuple, Union
|
3 |
-
|
4 |
-
import diskcache as dc
|
5 |
-
import hydra
|
6 |
-
import numpy as np
|
7 |
-
import torch
|
8 |
-
from loguru import logger
|
9 |
-
from PIL import Image
|
10 |
-
from torch.utils.data import DataLoader, Dataset
|
11 |
-
from torchvision.transforms import functional as TF
|
12 |
-
from tqdm.rich import tqdm
|
13 |
-
|
14 |
-
from utils.data_augment import Compose, HorizontalFlip, MixUp, Mosaic, VerticalFlip
|
15 |
-
from utils.drawer import draw_bboxes
|
16 |
-
|
17 |
-
|
18 |
-
class YoloDataset(Dataset):
|
19 |
-
def __init__(self, config: dict, phase: str = "train", image_size: int = 640):
|
20 |
-
dataset_cfg = config.data
|
21 |
-
augment_cfg = config.augmentation
|
22 |
-
phase_name = dataset_cfg.get(phase, phase)
|
23 |
-
self.image_size = image_size
|
24 |
-
|
25 |
-
transforms = [eval(aug)(prob) for aug, prob in augment_cfg.items()]
|
26 |
-
self.transform = Compose(transforms, self.image_size)
|
27 |
-
self.transform.get_more_data = self.get_more_data
|
28 |
-
self.data = self.load_data(dataset_cfg.path, phase_name)
|
29 |
-
|
30 |
-
def load_data(self, dataset_path, phase_name):
|
31 |
-
"""
|
32 |
-
Loads data from a cache or generates a new cache for a specific dataset phase.
|
33 |
-
|
34 |
-
Parameters:
|
35 |
-
dataset_path (str): The root path to the dataset directory.
|
36 |
-
phase_name (str): The specific phase of the dataset (e.g., 'train', 'test') to load or generate data for.
|
37 |
-
|
38 |
-
Returns:
|
39 |
-
dict: The loaded data from the cache for the specified phase.
|
40 |
-
"""
|
41 |
-
cache_path = path.join(dataset_path, ".cache")
|
42 |
-
cache = dc.Cache(cache_path)
|
43 |
-
data = cache.get(phase_name)
|
44 |
-
|
45 |
-
if data is None:
|
46 |
-
logger.info("Generating {} cache", phase_name)
|
47 |
-
images_path = path.join(dataset_path, phase_name, "images")
|
48 |
-
labels_path = path.join(dataset_path, phase_name, "labels")
|
49 |
-
data = self.filter_data(images_path, labels_path)
|
50 |
-
cache[phase_name] = data
|
51 |
-
|
52 |
-
cache.close()
|
53 |
-
logger.info("📦 Loaded {} cache", phase_name)
|
54 |
-
data = cache[phase_name]
|
55 |
-
return data
|
56 |
-
|
57 |
-
def filter_data(self, images_path: str, labels_path: str) -> list:
|
58 |
-
"""
|
59 |
-
Filters and collects dataset information by pairing images with their corresponding labels.
|
60 |
-
|
61 |
-
Parameters:
|
62 |
-
images_path (str): Path to the directory containing image files.
|
63 |
-
labels_path (str): Path to the directory containing label files.
|
64 |
-
|
65 |
-
Returns:
|
66 |
-
list: A list of tuples, each containing the path to an image file and its associated labels as a tensor.
|
67 |
-
"""
|
68 |
-
data = []
|
69 |
-
valid_inputs = 0
|
70 |
-
images_list = sorted(listdir(images_path))
|
71 |
-
for image_name in tqdm(images_list, desc="Filtering data"):
|
72 |
-
if not image_name.lower().endswith((".jpg", ".jpeg", ".png")):
|
73 |
-
continue
|
74 |
-
|
75 |
-
img_path = path.join(images_path, image_name)
|
76 |
-
base_name, _ = path.splitext(image_name)
|
77 |
-
label_path = path.join(labels_path, f"{base_name}.txt")
|
78 |
-
|
79 |
-
if path.isfile(label_path):
|
80 |
-
labels = self.load_valid_labels(label_path)
|
81 |
-
if labels is not None:
|
82 |
-
data.append((img_path, labels))
|
83 |
-
valid_inputs += 1
|
84 |
-
|
85 |
-
logger.info("Recorded {}/{} valid inputs", valid_inputs, len(images_list))
|
86 |
-
return data
|
87 |
-
|
88 |
-
def load_valid_labels(self, label_path: str) -> Union[torch.Tensor, None]:
|
89 |
-
"""
|
90 |
-
Loads and validates bounding box data is [0, 1] from a label file.
|
91 |
-
|
92 |
-
Parameters:
|
93 |
-
label_path (str): The filepath to the label file containing bounding box data.
|
94 |
-
|
95 |
-
Returns:
|
96 |
-
torch.Tensor or None: A tensor of all valid bounding boxes if any are found; otherwise, None.
|
97 |
-
"""
|
98 |
-
bboxes = []
|
99 |
-
with open(label_path, "r") as file:
|
100 |
-
for line in file:
|
101 |
-
parts = list(map(float, line.strip().split()))
|
102 |
-
cls = parts[0]
|
103 |
-
points = np.array(parts[1:]).reshape(-1, 2)
|
104 |
-
valid_points = points[(points >= 0) & (points <= 1)].reshape(-1, 2)
|
105 |
-
if valid_points.size > 1:
|
106 |
-
bbox = torch.tensor([cls, *valid_points.min(axis=0), *valid_points.max(axis=0)])
|
107 |
-
bboxes.append(bbox)
|
108 |
-
|
109 |
-
if bboxes:
|
110 |
-
return torch.stack(bboxes)
|
111 |
-
else:
|
112 |
-
logger.warning("No valid BBox in {}", label_path)
|
113 |
-
return None
|
114 |
-
|
115 |
-
def get_data(self, idx):
|
116 |
-
img_path, bboxes = self.data[idx]
|
117 |
-
img = Image.open(img_path).convert("RGB")
|
118 |
-
return img, bboxes
|
119 |
-
|
120 |
-
def get_more_data(self, num: int = 1):
|
121 |
-
indices = torch.randint(0, len(self), (num,))
|
122 |
-
return [self.get_data(idx) for idx in indices]
|
123 |
-
|
124 |
-
def __getitem__(self, idx) -> Union[Image.Image, torch.Tensor]:
|
125 |
-
img, bboxes = self.get_data(idx)
|
126 |
-
if self.transform:
|
127 |
-
img, bboxes = self.transform(img, bboxes)
|
128 |
-
img = TF.to_tensor(img)
|
129 |
-
return img, bboxes
|
130 |
-
|
131 |
-
def __len__(self) -> int:
|
132 |
-
return len(self.data)
|
133 |
-
|
134 |
-
|
135 |
-
class YoloDataLoader(DataLoader):
|
136 |
-
def __init__(self, config: dict):
|
137 |
-
"""Initializes the YoloDataLoader with hydra-config files."""
|
138 |
-
hyper = config.hyper.data
|
139 |
-
dataset = YoloDataset(config)
|
140 |
-
|
141 |
-
super().__init__(
|
142 |
-
dataset,
|
143 |
-
batch_size=hyper.batch_size,
|
144 |
-
shuffle=hyper.shuffle,
|
145 |
-
num_workers=hyper.num_workers,
|
146 |
-
pin_memory=hyper.pin_memory,
|
147 |
-
collate_fn=self.collate_fn,
|
148 |
-
)
|
149 |
-
|
150 |
-
def collate_fn(self, batch: List[Tuple[torch.Tensor, torch.Tensor]]) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
151 |
-
"""
|
152 |
-
A collate function to handle batching of images and their corresponding targets.
|
153 |
-
|
154 |
-
Args:
|
155 |
-
batch (list of tuples): Each tuple contains:
|
156 |
-
- image (torch.Tensor): The image tensor.
|
157 |
-
- labels (torch.Tensor): The tensor of labels for the image.
|
158 |
-
|
159 |
-
Returns:
|
160 |
-
Tuple[torch.Tensor, List[torch.Tensor]]: A tuple containing:
|
161 |
-
- A tensor of batched images.
|
162 |
-
- A list of tensors, each corresponding to bboxes for each image in the batch.
|
163 |
-
"""
|
164 |
-
images = torch.stack([item[0] for item in batch])
|
165 |
-
targets = [item[1] for item in batch]
|
166 |
-
return images, targets
|
167 |
-
|
168 |
-
|
169 |
-
def get_dataloader(config):
|
170 |
-
return YoloDataLoader(config)
|
171 |
-
|
172 |
-
|
173 |
-
@hydra.main(config_path="../config", config_name="config", version_base=None)
|
174 |
-
def main(cfg):
|
175 |
-
dataloader = get_dataloader(cfg)
|
176 |
-
draw_bboxes(next(iter(dataloader)))
|
177 |
-
|
178 |
-
|
179 |
-
if __name__ == "__main__":
|
180 |
-
import sys
|
181 |
-
|
182 |
-
sys.path.append("./")
|
183 |
-
from tools.log_helper import custom_logger
|
184 |
-
|
185 |
-
custom_logger()
|
186 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/drawer.py
DELETED
@@ -1,41 +0,0 @@
|
|
1 |
-
from typing import List, Union
|
2 |
-
|
3 |
-
import torch
|
4 |
-
from loguru import logger
|
5 |
-
from PIL import Image, ImageDraw, ImageFont
|
6 |
-
from torchvision.transforms.functional import to_pil_image
|
7 |
-
|
8 |
-
|
9 |
-
def draw_bboxes(img: Union[Image.Image, torch.Tensor], bboxes: List[List[Union[int, float]]]):
|
10 |
-
"""
|
11 |
-
Draw bounding boxes on an image.
|
12 |
-
|
13 |
-
Args:
|
14 |
-
- img (PIL Image or torch.Tensor): Image on which to draw the bounding boxes.
|
15 |
-
- bboxes (List of Lists/Tensors): Bounding boxes with [class_id, x_min, y_min, x_max, y_max],
|
16 |
-
where coordinates are normalized [0, 1].
|
17 |
-
"""
|
18 |
-
# Convert tensor image to PIL Image if necessary
|
19 |
-
if isinstance(img, torch.Tensor):
|
20 |
-
if img.dim() > 3:
|
21 |
-
logger.info("Multi-frame tensor detected, using the first image.")
|
22 |
-
img = img[0]
|
23 |
-
bboxes = bboxes[0]
|
24 |
-
img = to_pil_image(img)
|
25 |
-
|
26 |
-
draw = ImageDraw.Draw(img)
|
27 |
-
width, height = img.size
|
28 |
-
font = ImageFont.load_default(30)
|
29 |
-
|
30 |
-
for bbox in bboxes:
|
31 |
-
class_id, x_min, y_min, x_max, y_max = bbox
|
32 |
-
x_min = x_min * width
|
33 |
-
x_max = x_max * width
|
34 |
-
y_min = y_min * height
|
35 |
-
y_max = y_max * height
|
36 |
-
shape = [(x_min, y_min), (x_max, y_max)]
|
37 |
-
draw.rectangle(shape, outline="red", width=3)
|
38 |
-
draw.text((x_min, y_min), str(int(class_id)), font=font, fill="blue")
|
39 |
-
|
40 |
-
img.save("visualize.jpg") # Save the image with annotations
|
41 |
-
logger.info("Saved visualize image at visualize.png")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/get_dataset.py
DELETED
@@ -1,84 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import zipfile
|
3 |
-
|
4 |
-
import hydra
|
5 |
-
import requests
|
6 |
-
from loguru import logger
|
7 |
-
from tqdm.rich import tqdm
|
8 |
-
|
9 |
-
|
10 |
-
def download_file(url, dest_path):
|
11 |
-
"""
|
12 |
-
Downloads a file from a specified URL to a destination path with progress logging.
|
13 |
-
"""
|
14 |
-
logger.info(f"Downloading {os.path.basename(dest_path)}...")
|
15 |
-
with requests.get(url, stream=True) as r:
|
16 |
-
r.raise_for_status()
|
17 |
-
total_length = int(r.headers.get("content-length", 0))
|
18 |
-
with open(dest_path, "wb") as f, tqdm(
|
19 |
-
total=total_length, unit="iB", unit_scale=True, desc=os.path.basename(dest_path), leave=True
|
20 |
-
) as bar:
|
21 |
-
for chunk in r.iter_content(chunk_size=1024 * 1024):
|
22 |
-
f.write(chunk)
|
23 |
-
bar.update(len(chunk))
|
24 |
-
logger.info("Download complete!")
|
25 |
-
|
26 |
-
|
27 |
-
def unzip_file(zip_path, extract_to):
|
28 |
-
"""
|
29 |
-
Unzips a ZIP file to a specified directory.
|
30 |
-
"""
|
31 |
-
logger.info(f"Unzipping {os.path.basename(zip_path)}...")
|
32 |
-
with zipfile.ZipFile(zip_path, "r") as zip_ref:
|
33 |
-
zip_ref.extractall(extract_to)
|
34 |
-
os.remove(zip_path)
|
35 |
-
logger.info(f"Removed {zip_path}")
|
36 |
-
|
37 |
-
|
38 |
-
def check_files(directory, expected_count):
|
39 |
-
"""
|
40 |
-
Checks if the specified directory has the expected number of files.
|
41 |
-
"""
|
42 |
-
num_files = len([name for name in os.listdir(directory) if os.path.isfile(os.path.join(directory, name))])
|
43 |
-
return num_files == expected_count
|
44 |
-
|
45 |
-
|
46 |
-
@hydra.main(config_path="../config/data", config_name="download", version_base=None)
|
47 |
-
def prepare_dataset(download_cfg):
|
48 |
-
data_dir = download_cfg.path
|
49 |
-
base_url = download_cfg.images.base_url
|
50 |
-
datasets = download_cfg.images.datasets
|
51 |
-
|
52 |
-
for dataset_type in datasets:
|
53 |
-
file_name, expected_files = datasets[dataset_type].values()
|
54 |
-
url = f"{base_url}{file_name}"
|
55 |
-
local_zip_path = os.path.join(data_dir, file_name)
|
56 |
-
extract_to = os.path.join(data_dir, dataset_type, "images")
|
57 |
-
|
58 |
-
# Ensure the extraction directory exists
|
59 |
-
os.makedirs(extract_to, exist_ok=True)
|
60 |
-
|
61 |
-
# Check if the correct number of files exists
|
62 |
-
if check_files(extract_to, expected_files):
|
63 |
-
logger.info(f"✅ Dataset {dataset_type: >4} already verified.")
|
64 |
-
continue
|
65 |
-
|
66 |
-
if os.path.exists(local_zip_path):
|
67 |
-
logger.info(f"Dataset {dataset_type} already downloaded.")
|
68 |
-
else:
|
69 |
-
download_file(url, local_zip_path)
|
70 |
-
|
71 |
-
unzip_file(local_zip_path, extract_to)
|
72 |
-
|
73 |
-
print(os.path.exists(local_zip_path), check_files(extract_to, expected_files))
|
74 |
-
|
75 |
-
# Additional verification post extraction
|
76 |
-
if not check_files(extract_to, expected_files):
|
77 |
-
logger.error(f"Error in verifying the {dataset_type} dataset after extraction.")
|
78 |
-
|
79 |
-
|
80 |
-
if __name__ == "__main__":
|
81 |
-
from tools.log_helper import custom_logger
|
82 |
-
|
83 |
-
custom_logger()
|
84 |
-
prepare_dataset()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/loss.py
DELETED
@@ -1,2 +0,0 @@
|
|
1 |
-
def get_loss_function(*args, **kwargs):
|
2 |
-
raise NotImplementedError
|
|
|
|
|
|