| # Adapted with permission from the EdgeDB project; |
| # license: PSFL. |
| |
| |
| __all__ = ["TaskGroup"] |
| |
| from . import events |
| from . import exceptions |
| from . import tasks |
| |
| |
| class TaskGroup: |
| """Asynchronous context manager for managing groups of tasks. |
| |
| Example use: |
| |
| async with asyncio.TaskGroup() as group: |
| task1 = group.create_task(some_coroutine(...)) |
| task2 = group.create_task(other_coroutine(...)) |
| print("Both tasks have completed now.") |
| |
| All tasks are awaited when the context manager exits. |
| |
| Any exceptions other than `asyncio.CancelledError` raised within |
| a task will cancel all remaining tasks and wait for them to exit. |
| The exceptions are then combined and raised as an `ExceptionGroup`. |
| """ |
| def __init__(self): |
| self._entered = False |
| self._exiting = False |
| self._aborting = False |
| self._loop = None |
| self._parent_task = None |
| self._parent_cancel_requested = False |
| self._tasks = set() |
| self._errors = [] |
| self._base_error = None |
| self._on_completed_fut = None |
| |
| def __repr__(self): |
| info = [''] |
| if self._tasks: |
| info.append(f'tasks={len(self._tasks)}') |
| if self._errors: |
| info.append(f'errors={len(self._errors)}') |
| if self._aborting: |
| info.append('cancelling') |
| elif self._entered: |
| info.append('entered') |
| |
| info_str = ' '.join(info) |
| return f'<TaskGroup{info_str}>' |
| |
| async def __aenter__(self): |
| if self._entered: |
| raise RuntimeError( |
| f"TaskGroup {self!r} has been already entered") |
| self._entered = True |
| |
| if self._loop is None: |
| self._loop = events.get_running_loop() |
| |
| self._parent_task = tasks.current_task(self._loop) |
| if self._parent_task is None: |
| raise RuntimeError( |
| f'TaskGroup {self!r} cannot determine the parent task') |
| |
| return self |
| |
| async def __aexit__(self, et, exc, tb): |
| self._exiting = True |
| |
| if (exc is not None and |
| self._is_base_error(exc) and |
| self._base_error is None): |
| self._base_error = exc |
| |
| propagate_cancellation_error = \ |
| exc if et is exceptions.CancelledError else None |
| if self._parent_cancel_requested: |
| # If this flag is set we *must* call uncancel(). |
| if self._parent_task.uncancel() == 0: |
| # If there are no pending cancellations left, |
| # don't propagate CancelledError. |
| propagate_cancellation_error = None |
| |
| if et is not None: |
| if not self._aborting: |
| # Our parent task is being cancelled: |
| # |
| # async with TaskGroup() as g: |
| # g.create_task(...) |
| # await ... # <- CancelledError |
| # |
| # or there's an exception in "async with": |
| # |
| # async with TaskGroup() as g: |
| # g.create_task(...) |
| # 1 / 0 |
| # |
| self._abort() |
| |
| # We use while-loop here because "self._on_completed_fut" |
| # can be cancelled multiple times if our parent task |
| # is being cancelled repeatedly (or even once, when |
| # our own cancellation is already in progress) |
| while self._tasks: |
| if self._on_completed_fut is None: |
| self._on_completed_fut = self._loop.create_future() |
| |
| try: |
| await self._on_completed_fut |
| except exceptions.CancelledError as ex: |
| if not self._aborting: |
| # Our parent task is being cancelled: |
| # |
| # async def wrapper(): |
| # async with TaskGroup() as g: |
| # g.create_task(foo) |
| # |
| # "wrapper" is being cancelled while "foo" is |
| # still running. |
| propagate_cancellation_error = ex |
| self._abort() |
| |
| self._on_completed_fut = None |
| |
| assert not self._tasks |
| |
| if self._base_error is not None: |
| raise self._base_error |
| |
| # Propagate CancelledError if there is one, except if there |
| # are other errors -- those have priority. |
| if propagate_cancellation_error and not self._errors: |
| raise propagate_cancellation_error |
| |
| if et is not None and et is not exceptions.CancelledError: |
| self._errors.append(exc) |
| |
| if self._errors: |
| # Exceptions are heavy objects that can have object |
| # cycles (bad for GC); let's not keep a reference to |
| # a bunch of them. |
| try: |
| me = BaseExceptionGroup('unhandled errors in a TaskGroup', self._errors) |
| raise me from None |
| finally: |
| self._errors = None |
| |
| def create_task(self, coro, *, name=None, context=None): |
| """Create a new task in this group and return it. |
| |
| Similar to `asyncio.create_task`. |
| """ |
| if not self._entered: |
| raise RuntimeError(f"TaskGroup {self!r} has not been entered") |
| if self._exiting and not self._tasks: |
| raise RuntimeError(f"TaskGroup {self!r} is finished") |
| if self._aborting: |
| raise RuntimeError(f"TaskGroup {self!r} is shutting down") |
| if context is None: |
| task = self._loop.create_task(coro) |
| else: |
| task = self._loop.create_task(coro, context=context) |
| tasks._set_task_name(task, name) |
| task.add_done_callback(self._on_task_done) |
| self._tasks.add(task) |
| return task |
| |
| # Since Python 3.8 Tasks propagate all exceptions correctly, |
| # except for KeyboardInterrupt and SystemExit which are |
| # still considered special. |
| |
| def _is_base_error(self, exc: BaseException) -> bool: |
| assert isinstance(exc, BaseException) |
| return isinstance(exc, (SystemExit, KeyboardInterrupt)) |
| |
| def _abort(self): |
| self._aborting = True |
| |
| for t in self._tasks: |
| if not t.done(): |
| t.cancel() |
| |
| def _on_task_done(self, task): |
| self._tasks.discard(task) |
| |
| if self._on_completed_fut is not None and not self._tasks: |
| if not self._on_completed_fut.done(): |
| self._on_completed_fut.set_result(True) |
| |
| if task.cancelled(): |
| return |
| |
| exc = task.exception() |
| if exc is None: |
| return |
| |
| self._errors.append(exc) |
| if self._is_base_error(exc) and self._base_error is None: |
| self._base_error = exc |
| |
| if self._parent_task.done(): |
| # Not sure if this case is possible, but we want to handle |
| # it anyways. |
| self._loop.call_exception_handler({ |
| 'message': f'Task {task!r} has errored out but its parent ' |
| f'task {self._parent_task} is already completed', |
| 'exception': exc, |
| 'task': task, |
| }) |
| return |
| |
| if not self._aborting and not self._parent_cancel_requested: |
| # If parent task *is not* being cancelled, it means that we want |
| # to manually cancel it to abort whatever is being run right now |
| # in the TaskGroup. But we want to mark parent task as |
| # "not cancelled" later in __aexit__. Example situation that |
| # we need to handle: |
| # |
| # async def foo(): |
| # try: |
| # async with TaskGroup() as g: |
| # g.create_task(crash_soon()) |
| # await something # <- this needs to be canceled |
| # # by the TaskGroup, e.g. |
| # # foo() needs to be cancelled |
| # except Exception: |
| # # Ignore any exceptions raised in the TaskGroup |
| # pass |
| # await something_else # this line has to be called |
| # # after TaskGroup is finished. |
| self._abort() |
| self._parent_cancel_requested = True |
| self._parent_task.cancel() |