henry000 commited on
Commit
d5ba31a
Β·
1 Parent(s): ea725df

🚚 [Rename] tools and utils, move the function!

Browse files
model/yolo.py CHANGED
@@ -1,26 +1,10 @@
1
- import inspect
2
  from typing import Any, Dict, List, Union
3
 
4
  import torch
5
  import torch.nn as nn
6
  from loguru import logger
7
  from omegaconf import OmegaConf
8
-
9
- from model import module
10
- from utils.tools import load_model_cfg
11
-
12
-
13
- def get_layer_map():
14
- """
15
- Dynamically generates a dictionary mapping class names to classes,
16
- filtering to include only those that are subclasses of nn.Module,
17
- ensuring they are relevant neural network layers.
18
- """
19
- layer_map = {}
20
- for name, obj in inspect.getmembers(module, inspect.isclass):
21
- if issubclass(obj, nn.Module) and obj is not nn.Module:
22
- layer_map[name] = obj
23
- return layer_map
24
 
25
 
26
  class YOLO(nn.Module):
 
 
1
  from typing import Any, Dict, List, Union
2
 
3
  import torch
4
  import torch.nn as nn
5
  from loguru import logger
6
  from omegaconf import OmegaConf
7
+ from tools.layer_helper import get_layer_map
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
 
10
  class YOLO(nn.Module):
tools/layer_helper.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import torch.nn as nn
3
+ from model import module
4
+
5
+
6
+ def auto_pad():
7
+ raise NotImplementedError
8
+
9
+
10
+ def get_layer_map():
11
+ """
12
+ Dynamically generates a dictionary mapping class names to classes,
13
+ filtering to include only those that are subclasses of nn.Module,
14
+ ensuring they are relevant neural network layers.
15
+ """
16
+ layer_map = {}
17
+ for name, obj in inspect.getmembers(module, inspect.isclass):
18
+ if issubclass(obj, nn.Module) and obj is not nn.Module:
19
+ layer_map[name] = obj
20
+ return layer_map
tools/log_helper.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Module for initializing logging tools used in machine learning and data processing.
3
+ Supports integration with Weights & Biases (wandb), Loguru, TensorBoard, and other
4
+ logging frameworks as needed.
5
+
6
+ This setup ensures consistent logging across various platforms, facilitating
7
+ effective monitoring and debugging.
8
+
9
+ Example:
10
+ from tools.logger import custom_logger
11
+ custom_logger()
12
+ """
13
+
14
+ import sys
15
+ from loguru import logger
16
+
17
+
18
+ def custom_logger():
19
+ logger.remove()
20
+ logger.add(
21
+ sys.stderr,
22
+ format="<green>{time:MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | <level>{message}</level>",
23
+ )
train.py CHANGED
@@ -1,7 +1,6 @@
1
- import argparse
2
  from loguru import logger
3
  from model.yolo import get_model
4
- from utils.tools import load_model_cfg, custom_logger
5
  from utils.get_dataset import download_coco_dataset
6
  import hydra
7
  from config.config import Config
 
 
1
  from loguru import logger
2
  from model.yolo import get_model
3
+ from tools.log_helper import custom_logger
4
  from utils.get_dataset import download_coco_dataset
5
  import hydra
6
  from config.config import Config
utils/get_dataset.py CHANGED
@@ -78,7 +78,7 @@ def download_coco_dataset(download_cfg):
78
 
79
 
80
  if __name__ == "__main__":
81
- from tools import custom_logger
82
 
83
  custom_logger()
84
  download_coco_dataset()
 
78
 
79
 
80
  if __name__ == "__main__":
81
+ from tools.log_helper import custom_logger
82
 
83
  custom_logger()
84
  download_coco_dataset()
utils/tools.py DELETED
@@ -1,72 +0,0 @@
1
- import os
2
- import sys
3
- import yaml
4
- from loguru import logger
5
- from typing import Dict, Any
6
-
7
-
8
- def complete_path(file_name: str = "v7-base.yaml") -> str:
9
- """
10
- Ensures the path to a model configuration is a existing file
11
-
12
- Parameters:
13
- file_name (str): The filename or path, with default 'v7-base.yaml'.
14
-
15
- Returns:
16
- str: A complete path with necessary prefix and extension.
17
- """
18
- # Ensure the file has the '.yaml' extension if missing
19
- if not file_name.endswith(".yaml"):
20
- file_name += ".yaml"
21
-
22
- # Add folder prefix if only the filename is provided
23
- if os.path.dirname(file_name) == "":
24
- file_name = os.path.join("./config/model", file_name)
25
-
26
- return file_name
27
-
28
-
29
- def load_model_cfg(file_path: str) -> Dict[str, Any]:
30
- """
31
- Read a YAML configuration file, ensure necessary keys are present, and return its content as a dictionary.
32
-
33
- Args:
34
- file_path (str): The path to the YAML configuration file.
35
-
36
- Returns:
37
- Dict[str, Any]: The contents of the YAML file as a dictionary.
38
-
39
- Raises:
40
- FileNotFoundError: If the YAML file cannot be found.
41
- yaml.YAMLError: If there is an error parsing the YAML file.
42
- """
43
- file_path = complete_path(file_path)
44
- try:
45
- with open(file_path, "r") as file:
46
- model_cfg = yaml.safe_load(file) or {}
47
-
48
- # Check for required keys and set defaults if not present
49
- if "nc" not in model_cfg:
50
- model_cfg["nc"] = 80
51
- logger.warning("'nc' not found in the YAML file. Setting default 'nc' to 80.")
52
-
53
- if "model" not in model_cfg:
54
- logger.error("'model' is missing in the configuration file.")
55
- raise ValueError("Missing required key: 'model'")
56
-
57
- return model_cfg
58
-
59
- except FileNotFoundError:
60
- logger.error(f"YAML file not found: {file_path}")
61
- raise
62
- except yaml.YAMLError as e:
63
- logger.error(f"Error parsing YAML file: {e}")
64
- raise
65
-
66
-
67
- def custom_logger():
68
- logger.remove()
69
- logger.add(
70
- sys.stderr,
71
- format="<green>{time:MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | <level>{message}</level>",
72
- )