| import functools |
| import unittest |
| from unittest.mock import patch |
| |
| import torch |
| |
| |
| aten = torch.ops.aten |
| |
| # This list is not meant to be comprehensive |
| _COMPOSITE_OPS_THAT_CAN_BE_PRESERVED_TESTING_ONLY = [ |
| aten.arctan2.default, |
| aten.divide.Tensor, |
| aten.divide.Scalar, |
| aten.divide.Tensor_mode, |
| aten.divide.Scalar_mode, |
| aten.multiply.Tensor, |
| aten.multiply.Scalar, |
| aten.subtract.Tensor, |
| aten.subtract.Scalar, |
| aten.true_divide.Tensor, |
| aten.true_divide.Scalar, |
| aten.greater.Tensor, |
| aten.greater.Scalar, |
| aten.greater_equal.Tensor, |
| aten.greater_equal.Scalar, |
| aten.less_equal.Tensor, |
| aten.less_equal.Scalar, |
| aten.less.Tensor, |
| aten.less.Scalar, |
| aten.not_equal.Tensor, |
| aten.not_equal.Scalar, |
| aten.cat.names, |
| aten.sum.dim_DimnameList, |
| aten.mean.names_dim, |
| aten.prod.dim_Dimname, |
| aten.all.dimname, |
| aten.norm.names_ScalarOpt_dim, |
| aten.norm.names_ScalarOpt_dim_dtype, |
| aten.var.default, |
| aten.var.dim, |
| aten.var.names_dim, |
| aten.var.correction_names, |
| aten.std.default, |
| aten.std.dim, |
| aten.std.names_dim, |
| aten.std.correction_names, |
| aten.absolute.default, |
| aten.arccos.default, |
| aten.arccosh.default, |
| aten.arcsin.default, |
| aten.arcsinh.default, |
| aten.arctan.default, |
| aten.arctanh.default, |
| aten.clip.default, |
| aten.clip.Tensor, |
| aten.fix.default, |
| aten.negative.default, |
| aten.square.default, |
| aten.size.int, |
| aten.size.Dimname, |
| aten.stride.int, |
| aten.stride.Dimname, |
| aten.repeat_interleave.self_Tensor, |
| aten.repeat_interleave.self_int, |
| aten.sym_size.int, |
| aten.sym_stride.int, |
| aten.atleast_1d.Sequence, |
| aten.atleast_2d.Sequence, |
| aten.atleast_3d.Sequence, |
| aten.linear.default, |
| aten.conv2d.default, |
| aten.conv2d.padding, |
| aten.mish_backward.default, |
| aten.silu_backward.default, |
| aten.index_add.dimname, |
| aten.pad_sequence.default, |
| aten.index_copy.dimname, |
| aten.upsample_nearest1d.vec, |
| aten.upsample_nearest2d.vec, |
| aten.upsample_nearest3d.vec, |
| aten._upsample_nearest_exact1d.vec, |
| aten._upsample_nearest_exact2d.vec, |
| aten._upsample_nearest_exact3d.vec, |
| aten.rnn_tanh.input, |
| aten.rnn_tanh.data, |
| aten.rnn_relu.input, |
| aten.rnn_relu.data, |
| aten.lstm.input, |
| aten.lstm.data, |
| aten.gru.input, |
| aten.gru.data, |
| aten._upsample_bilinear2d_aa.vec, |
| aten._upsample_bicubic2d_aa.vec, |
| aten.upsample_bilinear2d.vec, |
| aten.upsample_trilinear3d.vec, |
| aten.upsample_linear1d.vec, |
| aten.matmul.default, |
| aten.upsample_bicubic2d.vec, |
| aten.__and__.Scalar, |
| aten.__and__.Tensor, |
| aten.__or__.Tensor, |
| aten.__or__.Scalar, |
| aten.__xor__.Tensor, |
| aten.__xor__.Scalar, |
| aten.scatter.dimname_src, |
| aten.scatter.dimname_value, |
| aten.scatter_add.dimname, |
| aten.is_complex.default, |
| aten.logsumexp.names, |
| aten.where.ScalarOther, |
| aten.where.ScalarSelf, |
| aten.where.Scalar, |
| aten.where.default, |
| aten.item.default, |
| aten.any.dimname, |
| aten.std_mean.default, |
| aten.std_mean.dim, |
| aten.std_mean.names_dim, |
| aten.std_mean.correction_names, |
| aten.var_mean.default, |
| aten.var_mean.dim, |
| aten.var_mean.names_dim, |
| aten.var_mean.correction_names, |
| aten.broadcast_tensors.default, |
| aten.stft.default, |
| aten.stft.center, |
| aten.istft.default, |
| aten.index_fill.Dimname_Scalar, |
| aten.index_fill.Dimname_Tensor, |
| aten.index_select.dimname, |
| aten.diag.default, |
| aten.cumsum.dimname, |
| aten.cumprod.dimname, |
| aten.meshgrid.default, |
| aten.meshgrid.indexing, |
| aten.fft_fft.default, |
| aten.fft_ifft.default, |
| aten.fft_rfft.default, |
| aten.fft_irfft.default, |
| aten.fft_hfft.default, |
| aten.fft_ihfft.default, |
| aten.fft_fftn.default, |
| aten.fft_ifftn.default, |
| aten.fft_rfftn.default, |
| aten.fft_ihfftn.default, |
| aten.fft_irfftn.default, |
| aten.fft_hfftn.default, |
| aten.fft_fft2.default, |
| aten.fft_ifft2.default, |
| aten.fft_rfft2.default, |
| aten.fft_irfft2.default, |
| aten.fft_hfft2.default, |
| aten.fft_ihfft2.default, |
| aten.fft_fftshift.default, |
| aten.fft_ifftshift.default, |
| aten.selu.default, |
| aten.margin_ranking_loss.default, |
| aten.hinge_embedding_loss.default, |
| aten.nll_loss.default, |
| aten.prelu.default, |
| aten.relu6.default, |
| aten.pairwise_distance.default, |
| aten.pdist.default, |
| aten.special_ndtr.default, |
| aten.cummax.dimname, |
| aten.cummin.dimname, |
| aten.logcumsumexp.dimname, |
| aten.max.other, |
| aten.max.names_dim, |
| aten.min.other, |
| aten.min.names_dim, |
| aten.linalg_eigvals.default, |
| aten.median.names_dim, |
| aten.nanmedian.names_dim, |
| aten.mode.dimname, |
| aten.gather.dimname, |
| aten.sort.dimname, |
| aten.sort.dimname_stable, |
| aten.argsort.default, |
| aten.argsort.dimname, |
| aten.rrelu.default, |
| aten.conv_transpose1d.default, |
| aten.conv_transpose2d.input, |
| aten.conv_transpose3d.input, |
| aten.conv1d.default, |
| aten.conv1d.padding, |
| aten.conv3d.default, |
| aten.conv3d.padding, |
| aten.float_power.Tensor_Tensor, |
| aten.float_power.Tensor_Scalar, |
| aten.float_power.Scalar, |
| aten.ldexp.Tensor, |
| aten._version.default, |
| ] |
| |
| |
| def make_test_cls_with_mocked_export( |
| cls, cls_prefix, fn_suffix, mocked_export_fn, xfail_prop=None |
| ): |
| MockedTestClass = type(f"{cls_prefix}{cls.__name__}", cls.__bases__, {}) |
| MockedTestClass.__qualname__ = MockedTestClass.__name__ |
| |
| for name in dir(cls): |
| if name.startswith("test_"): |
| fn = getattr(cls, name) |
| if not callable(fn): |
| setattr(MockedTestClass, name, getattr(cls, name)) |
| continue |
| new_name = f"{name}{fn_suffix}" |
| new_fn = _make_fn_with_mocked_export(fn, mocked_export_fn) |
| new_fn.__name__ = new_name |
| if xfail_prop is not None and hasattr(fn, xfail_prop): |
| new_fn = unittest.expectedFailure(new_fn) |
| setattr(MockedTestClass, new_name, new_fn) |
| # NB: Doesn't handle slots correctly, but whatever |
| elif not hasattr(MockedTestClass, name): |
| setattr(MockedTestClass, name, getattr(cls, name)) |
| |
| return MockedTestClass |
| |
| |
| def _make_fn_with_mocked_export(fn, mocked_export_fn): |
| @functools.wraps(fn) |
| def _fn(*args, **kwargs): |
| try: |
| from . import test_export |
| except ImportError: |
| import test_export |
| |
| with patch(f"{test_export.__name__}.export", mocked_export_fn): |
| return fn(*args, **kwargs) |
| |
| return _fn |
| |
| |
| # Controls tests generated in test/export/test_export_training_ir_to_run_decomp.py |
| def expectedFailureTrainingIRToRunDecomp(fn): |
| fn._expected_failure_training_ir_to_run_decomp = True |
| return fn |
| |
| |
| # Controls tests generated in test/export/test_export_training_ir_to_run_decomp.py |
| def expectedFailureTrainingIRToRunDecompNonStrict(fn): |
| fn._expected_failure_training_ir_to_run_decomp_non_strict = True |
| return fn |
| |
| |
| # Controls tests generated in test/export/test_export_nonstrict.py |
| def expectedFailureNonStrict(fn): |
| fn._expected_failure_non_strict = True |
| return fn |
| |
| |
| # Controls tests generated in test/export/test_retraceability.py |
| def expectedFailureRetraceability(fn): |
| fn._expected_failure_retrace = True |
| return fn |
| |
| |
| # Controls tests generated in test/export/test_serdes.py |
| def expectedFailureSerDer(fn): |
| fn._expected_failure_serdes = True |
| return fn |
| |
| |
| def expectedFailureSerDerPreDispatch(fn): |
| fn._expected_failure_serdes_pre_dispatch = True |
| return fn |
| |
| |
| def expectedFailurePreDispatchRunDecomp(fn): |
| fn._expected_failure_pre_dispatch = True |
| return fn |