Anchor override (#2350)
Browse files- models/yolo.py +5 -2
- train.py +1 -1
models/yolo.py
CHANGED
@@ -62,7 +62,7 @@ class Detect(nn.Module):
|
|
62 |
|
63 |
|
64 |
class Model(nn.Module):
|
65 |
-
def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None): # model, input channels, number of classes
|
66 |
super(Model, self).__init__()
|
67 |
if isinstance(cfg, dict):
|
68 |
self.yaml = cfg # model dict
|
@@ -75,8 +75,11 @@ class Model(nn.Module):
|
|
75 |
# Define model
|
76 |
ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels
|
77 |
if nc and nc != self.yaml['nc']:
|
78 |
-
logger.info(
|
79 |
self.yaml['nc'] = nc # override yaml value
|
|
|
|
|
|
|
80 |
self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist
|
81 |
self.names = [str(i) for i in range(self.yaml['nc'])] # default names
|
82 |
# print([x.shape for x in self.forward(torch.zeros(1, ch, 64, 64))])
|
|
|
62 |
|
63 |
|
64 |
class Model(nn.Module):
|
65 |
+
def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None, anchors=None): # model, input channels, number of classes
|
66 |
super(Model, self).__init__()
|
67 |
if isinstance(cfg, dict):
|
68 |
self.yaml = cfg # model dict
|
|
|
75 |
# Define model
|
76 |
ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels
|
77 |
if nc and nc != self.yaml['nc']:
|
78 |
+
logger.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
|
79 |
self.yaml['nc'] = nc # override yaml value
|
80 |
+
if anchors:
|
81 |
+
logger.info(f'Overriding model.yaml anchors with anchors={anchors}')
|
82 |
+
self.yaml['anchors'] = round(anchors) # override yaml value
|
83 |
self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist
|
84 |
self.names = [str(i) for i in range(self.yaml['nc'])] # default names
|
85 |
# print([x.shape for x in self.forward(torch.zeros(1, ch, 64, 64))])
|
train.py
CHANGED
@@ -84,7 +84,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
|
|
84 |
model.load_state_dict(state_dict, strict=False) # load
|
85 |
logger.info('Transferred %g/%g items from %s' % (len(state_dict), len(model.state_dict()), weights)) # report
|
86 |
else:
|
87 |
-
model = Model(opt.cfg, ch=3, nc=nc).to(device) # create
|
88 |
|
89 |
# Freeze
|
90 |
freeze = [] # parameter names to freeze (full or partial)
|
|
|
84 |
model.load_state_dict(state_dict, strict=False) # load
|
85 |
logger.info('Transferred %g/%g items from %s' % (len(state_dict), len(model.state_dict()), weights)) # report
|
86 |
else:
|
87 |
+
model = Model(opt.cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
|
88 |
|
89 |
# Freeze
|
90 |
freeze = [] # parameter names to freeze (full or partial)
|