Update hubconf.py for unified loading (#3005)
Browse files- hubconf.py +7 -27
hubconf.py
CHANGED
@@ -18,7 +18,7 @@ dependencies = ['torch', 'yaml']
|
|
18 |
check_requirements(Path(__file__).parent / 'requirements.txt', exclude=('tensorboard', 'pycocotools', 'thop'))
|
19 |
|
20 |
|
21 |
-
def create(name, pretrained, channels, classes, autoshape, verbose):
|
22 |
"""Creates a specified YOLOv5 model
|
23 |
|
24 |
Arguments:
|
@@ -33,7 +33,7 @@ def create(name, pretrained, channels, classes, autoshape, verbose):
|
|
33 |
YOLOv5 pytorch model
|
34 |
"""
|
35 |
set_logging(verbose=verbose)
|
36 |
-
fname =
|
37 |
try:
|
38 |
if pretrained and channels == 3 and classes == 80:
|
39 |
model = attempt_load(fname, map_location=torch.device('cpu')) # download/load FP32 model
|
@@ -60,30 +60,9 @@ def create(name, pretrained, channels, classes, autoshape, verbose):
|
|
60 |
raise Exception(s) from e
|
61 |
|
62 |
|
63 |
-
def custom(
|
64 |
-
|
65 |
-
|
66 |
-
Arguments (3 options):
|
67 |
-
path_or_model (str): 'path/to/model.pt'
|
68 |
-
path_or_model (dict): torch.load('path/to/model.pt')
|
69 |
-
path_or_model (nn.Module): torch.load('path/to/model.pt')['model']
|
70 |
-
|
71 |
-
Returns:
|
72 |
-
pytorch model
|
73 |
-
"""
|
74 |
-
set_logging(verbose=verbose)
|
75 |
-
|
76 |
-
model = torch.load(path_or_model) if isinstance(path_or_model, str) else path_or_model # load checkpoint
|
77 |
-
if isinstance(model, dict):
|
78 |
-
model = model['ema' if model.get('ema') else 'model'] # load model
|
79 |
-
|
80 |
-
hub_model = Model(model.yaml).to(next(model.parameters()).device) # create
|
81 |
-
hub_model.load_state_dict(model.float().state_dict()) # load state_dict
|
82 |
-
hub_model.names = model.names # class names
|
83 |
-
if autoshape:
|
84 |
-
hub_model = hub_model.autoshape() # for file/URI/PIL/cv2/np inputs and NMS
|
85 |
-
device = select_device('0' if torch.cuda.is_available() else 'cpu') # default to GPU if available
|
86 |
-
return hub_model.to(device)
|
87 |
|
88 |
|
89 |
def yolov5s(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True):
|
@@ -127,7 +106,8 @@ def yolov5x6(pretrained=True, channels=3, classes=80, autoshape=True, verbose=Tr
|
|
127 |
|
128 |
|
129 |
if __name__ == '__main__':
|
130 |
-
model = create(name='yolov5s', pretrained=True, channels=3, classes=80, autoshape=True,
|
|
|
131 |
# model = custom(path_or_model='path/to/model.pt') # custom
|
132 |
|
133 |
# Verify inference
|
|
|
18 |
check_requirements(Path(__file__).parent / 'requirements.txt', exclude=('tensorboard', 'pycocotools', 'thop'))
|
19 |
|
20 |
|
21 |
+
def create(name, pretrained, channels=3, classes=80, autoshape=True, verbose=True):
|
22 |
"""Creates a specified YOLOv5 model
|
23 |
|
24 |
Arguments:
|
|
|
33 |
YOLOv5 pytorch model
|
34 |
"""
|
35 |
set_logging(verbose=verbose)
|
36 |
+
fname = Path(name).with_suffix('.pt') # checkpoint filename
|
37 |
try:
|
38 |
if pretrained and channels == 3 and classes == 80:
|
39 |
model = attempt_load(fname, map_location=torch.device('cpu')) # download/load FP32 model
|
|
|
60 |
raise Exception(s) from e
|
61 |
|
62 |
|
63 |
+
def custom(path='path/to/model.pt', autoshape=True, verbose=True):
|
64 |
+
# YOLOv5 custom or local model
|
65 |
+
return create(path, autoshape, verbose)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
|
67 |
|
68 |
def yolov5s(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True):
|
|
|
106 |
|
107 |
|
108 |
if __name__ == '__main__':
|
109 |
+
model = create(name='weights/yolov5s.pt', pretrained=True, channels=3, classes=80, autoshape=True,
|
110 |
+
verbose=True) # pretrained
|
111 |
# model = custom(path_or_model='path/to/model.pt') # custom
|
112 |
|
113 |
# Verify inference
|