Add `device` argument to PyTorch Hub models (#3104)
Browse files* Allow to manual selection of device for torchhub models
* single line device
nested torch.device(torch.device(device)) ok
Co-authored-by: Glenn Jocher <[email protected]>
- hubconf.py +21 -20
hubconf.py
CHANGED
@@ -8,7 +8,7 @@ Usage:
|
|
8 |
import torch
|
9 |
|
10 |
|
11 |
-
def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbose=True):
|
12 |
"""Creates a specified YOLOv5 model
|
13 |
|
14 |
Arguments:
|
@@ -18,6 +18,7 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo
|
|
18 |
classes (int): number of model classes
|
19 |
autoshape (bool): apply YOLOv5 .autoshape() wrapper to model
|
20 |
verbose (bool): print all information to screen
|
|
|
21 |
|
22 |
Returns:
|
23 |
YOLOv5 pytorch model
|
@@ -50,7 +51,7 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo
|
|
50 |
model.names = ckpt['model'].names # set class names attribute
|
51 |
if autoshape:
|
52 |
model = model.autoshape() # for file/URI/PIL/cv2/np inputs and NMS
|
53 |
-
device = select_device('0' if torch.cuda.is_available() else 'cpu')
|
54 |
return model.to(device)
|
55 |
|
56 |
except Exception as e:
|
@@ -59,49 +60,49 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo
|
|
59 |
raise Exception(s) from e
|
60 |
|
61 |
|
62 |
-
def custom(path='path/to/model.pt', autoshape=True, verbose=True):
|
63 |
# YOLOv5 custom or local model
|
64 |
-
return _create(path, autoshape=autoshape, verbose=verbose)
|
65 |
|
66 |
|
67 |
-
def yolov5s(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True):
|
68 |
# YOLOv5-small model https://github.com/ultralytics/yolov5
|
69 |
-
return _create('yolov5s', pretrained, channels, classes, autoshape, verbose)
|
70 |
|
71 |
|
72 |
-
def yolov5m(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True):
|
73 |
# YOLOv5-medium model https://github.com/ultralytics/yolov5
|
74 |
-
return _create('yolov5m', pretrained, channels, classes, autoshape, verbose)
|
75 |
|
76 |
|
77 |
-
def yolov5l(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True):
|
78 |
# YOLOv5-large model https://github.com/ultralytics/yolov5
|
79 |
-
return _create('yolov5l', pretrained, channels, classes, autoshape, verbose)
|
80 |
|
81 |
|
82 |
-
def yolov5x(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True):
|
83 |
# YOLOv5-xlarge model https://github.com/ultralytics/yolov5
|
84 |
-
return _create('yolov5x', pretrained, channels, classes, autoshape, verbose)
|
85 |
|
86 |
|
87 |
-
def yolov5s6(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True):
|
88 |
# YOLOv5-small-P6 model https://github.com/ultralytics/yolov5
|
89 |
-
return _create('yolov5s6', pretrained, channels, classes, autoshape, verbose)
|
90 |
|
91 |
|
92 |
-
def yolov5m6(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True):
|
93 |
# YOLOv5-medium-P6 model https://github.com/ultralytics/yolov5
|
94 |
-
return _create('yolov5m6', pretrained, channels, classes, autoshape, verbose)
|
95 |
|
96 |
|
97 |
-
def yolov5l6(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True):
|
98 |
# YOLOv5-large-P6 model https://github.com/ultralytics/yolov5
|
99 |
-
return _create('yolov5l6', pretrained, channels, classes, autoshape, verbose)
|
100 |
|
101 |
|
102 |
-
def yolov5x6(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True):
|
103 |
# YOLOv5-xlarge-P6 model https://github.com/ultralytics/yolov5
|
104 |
-
return _create('yolov5x6', pretrained, channels, classes, autoshape, verbose)
|
105 |
|
106 |
|
107 |
if __name__ == '__main__':
|
|
|
8 |
import torch
|
9 |
|
10 |
|
11 |
+
def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
|
12 |
"""Creates a specified YOLOv5 model
|
13 |
|
14 |
Arguments:
|
|
|
18 |
classes (int): number of model classes
|
19 |
autoshape (bool): apply YOLOv5 .autoshape() wrapper to model
|
20 |
verbose (bool): print all information to screen
|
21 |
+
device (str, torch.device, None): device to use for model parameters
|
22 |
|
23 |
Returns:
|
24 |
YOLOv5 pytorch model
|
|
|
51 |
model.names = ckpt['model'].names # set class names attribute
|
52 |
if autoshape:
|
53 |
model = model.autoshape() # for file/URI/PIL/cv2/np inputs and NMS
|
54 |
+
device = select_device('0' if torch.cuda.is_available() else 'cpu') if device is None else torch.device(device)
|
55 |
return model.to(device)
|
56 |
|
57 |
except Exception as e:
|
|
|
60 |
raise Exception(s) from e
|
61 |
|
62 |
|
63 |
+
def custom(path='path/to/model.pt', autoshape=True, verbose=True, device=None):
|
64 |
# YOLOv5 custom or local model
|
65 |
+
return _create(path, autoshape=autoshape, verbose=verbose, device=device)
|
66 |
|
67 |
|
68 |
+
def yolov5s(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
|
69 |
# YOLOv5-small model https://github.com/ultralytics/yolov5
|
70 |
+
return _create('yolov5s', pretrained, channels, classes, autoshape, verbose, device)
|
71 |
|
72 |
|
73 |
+
def yolov5m(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
|
74 |
# YOLOv5-medium model https://github.com/ultralytics/yolov5
|
75 |
+
return _create('yolov5m', pretrained, channels, classes, autoshape, verbose, device)
|
76 |
|
77 |
|
78 |
+
def yolov5l(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
|
79 |
# YOLOv5-large model https://github.com/ultralytics/yolov5
|
80 |
+
return _create('yolov5l', pretrained, channels, classes, autoshape, verbose, device)
|
81 |
|
82 |
|
83 |
+
def yolov5x(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
|
84 |
# YOLOv5-xlarge model https://github.com/ultralytics/yolov5
|
85 |
+
return _create('yolov5x', pretrained, channels, classes, autoshape, verbose, device)
|
86 |
|
87 |
|
88 |
+
def yolov5s6(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
|
89 |
# YOLOv5-small-P6 model https://github.com/ultralytics/yolov5
|
90 |
+
return _create('yolov5s6', pretrained, channels, classes, autoshape, verbose, device)
|
91 |
|
92 |
|
93 |
+
def yolov5m6(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
|
94 |
# YOLOv5-medium-P6 model https://github.com/ultralytics/yolov5
|
95 |
+
return _create('yolov5m6', pretrained, channels, classes, autoshape, verbose, device)
|
96 |
|
97 |
|
98 |
+
def yolov5l6(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
|
99 |
# YOLOv5-large-P6 model https://github.com/ultralytics/yolov5
|
100 |
+
return _create('yolov5l6', pretrained, channels, classes, autoshape, verbose, device)
|
101 |
|
102 |
|
103 |
+
def yolov5x6(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
|
104 |
# YOLOv5-xlarge-P6 model https://github.com/ultralytics/yolov5
|
105 |
+
return _create('yolov5x6', pretrained, channels, classes, autoshape, verbose, device)
|
106 |
|
107 |
|
108 |
if __name__ == '__main__':
|