Nick Konovalchuk
commited on
Commit
·
be88154
1
Parent(s):
5f8446e
feat(hub): loading a custom model with `torch.hub.load` (#1396)
Browse files- hubconf.py +3 -0
- yolox/models/build.py +44 -24
hubconf.py
CHANGED
@@ -5,6 +5,8 @@
|
|
5 |
Usage example:
|
6 |
import torch
|
7 |
model = torch.hub.load("Megvii-BaseDetection/YOLOX", "yolox_s")
|
|
|
|
|
8 |
"""
|
9 |
dependencies = ["torch"]
|
10 |
|
@@ -16,4 +18,5 @@ from yolox.models import ( # isort:skip # noqa: F401, E402
|
|
16 |
yolox_l,
|
17 |
yolox_x,
|
18 |
yolov3,
|
|
|
19 |
)
|
|
|
5 |
Usage example:
|
6 |
import torch
|
7 |
model = torch.hub.load("Megvii-BaseDetection/YOLOX", "yolox_s")
|
8 |
+
model = torch.hub.load("Megvii-BaseDetection/YOLOX", "yolox_custom",
|
9 |
+
exp_path="exp.py", ckpt_path="ckpt.pth")
|
10 |
"""
|
11 |
dependencies = ["torch"]
|
12 |
|
|
|
18 |
yolox_l,
|
19 |
yolox_x,
|
20 |
yolov3,
|
21 |
+
yolox_custom
|
22 |
)
|
yolox/models/build.py
CHANGED
@@ -14,6 +14,7 @@ __all__ = [
|
|
14 |
"yolox_l",
|
15 |
"yolox_x",
|
16 |
"yolov3",
|
|
|
17 |
]
|
18 |
|
19 |
_CKPT_ROOT_URL = "https://github.com/Megvii-BaseDetection/YOLOX/releases/download"
|
@@ -28,16 +29,20 @@ _CKPT_FULL_PATH = {
|
|
28 |
}
|
29 |
|
30 |
|
31 |
-
def create_yolox_model(
|
32 |
-
|
33 |
-
) -> nn.Module:
|
34 |
"""creates and loads a YOLOX model
|
35 |
|
36 |
Args:
|
37 |
-
name (str): name of model. for example, "yolox-s", "yolox-tiny"
|
|
|
38 |
pretrained (bool): load pretrained weights into the model. Default to True.
|
39 |
-
|
40 |
-
|
|
|
|
|
|
|
|
|
41 |
|
42 |
Returns:
|
43 |
YOLOX model (nn.Module)
|
@@ -48,44 +53,59 @@ def create_yolox_model(
|
|
48 |
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
49 |
device = torch.device(device)
|
50 |
|
51 |
-
assert name in _CKPT_FULL_PATH
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
ckpt =
|
60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
|
62 |
yolox_model.to(device)
|
63 |
return yolox_model
|
64 |
|
65 |
|
66 |
-
def yolox_nano(pretrained=True, num_classes=80, device=None):
|
67 |
return create_yolox_model("yolox-nano", pretrained, num_classes, device)
|
68 |
|
69 |
|
70 |
-
def yolox_tiny(pretrained=True, num_classes=80, device=None):
|
71 |
return create_yolox_model("yolox-tiny", pretrained, num_classes, device)
|
72 |
|
73 |
|
74 |
-
def yolox_s(pretrained=True, num_classes=80, device=None):
|
75 |
return create_yolox_model("yolox-s", pretrained, num_classes, device)
|
76 |
|
77 |
|
78 |
-
def yolox_m(pretrained=True, num_classes=80, device=None):
|
79 |
return create_yolox_model("yolox-m", pretrained, num_classes, device)
|
80 |
|
81 |
|
82 |
-
def yolox_l(pretrained=True, num_classes=80, device=None):
|
83 |
return create_yolox_model("yolox-l", pretrained, num_classes, device)
|
84 |
|
85 |
|
86 |
-
def yolox_x(pretrained=True, num_classes=80, device=None):
|
87 |
return create_yolox_model("yolox-x", pretrained, num_classes, device)
|
88 |
|
89 |
|
90 |
-
def yolov3(pretrained=True, num_classes=80, device=None):
|
91 |
-
return create_yolox_model("
|
|
|
|
|
|
|
|
|
|
14 |
"yolox_l",
|
15 |
"yolox_x",
|
16 |
"yolov3",
|
17 |
+
"yolox_custom"
|
18 |
]
|
19 |
|
20 |
_CKPT_ROOT_URL = "https://github.com/Megvii-BaseDetection/YOLOX/releases/download"
|
|
|
29 |
}
|
30 |
|
31 |
|
32 |
+
def create_yolox_model(name: str, pretrained: bool = True, num_classes: int = 80, device=None,
|
33 |
+
exp_path: str = None, ckpt_path: str = None) -> nn.Module:
|
|
|
34 |
"""creates and loads a YOLOX model
|
35 |
|
36 |
Args:
|
37 |
+
name (str): name of model. for example, "yolox-s", "yolox-tiny" or "yolox_custom"
|
38 |
+
if you want to load your own model.
|
39 |
pretrained (bool): load pretrained weights into the model. Default to True.
|
40 |
+
device (str): default device to for model. Default to None.
|
41 |
+
num_classes (int): number of model classes. Default to 80.
|
42 |
+
exp_path (str): path to your own experiment file. Required if name="yolox_custom"
|
43 |
+
ckpt_path (str): path to your own ckpt. Required if name="yolox_custom" and you want to
|
44 |
+
load a pretrained model
|
45 |
+
|
46 |
|
47 |
Returns:
|
48 |
YOLOX model (nn.Module)
|
|
|
53 |
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
54 |
device = torch.device(device)
|
55 |
|
56 |
+
assert name in _CKPT_FULL_PATH or name == "yolox_custom", \
|
57 |
+
f"user should use one of value in {_CKPT_FULL_PATH.keys()} or \"yolox_custom\""
|
58 |
+
if name in _CKPT_FULL_PATH:
|
59 |
+
exp: Exp = get_exp(exp_name=name)
|
60 |
+
exp.num_classes = num_classes
|
61 |
+
yolox_model = exp.get_model()
|
62 |
+
if pretrained and num_classes == 80:
|
63 |
+
weights_url = _CKPT_FULL_PATH[name]
|
64 |
+
ckpt = load_state_dict_from_url(weights_url, map_location="cpu")
|
65 |
+
if "model" in ckpt:
|
66 |
+
ckpt = ckpt["model"]
|
67 |
+
yolox_model.load_state_dict(ckpt)
|
68 |
+
else:
|
69 |
+
assert exp_path is not None, "for a \"yolox_custom\" model exp_path must be provided"
|
70 |
+
exp: Exp = get_exp(exp_file=exp_path)
|
71 |
+
yolox_model = exp.get_model()
|
72 |
+
if ckpt_path:
|
73 |
+
ckpt = torch.load(ckpt_path, map_location="cpu")
|
74 |
+
if "model" in ckpt:
|
75 |
+
ckpt = ckpt["model"]
|
76 |
+
yolox_model.load_state_dict(ckpt)
|
77 |
|
78 |
yolox_model.to(device)
|
79 |
return yolox_model
|
80 |
|
81 |
|
82 |
+
def yolox_nano(pretrained: bool = True, num_classes: int = 80, device: str = None) -> nn.Module:
|
83 |
return create_yolox_model("yolox-nano", pretrained, num_classes, device)
|
84 |
|
85 |
|
86 |
+
def yolox_tiny(pretrained: bool = True, num_classes: int = 80, device: str = None) -> nn.Module:
|
87 |
return create_yolox_model("yolox-tiny", pretrained, num_classes, device)
|
88 |
|
89 |
|
90 |
+
def yolox_s(pretrained: bool = True, num_classes: int = 80, device: str = None) -> nn.Module:
|
91 |
return create_yolox_model("yolox-s", pretrained, num_classes, device)
|
92 |
|
93 |
|
94 |
+
def yolox_m(pretrained: bool = True, num_classes: int = 80, device: str = None) -> nn.Module:
|
95 |
return create_yolox_model("yolox-m", pretrained, num_classes, device)
|
96 |
|
97 |
|
98 |
+
def yolox_l(pretrained: bool = True, num_classes: int = 80, device: str = None) -> nn.Module:
|
99 |
return create_yolox_model("yolox-l", pretrained, num_classes, device)
|
100 |
|
101 |
|
102 |
+
def yolox_x(pretrained: bool = True, num_classes: int = 80, device: str = None) -> nn.Module:
|
103 |
return create_yolox_model("yolox-x", pretrained, num_classes, device)
|
104 |
|
105 |
|
106 |
+
def yolov3(pretrained: bool = True, num_classes: int = 80, device: str = None) -> nn.Module:
|
107 |
+
return create_yolox_model("yolov3", pretrained, num_classes, device)
|
108 |
+
|
109 |
+
|
110 |
+
def yolox_custom(ckpt_path: str = None, exp_path: str = None, device: str = None) -> nn.Module:
|
111 |
+
return create_yolox_model("yolox_custom", ckpt_path=ckpt_path, exp_path=exp_path, device=device)
|