| # Copyright (c) Meta Platforms, Inc. and affiliates. |
| # All rights reserved. |
| # |
| # This source code is licensed under the BSD-style license found in the |
| # LICENSE file in the root directory of this source tree. |
| |
| from enum import IntEnum |
| from typing import Optional, Set, Tuple |
| |
| import torch |
| |
| from executorch.backends.vulkan.serialization.vulkan_graph_schema import ( |
| VkMemoryLayout, |
| VkStorageType, |
| ) |
| |
| from executorch.exir.tensor import TensorSpec |
| |
| from torch._export.utils import is_buffer, is_param |
| |
| from torch._subclasses.fake_tensor import FakeTensor |
| |
| from torch.export import ExportedProgram |
| |
| ## |
| ## Node type determination |
| ## |
| |
| |
| def is_get_attr_node(node: torch.fx.Node) -> bool: |
| return isinstance(node, torch.fx.Node) and node.op == "get_attr" |
| |
| |
| def is_constant(program: ExportedProgram, node: torch.fx.Node) -> bool: |
| return node.name in program.graph_signature.inputs_to_lifted_tensor_constants |
| |
| |
| def is_param_node(program: ExportedProgram, node: torch.fx.Node) -> bool: |
| """ |
| Check if the given node is a parameter within the exported program |
| """ |
| return ( |
| is_get_attr_node(node) |
| or is_param(program, node) |
| or is_buffer(program, node) |
| or is_constant(program, node) |
| ) |
| |
| |
| def is_symint_node(node: torch.fx.Node) -> bool: |
| """ |
| Returns true if the given node produces a SymInt value |
| """ |
| if "val" not in node.meta: |
| return False |
| |
| if isinstance(node.meta["val"], torch.SymInt): |
| return True |
| |
| return False |
| |
| |
| def is_tensor_node(node: torch.fx.Node) -> bool: |
| """ |
| Returns true if the given node produces a tensor value, or a collection of tensor values |
| """ |
| # All nodes with tensor values are tagged by the SpecPropPass transform |
| if "spec" in node.meta: |
| return True |
| |
| if "val" not in node.meta: |
| return False |
| |
| if isinstance(node.meta["val"], FakeTensor): |
| return True |
| |
| if isinstance(node.meta["val"], list) or isinstance(node.meta["val"], tuple): |
| return all(isinstance(x, FakeTensor) for x in node.meta["val"]) |
| |
| return False |
| |
| |
| ## |
| ## Memory Layout, Storage Type Determination |
| ## |
| |
| ImageExtents = Tuple[int, int, int] |
| |
| DEFAULT_TEXTURE_LIMITS = (16384, 16384, 2048) |
| DEFAULT_BUFFER_LIMIT = 128 * (1024 * 1024) |
| |
| |
| class PackedDim(IntEnum): |
| WIDTH = 0 |
| HEIGHT = 1 |
| CHANNELS = 2 |
| |
| |
| all_packed_dims: Set[PackedDim] = { |
| PackedDim.WIDTH, |
| PackedDim.HEIGHT, |
| PackedDim.CHANNELS, |
| } |
| |
| all_storage_types: Set[VkStorageType] = { |
| VkStorageType.BUFFER, |
| VkStorageType.TEXTURE_3D, |
| } |
| |
| all_memory_layouts: Set[VkMemoryLayout] = { |
| VkMemoryLayout.TENSOR_WIDTH_PACKED, |
| VkMemoryLayout.TENSOR_HEIGHT_PACKED, |
| VkMemoryLayout.TENSOR_CHANNELS_PACKED, |
| } |
| |
| |
| def within_buffer_limit(node: torch.fx.Node, buffer_limit: int) -> int: |
| """ |
| Checks whether the tensors produced by the given node can fit within the device's |
| GPU buffer limit, which represents the maximum number of elements that can be stored |
| in a GPU buffer. |
| """ |
| assert is_tensor_node(node) |
| |
| if isinstance(node.meta["val"], FakeTensor): |
| return node.meta["val"].numel() < buffer_limit |
| elif isinstance(node.meta["val"], list) or isinstance(node.meta["val"], tuple): |
| return all(x.numel() < buffer_limit for x in node.meta["val"]) |
| else: |
| raise RuntimeError(f"Cannot get numel for val of type {type(node.meta['val'])}") |
| |
| |
| def required_image_extents(sizes: torch.Size, layout: VkMemoryLayout) -> ImageExtents: |
| """ |
| Calculate the image extents that will be used to represent a tensor with the given sizes |
| and memory layout in the Vulkan Delegate. |
| """ |
| width = sizes[-1] if len(sizes) >= 1 else 1 |
| height = sizes[-2] if len(sizes) >= 2 else 1 |
| channels = sizes[-3] if len(sizes) >= 3 else 1 |
| batch = sizes[0] if len(sizes) >= 4 else 1 |
| |
| if layout == VkMemoryLayout.TENSOR_WIDTH_PACKED: |
| width = (width + 3) // 4 |
| elif layout == VkMemoryLayout.TENSOR_HEIGHT_PACKED: |
| height = (height + 3) // 4 |
| elif layout == VkMemoryLayout.TENSOR_CHANNELS_PACKED: |
| channels = (channels + 3) // 4 |
| else: |
| raise RuntimeError(f"Unsupported memory layout {layout}") |
| |
| return width, height, channels * batch |
| |
| |
| def extents_are_valid(extents: ImageExtents, limits: ImageExtents) -> bool: |
| return all(extents[i] <= limits[i] for i in range(len(extents))) |
| |
| |
| def valid_texture_memory_layouts( |
| tensor_sizes: torch.Size, texture_limits: ImageExtents |
| ) -> Set[VkMemoryLayout]: |
| """ |
| Given tensor sizes, determine the set of memory layouts which will prodice a texture |
| that can fit within the specified device limits. |
| """ |
| valid_layouts = set() |
| for layout in list(all_memory_layouts): |
| extents = required_image_extents(tensor_sizes, layout) |
| if extents_are_valid(extents, texture_limits): |
| valid_layouts.add(layout) |
| |
| return valid_layouts |
| |
| |
| def possible_node_memory_layouts( |
| node: torch.fx.Node, texture_limits: ImageExtents |
| ) -> Set[VkMemoryLayout]: |
| """ |
| Given a node, determine the set of memory layouts which can be used to represent all |
| tensors involved in the computation. |
| """ |
| assert is_tensor_node(node) |
| if isinstance(node.meta["val"], FakeTensor): |
| return valid_texture_memory_layouts(node.meta["val"].shape, texture_limits) |
| valid_layouts = set() |
| if isinstance(node.meta["val"], list) or isinstance(node.meta["val"], tuple): |
| for fake_tensor in node.meta["val"]: |
| valid_layouts = valid_layouts.union( |
| valid_texture_memory_layouts(fake_tensor.shape, texture_limits) |
| ) |
| |
| return valid_layouts |
| |
| |
| ## |
| ## TensorSpec Utils |
| ## |
| |
| |
| def set_node_spec_attr(node: torch.fx.Node, attr: str, value): |
| assert "spec" in node.meta |
| spec = node.meta["spec"] |
| if isinstance(spec, TensorSpec): |
| setattr(spec, attr, value) |
| elif isinstance(spec, list) or isinstance(spec, tuple): |
| for s in spec: |
| assert isinstance(s, TensorSpec) |
| setattr(s, attr, value) |
| else: |
| raise RuntimeError(f"Cannot set attr for spec of type {type(spec)}") |
| |
| |
| def get_node_spec_attr(node: torch.fx.Node, attr: str, return_first: bool = True): |
| assert "spec" in node.meta |
| spec = node.meta["spec"] |
| if isinstance(spec, TensorSpec): |
| return getattr(spec, attr) if hasattr(spec, attr) else None |
| elif isinstance(spec, list) or isinstance(spec, tuple): |
| if return_first: |
| return getattr(spec[0], attr) if hasattr(spec, attr) else None |
| else: |
| return [getattr(s, attr) if hasattr(s, attr) else None for s in spec] |
| else: |
| raise RuntimeError(f"Cannot get attr for spec of type {type(spec)}") |
| |
| |
| def get_node_storage_type(node: torch.fx.Node) -> Optional[VkStorageType]: |
| return get_node_spec_attr(node, "vk_storage_type") |
| |
| |
| def get_node_memory_layout(node: torch.fx.Node) -> Optional[VkMemoryLayout]: |
| return get_node_spec_attr(node, "vk_memory_layout") |