Suburst commited on
Commit
cd03dc0
·
verified ·
1 Parent(s): 021bdf4

Update Yolov5_Deepsort/models/yolo.py

Browse files
Files changed (1) hide show
  1. Yolov5_Deepsort/models/yolo.py +2 -1
Yolov5_Deepsort/models/yolo.py CHANGED
@@ -36,7 +36,7 @@ class Detect(nn.Module):
36
  a = torch.tensor(anchors).float().view(self.nl, -1, 2)
37
  self.register_buffer('anchors', a) # shape(nl,na,2)
38
  self.register_buffer('anchor_grid', a.clone().view(self.nl, 1, -1, 1, 1, 2)) # shape(nl,1,na,1,1,2)
39
- self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch).float() # output conv
40
  self.inplace = inplace # use in-place ops (e.g. slice assignment)
41
 
42
  def forward(self, x):
@@ -89,6 +89,7 @@ class Model(nn.Module):
89
  logger.info(f'Overriding model.yaml anchors with anchors={anchors}')
90
  self.yaml['anchors'] = round(anchors) # override yaml value
91
  self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist
 
92
  self.names = [str(i) for i in range(self.yaml['nc'])] # default names
93
  self.inplace = self.yaml.get('inplace', True)
94
  # logger.info([x.shape for x in self.forward(torch.zeros(1, ch, 64, 64))])
 
36
  a = torch.tensor(anchors).float().view(self.nl, -1, 2)
37
  self.register_buffer('anchors', a) # shape(nl,na,2)
38
  self.register_buffer('anchor_grid', a.clone().view(self.nl, 1, -1, 1, 1, 2)) # shape(nl,1,na,1,1,2)
39
+ self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1).float() for x in ch).float() # output conv
40
  self.inplace = inplace # use in-place ops (e.g. slice assignment)
41
 
42
  def forward(self, x):
 
89
  logger.info(f'Overriding model.yaml anchors with anchors={anchors}')
90
  self.yaml['anchors'] = round(anchors) # override yaml value
91
  self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist
92
+ self.model = self.model.float()
93
  self.names = [str(i) for i in range(self.yaml['nc'])] # default names
94
  self.inplace = self.yaml.get('inplace', True)
95
  # logger.info([x.shape for x in self.forward(torch.zeros(1, ch, 64, 64))])