|
import logging |
|
import os |
|
import random |
|
import sys |
|
import time |
|
from shutil import get_terminal_size |
|
|
|
import numpy as np |
|
import torch |
|
|
|
logger = logging.getLogger('base') |
|
|
|
|
|
def make_exp_dirs(opt): |
|
"""Make dirs for experiments.""" |
|
path_opt = opt['path'].copy() |
|
if opt['is_train']: |
|
overwrite = True if 'debug' in opt['name'] else False |
|
os.makedirs(path_opt.pop('experiments_root'), exist_ok=overwrite) |
|
os.makedirs(path_opt.pop('models'), exist_ok=overwrite) |
|
else: |
|
os.makedirs(path_opt.pop('results_root')) |
|
|
|
|
|
def set_random_seed(seed): |
|
"""Set random seeds.""" |
|
random.seed(seed) |
|
np.random.seed(seed) |
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed(seed) |
|
torch.cuda.manual_seed_all(seed) |
|
|
|
|
|
class ProgressBar(object): |
|
"""A progress bar which can print the progress. |
|
|
|
Modified from: |
|
https://github.com/hellock/cvbase/blob/master/cvbase/progress.py |
|
""" |
|
|
|
def __init__(self, task_num=0, bar_width=50, start=True): |
|
self.task_num = task_num |
|
max_bar_width = self._get_max_bar_width() |
|
self.bar_width = ( |
|
bar_width if bar_width <= max_bar_width else max_bar_width) |
|
self.completed = 0 |
|
if start: |
|
self.start() |
|
|
|
def _get_max_bar_width(self): |
|
terminal_width, _ = get_terminal_size() |
|
max_bar_width = min(int(terminal_width * 0.6), terminal_width - 50) |
|
if max_bar_width < 10: |
|
print(f'terminal width is too small ({terminal_width}), ' |
|
'please consider widen the terminal for better ' |
|
'progressbar visualization') |
|
max_bar_width = 10 |
|
return max_bar_width |
|
|
|
def start(self): |
|
if self.task_num > 0: |
|
sys.stdout.write(f"[{' ' * self.bar_width}] 0/{self.task_num}, " |
|
f'elapsed: 0s, ETA:\nStart...\n') |
|
else: |
|
sys.stdout.write('completed: 0, elapsed: 0s') |
|
sys.stdout.flush() |
|
self.start_time = time.time() |
|
|
|
def update(self, msg='In progress...'): |
|
self.completed += 1 |
|
elapsed = time.time() - self.start_time |
|
fps = self.completed / elapsed |
|
if self.task_num > 0: |
|
percentage = self.completed / float(self.task_num) |
|
eta = int(elapsed * (1 - percentage) / percentage + 0.5) |
|
mark_width = int(self.bar_width * percentage) |
|
bar_chars = '>' * mark_width + '-' * (self.bar_width - mark_width) |
|
sys.stdout.write('\033[2F') |
|
sys.stdout.write( |
|
'\033[J' |
|
) |
|
sys.stdout.write( |
|
f'[{bar_chars}] {self.completed}/{self.task_num}, ' |
|
f'{fps:.1f} task/s, elapsed: {int(elapsed + 0.5)}s, ' |
|
f'ETA: {eta:5}s\n{msg}\n') |
|
else: |
|
sys.stdout.write( |
|
f'completed: {self.completed}, elapsed: {int(elapsed + 0.5)}s, ' |
|
f'{fps:.1f} tasks/s') |
|
sys.stdout.flush() |
|
|
|
|
|
class AverageMeter(object): |
|
""" |
|
Computes and stores the average and current value |
|
Imported from |
|
https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262 |
|
""" |
|
|
|
def __init__(self): |
|
self.reset() |
|
|
|
def reset(self): |
|
self.val = 0 |
|
self.avg = 0 |
|
self.sum = 0 |
|
self.count = 0 |
|
|
|
def update(self, val, n=1): |
|
|
|
|
|
|
|
|
|
|
|
|
|
self.sum += val * n |
|
|
|
|
|
self.count += n |
|
|
|
|
|
|
|
self.avg = self.sum / self.count |
|
|