glenn-jocher commited on
Commit
a3ecf0f
·
unverified ·
1 Parent(s): fe6ebb9

Anchor override (#2350)

Browse files
Files changed (2) hide show
  1. models/yolo.py +5 -2
  2. train.py +1 -1
models/yolo.py CHANGED
@@ -62,7 +62,7 @@ class Detect(nn.Module):
62
 
63
 
64
  class Model(nn.Module):
65
- def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None): # model, input channels, number of classes
66
  super(Model, self).__init__()
67
  if isinstance(cfg, dict):
68
  self.yaml = cfg # model dict
@@ -75,8 +75,11 @@ class Model(nn.Module):
75
  # Define model
76
  ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels
77
  if nc and nc != self.yaml['nc']:
78
- logger.info('Overriding model.yaml nc=%g with nc=%g' % (self.yaml['nc'], nc))
79
  self.yaml['nc'] = nc # override yaml value
 
 
 
80
  self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist
81
  self.names = [str(i) for i in range(self.yaml['nc'])] # default names
82
  # print([x.shape for x in self.forward(torch.zeros(1, ch, 64, 64))])
 
62
 
63
 
64
  class Model(nn.Module):
65
+ def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None, anchors=None): # model, input channels, number of classes
66
  super(Model, self).__init__()
67
  if isinstance(cfg, dict):
68
  self.yaml = cfg # model dict
 
75
  # Define model
76
  ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels
77
  if nc and nc != self.yaml['nc']:
78
+ logger.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
79
  self.yaml['nc'] = nc # override yaml value
80
+ if anchors:
81
+ logger.info(f'Overriding model.yaml anchors with anchors={anchors}')
82
+ self.yaml['anchors'] = round(anchors) # override yaml value
83
  self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist
84
  self.names = [str(i) for i in range(self.yaml['nc'])] # default names
85
  # print([x.shape for x in self.forward(torch.zeros(1, ch, 64, 64))])
train.py CHANGED
@@ -84,7 +84,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
84
  model.load_state_dict(state_dict, strict=False) # load
85
  logger.info('Transferred %g/%g items from %s' % (len(state_dict), len(model.state_dict()), weights)) # report
86
  else:
87
- model = Model(opt.cfg, ch=3, nc=nc).to(device) # create
88
 
89
  # Freeze
90
  freeze = [] # parameter names to freeze (full or partial)
 
84
  model.load_state_dict(state_dict, strict=False) # load
85
  logger.info('Transferred %g/%g items from %s' % (len(state_dict), len(model.state_dict()), weights)) # report
86
  else:
87
+ model = Model(opt.cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
88
 
89
  # Freeze
90
  freeze = [] # parameter names to freeze (full or partial)