Spaces:
Running
Running
from multiprocessing import Pool | |
import os | |
from typing import Callable, Iterable, Sized | |
from rich.progress import (BarColumn, MofNCompleteColumn, Progress, Task, | |
TaskProgressColumn, TextColumn, TimeRemainingColumn) | |
from rich.text import Text | |
import os.path as osp | |
import portalocker | |
from ..smp import load, dump | |
class _Worker: | |
"""Function wrapper for ``track_progress_rich``""" | |
def __init__(self, func) -> None: | |
self.func = func | |
def __call__(self, inputs): | |
inputs, idx = inputs | |
if not isinstance(inputs, (tuple, list, dict)): | |
inputs = (inputs, ) | |
if isinstance(inputs, dict): | |
return self.func(**inputs), idx | |
else: | |
return self.func(*inputs), idx | |
class _SkipFirstTimeRemainingColumn(TimeRemainingColumn): | |
"""Skip calculating remaining time for the first few times. | |
Args: | |
skip_times (int): The number of times to skip. Defaults to 0. | |
""" | |
def __init__(self, *args, skip_times=0, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.skip_times = skip_times | |
def render(self, task: Task) -> Text: | |
"""Show time remaining.""" | |
if task.completed <= self.skip_times: | |
return Text('-:--:--', style='progress.remaining') | |
return super().render(task) | |
def _tasks_with_index(tasks): | |
"""Add index to tasks.""" | |
for idx, task in enumerate(tasks): | |
yield task, idx | |
def track_progress_rich(func: Callable, | |
tasks: Iterable = tuple(), | |
task_num: int = None, | |
nproc: int = 1, | |
chunksize: int = 1, | |
description: str = 'Processing', | |
save=None, keys=None, | |
color: str = 'blue') -> list: | |
"""Track the progress of parallel task execution with a progress bar. The | |
built-in :mod:`multiprocessing` module is used for process pools and tasks | |
are done with :func:`Pool.map` or :func:`Pool.imap_unordered`. | |
Args: | |
func (callable): The function to be applied to each task. | |
tasks (Iterable or Sized): A tuple of tasks. There are several cases | |
for different format tasks: | |
- When ``func`` accepts no arguments: tasks should be an empty | |
tuple, and ``task_num`` must be specified. | |
- When ``func`` accepts only one argument: tasks should be a tuple | |
containing the argument. | |
- When ``func`` accepts multiple arguments: tasks should be a | |
tuple, with each element representing a set of arguments. | |
If an element is a ``dict``, it will be parsed as a set of | |
keyword-only arguments. | |
Defaults to an empty tuple. | |
task_num (int, optional): If ``tasks`` is an iterator which does not | |
have length, the number of tasks can be provided by ``task_num``. | |
Defaults to None. | |
nproc (int): Process (worker) number, if nuproc is 1, | |
use single process. Defaults to 1. | |
chunksize (int): Refer to :class:`multiprocessing.Pool` for details. | |
Defaults to 1. | |
description (str): The description of progress bar. | |
Defaults to "Process". | |
color (str): The color of progress bar. Defaults to "blue". | |
Examples: | |
>>> import time | |
>>> def func(x): | |
... time.sleep(1) | |
... return x**2 | |
>>> track_progress_rich(func, range(10), nproc=2) | |
Returns: | |
list: The task results. | |
""" | |
if save is not None: | |
assert osp.exists(osp.dirname(save)) or osp.dirname(save) == '' | |
if not osp.exists(save): | |
dump({}, save) | |
if keys is not None: | |
assert len(keys) == len(tasks) | |
if not callable(func): | |
raise TypeError('func must be a callable object') | |
if not isinstance(tasks, Iterable): | |
raise TypeError( | |
f'tasks must be an iterable object, but got {type(tasks)}') | |
if isinstance(tasks, Sized): | |
if len(tasks) == 0: | |
if task_num is None: | |
raise ValueError('If tasks is an empty iterable, ' | |
'task_num must be set') | |
else: | |
tasks = tuple(tuple() for _ in range(task_num)) | |
else: | |
if task_num is not None and task_num != len(tasks): | |
raise ValueError('task_num does not match the length of tasks') | |
task_num = len(tasks) | |
if nproc <= 0: | |
raise ValueError('nproc must be a positive number') | |
skip_times = nproc * chunksize if nproc > 1 else 0 | |
prog_bar = Progress( | |
TextColumn('{task.description}'), | |
BarColumn(), | |
_SkipFirstTimeRemainingColumn(skip_times=skip_times), | |
MofNCompleteColumn(), | |
TaskProgressColumn(show_speed=True), | |
) | |
worker = _Worker(func) | |
task_id = prog_bar.add_task( | |
total=task_num, color=color, description=description) | |
tasks = _tasks_with_index(tasks) | |
# Use single process when nproc is 1, else use multiprocess. | |
with prog_bar: | |
if nproc == 1: | |
results = [] | |
for task in tasks: | |
result, idx = worker(task) | |
results.append(result) | |
if save is not None: | |
with portalocker.Lock(save, timeout=5) as fh: | |
ans = load(save) | |
ans[keys[idx]] = result | |
if os.environ.get('VERBOSE', True): | |
print(keys[idx], result, flush=True) | |
dump(ans, save) | |
fh.flush() | |
os.fsync(fh.fileno()) | |
prog_bar.update(task_id, advance=1, refresh=True) | |
else: | |
with Pool(nproc) as pool: | |
results = [] | |
unordered_results = [] | |
gen = pool.imap_unordered(worker, tasks, chunksize) | |
try: | |
for result in gen: | |
result, idx = result | |
unordered_results.append((result, idx)) | |
if save is not None: | |
with portalocker.Lock(save, timeout=5) as fh: | |
ans = load(save) | |
ans[keys[idx]] = result | |
if os.environ.get('VERBOSE', False): | |
print(keys[idx], result, flush=True) | |
dump(ans, save) | |
fh.flush() | |
os.fsync(fh.fileno()) | |
results.append(None) | |
prog_bar.update(task_id, advance=1, refresh=True) | |
except Exception as e: | |
prog_bar.stop() | |
raise e | |
for result, idx in unordered_results: | |
results[idx] = result | |
return results | |