| # Copyright (c) Meta Platforms, Inc. and affiliates |
| import torch |
| |
| from typing import ( |
| Callable, |
| Collection, |
| List, |
| Mapping, |
| MutableMapping, |
| Optional, |
| Tuple, |
| TypeVar, |
| Union, |
| cast, |
| ) |
| from torch.distributed.checkpoint.metadata import ( |
| STATE_DICT_TYPE, |
| ) |
| from torch.distributed._shard.sharded_tensor.api import ShardedTensor |
| from torch.distributed._tensor import DTensor |
| |
| PATH_ITEM = Union[str, int] |
| OBJ_PATH = Tuple[PATH_ITEM, ...] |
| T = TypeVar("T") |
| |
| STATE_DICT_ITEM = object |
| CONTAINER_TYPE = MutableMapping[PATH_ITEM, STATE_DICT_ITEM] |
| |
| __all__ = ["traverse_state_dict", "set_element", "get_element", "print_tensor"] |
| |
| |
| def _keep_visiting_tensors(value: STATE_DICT_ITEM) -> bool: |
| return isinstance(value, torch.Tensor) |
| |
| |
| # TODO: update docstring for traverse.py |
| def traverse_state_dict( |
| state_dict: STATE_DICT_TYPE, |
| visitor: Callable[[OBJ_PATH, STATE_DICT_ITEM], None], |
| keep_traversing: Callable[[STATE_DICT_ITEM], bool] = _keep_visiting_tensors, |
| ) -> None: |
| """ |
| Invoke ``visitor`` for each value recursively in ``state_dict``. |
| Traversal is short-circuited when if finds a collection for which ``keep_visiting_tensors`` evaluates |
| to false for all elements. |
| By default, all collections with at least one ``torch.Tensor`` element are traversed. |
| Visitor takes a path argument that is a tuple of the keys used to reach it. |
| """ |
| # a value is terminal if it has no other containers values inside it |
| def _is_terminal(value: STATE_DICT_ITEM) -> bool: |
| values: Collection[STATE_DICT_ITEM] |
| if isinstance(value, Mapping): |
| values = value.values() |
| elif isinstance(value, list): |
| values = value |
| else: |
| return True |
| |
| for entry in values: |
| if isinstance(entry, (Mapping, list)) and not _is_terminal(entry): |
| return False |
| if keep_traversing is not None and keep_traversing(entry): |
| return False |
| return True |
| |
| def _traverse_obj(path: OBJ_PATH, value: STATE_DICT_ITEM) -> None: |
| if _is_terminal(value): |
| visitor(path, value) |
| elif isinstance(value, Mapping): |
| for k, v in value.items(): |
| _traverse_obj(path + (str(k),), v) |
| elif isinstance(value, list): |
| for i, v in enumerate(value): |
| _traverse_obj(path + (i,), v) |
| |
| for key, value in state_dict.items(): |
| _traverse_obj((str(key),), value) |
| |
| |
| def set_element( |
| root_dict: STATE_DICT_TYPE, path: OBJ_PATH, value: STATE_DICT_ITEM |
| ) -> None: |
| """ |
| Set ``value`` in ``root_dict`` along the ``path`` object path. |
| """ |
| cur_container = cast(CONTAINER_TYPE, root_dict) |
| |
| def extend_list(lst: List[STATE_DICT_ITEM], idx: int) -> None: |
| while len(lst) <= idx: |
| lst.append(None) |
| |
| for i in range(1, len(path)): |
| prev_key = path[i - 1] |
| key = path[i] |
| def_val = cast(STATE_DICT_ITEM, {} if type(key) == str else []) |
| |
| if isinstance(cur_container, Mapping): |
| cur_container = cast( |
| CONTAINER_TYPE, cur_container.setdefault(prev_key, def_val) |
| ) |
| else: |
| extend_list(cur_container, prev_key) |
| if cur_container[prev_key] is None: |
| cur_container[prev_key] = def_val |
| cur_container = cur_container[prev_key] |
| |
| key = path[-1] |
| if type(key) == int: |
| extend_list(cast(List[STATE_DICT_ITEM], cur_container), key) |
| |
| cur_container[key] = value |
| |
| |
| def get_element( |
| root_dict: STATE_DICT_TYPE, |
| path: OBJ_PATH, |
| default_value: Optional[T] = None, |
| ) -> Optional[T]: |
| """ |
| Retrieve the value at ``path``from ``root_dict``, returning ``default_value`` if not found. |
| """ |
| cur_value = cast(CONTAINER_TYPE, root_dict) |
| for part in path: |
| if type(part) is int: |
| if not isinstance(cur_value, list) or len(cur_value) < part: |
| return default_value |
| elif not isinstance(cur_value, Mapping) or part not in cur_value: |
| return default_value |
| |
| cur_value = cast(CONTAINER_TYPE, cur_value[part]) |
| return cast(Optional[T], cur_value) |
| |
| |
| def _print_nested( |
| value: STATE_DICT_ITEM, |
| prefix: str = "", |
| print_fun: Callable[[str], None] = print, |
| ) -> None: |
| if type(value) is ShardedTensor: |
| print_fun(f"{prefix} ShardedTensor size: {value.size()}") |
| for shard in value.local_shards(): |
| _print_nested( |
| shard.tensor, |
| f"{shard.metadata.shard_offsets} ", |
| print_fun=print_fun, |
| ) |
| elif type(value) is (DTensor): |
| print_fun(f"{prefix} DistributedTensor size: {value.size()}") |
| # TODO: add local offset for _local_tensor in print_nested. |
| _print_nested( |
| value._local_tensor, |
| print_fun=print_fun, |
| ) |
| elif isinstance(value, torch.Tensor): |
| print_fun(f"{prefix} Tensor size: {value.size()}") |
| else: |
| print_fun(f"{prefix} Type: {type(value)}") |
| |
| |
| def print_tensor( |
| path: OBJ_PATH, |
| value: STATE_DICT_ITEM, |
| print_fun: Callable[[str], None] = print, |
| ) -> None: |
| """ |
| Callback that can be used with travese_state_dict to print its content. |
| By default the content is printed using the builtin ``print`` but this can |
| be change by passing a different ``print_fun` callable. |
| """ |
| _print_nested(value, prefix=str(path), print_fun=print_fun) |