Simplified PyTorch hub for custom models (#1677)
Browse files- hubconf.py +24 -1
hubconf.py
CHANGED
@@ -106,8 +106,31 @@ def yolov5x(pretrained=False, channels=3, classes=80):
|
|
106 |
return create('yolov5x', pretrained, channels, classes)
|
107 |
|
108 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
if __name__ == '__main__':
|
110 |
-
model = create(name='yolov5s', pretrained=True, channels=3, classes=80) # example
|
|
|
111 |
model = model.autoshape() # for PIL/cv2/np inputs and NMS
|
112 |
|
113 |
# Verify inference
|
|
|
106 |
return create('yolov5x', pretrained, channels, classes)
|
107 |
|
108 |
|
109 |
+
def custom(model='path/to/model.pt'):
|
110 |
+
"""YOLOv5-custom model from https://github.com/ultralytics/yolov5
|
111 |
+
|
112 |
+
Arguments (3 format options):
|
113 |
+
model (str): 'path/to/model.pt'
|
114 |
+
model (dict): torch.load('path/to/model.pt')
|
115 |
+
model (nn.Module): 'torch.load('path/to/model.pt')['model']
|
116 |
+
|
117 |
+
Returns:
|
118 |
+
pytorch model
|
119 |
+
"""
|
120 |
+
if isinstance(model, str):
|
121 |
+
model = torch.load(model) # load checkpoint
|
122 |
+
if isinstance(model, dict):
|
123 |
+
model = model['model'] # load model
|
124 |
+
|
125 |
+
hub_model = Model(model.yaml).to(next(model.parameters()).device) # create
|
126 |
+
hub_model.load_state_dict(model.float().state_dict()) # load state_dict
|
127 |
+
hub_model.names = model.names # class names
|
128 |
+
return hub_model
|
129 |
+
|
130 |
+
|
131 |
if __name__ == '__main__':
|
132 |
+
model = create(name='yolov5s', pretrained=True, channels=3, classes=80) # pretrained example
|
133 |
+
# model = custom(model='path/to/model.pt') # custom example
|
134 |
model = model.autoshape() # for PIL/cv2/np inputs and NMS
|
135 |
|
136 |
# Verify inference
|