Feng Wang commited on
Commit
27cdfe4
·
1 Parent(s): f609557

feat(model): support hub load

Browse files
Files changed (3) hide show
  1. hubconf.py +19 -0
  2. yolox/models/__init__.py +1 -0
  3. yolox/models/build.py +91 -0
hubconf.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding:utf-8 -*-
3
+
4
+ """
5
+ Usage example:
6
+ import torch
7
+ model = torch.hub.load("Megvii-BaseDetection/YOLOX", "yolox_s")
8
+ """
9
+ dependencies = ["torch"]
10
+
11
+ from yolox.models import ( # noqa: F401, E402
12
+ yolox_tiny,
13
+ yolox_nano,
14
+ yolox_s,
15
+ yolox_m,
16
+ yolox_l,
17
+ yolox_x,
18
+ yolov3,
19
+ )
yolox/models/__init__.py CHANGED
@@ -2,6 +2,7 @@
2
  # -*- coding:utf-8 -*-
3
  # Copyright (c) Megvii Inc. All rights reserved.
4
 
 
5
  from .darknet import CSPDarknet, Darknet
6
  from .losses import IOUloss
7
  from .yolo_fpn import YOLOFPN
 
2
  # -*- coding:utf-8 -*-
3
  # Copyright (c) Megvii Inc. All rights reserved.
4
 
5
+ from .build import *
6
  from .darknet import CSPDarknet, Darknet
7
  from .losses import IOUloss
8
  from .yolo_fpn import YOLOFPN
yolox/models/build.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding:utf-8 -*-
3
+
4
+ import torch
5
+ from torch import nn
6
+ from torch.hub import load_state_dict_from_url
7
+
8
+ __all__ = [
9
+ "create_yolox_model",
10
+ "yolox_nano",
11
+ "yolox_tiny",
12
+ "yolox_s",
13
+ "yolox_m",
14
+ "yolox_l",
15
+ "yolox_x",
16
+ "yolov3",
17
+ ]
18
+
19
+ _CKPT_ROOT_URL = "https://github.com/Megvii-BaseDetection/YOLOX/releases/download"
20
+ _CKPT_FULL_PATH = {
21
+ "yolox-nano": f"{_CKPT_ROOT_URL}/0.1.1rc0/yolox_nano.pth",
22
+ "yolox-tiny": f"{_CKPT_ROOT_URL}/0.1.1rc0/yolox_tiny.pth",
23
+ "yolox-s": f"{_CKPT_ROOT_URL}/0.1.1rc0/yolox_s.pth",
24
+ "yolox-m": f"{_CKPT_ROOT_URL}/0.1.1rc0/yolox_m.pth",
25
+ "yolox-l": f"{_CKPT_ROOT_URL}/0.1.1rc0/yolox_l.pth",
26
+ "yolox-x": f"{_CKPT_ROOT_URL}/0.1.1rc0/yolox_x.pth",
27
+ "yolov3": f"{_CKPT_ROOT_URL}/0.1.1rc0/yolox_darknet.pth",
28
+ }
29
+
30
+
31
+ def create_yolox_model(
32
+ name: str, pretrained: bool = True, num_classes: int = 80, device=None
33
+ ) -> nn.Module:
34
+ """creates and loads a YOLOX model
35
+
36
+ Args:
37
+ name (str): name of model. for example, "yolox-s", "yolox-tiny".
38
+ pretrained (bool): load pretrained weights into the model. Default to True.
39
+ num_classes (int): number of model classes. Defalut to 80.
40
+ device (str): default device to for model. Defalut to None.
41
+
42
+ Returns:
43
+ YOLOX model (nn.Module)
44
+ """
45
+ from yolox.exp import get_exp, Exp
46
+
47
+ if device is None:
48
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
49
+ device = torch.device(device)
50
+
51
+ assert name in _CKPT_FULL_PATH, f"user should use one of value in {_CKPT_FULL_PATH.keys()}"
52
+ exp: Exp = get_exp(exp_name=name)
53
+ exp.num_classes = num_classes
54
+ yolox_model = exp.get_model()
55
+ if pretrained and num_classes == 80:
56
+ weights_url = _CKPT_FULL_PATH[name]
57
+ ckpt = load_state_dict_from_url(weights_url, map_location="cpu")
58
+ if "model" in ckpt:
59
+ ckpt = ckpt["model"]
60
+ yolox_model.load_state_dict(ckpt)
61
+
62
+ yolox_model.to(device)
63
+ return yolox_model
64
+
65
+
66
+ def yolox_nano(pretrained=True, num_classes=80, device=None):
67
+ return create_yolox_model("yolox-nano", pretrained, num_classes, device)
68
+
69
+
70
+ def yolox_tiny(pretrained=True, num_classes=80, device=None):
71
+ return create_yolox_model("yolox-tiny", pretrained, num_classes, device)
72
+
73
+
74
+ def yolox_s(pretrained=True, num_classes=80, device=None):
75
+ return create_yolox_model("yolox-s", pretrained, num_classes, device)
76
+
77
+
78
+ def yolox_m(pretrained=True, num_classes=80, device=None):
79
+ return create_yolox_model("yolox-m", pretrained, num_classes, device)
80
+
81
+
82
+ def yolox_l(pretrained=True, num_classes=80, device=None):
83
+ return create_yolox_model("yolox-l", pretrained, num_classes, device)
84
+
85
+
86
+ def yolox_x(pretrained=True, num_classes=80, device=None):
87
+ return create_yolox_model("yolox-x", pretrained, num_classes, device)
88
+
89
+
90
+ def yolov3(pretrained=True, num_classes=80, device=None):
91
+ return create_yolox_model("yolox-tiny", pretrained, num_classes, device)