Ported matmul compositeimplicitautograd impl into core (#85239)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85239
Approved by: https://github.com/ezyang, https://github.com/lezcano
diff --git a/functorch/_src/aot_autograd.py b/functorch/_src/aot_autograd.py
index 26436f1..0cf820c 100644
--- a/functorch/_src/aot_autograd.py
+++ b/functorch/_src/aot_autograd.py
@@ -276,6 +276,9 @@
def aot_dispatch_base(flat_fn, flat_args: List[Tensor], aot_config: AOTConfig):
fw_module = make_fx(flat_fn, aot_config.decompositions)(*flat_args)
+ if config.debug_graphs:
+ print("====== Forward (only) graph ======")
+ fw_module.print_readable()
with track_graph_compiling("inference"):
compiled_fw = aot_config.fw_compiler(fw_module, flat_args)
diff --git a/test/test_decomp.py b/test/test_decomp.py
index ffd44cd..719240d 100644
--- a/test/test_decomp.py
+++ b/test/test_decomp.py
@@ -24,6 +24,7 @@
instantiate_device_type_tests,
)
from torch.testing._internal.common_methods_invocations import op_db
+from torch._dispatch.python import enable_python_dispatcher
import itertools
import functools
@@ -486,7 +487,7 @@
# explicit clearing is necessary as I will create a fresh mode
# for each region
decomposed.clear()
- with enable_torch_dispatch_mode(DecompCrossRefMode):
+ with enable_torch_dispatch_mode(DecompCrossRefMode), enable_python_dispatcher():
decomp_out, decomp_vjp_fn = ref_vjp_no_create(fn, *primals)
if aten_name in decomposition_names:
check_decomposed(aten_name)
@@ -495,7 +496,7 @@
cotangents = tree_map(lambda x: torch.randn_like(x), decomp_out)
decomposed.clear()
- with enable_torch_dispatch_mode(DecompCrossRefMode):
+ with enable_torch_dispatch_mode(DecompCrossRefMode), enable_python_dispatcher():
decomp_vjp_fn(cotangents)
if not run_all:
check_decomposed(op.aten_backward_name)
@@ -504,7 +505,7 @@
args = [sample_input.input] + list(sample_input.args)
kwargs = sample_input.kwargs
decomposed.clear()
- with enable_torch_dispatch_mode(DecompCrossRefMode):
+ with enable_torch_dispatch_mode(DecompCrossRefMode), enable_python_dispatcher():
func(*args, **kwargs)
if not run_all:
check_decomposed(aten_name)
diff --git a/test/test_ops.py b/test/test_ops.py
index 4cf025a..e96b59a 100644
--- a/test/test_ops.py
+++ b/test/test_ops.py
@@ -1671,7 +1671,6 @@
'_refs.expand_as',
'_refs.as_strided', # _prims._as_strided_meta: "reduce() of empty sequence with no initial value"
'_refs.copy_to', # torch._C._jit_get_operation: No such operator aten::copy_to
- '_refs.clone', # test_meta.py: view size is not compatible with input tensor's size and stride
'_refs.equal', # 'bool' object has no attribute 'dtype'
'_refs.conj', # Calls _prims.conj
'_refs.real',
diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py
index e238bdb..024e729 100644
--- a/test/test_proxy_tensor.py
+++ b/test/test_proxy_tensor.py
@@ -22,6 +22,7 @@
import types
import functools
+import itertools
aten = torch.ops.aten
@@ -1010,7 +1011,6 @@
xfail('linalg.eigvals'),
skip('_masked.logsumexp', ''), # Tensors of type TensorImpl do not have numel
xfail('__getitem__', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
- xfail('__rmatmul__', ''), # aten.new_empty.default - couldn't find symbolic meta function/decomposition
xfail('_masked.amax', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition
xfail('_masked.amin', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition
xfail('_masked.argmax', ''), # aten.argmax.default - couldn't find symbolic meta function/decomposition
@@ -1022,15 +1022,12 @@
xfail('_masked.mean', ''), # ones() received an invalid combination of arguments - got (torch.Size, device=torch.device, ...
xfail('_masked.median', ''), # aten.nanmedian.dim - couldn't find symbolic meta function/decomposition
xfail('_masked.norm', ''), # aten.linalg_vector_norm.default - couldn't find symbolic meta function/decomposition
- xfail('_masked.normalize', ''), # aten.linalg_vector_norm.default - couldn't find symbolic meta function/decomposition
xfail('_masked.prod', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition
xfail('_masked.softmax', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition
xfail('_masked.softmin', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition
xfail('_masked.std', ''), # ones() received an invalid combination of arguments - got (torch.Size, device=torch.device, d...
xfail('_masked.sum', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition
xfail('_masked.var', ''), # ones() received an invalid combination of arguments - got (torch.Size, device=torch.device, d...
- xfail('addmm', ''), # aten.mm.default - couldn't find symbolic meta function/decomposition
- xfail('addmm', 'decomposed'), # aten.mm.default - couldn't find symbolic meta function/decomposition
xfail('addmv', ''), # aten.addmv.default - couldn't find symbolic meta function/decomposition
xfail('addr', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('aminmax', ''), # aten.aminmax.default - couldn't find symbolic meta function/decomposition
@@ -1042,13 +1039,11 @@
xfail('as_strided_scatter', ''), # aten.as_strided_scatter.default - couldn't find symbolic meta function/decomposition
xfail('baddbmm', ''), # aten.baddbmm.default - couldn't find symbolic meta function/decomposition
xfail('bernoulli', ''), # aten.bernoulli.default - couldn't find symbolic meta function/decomposition
- xfail('bmm', ''), # aten.bmm.default - couldn't find symbolic meta function/decomposition
xfail('bucketize', ''), # aten.bucketize.Tensor - couldn't find symbolic meta function/decomposition
xfail('cartesian_prod', ''), # Tensors of type TensorImpl do not have numel
xfail('cdist', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('cholesky_solve', ''), # Could not run 'aten::_cholesky_solve_helper' with arguments from the 'Meta' back...
xfail('chunk', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
- xfail('clone', ''), # aten.clone.default - couldn't find symbolic meta function/decomposition
xfail('column_stack', ''), # Tensors of type TensorImpl do not have numel
xfail('constant_pad_nd', ''), # aten.fill.Scalar - couldn't find symbolic meta function/decomposition
xfail('count_nonzero', ''), # Could not run 'aten::count_nonzero.dim_IntList' with arguments from the 'Meta' ba...
@@ -1150,7 +1145,6 @@
xfail('linalg.tensorinv', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('linalg.tensorsolve', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('linalg.vander', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
- xfail('linalg.vector_norm', ''), # TensorImpl do not have numel
xfail('logaddexp2', ''), # aten.logaddexp2.default - couldn't find symbolic meta function/decomposition
xfail('logaddexp', ''), # aten.logaddexp.default - couldn't find symbolic meta function/decomposition
xfail('logcumsumexp', ''), # aten.logcumsumexp.default - couldn't find symbolic meta function/decomposition
@@ -1161,14 +1155,12 @@
xfail('masked_fill', ''), # expected predicate to be bool, got torch.float32
xfail('masked_scatter', ''), # aten.masked_scatter.default - couldn't find symbolic meta function/decomposition
xfail('masked_select', ''), # aten.masked_select.default - couldn't find symbolic meta function/decomposition
- xfail('matmul', ''), # aten.new_empty.default - couldn't find symbolic meta function/decomposition
xfail('matrix_exp', ''), # aten.linalg_matrix_exp.default - couldn't find symbolic meta function/decomposition
xfail('max', 'reduction_with_dim'), # aten.max.dim - couldn't find symbolic meta function/decomposition
xfail('median', ''), # Could not run 'aten::median' with arguments from the 'Meta' backend. This could be becau...
xfail('meshgrid', 'list_of_tensors'), # Tensors of type TensorImpl do not have numel
xfail('meshgrid', 'variadic_tensors'), # Tensors of type TensorImpl do not have numel
xfail('min', 'reduction_with_dim'), # aten.min.dim - couldn't find symbolic meta function/decomposition
- xfail('mm', ''), # aten.mm.default - couldn't find symbolic meta function/decomposition
xfail('mode', ''), # aten.mode.default - couldn't find symbolic meta function/decomposition
xfail('msort', ''), # aten.sort.default - couldn't find symbolic meta function/decomposition
xfail('nanquantile', ''), # Could not run 'aten::equal' with arguments from the 'Meta' backend.
@@ -1213,7 +1205,6 @@
xfail('nn.functional.local_response_norm', ''), # Tensors of type TensorImpl do not have numel
xfail('nn.functional.margin_ranking_loss', ''), # The underlying op of 'aten.stride' has no overload name '_schema'
xfail('nn.functional.max_pool1d', ''), # Trying to call aten.size on a tensor with symbolic shapes.
- xfail('nn.functional.max_pool2d', ''), # aten.max_pool2d_with_indices.default - couldn't find symbolic meta function/d...
xfail('nn.functional.max_pool3d', ''), # aten.max_pool3d_with_indices.default - couldn't find symbolic meta function/d...
xfail('nn.functional.max_unpool1d', 'grad'), # aten.max_unpool2d.default - couldn't find symbolic meta function/decom...
xfail('nn.functional.max_unpool2d', 'grad'), # aten.max_unpool2d.default - couldn't find symbolic meta function/decom...
@@ -1234,7 +1225,6 @@
xfail('nn.functional.unfold', ''), # aten.im2col.default - couldn't find symbolic meta function/decomposition
xfail('nn.functional.upsample_bilinear', ''), # aten.upsample_bilinear2d.vec - couldn't find symbolic meta function/de...
xfail('nn.functional.upsample_nearest', ''), # aten.upsample_nearest1d.vec - couldn't find symbolic meta function/deco...
- xfail('norm', ''), # TensorImpl does not have numel
xfail('norm', 'nuc'), # aten._linalg_svd.default - couldn't find symbolic meta function/decomposition
xfail('normal', ''), # aten.normal.Tensor_Tensor - couldn't find symbolic meta function/decomposition
xfail('normal', 'number_mean'), # aten.normal.float_Tensor - couldn't find symbolic meta function/decomposition
@@ -1337,7 +1327,9 @@
return op.op(*args, **kwargs)
sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False)
new_f = None
- for sample_input in sample_inputs_itr:
+
+ # Limit ourselves to first 100 inputs so symbolic tracing tests don't take too long
+ for sample_input in itertools.islice(sample_inputs_itr, 100):
args = [sample_input.input] + list(sample_input.args)
kwargs = sample_input.kwargs
@@ -1345,7 +1337,6 @@
new_f = make_fx(f, tracing_mode=tracing_mode)(args, kwargs)
except DynamicOutputShapeException as e:
self.skipTest("Dynamic output shape operation in trace")
-
for arg in args:
if isinstance(arg, torch.Tensor) and arg.dtype == torch.float:
arg.uniform_(0, 1)
diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py
index 1d09c0b..05dc65d 100644
--- a/torch/_decomp/decompositions.py
+++ b/torch/_decomp/decompositions.py
@@ -1,4 +1,5 @@
import functools
+import operator
from enum import Enum
from itertools import product
from typing import Callable, Iterable, List, Optional, Tuple
@@ -12,6 +13,8 @@
from torch._prims_common.wrappers import out_wrapper
from torch.utils._pytree import tree_flatten, tree_map
+DispatchKey = torch._C.DispatchKey # type: ignore[attr-defined]
+
# None of these functions are publicly accessible; get at them
# from torch._decomps
__all__: List[str] = []
@@ -2030,3 +2033,131 @@
loss = loss * weight
return apply_loss_reduction(loss, reduction)
+
+
+def should_fold(tensor1: torch.Tensor, dim_tensor2: int) -> bool:
+ dim_tensor1 = tensor1.ndim
+ if dim_tensor1 >= 3 and (dim_tensor2 == 1 or dim_tensor2 == 2):
+ t1_sizes_ptr = tensor1.shape
+ t1_strides = tensor1.stride()
+ if (
+ dim_tensor1 == 3
+ and dim_tensor2 == 2
+ and t1_strides[-1] != 1
+ and t1_strides[0] == t1_sizes_ptr[1] * t1_sizes_ptr[2]
+ ):
+ # First dim is slowest moving, and then the following two dims are
+ # transposed. This can happen for example by permute(0, 2, 1).
+ # First 2 dims could be folded to use mm but would require permutation
+ # with actual data movement, which can be instead handled by BMM with each
+ # GEMM transposed.
+ # This can be generalized to a tensor with dim X + Y + Z where X, Y, and Z
+ # dims are contiguous, Y dims and Z dims are transposed, and X, Y, Z > 0.
+ # For example, this can happen by permute(0, 1, 5, 2, 3, 4), where X = 2,
+ # Y = 3, and Z = 1.
+ return False
+ else:
+ return True
+ else:
+ return False
+
+
[email protected]_impl(DispatchKey.CompositeImplicitAutograd)
+def matmul(tensor1, tensor2):
+ dim_tensor1 = tensor1.dim()
+ dim_tensor2 = tensor2.dim()
+ assert dim_tensor1 != 0 and dim_tensor2 != 0
+ if dim_tensor1 == 1 and dim_tensor2 == 1:
+ return torch.dot(tensor1, tensor2)
+ elif dim_tensor1 == 2 and dim_tensor2 == 1:
+ return torch.mv(tensor1, tensor2)
+ elif dim_tensor1 == 1 and dim_tensor2 == 2:
+ return torch.squeeze(torch.mm(torch.unsqueeze(tensor1, 0), tensor2), 0)
+ elif dim_tensor1 == 2 and dim_tensor2 == 2:
+ # if tensor1.shape[1] != tensor2.shape[0]:
+ # breakpoint()
+ return torch.mm(tensor1, tensor2)
+ elif should_fold(tensor1, dim_tensor2) or should_fold(tensor2, dim_tensor1):
+ # NB: Much of this was written with Copilot! (although still had to fix a bunch of issues)
+
+ # dim_tensor1 >=3 && (dim_tensor2 == 1 || dim_tensor2 == 2) ||
+ # dim_tensor2 >=3 && (dim_tensor1 == 1 || dim_tensor1 == 2)
+ # and some condition on the strides is fulfilled
+
+ # optimization: use mm instead of bmm by folding the batch of the larger tensor
+ # into its leading matrix dimension
+ transpose = dim_tensor2 > dim_tensor1
+ t1 = tensor2.mT if transpose else tensor1
+ t2 = (
+ tensor2 if not transpose else (tensor1.t() if dim_tensor1 == 2 else tensor1)
+ )
+ # Invariant: t1.dim() >= 3 && (t2.dim() == 1 || t2.dim() == 2)
+ # and t1 and t2 are matmul-compatible
+
+ # Why not t1.view(-1, sizes_1[-1])?
+ # If the last dim is 0, then view(-1, 0) won't work because the -1 becomes ambiguous.
+ # This can happen in e.g. [3, 5, 0] @ [0, 0].
+ sizes_1 = t1.shape
+ output_shape = list(sizes_1[:-1])
+ folded_dim1 = functools.reduce(operator.mul, output_shape)
+
+ # Readjust output_shape if we are multiplying by a matrix
+ t2_is_matrix = t2.dim() == 2
+ if t2_is_matrix:
+ output_shape.append(t2.shape[1])
+ # HACK: We need reshape with symint support
+ t1 = t1.contiguous()
+ t1_folded = t1.view(folded_dim1, sizes_1[-1])
+ if t2_is_matrix:
+ # FIXME This path always does an unnecessary copy when transpose == True as the returned
+ # result from BLAS is already C-transposed
+ output = t1_folded.mm(t2).view(output_shape)
+ return output.mT.contiguous() if transpose else output
+ else:
+ return t1_folded.mv(t2).view(output_shape)
+
+ elif dim_tensor1 >= 1 and dim_tensor2 >= 1:
+ # We are multiplying b1 x n x m1 by x2 x m2 x p (where b1 can be a list);
+ # we track m1 vs m2 separately even though they must match for nicer error messages
+ n = tensor1.size(-2) if dim_tensor1 > 1 else 1
+ m1 = tensor1.size(-1)
+ batch_tensor1 = tensor1.shape[:-2]
+ m2 = tensor2.size(-2) if dim_tensor2 > 1 else tensor2.size(-1)
+ p = tensor2.size(-1) if dim_tensor2 > 1 else 1
+ batch_tensor2: List[int] = []
+ # TODO: handling of slice
+ for i in range(dim_tensor2 - 2):
+ batch_tensor2.append(tensor2.size(i))
+
+ # expand the batch portion (i.e. cut off matrix dimensions and expand rest)
+ expand_batch_portion = list(
+ torch.broadcast_shapes(batch_tensor1, batch_tensor2)
+ )
+
+ tensor1_expand_size = expand_batch_portion + [n, m1]
+ tensor2_expand_size = expand_batch_portion + [m2, p]
+
+ expand_batch_product = prod(expand_batch_portion)
+
+ # HACK: We need reshape with symint support
+ tensor1_expanded = (
+ tensor1.expand(tensor1_expand_size)
+ .contiguous()
+ .view(expand_batch_product, n, m1)
+ )
+ tensor2_expanded = (
+ tensor2.expand(tensor2_expand_size)
+ .contiguous()
+ .view(expand_batch_product, m2, p)
+ )
+
+ output_shape = expand_batch_portion
+ if dim_tensor1 > 1:
+ output_shape.append(n)
+
+ if dim_tensor2 > 1:
+ output_shape.append(p)
+
+ return tensor1_expanded.bmm(tensor2_expanded).view(output_shape)
+ else:
+ utils.check(False, lambda: "both arguments to matmul need to be at least 1D")
diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py
index e63e277..4684a72 100644
--- a/torch/_meta_registrations.py
+++ b/torch/_meta_registrations.py
@@ -193,6 +193,16 @@
return self.new_empty(())
+@register_meta([aten.mm.default], register_dispatcher=False)
+def meta_mm(a, b):
+ check(a.dim() == 2, lambda: "a must be 2D")
+ check(b.dim() == 2, lambda: "b must be 2D")
+ N, M1 = a.shape
+ M2, P = b.shape
+ check(M1 == M2, lambda: "a and b must have same reduction dim")
+ return a.new_empty(N, P)
+
+
def _compute_reduction_shape(self, dims, keepdim):
if keepdim:
return tuple(self.shape[i] if i not in dims else 1 for i in range(self.ndim))
@@ -762,6 +772,234 @@
return self.view(self.shape)
+def common_meta_baddbmm_bmm(batch1, batch2, is_bmm, self_baddbmm=None):
+ check(batch1.dim() == 3, lambda: "batch1 must be a 3D tensor")
+ check(batch2.dim() == 3, lambda: "batch2 must be a 3D tensor")
+
+ batch1_sizes = batch1.size()
+ batch2_sizes = batch2.size()
+
+ bs = batch1_sizes[0]
+ contraction_size = batch1_sizes[2]
+ res_rows = batch1_sizes[1]
+ res_cols = batch2_sizes[2]
+ output_size = (bs, res_rows, res_cols)
+
+ check(
+ batch2_sizes[0] == bs and batch2_sizes[1] == contraction_size,
+ lambda: f"Expected size for first two dimensions of batch2 tensor to be: [{bs}"
+ f", {contraction_size}] but got: [{batch2_sizes[0]}, {batch2_sizes[1]}].",
+ )
+
+ # TODO: handle out
+
+ output = batch2.new_empty(output_size)
+
+ if not is_bmm and self_baddbmm is not None:
+ check(self_baddbmm.dim() == 3, lambda: "self must be a 3D tensor")
+ check(
+ self_baddbmm.size() == output_size,
+ lambda: "Expected an input tensor shape with shape {output_size} but got shape: {self.size()}",
+ )
+
+ return output
+
+
+@register_meta(aten.bmm.default, register_dispatcher=False)
+def meta_bmm(self, mat2):
+ return common_meta_baddbmm_bmm(self, mat2, True)
+
+
+def div_rtn(x, y):
+ q = x // y
+ r = x % y
+ # WARNING: explicit bool conversion here is necessary;
+ # would be fixed by SymBool
+ if r != 0 and (bool(r < 0) != bool(y < 0)):
+ q -= 1
+ return q
+
+
+def pooling_output_shape_pad_lr(
+ inputSize, kernelSize, pad_l, pad_r, stride, dilation, ceil_mode
+):
+ outputSize = (
+ div_rtn(
+ inputSize
+ + pad_l
+ + pad_r
+ - dilation * (kernelSize - 1)
+ - 1
+ + (stride - 1 if ceil_mode else 0),
+ stride,
+ )
+ + 1
+ )
+ if ceil_mode:
+ if (outputSize - 1) * stride >= inputSize + pad_l:
+ outputSize -= 1
+ return outputSize
+
+
+def pooling_output_shape(inputSize, kernelSize, pad, stride, dilation, ceil_mode):
+ check(stride != 0, lambda: "stride should not be zero")
+ check(pad >= 0, lambda: f"pad must be non-negative, but got pad: {pad}")
+ check(
+ pad <= kernelSize // 2,
+ lambda: f"pad should be at most half of kernel size, but got pad={pad} and kernel_size={kernelSize}",
+ )
+ return pooling_output_shape_pad_lr(
+ inputSize, kernelSize, pad, pad, stride, dilation, ceil_mode
+ )
+
+
+def pool2d_shape_check(
+ input,
+ kH,
+ kW,
+ dH,
+ dW,
+ padH,
+ padW,
+ dilationH,
+ dilationW,
+ nInputPlane,
+ inputHeight,
+ inputWidth,
+ outputHeight,
+ outputWidth,
+ memory_format,
+):
+ ndim = input.dim()
+ nOutputPlane = nInputPlane
+
+ check(
+ kW > 0 and kH > 0,
+ lambda: "kernel size should be greater than zero, but got kH: {kH}, kW: {kW}",
+ )
+ check(
+ dW > 0 and dH > 0,
+ lambda: "stride should be greater than zero, but got dH: {dH}, dW: {dW}",
+ )
+ check(
+ dilationH > 0 and dilationW > 0,
+ lambda: "dilation should be greater than zero, but got dilationH: {dilationH}, dilationW: {dilationW}",
+ )
+
+ valid_dims = input.size(1) != 0 and input.size(2) != 0
+
+ if memory_format == torch.channels_last:
+ check(
+ ndim == 4 and valid_dims and input.size(3) != 0,
+ lambda: "Expected 4D (batch mode) tensor expected for input with channels_last layout"
+ " with optional 0 dim batch size for input, but got: {input.size()}",
+ )
+ else:
+ check(
+ (ndim == 3 and input.size(0) != 0 and valid_dims)
+ or (ndim == 4 and valid_dims and input.size(3) != 0),
+ lambda: f"Expected 3D or 4D (batch mode) tensor with optional 0 dim batch size for input, but got: {input.size()}",
+ )
+
+ check(
+ kW // 2 >= padW and kH // 2 >= padH,
+ lambda: "pad should be smaller than or equal to half of kernel size, but got "
+ f"padW = {padW}, padH = {padH}, kW = {kW}, kH = {kH}",
+ )
+
+ check(
+ outputWidth >= 1 and outputHeight >= 1,
+ lambda: f"Given input size: ({nInputPlane}x{inputHeight}x{inputWidth}). "
+ f"Calculated output size: ({nOutputPlane}x{outputHeight}x{outputWidth}). "
+ "Output size is too small",
+ )
+
+
+@register_meta(aten.max_pool2d_with_indices.default, register_dispatcher=False)
+def meta_max_pool2d_with_indices(
+ input, kernel_size, stride=(), padding=(0,), dilation=(1,), ceil_mode=False
+):
+ # Reference: aten/src/ATen/native/DilatedMaxPool2d.cpp
+ def unpack(name, val):
+ check(
+ len(val) in [1, 2],
+ lambda: f"max_pool2d: {name} must either be a single int, or a tuple of two ints",
+ )
+ H = val[0]
+ W = H if len(val) == 1 else val[1]
+ return H, W
+
+ kH, kW = unpack("kernel_size", kernel_size)
+
+ check(
+ len(stride) in [0, 1, 2],
+ lambda: "max_pool2d: stride must either be omitted, a single int, or a tuple of two ints",
+ )
+ if len(stride) == 0:
+ dH, dW = kH, kW
+ else:
+ dH, dW = unpack("stride", stride)
+
+ padH, padW = unpack("padding", padding)
+ dilationH, dilationW = unpack("dilation", dilation)
+
+ memory_format = utils.suggest_memory_format(input)
+ if memory_format == torch.channels_last:
+ check(
+ input.dim() == 4,
+ lambda: "non-empty 4D (batch mode) tensor expected for input with channels_last layout",
+ )
+ elif memory_format == torch.contiguous_format:
+ check(
+ input.dim() in [3, 4],
+ lambda: "non-empty 3D or 4D (batch mode) tensor expected for input",
+ )
+ else:
+ check(
+ False,
+ lambda: "Unsupport memory format. Supports only ChannelsLast, Contiguous",
+ )
+
+ nbatch = input.size(-4) if input.dim() == 4 else 1
+ nInputPlane = input.size(-3)
+ inputHeight = input.size(-2)
+ inputWidth = input.size(-1)
+
+ outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, dilationH, ceil_mode)
+ outputWidth = pooling_output_shape(inputWidth, kW, padW, dW, dilationW, ceil_mode)
+
+ pool2d_shape_check(
+ input,
+ kH,
+ kW,
+ dH,
+ dW,
+ padH,
+ padW,
+ dilationH,
+ dilationW,
+ nInputPlane,
+ inputHeight,
+ inputWidth,
+ outputHeight,
+ outputWidth,
+ memory_format,
+ )
+
+ if input.dim() == 3:
+ size = [nInputPlane, outputHeight, outputWidth]
+ else:
+ size = [nbatch, nInputPlane, outputHeight, outputWidth]
+ return (
+ torch.empty(
+ size, dtype=input.dtype, device=input.device, memory_format=memory_format
+ ),
+ torch.empty(
+ size, dtype=torch.int64, device=input.device, memory_format=memory_format
+ ),
+ )
+
+
# We must also trigger meta registrations from PrimTorch ref
# decompositions
import torch._refs
diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py
index 7d334c9..6f59ad1 100644
--- a/torch/_refs/__init__.py
+++ b/torch/_refs/__init__.py
@@ -338,10 +338,7 @@
return x
if not utils.same_shape(x.shape, common_shape):
- common_rank = len(common_shape) + 1
- start = common_rank - (len(x.shape) + 1)
- dims = tuple(range(start, len(x.shape) + start))
- return prims.broadcast_in_dim(x, common_shape, dims)
+ return x.expand(common_shape)
return x
else:
@@ -1658,6 +1655,7 @@
#
# Data Movement References
#
+@register_decomposition(torch.ops.aten.clone.default)
def clone(
a: TensorLikeType, *, memory_format: torch.memory_format = torch.preserve_format
) -> TensorLikeType:
diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py
index 38b446e..ee0f8ab 100644
--- a/torch/_subclasses/fake_tensor.py
+++ b/torch/_subclasses/fake_tensor.py
@@ -727,7 +727,6 @@
aten.empty_strided.default,
aten.as_strided.default,
aten.zeros.default,
- aten.clone.default,
aten.detach.default,
]
# IDK: feels bad man, sym_numel on as_strided infinite loops otherwise
diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py
index 34b4012..202990d 100644
--- a/torch/fx/experimental/proxy_tensor.py
+++ b/torch/fx/experimental/proxy_tensor.py
@@ -442,7 +442,8 @@
if func in [prim.device.default]:
return func(*args, **kwargs)
- return proxy_call(self, func, args, kwargs)
+ out = proxy_call(self, func, args, kwargs)
+ return out
SymInt = torch.SymIntNode
diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py
index e938945..54fe764 100644
--- a/torch/fx/experimental/symbolic_shapes.py
+++ b/torch/fx/experimental/symbolic_shapes.py
@@ -1,6 +1,6 @@
import torch
import torch.utils._pytree as pytree
-from typing import Dict, List, Type, Optional, cast
+from typing import Set, Dict, List, Type, Optional, cast
import operator
import functools
from functools import lru_cache
@@ -17,7 +17,7 @@
__all__ = [
"has_symbolic_sizes_strides", "create_contiguous", "PySymInt", "ShapeEnv",
- "SymDispatchMode", "PySymFloat", "sym_float"
+ "SymDispatchMode", "PySymFloat", "sym_float", "FloorDiv"
]
SYM_FUNCTION_MODE = None
@@ -148,13 +148,31 @@
def __str__(self):
return f"{self.expr}"
+if HAS_SYMPY:
+ class FloorDiv(sympy.Function):
+ """
+ We maintain this so that:
+ 1. We can use divisibility guards to simplify FloorDiv(a, b) to a / b.
+ 2. Printing out the expression is nicer (compared to say, representing a//b as (a - a % b) / b)
+ """
+ nargs = (2,)
+
+ @classmethod
+ def eval(cls, base, divisor):
+ if base == 0:
+ return sympy.Integer(0)
+ if divisor == 1:
+ return base
+ if isinstance(base, sympy.Integer) and isinstance(divisor, sympy.Integer):
+ return base // divisor
+
# Methods that have a `__foo__` as well as `__rfoo__`
reflectable_magic_methods = {
'add': lambda a, b: a + b,
'sub': lambda a, b: a - b,
'mul': lambda a, b: a * b,
'mod': lambda a, b: a % b,
- 'floordiv': lambda a, b: (a - (a % b)) / b
+ 'floordiv': lambda a, b: FloorDiv(a, b)
}
magic_methods = {
@@ -225,8 +243,8 @@
# Maps from sympy ints to expressions representing them
# Populated from equality guards (i.e. a.shape[0] == b.shape[0])
self.replacements: Dict["sympy.Symbol", "sympy.Expr"] = {} #
- # Keys are Mod(x, y), values are 0 (for ease of substitution)
- self.divisible: Dict["sympy.Expr", "sympy.Integer"] = {}
+ # Set holds a % b expressions that evaluate to 0.
+ self.divisible: Set["sympy.Expr"] = set()
def _get_key(self):
"""
@@ -267,9 +285,7 @@
return all(guard.xreplace(new_env.var_to_val) == value for guard, value, _ in self.guards)
def get_nontrivial_guards(self):
- guards = [(self.simplify(guard), val) for guard, val, _ in self.guards]
- guards = [guard for guard in guards if len(guard[0].free_symbols) > 0]
- return guards
+ return [(self.simplify(guard), val) for guard, val, _ in self.guards if self._maybe_evaluate_static(guard) is None]
def get_shape_groups(self):
shape_groups = collections.defaultdict(list)
@@ -282,6 +298,7 @@
"""
Tries to evaluate expr without introducing guards
"""
+ expr = self.simplify(expr)
# Simplifies assuming that shape vars > 1 (since we cache on 0/1 shape values)
symbols = list(expr.free_symbols)
new_shape_env = {
@@ -289,7 +306,10 @@
for idx, k in enumerate(symbols)
}
new_expr = expr.xreplace(new_shape_env)
- new_expr = sympy.expand(new_expr)
+ floor_div_replace = {}
+ for atom in new_expr.atoms(FloorDiv):
+ floor_div_replace[atom] = sympy.floor(atom.args[0] / atom.args[1])
+ new_expr = sympy.expand(new_expr.xreplace(floor_div_replace))
if len(list(new_expr.free_symbols)) == 0:
return new_expr
return None
@@ -301,20 +321,25 @@
@_lru_cache
def _update_divisible(self):
- new_divisible = {}
+ new_divisible = set()
for k in self.divisible:
res = self.replace(k)
if len(res.free_symbols) > 0:
- new_divisible[k] = sympy.Integer(0)
+ new_divisible.add(k)
self.divisible = new_divisible
@_lru_cache
def simplify(self, expr: "sympy.Expr") -> "sympy.Expr":
expr = self.replace(expr)
- if expr.has(sympy.Mod):
+ if expr.has(FloorDiv):
self._update_divisible()
- expr = expr.xreplace(self.divisible)
+ div_replacements = {}
+ for atom in expr.atoms(FloorDiv):
+ base, divisor = atom.args
+ if self.replace(base % divisor) in self.divisible:
+ div_replacements[atom] = base / divisor
+ expr = expr.xreplace(div_replacements)
expr = sympy.expand(expr)
return expr
@@ -363,24 +388,23 @@
lhs = expr.lhs
rhs = expr.rhs
try:
- solutions = sympy.solveset(lhs - rhs, free[0], domain=sympy.S.Integers)
- if not solutions.is_finite_set:
- if expr.has(sympy.Mod):
- mod_expr = tuple(expr.atoms(sympy.Mod))[0]
- solutions = sympy.solveset(lhs - rhs, mod_expr, domain=sympy.S.Integers)
- if solutions.is_finite_set and len(solutions) == 1 and tuple(solutions)[0] == 0:
- self.divisible[mod_expr] = sympy.Integer(0)
+ solutions = sympy.solve(lhs - rhs, free[0], dict=True)
+ if len(solutions) != 1:
return
-
- if not isinstance(solutions, sympy.FiniteSet):
- return
-
- solutions = tuple(solutions)
- if len(solutions) == 1 and all(t.is_integer for t in sympy.preorder_traversal(solutions[0])):
- new_var = self._find(solutions[0])
+ solution = solutions[0][free[0]]
+ if all(t.is_integer for t in sympy.preorder_traversal(solution)):
+ new_var = self._find(solution)
self.replacements[cast(sympy.Symbol, free[0])] = new_var
- except ZeroDivisionError:
- pass
+ except NotImplementedError:
+ if expr.has(sympy.Mod):
+ mod_expr = tuple(expr.atoms(sympy.Mod))[0]
+ try:
+ solutions = sympy.solve(lhs - rhs, mod_expr, dict=True)
+ if len(solutions) == 1 and solutions[0][mod_expr] == 0:
+ self.divisible.add(mod_expr)
+ except NotImplementedError:
+ pass
+ return
@lru_cache(256)
def evaluate_expr(self, expr: "sympy.Expr"):
@@ -390,7 +414,6 @@
if len(expr.free_symbols) == 0:
return expr
expr = self.simplify(expr)
-
static_expr = self._maybe_evaluate_static(expr)
if static_expr is not None:
return static_expr
diff --git a/torch/fx/proxy.py b/torch/fx/proxy.py
index c7eb3c9..5d65f03 100644
--- a/torch/fx/proxy.py
+++ b/torch/fx/proxy.py
@@ -162,7 +162,6 @@
return a.node
elif isinstance(a, base_types) or a is None or a is ...:
return a
-
raise NotImplementedError(f"argument of type: {type(a)}")
@compatibility(is_backward_compatible=True)
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 75ece1f..7f8d9b8 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -5558,7 +5558,7 @@
return samples
-def sample_inputs_matmul(op_info, device, dtype, requires_grad, **kwargs):
+def sample_inputs_matmul(op_info, device, dtype, requires_grad, is_rmatmul=False, **kwargs):
test_cases = (((L,), (L,)),
((S, M), (M,)),
((M,), (M, S)),
@@ -5577,12 +5577,10 @@
for lhs_shape, rhs_shape in test_cases:
lhs = make_tensor(lhs_shape, dtype=dtype, device=device, low=None, high=None, requires_grad=requires_grad)
rhs = make_tensor(rhs_shape, dtype=dtype, device=device, low=None, high=None, requires_grad=requires_grad)
- if op_info.name == 'matmul':
+ if not is_rmatmul:
sample_inputs.append(SampleInput(lhs, args=(rhs,)))
- elif op_info.name == '__rmatmul__':
- sample_inputs.append(SampleInput(rhs, args=(lhs,)))
else:
- raise RuntimeError("`op_info.name` must be 'matmul' or '__rmatmul__'")
+ sample_inputs.append(SampleInput(rhs, args=(lhs,)))
return tuple(sample_inputs)
@@ -9972,7 +9970,7 @@
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
check_batched_forward_grad=False,
- sample_inputs_func=sample_inputs_matmul,
+ sample_inputs_func=partial(sample_inputs_matmul, is_rmatmul=False),
decorators=[
# NVIDIA only assures that bfloat16 is supported by bmm if SM >= 5.3
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes', device_type='cuda', active_if=not SM53OrLater),
@@ -12679,7 +12677,7 @@
*[torch.bfloat16]
if (SM53OrLater and CUDA11OrLater) or TEST_WITH_ROCM else []),
assert_autodiffed=True,
- sample_inputs_func=sample_inputs_matmul,
+ sample_inputs_func=partial(sample_inputs_matmul, is_rmatmul=True),
# Runs very slowly on slow gradcheck - alternatively reduce input sizes
gradcheck_fast_mode=True,
supports_out=False,