Suburst commited on
Commit
dd3bc12
·
verified ·
1 Parent(s): 6435abd

Update Yolov5_Deepsort/models/yolo.py

Browse files
Files changed (1) hide show
  1. Yolov5_Deepsort/models/yolo.py +7 -1
Yolov5_Deepsort/models/yolo.py CHANGED
@@ -89,7 +89,13 @@ class Model(nn.Module):
89
  logger.info(f'Overriding model.yaml anchors with anchors={anchors}')
90
  self.yaml['anchors'] = round(anchors) # override yaml value
91
  self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist
92
- self.model = self.model.float()
 
 
 
 
 
 
93
  self.names = [str(i) for i in range(self.yaml['nc'])] # default names
94
  self.inplace = self.yaml.get('inplace', True)
95
  # logger.info([x.shape for x in self.forward(torch.zeros(1, ch, 64, 64))])
 
89
  logger.info(f'Overriding model.yaml anchors with anchors={anchors}')
90
  self.yaml['anchors'] = round(anchors) # override yaml value
91
  self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist
92
+ for module in self.model.modules():
93
+ if hasattr(module, 'weight') and module.weight is not None:
94
+ module.weight.data = module.weight.data.float()
95
+ if hasattr(module, 'bias') and module.bias is not None:
96
+ module.bias.data = module.bias.data.float()
97
+
98
+ #self.model = self.model.float()
99
  self.names = [str(i) for i in range(self.yaml['nc'])] # default names
100
  self.inplace = self.yaml.get('inplace', True)
101
  # logger.info([x.shape for x in self.forward(torch.zeros(1, ch, 64, 64))])