|
import os |
|
from typing import Dict, List, Optional, Union |
|
|
|
import torch |
|
from loguru import logger |
|
from omegaconf import ListConfig, OmegaConf |
|
from torch import nn |
|
|
|
from yolo.config.config import ModelConfig, YOLOLayer |
|
from yolo.tools.dataset_preparation import prepare_weight |
|
from yolo.utils.module_utils import get_layer_map |
|
|
|
|
|
class YOLO(nn.Module): |
|
""" |
|
A preliminary YOLO (You Only Look Once) model class still under development. |
|
|
|
Parameters: |
|
model_cfg: Configuration for the YOLO model. Expected to define the layers, |
|
parameters, and any other relevant configuration details. |
|
""" |
|
|
|
def __init__(self, model_cfg: ModelConfig, class_num: int = 80): |
|
super(YOLO, self).__init__() |
|
self.num_classes = class_num |
|
self.layer_map = get_layer_map() |
|
self.model: List[YOLOLayer] = nn.ModuleList() |
|
self.reg_max = getattr(model_cfg.anchor, "reg_max", 16) |
|
self.strides = getattr(model_cfg.anchor, "strides", None) |
|
self.build_model(model_cfg.model) |
|
|
|
def build_model(self, model_arch: Dict[str, List[Dict[str, Dict[str, Dict]]]]): |
|
self.layer_index = {} |
|
output_dim, layer_idx = [3], 1 |
|
logger.info(f"π Building YOLO") |
|
for arch_name in model_arch: |
|
if model_arch[arch_name]: |
|
logger.info(f" ποΈ Building {arch_name}") |
|
for layer_idx, layer_spec in enumerate(model_arch[arch_name], start=layer_idx): |
|
layer_type, layer_info = next(iter(layer_spec.items())) |
|
layer_args = layer_info.get("args", {}) |
|
|
|
|
|
source = self.get_source_idx(layer_info.get("source", -1), layer_idx) |
|
|
|
|
|
if any(module in layer_type for module in ["Conv", "ELAN", "ADown", "AConv", "CBLinear"]): |
|
layer_args["in_channels"] = output_dim[source] |
|
if "Detection" in layer_type: |
|
layer_args["in_channels"] = [output_dim[idx] for idx in source] |
|
layer_args["num_classes"] = self.num_classes |
|
layer_args["reg_max"] = self.reg_max |
|
|
|
|
|
layer = self.create_layer(layer_type, source, layer_info, **layer_args) |
|
self.model.append(layer) |
|
|
|
if layer.tags: |
|
if layer.tags in self.layer_index: |
|
raise ValueError(f"Duplicate tag '{layer_info['tags']}' found.") |
|
self.layer_index[layer.tags] = layer_idx |
|
|
|
out_channels = self.get_out_channels(layer_type, layer_args, output_dim, source) |
|
output_dim.append(out_channels) |
|
setattr(layer, "out_c", out_channels) |
|
layer_idx += 1 |
|
|
|
def forward(self, x): |
|
y = {0: x} |
|
output = dict() |
|
for index, layer in enumerate(self.model, start=1): |
|
if isinstance(layer.source, list): |
|
model_input = [y[idx] for idx in layer.source] |
|
else: |
|
model_input = y[layer.source] |
|
x = layer(model_input) |
|
y[-1] = x |
|
if layer.usable: |
|
y[index] = x |
|
if layer.output: |
|
output[layer.tags] = x |
|
return output |
|
|
|
def get_out_channels(self, layer_type: str, layer_args: dict, output_dim: list, source: Union[int, list]): |
|
|
|
if any(module in layer_type for module in ["Conv", "ELAN", "ADown", "AConv"]): |
|
return layer_args["out_channels"] |
|
if layer_type == "CBFuse": |
|
return output_dim[source[-1]] |
|
if layer_type in ["Pool", "UpSample"]: |
|
return output_dim[source] |
|
if layer_type == "Concat": |
|
return sum(output_dim[idx] for idx in source) |
|
if layer_type == "IDetect": |
|
return None |
|
|
|
def get_source_idx(self, source: Union[ListConfig, str, int], layer_idx: int): |
|
if isinstance(source, ListConfig): |
|
return [self.get_source_idx(index, layer_idx) for index in source] |
|
if isinstance(source, str): |
|
source = self.layer_index[source] |
|
if source < -1: |
|
source += layer_idx |
|
if source > 0: |
|
self.model[source - 1].usable = True |
|
return source |
|
|
|
def create_layer(self, layer_type: str, source: Union[int, list], layer_info: Dict, **kwargs) -> YOLOLayer: |
|
if layer_type in self.layer_map: |
|
layer = self.layer_map[layer_type](**kwargs) |
|
setattr(layer, "layer_type", layer_type) |
|
setattr(layer, "source", source) |
|
setattr(layer, "in_c", kwargs.get("in_channels", None)) |
|
setattr(layer, "output", layer_info.get("output", False)) |
|
setattr(layer, "tags", layer_info.get("tags", None)) |
|
setattr(layer, "usable", 0) |
|
return layer |
|
else: |
|
raise ValueError(f"Unsupported layer type: {layer_type}") |
|
|
|
|
|
def create_model(model_cfg: ModelConfig, weight_path: Union[bool, str] = True, class_num: int = 80) -> YOLO: |
|
"""Constructs and returns a model from a Dictionary configuration file. |
|
|
|
Args: |
|
config_file (dict): The configuration file of the model. |
|
|
|
Returns: |
|
YOLO: An instance of the model defined by the given configuration. |
|
""" |
|
|
|
OmegaConf.set_struct(model_cfg, False) |
|
model = YOLO(model_cfg, class_num) |
|
if weight_path: |
|
if weight_path == True: |
|
weight_path = os.path.join("weights", f"{model_cfg.name}.pt") |
|
if not os.path.exists(weight_path): |
|
logger.info(f"π Weight {weight_path} not found, try downloading") |
|
prepare_weight(weight_path=weight_path) |
|
if os.path.exists(weight_path): |
|
|
|
model.model.load_state_dict(torch.load(weight_path, map_location=torch.device("cpu")), strict=False) |
|
logger.info("β
Success load model & weight") |
|
else: |
|
logger.info("β
Success load model") |
|
return model |
|
|