File size: 7,574 Bytes
4775b4c fa09d11 183312f 306fc38 5556058 7d7e199 31cab2b afa32b4 230a441 0174b5b dcceddd b80dc1e 0a3c9de b80dc1e 0a3c9de 542860e 856cce6 f99f89b 97681c2 92614f8 802cb12 7c6ce21 200b5c1 802cb12 5556058 97681c2 5556058 92614f8 5556058 78e3679 97681c2 d44dbc0 5556058 f99f89b 97681c2 5556058 92614f8 97681c2 92614f8 97681c2 92614f8 97681c2 0f9ffa2 5556058 542860e 198ddce 2dd2ae5 198ddce 5556058 542860e 5727efb 198ddce 3186e72 2dd2ae5 3186e72 542860e 2ae070c 97681c2 5556058 2ae070c 542860e 2ae070c 97681c2 92614f8 5727efb 92614f8 5727efb 92614f8 97681c2 542860e 0f9ffa2 5556058 0f9ffa2 5556058 5727efb 542860e 97681c2 b80dc1e 4775b4c 802cb12 6972568 4775b4c 802cb12 4775b4c b80dc1e fa09d11 b80dc1e 40afe67 0a3c9de 40afe67 f5518c0 fa09d11 0ff0fd0 4775b4c 0ff0fd0 fa09d11 40afe67 fa09d11 4775b4c 802cb12 200b5c1 802cb12 7d7e199 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 |
from collections import OrderedDict
from pathlib import Path
from typing import Dict, List, Union
import torch
from omegaconf import ListConfig, OmegaConf
from torch import nn
from yolo.config.config import ModelConfig, YOLOLayer
from yolo.tools.dataset_preparation import prepare_weight
from yolo.utils.logger import logger
from yolo.utils.module_utils import get_layer_map
class YOLO(nn.Module):
"""
A preliminary YOLO (You Only Look Once) model class still under development.
Parameters:
model_cfg: Configuration for the YOLO model. Expected to define the layers,
parameters, and any other relevant configuration details.
"""
def __init__(self, model_cfg: ModelConfig, class_num: int = 80):
super(YOLO, self).__init__()
self.num_classes = class_num
self.layer_map = get_layer_map() # Get the map Dict[str: Module]
self.model: List[YOLOLayer] = nn.ModuleList()
self.reg_max = getattr(model_cfg.anchor, "reg_max", 16)
self.build_model(model_cfg.model)
def build_model(self, model_arch: Dict[str, List[Dict[str, Dict[str, Dict]]]]):
self.layer_index = {}
output_dim, layer_idx = [3], 1
logger.info(f":tractor: Building YOLO")
for arch_name in model_arch:
if model_arch[arch_name]:
logger.info(f" :building_construction: Building {arch_name}")
for layer_idx, layer_spec in enumerate(model_arch[arch_name], start=layer_idx):
layer_type, layer_info = next(iter(layer_spec.items()))
layer_args = layer_info.get("args", {})
# Get input source
source = self.get_source_idx(layer_info.get("source", -1), layer_idx)
# Find in channels
if any(module in layer_type for module in ["Conv", "ELAN", "ADown", "AConv", "CBLinear"]):
layer_args["in_channels"] = output_dim[source]
if any(module in layer_type for module in ["Detection", "Segmentation", "Classification"]):
if isinstance(source, list):
layer_args["in_channels"] = [output_dim[idx] for idx in source]
else:
layer_args["in_channel"] = output_dim[source]
layer_args["num_classes"] = self.num_classes
layer_args["reg_max"] = self.reg_max
# create layers
layer = self.create_layer(layer_type, source, layer_info, **layer_args)
self.model.append(layer)
if layer.tags:
if layer.tags in self.layer_index:
raise ValueError(f"Duplicate tag '{layer_info['tags']}' found.")
self.layer_index[layer.tags] = layer_idx
out_channels = self.get_out_channels(layer_type, layer_args, output_dim, source)
output_dim.append(out_channels)
setattr(layer, "out_c", out_channels)
layer_idx += 1
def forward(self, x):
y = {0: x}
output = dict()
for index, layer in enumerate(self.model, start=1):
if isinstance(layer.source, list):
model_input = [y[idx] for idx in layer.source]
else:
model_input = y[layer.source]
x = layer(model_input)
y[-1] = x
if layer.usable:
y[index] = x
if layer.output:
output[layer.tags] = x
return output
def get_out_channels(self, layer_type: str, layer_args: dict, output_dim: list, source: Union[int, list]):
if hasattr(layer_args, "out_channels"):
return layer_args["out_channels"]
if layer_type == "CBFuse":
return output_dim[source[-1]]
if isinstance(source, int):
return output_dim[source]
if isinstance(source, list):
return sum(output_dim[idx] for idx in source)
def get_source_idx(self, source: Union[ListConfig, str, int], layer_idx: int):
if isinstance(source, ListConfig):
return [self.get_source_idx(index, layer_idx) for index in source]
if isinstance(source, str):
source = self.layer_index[source]
if source < -1:
source += layer_idx
if source > 0: # Using Previous Layer's Output
self.model[source - 1].usable = True
return source
def create_layer(self, layer_type: str, source: Union[int, list], layer_info: Dict, **kwargs) -> YOLOLayer:
if layer_type in self.layer_map:
layer = self.layer_map[layer_type](**kwargs)
setattr(layer, "layer_type", layer_type)
setattr(layer, "source", source)
setattr(layer, "in_c", kwargs.get("in_channels", None))
setattr(layer, "output", layer_info.get("output", False))
setattr(layer, "tags", layer_info.get("tags", None))
setattr(layer, "usable", 0)
return layer
else:
raise ValueError(f"Unsupported layer type: {layer_type}")
def save_load_weights(self, weights: Union[Path, OrderedDict]):
"""
Update the model's weights with the provided weights.
args:
weights: A OrderedDict containing the new weights.
"""
if isinstance(weights, Path):
weights = torch.load(weights, map_location=torch.device("cpu"), weights_only=False)
if "model_state_dict" in weights:
weights = weights["model_state_dict"]
model_state_dict = self.model.state_dict()
# TODO1: autoload old version weight
# TODO2: weight transform if num_class difference
error_dict = {"Mismatch": set(), "Not Found": set()}
for model_key, model_weight in model_state_dict.items():
if model_key not in weights:
error_dict["Not Found"].add(tuple(model_key.split(".")[:-2]))
continue
if model_weight.shape != weights[model_key].shape:
error_dict["Mismatch"].add(tuple(model_key.split(".")[:-2]))
continue
model_state_dict[model_key] = weights[model_key]
for error_name, error_set in error_dict.items():
for weight_name in error_set:
logger.warning(f":warning: Weight {error_name} for key: {'.'.join(weight_name)}")
self.model.load_state_dict(model_state_dict)
def create_model(model_cfg: ModelConfig, weight_path: Union[bool, Path] = True, class_num: int = 80) -> YOLO:
"""Constructs and returns a model from a Dictionary configuration file.
Args:
config_file (dict): The configuration file of the model.
Returns:
YOLO: An instance of the model defined by the given configuration.
"""
OmegaConf.set_struct(model_cfg, False)
model = YOLO(model_cfg, class_num)
if weight_path:
if weight_path == True:
weight_path = Path("weights") / f"{model_cfg.name}.pt"
elif isinstance(weight_path, str):
weight_path = Path(weight_path)
if not weight_path.exists():
logger.info(f"🌐 Weight {weight_path} not found, try downloading")
prepare_weight(weight_path=weight_path)
if weight_path.exists():
model.save_load_weights(weight_path)
logger.info(":white_check_mark: Success load model & weight")
else:
logger.info(":white_check_mark: Success load model")
return model
|