PyTorch 1.7.0 Compatibility Updates (#1233)
Browse files* torch 1.7.0 compatibility updates
* add inference verification
- hubconf.py +8 -0
- models/experimental.py +7 -0
- models/yolo.py +0 -1
- utils/torch_utils.py +1 -1
hubconf.py
CHANGED
@@ -108,3 +108,11 @@ def yolov5x(pretrained=False, channels=3, classes=80):
|
|
108 |
|
109 |
if __name__ == '__main__':
|
110 |
model = create(name='yolov5s', pretrained=True, channels=3, classes=80) # example
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
|
109 |
if __name__ == '__main__':
|
110 |
model = create(name='yolov5s', pretrained=True, channels=3, classes=80) # example
|
111 |
+
model = model.fuse().eval().autoshape() # for autoshaping of PIL/cv2/np inputs and NMS
|
112 |
+
|
113 |
+
# Verify inference
|
114 |
+
from PIL import Image
|
115 |
+
|
116 |
+
img = Image.open('inference/images/zidane.jpg')
|
117 |
+
y = model(img)
|
118 |
+
print(y[0].shape)
|
models/experimental.py
CHANGED
@@ -136,6 +136,13 @@ def attempt_load(weights, map_location=None):
|
|
136 |
attempt_download(w)
|
137 |
model.append(torch.load(w, map_location=map_location)['model'].float().fuse().eval()) # load FP32 model
|
138 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
if len(model) == 1:
|
140 |
return model[-1] # return model
|
141 |
else:
|
|
|
136 |
attempt_download(w)
|
137 |
model.append(torch.load(w, map_location=map_location)['model'].float().fuse().eval()) # load FP32 model
|
138 |
|
139 |
+
# Compatibility updates
|
140 |
+
for m in model.modules():
|
141 |
+
if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6]:
|
142 |
+
m.inplace = True # pytorch 1.7.0 compatibility
|
143 |
+
elif type(m) is Conv:
|
144 |
+
m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility
|
145 |
+
|
146 |
if len(model) == 1:
|
147 |
return model[-1] # return model
|
148 |
else:
|
models/yolo.py
CHANGED
@@ -165,7 +165,6 @@ class Model(nn.Module):
|
|
165 |
print('Fusing layers... ')
|
166 |
for m in self.model.modules():
|
167 |
if type(m) is Conv and hasattr(m, 'bn'):
|
168 |
-
m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatability
|
169 |
m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
|
170 |
delattr(m, 'bn') # remove batchnorm
|
171 |
m.forward = m.fuseforward # update forward
|
|
|
165 |
print('Fusing layers... ')
|
166 |
for m in self.model.modules():
|
167 |
if type(m) is Conv and hasattr(m, 'bn'):
|
|
|
168 |
m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
|
169 |
delattr(m, 'bn') # remove batchnorm
|
170 |
m.forward = m.fuseforward # update forward
|
utils/torch_utils.py
CHANGED
@@ -74,7 +74,7 @@ def initialize_weights(model):
|
|
74 |
elif t is nn.BatchNorm2d:
|
75 |
m.eps = 1e-3
|
76 |
m.momentum = 0.03
|
77 |
-
elif t in [nn.LeakyReLU, nn.ReLU, nn.ReLU6]:
|
78 |
m.inplace = True
|
79 |
|
80 |
|
|
|
74 |
elif t is nn.BatchNorm2d:
|
75 |
m.eps = 1e-3
|
76 |
m.momentum = 0.03
|
77 |
+
elif t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6]:
|
78 |
m.inplace = True
|
79 |
|
80 |
|