rename class autoShape -> AutoShape (#3173)
Browse files* rename class autoShape -> AutoShape
follow other class naming convention
* rename class autoShape -> AutoShape
follow other classes' naming convention
* rename class autoShape -> AutoShape
- models/common.py +3 -3
- models/yolo.py +3 -3
models/common.py
CHANGED
@@ -223,18 +223,18 @@ class NMS(nn.Module):
|
|
223 |
return non_max_suppression(x[0], conf_thres=self.conf, iou_thres=self.iou, classes=self.classes)
|
224 |
|
225 |
|
226 |
-
class
|
227 |
# input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS
|
228 |
conf = 0.25 # NMS confidence threshold
|
229 |
iou = 0.45 # NMS IoU threshold
|
230 |
classes = None # (optional list) filter by class
|
231 |
|
232 |
def __init__(self, model):
|
233 |
-
super(
|
234 |
self.model = model.eval()
|
235 |
|
236 |
def autoshape(self):
|
237 |
-
print('
|
238 |
return self
|
239 |
|
240 |
@torch.no_grad()
|
|
|
223 |
return non_max_suppression(x[0], conf_thres=self.conf, iou_thres=self.iou, classes=self.classes)
|
224 |
|
225 |
|
226 |
+
class AutoShape(nn.Module):
|
227 |
# input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS
|
228 |
conf = 0.25 # NMS confidence threshold
|
229 |
iou = 0.45 # NMS IoU threshold
|
230 |
classes = None # (optional list) filter by class
|
231 |
|
232 |
def __init__(self, model):
|
233 |
+
super(AutoShape, self).__init__()
|
234 |
self.model = model.eval()
|
235 |
|
236 |
def autoshape(self):
|
237 |
+
print('AutoShape already enabled, skipping... ') # model already converted to model.autoshape()
|
238 |
return self
|
239 |
|
240 |
@torch.no_grad()
|
models/yolo.py
CHANGED
@@ -215,9 +215,9 @@ class Model(nn.Module):
|
|
215 |
self.model = self.model[:-1] # remove
|
216 |
return self
|
217 |
|
218 |
-
def autoshape(self): # add
|
219 |
-
logger.info('Adding
|
220 |
-
m =
|
221 |
copy_attr(m, self, include=('yaml', 'nc', 'hyp', 'names', 'stride'), exclude=()) # copy attributes
|
222 |
return m
|
223 |
|
|
|
215 |
self.model = self.model[:-1] # remove
|
216 |
return self
|
217 |
|
218 |
+
def autoshape(self): # add AutoShape module
|
219 |
+
logger.info('Adding AutoShape... ')
|
220 |
+
m = AutoShape(self) # wrap model
|
221 |
copy_attr(m, self, include=('yaml', 'nc', 'hyp', 'names', 'stride'), exclude=()) # copy attributes
|
222 |
return m
|
223 |
|