Demo750's picture
Upload folder using huggingface_hub
569f484 verified
from multiprocessing import Pool
import os
from typing import Callable, Iterable, Sized
from rich.progress import (BarColumn, MofNCompleteColumn, Progress, Task,
TaskProgressColumn, TextColumn, TimeRemainingColumn)
from rich.text import Text
import os.path as osp
import portalocker
from ..smp import load, dump
class _Worker:
"""Function wrapper for ``track_progress_rich``"""
def __init__(self, func) -> None:
self.func = func
def __call__(self, inputs):
inputs, idx = inputs
if not isinstance(inputs, (tuple, list, dict)):
inputs = (inputs, )
if isinstance(inputs, dict):
return self.func(**inputs), idx
else:
return self.func(*inputs), idx
class _SkipFirstTimeRemainingColumn(TimeRemainingColumn):
"""Skip calculating remaining time for the first few times.
Args:
skip_times (int): The number of times to skip. Defaults to 0.
"""
def __init__(self, *args, skip_times=0, **kwargs):
super().__init__(*args, **kwargs)
self.skip_times = skip_times
def render(self, task: Task) -> Text:
"""Show time remaining."""
if task.completed <= self.skip_times:
return Text('-:--:--', style='progress.remaining')
return super().render(task)
def _tasks_with_index(tasks):
"""Add index to tasks."""
for idx, task in enumerate(tasks):
yield task, idx
def track_progress_rich(func: Callable,
tasks: Iterable = tuple(),
task_num: int = None,
nproc: int = 1,
chunksize: int = 1,
description: str = 'Processing',
save=None, keys=None,
color: str = 'blue') -> list:
"""Track the progress of parallel task execution with a progress bar. The
built-in :mod:`multiprocessing` module is used for process pools and tasks
are done with :func:`Pool.map` or :func:`Pool.imap_unordered`.
Args:
func (callable): The function to be applied to each task.
tasks (Iterable or Sized): A tuple of tasks. There are several cases
for different format tasks:
- When ``func`` accepts no arguments: tasks should be an empty
tuple, and ``task_num`` must be specified.
- When ``func`` accepts only one argument: tasks should be a tuple
containing the argument.
- When ``func`` accepts multiple arguments: tasks should be a
tuple, with each element representing a set of arguments.
If an element is a ``dict``, it will be parsed as a set of
keyword-only arguments.
Defaults to an empty tuple.
task_num (int, optional): If ``tasks`` is an iterator which does not
have length, the number of tasks can be provided by ``task_num``.
Defaults to None.
nproc (int): Process (worker) number, if nuproc is 1,
use single process. Defaults to 1.
chunksize (int): Refer to :class:`multiprocessing.Pool` for details.
Defaults to 1.
description (str): The description of progress bar.
Defaults to "Process".
color (str): The color of progress bar. Defaults to "blue".
Examples:
>>> import time
>>> def func(x):
... time.sleep(1)
... return x**2
>>> track_progress_rich(func, range(10), nproc=2)
Returns:
list: The task results.
"""
if save is not None:
assert osp.exists(osp.dirname(save)) or osp.dirname(save) == ''
if not osp.exists(save):
dump({}, save)
if keys is not None:
assert len(keys) == len(tasks)
if not callable(func):
raise TypeError('func must be a callable object')
if not isinstance(tasks, Iterable):
raise TypeError(
f'tasks must be an iterable object, but got {type(tasks)}')
if isinstance(tasks, Sized):
if len(tasks) == 0:
if task_num is None:
raise ValueError('If tasks is an empty iterable, '
'task_num must be set')
else:
tasks = tuple(tuple() for _ in range(task_num))
else:
if task_num is not None and task_num != len(tasks):
raise ValueError('task_num does not match the length of tasks')
task_num = len(tasks)
if nproc <= 0:
raise ValueError('nproc must be a positive number')
skip_times = nproc * chunksize if nproc > 1 else 0
prog_bar = Progress(
TextColumn('{task.description}'),
BarColumn(),
_SkipFirstTimeRemainingColumn(skip_times=skip_times),
MofNCompleteColumn(),
TaskProgressColumn(show_speed=True),
)
worker = _Worker(func)
task_id = prog_bar.add_task(
total=task_num, color=color, description=description)
tasks = _tasks_with_index(tasks)
# Use single process when nproc is 1, else use multiprocess.
with prog_bar:
if nproc == 1:
results = []
for task in tasks:
result, idx = worker(task)
results.append(result)
if save is not None:
with portalocker.Lock(save, timeout=5) as fh:
ans = load(save)
ans[keys[idx]] = result
if os.environ.get('VERBOSE', True):
print(keys[idx], result, flush=True)
dump(ans, save)
fh.flush()
os.fsync(fh.fileno())
prog_bar.update(task_id, advance=1, refresh=True)
else:
with Pool(nproc) as pool:
results = []
unordered_results = []
gen = pool.imap_unordered(worker, tasks, chunksize)
try:
for result in gen:
result, idx = result
unordered_results.append((result, idx))
if save is not None:
with portalocker.Lock(save, timeout=5) as fh:
ans = load(save)
ans[keys[idx]] = result
if os.environ.get('VERBOSE', False):
print(keys[idx], result, flush=True)
dump(ans, save)
fh.flush()
os.fsync(fh.fileno())
results.append(None)
prog_bar.update(task_id, advance=1, refresh=True)
except Exception as e:
prog_bar.stop()
raise e
for result, idx in unordered_results:
results[idx] = result
return results