File size: 1,519 Bytes
570db9a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from typing import List, Tuple, Optional, Any, Union

from .model import _classifier, _regressor, Classifier, Regressor
from .clip import _clip_ebc, CLIP_EBC


clip_names = ["resnet50", "resnet50x4", "resnet50x16", "resnet50x64", "resnet101", "vit_b_16", "vit_b_32", "vit_l_14"]


def get_model(
    backbone: str,
    input_size: int,
    reduction: int,
    bins: Optional[List[Tuple[float, float]]] = None,
    anchor_points: Optional[List[float]] = None,
    **kwargs: Any,
) -> Union[Regressor, Classifier, CLIP_EBC]:
    backbone = backbone.lower()
    if "clip" in backbone:
        backbone = backbone[5:]
        assert backbone in clip_names, f"Expected backbone to be in {clip_names}, got {backbone}"
        return _clip_ebc(
            backbone=backbone,
            input_size=input_size,
            reduction=reduction,
            bins=bins,
            anchor_points=anchor_points,
            **kwargs
        )
    elif bins is None and anchor_points is None:
        return _regressor(
            backbone=backbone,
            input_size=input_size,
            reduction=reduction,
        )
    else:
        assert bins is not None and anchor_points is not None, f"Expected bins and anchor_points to be both None or not None, got {bins} and {anchor_points}"
        return _classifier(
            backbone=backbone,
            input_size=input_size,
            reduction=reduction,
            bins=bins,
            anchor_points=anchor_points,
        )


__all__ = [
    "get_model",
]