| # Copyright (c) Facebook, Inc. and its affiliates. |
| # All rights reserved. |
| # |
| # This source code is licensed under the BSD-style license found in the |
| # LICENSE file in the root directory of this source tree. |
| import abc |
| import logging |
| import threading |
| import time |
| from contextlib import contextmanager |
| from inspect import getframeinfo, stack |
| from typing import Any, Dict, List, Optional, Set |
| |
| |
| class TimerRequest: |
| """ |
| Data object representing a countdown timer acquisition and release |
| that is used between the ``TimerClient`` and ``TimerServer``. |
| A negative ``expiration_time`` should be interpreted as a "release" |
| request. |
| |
| .. note:: the type of ``worker_id`` is implementation specific. |
| It is whatever the TimerServer and TimerClient implementations |
| have on to uniquely identify a worker. |
| """ |
| |
| __slots__ = ["worker_id", "scope_id", "expiration_time"] |
| |
| def __init__(self, worker_id: Any, scope_id: str, expiration_time: float): |
| self.worker_id = worker_id |
| self.scope_id = scope_id |
| self.expiration_time = expiration_time |
| |
| def __eq__(self, other): |
| if isinstance(other, TimerRequest): |
| return ( |
| self.worker_id == other.worker_id |
| and self.scope_id == other.scope_id |
| and self.expiration_time == other.expiration_time |
| ) |
| return False |
| |
| |
| class TimerClient(abc.ABC): |
| """ |
| Client library to acquire and release countdown timers by communicating |
| with the TimerServer. |
| """ |
| |
| @abc.abstractmethod |
| def acquire(self, scope_id: str, expiration_time: float) -> None: |
| """ |
| Acquires a timer for the worker that holds this client object |
| given the scope_id and expiration_time. Typically registers |
| the timer with the TimerServer. |
| """ |
| pass |
| |
| @abc.abstractmethod |
| def release(self, scope_id: str): |
| """ |
| Releases the timer for the ``scope_id`` on the worker this |
| client represents. After this method is |
| called, the countdown timer on the scope is no longer in effect. |
| """ |
| pass |
| |
| |
| class RequestQueue(abc.ABC): |
| """ |
| Consumer queue holding timer acquisition/release requests |
| """ |
| |
| @abc.abstractmethod |
| def size(self) -> int: |
| """ |
| Returns the size of the queue at the time this method is called. |
| Note that by the time ``get`` is called the size of the queue |
| may have increased. The size of the queue should not decrease |
| until the ``get`` method is called. That is, the following assertion |
| should hold: |
| |
| size = q.size() |
| res = q.get(size, timeout=0) |
| assert size == len(res) |
| |
| -- or -- |
| |
| size = q.size() |
| res = q.get(size * 2, timeout=1) |
| assert size <= len(res) <= size * 2 |
| """ |
| pass |
| |
| @abc.abstractmethod |
| def get(self, size: int, timeout: float) -> List[TimerRequest]: |
| """ |
| Gets up to ``size`` number of timer requests in a blocking fashion |
| (no more than ``timeout`` seconds). |
| """ |
| pass |
| |
| |
| class TimerServer(abc.ABC): |
| """ |
| Entity that monitors active timers and expires them |
| in a timely fashion. This server is responsible for |
| reaping workers that have expired timers. |
| """ |
| |
| def __init__( |
| self, request_queue: RequestQueue, max_interval: float, daemon: bool = True |
| ): |
| """ |
| :param request_queue: Consumer ``RequestQueue`` |
| :param max_interval: max time (in seconds) to wait |
| for an item in the request_queue |
| :param daemon: whether to run the watchdog thread as a daemon |
| """ |
| super().__init__() |
| self._request_queue = request_queue |
| self._max_interval = max_interval |
| self._daemon = daemon |
| self._watchdog_thread: Optional[threading.Thread] = None |
| self._stop_signaled = False |
| |
| @abc.abstractmethod |
| def register_timers(self, timer_requests: List[TimerRequest]) -> None: |
| """ |
| Processes the incoming timer requests and registers them with the server. |
| The timer request can either be a acquire-timer or release-timer request. |
| Timer requests with a negative expiration_time should be interpreted |
| as a release-timer request. |
| """ |
| pass |
| |
| @abc.abstractmethod |
| def clear_timers(self, worker_ids: Set[Any]) -> None: |
| """ |
| Clears all timers for the given ``worker_ids``. |
| """ |
| pass |
| |
| @abc.abstractmethod |
| def get_expired_timers(self, deadline: float) -> Dict[str, List[TimerRequest]]: |
| """ |
| Returns all expired timers for each worker_id. An expired timer |
| is a timer for which the expiration_time is less than or equal to |
| the provided deadline. |
| """ |
| pass |
| |
| @abc.abstractmethod |
| def _reap_worker(self, worker_id: Any) -> bool: |
| """ |
| Reaps the given worker. Returns True if the worker has been |
| successfully reaped, False otherwise. If any uncaught exception |
| is thrown from this method, the worker is considered reaped |
| and all associated timers will be removed. |
| """ |
| |
| def _reap_worker_no_throw(self, worker_id: Any) -> bool: |
| """ |
| Wraps ``_reap_worker(worker_id)``, if an uncaught exception is |
| thrown, then it considers the worker as reaped. |
| """ |
| try: |
| return self._reap_worker(worker_id) |
| except Exception as e: |
| logging.error( |
| "Uncaught exception thrown from _reap_worker(), " |
| "check that the implementation correctly catches exceptions", |
| exc_info=e, |
| ) |
| return True |
| |
| def _watchdog_loop(self): |
| while not self._stop_signaled: |
| try: |
| self._run_watchdog() |
| except Exception as e: |
| logging.error("Error running watchdog", exc_info=e) |
| |
| def _run_watchdog(self): |
| batch_size = max(1, self._request_queue.size()) |
| timer_requests = self._request_queue.get(batch_size, self._max_interval) |
| self.register_timers(timer_requests) |
| now = time.time() |
| reaped_worker_ids = set() |
| for worker_id, expired_timers in self.get_expired_timers(now).items(): |
| logging.info( |
| f"Reaping worker_id=[{worker_id}]." |
| f" Expired timers: {self._get_scopes(expired_timers)}" |
| ) |
| if self._reap_worker_no_throw(worker_id): |
| logging.info(f"Successfully reaped worker=[{worker_id}]") |
| reaped_worker_ids.add(worker_id) |
| else: |
| logging.error( |
| f"Error reaping worker=[{worker_id}]. Will retry on next watchdog." |
| ) |
| self.clear_timers(reaped_worker_ids) |
| |
| def _get_scopes(self, timer_requests): |
| return [r.scope_id for r in timer_requests] |
| |
| def start(self) -> None: |
| logging.info( |
| f"Starting {type(self).__name__}..." |
| f" max_interval={self._max_interval}," |
| f" daemon={self._daemon}" |
| ) |
| self._watchdog_thread = threading.Thread( |
| target=self._watchdog_loop, daemon=self._daemon |
| ) |
| logging.info("Starting watchdog thread...") |
| self._watchdog_thread.start() |
| |
| def stop(self) -> None: |
| logging.info(f"Stopping {type(self).__name__}") |
| self._stop_signaled = True |
| if self._watchdog_thread: |
| logging.info("Stopping watchdog thread...") |
| self._watchdog_thread.join(self._max_interval) |
| self._watchdog_thread = None |
| else: |
| logging.info("No watchdog thread running, doing nothing") |
| |
| |
| _timer_client = None |
| |
| |
| def configure(timer_client: TimerClient): |
| """ |
| Configures a timer client. Must be called before using ``expires``. |
| """ |
| global _timer_client |
| _timer_client = timer_client |
| logging.info(f"Timer client configured to: {type(_timer_client).__name__}") |
| |
| |
| @contextmanager |
| def expires( |
| after: float, scope: Optional[str] = None, client: Optional[TimerClient] = None |
| ): |
| """ |
| Acquires a countdown timer that expires in ``after`` seconds from now, |
| unless the code-block that it wraps is finished within the timeframe. |
| When the timer expires, this worker is eligible to be reaped. The |
| exact meaning of "reaped" depends on the client implementation. In |
| most cases, reaping means to terminate the worker process. |
| Note that the worker is NOT guaranteed to be reaped at exactly |
| ``time.now() + after``, but rather the worker is "eligible" for being |
| reaped and the ``TimerServer`` that the client talks to will ultimately |
| make the decision when and how to reap the workers with expired timers. |
| |
| Usage:: |
| |
| torch.distributed.elastic.timer.configure(LocalTimerClient()) |
| with expires(after=10): |
| torch.distributed.all_reduce(...) |
| """ |
| if client is None: |
| if _timer_client is None: |
| raise RuntimeError("Configure timer client before using coundown timers.") |
| client = _timer_client |
| if scope is None: |
| # grab the caller file + lineno |
| caller = getframeinfo(stack()[1][0]) |
| scope = f"{caller.filename}#{caller.lineno}" |
| expiration = time.time() + after |
| client.acquire(scope, expiration) |
| try: |
| yield |
| finally: |
| client.release(scope) |