File size: 24,640 Bytes
1f53a4c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
#!/usr/bin/env python
# -*- coding: utf-8 -*-r

from collections import OrderedDict
import torch
import torch.nn as nn
from torchvision.ops import MLP
import torchvision.models as models
from typing import Dict, Optional


class BaseNet:
    """
    Class to construct network
    """
    cnn = {
            'ResNet18': models.resnet18,
            'ResNet': models.resnet50,
            'DenseNet': models.densenet161,
            'EfficientNetB0': models.efficientnet_b0,
            'EfficientNetB2': models.efficientnet_b2,
            'EfficientNetB4': models.efficientnet_b4,
            'EfficientNetB6': models.efficientnet_b6,
            'EfficientNetV2s': models.efficientnet_v2_s,
            'EfficientNetV2m': models.efficientnet_v2_m,
            'EfficientNetV2l': models.efficientnet_v2_l,
            'ConvNeXtTiny': models.convnext_tiny,
            'ConvNeXtSmall': models.convnext_small,
            'ConvNeXtBase': models.convnext_base,
            'ConvNeXtLarge': models.convnext_large
            }

    vit = {
            'ViTb16': models.vit_b_16,
            'ViTb32': models.vit_b_32,
            'ViTl16': models.vit_l_16,
            'ViTl32': models.vit_l_32,
            'ViTH14': models.vit_h_14
            }

    net = {**cnn, **vit}

    _classifier = {
            'ResNet': 'fc',
            'DenseNet': 'classifier',
            'EfficientNet': 'classifier',
            'ConvNext': 'classifier',
            'ViT': 'heads'
            }

    classifier = {
                'ResNet18': _classifier['ResNet'],
                'ResNet': _classifier['ResNet'],
                'DenseNet': _classifier['DenseNet'],
                'EfficientNetB0': _classifier['EfficientNet'],
                'EfficientNetB2': _classifier['EfficientNet'],
                'EfficientNetB4': _classifier['EfficientNet'],
                'EfficientNetB6': _classifier['EfficientNet'],
                'EfficientNetV2s': _classifier['EfficientNet'],
                'EfficientNetV2m': _classifier['EfficientNet'],
                'EfficientNetV2l': _classifier['EfficientNet'],
                'ConvNeXtTiny': _classifier['ConvNext'],
                'ConvNeXtSmall': _classifier['ConvNext'],
                'ConvNeXtBase':  _classifier['ConvNext'],
                'ConvNeXtLarge':  _classifier['ConvNext'],
                'ViTb16': _classifier['ViT'],
                'ViTb32': _classifier['ViT'],
                'ViTl16': _classifier['ViT'],
                'ViTl32': _classifier['ViT'],
                'ViTH14': _classifier['ViT']
                }

    mlp_config = {
                'hidden_channels': [256, 256, 256],
                'dropout': 0.2
                }

    DUMMY = nn.Identity()

    @classmethod
    def MLPNet(cls, mlp_num_inputs: int = None, inplace: bool = None) -> MLP:
        """
        Construct MLP.

        Args:
            mlp_num_inputs (int): the number of input of MLP
            inplace (bool, optional): parameter for the activation layer, which can optionally do the operation in-place. Defaults to None.

        Returns:
            MLP: MLP
        """
        assert isinstance(mlp_num_inputs, int), f"Invalid number of inputs for MLP: {mlp_num_inputs}."
        mlp = MLP(in_channels=mlp_num_inputs, hidden_channels=cls.mlp_config['hidden_channels'], inplace=inplace, dropout=cls.mlp_config['dropout'])
        return mlp

    @classmethod
    def align_in_channels_1ch(cls, net_name: str = None, net: nn.Module = None) -> nn.Module:
        """
        Modify network to handle gray scale image.

        Args:
            net_name (str): network name
            net (nn.Module): network itself

        Returns:
            nn.Module: network available for gray scale
        """
        if net_name.startswith('ResNet'):
            net.conv1.in_channels = 1
            net.conv1.weight = nn.Parameter(net.conv1.weight.sum(dim=1).unsqueeze(1))

        elif net_name.startswith('DenseNet'):
            net.features.conv0.in_channels = 1
            net.features.conv0.weight = nn.Parameter(net.features.conv0.weight.sum(dim=1).unsqueeze(1))

        elif net_name.startswith('Efficient'):
            net.features[0][0].in_channels = 1
            net.features[0][0].weight = nn.Parameter(net.features[0][0].weight.sum(dim=1).unsqueeze(1))

        elif net_name.startswith('ConvNeXt'):
            net.features[0][0].in_channels = 1
            net.features[0][0].weight = nn.Parameter(net.features[0][0].weight.sum(dim=1).unsqueeze(1))

        elif net_name.startswith('ViT'):
            net.conv_proj.in_channels = 1
            net.conv_proj.weight = nn.Parameter(net.conv_proj.weight.sum(dim=1).unsqueeze(1))

        else:
            raise ValueError(f"No specified net: {net_name}.")
        return net

    @classmethod
    def set_net(
                cls,
                net_name: str = None,
                in_channel: int = None,
                vit_image_size: int = None,
                pretrained: bool = None
                ) -> nn.Module:
        """
        Modify network depending on in_channel and vit_image_size.

        Args:
            net_name (str): network name
            in_channel (int, optional): image channel(any of 1ch or 3ch). Defaults to None.
            vit_image_size (int, optional): image size which ViT handles if ViT is used. Defaults to None.
                                            vit_image_size should be power of patch size.
            pretrained (bool, optional): True when use pretrained CNN or ViT, otherwise False. Defaults to None.

        Returns:
            nn.Module: modified network
        """
        assert net_name in cls.net, f"No specified net: {net_name}."
        if net_name in cls.cnn:
            if pretrained:
                net = cls.cnn[net_name](weights='DEFAULT')
            else:
                net = cls.cnn[net_name]()
        else:
            # When ViT
            # always use pretrained
            net = cls.set_vit(net_name=net_name, vit_image_size=vit_image_size)

        if in_channel == 1:
            net = cls.align_in_channels_1ch(net_name=net_name, net=net)
        return net

    @classmethod
    def set_vit(cls, net_name: str = None, vit_image_size: int = None) -> nn.Module:
        """
        Modify ViT depending on vit_image_size.

        Args:
            net_name (str): ViT name
            vit_image_size (int): image size which ViT handles if ViT is used.

        Returns:
            nn.Module: modified ViT
        """
        base_vit = cls.vit[net_name]
        # pretrained_vit = base_vit(weights=cls.vit_weight[net_name])
        pretrained_vit = base_vit(weights='DEFAULT')

        # Align weight depending on image size
        weight = pretrained_vit.state_dict()
        patch_size = int(net_name[-2:])  # 'ViTb16' -> 16
        aligned_weight = models.vision_transformer.interpolate_embeddings(
                                                    image_size=vit_image_size,
                                                    patch_size=patch_size,
                                                    model_state=weight
                                                    )
        aligned_vit = base_vit(image_size=vit_image_size)  # Specify new image size.
        aligned_vit.load_state_dict(aligned_weight)        # Load weight which can handle the new image size.
        return aligned_vit

    @classmethod
    def construct_extractor(
                            cls,
                            net_name: str = None,
                            mlp_num_inputs: int = None,
                            in_channel: int = None,
                            vit_image_size: int = None,
                            pretrained: bool = None
                            ) -> nn.Module:
        """
        Construct extractor of network depending on net_name.

        Args:
            net_name (str): network name.
            mlp_num_inputs (int, optional): number of input of MLP. Defaults to None.
            in_channel (int, optional): image channel(any of 1ch or 3ch). Defaults to None.
            vit_image_size (int, optional): image size which ViT handles if ViT is used. Defaults to None.
            pretrained (bool, optional): True when use pretrained CNN or ViT, otherwise False. Defaults to None.

        Returns:
            nn.Module: extractor of network
        """
        if net_name == 'MLP':
            extractor = cls.MLPNet(mlp_num_inputs=mlp_num_inputs)
        else:
            extractor = cls.set_net(net_name=net_name, in_channel=in_channel, vit_image_size=vit_image_size, pretrained=pretrained)
            setattr(extractor, cls.classifier[net_name], cls.DUMMY)  # Replace classifier with DUMMY(=nn.Identity()).
        return extractor

    @classmethod
    def get_classifier(cls, net_name: str) -> nn.Module:
        """
        Get classifier of network depending on net_name.

        Args:
            net_name (str): network name

        Returns:
            nn.Module: classifier of network
        """
        net = cls.net[net_name]()
        classifier = getattr(net, cls.classifier[net_name])
        return classifier

    @classmethod
    def construct_multi_classifier(cls, net_name: str = None, num_outputs_for_label: Dict[str, int] = None) -> nn.ModuleDict:
        """
        Construct classifier for multi-label.

        Args:
            net_name (str): network name
            num_outputs_for_label (Dict[str, int]): number of outputs for each label

        Returns:
            nn.ModuleDict: classifier for multi-label
        """
        classifiers = dict()
        if net_name == 'MLP':
            in_features = cls.mlp_config['hidden_channels'][-1]
            for label_name, num_outputs in num_outputs_for_label.items():
                classifiers[label_name] = nn.Linear(in_features, num_outputs)

        elif net_name.startswith('ResNet') or net_name.startswith('DenseNet'):
            base_classifier = cls.get_classifier(net_name)
            in_features = base_classifier.in_features
            for label_name, num_outputs in num_outputs_for_label.items():
                classifiers[label_name] = nn.Linear(in_features, num_outputs)

        elif net_name.startswith('EfficientNet'):
            base_classifier = cls.get_classifier(net_name)
            dropout = base_classifier[0].p
            in_features = base_classifier[1].in_features
            for label_name, num_outputs in num_outputs_for_label.items():
                classifiers[label_name] = nn.Sequential(
                                                        nn.Dropout(p=dropout, inplace=False),
                                                        nn.Linear(in_features, num_outputs)
                                                    )

        elif net_name.startswith('ConvNeXt'):
            base_classifier = cls.get_classifier(net_name)
            layer_norm = base_classifier[0]
            flatten = base_classifier[1]
            in_features = base_classifier[2].in_features
            for label_name, num_outputs in num_outputs_for_label.items():
                # Shape is changed before nn.Linear.
                classifiers[label_name] = nn.Sequential(
                                                        layer_norm,
                                                        flatten,
                                                        nn.Linear(in_features, num_outputs)
                                                    )

        elif net_name.startswith('ViT'):
            base_classifier = cls.get_classifier(net_name)
            in_features = base_classifier.head.in_features
            for label_name, num_outputs in num_outputs_for_label.items():
                classifiers[label_name] = nn.Sequential(
                                                OrderedDict([
                                                        ('head', nn.Linear(in_features, num_outputs))
                                                        ])
                                                )

        else:
            raise ValueError(f"No specified net: {net_name}.")

        multi_classifier = nn.ModuleDict(classifiers)
        return multi_classifier

    @classmethod
    def get_classifier_in_features(cls, net_name: str) -> int:
        """
        Return in_feature of network indicating by net_name.
        This class is used in class MultiNetFusion() only.

        Args:
            net_name (str): net_name

        Returns:
            int : in_feature

        Required:
        classifier.in_feature
        classifier.[1].in_features
        classifier.[2].in_features
        classifier.head.in_features
        """
        if net_name == 'MLP':
            in_features = cls.mlp_config['hidden_channels'][-1]

        elif net_name.startswith('ResNet') or net_name.startswith('DenseNet'):
            base_classifier = cls.get_classifier(net_name)
            in_features = base_classifier.in_features

        elif net_name.startswith('EfficientNet'):
            base_classifier = cls.get_classifier(net_name)
            in_features = base_classifier[1].in_features

        elif net_name.startswith('ConvNeXt'):
            base_classifier = cls.get_classifier(net_name)
            in_features = base_classifier[2].in_features

        elif net_name.startswith('ViT'):
            base_classifier = cls.get_classifier(net_name)
            in_features = base_classifier.head.in_features

        else:
            raise ValueError(f"No specified net: {net_name}.")
        return in_features

    @classmethod
    def construct_aux_module(cls, net_name: str) -> nn.Sequential:
        """
        Construct module to align the shape of feature from extractor depending on network.
        Actually, only when net_name == 'ConvNeXt'.
        Because ConvNeXt has the process of aligning the dimensions in its classifier.

        Needs to align shape of the feature extractor when ConvNeXt
        (classifier):
        Sequential(
            (0): LayerNorm2d((768,), eps=1e-06, elementwise_affine=True)
            (1): Flatten(start_dim=1, end_dim=-1)
            (2): Linear(in_features=768, out_features=1000, bias=True)
        )

        Args:
            net_name (str): net name

        Returns:
            nn.Module: layers such that they align the dimension of the output from the extractor like the original ConvNeXt.
        """
        aux_module = cls.DUMMY
        if net_name.startswith('ConvNeXt'):
            base_classifier = cls.get_classifier(net_name)
            layer_norm = base_classifier[0]
            flatten = base_classifier[1]
            aux_module = nn.Sequential(
                                layer_norm,
                                flatten
                                )
        return aux_module

    @classmethod
    def get_last_extractor(cls, net: nn.Module = None, mlp: str = None, net_name: str = None) -> nn.Module:
        """
        Return the last extractor of network.
        This is for Grad-CAM.
        net should be one loaded weight.

        Args:
            net (nn.Module): network itself
            mlp (str): 'MLP', otherwise None
            net_name (str): network name

        Returns:
            nn.Module: last extractor of network
        """
        assert (net_name is not None), f"Network does not contain CNN or ViT: mlp={mlp}, net={net_name}."

        _extractor = net.extractor_net

        if net_name.startswith('ResNet'):
            last_extractor = _extractor.layer4[-1]
        elif net_name.startswith('DenseNet'):
            last_extractor = _extractor.features.denseblock4.denselayer24
        elif net_name.startswith('EfficientNet'):
            last_extractor = _extractor.features[-1]
        elif net_name.startswith('ConvNeXt'):
            last_extractor = _extractor.features[-1][-1].block
        elif net_name.startswith('ViT'):
            last_extractor = _extractor.encoder.layers[-1]
        else:
            raise ValueError(f"Cannot get last extractor of net: {net_name}.")
        return last_extractor


class MultiMixin:
    """
    Class to define auxiliary function to handle multi-label.
    """
    def multi_forward(self, out_features: int) -> Dict[str, float]:
        """
        Forward out_features to classifier for each label.

        Args:
            out_features (int): output from extractor

        Returns:
            Dict[str, float]: output of classifier of each label
        """
        output = dict()
        for label_name, classifier in self.multi_classifier.items():
            output[label_name] = classifier(out_features)
        return output


class MultiWidget(nn.Module, BaseNet, MultiMixin):
    """
    Class for a widget to inherit multiple classes simultaneously.
    """
    pass


class MultiNet(MultiWidget):
    """
    Model of MLP, CNN or ViT.
    """
    def __init__(
                self,
                net_name: str = None,
                num_outputs_for_label: Dict[str, int] = None,
                mlp_num_inputs: int = None,
                in_channel: int = None,
                vit_image_size: int = None,
                pretrained: bool = None
                ) -> None:
        """
        Args:
            net_name (str): MLP, CNN or ViT name
            num_outputs_for_label (Dict[str, int]): number of classes for each label
            mlp_num_inputs (int): number of input of MLP.
            in_channel (int): number of image channel, ie gray scale(=1) or color image(=3).
            vit_image_size (int): image size to be input to ViT.
            pretrained (bool): True when use pretrained CNN or ViT, otherwise False.
        """
        super().__init__()

        self.net_name = net_name
        self.num_outputs_for_label = num_outputs_for_label
        self.mlp_num_inputs = mlp_num_inputs
        self.in_channel = in_channel
        self.vit_image_size = vit_image_size
        self.pretrained = pretrained

        # self.extractor_net = MLP or CVmodel
        self.extractor_net = self.construct_extractor(
                                                    net_name=self.net_name,
                                                    mlp_num_inputs=self.mlp_num_inputs,
                                                    in_channel=self.in_channel,
                                                    vit_image_size=self.vit_image_size,
                                                    pretrained=self.pretrained
                                                    )
        self.multi_classifier = self.construct_multi_classifier(net_name=self.net_name, num_outputs_for_label=self.num_outputs_for_label)

    def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
        """
        Forward.

        Args:
            x (torch.Tensor): tabular data or image

        Returns:
            Dict[str, torch.Tensor]: output
        """
        out_features = self.extractor_net(x)
        output = self.multi_forward(out_features)
        return output


class MultiNetFusion(MultiWidget):
    """
    Fusion model of MLP and CNN or ViT.
    """
    def __init__(
                self,
                net_name: str = None,
                num_outputs_for_label: Dict[str, int] = None,
                mlp_num_inputs: int = None,
                in_channel: int = None,
                vit_image_size: int = None,
                pretrained: bool = None
                ) -> None:
        """
        Args:
            net_name (str): CNN or ViT name. It is clear that MLP is used in fusion model.
            num_outputs_for_label (Dict[str, int]): number of classes for each label
            mlp_num_inputs (int): number of input of MLP. Defaults to None.
            in_channel (int): number of image channel, ie gray scale(=1) or color image(=3).
            vit_image_size (int): image size to be input to ViT.
            pretrained (bool): True when use pretrained CNN or ViT, otherwise False.
        """
        assert (net_name != 'MLP'), 'net_name should not be MLP.'

        super().__init__()

        self.net_name = net_name
        self.num_outputs_for_label = num_outputs_for_label
        self.mlp_num_inputs = mlp_num_inputs
        self.in_channel = in_channel
        self.vit_image_size = vit_image_size
        self.pretrained = pretrained

        # Extractor of MLP and Net
        self.extractor_mlp = self.construct_extractor(net_name='MLP', mlp_num_inputs=self.mlp_num_inputs)
        self.extractor_net = self.construct_extractor(
                                                    net_name=self.net_name,
                                                    in_channel=self.in_channel,
                                                    vit_image_size=self.vit_image_size,
                                                    pretrained=self.pretrained
                                                    )
        self.aux_module = self.construct_aux_module(self.net_name)

        # Intermediate MLP
        self.in_features_from_mlp = self.get_classifier_in_features('MLP')
        self.in_features_from_net = self.get_classifier_in_features(self.net_name)
        self.inter_mlp_in_feature = self.in_features_from_mlp + self.in_features_from_net
        self.inter_mlp = self.MLPNet(mlp_num_inputs=self.inter_mlp_in_feature, inplace=False)

        # Multi classifier
        self.multi_classifier = self.construct_multi_classifier(net_name='MLP', num_outputs_for_label=num_outputs_for_label)

    def forward(self, x_mlp: torch.Tensor, x_net: torch.Tensor) -> Dict[str, torch.Tensor]:
        """
        Forward.

        Args:
            x_mlp (torch.Tensor): tabular data
            x_net (torch.Tensor): image

        Returns:
            Dict[str, torch.Tensor]: output
        """
        out_mlp = self.extractor_mlp(x_mlp)
        out_net = self.extractor_net(x_net)
        out_net = self.aux_module(out_net)

        out_features = torch.cat([out_mlp, out_net], dim=1)
        out_features = self.inter_mlp(out_features)
        output = self.multi_forward(out_features)
        return output


def create_net(
            mlp: Optional[str] = None,
            net: Optional[str] = None,
            num_outputs_for_label: Dict[str, int] = None,
            mlp_num_inputs: int = None,
            in_channel: int = None,
            vit_image_size: int = None,
            pretrained: bool = None
            ) -> nn.Module:
    """
    Create network.

    Args:
        mlp (Optional[str]): 'MLP' or None
        net (Optional[str]):  CNN, ViT name or None
        num_outputs_for_label (Dict[str, int]): number of outputs for each label
        mlp_num_inputs (int): number of input of MLP.
        in_channel (int): number of image channel, ie gray scale(=1) or color image(=3).
        vit_image_size (int): image size to be input to ViT.
        pretrained (bool): True when use pretrained CNN or ViT, otherwise False.

    Returns:
        nn.Module: network
    """
    _isMLPModel = (mlp is not None) and (net is None)
    _isCVModel = (mlp is None) and (net is not None)
    _isFusion = (mlp is not None) and (net is not None)

    if _isMLPModel:
        multi_net = MultiNet(
                            net_name='MLP',
                            num_outputs_for_label=num_outputs_for_label,
                            mlp_num_inputs=mlp_num_inputs,
                            in_channel=in_channel,
                            vit_image_size=vit_image_size,
                            pretrained=False   # No need of pretrained for MLP
                            )
    elif _isCVModel:
        multi_net = MultiNet(
                            net_name=net,
                            num_outputs_for_label=num_outputs_for_label,
                            mlp_num_inputs=mlp_num_inputs,
                            in_channel=in_channel,
                            vit_image_size=vit_image_size,
                            pretrained=pretrained
                            )
    elif _isFusion:
        multi_net = MultiNetFusion(
                                net_name=net,
                                num_outputs_for_label=num_outputs_for_label,
                                mlp_num_inputs=mlp_num_inputs,
                                in_channel=in_channel,
                                vit_image_size=vit_image_size,
                                pretrained=pretrained
                                )
    else:
        raise ValueError(f"Invalid model type: mlp={mlp}, net={net}.")

    return multi_net