File size: 464 Bytes
533763f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
from .ResNet import *
from .VGGNet import *

__all__ = ['get_backbone']


def get_backbone(model_name='', pretrained=True, num_classes=None, **kwargs):
    if 'res' in model_name:
        model = get_resnet(model_name, pretrained=pretrained, num_classes=num_classes, **kwargs)

    elif 'vgg' in model_name:
        model = get_vgg(model_name, pretrained=pretrained, num_classes=num_classes, **kwargs)
    else:
        raise NotImplementedError
    return model