Martin Cerman Martin Cerman commited on
Commit
dea5a8a
·
unverified ·
1 Parent(s): 8228669

🔨 [FIX] Fixes memory leak (#83)

Browse files

* Fixes memory leak

* Changed total_loss to use float type and adjusted collection of loss

---------

Co-authored-by: Martin Cerman <[email protected]>

Files changed (1) hide show
  1. yolo/tools/solver.py +2 -2
yolo/tools/solver.py CHANGED
@@ -86,7 +86,7 @@ class ModelTrainer:
86
 
87
  def train_one_epoch(self, dataloader):
88
  self.model.train()
89
- total_loss = defaultdict(lambda: torch.tensor(0.0, device=self.device))
90
  total_samples = 0
91
  self.optimizer.next_epoch(len(dataloader))
92
  for batch_size, images, targets, *_ in dataloader:
@@ -96,7 +96,7 @@ class ModelTrainer:
96
  for loss_name, loss_val in loss_each.items():
97
  if self.use_ddp: # collecting loss for each batch
98
  distributed.all_reduce(loss_val, op=distributed.ReduceOp.AVG)
99
- total_loss[loss_name] += loss_val * batch_size
100
  total_samples += batch_size
101
  self.progress.one_batch(loss_each)
102
 
 
86
 
87
  def train_one_epoch(self, dataloader):
88
  self.model.train()
89
+ total_loss = defaultdict(float)
90
  total_samples = 0
91
  self.optimizer.next_epoch(len(dataloader))
92
  for batch_size, images, targets, *_ in dataloader:
 
96
  for loss_name, loss_val in loss_each.items():
97
  if self.use_ddp: # collecting loss for each batch
98
  distributed.all_reduce(loss_val, op=distributed.ReduceOp.AVG)
99
+ total_loss[loss_name] += loss_val.item() * batch_size
100
  total_samples += batch_size
101
  self.progress.one_batch(loss_each)
102