henry000 commited on
Commit
6aabc6c
·
1 Parent(s): 745aab9

🎨 [Update] progress and tqdm

Browse files
yolo/tools/log_helper.py CHANGED
@@ -47,7 +47,7 @@ class CustomProgress:
47
  self.wandb = wandb.init(project="YOLO", resume="allow", mode="online", dir="runs", name=cfg.name)
48
 
49
  def start_train(self, num_epochs: int):
50
- self.task_epoch = self.progress.add_task("[cyan]Epochs", total=num_epochs)
51
 
52
  def one_epoch(self):
53
  self.progress.update(self.task_epoch, advance=1)
@@ -63,9 +63,9 @@ class CustomProgress:
63
  for loss_name, loss_value in loss_dict.items():
64
  self.wandb.log({f"Loss/{loss_name}": loss_value})
65
 
66
- loss_str = "Loss"
67
  for loss_name, loss_val in loss_dict.items():
68
- loss_str += f" {loss_name[:-4]}: {loss_val:.2f} |"
69
 
70
  self.progress.update(self.batch_task, advance=1, description=f"[green]Batches [white]{loss_str}")
71
 
 
47
  self.wandb = wandb.init(project="YOLO", resume="allow", mode="online", dir="runs", name=cfg.name)
48
 
49
  def start_train(self, num_epochs: int):
50
+ self.task_epoch = self.progress.add_task("[cyan]Epochs [white]| Loss | Box | DFL | BCE |", total=num_epochs)
51
 
52
  def one_epoch(self):
53
  self.progress.update(self.task_epoch, advance=1)
 
63
  for loss_name, loss_value in loss_dict.items():
64
  self.wandb.log({f"Loss/{loss_name}": loss_value})
65
 
66
+ loss_str = "| -.-- |"
67
  for loss_name, loss_val in loss_dict.items():
68
+ loss_str += f" {loss_val:2.2f} |"
69
 
70
  self.progress.update(self.batch_task, advance=1, description=f"[green]Batches [white]{loss_str}")
71
 
yolo/utils/dataloader.py CHANGED
@@ -7,6 +7,7 @@ import numpy as np
7
  import torch
8
  from loguru import logger
9
  from PIL import Image
 
10
  from torch.utils.data import DataLoader, Dataset
11
  from torchvision.transforms import functional as TF
12
  from tqdm.rich import tqdm
@@ -74,7 +75,7 @@ class YoloDataset(Dataset):
74
 
75
  data = []
76
  valid_inputs = 0
77
- for image_name in tqdm(images_list, desc="Filtering data"):
78
  if not image_name.lower().endswith((".jpg", ".jpeg", ".png")):
79
  continue
80
  image_id, _ = path.splitext(image_name)
 
7
  import torch
8
  from loguru import logger
9
  from PIL import Image
10
+ from rich.progress import track
11
  from torch.utils.data import DataLoader, Dataset
12
  from torchvision.transforms import functional as TF
13
  from tqdm.rich import tqdm
 
75
 
76
  data = []
77
  valid_inputs = 0
78
+ for image_name in track(images_list, description="Filtering data"):
79
  if not image_name.lower().endswith((".jpg", ".jpeg", ".png")):
80
  continue
81
  image_id, _ = path.splitext(image_name)