File size: 5,129 Bytes
183312f
 
 
b80dc1e
5556058
31cab2b
856cce6
1197f7d
0f9ffa2
b80dc1e
 
 
 
 
 
 
 
 
 
 
1851849
b80dc1e
1851849
542860e
856cce6
7c6ce21
0f9ffa2
97681c2
 
92614f8
 
c14cffe
7c6ce21
c14cffe
5556058
97681c2
 
 
5556058
92614f8
5556058
 
 
97681c2
5556058
 
 
97681c2
5556058
 
92614f8
97681c2
92614f8
 
97681c2
92614f8
97681c2
 
 
0f9ffa2
5556058
 
542860e
198ddce
3186e72
198ddce
5556058
542860e
 
 
 
198ddce
 
3186e72
 
 
542860e
 
5556058
97681c2
5556058
 
b67aac7
542860e
97681c2
 
 
 
 
92614f8
 
 
 
 
 
 
 
 
 
 
 
97681c2
542860e
0f9ffa2
5556058
0f9ffa2
5556058
 
542860e
97681c2
 
b80dc1e
 
1851849
b80dc1e
 
 
 
 
 
 
 
1851849
 
542860e
b80dc1e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from typing import Any, Dict, List, Union

import torch.nn as nn
from loguru import logger
from omegaconf import ListConfig, OmegaConf

from yolo.config.config import Config, Model, YOLOLayer
from yolo.tools.layer_helper import get_layer_map
from yolo.tools.log_helper import log_model


class YOLO(nn.Module):
    """
    A preliminary YOLO (You Only Look Once) model class still under development.

    Parameters:
        model_cfg: Configuration for the YOLO model. Expected to define the layers,
                   parameters, and any other relevant configuration details.
    """

    def __init__(self, model_cfg: Model, num_classes: int):
        super(YOLO, self).__init__()
        self.num_classes = num_classes
        self.layer_map = get_layer_map()  # Get the map Dict[str: Module]
        self.model: List[YOLOLayer] = nn.ModuleList()
        self.build_model(model_cfg.model)
        log_model(self.model)

    def build_model(self, model_arch: Dict[str, List[Dict[str, Dict[str, Dict]]]]):
        self.layer_index = {}
        output_dim, layer_idx = [3], 1
        logger.info(f"🚜 Building YOLO")
        for arch_name in model_arch:
            logger.info(f"  πŸ—οΈ  Building {arch_name}")
            for layer_idx, layer_spec in enumerate(model_arch[arch_name], start=layer_idx):
                layer_type, layer_info = next(iter(layer_spec.items()))
                layer_args = layer_info.get("args", {})

                # Get input source
                source = self.get_source_idx(layer_info.get("source", -1), layer_idx)

                # Find in channels
                if any(module in layer_type for module in ["Conv", "ELAN", "ADown", "CBLinear"]):
                    layer_args["in_channels"] = output_dim[source]
                if "Detection" in layer_type:
                    layer_args["in_channels"] = [output_dim[idx] for idx in source]
                    layer_args["num_classes"] = self.num_classes

                # create layers
                layer = self.create_layer(layer_type, source, layer_info, **layer_args)
                self.model.append(layer)

                if layer.tags:
                    if layer.tags in self.layer_index:
                        raise ValueError(f"Duplicate tag '{layer_info['tags']}' found.")
                    self.layer_index[layer.tags] = layer_idx

                out_channels = self.get_out_channels(layer_type, layer_args, output_dim, source)
                output_dim.append(out_channels)
                setattr(layer, "out_c", out_channels)
            layer_idx += 1

    def forward(self, x):
        y = {0: x}
        output = []
        for index, layer in enumerate(self.model, start=1):
            if isinstance(layer.source, list):
                model_input = [y[idx] for idx in layer.source]
            else:
                model_input = y[layer.source]
            x = layer(model_input)
            if hasattr(layer, "save"):
                y[index] = x
            if layer.output:
                output.append(x)
        return output

    def get_out_channels(self, layer_type: str, layer_args: dict, output_dim: list, source: Union[int, list]):
        if any(module in layer_type for module in ["Conv", "ELAN", "ADown"]):
            return layer_args["out_channels"]
        if layer_type == "CBFuse":
            return output_dim[source[-1]]
        if layer_type in ["Pool", "UpSample"]:
            return output_dim[source]
        if layer_type == "Concat":
            return sum(output_dim[idx] for idx in source)
        if layer_type == "IDetect":
            return None

    def get_source_idx(self, source: Union[ListConfig, str, int], layer_idx: int):
        if isinstance(source, ListConfig):
            return [self.get_source_idx(index, layer_idx) for index in source]
        if isinstance(source, str):
            source = self.layer_index[source]
        if source < 0:
            source += layer_idx
        if source > 0:
            setattr(self.model[source - 1], "save", True)
        return source

    def create_layer(self, layer_type: str, source: Union[int, list], layer_info: Dict, **kwargs) -> YOLOLayer:
        if layer_type in self.layer_map:
            layer = self.layer_map[layer_type](**kwargs)
            setattr(layer, "layer_type", layer_type)
            setattr(layer, "source", source)
            setattr(layer, "in_c", kwargs.get("in_channels", None))
            setattr(layer, "output", layer_info.get("output", False))
            setattr(layer, "tags", layer_info.get("tags", None))
            return layer
        else:
            raise ValueError(f"Unsupported layer type: {layer_type}")


def get_model(cfg: Config) -> YOLO:
    """Constructs and returns a model from a Dictionary configuration file.

    Args:
        config_file (dict): The configuration file of the model.

    Returns:
        YOLO: An instance of the model defined by the given configuration.
    """
    OmegaConf.set_struct(cfg.model, False)
    model = YOLO(cfg.model, cfg.hyper.data.class_num)
    logger.info("βœ… Success load model")
    return model