|
from collections import OrderedDict |
|
from pathlib import Path |
|
from typing import Dict, List, Union |
|
|
|
import torch |
|
from omegaconf import ListConfig, OmegaConf |
|
from torch import nn |
|
|
|
from yolo.config.config import ModelConfig, YOLOLayer |
|
from yolo.tools.dataset_preparation import prepare_weight |
|
from yolo.utils.logger import logger |
|
from yolo.utils.module_utils import get_layer_map |
|
|
|
|
|
class YOLO(nn.Module): |
|
""" |
|
A preliminary YOLO (You Only Look Once) model class still under development. |
|
|
|
Parameters: |
|
model_cfg: Configuration for the YOLO model. Expected to define the layers, |
|
parameters, and any other relevant configuration details. |
|
""" |
|
|
|
def __init__(self, model_cfg: ModelConfig, class_num: int = 80): |
|
super(YOLO, self).__init__() |
|
self.num_classes = class_num |
|
self.layer_map = get_layer_map() |
|
self.model: List[YOLOLayer] = nn.ModuleList() |
|
self.reg_max = getattr(model_cfg.anchor, "reg_max", 16) |
|
self.build_model(model_cfg.model) |
|
|
|
def build_model(self, model_arch: Dict[str, List[Dict[str, Dict[str, Dict]]]]): |
|
self.layer_index = {} |
|
output_dim, layer_idx = [3], 1 |
|
logger.info(f":tractor: Building YOLO") |
|
for arch_name in model_arch: |
|
if model_arch[arch_name]: |
|
logger.info(f" :building_construction: Building {arch_name}") |
|
for layer_idx, layer_spec in enumerate(model_arch[arch_name], start=layer_idx): |
|
layer_type, layer_info = next(iter(layer_spec.items())) |
|
layer_args = layer_info.get("args", {}) |
|
|
|
|
|
source = self.get_source_idx(layer_info.get("source", -1), layer_idx) |
|
|
|
|
|
if any(module in layer_type for module in ["Conv", "ELAN", "ADown", "AConv", "CBLinear"]): |
|
layer_args["in_channels"] = output_dim[source] |
|
if any(module in layer_type for module in ["Detection", "Segmentation", "Classification"]): |
|
if isinstance(source, list): |
|
layer_args["in_channels"] = [output_dim[idx] for idx in source] |
|
else: |
|
layer_args["in_channel"] = output_dim[source] |
|
layer_args["num_classes"] = self.num_classes |
|
layer_args["reg_max"] = self.reg_max |
|
|
|
|
|
layer = self.create_layer(layer_type, source, layer_info, **layer_args) |
|
self.model.append(layer) |
|
|
|
if layer.tags: |
|
if layer.tags in self.layer_index: |
|
raise ValueError(f"Duplicate tag '{layer_info['tags']}' found.") |
|
self.layer_index[layer.tags] = layer_idx |
|
|
|
out_channels = self.get_out_channels(layer_type, layer_args, output_dim, source) |
|
output_dim.append(out_channels) |
|
setattr(layer, "out_c", out_channels) |
|
layer_idx += 1 |
|
|
|
def forward(self, x): |
|
y = {0: x} |
|
output = dict() |
|
for index, layer in enumerate(self.model, start=1): |
|
if isinstance(layer.source, list): |
|
model_input = [y[idx] for idx in layer.source] |
|
else: |
|
model_input = y[layer.source] |
|
x = layer(model_input) |
|
y[-1] = x |
|
if layer.usable: |
|
y[index] = x |
|
if layer.output: |
|
output[layer.tags] = x |
|
return output |
|
|
|
def get_out_channels(self, layer_type: str, layer_args: dict, output_dim: list, source: Union[int, list]): |
|
if hasattr(layer_args, "out_channels"): |
|
return layer_args["out_channels"] |
|
if layer_type == "CBFuse": |
|
return output_dim[source[-1]] |
|
if isinstance(source, int): |
|
return output_dim[source] |
|
if isinstance(source, list): |
|
return sum(output_dim[idx] for idx in source) |
|
|
|
def get_source_idx(self, source: Union[ListConfig, str, int], layer_idx: int): |
|
if isinstance(source, ListConfig): |
|
return [self.get_source_idx(index, layer_idx) for index in source] |
|
if isinstance(source, str): |
|
source = self.layer_index[source] |
|
if source < -1: |
|
source += layer_idx |
|
if source > 0: |
|
self.model[source - 1].usable = True |
|
return source |
|
|
|
def create_layer(self, layer_type: str, source: Union[int, list], layer_info: Dict, **kwargs) -> YOLOLayer: |
|
if layer_type in self.layer_map: |
|
layer = self.layer_map[layer_type](**kwargs) |
|
setattr(layer, "layer_type", layer_type) |
|
setattr(layer, "source", source) |
|
setattr(layer, "in_c", kwargs.get("in_channels", None)) |
|
setattr(layer, "output", layer_info.get("output", False)) |
|
setattr(layer, "tags", layer_info.get("tags", None)) |
|
setattr(layer, "usable", 0) |
|
return layer |
|
else: |
|
raise ValueError(f"Unsupported layer type: {layer_type}") |
|
|
|
def save_load_weights(self, weights: Union[Path, OrderedDict]): |
|
""" |
|
Update the model's weights with the provided weights. |
|
|
|
args: |
|
weights: A OrderedDict containing the new weights. |
|
""" |
|
if isinstance(weights, Path): |
|
weights = torch.load(weights, map_location=torch.device("cpu"), weights_only=False) |
|
if "model_state_dict" in weights: |
|
weights = weights["model_state_dict"] |
|
|
|
model_state_dict = self.model.state_dict() |
|
|
|
|
|
|
|
|
|
error_dict = {"Mismatch": set(), "Not Found": set()} |
|
for model_key, model_weight in model_state_dict.items(): |
|
if model_key not in weights: |
|
error_dict["Not Found"].add(tuple(model_key.split(".")[:-2])) |
|
continue |
|
if model_weight.shape != weights[model_key].shape: |
|
error_dict["Mismatch"].add(tuple(model_key.split(".")[:-2])) |
|
continue |
|
model_state_dict[model_key] = weights[model_key] |
|
|
|
for error_name, error_set in error_dict.items(): |
|
for weight_name in error_set: |
|
logger.warning(f":warning: Weight {error_name} for key: {'.'.join(weight_name)}") |
|
|
|
self.model.load_state_dict(model_state_dict) |
|
|
|
|
|
def create_model(model_cfg: ModelConfig, weight_path: Union[bool, Path] = True, class_num: int = 80) -> YOLO: |
|
"""Constructs and returns a model from a Dictionary configuration file. |
|
|
|
Args: |
|
config_file (dict): The configuration file of the model. |
|
|
|
Returns: |
|
YOLO: An instance of the model defined by the given configuration. |
|
""" |
|
OmegaConf.set_struct(model_cfg, False) |
|
model = YOLO(model_cfg, class_num) |
|
if weight_path: |
|
if weight_path == True: |
|
weight_path = Path("weights") / f"{model_cfg.name}.pt" |
|
elif isinstance(weight_path, str): |
|
weight_path = Path(weight_path) |
|
|
|
if not weight_path.exists(): |
|
logger.info(f"π Weight {weight_path} not found, try downloading") |
|
prepare_weight(weight_path=weight_path) |
|
if weight_path.exists(): |
|
model.save_load_weights(weight_path) |
|
logger.info(":white_check_mark: Success load model & weight") |
|
else: |
|
logger.info(":white_check_mark: Success load model") |
|
return model |
|
|