File size: 1,040 Bytes
966ae59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# -*- coding: utf-8 -*-
# Copyright (c) XiMing Xing. All rights reserved.
# Author: XiMing Xing
# Description:

from typing import Callable
from tqdm.auto import tqdm


def tqdm_decorator(func: Callable):
    """A decorator function called tqdm_decorator that takes a function as an argument and
    returns a new function that wraps the input function with a tqdm progress bar.

    Noting: **The input function is assumed to have an object self as its first argument**, which contains a step attribute,
    an args attribute with a train_num_steps attribute, and an accelerator attribute with an is_main_process attribute.

    Args:
        func: tqdm_decorator

    Returns:
            a new function that wraps the input function with a tqdm progress bar.
    """

    def wrapper(*args, **kwargs):
        with tqdm(initial=args[0].step,
                  total=args[0].args.train_num_steps,
                  disable=not args[0].accelerator.is_main_process) as pbar:
            func(*args, **kwargs, pbar=pbar)

    return wrapper