| #!/usr/bin/env python3 |
| |
| # Copyright (c) Facebook, Inc. and its affiliates. |
| # All rights reserved. |
| # |
| # This source code is licensed under the BSD-style license found in the |
| # LICENSE file in the root directory of this source tree. |
| import sys |
| import uuid |
| from dataclasses import dataclass, field |
| from typing import Any, Callable, Dict, List, Optional, Tuple, Union |
| |
| import torch.distributed.elastic.rendezvous.registry as rdzv_registry |
| from torch.distributed.elastic import events, metrics |
| from torch.distributed.elastic.agent.server.api import WorkerSpec |
| from torch.distributed.elastic.agent.server.local_elastic_agent import LocalElasticAgent |
| from torch.distributed.elastic.multiprocessing import SignalException, Std |
| from torch.distributed.elastic.multiprocessing.errors import ChildFailedError |
| from torch.distributed.elastic.rendezvous import RendezvousParameters |
| from torch.distributed.elastic.rendezvous.utils import parse_rendezvous_endpoint |
| from torch.distributed.elastic.utils.logging import get_logger |
| |
| __all__ = ['LaunchConfig', 'elastic_launch', 'launch_agent'] |
| |
| logger = get_logger() |
| |
| |
| @dataclass |
| class LaunchConfig: |
| """ |
| Creates a rendezvous config. |
| |
| Args: |
| min_nodes: Minimum amount of nodes that the user function will |
| be launched on. Elastic agent ensures that the user |
| function start only when the min_nodes amount enters |
| the rendezvous. |
| max_nodes: Maximum amount of nodes that the user function |
| will be launched on. |
| nproc_per_node: On each node the elastic agent will launch |
| this amount of workers that will execute user |
| defined function. |
| rdzv_backend: rdzv_backend to use in the rendezvous (zeus-adapter, etcd). |
| rdzv_endpoint: The endpoint of the rdzv sync. storage. |
| rdzv_configs: Key, value pair that specifies rendezvous specific configuration. |
| rdzv_timeout: Legacy argument that specifies timeout for the rendezvous. It is going |
| to be removed in future versions, see the note below. The default timeout is 900 seconds. |
| run_id: The unique run id of the job (if not passed a unique one will be |
| deduced from run environment - flow workflow id in flow - or auto generated). |
| role: User defined role of the worker (defaults to "trainer"). |
| max_restarts: The maximum amount of restarts that elastic agent will conduct |
| on workers before failure. |
| monitor_interval: The interval in seconds that is used by the elastic_agent |
| as a period of monitoring workers. |
| start_method: The method is used by the elastic agent to start the |
| workers (spawn, fork, forkserver). |
| log_dir: base log directory where log files are written. If not set, |
| one is created in a tmp dir but NOT removed on exit. |
| redirects: configuration to redirect stdout/stderr to log files. |
| Pass a single ``Std`` enum to redirect all workers, |
| or a mapping keyed by local_rank to selectively redirect. |
| tee: configuration to "tee" stdout/stderr to console + log file. |
| metrics_cfg: configuration to initialize metrics. |
| local_addr: address of the local node if any. If not set, a lookup on the local |
| machine's FQDN will be performed. |
| ..note: |
| `rdzv_timeout` is a legacy argument that will be removed in future. |
| Set the timeout via `rdzv_configs['timeout']` |
| |
| """ |
| |
| min_nodes: int |
| max_nodes: int |
| nproc_per_node: int |
| run_id: str = "" |
| role: str = "default_role" |
| rdzv_endpoint: str = "" |
| rdzv_backend: str = "etcd" |
| rdzv_configs: Dict[str, Any] = field(default_factory=dict) |
| rdzv_timeout: int = -1 |
| max_restarts: int = 3 |
| monitor_interval: float = 30 |
| start_method: str = "spawn" |
| log_dir: Optional[str] = None |
| redirects: Union[Std, Dict[int, Std]] = Std.NONE |
| tee: Union[Std, Dict[int, Std]] = Std.NONE |
| metrics_cfg: Dict[str, str] = field(default_factory=dict) |
| local_addr: Optional[str] = None |
| |
| def __post_init__(self): |
| default_timeout = 900 |
| if self.rdzv_timeout != -1: |
| self.rdzv_configs["timeout"] = self.rdzv_timeout |
| elif "timeout" not in self.rdzv_configs: |
| self.rdzv_configs["timeout"] = default_timeout |
| |
| |
| class elastic_launch: |
| """ |
| Launches an torchelastic agent on the container that invoked the entrypoint. |
| |
| 1. Pass the ``entrypoint`` arguments as non ``kwargs`` (e.g. no named parameters)/ |
| ``entrypoint`` can be a function or a command. |
| 2. The return value is a map of each worker's output mapped |
| by their respective global rank. |
| |
| Usage |
| |
| :: |
| |
| def worker_fn(foo): |
| # ... |
| |
| def main(): |
| # entrypoint is a function. |
| outputs = elastic_launch(LaunchConfig, worker_fn)(foo) |
| # return rank 0's output |
| return outputs[0] |
| |
| # entrypoint is a command and ``script.py`` is the python module. |
| outputs = elastic_launch(LaunchConfig, "script.py")(args) |
| outputs = elastic_launch(LaunchConfig, "python")("script.py") |
| """ |
| |
| def __init__( |
| self, |
| config: LaunchConfig, |
| entrypoint: Union[Callable, str, None], |
| ): |
| self._config = config |
| self._entrypoint = entrypoint |
| |
| def __call__(self, *args): |
| return launch_agent(self._config, self._entrypoint, list(args)) |
| |
| |
| def _get_entrypoint_name( |
| entrypoint: Union[Callable, str, None], args: List[Any] |
| ) -> str: |
| """Retrive entrypoint name with the rule: |
| 1. If entrypoint is a function, use ``entrypont.__qualname__``. |
| 2. If entrypoint is a string, check its value: |
| 2.1 if entrypoint equals to ``sys.executable`` (like "python"), use the first element from ``args`` |
| which does not start with hifen letter (for example, "-u" will be skipped). |
| 2.2 otherwise, use ``entrypoint`` value. |
| 3. Otherwise, return empty string. |
| """ |
| if isinstance(entrypoint, Callable): # type: ignore[arg-type] |
| return entrypoint.__name__ # type: ignore[union-attr] |
| elif isinstance(entrypoint, str): |
| if entrypoint == sys.executable: |
| return next((arg for arg in args if arg[0] != "-"), "") |
| else: |
| return entrypoint |
| else: |
| return "" |
| |
| |
| def _get_addr_and_port( |
| rdzv_parameters: RendezvousParameters, |
| ) -> Tuple[Optional[str], Optional[int]]: |
| if rdzv_parameters.backend != "static": |
| return (None, None) |
| endpoint = rdzv_parameters.endpoint |
| endpoint = endpoint.strip() |
| if not endpoint: |
| raise ValueError( |
| "Endpoint is missing in endpoint. Try to add --master_addr and --master_port" |
| ) |
| master_addr, master_port = parse_rendezvous_endpoint(endpoint, default_port=-1) |
| if master_port == -1: |
| raise ValueError( |
| f"port is missing in endpoint: {endpoint}. Try to specify --master_port" |
| ) |
| return (master_addr, master_port) |
| |
| |
| def launch_agent( |
| config: LaunchConfig, |
| entrypoint: Union[Callable, str, None], |
| args: List[Any], |
| ) -> Dict[int, Any]: |
| if not config.run_id: |
| run_id = str(uuid.uuid4().int) |
| logger.warning(f"config has no run_id, generated a random run_id: {run_id}") |
| config.run_id = run_id |
| |
| entrypoint_name = _get_entrypoint_name(entrypoint, args) |
| |
| logger.info( |
| f"Starting elastic_operator with launch configs:\n" |
| f" entrypoint : {entrypoint_name}\n" |
| f" min_nodes : {config.min_nodes}\n" |
| f" max_nodes : {config.max_nodes}\n" |
| f" nproc_per_node : {config.nproc_per_node}\n" |
| f" run_id : {config.run_id}\n" |
| f" rdzv_backend : {config.rdzv_backend}\n" |
| f" rdzv_endpoint : {config.rdzv_endpoint}\n" |
| f" rdzv_configs : {config.rdzv_configs}\n" |
| f" max_restarts : {config.max_restarts}\n" |
| f" monitor_interval : {config.monitor_interval}\n" |
| f" log_dir : {config.log_dir}\n" |
| f" metrics_cfg : {config.metrics_cfg}\n" |
| ) |
| |
| rdzv_parameters = RendezvousParameters( |
| backend=config.rdzv_backend, |
| endpoint=config.rdzv_endpoint, |
| run_id=config.run_id, |
| min_nodes=config.min_nodes, |
| max_nodes=config.max_nodes, |
| local_addr=config.local_addr, |
| **config.rdzv_configs, |
| ) |
| |
| master_addr, master_port = _get_addr_and_port(rdzv_parameters) |
| |
| spec = WorkerSpec( |
| role=config.role, |
| local_world_size=config.nproc_per_node, |
| entrypoint=entrypoint, |
| args=tuple(args), |
| rdzv_handler=rdzv_registry.get_rendezvous_handler(rdzv_parameters), |
| max_restarts=config.max_restarts, |
| monitor_interval=config.monitor_interval, |
| redirects=config.redirects, |
| tee=config.tee, |
| master_addr=master_addr, |
| master_port=master_port, |
| local_addr=config.local_addr, |
| ) |
| |
| agent = LocalElasticAgent( |
| spec=spec, start_method=config.start_method, log_dir=config.log_dir |
| ) |
| |
| shutdown_rdzv = True |
| try: |
| metrics.initialize_metrics(metrics.MetricsConfig(config.metrics_cfg)) |
| |
| result = agent.run() |
| # records that agent.run() has succeeded NOT that workers have succeeded |
| events.record(agent.get_event_succeeded()) |
| |
| if result.is_failed(): |
| # ChildFailedError is treated specially by @record |
| # if the error files for the failed children exist |
| # @record will copy the first error (root cause) |
| # to the error file of the launcher process. |
| raise ChildFailedError( |
| name=entrypoint_name, |
| failures=result.failures, |
| ) |
| |
| return result.return_values |
| except ChildFailedError: |
| raise |
| except SignalException: |
| # when the agent dies with a signal do NOT shutdown the rdzv_handler |
| # since this closes the rendezvous on this rdzv_id permanently and |
| # prevents any additional scaling events |
| shutdown_rdzv = False |
| events.record(agent.get_event_failed()) |
| raise |
| except Exception: |
| events.record(agent.get_event_failed()) |
| raise |
| finally: |
| if shutdown_rdzv: |
| spec.rdzv_handler.shutdown() |