Nick Konovalchuk commited on
Commit
be88154
·
1 Parent(s): 5f8446e

feat(hub): loading a custom model with `torch.hub.load` (#1396)

Browse files
Files changed (2) hide show
  1. hubconf.py +3 -0
  2. yolox/models/build.py +44 -24
hubconf.py CHANGED
@@ -5,6 +5,8 @@
5
  Usage example:
6
  import torch
7
  model = torch.hub.load("Megvii-BaseDetection/YOLOX", "yolox_s")
 
 
8
  """
9
  dependencies = ["torch"]
10
 
@@ -16,4 +18,5 @@ from yolox.models import ( # isort:skip # noqa: F401, E402
16
  yolox_l,
17
  yolox_x,
18
  yolov3,
 
19
  )
 
5
  Usage example:
6
  import torch
7
  model = torch.hub.load("Megvii-BaseDetection/YOLOX", "yolox_s")
8
+ model = torch.hub.load("Megvii-BaseDetection/YOLOX", "yolox_custom",
9
+ exp_path="exp.py", ckpt_path="ckpt.pth")
10
  """
11
  dependencies = ["torch"]
12
 
 
18
  yolox_l,
19
  yolox_x,
20
  yolov3,
21
+ yolox_custom
22
  )
yolox/models/build.py CHANGED
@@ -14,6 +14,7 @@ __all__ = [
14
  "yolox_l",
15
  "yolox_x",
16
  "yolov3",
 
17
  ]
18
 
19
  _CKPT_ROOT_URL = "https://github.com/Megvii-BaseDetection/YOLOX/releases/download"
@@ -28,16 +29,20 @@ _CKPT_FULL_PATH = {
28
  }
29
 
30
 
31
- def create_yolox_model(
32
- name: str, pretrained: bool = True, num_classes: int = 80, device=None
33
- ) -> nn.Module:
34
  """creates and loads a YOLOX model
35
 
36
  Args:
37
- name (str): name of model. for example, "yolox-s", "yolox-tiny".
 
38
  pretrained (bool): load pretrained weights into the model. Default to True.
39
- num_classes (int): number of model classes. Defalut to 80.
40
- device (str): default device to for model. Defalut to None.
 
 
 
 
41
 
42
  Returns:
43
  YOLOX model (nn.Module)
@@ -48,44 +53,59 @@ def create_yolox_model(
48
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
49
  device = torch.device(device)
50
 
51
- assert name in _CKPT_FULL_PATH, f"user should use one of value in {_CKPT_FULL_PATH.keys()}"
52
- exp: Exp = get_exp(exp_name=name)
53
- exp.num_classes = num_classes
54
- yolox_model = exp.get_model()
55
- if pretrained and num_classes == 80:
56
- weights_url = _CKPT_FULL_PATH[name]
57
- ckpt = load_state_dict_from_url(weights_url, map_location="cpu")
58
- if "model" in ckpt:
59
- ckpt = ckpt["model"]
60
- yolox_model.load_state_dict(ckpt)
 
 
 
 
 
 
 
 
 
 
 
61
 
62
  yolox_model.to(device)
63
  return yolox_model
64
 
65
 
66
- def yolox_nano(pretrained=True, num_classes=80, device=None):
67
  return create_yolox_model("yolox-nano", pretrained, num_classes, device)
68
 
69
 
70
- def yolox_tiny(pretrained=True, num_classes=80, device=None):
71
  return create_yolox_model("yolox-tiny", pretrained, num_classes, device)
72
 
73
 
74
- def yolox_s(pretrained=True, num_classes=80, device=None):
75
  return create_yolox_model("yolox-s", pretrained, num_classes, device)
76
 
77
 
78
- def yolox_m(pretrained=True, num_classes=80, device=None):
79
  return create_yolox_model("yolox-m", pretrained, num_classes, device)
80
 
81
 
82
- def yolox_l(pretrained=True, num_classes=80, device=None):
83
  return create_yolox_model("yolox-l", pretrained, num_classes, device)
84
 
85
 
86
- def yolox_x(pretrained=True, num_classes=80, device=None):
87
  return create_yolox_model("yolox-x", pretrained, num_classes, device)
88
 
89
 
90
- def yolov3(pretrained=True, num_classes=80, device=None):
91
- return create_yolox_model("yolox-tiny", pretrained, num_classes, device)
 
 
 
 
 
14
  "yolox_l",
15
  "yolox_x",
16
  "yolov3",
17
+ "yolox_custom"
18
  ]
19
 
20
  _CKPT_ROOT_URL = "https://github.com/Megvii-BaseDetection/YOLOX/releases/download"
 
29
  }
30
 
31
 
32
+ def create_yolox_model(name: str, pretrained: bool = True, num_classes: int = 80, device=None,
33
+ exp_path: str = None, ckpt_path: str = None) -> nn.Module:
 
34
  """creates and loads a YOLOX model
35
 
36
  Args:
37
+ name (str): name of model. for example, "yolox-s", "yolox-tiny" or "yolox_custom"
38
+ if you want to load your own model.
39
  pretrained (bool): load pretrained weights into the model. Default to True.
40
+ device (str): default device to for model. Default to None.
41
+ num_classes (int): number of model classes. Default to 80.
42
+ exp_path (str): path to your own experiment file. Required if name="yolox_custom"
43
+ ckpt_path (str): path to your own ckpt. Required if name="yolox_custom" and you want to
44
+ load a pretrained model
45
+
46
 
47
  Returns:
48
  YOLOX model (nn.Module)
 
53
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
54
  device = torch.device(device)
55
 
56
+ assert name in _CKPT_FULL_PATH or name == "yolox_custom", \
57
+ f"user should use one of value in {_CKPT_FULL_PATH.keys()} or \"yolox_custom\""
58
+ if name in _CKPT_FULL_PATH:
59
+ exp: Exp = get_exp(exp_name=name)
60
+ exp.num_classes = num_classes
61
+ yolox_model = exp.get_model()
62
+ if pretrained and num_classes == 80:
63
+ weights_url = _CKPT_FULL_PATH[name]
64
+ ckpt = load_state_dict_from_url(weights_url, map_location="cpu")
65
+ if "model" in ckpt:
66
+ ckpt = ckpt["model"]
67
+ yolox_model.load_state_dict(ckpt)
68
+ else:
69
+ assert exp_path is not None, "for a \"yolox_custom\" model exp_path must be provided"
70
+ exp: Exp = get_exp(exp_file=exp_path)
71
+ yolox_model = exp.get_model()
72
+ if ckpt_path:
73
+ ckpt = torch.load(ckpt_path, map_location="cpu")
74
+ if "model" in ckpt:
75
+ ckpt = ckpt["model"]
76
+ yolox_model.load_state_dict(ckpt)
77
 
78
  yolox_model.to(device)
79
  return yolox_model
80
 
81
 
82
+ def yolox_nano(pretrained: bool = True, num_classes: int = 80, device: str = None) -> nn.Module:
83
  return create_yolox_model("yolox-nano", pretrained, num_classes, device)
84
 
85
 
86
+ def yolox_tiny(pretrained: bool = True, num_classes: int = 80, device: str = None) -> nn.Module:
87
  return create_yolox_model("yolox-tiny", pretrained, num_classes, device)
88
 
89
 
90
+ def yolox_s(pretrained: bool = True, num_classes: int = 80, device: str = None) -> nn.Module:
91
  return create_yolox_model("yolox-s", pretrained, num_classes, device)
92
 
93
 
94
+ def yolox_m(pretrained: bool = True, num_classes: int = 80, device: str = None) -> nn.Module:
95
  return create_yolox_model("yolox-m", pretrained, num_classes, device)
96
 
97
 
98
+ def yolox_l(pretrained: bool = True, num_classes: int = 80, device: str = None) -> nn.Module:
99
  return create_yolox_model("yolox-l", pretrained, num_classes, device)
100
 
101
 
102
+ def yolox_x(pretrained: bool = True, num_classes: int = 80, device: str = None) -> nn.Module:
103
  return create_yolox_model("yolox-x", pretrained, num_classes, device)
104
 
105
 
106
+ def yolov3(pretrained: bool = True, num_classes: int = 80, device: str = None) -> nn.Module:
107
+ return create_yolox_model("yolov3", pretrained, num_classes, device)
108
+
109
+
110
+ def yolox_custom(ckpt_path: str = None, exp_path: str = None, device: str = None) -> nn.Module:
111
+ return create_yolox_model("yolox_custom", ckpt_path=ckpt_path, exp_path=exp_path, device=device)