PyTorch Hub custom model to CUDA device fix (#2636)
Browse files- hubconf.py +4 -1
hubconf.py
CHANGED
@@ -128,7 +128,10 @@ def custom(path_or_model='path/to/model.pt', autoshape=True):
|
|
128 |
hub_model = Model(model.yaml).to(next(model.parameters()).device) # create
|
129 |
hub_model.load_state_dict(model.float().state_dict()) # load state_dict
|
130 |
hub_model.names = model.names # class names
|
131 |
-
|
|
|
|
|
|
|
132 |
|
133 |
|
134 |
if __name__ == '__main__':
|
|
|
128 |
hub_model = Model(model.yaml).to(next(model.parameters()).device) # create
|
129 |
hub_model.load_state_dict(model.float().state_dict()) # load state_dict
|
130 |
hub_model.names = model.names # class names
|
131 |
+
if autoshape:
|
132 |
+
hub_model = hub_model.autoshape() # for file/URI/PIL/cv2/np inputs and NMS
|
133 |
+
device = select_device('0' if torch.cuda.is_available() else 'cpu') # default to GPU if available
|
134 |
+
return hub_model.to(device)
|
135 |
|
136 |
|
137 |
if __name__ == '__main__':
|