hjc-owo
init repo
966ae59
# -*- coding: utf-8 -*-
# Copyright (c) XiMing Xing. All rights reserved.
# Author: XiMing Xing
# Description:
from typing import Callable
from tqdm.auto import tqdm
def tqdm_decorator(func: Callable):
"""A decorator function called tqdm_decorator that takes a function as an argument and
returns a new function that wraps the input function with a tqdm progress bar.
Noting: **The input function is assumed to have an object self as its first argument**, which contains a step attribute,
an args attribute with a train_num_steps attribute, and an accelerator attribute with an is_main_process attribute.
Args:
func: tqdm_decorator
Returns:
a new function that wraps the input function with a tqdm progress bar.
"""
def wrapper(*args, **kwargs):
with tqdm(initial=args[0].step,
total=args[0].args.train_num_steps,
disable=not args[0].accelerator.is_main_process) as pbar:
func(*args, **kwargs, pbar=pbar)
return wrapper