glenn-jocher commited on
Commit
87ca35b
·
unverified ·
1 Parent(s): 54043a9

Simplified PyTorch hub for custom models (#1677)

Browse files
Files changed (1) hide show
  1. hubconf.py +24 -1
hubconf.py CHANGED
@@ -106,8 +106,31 @@ def yolov5x(pretrained=False, channels=3, classes=80):
106
  return create('yolov5x', pretrained, channels, classes)
107
 
108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  if __name__ == '__main__':
110
- model = create(name='yolov5s', pretrained=True, channels=3, classes=80) # example
 
111
  model = model.autoshape() # for PIL/cv2/np inputs and NMS
112
 
113
  # Verify inference
 
106
  return create('yolov5x', pretrained, channels, classes)
107
 
108
 
109
+ def custom(model='path/to/model.pt'):
110
+ """YOLOv5-custom model from https://github.com/ultralytics/yolov5
111
+
112
+ Arguments (3 format options):
113
+ model (str): 'path/to/model.pt'
114
+ model (dict): torch.load('path/to/model.pt')
115
+ model (nn.Module): 'torch.load('path/to/model.pt')['model']
116
+
117
+ Returns:
118
+ pytorch model
119
+ """
120
+ if isinstance(model, str):
121
+ model = torch.load(model) # load checkpoint
122
+ if isinstance(model, dict):
123
+ model = model['model'] # load model
124
+
125
+ hub_model = Model(model.yaml).to(next(model.parameters()).device) # create
126
+ hub_model.load_state_dict(model.float().state_dict()) # load state_dict
127
+ hub_model.names = model.names # class names
128
+ return hub_model
129
+
130
+
131
  if __name__ == '__main__':
132
+ model = create(name='yolov5s', pretrained=True, channels=3, classes=80) # pretrained example
133
+ # model = custom(model='path/to/model.pt') # custom example
134
  model = model.autoshape() # for PIL/cv2/np inputs and NMS
135
 
136
  # Verify inference