blob: a6db780309dbee1fe11b17a940dd82c40f6fc246 [file] [log] [blame] [edit]
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from enum import IntEnum
from typing import Optional, Set, Tuple
import torch
from executorch.backends.vulkan.serialization.vulkan_graph_schema import (
VkMemoryLayout,
VkStorageType,
)
from executorch.exir.tensor import TensorSpec
from torch._export.utils import is_buffer, is_param
from torch._subclasses.fake_tensor import FakeTensor
from torch.export import ExportedProgram
##
## Node type determination
##
def is_get_attr_node(node: torch.fx.Node) -> bool:
return isinstance(node, torch.fx.Node) and node.op == "get_attr"
def is_constant(program: ExportedProgram, node: torch.fx.Node) -> bool:
return node.name in program.graph_signature.inputs_to_lifted_tensor_constants
def is_param_node(program: ExportedProgram, node: torch.fx.Node) -> bool:
"""
Check if the given node is a parameter within the exported program
"""
return (
is_get_attr_node(node)
or is_param(program, node)
or is_buffer(program, node)
or is_constant(program, node)
)
def is_symint_node(node: torch.fx.Node) -> bool:
"""
Returns true if the given node produces a SymInt value
"""
if "val" not in node.meta:
return False
if isinstance(node.meta["val"], torch.SymInt):
return True
return False
def is_tensor_node(node: torch.fx.Node) -> bool:
"""
Returns true if the given node produces a tensor value, or a collection of tensor values
"""
# All nodes with tensor values are tagged by the SpecPropPass transform
if "spec" in node.meta:
return True
if "val" not in node.meta:
return False
if isinstance(node.meta["val"], FakeTensor):
return True
if isinstance(node.meta["val"], list) or isinstance(node.meta["val"], tuple):
return all(isinstance(x, FakeTensor) for x in node.meta["val"])
return False
##
## Memory Layout, Storage Type Determination
##
ImageExtents = Tuple[int, int, int]
DEFAULT_TEXTURE_LIMITS = (16384, 16384, 2048)
DEFAULT_BUFFER_LIMIT = 128 * (1024 * 1024)
class PackedDim(IntEnum):
WIDTH = 0
HEIGHT = 1
CHANNELS = 2
all_packed_dims: Set[PackedDim] = {
PackedDim.WIDTH,
PackedDim.HEIGHT,
PackedDim.CHANNELS,
}
all_storage_types: Set[VkStorageType] = {
VkStorageType.BUFFER,
VkStorageType.TEXTURE_3D,
}
all_memory_layouts: Set[VkMemoryLayout] = {
VkMemoryLayout.TENSOR_WIDTH_PACKED,
VkMemoryLayout.TENSOR_HEIGHT_PACKED,
VkMemoryLayout.TENSOR_CHANNELS_PACKED,
}
def within_buffer_limit(node: torch.fx.Node, buffer_limit: int) -> int:
"""
Checks whether the tensors produced by the given node can fit within the device's
GPU buffer limit, which represents the maximum number of elements that can be stored
in a GPU buffer.
"""
assert is_tensor_node(node)
if isinstance(node.meta["val"], FakeTensor):
return node.meta["val"].numel() < buffer_limit
elif isinstance(node.meta["val"], list) or isinstance(node.meta["val"], tuple):
return all(x.numel() < buffer_limit for x in node.meta["val"])
else:
raise RuntimeError(f"Cannot get numel for val of type {type(node.meta['val'])}")
def required_image_extents(sizes: torch.Size, layout: VkMemoryLayout) -> ImageExtents:
"""
Calculate the image extents that will be used to represent a tensor with the given sizes
and memory layout in the Vulkan Delegate.
"""
width = sizes[-1] if len(sizes) >= 1 else 1
height = sizes[-2] if len(sizes) >= 2 else 1
channels = sizes[-3] if len(sizes) >= 3 else 1
batch = sizes[0] if len(sizes) >= 4 else 1
if layout == VkMemoryLayout.TENSOR_WIDTH_PACKED:
width = (width + 3) // 4
elif layout == VkMemoryLayout.TENSOR_HEIGHT_PACKED:
height = (height + 3) // 4
elif layout == VkMemoryLayout.TENSOR_CHANNELS_PACKED:
channels = (channels + 3) // 4
else:
raise RuntimeError(f"Unsupported memory layout {layout}")
return width, height, channels * batch
def extents_are_valid(extents: ImageExtents, limits: ImageExtents) -> bool:
return all(extents[i] <= limits[i] for i in range(len(extents)))
def valid_texture_memory_layouts(
tensor_sizes: torch.Size, texture_limits: ImageExtents
) -> Set[VkMemoryLayout]:
"""
Given tensor sizes, determine the set of memory layouts which will prodice a texture
that can fit within the specified device limits.
"""
valid_layouts = set()
for layout in list(all_memory_layouts):
extents = required_image_extents(tensor_sizes, layout)
if extents_are_valid(extents, texture_limits):
valid_layouts.add(layout)
return valid_layouts
def possible_node_memory_layouts(
node: torch.fx.Node, texture_limits: ImageExtents
) -> Set[VkMemoryLayout]:
"""
Given a node, determine the set of memory layouts which can be used to represent all
tensors involved in the computation.
"""
assert is_tensor_node(node)
if isinstance(node.meta["val"], FakeTensor):
return valid_texture_memory_layouts(node.meta["val"].shape, texture_limits)
valid_layouts = set()
if isinstance(node.meta["val"], list) or isinstance(node.meta["val"], tuple):
for fake_tensor in node.meta["val"]:
valid_layouts = valid_layouts.union(
valid_texture_memory_layouts(fake_tensor.shape, texture_limits)
)
return valid_layouts
##
## TensorSpec Utils
##
def set_node_spec_attr(node: torch.fx.Node, attr: str, value):
assert "spec" in node.meta
spec = node.meta["spec"]
if isinstance(spec, TensorSpec):
setattr(spec, attr, value)
elif isinstance(spec, list) or isinstance(spec, tuple):
for s in spec:
assert isinstance(s, TensorSpec)
setattr(s, attr, value)
else:
raise RuntimeError(f"Cannot set attr for spec of type {type(spec)}")
def get_node_spec_attr(node: torch.fx.Node, attr: str, return_first: bool = True):
assert "spec" in node.meta
spec = node.meta["spec"]
if isinstance(spec, TensorSpec):
return getattr(spec, attr) if hasattr(spec, attr) else None
elif isinstance(spec, list) or isinstance(spec, tuple):
if return_first:
return getattr(spec[0], attr) if hasattr(spec, attr) else None
else:
return [getattr(s, attr) if hasattr(s, attr) else None for s in spec]
else:
raise RuntimeError(f"Cannot get attr for spec of type {type(spec)}")
def get_node_storage_type(node: torch.fx.Node) -> Optional[VkStorageType]:
return get_node_spec_attr(node, "vk_storage_type")
def get_node_memory_layout(node: torch.fx.Node) -> Optional[VkMemoryLayout]:
return get_node_spec_attr(node, "vk_memory_layout")