π [Fix] Dataset autodownload bug, new dataset format
Browse files- config/config.py +19 -0
- config/data/download.yaml +16 -12
- utils/get_dataset.py +60 -60
config/config.py
CHANGED
@@ -14,6 +14,25 @@ class Download:
|
|
14 |
path: str
|
15 |
|
16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
@dataclass
|
18 |
class Config:
|
19 |
model: Model
|
|
|
14 |
path: str
|
15 |
|
16 |
|
17 |
+
@dataclass
|
18 |
+
class Dataset:
|
19 |
+
file_name: str
|
20 |
+
num_files: int
|
21 |
+
|
22 |
+
|
23 |
+
@dataclass
|
24 |
+
class Datasets:
|
25 |
+
base_url: str
|
26 |
+
images: Dict[str, Dataset]
|
27 |
+
|
28 |
+
|
29 |
+
@dataclass
|
30 |
+
class Download:
|
31 |
+
auto: bool
|
32 |
+
save_path: str
|
33 |
+
datasets: Datasets
|
34 |
+
|
35 |
+
|
36 |
@dataclass
|
37 |
class Config:
|
38 |
model: Model
|
config/data/download.yaml
CHANGED
@@ -1,17 +1,21 @@
|
|
1 |
auto: True
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
file_name: train2017
|
8 |
file_num: 118287
|
9 |
-
|
10 |
-
file_name: val2017
|
11 |
-
|
12 |
-
|
13 |
-
file_name: test2017
|
14 |
-
|
|
|
|
|
|
|
|
|
15 |
hydra:
|
16 |
run:
|
17 |
dir: ./runs
|
|
|
1 |
auto: True
|
2 |
+
save_path: data/coco
|
3 |
+
datasets:
|
4 |
+
images:
|
5 |
+
base_url: http://images.cocodataset.org/zips/
|
6 |
+
train2017:
|
7 |
+
file_name: train2017
|
8 |
file_num: 118287
|
9 |
+
val2017:
|
10 |
+
file_name: val2017
|
11 |
+
file_num: 5000
|
12 |
+
test2017:
|
13 |
+
file_name: test2017
|
14 |
+
file_num: 40670
|
15 |
+
annotations:
|
16 |
+
base_url: http://images.cocodataset.org/annotations/
|
17 |
+
annotations:
|
18 |
+
file_name: annotations_trainval2017
|
19 |
hydra:
|
20 |
run:
|
21 |
dir: ./runs
|
utils/get_dataset.py
CHANGED
@@ -1,83 +1,83 @@
|
|
1 |
import os
|
2 |
import zipfile
|
3 |
|
4 |
-
import hydra
|
5 |
import requests
|
|
|
6 |
from loguru import logger
|
7 |
-
from tqdm
|
8 |
|
9 |
|
10 |
-
def download_file(url,
|
11 |
"""
|
12 |
-
Downloads a file from
|
13 |
"""
|
14 |
-
logger.info(f"Downloading {os.path.basename(
|
15 |
-
with requests.get(url, stream=True) as
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
total=total_length, unit="iB", unit_scale=True, desc=os.path.basename(dest_path), leave=True
|
20 |
-
) as bar:
|
21 |
-
for chunk in r.iter_content(chunk_size=1024 * 1024):
|
22 |
-
f.write(chunk)
|
23 |
-
bar.update(len(chunk))
|
24 |
-
logger.info("Download complete!")
|
25 |
-
|
26 |
-
|
27 |
-
def unzip_file(zip_path, extract_to):
|
28 |
-
"""
|
29 |
-
Unzips a ZIP file to a specified directory.
|
30 |
-
"""
|
31 |
-
logger.info(f"Unzipping {os.path.basename(zip_path)}...")
|
32 |
-
with zipfile.ZipFile(zip_path, "r") as zip_ref:
|
33 |
-
zip_ref.extractall(extract_to)
|
34 |
-
os.remove(zip_path)
|
35 |
-
logger.info(f"Removed {zip_path}")
|
36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
|
38 |
-
|
|
|
39 |
"""
|
40 |
-
|
41 |
"""
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
def prepare_dataset(download_cfg):
|
48 |
-
data_dir = download_cfg.path
|
49 |
-
base_url = download_cfg.images.base_url
|
50 |
-
datasets = download_cfg.images.datasets
|
51 |
-
|
52 |
-
for dataset_type in datasets:
|
53 |
-
file_name, expected_files = datasets[dataset_type].values()
|
54 |
-
url = f"{base_url}{file_name}"
|
55 |
-
local_zip_path = os.path.join(data_dir, file_name)
|
56 |
-
extract_to = os.path.join(data_dir, dataset_type, "images")
|
57 |
-
|
58 |
-
# Ensure the extraction directory exists
|
59 |
-
os.makedirs(extract_to, exist_ok=True)
|
60 |
|
61 |
-
# Check if the correct number of files exists
|
62 |
-
if check_files(extract_to, expected_files):
|
63 |
-
logger.info(f"Dataset {dataset_type} already verified.")
|
64 |
-
continue
|
65 |
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
|
73 |
-
print(os.path.exists(local_zip_path), check_files(extract_to, expected_files))
|
74 |
|
75 |
-
|
76 |
-
|
77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
|
79 |
|
80 |
if __name__ == "__main__":
|
|
|
|
|
|
|
81 |
from tools.log_helper import custom_logger
|
82 |
|
83 |
custom_logger()
|
|
|
1 |
import os
|
2 |
import zipfile
|
3 |
|
|
|
4 |
import requests
|
5 |
+
from hydra import main
|
6 |
from loguru import logger
|
7 |
+
from tqdm import tqdm
|
8 |
|
9 |
|
10 |
+
def download_file(url, destination):
|
11 |
"""
|
12 |
+
Downloads a file from the specified URL to the destination path with progress logging.
|
13 |
"""
|
14 |
+
logger.info(f"Downloading {os.path.basename(destination)}...")
|
15 |
+
with requests.get(url, stream=True) as response:
|
16 |
+
response.raise_for_status()
|
17 |
+
total_size = int(response.headers.get("content-length", 0))
|
18 |
+
progress = tqdm(total=total_size, unit="iB", unit_scale=True, desc=os.path.basename(destination), leave=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
+
with open(destination, "wb") as file:
|
21 |
+
for data in response.iter_content(chunk_size=1024 * 1024): # 1 MB chunks
|
22 |
+
file.write(data)
|
23 |
+
progress.update(len(data))
|
24 |
+
progress.close()
|
25 |
+
logger.info("Download completed.")
|
26 |
|
27 |
+
|
28 |
+
def unzip_file(source, destination):
|
29 |
"""
|
30 |
+
Extracts a ZIP file to the specified directory and removes the ZIP file after extraction.
|
31 |
"""
|
32 |
+
logger.info(f"Unzipping {os.path.basename(source)}...")
|
33 |
+
with zipfile.ZipFile(source, "r") as zip_ref:
|
34 |
+
zip_ref.extractall(destination)
|
35 |
+
os.remove(source)
|
36 |
+
logger.info(f"Removed {source}.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
|
|
|
|
|
|
|
|
|
38 |
|
39 |
+
def check_files(directory, expected_count=None):
|
40 |
+
"""
|
41 |
+
Returns True if the number of files in the directory matches expected_count, False otherwise.
|
42 |
+
"""
|
43 |
+
files = [f for f in os.listdir(directory) if os.path.isfile(os.path.join(directory, f))]
|
44 |
+
return len(files) == expected_count if expected_count is not None else bool(files)
|
45 |
|
|
|
46 |
|
47 |
+
@main(config_path="../config/data", config_name="download", version_base=None)
|
48 |
+
def prepare_dataset(cfg):
|
49 |
+
"""
|
50 |
+
Prepares dataset by downloading and unzipping if necessary.
|
51 |
+
"""
|
52 |
+
data_dir = cfg.save_path
|
53 |
+
for data_type, settings in cfg.datasets.items():
|
54 |
+
base_url = settings["base_url"]
|
55 |
+
for dataset_type, dataset_args in settings.items():
|
56 |
+
if dataset_type == "base_url":
|
57 |
+
continue # Skip the base_url entry
|
58 |
+
file_name = f"{dataset_args.get('file_name', dataset_type)}.zip"
|
59 |
+
url = f"{base_url}{file_name}"
|
60 |
+
local_zip_path = os.path.join(data_dir, file_name)
|
61 |
+
extract_to = os.path.join(data_dir, data_type) if data_type != "annotations" else data_dir
|
62 |
+
final_place = os.path.join(extract_to, dataset_type)
|
63 |
+
|
64 |
+
os.makedirs(extract_to, exist_ok=True)
|
65 |
+
if check_files(final_place, dataset_args.get("file_num")):
|
66 |
+
logger.info(f"Dataset {dataset_type} already verified.")
|
67 |
+
continue
|
68 |
+
|
69 |
+
if not os.path.exists(local_zip_path):
|
70 |
+
download_file(url, local_zip_path)
|
71 |
+
unzip_file(local_zip_path, extract_to)
|
72 |
+
|
73 |
+
if not check_files(final_place, dataset_args.get("file_num")):
|
74 |
+
logger.error(f"Error verifying the {dataset_type} dataset after extraction.")
|
75 |
|
76 |
|
77 |
if __name__ == "__main__":
|
78 |
+
import sys
|
79 |
+
|
80 |
+
sys.path.append("./")
|
81 |
from tools.log_helper import custom_logger
|
82 |
|
83 |
custom_logger()
|