blob: 75dc42453348f4cf77dd6a412f091c40a3ba5765 [file] [log] [blame]
# Copyright (c) Meta Platforms, Inc. and affiliates
import torch
from typing import (
Callable,
Collection,
List,
Mapping,
MutableMapping,
Optional,
Tuple,
TypeVar,
Union,
cast,
)
from torch.distributed.checkpoint.metadata import (
STATE_DICT_TYPE,
)
from torch.distributed._shard.sharded_tensor.api import ShardedTensor
from torch.distributed._tensor import DTensor
PATH_ITEM = Union[str, int]
OBJ_PATH = Tuple[PATH_ITEM, ...]
T = TypeVar("T")
STATE_DICT_ITEM = object
CONTAINER_TYPE = MutableMapping[PATH_ITEM, STATE_DICT_ITEM]
__all__ = ["traverse_state_dict", "set_element", "get_element", "print_tensor"]
def _keep_visiting_tensors(value: STATE_DICT_ITEM) -> bool:
return isinstance(value, torch.Tensor)
# TODO: update docstring for traverse.py
def traverse_state_dict(
state_dict: STATE_DICT_TYPE,
visitor: Callable[[OBJ_PATH, STATE_DICT_ITEM], None],
keep_traversing: Callable[[STATE_DICT_ITEM], bool] = _keep_visiting_tensors,
) -> None:
"""
Invoke ``visitor`` for each value recursively in ``state_dict``.
Traversal is short-circuited when if finds a collection for which ``keep_visiting_tensors`` evaluates
to false for all elements.
By default, all collections with at least one ``torch.Tensor`` element are traversed.
Visitor takes a path argument that is a tuple of the keys used to reach it.
"""
# a value is terminal if it has no other containers values inside it
def _is_terminal(value: STATE_DICT_ITEM) -> bool:
values: Collection[STATE_DICT_ITEM]
if isinstance(value, Mapping):
values = value.values()
elif isinstance(value, list):
values = value
else:
return True
for entry in values:
if isinstance(entry, (Mapping, list)) and not _is_terminal(entry):
return False
if keep_traversing is not None and keep_traversing(entry):
return False
return True
def _traverse_obj(path: OBJ_PATH, value: STATE_DICT_ITEM) -> None:
if _is_terminal(value):
visitor(path, value)
elif isinstance(value, Mapping):
for k, v in value.items():
_traverse_obj(path + (str(k),), v)
elif isinstance(value, list):
for i, v in enumerate(value):
_traverse_obj(path + (i,), v)
for key, value in state_dict.items():
_traverse_obj((str(key),), value)
def set_element(
root_dict: STATE_DICT_TYPE, path: OBJ_PATH, value: STATE_DICT_ITEM
) -> None:
"""
Set ``value`` in ``root_dict`` along the ``path`` object path.
"""
cur_container = cast(CONTAINER_TYPE, root_dict)
def extend_list(lst: List[STATE_DICT_ITEM], idx: int) -> None:
while len(lst) <= idx:
lst.append(None)
for i in range(1, len(path)):
prev_key = path[i - 1]
key = path[i]
def_val = cast(STATE_DICT_ITEM, {} if type(key) == str else [])
if isinstance(cur_container, Mapping):
cur_container = cast(
CONTAINER_TYPE, cur_container.setdefault(prev_key, def_val)
)
else:
extend_list(cur_container, prev_key)
if cur_container[prev_key] is None:
cur_container[prev_key] = def_val
cur_container = cur_container[prev_key]
key = path[-1]
if type(key) == int:
extend_list(cast(List[STATE_DICT_ITEM], cur_container), key)
cur_container[key] = value
def get_element(
root_dict: STATE_DICT_TYPE,
path: OBJ_PATH,
default_value: Optional[T] = None,
) -> Optional[T]:
"""
Retrieve the value at ``path``from ``root_dict``, returning ``default_value`` if not found.
"""
cur_value = cast(CONTAINER_TYPE, root_dict)
for part in path:
if type(part) is int:
if not isinstance(cur_value, list) or len(cur_value) < part:
return default_value
elif not isinstance(cur_value, Mapping) or part not in cur_value:
return default_value
cur_value = cast(CONTAINER_TYPE, cur_value[part])
return cast(Optional[T], cur_value)
def _print_nested(
value: STATE_DICT_ITEM,
prefix: str = "",
print_fun: Callable[[str], None] = print,
) -> None:
if type(value) is ShardedTensor:
print_fun(f"{prefix} ShardedTensor size: {value.size()}")
for shard in value.local_shards():
_print_nested(
shard.tensor,
f"{shard.metadata.shard_offsets} ",
print_fun=print_fun,
)
elif type(value) is (DTensor):
print_fun(f"{prefix} DistributedTensor size: {value.size()}")
# TODO: add local offset for _local_tensor in print_nested.
_print_nested(
value._local_tensor,
print_fun=print_fun,
)
elif isinstance(value, torch.Tensor):
print_fun(f"{prefix} Tensor size: {value.size()}")
else:
print_fun(f"{prefix} Type: {type(value)}")
def print_tensor(
path: OBJ_PATH,
value: STATE_DICT_ITEM,
print_fun: Callable[[str], None] = print,
) -> None:
"""
Callback that can be used with travese_state_dict to print its content.
By default the content is printed using the builtin ``print`` but this can
be change by passing a different ``print_fun` callable.
"""
_print_nested(value, prefix=str(path), print_fun=print_fun)