glenn-jocher commited on
Commit
ee16983
·
unverified ·
1 Parent(s): 2e95cf3

PyTorch Hub custom model to CUDA device fix (#2636)

Browse files
Files changed (1) hide show
  1. hubconf.py +4 -1
hubconf.py CHANGED
@@ -128,7 +128,10 @@ def custom(path_or_model='path/to/model.pt', autoshape=True):
128
  hub_model = Model(model.yaml).to(next(model.parameters()).device) # create
129
  hub_model.load_state_dict(model.float().state_dict()) # load state_dict
130
  hub_model.names = model.names # class names
131
- return hub_model.autoshape() if autoshape else hub_model
 
 
 
132
 
133
 
134
  if __name__ == '__main__':
 
128
  hub_model = Model(model.yaml).to(next(model.parameters()).device) # create
129
  hub_model.load_state_dict(model.float().state_dict()) # load state_dict
130
  hub_model.names = model.names # class names
131
+ if autoshape:
132
+ hub_model = hub_model.autoshape() # for file/URI/PIL/cv2/np inputs and NMS
133
+ device = select_device('0' if torch.cuda.is_available() else 'cpu') # default to GPU if available
134
+ return hub_model.to(device)
135
 
136
 
137
  if __name__ == '__main__':