File size: 6,303 Bytes
230a441
18edc1d
183312f
306fc38
b80dc1e
5556058
7d7e199
31cab2b
afa32b4
230a441
dcceddd
b80dc1e
 
 
 
 
 
 
 
 
 
 
0a3c9de
b80dc1e
0a3c9de
542860e
856cce6
f99f89b
200b5c1
f99f89b
97681c2
 
92614f8
 
c14cffe
7c6ce21
200b5c1
 
5556058
97681c2
 
 
5556058
92614f8
5556058
 
78e3679
97681c2
5556058
 
 
f99f89b
97681c2
5556058
 
92614f8
97681c2
92614f8
 
97681c2
92614f8
97681c2
 
 
0f9ffa2
5556058
 
542860e
198ddce
2dd2ae5
198ddce
5556058
542860e
 
 
 
5727efb
 
198ddce
3186e72
2dd2ae5
3186e72
542860e
 
1504257
78e3679
97681c2
5556058
 
b67aac7
542860e
97681c2
 
 
 
 
92614f8
 
 
 
 
5727efb
92614f8
5727efb
 
92614f8
 
 
97681c2
542860e
0f9ffa2
5556058
0f9ffa2
5556058
 
5727efb
542860e
97681c2
 
b80dc1e
 
b6b57c7
b80dc1e
 
 
 
 
 
 
 
1504257
40afe67
0a3c9de
40afe67
f5518c0
 
68c38b8
40afe67
 
18edc1d
7d7e199
f5518c0
200b5c1
 
 
7d7e199
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import os
from typing import Dict, List, Optional, Union

import torch
from loguru import logger
from omegaconf import ListConfig, OmegaConf
from torch import nn

from yolo.config.config import ModelConfig, YOLOLayer
from yolo.tools.dataset_preparation import prepare_weight
from yolo.utils.module_utils import get_layer_map


class YOLO(nn.Module):
    """
    A preliminary YOLO (You Only Look Once) model class still under development.

    Parameters:
        model_cfg: Configuration for the YOLO model. Expected to define the layers,
                   parameters, and any other relevant configuration details.
    """

    def __init__(self, model_cfg: ModelConfig, class_num: int = 80):
        super(YOLO, self).__init__()
        self.num_classes = class_num
        self.layer_map = get_layer_map()  # Get the map Dict[str: Module]
        self.model: List[YOLOLayer] = nn.ModuleList()
        self.reg_max = getattr(model_cfg.anchor, "reg_max", 16)
        self.strides = getattr(model_cfg.anchor, "strides", None)
        self.build_model(model_cfg.model)

    def build_model(self, model_arch: Dict[str, List[Dict[str, Dict[str, Dict]]]]):
        self.layer_index = {}
        output_dim, layer_idx = [3], 1
        logger.info(f"🚜 Building YOLO")
        for arch_name in model_arch:
            if model_arch[arch_name]:
                logger.info(f"  πŸ—οΈ  Building {arch_name}")
            for layer_idx, layer_spec in enumerate(model_arch[arch_name], start=layer_idx):
                layer_type, layer_info = next(iter(layer_spec.items()))
                layer_args = layer_info.get("args", {})

                # Get input source
                source = self.get_source_idx(layer_info.get("source", -1), layer_idx)

                # Find in channels
                if any(module in layer_type for module in ["Conv", "ELAN", "ADown", "AConv", "CBLinear"]):
                    layer_args["in_channels"] = output_dim[source]
                if "Detection" in layer_type:
                    layer_args["in_channels"] = [output_dim[idx] for idx in source]
                    layer_args["num_classes"] = self.num_classes
                    layer_args["reg_max"] = self.reg_max

                # create layers
                layer = self.create_layer(layer_type, source, layer_info, **layer_args)
                self.model.append(layer)

                if layer.tags:
                    if layer.tags in self.layer_index:
                        raise ValueError(f"Duplicate tag '{layer_info['tags']}' found.")
                    self.layer_index[layer.tags] = layer_idx

                out_channels = self.get_out_channels(layer_type, layer_args, output_dim, source)
                output_dim.append(out_channels)
                setattr(layer, "out_c", out_channels)
            layer_idx += 1

    def forward(self, x):
        y = {0: x}
        output = dict()
        for index, layer in enumerate(self.model, start=1):
            if isinstance(layer.source, list):
                model_input = [y[idx] for idx in layer.source]
            else:
                model_input = y[layer.source]
            x = layer(model_input)
            y[-1] = x
            if layer.usable:
                y[index] = x
            if layer.output:
                output[layer.tags] = x
        return output

    def get_out_channels(self, layer_type: str, layer_args: dict, output_dim: list, source: Union[int, list]):
        # TODO refactor, check out_channels in layer_args, next check source is list, CBFuse|Concat
        if any(module in layer_type for module in ["Conv", "ELAN", "ADown", "AConv"]):
            return layer_args["out_channels"]
        if layer_type == "CBFuse":
            return output_dim[source[-1]]
        if layer_type in ["Pool", "UpSample"]:
            return output_dim[source]
        if layer_type == "Concat":
            return sum(output_dim[idx] for idx in source)
        if layer_type == "IDetect":
            return None

    def get_source_idx(self, source: Union[ListConfig, str, int], layer_idx: int):
        if isinstance(source, ListConfig):
            return [self.get_source_idx(index, layer_idx) for index in source]
        if isinstance(source, str):
            source = self.layer_index[source]
        if source < -1:
            source += layer_idx
        if source > 0:  # Using Previous Layer's Output
            self.model[source - 1].usable = True
        return source

    def create_layer(self, layer_type: str, source: Union[int, list], layer_info: Dict, **kwargs) -> YOLOLayer:
        if layer_type in self.layer_map:
            layer = self.layer_map[layer_type](**kwargs)
            setattr(layer, "layer_type", layer_type)
            setattr(layer, "source", source)
            setattr(layer, "in_c", kwargs.get("in_channels", None))
            setattr(layer, "output", layer_info.get("output", False))
            setattr(layer, "tags", layer_info.get("tags", None))
            setattr(layer, "usable", 0)
            return layer
        else:
            raise ValueError(f"Unsupported layer type: {layer_type}")


def create_model(model_cfg: ModelConfig, weight_path: Union[bool, str] = True, class_num: int = 80) -> YOLO:
    """Constructs and returns a model from a Dictionary configuration file.

    Args:
        config_file (dict): The configuration file of the model.

    Returns:
        YOLO: An instance of the model defined by the given configuration.
    """
    # TODO: "weight_path -> weight = [True|None-False|Path]: True should be default of model name?"
    OmegaConf.set_struct(model_cfg, False)
    model = YOLO(model_cfg, class_num)
    if weight_path:
        if weight_path == True:
            weight_path = os.path.join("weights", f"{model_cfg.name}.pt")
        if not os.path.exists(weight_path):
            logger.info(f"🌐 Weight {weight_path} not found, try downloading")
            prepare_weight(weight_path=weight_path)
        if os.path.exists(weight_path):
            # TODO: fix map_location
            model.model.load_state_dict(torch.load(weight_path, map_location=torch.device("cpu")), strict=False)
            logger.info("βœ… Success load model & weight")
    else:
        logger.info("βœ… Success load model")
    return model