henry000 commited on
Commit
eff2849
Β·
1 Parent(s): 23db031

πŸ› [Fix] Dataset autodownload bug, new dataset format

Browse files
Files changed (3) hide show
  1. config/config.py +19 -0
  2. config/data/download.yaml +16 -12
  3. utils/get_dataset.py +60 -60
config/config.py CHANGED
@@ -14,6 +14,25 @@ class Download:
14
  path: str
15
 
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  @dataclass
18
  class Config:
19
  model: Model
 
14
  path: str
15
 
16
 
17
+ @dataclass
18
+ class Dataset:
19
+ file_name: str
20
+ num_files: int
21
+
22
+
23
+ @dataclass
24
+ class Datasets:
25
+ base_url: str
26
+ images: Dict[str, Dataset]
27
+
28
+
29
+ @dataclass
30
+ class Download:
31
+ auto: bool
32
+ save_path: str
33
+ datasets: Datasets
34
+
35
+
36
  @dataclass
37
  class Config:
38
  model: Model
config/data/download.yaml CHANGED
@@ -1,17 +1,21 @@
1
  auto: True
2
- path: data/coco
3
- images:
4
- base_url: http://images.cocodataset.org/zips/
5
- datasets:
6
- train:
7
- file_name: train2017.zip
8
  file_num: 118287
9
- val:
10
- file_name: val2017.zip
11
- num_files: 5000
12
- test:
13
- file_name: test2017.zip
14
- num_files: 40670
 
 
 
 
15
  hydra:
16
  run:
17
  dir: ./runs
 
1
  auto: True
2
+ save_path: data/coco
3
+ datasets:
4
+ images:
5
+ base_url: http://images.cocodataset.org/zips/
6
+ train2017:
7
+ file_name: train2017
8
  file_num: 118287
9
+ val2017:
10
+ file_name: val2017
11
+ file_num: 5000
12
+ test2017:
13
+ file_name: test2017
14
+ file_num: 40670
15
+ annotations:
16
+ base_url: http://images.cocodataset.org/annotations/
17
+ annotations:
18
+ file_name: annotations_trainval2017
19
  hydra:
20
  run:
21
  dir: ./runs
utils/get_dataset.py CHANGED
@@ -1,83 +1,83 @@
1
  import os
2
  import zipfile
3
 
4
- import hydra
5
  import requests
 
6
  from loguru import logger
7
- from tqdm.rich import tqdm
8
 
9
 
10
- def download_file(url, dest_path):
11
  """
12
- Downloads a file from a specified URL to a destination path with progress logging.
13
  """
14
- logger.info(f"Downloading {os.path.basename(dest_path)}...")
15
- with requests.get(url, stream=True) as r:
16
- r.raise_for_status()
17
- total_length = int(r.headers.get("content-length", 0))
18
- with open(dest_path, "wb") as f, tqdm(
19
- total=total_length, unit="iB", unit_scale=True, desc=os.path.basename(dest_path), leave=True
20
- ) as bar:
21
- for chunk in r.iter_content(chunk_size=1024 * 1024):
22
- f.write(chunk)
23
- bar.update(len(chunk))
24
- logger.info("Download complete!")
25
-
26
-
27
- def unzip_file(zip_path, extract_to):
28
- """
29
- Unzips a ZIP file to a specified directory.
30
- """
31
- logger.info(f"Unzipping {os.path.basename(zip_path)}...")
32
- with zipfile.ZipFile(zip_path, "r") as zip_ref:
33
- zip_ref.extractall(extract_to)
34
- os.remove(zip_path)
35
- logger.info(f"Removed {zip_path}")
36
 
 
 
 
 
 
 
37
 
38
- def check_files(directory, expected_count):
 
39
  """
40
- Checks if the specified directory has the expected number of files.
41
  """
42
- num_files = len([name for name in os.listdir(directory) if os.path.isfile(os.path.join(directory, name))])
43
- return num_files == expected_count
44
-
45
-
46
- @hydra.main(config_path="../config/data", config_name="download", version_base=None)
47
- def prepare_dataset(download_cfg):
48
- data_dir = download_cfg.path
49
- base_url = download_cfg.images.base_url
50
- datasets = download_cfg.images.datasets
51
-
52
- for dataset_type in datasets:
53
- file_name, expected_files = datasets[dataset_type].values()
54
- url = f"{base_url}{file_name}"
55
- local_zip_path = os.path.join(data_dir, file_name)
56
- extract_to = os.path.join(data_dir, dataset_type, "images")
57
-
58
- # Ensure the extraction directory exists
59
- os.makedirs(extract_to, exist_ok=True)
60
 
61
- # Check if the correct number of files exists
62
- if check_files(extract_to, expected_files):
63
- logger.info(f"Dataset {dataset_type} already verified.")
64
- continue
65
 
66
- if os.path.exists(local_zip_path):
67
- logger.info(f"Dataset {dataset_type} already downloaded.")
68
- else:
69
- download_file(url, local_zip_path)
70
-
71
- unzip_file(local_zip_path, extract_to)
72
 
73
- print(os.path.exists(local_zip_path), check_files(extract_to, expected_files))
74
 
75
- # Additional verification post extraction
76
- if not check_files(extract_to, expected_files):
77
- logger.error(f"Error in verifying the {dataset_type} dataset after extraction.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
 
80
  if __name__ == "__main__":
 
 
 
81
  from tools.log_helper import custom_logger
82
 
83
  custom_logger()
 
1
  import os
2
  import zipfile
3
 
 
4
  import requests
5
+ from hydra import main
6
  from loguru import logger
7
+ from tqdm import tqdm
8
 
9
 
10
+ def download_file(url, destination):
11
  """
12
+ Downloads a file from the specified URL to the destination path with progress logging.
13
  """
14
+ logger.info(f"Downloading {os.path.basename(destination)}...")
15
+ with requests.get(url, stream=True) as response:
16
+ response.raise_for_status()
17
+ total_size = int(response.headers.get("content-length", 0))
18
+ progress = tqdm(total=total_size, unit="iB", unit_scale=True, desc=os.path.basename(destination), leave=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
+ with open(destination, "wb") as file:
21
+ for data in response.iter_content(chunk_size=1024 * 1024): # 1 MB chunks
22
+ file.write(data)
23
+ progress.update(len(data))
24
+ progress.close()
25
+ logger.info("Download completed.")
26
 
27
+
28
+ def unzip_file(source, destination):
29
  """
30
+ Extracts a ZIP file to the specified directory and removes the ZIP file after extraction.
31
  """
32
+ logger.info(f"Unzipping {os.path.basename(source)}...")
33
+ with zipfile.ZipFile(source, "r") as zip_ref:
34
+ zip_ref.extractall(destination)
35
+ os.remove(source)
36
+ logger.info(f"Removed {source}.")
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
 
 
 
 
38
 
39
+ def check_files(directory, expected_count=None):
40
+ """
41
+ Returns True if the number of files in the directory matches expected_count, False otherwise.
42
+ """
43
+ files = [f for f in os.listdir(directory) if os.path.isfile(os.path.join(directory, f))]
44
+ return len(files) == expected_count if expected_count is not None else bool(files)
45
 
 
46
 
47
+ @main(config_path="../config/data", config_name="download", version_base=None)
48
+ def prepare_dataset(cfg):
49
+ """
50
+ Prepares dataset by downloading and unzipping if necessary.
51
+ """
52
+ data_dir = cfg.save_path
53
+ for data_type, settings in cfg.datasets.items():
54
+ base_url = settings["base_url"]
55
+ for dataset_type, dataset_args in settings.items():
56
+ if dataset_type == "base_url":
57
+ continue # Skip the base_url entry
58
+ file_name = f"{dataset_args.get('file_name', dataset_type)}.zip"
59
+ url = f"{base_url}{file_name}"
60
+ local_zip_path = os.path.join(data_dir, file_name)
61
+ extract_to = os.path.join(data_dir, data_type) if data_type != "annotations" else data_dir
62
+ final_place = os.path.join(extract_to, dataset_type)
63
+
64
+ os.makedirs(extract_to, exist_ok=True)
65
+ if check_files(final_place, dataset_args.get("file_num")):
66
+ logger.info(f"Dataset {dataset_type} already verified.")
67
+ continue
68
+
69
+ if not os.path.exists(local_zip_path):
70
+ download_file(url, local_zip_path)
71
+ unzip_file(local_zip_path, extract_to)
72
+
73
+ if not check_files(final_place, dataset_args.get("file_num")):
74
+ logger.error(f"Error verifying the {dataset_type} dataset after extraction.")
75
 
76
 
77
  if __name__ == "__main__":
78
+ import sys
79
+
80
+ sys.path.append("./")
81
  from tools.log_helper import custom_logger
82
 
83
  custom_logger()