File size: 5,129 Bytes
183312f b80dc1e 5556058 31cab2b 856cce6 1197f7d 0f9ffa2 b80dc1e 1851849 b80dc1e 1851849 542860e 856cce6 7c6ce21 0f9ffa2 97681c2 92614f8 c14cffe 7c6ce21 c14cffe 5556058 97681c2 5556058 92614f8 5556058 97681c2 5556058 97681c2 5556058 92614f8 97681c2 92614f8 97681c2 92614f8 97681c2 0f9ffa2 5556058 542860e 198ddce 3186e72 198ddce 5556058 542860e 198ddce 3186e72 542860e 5556058 97681c2 5556058 b67aac7 542860e 97681c2 92614f8 97681c2 542860e 0f9ffa2 5556058 0f9ffa2 5556058 542860e 97681c2 b80dc1e 1851849 b80dc1e 1851849 542860e b80dc1e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
from typing import Any, Dict, List, Union
import torch.nn as nn
from loguru import logger
from omegaconf import ListConfig, OmegaConf
from yolo.config.config import Config, Model, YOLOLayer
from yolo.tools.layer_helper import get_layer_map
from yolo.tools.log_helper import log_model
class YOLO(nn.Module):
"""
A preliminary YOLO (You Only Look Once) model class still under development.
Parameters:
model_cfg: Configuration for the YOLO model. Expected to define the layers,
parameters, and any other relevant configuration details.
"""
def __init__(self, model_cfg: Model, num_classes: int):
super(YOLO, self).__init__()
self.num_classes = num_classes
self.layer_map = get_layer_map() # Get the map Dict[str: Module]
self.model: List[YOLOLayer] = nn.ModuleList()
self.build_model(model_cfg.model)
log_model(self.model)
def build_model(self, model_arch: Dict[str, List[Dict[str, Dict[str, Dict]]]]):
self.layer_index = {}
output_dim, layer_idx = [3], 1
logger.info(f"π Building YOLO")
for arch_name in model_arch:
logger.info(f" ποΈ Building {arch_name}")
for layer_idx, layer_spec in enumerate(model_arch[arch_name], start=layer_idx):
layer_type, layer_info = next(iter(layer_spec.items()))
layer_args = layer_info.get("args", {})
# Get input source
source = self.get_source_idx(layer_info.get("source", -1), layer_idx)
# Find in channels
if any(module in layer_type for module in ["Conv", "ELAN", "ADown", "CBLinear"]):
layer_args["in_channels"] = output_dim[source]
if "Detection" in layer_type:
layer_args["in_channels"] = [output_dim[idx] for idx in source]
layer_args["num_classes"] = self.num_classes
# create layers
layer = self.create_layer(layer_type, source, layer_info, **layer_args)
self.model.append(layer)
if layer.tags:
if layer.tags in self.layer_index:
raise ValueError(f"Duplicate tag '{layer_info['tags']}' found.")
self.layer_index[layer.tags] = layer_idx
out_channels = self.get_out_channels(layer_type, layer_args, output_dim, source)
output_dim.append(out_channels)
setattr(layer, "out_c", out_channels)
layer_idx += 1
def forward(self, x):
y = {0: x}
output = []
for index, layer in enumerate(self.model, start=1):
if isinstance(layer.source, list):
model_input = [y[idx] for idx in layer.source]
else:
model_input = y[layer.source]
x = layer(model_input)
if hasattr(layer, "save"):
y[index] = x
if layer.output:
output.append(x)
return output
def get_out_channels(self, layer_type: str, layer_args: dict, output_dim: list, source: Union[int, list]):
if any(module in layer_type for module in ["Conv", "ELAN", "ADown"]):
return layer_args["out_channels"]
if layer_type == "CBFuse":
return output_dim[source[-1]]
if layer_type in ["Pool", "UpSample"]:
return output_dim[source]
if layer_type == "Concat":
return sum(output_dim[idx] for idx in source)
if layer_type == "IDetect":
return None
def get_source_idx(self, source: Union[ListConfig, str, int], layer_idx: int):
if isinstance(source, ListConfig):
return [self.get_source_idx(index, layer_idx) for index in source]
if isinstance(source, str):
source = self.layer_index[source]
if source < 0:
source += layer_idx
if source > 0:
setattr(self.model[source - 1], "save", True)
return source
def create_layer(self, layer_type: str, source: Union[int, list], layer_info: Dict, **kwargs) -> YOLOLayer:
if layer_type in self.layer_map:
layer = self.layer_map[layer_type](**kwargs)
setattr(layer, "layer_type", layer_type)
setattr(layer, "source", source)
setattr(layer, "in_c", kwargs.get("in_channels", None))
setattr(layer, "output", layer_info.get("output", False))
setattr(layer, "tags", layer_info.get("tags", None))
return layer
else:
raise ValueError(f"Unsupported layer type: {layer_type}")
def get_model(cfg: Config) -> YOLO:
"""Constructs and returns a model from a Dictionary configuration file.
Args:
config_file (dict): The configuration file of the model.
Returns:
YOLO: An instance of the model defined by the given configuration.
"""
OmegaConf.set_struct(cfg.model, False)
model = YOLO(cfg.model, cfg.hyper.data.class_num)
logger.info("β
Success load model")
return model
|