henry000 commited on
Commit
230a441
·
1 Parent(s): 7330b76

✨ [Add] autodownload model wieght!

Browse files
yolo/model/yolo.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from typing import Any, Dict, List, Union
2
 
3
  import torch
@@ -6,6 +7,7 @@ from loguru import logger
6
  from omegaconf import ListConfig, OmegaConf
7
 
8
  from yolo.config.config import Config, Model, YOLOLayer
 
9
  from yolo.tools.drawer import draw_model
10
  from yolo.utils.logging_utils import log_model_structure
11
  from yolo.utils.module_utils import get_layer_map
@@ -127,8 +129,13 @@ def get_model(cfg: Config) -> YOLO:
127
  model = YOLO(cfg.model, cfg.class_num)
128
  logger.info("✅ Success load model")
129
  if cfg.weight:
130
- model.model.load_state_dict(torch.load(cfg.weight))
131
- logger.info("✅ Success load model weight")
 
 
 
 
 
132
  log_model_structure(model.model)
133
  draw_model(model=model)
134
  return model
 
1
+ import os
2
  from typing import Any, Dict, List, Union
3
 
4
  import torch
 
7
  from omegaconf import ListConfig, OmegaConf
8
 
9
  from yolo.config.config import Config, Model, YOLOLayer
10
+ from yolo.tools.dataset_preparation import prepare_weight
11
  from yolo.tools.drawer import draw_model
12
  from yolo.utils.logging_utils import log_model_structure
13
  from yolo.utils.module_utils import get_layer_map
 
129
  model = YOLO(cfg.model, cfg.class_num)
130
  logger.info("✅ Success load model")
131
  if cfg.weight:
132
+ if os.path.exists(cfg.weight):
133
+ model.model.load_state_dict(torch.load(cfg.weight))
134
+ logger.info("✅ Success load model weight")
135
+ else:
136
+ logger.info(f"🌐 Weight {cfg.weight} not found, try downloading")
137
+ prepare_weight(weight_name=cfg.weight)
138
+
139
  log_model_structure(model.model)
140
  draw_model(model=model)
141
  return model
yolo/tools/dataset_preparation.py CHANGED
@@ -1,10 +1,10 @@
1
  import os
2
  import zipfile
 
3
 
4
  import requests
5
- from hydra import main
6
  from loguru import logger
7
- from tqdm import tqdm
8
 
9
  from yolo.config.config import DatasetConfig
10
 
@@ -13,18 +13,24 @@ def download_file(url, destination):
13
  """
14
  Downloads a file from the specified URL to the destination path with progress logging.
15
  """
16
- logger.info(f"Downloading {os.path.basename(destination)}...")
17
  with requests.get(url, stream=True) as response:
18
  response.raise_for_status()
19
  total_size = int(response.headers.get("content-length", 0))
20
- progress = tqdm(total=total_size, unit="iB", unit_scale=True, desc=os.path.basename(destination), leave=True)
21
-
22
- with open(destination, "wb") as file:
23
- for data in response.iter_content(chunk_size=1024 * 1024): # 1 MB chunks
24
- file.write(data)
25
- progress.update(len(data))
26
- progress.close()
27
- logger.info("Download completed.")
 
 
 
 
 
 
 
28
 
29
 
30
  def unzip_file(source, destination):
@@ -46,7 +52,6 @@ def check_files(directory, expected_count=None):
46
  return len(files) == expected_count if expected_count is not None else bool(files)
47
 
48
 
49
- @main(config_path="../config/data", config_name="download", version_base=None)
50
  def prepare_dataset(cfg: DatasetConfig):
51
  """
52
  Prepares dataset by downloading and unzipping if necessary.
@@ -76,6 +81,19 @@ def prepare_dataset(cfg: DatasetConfig):
76
  logger.error(f"Error verifying the {dataset_type} dataset after extraction.")
77
 
78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  if __name__ == "__main__":
80
  import sys
81
 
@@ -83,4 +101,4 @@ if __name__ == "__main__":
83
  from utils.logging_utils import custom_logger
84
 
85
  custom_logger()
86
- prepare_dataset()
 
1
  import os
2
  import zipfile
3
+ from typing import Optional
4
 
5
  import requests
 
6
  from loguru import logger
7
+ from rich.progress import BarColumn, Progress, TextColumn, TimeRemainingColumn
8
 
9
  from yolo.config.config import DatasetConfig
10
 
 
13
  """
14
  Downloads a file from the specified URL to the destination path with progress logging.
15
  """
 
16
  with requests.get(url, stream=True) as response:
17
  response.raise_for_status()
18
  total_size = int(response.headers.get("content-length", 0))
19
+ with Progress(
20
+ TextColumn("[progress.description]{task.description}"),
21
+ BarColumn(),
22
+ "[progress.percentage]{task.percentage:>3.1f}%",
23
+ "•",
24
+ "{task.completed}/{task.total} bytes",
25
+ "•",
26
+ TimeRemainingColumn(),
27
+ ) as progress:
28
+ task = progress.add_task(f"📥 Downloading {os.path.basename(destination)}...", total=total_size)
29
+ with open(destination, "wb") as file:
30
+ for data in response.iter_content(chunk_size=1024 * 1024): # 1 MB chunks
31
+ file.write(data)
32
+ progress.update(task, advance=len(data))
33
+ logger.info("✅ Download completed.")
34
 
35
 
36
  def unzip_file(source, destination):
 
52
  return len(files) == expected_count if expected_count is not None else bool(files)
53
 
54
 
 
55
  def prepare_dataset(cfg: DatasetConfig):
56
  """
57
  Prepares dataset by downloading and unzipping if necessary.
 
81
  logger.error(f"Error verifying the {dataset_type} dataset after extraction.")
82
 
83
 
84
+ def prepare_weight(downlaod_link: Optional[str] = None, weight_name: str = "v9-c.pt"):
85
+ if downlaod_link is None:
86
+ downlaod_link = "https://github.com/WongKinYiu/yolov9mit/releases/download/v1.0-alpha/"
87
+ weight_link = f"{downlaod_link}{weight_name}"
88
+
89
+ if os.path.exists(weight_name):
90
+ logger.info(f"Weight file '{weight_name}' already exists.")
91
+ try:
92
+ download_file(weight_link, weight_name)
93
+ except requests.exceptions.RequestException as e:
94
+ logger.warning(f"Failed to download the weight file: {e}")
95
+
96
+
97
  if __name__ == "__main__":
98
  import sys
99
 
 
101
  from utils.logging_utils import custom_logger
102
 
103
  custom_logger()
104
+ prepare_weight()