henry000 commited on
Commit
dc55a8e
Β·
1 Parent(s): 710e371

πŸ’š [Fix] the amp autocast, make cpu available

Browse files
Files changed (1) hide show
  1. yolo/utils/loss.py +3 -2
yolo/utils/loss.py CHANGED
@@ -24,7 +24,8 @@ def get_loss_function(*args, **kwargs):
24
  class BCELoss(nn.Module):
25
  def __init__(self) -> None:
26
  super().__init__()
27
- self.bce = BCEWithLogitsLoss(pos_weight=torch.tensor([1.0], device=torch.device("cuda")), reduction="none")
 
28
 
29
  def forward(self, predicts_cls: Tensor, targets_cls: Tensor, cls_norm: Tensor) -> Any:
30
  return self.bce(predicts_cls, targets_cls).sum() / cls_norm
@@ -138,7 +139,7 @@ class YOLOLoss:
138
  anchors_box = anchors_box / self.scaler[None, :, None]
139
  return anchors_cls, anchors_box
140
 
141
- @torch.autocast("cuda")
142
  def __call__(self, predicts: List[Tensor], targets: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
143
  # Batch_Size x (Anchor + Class) x H x W
144
  # TODO: check datatype, why targets has a little bit error with origin version
 
24
  class BCELoss(nn.Module):
25
  def __init__(self) -> None:
26
  super().__init__()
27
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
+ self.bce = BCEWithLogitsLoss(pos_weight=torch.tensor([1.0], device=device), reduction="none")
29
 
30
  def forward(self, predicts_cls: Tensor, targets_cls: Tensor, cls_norm: Tensor) -> Any:
31
  return self.bce(predicts_cls, targets_cls).sum() / cls_norm
 
139
  anchors_box = anchors_box / self.scaler[None, :, None]
140
  return anchors_cls, anchors_box
141
 
142
+ @torch.autocast("cuda" if torch.cuda.is_available() else "cpu")
143
  def __call__(self, predicts: List[Tensor], targets: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
144
  # Batch_Size x (Anchor + Class) x H x W
145
  # TODO: check datatype, why targets has a little bit error with origin version