| # mypy: allow-untyped-defs |
| import functools |
| import math |
| import operator |
| from typing import * # noqa: F403 |
| |
| import torch |
| import torch.nn.functional as F |
| from torch.fx.operator_schemas import normalize_function |
| from torch.nested._internal.sdpa import jagged_scaled_dot_product_attention |
| |
| from .nested_tensor import NestedTensor |
| |
| |
| __all__: List[Any] = [] |
| |
| JAGGED_OPS_TABLE: Dict[Any, Any] = {} |
| |
| |
| # Simplifying assumption: we assume that the batch dim is always the left-most |
| # dim, and the ragged dim is always the second dim. |
| def _outer_to_inner_dim(ndim, dim): |
| assert dim >= 0 and dim < ndim |
| return 0 if dim < 2 else dim - 1 |
| |
| |
| def _wrap_jagged_dim( |
| ndim, dim, op_name, convert_to_inner_dim=True, allow_batch_dim=False |
| ): |
| from torch._prims_common import canonicalize_dims |
| |
| wrapped = canonicalize_dims(ndim, dim) |
| if wrapped == 1: |
| raise RuntimeError(f"{op_name}(): not supported for NestedTensor on dim=1") |
| elif wrapped == 0 and not allow_batch_dim: |
| raise RuntimeError(f"{op_name}(): not supported for NestedTensor on dim=0") |
| return _outer_to_inner_dim(ndim, wrapped) if convert_to_inner_dim else wrapped |
| |
| |
| def _wrap_jagged_dims(ndim, dims, op_name, ragged_idx=1): |
| """ |
| For NestedTensor operators, |
| wraps dimensions to non-negative values, |
| and returns metadata related to reduction dimension(s). |
| """ |
| from torch._prims_common import canonicalize_dims |
| |
| assert isinstance( |
| dims, (tuple, list) |
| ), f"_wrap_jagged_dims(): cannot iterate over dimensions of type {type(dims)}" |
| |
| wrapped_dims = [ |
| canonicalize_dims(ndim, d) for d in dims |
| ] # convert all indices to non-negative values |
| |
| operate_on_batch = 0 in wrapped_dims |
| operate_on_ragged = ragged_idx in wrapped_dims |
| operate_on_non_batch = any(d != 0 and d != ragged_idx for d in wrapped_dims) |
| |
| outer_to_inner_dim = tuple( |
| _outer_to_inner_dim(ndim, d) for d in wrapped_dims if d != 0 |
| ) |
| |
| return outer_to_inner_dim, operate_on_batch, operate_on_ragged, operate_on_non_batch |
| |
| |
| def check_schema(schema_str: str, func, *args, **kwargs) -> None: |
| named_arg_types = schema_str.split(", ") |
| num_optional_args = [x.endswith("?") for x in named_arg_types].count(True) |
| min_args = len(named_arg_types) - num_optional_args |
| |
| # special case: ellipses allows for any number of unchecked args at the end |
| if named_arg_types[-1] == "...": |
| named_arg_types = named_arg_types[:-1] |
| else: |
| if not (len(args) >= min_args and len(args) <= len(named_arg_types)): |
| raise ValueError( |
| f"NestedTensor {func.__name__}({schema_str}): expected at least {min_args} " |
| f"arguments and at most {len(named_arg_types)} arguments, but got: " |
| f"{len(args)} arguments" |
| ) |
| |
| arg_type_check_fns = { |
| "t": lambda x: isinstance(x, torch.Tensor) and not isinstance(x, NestedTensor), |
| "jt": lambda x: isinstance(x, NestedTensor) |
| and x._lengths is None |
| and x._ragged_idx == 1, # ops with "jt" require contiguous JT only |
| "jt_all": lambda x: isinstance( |
| x, NestedTensor |
| ), # ops with "jt_all" can accept all kinds of JT |
| "any": lambda x: True, |
| } |
| for i, named_arg_type in enumerate(named_arg_types): |
| name, arg_type = named_arg_type.split(": ") |
| is_optional = arg_type.endswith("?") |
| normalized_arg_type = arg_type[:-1] if is_optional else arg_type |
| if normalized_arg_type not in arg_type_check_fns.keys(): |
| raise AssertionError(f"Unknown arg type: {normalized_arg_type}") |
| |
| if i >= len(args): |
| if not is_optional: |
| raise ValueError( |
| f"NestedTensor {func.__name__}({schema_str}) " |
| f"missing required argument: {name}" |
| ) |
| continue |
| |
| _check_fn = arg_type_check_fns[normalized_arg_type] |
| |
| def check_fn(x, is_optional=is_optional): |
| if is_optional: |
| return x is None or _check_fn(x) |
| else: |
| return _check_fn(x) |
| |
| if not check_fn(args[i]): |
| type_to_desc = { |
| "t": "tensor", |
| "t?": "optional tensor", |
| "jt": "contiguous jagged layout NestedTensor", |
| "jt_all": "jagged layout NestedTensor", |
| "any": "<any type>", |
| } |
| |
| raise ValueError( |
| f"NestedTensor {func.__name__}({schema_str}): expected {name} to be a " |
| f"{type_to_desc[arg_type]}" |
| ) |
| |
| |
| def check_ragged_dim_same( |
| func, a: NestedTensor, a_name: str, b: NestedTensor, b_name: str |
| ) -> None: |
| # Calling into .shape here |
| if a._size[a._ragged_idx] != b._size[b._ragged_idx]: |
| raise RuntimeError( |
| f"NestedTensor {func.__name__}: expected {a_name} and {b_name} to have the " |
| "same exact offsets tensor." |
| ) |
| |
| |
| # returns True if the raggedness-relevant portions of the NT shape |
| # match those of the specified size |
| def raggedness_matches(nt, size): |
| end = nt._ragged_idx + 1 |
| nt_ragged = nt._size[:end] |
| size_ragged = size[:end] |
| return len(nt_ragged) == len(size_ragged) and ( |
| all(ns == s or s == -1 for ns, s in zip(nt_ragged, size_ragged)) |
| ) |
| |
| |
| def squeeze_leading_ones(t): |
| # Note: [ Squeezing leading ones ] |
| # |
| # Squeeze leading ones from t. |
| # |
| # We want: |
| # (B, j0, ?, ?) + (1, 1, ?, ?) -> (B, j0, ?, ?) |
| # (B, j0, ?, ?) + (1, 1, 1, ?, ?) -> (1, B, j0, ?, ?) (not yet supported) |
| # |
| # 1) Squeeze extra ones and grab values from NT |
| # (1, 1, ?, ?) -> (?, ?) and (sum(*), ?, ?) -> (B, j0, ?, ?) |
| # 2) Do dense broadcasting: |
| # (sum(*), ?, ?) + (?, ?) -> (sum(*), ?, ?) |
| # 3) Construct nested tensor |
| # (sum(*), ?, ?) -> (B, j0, ?, ?) |
| # |
| # If unsqueezing on the 0th dim becomes supported, we would unsqueeze |
| # at step (4) and we would need to update this function to record how |
| # many ones we unsqueezed. |
| while t.dim() > 0 and t.shape[0] == 1: |
| t = t.squeeze(0) |
| return t |
| |
| |
| def register_func(tables, aten_ops, schema_str): |
| if not isinstance(aten_ops, list): |
| aten_ops = [aten_ops] |
| if not isinstance(tables, list): |
| tables = [tables] |
| |
| def wrapper(func): |
| for aten_op in aten_ops: |
| |
| def get_inner(aten_op): |
| def inner(*args, **kwargs): |
| check_schema(schema_str, func, *args, **kwargs) |
| return func(aten_op, *args, **kwargs) |
| |
| return inner |
| |
| for table in tables: |
| table[aten_op] = get_inner(aten_op) |
| return func |
| |
| return wrapper |
| |
| |
| register_jagged_func = functools.partial(register_func, JAGGED_OPS_TABLE) |
| |
| |
| def lookup_jagged(func, *args, **kwargs) -> Optional[Callable]: |
| dispatch_func = JAGGED_OPS_TABLE.get(func, None) |
| if dispatch_func is not None: |
| return dispatch_func |
| |
| # Handle pointwise fallbacks |
| if torch.Tag.pointwise in func.tags: |
| # Assume there aren't additional tensors that aren't the "unary/binary" args |
| num_tensor_args = sum(isinstance(x, torch.Tensor) for x in args) |
| if num_tensor_args == 1: |
| # Build up the check schema string. The first tensor arg is assumed to be |
| # an NJT and other args are sent through as-is. |
| schema_parts = [] |
| for arg in func._schema.arguments: |
| if isinstance(arg.type, torch.TensorType): |
| schema_parts.append(f"{arg.name}: jt_all") |
| break |
| else: |
| schema_parts.append(f"{arg.name}: any") |
| schema_parts.append("...") |
| check_schema_str = ", ".join(schema_parts) |
| check_schema(check_schema_str, func, *args, **kwargs) |
| return functools.partial(jagged_unary_pointwise, func) |
| elif num_tensor_args == 2: |
| check_schema("lhs: any, rhs: any, ...", func, *args, **kwargs) |
| return functools.partial(jagged_binary_pointwise, func) |
| |
| return None |
| |
| |
| def extract_kwargs(arg): |
| kwargs = { |
| "offsets": arg.offsets(), |
| "_metadata_cache": arg._metadata_cache, |
| "_ragged_idx": arg._ragged_idx, |
| } |
| return kwargs |
| |
| |
| def jagged_unary_pointwise(func, *args, **kwargs): |
| # assume if we get here that there is a single NJT input in the args |
| njt = next(arg for arg in args if isinstance(arg, NestedTensor)) |
| return NestedTensor( |
| func(*(arg._values if arg is njt else arg for arg in args), **kwargs), |
| **extract_kwargs(njt), |
| ) |
| |
| |
| def jagged_binary_pointwise(func, *args, **kwargs): |
| a, b = args[0], args[1] |
| assert isinstance(a, NestedTensor) or isinstance(b, NestedTensor) |
| |
| mismatch_error_msg = ( |
| "cannot call binary pointwise function {} with inputs of shapes {} and {}" |
| ) |
| # a is NT, b is NT |
| if isinstance(a, NestedTensor) and isinstance(b, NestedTensor): |
| # ex: (B, j0, D) + (B, j0, D) |
| # ex: (B, j0, D) + (B, j0, 1) |
| if raggedness_matches(a, b._size): |
| return NestedTensor( |
| func(a._values, b._values, *args[2:], **kwargs), **extract_kwargs(a) |
| ) |
| raise RuntimeError(mismatch_error_msg.format(func.__name__, a._size, b._size)) |
| # either a is NT or b is NT at this point |
| a_is_nt = isinstance(a, NestedTensor) |
| extracted_kwargs = extract_kwargs(a) if a_is_nt else extract_kwargs(b) |
| |
| # === Handle broadcasting across the batch / ragged dims === |
| |
| # Easy case: take advantage of pre-existing broadcasting logic |
| # ex: (B, j0, ?, ?) + (?) -> (B, j0, ?, ?) |
| # ex: (B, j0, ?, ?) + (?, ?) -> (B, j0, ?, ?) |
| # ex: (B, j0, ?, ?) + (1, 1, ?, ?) -> (B, j0, ?, ?) |
| nt, t = (a, b) if a_is_nt else (b, a) |
| # See Note: [ Squeezing leading ones ] |
| if t.dim() > nt.dim(): |
| raise NotImplementedError("NYI: broadcasting NT with T with larger dim") |
| t_squeezed = squeeze_leading_ones(t) |
| if nt.dim() >= t_squeezed.dim() + 2: |
| lhs, rhs = (nt._values, t_squeezed) if a_is_nt else (t_squeezed, nt._values) |
| return NestedTensor(func(lhs, rhs, *args[2:], **kwargs), **extracted_kwargs) |
| |
| # Harder case: do manual broadcasting over unbound components |
| # when NT dim == non-NT dim |
| # ex: (B, j0, D_0, D_1) + (B, 1, D_0, D_1) -> (B, j0, D_0, D_1) |
| if a.dim() == b.dim(): |
| # ex: (B, j0, D_0, D_1) + (1, 1, D_0, D_1) -> should |
| # be (B, j0, D_0, D_1) but not yet supported |
| if a.shape[0] != b.shape[0]: |
| raise RuntimeError( |
| mismatch_error_msg.format(func.__name__, a.shape, b.shape) |
| ) |
| |
| # need to use offsets to broadcast across ragged dim properly |
| # NB: inefficient fallback here; Triton codegen can help this |
| # TODO: Make this work with autograd |
| outputs = [] |
| for a_comp, b_comp in zip(a.unbind(), b.unbind()): |
| outputs.append(func(a_comp, b_comp, *args[2:], **kwargs)) |
| new_values = torch.cat(outputs, dim=0) |
| return NestedTensor(new_values, **extracted_kwargs) |
| |
| # ex: (B, j0, D_0, D_1) + (A, B, 1, D_0, D_1) -> error because this breaks the invariant |
| # that ragged dim is wrt left-most batch dim |
| raise RuntimeError(mismatch_error_msg.format(func.__name__, a.shape, b.shape)) |
| |
| |
| def jagged_torch_function(func, *args, **kwargs): |
| # SDPA has special kernels that handle nested tensors. |
| # Dispatch to the correct implementation here |
| if func is torch._C._nn.scaled_dot_product_attention: |
| return jagged_scaled_dot_product_attention(*args, **kwargs) |
| |
| if func.__name__ == "apply_": |
| func(args[0]._values, *args[1:], **kwargs) |
| return args[0] |
| |
| # Handle flatten() here because it's CompositeImplicit. |
| if func.__name__ == "flatten": |
| |
| def _flatten_sig(input, start_dim=0, end_dim=-1): |
| pass |
| |
| _, new_kwargs = normalize_function( # type: ignore[misc] |
| _flatten_sig, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
| ) |
| |
| inp = new_kwargs.pop("input") |
| |
| # NB: stay in outer dim space because we're going to redispatch on a NT input |
| start_dim = _wrap_jagged_dim( |
| inp.dim(), new_kwargs["start_dim"], "flatten", convert_to_inner_dim=False |
| ) |
| end_dim = _wrap_jagged_dim( |
| inp.dim(), new_kwargs["end_dim"], "flatten", convert_to_inner_dim=False |
| ) |
| |
| if start_dim == end_dim: |
| return inp |
| |
| product = functools.reduce(operator.mul, inp.shape[start_dim : end_dim + 1]) |
| new_shape = (*inp.shape[:start_dim], product, *inp.shape[end_dim + 1 :]) |
| |
| return inp.reshape(*new_shape) |
| |
| raise NotImplementedError(func) |
| |
| |
| @register_jagged_func( |
| [ |
| torch.ops.aten.is_non_overlapping_and_dense.default, |
| torch.ops.aten.sym_size.default, |
| torch.ops.aten.dim.default, |
| torch.ops.aten.numel.default, |
| torch.ops.aten.sym_numel.default, |
| torch.ops.aten.sym_stride.default, |
| torch.ops.aten.sym_storage_offset.default, |
| ], |
| "self: jt_all", |
| ) |
| def tensor_attr_supported_getter(func, *args, **kwargs): |
| if func == torch.ops.aten.is_non_overlapping_and_dense.default: |
| return False |
| |
| if func == torch.ops.aten.sym_size.default: |
| return args[0]._size |
| |
| if func == torch.ops.aten.dim.default: |
| return len(args[0]._size) |
| |
| if func in (torch.ops.aten.sym_numel.default, torch.ops.aten.numel.default): |
| if args[0]._lengths is not None: |
| return int(sum(args[0]._lengths) * math.prod(args[0]._size[2:])) |
| return args[0]._values.numel() |
| |
| if func == torch.ops.aten.sym_stride.default: |
| return args[0]._strides |
| |
| if func == torch.ops.aten.sym_storage_offset.default: |
| return args[0]._values.storage_offset() |
| |
| |
| @register_jagged_func(torch.ops.prim.layout.default, "self: jt_all") |
| def prim_layout_default(func, *args, **kwargs): |
| return torch.jagged |
| |
| |
| @register_jagged_func( |
| [torch.ops.aten.size.default], |
| "self: jt_all", |
| ) |
| def tensor_attr_unsupported_getter(func, *args, **kwargs): |
| if func == torch.ops.aten.size.default: |
| raise RuntimeError( |
| "NestedTensors does not support directly calling torch.ops.aten.size " |
| "please use `nested_tensor.size()` instead." |
| ) |
| |
| |
| @register_jagged_func(torch.ops.aten.is_contiguous.default, "self: jt_all") |
| def is_contiguous_general(func, *args, **kwargs): |
| from torch._prims_common import is_contiguous_for_memory_format |
| |
| _, new_kwargs = normalize_function( # type: ignore[misc] |
| func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
| ) |
| inp = new_kwargs.pop("input") |
| |
| # If created from narrow() check for lengths |
| if inp.lengths() is not None: |
| return False |
| |
| new_kwargs["memory_format"] = new_kwargs.get( |
| "memory_format", torch.contiguous_format |
| ) |
| if new_kwargs["memory_format"] == torch.preserve_format: |
| return True |
| return is_contiguous_for_memory_format(inp._values, **new_kwargs) |
| |
| |
| register_jagged_func( |
| torch.ops.aten.is_contiguous.memory_format, "self: jt_all, memory_format: any?" |
| )(is_contiguous_general) |
| |
| |
| @register_jagged_func( |
| torch.ops.aten.clone.default, "input: jt_all, memory_format: any?" |
| ) |
| def clone_default(func, *args, **kwargs): |
| _, new_kwargs = normalize_function( # type: ignore[misc] |
| func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
| ) |
| |
| inp = new_kwargs.pop("input") |
| |
| new_meta = extract_kwargs(inp) |
| |
| if inp._lengths is not None: |
| if new_kwargs["memory_format"] == torch.contiguous_format: |
| # need to copy to remove "holes" non-contiguity / lengths metadata |
| # TODO: write a kernel for this |
| from .nested_tensor import jagged_from_list |
| |
| # TODO: We probably want the output to have the same ragged structure / nested int. |
| assert ( |
| inp._ragged_idx == 1 |
| ), "NJT with ragged_idx != 1 not supported for contiguous clone" |
| contig, _ = jagged_from_list(inp.unbind(), offsets=None) |
| return contig |
| else: |
| # need to preserve any lengths metadata present |
| new_meta["lengths"] = inp._lengths |
| |
| return NestedTensor(func(inp._values, **new_kwargs), **new_meta) |
| |
| |
| @register_jagged_func(torch.ops.aten.linear.default, "input: jt, weight: t, bias: t?") |
| def linear_default(func, *args, **kwargs): |
| _, new_kwargs = normalize_function( # type: ignore[misc] |
| func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
| ) |
| |
| inp = new_kwargs.pop("input") |
| |
| return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp)) |
| |
| |
| @register_jagged_func( |
| torch.ops.aten.linear_backward.default, |
| "self: jt, grad_output: jt, weight: t, output_mask: any", |
| ) |
| def linear_backward_default(func, *args, **kwargs): |
| _, new_kwargs = normalize_function( # type: ignore[misc] |
| func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
| ) |
| |
| inp = new_kwargs.pop("input") |
| grad_output = new_kwargs.pop("grad_output") |
| weight = new_kwargs.pop("weight") |
| |
| check_ragged_dim_same(func, inp, "self", grad_output, "grad_output") |
| ds = NestedTensor( |
| torch.matmul(grad_output._values, weight), **extract_kwargs(grad_output) |
| ) |
| dw = torch.matmul(grad_output._values.transpose(-2, -1), inp._values) |
| db = None # NYI: gradient for bias, need to reduce over ragged dim |
| return (ds, dw, db) |
| |
| |
| @register_jagged_func(torch.ops.aten.to.dtype, "input: jt_all, dtype: any") |
| def to_dtype(func, *args, **kwargs): |
| _, new_kwargs = normalize_function( # type: ignore[misc] |
| func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
| ) |
| |
| inp = new_kwargs.pop("input") |
| |
| return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp)) |
| |
| |
| @register_jagged_func(torch.ops.aten._to_copy.default, "self: jt_all") |
| def to_copy_default(func, *args, **kwargs): |
| from .nested_tensor import _tensor_symint_registry |
| |
| _, new_kwargs = normalize_function( # type: ignore[misc] |
| func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
| ) |
| |
| inp = new_kwargs.pop("input") |
| # don't change layout |
| new_kwargs.pop("layout") |
| |
| new_values = func(inp._values, **new_kwargs) |
| new_offsets = inp._offsets.to(device=new_values.device) |
| |
| from torch._subclasses.fake_tensor import FakeTensor |
| from torch._subclasses.functional_tensor import ( |
| FunctionalTensor, |
| mb_unwrap_functional_tensor, |
| ) |
| |
| if isinstance(new_offsets, (FakeTensor, FunctionalTensor)): |
| # Temporary hack until we have the union find |
| tgt = mb_unwrap_functional_tensor(new_offsets) |
| src = mb_unwrap_functional_tensor(inp._offsets) |
| tgt.nested_int_memo = src.nested_int_memo |
| else: |
| _tensor_symint_registry[new_offsets] = _tensor_symint_registry[inp._offsets] |
| inp_kwargs = extract_kwargs(inp) |
| inp_kwargs["offsets"] = new_offsets |
| |
| return NestedTensor(new_values, **inp_kwargs) |
| |
| |
| @register_jagged_func( |
| torch.ops.aten.copy_.default, "self: jt_all, src: jt_all, non_blocking: any?" |
| ) |
| def copy_default(func, *args, **kwargs): |
| _, new_kwargs = normalize_function( # type: ignore[misc] |
| func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
| ) |
| inp = new_kwargs.pop("input") |
| src = new_kwargs.pop("src") |
| if inp._size != src._size: |
| raise RuntimeError( |
| "copy_ only supports Nested Tensors that have same size and the exact same offset tensor." |
| ) |
| inp.values().copy_(src.values()) |
| return inp |
| |
| |
| register_jagged_func(torch.ops.aten.detach.default, "self: jt_all")( |
| jagged_unary_pointwise |
| ) |
| |
| |
| @register_jagged_func( |
| [ |
| torch.ops.aten.empty_like.default, |
| torch.ops.aten.ones_like.default, |
| torch.ops.aten.zeros_like.default, |
| torch.ops.aten.randn_like.default, |
| ], |
| "self: jt_all", |
| ) |
| def like_factory_default(func, *args, **kwargs): |
| _, new_kwargs = normalize_function( # type: ignore[misc] |
| func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
| ) |
| |
| inp = new_kwargs.pop("input") |
| |
| # Default layout is technically torch.strided but only jagged is supported here. |
| # Rather than force users to specify the layout, assume jagged. |
| # This should be set to strided for redispatching on values. |
| new_kwargs["layout"] = torch.strided |
| |
| return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp)) |
| |
| |
| @register_jagged_func(torch.ops.aten.zero_.default, "self: jt_all") |
| def zero__default(func, *args, **kwargs): |
| _, new_kwargs = normalize_function( # type: ignore[misc] |
| func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
| ) |
| |
| inp = new_kwargs.pop("input") |
| func(inp._values) |
| return inp |
| |
| |
| @register_jagged_func( |
| torch.ops.aten._softmax.default, "self: jt_all, dim: any, half_to_float: any" |
| ) |
| def _softmax_default(func, *args, **kwargs): |
| _, new_kwargs = normalize_function( # type: ignore[misc] |
| func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
| ) |
| |
| if isinstance(new_kwargs["dim"], tuple): |
| raise RuntimeError( |
| "softmax(): not supported for dimensions of type 'tuple' for NestedTensor" |
| ) |
| |
| inp = new_kwargs.pop("input") |
| |
| ( |
| new_kwargs["dim"], |
| reduce_on_batch, |
| reduce_on_ragged, |
| reduce_on_non_batch, |
| ) = _wrap_jagged_dims( |
| inp.dim(), |
| (new_kwargs["dim"],), |
| "softmax", |
| inp._ragged_idx, |
| ) |
| |
| if reduce_on_batch: |
| raise RuntimeError( |
| "softmax(): not supported when reducing across the batch dimension for NestedTensor" |
| ) |
| |
| if reduce_on_ragged and inp._ragged_idx > 1: |
| raise RuntimeError( |
| "softmax(): not supported when reducing along the ragged dimension for ragged_idx > 1 for NestedTensor" |
| ) |
| |
| if reduce_on_ragged and inp._lengths is not None: |
| raise RuntimeError( |
| "softmax(): not supported where lengths is not None " |
| + "if reducing across the ragged dimension for NestedTensor" |
| ) |
| |
| new_kwargs["dim"] = new_kwargs["dim"][ |
| 0 |
| ] # torch.softmax takes in the reduction dimension as an integer |
| |
| if reduce_on_ragged: |
| padded_softmax_values = torch.nn.functional.softmax( |
| torch.ops.aten._jagged_to_padded_dense_forward( |
| inp._values.reshape( |
| inp._values.shape[0], -1 |
| ), # values are required to be 2D tensors for j2pd |
| [inp._offsets], |
| max_lengths=[inp._max_seqlen], # max length of ragged dimension |
| padding_value=float("-inf"), # e^-inf = 0 |
| ), |
| dim=inp._ragged_idx, |
| ) |
| |
| softmax_values = torch.ops.aten._padded_dense_to_jagged_forward( |
| padded_softmax_values, |
| [inp._offsets], |
| total_L=inp._values.shape[ |
| 0 |
| ], # providing this parameter helps avoid a GPU/CPU sync |
| ).reshape( |
| -1, *inp._values.shape[1:] |
| ) # expand softmax_values back to original shape (inp._values.shape) |
| |
| return NestedTensor(softmax_values, **extract_kwargs(inp)) |
| |
| return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp)) |
| |
| |
| @register_jagged_func( |
| torch.ops.aten._softmax_backward_data.default, |
| "grad_output: jt, output: jt, dim: any, input_dtype: any", |
| ) |
| def _softmax_backward(func, *args, **kwargs): |
| _, new_kwargs = normalize_function( # type: ignore[misc] |
| func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
| ) |
| grad_out = new_kwargs.pop("grad_output") |
| output = new_kwargs.pop("output") |
| return NestedTensor( |
| func(grad_out._values, output._values, **new_kwargs), **extract_kwargs(grad_out) |
| ) |
| |
| |
| @register_jagged_func( |
| torch.ops.aten.native_dropout.default, "self: jt, float: any, train: any?" |
| ) |
| def native_dropout_default(func, *args, **kwargs): |
| _, new_kwargs = normalize_function( # type: ignore[misc] |
| func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
| ) |
| |
| inp = new_kwargs.pop("input") |
| out1, out2 = func(inp._values, **new_kwargs) |
| return ( |
| NestedTensor(out1, **extract_kwargs(inp)), |
| NestedTensor(out2, **extract_kwargs(inp)), |
| ) |
| |
| |
| @register_jagged_func( |
| torch.ops.aten.native_dropout_backward.default, |
| "grad_output: jt, mask: jt, scale: any", |
| ) |
| def native_dropout_backward_default(func, *args, **kwargs): |
| _, new_kwargs = normalize_function( # type: ignore[misc] |
| func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
| ) |
| grad_output = new_kwargs.pop("grad_output") |
| mask = new_kwargs.pop("mask") |
| return NestedTensor( |
| func(grad_output._values, mask._values, **new_kwargs), |
| **extract_kwargs(grad_output), |
| ) |
| |
| |
| @register_jagged_func(torch.ops.aten.prod.dim_int, "self: jt, dim: any, keepdim: any?") |
| def prod_dim_int(func, *args, **kwargs): |
| _, new_kwargs = normalize_function( # type: ignore[misc] |
| func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
| ) |
| |
| inp = new_kwargs.pop("input") |
| # TODO: Figure out how to handle this better |
| # keep_dim is required to keep it in jagged format |
| if not new_kwargs["keepdim"]: |
| raise RuntimeError("prod(): keepdim=True must be set for NestedTensor") |
| dim = new_kwargs["dim"] |
| new_kwargs["dim"] = _wrap_jagged_dim(len(inp._size), dim, "prod") |
| |
| return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(args[0])) |
| |
| |
| @register_jagged_func( |
| torch.ops.aten.split.Tensor, "self: jt, split_size: any, dim: any" |
| ) |
| def split_tensor(func, *args, **kwargs): |
| _, new_kwargs = normalize_function( # type: ignore[misc] |
| func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
| ) |
| |
| inp = new_kwargs.pop("input") |
| |
| new_kwargs["dim"] = _wrap_jagged_dim(inp.dim(), new_kwargs["dim"], "split") |
| |
| return tuple( |
| NestedTensor(values=x, **extract_kwargs(inp)) |
| for x in func(inp._values, **new_kwargs) |
| ) |
| |
| |
| @register_jagged_func( |
| torch.ops.aten.split_with_sizes.default, "self: jt, split_sizes: any, dim: any" |
| ) |
| def split_with_sizes_default(func, *args, **kwargs): |
| _, new_kwargs = normalize_function( # type: ignore[misc] |
| func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
| ) |
| |
| inp = new_kwargs.pop("input") |
| |
| new_kwargs["dim"] = _wrap_jagged_dim( |
| inp.dim(), new_kwargs["dim"], "split_with_sizes" |
| ) |
| |
| return [ |
| NestedTensor(values=x, **extract_kwargs(inp)) |
| for x in func(inp._values, **new_kwargs) |
| ] |
| |
| |
| @register_jagged_func( |
| torch.ops.aten.narrow.default, "self: jt, dim: any, start: any, length: any" |
| ) |
| def narrow(func, *args, **kwargs): |
| _, new_kwargs = normalize_function( # type: ignore[misc] |
| func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
| ) |
| inp = new_kwargs.pop("input") |
| |
| dim = _wrap_jagged_dim(inp.dim(), new_kwargs["dim"], "narrow") |
| values = func( |
| inp._values, |
| dim=dim, |
| start=new_kwargs["start"], |
| length=new_kwargs["length"], |
| ) |
| return NestedTensor(values, **extract_kwargs(inp)) |
| |
| |
| @register_jagged_func(torch.ops.aten.chunk.default, "self: jt, chunks: any, dim: any?") |
| def chunk_default(func, *args, **kwargs): |
| _, new_kwargs = normalize_function( # type: ignore[misc] |
| func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
| ) |
| |
| inp = new_kwargs.pop("input") |
| |
| new_kwargs["dim"] = _wrap_jagged_dim( |
| inp.dim(), new_kwargs["dim"], "chunk", allow_batch_dim=True |
| ) |
| |
| if new_kwargs["dim"] == 0: |
| chunks = new_kwargs["chunks"] |
| dim0_size = inp._size[0] |
| chunk_size = math.ceil(dim0_size / chunks) |
| |
| # get _offsets of the chunks |
| lengths = inp._offsets.diff() |
| chunked_lengths = lengths.chunk(chunks) |
| chunked_offsets = [torch.cumsum(x, dim=0) for x in chunked_lengths] |
| chunked_offsets = [F.pad(x, (1, 0), value=0) for x in chunked_offsets] # type: ignore[arg-type] |
| nested_kwargs = [ |
| {"offsets": per_offsets, "_ragged_idx": inp._ragged_idx} |
| for per_offsets in chunked_offsets |
| ] |
| |
| # get _values of the chunks |
| split_sizes = [x.sum().item() for x in chunked_lengths] |
| chunk_values = inp._values.split(split_sizes) |
| |
| return [ |
| NestedTensor(values=chunk_values[i], **(nested_kwargs[i])) |
| for i in range(0, chunk_size) |
| ] |
| else: |
| return [ |
| NestedTensor(values=x, **extract_kwargs(inp)) |
| for x in func(inp._values, **new_kwargs) |
| ] |
| |
| |
| @register_jagged_func(torch.ops.aten.unbind.int, "self: jt_all, dim: any?") |
| def unbind_int(func, *args, **kwargs): |
| # Note that this specializes on the length of the offsets |
| _, new_kwargs = normalize_function( # type: ignore[misc] |
| func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
| ) |
| |
| dim = new_kwargs["dim"] |
| if dim != 0: |
| raise RuntimeError("unbind(): only supported for NestedTensor on dim=0") |
| |
| inp = new_kwargs.pop("input") |
| values = inp.values() |
| offsets = inp.offsets() |
| lengths = inp.lengths() |
| ragged_idx = inp._ragged_idx |
| |
| if lengths is None: |
| return torch.split(values, offsets.diff().tolist(), dim=(ragged_idx - 1)) |
| |
| if ragged_idx <= 0: |
| raise RuntimeError( |
| "unbind(): nested tensor ragged_idx out of bounds (should be >= 1)" |
| ) |
| for i in range(lengths.shape[0]): |
| if offsets[i] + lengths[i] > values.shape[ragged_idx - 1]: |
| raise RuntimeError( |
| "unbind(): nested tensor offsets and lengths do not match ragged_idx dimension" |
| ) |
| return [ |
| torch.narrow(values, dim=(ragged_idx - 1), start=offsets[i], length=lengths[i]) |
| for i in range(lengths.shape[0]) |
| ] |
| |
| |
| @register_jagged_func(torch.ops.aten.squeeze.dim, "self: jt, dim: any") |
| def squeeze_dim(func, *args, **kwargs): |
| _, new_kwargs = normalize_function( # type: ignore[misc] |
| func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
| ) |
| |
| inp = new_kwargs.pop("input") |
| values = inp._values |
| |
| new_kwargs["dim"] = _wrap_jagged_dim(len(inp._size), new_kwargs["dim"], "squeeze") |
| return NestedTensor(func(values, **new_kwargs), **extract_kwargs(inp)) |
| |
| |
| @register_jagged_func(torch.ops.aten.unsqueeze.default, "self: jt, dim: any") |
| def unsqueeze_default(func, *args, **kwargs): |
| _, new_kwargs = normalize_function( # type: ignore[misc] |
| func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
| ) |
| |
| inp = new_kwargs.pop("input") |
| values = inp._values |
| |
| # Account for collapsed jagged dim |
| dim = new_kwargs["dim"] |
| new_kwargs["dim"] = _wrap_jagged_dim(len(inp._size) + 1, dim, "unsqueeze") |
| return NestedTensor(func(values, **new_kwargs), **extract_kwargs(inp)) |
| |
| |
| @register_jagged_func(torch.ops.aten.cat.default, "tensors: any, dim: any") |
| def cat_default(func, *args, **kwargs): |
| _, new_kwargs = normalize_function( # type: ignore[misc] |
| func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
| ) |
| |
| tensors = new_kwargs.pop("tensors") |
| |
| # Convert any non-nested to nested |
| nested = [t for t in tensors if t.is_nested] |
| assert len(nested) > 0 |
| first = nested[0] |
| tensors = [t if t.is_nested else t.expand_as(first) for t in tensors] |
| |
| # Account for collapsed jagged dim |
| dim = new_kwargs["dim"] |
| new_kwargs["dim"] = _wrap_jagged_dim(len(first.shape), dim, "cat") |
| |
| return NestedTensor( |
| func([t._values for t in tensors], **new_kwargs), **extract_kwargs(tensors[0]) |
| ) |
| |
| |
| @register_jagged_func(torch.ops.aten.matmul.default, "self: jt, other: any") |
| def matmul_default(func, *args, **kwargs): |
| _, new_kwargs = normalize_function( # type: ignore[misc] |
| func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
| ) |
| |
| inp = new_kwargs.pop("input") |
| other = new_kwargs.pop("other") |
| |
| if inp.is_nested and not other.is_nested: |
| return NestedTensor( |
| func(inp._values, other, **new_kwargs), **extract_kwargs(inp) |
| ) |
| elif inp.is_nested and other.is_nested: |
| # BMM with equivalent ragged dims between the two inputs |
| if inp.dim() > 3 and other.dim() > 3 and raggedness_matches(inp, other._size): |
| return NestedTensor(func(inp._values, other._values), **extract_kwargs(inp)) |
| |
| raise RuntimeError( |
| f"matmul(): not supported between inputs of shapes {inp._size} and {other.shape}" |
| ) |
| |
| |
| @register_jagged_func( |
| torch.ops.aten.expand.default, "self: jt, size: any, implicit: any?" |
| ) |
| def expand_default(func, *args, **kwargs): |
| _, new_kwargs = normalize_function( # type: ignore[misc] |
| func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
| ) |
| |
| inp = new_kwargs.pop("input") |
| size = new_kwargs["size"] |
| |
| assert ("implicit" not in new_kwargs) or (not new_kwargs.pop("implicit")) |
| if not raggedness_matches(inp, size): |
| raise RuntimeError(f"expand(): cannot expand shape {inp._size} -> {size}") |
| |
| expand_arg = [-1, *size[2:]] |
| return NestedTensor(func(inp._values, expand_arg), **extract_kwargs(inp)) |
| |
| |
| @register_jagged_func(torch.ops.aten.expand_as.default, "self: t, other: jt") |
| def expand_as_default(func, *args, **kwargs): |
| _, new_kwargs = normalize_function( # type: ignore[misc] |
| func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
| ) |
| |
| inp = new_kwargs.pop("input") |
| other = new_kwargs.pop("other") |
| |
| return NestedTensor(func(inp, other._values), **extract_kwargs(other)) |
| |
| |
| @register_jagged_func(torch.ops.aten.where.self, "condition: jt, self: jt, other: jt") |
| def where_self(func, *args, **kwargs): |
| _, new_kwargs = normalize_function( # type: ignore[misc] |
| func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
| ) |
| |
| condition = new_kwargs.pop("condition") |
| inp = new_kwargs.pop("input") |
| other = new_kwargs.pop("other") |
| |
| assert condition._size == other._size == inp._size |
| |
| return NestedTensor( |
| func(condition._values, inp._values, other._values, **new_kwargs), |
| **extract_kwargs(condition), |
| ) |
| |
| |
| @register_jagged_func(torch.ops.aten._pin_memory.default, "self: jt, device: any?") |
| def _pin_memory_default(func, *args, **kwargs): |
| _, new_kwargs = normalize_function( # type: ignore[misc] |
| func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
| ) |
| |
| inp = new_kwargs.pop("input") |
| |
| return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp)) |
| |
| |
| @register_jagged_func(torch.ops.aten.is_pinned.default, "self: jt, device: any?") |
| def is_pinned_default(func, *args, **kwargs): |
| _, new_kwargs = normalize_function( # type: ignore[misc] |
| func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
| ) |
| |
| inp = new_kwargs.pop("input") |
| |
| return func(inp._values, **new_kwargs) |
| |
| |
| @register_jagged_func( |
| torch.ops.aten.is_same_size.default, "self: jt_all, other: jt_all" |
| ) |
| def is_same_size_default(func, *args, **kwargs): |
| return args[0]._size == args[1]._size |
| |
| |
| @register_jagged_func( |
| torch.ops.aten.sum.dim_IntList, |
| "self: jt_all, dim: any?, keepdim: any?, dtype: any?", |
| ) |
| def sum_dim_IntList(func, *args, **kwargs): |
| """ |
| Performs a sum along the provided tensor dimension. |
| Returns a dense tensor if the ragged dimension is reduced away, else returns a nested tensor. |
| """ |
| _, new_kwargs = normalize_function( # type: ignore[misc] |
| func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
| ) |
| inp = new_kwargs.pop("input") |
| |
| ( |
| new_kwargs["dim"], |
| reduce_on_batch, |
| reduce_on_ragged, |
| reduce_on_non_batch, |
| ) = _wrap_jagged_dims( |
| inp.dim(), |
| new_kwargs["dim"], |
| "sum", |
| inp._ragged_idx, |
| ) |
| |
| if reduce_on_ragged and inp._lengths is not None: |
| raise RuntimeError( |
| "sum(): not supported where lengths is not None " |
| + "if reducing across the ragged dimension for NestedTensor" |
| ) |
| |
| if reduce_on_ragged: # raggedness reduced away --> return dense tensor |
| if ( |
| reduce_on_batch |
| ): # reduction cases: (batch, ragged), (batch, ragged, non-batch), etc. |
| out = func( |
| inp._values, **new_kwargs |
| ) # no need to read offsets --> apply sum directly on values |
| else: |
| if ( |
| reduce_on_non_batch |
| ): # invalid reduction cases: (ragged, non-batch), etc. |
| raise RuntimeError( |
| "sum(): not supported along a ragged and non-batch dimension for NestedTensor" |
| ) |
| # reduction cases: (ragged) |
| values_ragged_dim_outer = inp._values.permute( |
| inp._ragged_idx - 1, # outer dimension |
| *range(0, inp._ragged_idx - 1), |
| *range(inp._ragged_idx, inp.dim() - 1), |
| ) # shift reduction dimension of values backward to outer dimension |
| |
| # _jagged_to_padded_dense_forward requires values to be a 2D tensor |
| # with the ragged dimension as the 0th dimension |
| padded = torch.ops.aten._jagged_to_padded_dense_forward( |
| values_ragged_dim_outer.reshape(values_ragged_dim_outer.shape[0], -1), |
| [inp._offsets], |
| max_lengths=[inp._max_seqlen], |
| ) |
| |
| padded_ragged_dim_original = padded.view( |
| padded.shape[0], |
| inp._max_seqlen, |
| *values_ragged_dim_outer.shape[ |
| 1: |
| ], # expand non-batch dimensions of padded tensor |
| ).permute( |
| 0, |
| *range(2, inp._ragged_idx + 1), |
| 1, |
| *range(inp._ragged_idx + 1, inp.dim()), |
| ) # shift reduction dimension of padded tensor forward to original ragged dimension |
| |
| out = torch.sum( |
| padded_ragged_dim_original, |
| dim=inp._ragged_idx, |
| ) # need to read offsets --> pad jagged dimension and apply sum |
| |
| if new_kwargs["keepdim"]: |
| # TODO: Fix this; it's a bug. should be unsqueezing on ragged_idx |
| out = out.unsqueeze(0) |
| return out |
| else: # raggedness preserved --> return nested tensor |
| if ( |
| reduce_on_batch |
| ): # invalid reduction cases: (batch), (batch, non-batch), etc. |
| raise RuntimeError( |
| "sum(): not supported along the batch dimension but not the ragged dimension for NestedTensor" |
| ) |
| # reduction cases: (non-batch), (non-batch, non-batch), etc. |
| return NestedTensor( |
| func(inp._values, **new_kwargs), **extract_kwargs(inp) |
| ) # apply sum directly on values |
| |
| |
| @register_jagged_func( |
| torch.ops.aten.transpose.int, "self: jt_all, dim0: any, dim1: any" |
| ) |
| def transpose_int(func, *args, **kwargs): |
| _, new_kwargs = normalize_function( # type: ignore[misc] |
| func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
| ) |
| |
| from torch._prims_common import canonicalize_dims |
| |
| inp = new_kwargs.pop("input") |
| dim0, dim1 = canonicalize_dims(inp.dim(), (new_kwargs["dim0"], new_kwargs["dim1"])) |
| |
| if inp._lengths is not None: |
| raise ValueError( |
| "transpose(): not supported on jagged layout nested tensor with holes" |
| ) |
| |
| # To support the SDPA API, inputs need to have the ragged idx transposed to dim 2 |
| # instead of 1, although the internal Flash and mem-effn implementations will |
| # use the inputs with raggedness in dim 1. |
| if dim0 == inp._ragged_idx or dim1 == inp._ragged_idx: |
| if dim0 == 0 or dim1 == 0: |
| raise ValueError( |
| "Transpose is not supported on the batch dimension for jagged NT" |
| ) |
| if dim0 == inp._ragged_idx: |
| to_dim = dim1 |
| else: |
| to_dim = dim0 |
| inp_kwargs = extract_kwargs(inp) |
| inp_kwargs["_ragged_idx"] = to_dim |
| return NestedTensor( |
| inp.values().transpose( |
| _outer_to_inner_dim(len(inp._size), dim0), |
| _outer_to_inner_dim(len(inp._size), dim1), |
| ), |
| **inp_kwargs, |
| ) |
| |
| new_kwargs["dim0"] = _wrap_jagged_dim(inp.dim(), new_kwargs["dim0"], "transpose") |
| new_kwargs["dim1"] = _wrap_jagged_dim(inp.dim(), new_kwargs["dim1"], "transpose") |
| |
| return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp)) |
| |
| |
| @register_jagged_func(torch.ops.aten.permute.default, "self: jt_all, dims: any") |
| def permute_default(func, *args, **kwargs): |
| _, new_kwargs = normalize_function( # type: ignore[misc] |
| func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
| ) |
| inp = new_kwargs.pop("input") |
| dims = new_kwargs.pop("dims") |
| inp_kwargs = extract_kwargs(inp) |
| inp_dim = len(inp._size) |
| |
| # The first two checks are the same as the checks in the normal permute implementation |
| if inp_dim != len(dims): |
| raise ValueError( |
| f"permute(): number of dimensions in the tensor input ({inp_dim}) " |
| + f"does not match the length of the desired ordering of dimensions ({len(dims)}).", |
| ) |
| |
| from torch._prims_common import canonicalize_dims |
| |
| canonicalized_dims = canonicalize_dims(inp_dim, dims) |
| |
| if len(canonicalized_dims) != len(set(canonicalized_dims)): |
| raise ValueError("permute(): duplicate dims are not allowed.") |
| |
| if inp._lengths is not None: |
| raise ValueError( |
| "permute(): not supported on jagged layout nested tensor with holes" |
| ) |
| if canonicalized_dims[0] != 0: |
| raise ValueError( |
| "Permute is not supported on the batch dimension for jagged NT" |
| ) |
| inp_kwargs["_ragged_idx"] = canonicalized_dims.index(inp._ragged_idx) |
| inner_dims = [_outer_to_inner_dim(inp_dim, dim) for dim in canonicalized_dims[1:]] |
| new_kwargs["dims"] = inner_dims |
| return NestedTensor(func(inp._values, **new_kwargs), **inp_kwargs) |
| |
| |
| @register_jagged_func( |
| [torch.ops.aten.view.default, torch.ops.aten._unsafe_view.default], |
| "self: jt_all, size: any", |
| ) |
| def view_default(func, *args, **kwargs): |
| _, new_kwargs = normalize_function( # type: ignore[misc] |
| func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
| ) |
| |
| inp = new_kwargs.pop("input") |
| size = new_kwargs.pop("size") |
| |
| if inp._ragged_idx != 1 and tuple(inp._size) != tuple(size): |
| raise RuntimeError( |
| f"view(): does not support ragged_idx != 1 except when inp._size == size. " |
| f"inp._size is ({inp._size}) and size is ({size})." |
| ) |
| |
| # Ensure specified size still includes batch and ragged dims |
| if len(size) < 3 or not raggedness_matches(inp, size): |
| raise RuntimeError(f"view(): cannot view shape {inp._size} as {size}") |
| |
| # outer size: the size of the NT, e.g. [3, j0, 10] |
| # inner size: the size of the values, e.g. [8, 10] (e.g. for offsets = [0, 3, 5, 8]) |
| # this function gets inner_size[inner_idx] for a given inner_idx. |
| # |
| # example: for outer size [a, b, c, j0, d, e, f] |
| # assume that j0 is ragged, other are concrete integers |
| # and ragged_idx=3 |
| # inner size will be [b, c, inp._values.size(ragged_idx), d, e, f] |
| # therefore: |
| # inner_size[0] = outer_size[1] |
| # inner_size[1] = outer_size[2] |
| # inner_size[0] = inp._values.size(ragged_idx - 1) |
| # inner_size[3] = outer_size[4] |
| # inner_size[4] = outer_size[5] |
| def get_inner_size(inner_idx): |
| nonlocal inp, size |
| if inner_idx == inp._ragged_idx - 1: |
| return inp._values.size(inner_idx) |
| else: |
| return size[inner_idx + 1] |
| |
| inner_size = [get_inner_size(i) for i in range(len(size) - 1)] |
| |
| return NestedTensor(func(inp._values, inner_size), **extract_kwargs(inp)) |
| |
| |
| @register_jagged_func( |
| torch.ops.aten.native_layer_norm.default, |
| "input: jt_all, normalized_shape: any, weight: any?, bias: any?, eps: any", |
| ) |
| def native_layer_norm_default(func, *args, **kwargs): |
| _, new_kwargs = normalize_function( # type: ignore[misc] |
| func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
| ) |
| |
| inp = new_kwargs.pop("input") |
| |
| if inp.dim() <= 2: |
| raise RuntimeError( |
| "layer_norm(): not supported for NestedTensor objects with 2 or fewer dimensions" |
| ) |
| |
| normalized_shape = new_kwargs["normalized_shape"] |
| ragged_size = inp.shape[inp._ragged_idx] |
| |
| num_dims_not_normalized = inp.dim() - len(normalized_shape) |
| |
| if ( |
| num_dims_not_normalized == 0 |
| ): # error if trying to normalize over the batch dimension |
| raise RuntimeError( |
| "layer_norm(): not supported when normalizing over the batch dimension for NestedTensor" |
| ) |
| |
| if ragged_size in normalized_shape and inp._lengths is not None: |
| raise RuntimeError( |
| "layer_norm(): not supported where lengths is not None if operating on the ragged dimension for NestedTensor" |
| ) |
| |
| if ( |
| ragged_size in normalized_shape |
| ): # special handling for normalizing over the ragged dimension |
| padded_input = torch.ops.aten._jagged_to_padded_dense_forward( |
| inp._values.flatten( |
| start_dim=inp._ragged_idx |
| ), # _jagged_to_padded_dense_forward requires values to be a 2D tensor |
| [inp._offsets], |
| max_lengths=[inp._max_seqlen], # max length of ragged dimension |
| ) |
| |
| padded_mask = torch.ops.aten._jagged_to_padded_dense_forward( |
| torch.ones((inp._values.shape[0], 1), device=inp.device, dtype=inp.dtype), |
| [inp._offsets], |
| max_lengths=[inp._max_seqlen], # max length of ragged dimension |
| ).expand( |
| padded_input.shape |
| ) # mask elements outside of the ragged dimension and expand to the same shape as padded input (3D dense tensor) |
| |
| ragged_lengths = ( |
| inp._offsets.diff().unsqueeze(1).unsqueeze(1) * padded_input.shape[2] |
| ) # ragged dim * inner dim, since we sum over dims (1, 2) (the layer on which we normalize) |
| |
| mean = ( |
| torch.sum( |
| padded_input, |
| dim=(1, 2), |
| keepdim=True, |
| ) |
| / ragged_lengths |
| ) # a sum over (1, 2) ensures layer norm, whereas a sum over (1) would be an instance norm |
| |
| padded_normalized = ( |
| padded_input - mean |
| ) * padded_mask # mask elements outside of the ragged dimension size for correct variance calculation |
| |
| variance = ( |
| torch.sum( |
| torch.square(padded_normalized), |
| dim=(1, 2), |
| keepdim=True, |
| ) |
| / ragged_lengths |
| ) # a sum over (1, 2) ensures layer norm, whereas a sum over (1) would be an instance norm |
| |
| std = torch.sqrt(variance + new_kwargs["eps"]) |
| padded_layer_norm = padded_normalized / std |
| |
| jagged_layer_norm_values = torch.ops.aten._padded_dense_to_jagged_forward( |
| padded_layer_norm, |
| [inp._offsets], |
| total_L=inp._values.shape[ |
| 0 |
| ], # providing this parameter helps avoid a GPU/CPU sync |
| ).unflatten( |
| -1, inp.shape[inp._ragged_idx + 1 :] |
| ) # unflatten last dimension back into original nested tensor shape, e.g. (B, *, WH) --> (B, *, W, H) |
| |
| return ( |
| NestedTensor(jagged_layer_norm_values, **extract_kwargs(inp)), |
| mean, |
| std, |
| ) |
| |
| output, mean, std = func(inp._values, **new_kwargs) |
| return (NestedTensor(output, **extract_kwargs(inp)), mean, std) |
| |
| |
| @register_jagged_func( |
| torch.ops.aten.native_layer_norm_backward.default, |
| "grad_out: jt, input: jt, normalized_shape: any, mean: any, rstd: any, weight: any?, bias: any?, output_mask: any", |
| ) |
| def native_layer_norm_backward_default(func, *args, **kwargs): |
| _, new_kwargs = normalize_function( # type: ignore[misc] |
| func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
| ) |
| grad_out = new_kwargs.pop("grad_out") |
| inp = new_kwargs.pop("input") |
| d_input, d_gamma, d_beta = func(grad_out._values, inp._values, **new_kwargs) |
| if d_input is None: |
| return (None, d_gamma, d_beta) |
| |
| return (NestedTensor(d_input, **extract_kwargs(inp)), d_gamma, d_beta) |
| |
| |
| @register_jagged_func(torch.ops.aten.select.int, "self: jt, dim: any, index: any") |
| def select_int(func, *args, **kwargs): |
| _, new_kwargs = normalize_function( # type: ignore[misc] |
| func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
| ) |
| |
| inp = new_kwargs.pop("input") |
| new_kwargs["dim"] = _wrap_jagged_dim( |
| inp.dim(), new_kwargs["dim"], "select", allow_batch_dim=True |
| ) |
| |
| # handle batch dim slicing via unbind() for now |
| # TODO: make this more efficient |
| if new_kwargs["dim"] == 0: |
| return inp.unbind()[new_kwargs["index"]] |
| |
| return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp)) |
| |
| |
| @register_jagged_func( |
| torch.ops.aten.slice.Tensor, |
| "self: jt, dim: any?, start: any?, end: any?, step: any?", |
| ) |
| def slice_tensor(func, *args, **kwargs): |
| _, new_kwargs = normalize_function( # type: ignore[misc] |
| func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
| ) |
| |
| inp = new_kwargs.pop("input") |
| new_kwargs["dim"] = _wrap_jagged_dim(inp.dim(), new_kwargs["dim"], "slice") |
| |
| return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp)) |
| |
| |
| @register_jagged_func( |
| torch.ops.aten.convolution.default, |
| "input: jt, weight: t, bias: t?, stride: any, padding: any, " |
| "dilation: any, transposed: any, output_padding: any, groups: any", |
| ) |
| def convolution_default(func, *args, **kwargs): |
| _, new_kwargs = normalize_function( # type: ignore[misc] |
| func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
| ) |
| |
| inp = new_kwargs.pop("input") |
| |
| return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp)) |
| |
| |
| @register_jagged_func( |
| torch.ops.aten.mean.dim, "self: jt_all, dim: any?, keepdim: any?, dtype: any?" |
| ) |
| def mean_dim(func, *args, **kwargs): |
| """ |
| Performs a mean along the provided tensor dimension. |
| Returns a dense tensor if the ragged dimension is reduced away, else returns a nested tensor. |
| """ |
| _, new_kwargs = normalize_function( # type: ignore[misc] |
| func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
| ) |
| |
| if len(new_kwargs["dim"]) > 1: |
| raise RuntimeError( |
| "mean(): not supported across multiple dimensions for NestedTensor" |
| ) |
| |
| inp = new_kwargs.pop("input") |
| |
| ( |
| new_kwargs["dim"], |
| reduce_on_batch, |
| reduce_on_ragged, |
| reduce_on_non_batch, |
| ) = _wrap_jagged_dims( |
| inp.dim(), |
| new_kwargs["dim"], |
| "mean", |
| inp._ragged_idx, |
| ) |
| |
| if reduce_on_batch: |
| raise RuntimeError( |
| "mean(): not supported along the batch dimension but not the ragged dimension for NestedTensor" |
| ) |
| |
| if reduce_on_ragged and inp._lengths is not None: |
| raise RuntimeError( |
| "mean(): not supported where lengths is not None " |
| + "if reducing across the ragged dimension for NestedTensor" |
| ) |
| |
| if not new_kwargs["keepdim"]: |
| raise RuntimeError("mean(): not supported when keepdim=False for NestedTensor") |
| |
| if reduce_on_ragged: # raggedness reduced away |
| torch_sum = torch.sum(inp, dim=inp._ragged_idx, keepdim=new_kwargs["keepdim"]) |
| |
| # for every non-batch dimension, |
| # unsqueeze lengths into the same shape as the PyTorch sum, |
| # as the extra dimensions must all be divided by the same length |
| lengths = inp._offsets.diff() |
| for _ in range(inp.dim() - 2): |
| lengths = lengths.unsqueeze(-1) |
| |
| return torch_sum / lengths.broadcast_to(torch_sum.shape) |
| |
| return NestedTensor( |
| func(inp._values, **new_kwargs), **extract_kwargs(inp) |
| ) # raggedness preserved |
| |
| |
| @register_jagged_func(torch.ops.aten.stack.default, "tensors: any, dim: any") |
| def stack_default(func, *args, **kwargs): |
| _, new_kwargs = normalize_function( # type: ignore[misc] |
| func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
| ) |
| |
| # guaranteed this is non-empty if we got here |
| tensors = new_kwargs.pop("tensors") |
| for t in tensors: |
| if not isinstance(t, NestedTensor): |
| raise RuntimeError("stack(): expected all nested tensors inputs") |
| |
| if t.dim() != tensors[0].dim(): |
| raise RuntimeError( |
| "stack(): expected all nested tensors to have the same dim" |
| ) |
| |
| if not raggedness_matches(t, tensors[0].shape): |
| raise RuntimeError( |
| "stack(): expected all nested tensors to have the same nested structure" |
| ) |
| |
| new_kwargs["dim"] = _wrap_jagged_dim( |
| tensors[0].dim() + 1, new_kwargs["dim"], "stack" |
| ) |
| |
| return NestedTensor( |
| func([t._values for t in tensors], **new_kwargs), **extract_kwargs(tensors[0]) |
| ) |
| |
| |
| @register_jagged_func( |
| torch.ops.aten.embedding.default, |
| "weight: t, indices: jt, padding_idx: any?, scale_grad_by_freq: any?, sparse: any?", |
| ) |
| def embedding_default(func, *args, **kwargs): |
| _, new_kwargs = normalize_function( # type: ignore[misc] |
| func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
| ) |
| |
| # guaranteed this is non-empty if we got here |
| indices = new_kwargs.pop("indices") |
| weight = new_kwargs.pop("weight") |
| |
| return NestedTensor( |
| func(weight, indices._values, **new_kwargs), **extract_kwargs(indices) |
| ) |
| |
| |
| @register_jagged_func( |
| [ |
| torch.ops.aten.values.default, |
| torch.ops.aten._nested_get_values.default, |
| ], |
| "self: jt_all", |
| ) |
| def values_default(func, *args, **kwargs): |
| _, new_kwargs = normalize_function( # type: ignore[misc] |
| func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
| ) |
| |
| inp = new_kwargs.pop("input") |
| |
| # TODO: Handle inference mode properly. |
| # See https://github.com/pytorch/pytorch/issues/112024#issuecomment-1779554292 |
| return inp._values.detach() |
| |
| |
| @register_jagged_func(torch.ops.aten.all.default, "self: jt_all") |
| def all_default(func, *args, **kwargs): |
| _, new_kwargs = normalize_function( # type: ignore[misc] |
| func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
| ) |
| |
| inp = new_kwargs.pop("input") |
| |
| return func(inp._values) |
| |
| |
| @register_jagged_func( |
| torch.ops.aten._nested_view_from_jagged.default, |
| "values: t, offsets: t, dummy: jt_all, lengths: t?, ragged_idx: any?, min_seqlen: t?, max_seqlen: t?", |
| ) |
| def _nested_view_from_jagged_default(func, *args, **kwargs): |
| _, new_kwargs = normalize_function( # type: ignore[misc] |
| func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
| ) |
| |
| values, offsets, lengths = ( |
| new_kwargs["input"], |
| new_kwargs["offsets"], |
| new_kwargs["lengths"], |
| ) |
| ragged_idx = new_kwargs["ragged_idx"] |
| min_seqlen = new_kwargs["min_seqlen"] |
| max_seqlen = new_kwargs["max_seqlen"] |
| metadata_cache = {} |
| if min_seqlen is not None: |
| metadata_cache["min_seqlen"] = min_seqlen |
| if max_seqlen is not None: |
| metadata_cache["max_seqlen"] = max_seqlen |
| |
| return NestedTensor( |
| values, |
| offsets, |
| lengths=lengths, |
| _ragged_idx=ragged_idx, |
| _metadata_cache=metadata_cache, |
| ) |
| |
| |
| @register_jagged_func(torch.ops.aten._nested_get_offsets.default, "self: jt_all") |
| def _nested_get_offsets(func, *args, **kwargs): |
| _, new_kwargs = normalize_function( # type: ignore[misc] |
| func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
| ) |
| |
| inp = new_kwargs.pop("input") |
| return inp._offsets |
| |
| |
| @register_jagged_func(torch.ops.aten._nested_get_lengths.default, "self: jt_all") |
| def _nested_get_lengths(func, *args, **kwargs): |
| _, new_kwargs = normalize_function( # type: ignore[misc] |
| func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
| ) |
| |
| inp = new_kwargs.pop("input") |
| return inp._lengths |
| |
| |
| @register_jagged_func(torch.ops.aten._nested_get_ragged_idx.default, "self: jt_all") |
| def _nested_get_ragged_idx(func, *args, **kwargs): |
| _, new_kwargs = normalize_function( # type: ignore[misc] |
| func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
| ) |
| |
| inp = new_kwargs.pop("input") |
| return inp._ragged_idx |
| |
| |
| @register_jagged_func(torch.ops.aten._nested_get_min_seqlen.default, "self: jt_all") |
| def _nested_get_min_seqlen(func, *args, **kwargs): |
| _, new_kwargs = normalize_function( # type: ignore[misc] |
| func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
| ) |
| |
| inp = new_kwargs.pop("input") |
| return inp._metadata_cache.get("min_seqlen", None) |
| |
| |
| @register_jagged_func(torch.ops.aten._nested_get_max_seqlen.default, "self: jt_all") |
| def _nested_get_max_seqlen(func, *args, **kwargs): |
| _, new_kwargs = normalize_function( # type: ignore[misc] |
| func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
| ) |
| |
| inp = new_kwargs.pop("input") |
| return inp._metadata_cache.get("max_seqlen", None) |
| |
| |
| # If a section of the Nested Tensor is fully masked out we still retain the section with a length of 0 |
| @register_jagged_func(torch.ops.aten.masked_select.default, "self: jt, mask: any") |
| def masked_select_default(func, *args, **kwargs): |
| _, new_kwargs = normalize_function( # type: ignore[misc] |
| func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
| ) |
| inp = new_kwargs.pop("input") |
| mask = new_kwargs.pop("mask") |
| |
| if inp.ndim > 2: |
| raise RuntimeError("masked_select only support 2-D selections currently") |
| elif inp.shape != mask.shape: |
| raise RuntimeError( |
| f"Mask with shape {mask.shape} is not compatible with input's shape {inp.shape}" |
| ) |
| res_values = inp._values.masked_select(mask.values()) |
| mask_cumsum = F.pad(mask.values().cumsum(dim=0), (1, 0)) # type: ignore[arg-type] |
| |
| args = extract_kwargs(inp) |
| args["offsets"] = mask_cumsum[inp._offsets] |
| return NestedTensor( |
| values=res_values, |
| **args, |
| ) |
| |
| |
| # Make the dummy available on the C++ side. |
| @register_jagged_func(torch.ops.aten._nested_get_jagged_dummy.default, "self: any") |
| def _nested_get_jagged_dummy(func, *args, **kwargs): |
| from torch.nested._internal.nested_tensor import _nt_view_dummy |
| |
| return _nt_view_dummy() |
| |
| |
| with torch.library._scoped_library("aten", "IMPL") as aten: |
| aten.impl("_nested_get_jagged_dummy", _nested_get_jagged_dummy, "CPU") |
| aten.impl("_nested_get_jagged_dummy", _nested_get_jagged_dummy, "CUDA") |
| aten.impl("_nested_get_jagged_dummy", _nested_get_jagged_dummy, "Meta") |