Spaces:
Running
Running
File size: 8,473 Bytes
1d777c4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 |
# Adapted with permission from the EdgeDB project;
# license: PSFL.
__all__ = ["TaskGroup"]
from . import events
from . import exceptions
from . import tasks
class TaskGroup:
"""Asynchronous context manager for managing groups of tasks.
Example use:
async with asyncio.TaskGroup() as group:
task1 = group.create_task(some_coroutine(...))
task2 = group.create_task(other_coroutine(...))
print("Both tasks have completed now.")
All tasks are awaited when the context manager exits.
Any exceptions other than `asyncio.CancelledError` raised within
a task will cancel all remaining tasks and wait for them to exit.
The exceptions are then combined and raised as an `ExceptionGroup`.
"""
def __init__(self):
self._entered = False
self._exiting = False
self._aborting = False
self._loop = None
self._parent_task = None
self._parent_cancel_requested = False
self._tasks = set()
self._errors = []
self._base_error = None
self._on_completed_fut = None
def __repr__(self):
info = ['']
if self._tasks:
info.append(f'tasks={len(self._tasks)}')
if self._errors:
info.append(f'errors={len(self._errors)}')
if self._aborting:
info.append('cancelling')
elif self._entered:
info.append('entered')
info_str = ' '.join(info)
return f'<TaskGroup{info_str}>'
async def __aenter__(self):
if self._entered:
raise RuntimeError(
f"TaskGroup {self!r} has been already entered")
self._entered = True
if self._loop is None:
self._loop = events.get_running_loop()
self._parent_task = tasks.current_task(self._loop)
if self._parent_task is None:
raise RuntimeError(
f'TaskGroup {self!r} cannot determine the parent task')
return self
async def __aexit__(self, et, exc, tb):
self._exiting = True
if (exc is not None and
self._is_base_error(exc) and
self._base_error is None):
self._base_error = exc
propagate_cancellation_error = \
exc if et is exceptions.CancelledError else None
if self._parent_cancel_requested:
# If this flag is set we *must* call uncancel().
if self._parent_task.uncancel() == 0:
# If there are no pending cancellations left,
# don't propagate CancelledError.
propagate_cancellation_error = None
if et is not None:
if not self._aborting:
# Our parent task is being cancelled:
#
# async with TaskGroup() as g:
# g.create_task(...)
# await ... # <- CancelledError
#
# or there's an exception in "async with":
#
# async with TaskGroup() as g:
# g.create_task(...)
# 1 / 0
#
self._abort()
# We use while-loop here because "self._on_completed_fut"
# can be cancelled multiple times if our parent task
# is being cancelled repeatedly (or even once, when
# our own cancellation is already in progress)
while self._tasks:
if self._on_completed_fut is None:
self._on_completed_fut = self._loop.create_future()
try:
await self._on_completed_fut
except exceptions.CancelledError as ex:
if not self._aborting:
# Our parent task is being cancelled:
#
# async def wrapper():
# async with TaskGroup() as g:
# g.create_task(foo)
#
# "wrapper" is being cancelled while "foo" is
# still running.
propagate_cancellation_error = ex
self._abort()
self._on_completed_fut = None
assert not self._tasks
if self._base_error is not None:
raise self._base_error
# Propagate CancelledError if there is one, except if there
# are other errors -- those have priority.
if propagate_cancellation_error and not self._errors:
raise propagate_cancellation_error
if et is not None and et is not exceptions.CancelledError:
self._errors.append(exc)
if self._errors:
# Exceptions are heavy objects that can have object
# cycles (bad for GC); let's not keep a reference to
# a bunch of them.
try:
me = BaseExceptionGroup('unhandled errors in a TaskGroup', self._errors)
raise me from None
finally:
self._errors = None
def create_task(self, coro, *, name=None, context=None):
"""Create a new task in this group and return it.
Similar to `asyncio.create_task`.
"""
if not self._entered:
raise RuntimeError(f"TaskGroup {self!r} has not been entered")
if self._exiting and not self._tasks:
raise RuntimeError(f"TaskGroup {self!r} is finished")
if self._aborting:
raise RuntimeError(f"TaskGroup {self!r} is shutting down")
if context is None:
task = self._loop.create_task(coro)
else:
task = self._loop.create_task(coro, context=context)
tasks._set_task_name(task, name)
task.add_done_callback(self._on_task_done)
self._tasks.add(task)
return task
# Since Python 3.8 Tasks propagate all exceptions correctly,
# except for KeyboardInterrupt and SystemExit which are
# still considered special.
def _is_base_error(self, exc: BaseException) -> bool:
assert isinstance(exc, BaseException)
return isinstance(exc, (SystemExit, KeyboardInterrupt))
def _abort(self):
self._aborting = True
for t in self._tasks:
if not t.done():
t.cancel()
def _on_task_done(self, task):
self._tasks.discard(task)
if self._on_completed_fut is not None and not self._tasks:
if not self._on_completed_fut.done():
self._on_completed_fut.set_result(True)
if task.cancelled():
return
exc = task.exception()
if exc is None:
return
self._errors.append(exc)
if self._is_base_error(exc) and self._base_error is None:
self._base_error = exc
if self._parent_task.done():
# Not sure if this case is possible, but we want to handle
# it anyways.
self._loop.call_exception_handler({
'message': f'Task {task!r} has errored out but its parent '
f'task {self._parent_task} is already completed',
'exception': exc,
'task': task,
})
return
if not self._aborting and not self._parent_cancel_requested:
# If parent task *is not* being cancelled, it means that we want
# to manually cancel it to abort whatever is being run right now
# in the TaskGroup. But we want to mark parent task as
# "not cancelled" later in __aexit__. Example situation that
# we need to handle:
#
# async def foo():
# try:
# async with TaskGroup() as g:
# g.create_task(crash_soon())
# await something # <- this needs to be canceled
# # by the TaskGroup, e.g.
# # foo() needs to be cancelled
# except Exception:
# # Ignore any exceptions raised in the TaskGroup
# pass
# await something_else # this line has to be called
# # after TaskGroup is finished.
self._abort()
self._parent_cancel_requested = True
self._parent_task.cancel()
|