developer0hye commited on
Commit
be86c21
·
unverified ·
1 Parent(s): 17b0f71

rename class autoShape -> AutoShape (#3173)

Browse files

* rename class autoShape -> AutoShape

follow other class naming convention

* rename class autoShape -> AutoShape

follow other classes' naming convention

* rename class autoShape -> AutoShape

Files changed (2) hide show
  1. models/common.py +3 -3
  2. models/yolo.py +3 -3
models/common.py CHANGED
@@ -223,18 +223,18 @@ class NMS(nn.Module):
223
  return non_max_suppression(x[0], conf_thres=self.conf, iou_thres=self.iou, classes=self.classes)
224
 
225
 
226
- class autoShape(nn.Module):
227
  # input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS
228
  conf = 0.25 # NMS confidence threshold
229
  iou = 0.45 # NMS IoU threshold
230
  classes = None # (optional list) filter by class
231
 
232
  def __init__(self, model):
233
- super(autoShape, self).__init__()
234
  self.model = model.eval()
235
 
236
  def autoshape(self):
237
- print('autoShape already enabled, skipping... ') # model already converted to model.autoshape()
238
  return self
239
 
240
  @torch.no_grad()
 
223
  return non_max_suppression(x[0], conf_thres=self.conf, iou_thres=self.iou, classes=self.classes)
224
 
225
 
226
+ class AutoShape(nn.Module):
227
  # input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS
228
  conf = 0.25 # NMS confidence threshold
229
  iou = 0.45 # NMS IoU threshold
230
  classes = None # (optional list) filter by class
231
 
232
  def __init__(self, model):
233
+ super(AutoShape, self).__init__()
234
  self.model = model.eval()
235
 
236
  def autoshape(self):
237
+ print('AutoShape already enabled, skipping... ') # model already converted to model.autoshape()
238
  return self
239
 
240
  @torch.no_grad()
models/yolo.py CHANGED
@@ -215,9 +215,9 @@ class Model(nn.Module):
215
  self.model = self.model[:-1] # remove
216
  return self
217
 
218
- def autoshape(self): # add autoShape module
219
- logger.info('Adding autoShape... ')
220
- m = autoShape(self) # wrap model
221
  copy_attr(m, self, include=('yaml', 'nc', 'hyp', 'names', 'stride'), exclude=()) # copy attributes
222
  return m
223
 
 
215
  self.model = self.model[:-1] # remove
216
  return self
217
 
218
+ def autoshape(self): # add AutoShape module
219
+ logger.info('Adding AutoShape... ')
220
+ m = AutoShape(self) # wrap model
221
  copy_attr(m, self, include=('yaml', 'nc', 'hyp', 'names', 'stride'), exclude=()) # copy attributes
222
  return m
223