| # Owner(s): ["module: unknown"] |
| |
| from functools import partial |
| from textwrap import dedent |
| |
| import torch |
| from torch.testing import FileCheck |
| from torch.testing._internal.common_device_type import ( |
| instantiate_device_type_tests, |
| OpDTypes, |
| ops, |
| ) |
| from torch.testing._internal.common_jit import ( |
| check_against_reference, |
| JitCommonTestCase, |
| ) |
| from torch.testing._internal.common_methods_invocations import op_db |
| from torch.testing._internal.common_utils import ( |
| clone_input_helper, |
| first_sample, |
| IS_SANDCASTLE, |
| run_tests, |
| TestCase, |
| unMarkDynamoStrictTest, |
| ) |
| from torch.testing._internal.jit_metaprogramming_utils import ( |
| check_alias_annotation, |
| create_script_fn, |
| create_traced_fn, |
| ) |
| from torch.testing._internal.jit_utils import ( |
| disable_autodiff_subgraph_inlining, |
| is_lambda, |
| ) |
| |
| |
| # variant testing is only done with torch.float and torch.cfloat to avoid |
| # excessive test times and maximize signal to noise ratio |
| _variant_ops = partial( |
| ops, dtypes=OpDTypes.supported, allowed_dtypes=(torch.float, torch.cfloat) |
| ) |
| |
| |
| # Tests operators for consistency between JIT and eager, also checks |
| # correctness of JIT specific alias schemas and intended |
| # autodifferentiation behavior. |
| # Inherits from JitCommonTestCase instead of TestCase directly to share |
| # functionality with original test_jit.py method operator tests |
| @unMarkDynamoStrictTest |
| class TestJit(JitCommonTestCase): |
| exact_dtype = True |
| |
| # Tests that the forward and backward passes of operations produce the |
| # same values for the cross-product of op variants (function, method, inplace) |
| # and runtimes (eager, traced, scripted). |
| # TODO WARNING: inplace x {traced, scripted} not currently tested |
| @_variant_ops(op_db) |
| def test_variant_consistency_jit(self, device, dtype, op): |
| _requires_grad = dtype in op.supported_backward_dtypes( |
| torch.device(device).type |
| ) |
| |
| include_conjugated_inputs = op.test_conjugated_samples and dtype.is_complex |
| samples = op.sample_inputs( |
| device, |
| dtype, |
| requires_grad=_requires_grad, |
| include_conjugated_inputs=include_conjugated_inputs, |
| ) |
| |
| # Acquires variants to test |
| func = op.get_op() |
| method = op.get_method() |
| variants = { |
| # TODO: inplace tests currently fail, fix and add inplace variant |
| "function": func, |
| "method": method, |
| } |
| |
| # scripting strips the torch.ops prefix from these operators |
| # incorrectly; don't bother testing this case. Count this |
| # as "testing" |
| if isinstance(func, torch._ops.OpOverload): |
| self.skipTest("variant consistency doesn't work on torch.ops") |
| |
| # TODO: find better way to standardize on op registration itself.. |
| has_fake_function = op.name in ["resize_", "resize_as_"] |
| |
| if has_fake_function: |
| variants = {"method": getattr(torch.Tensor, op.name)} |
| samples = op.sample_inputs(device, dtype, requires_grad=False) |
| |
| tested = False |
| for sample in samples: |
| # Test traced and scripted consistency |
| for func_type, variant in variants.items(): |
| if variant is None: |
| continue |
| |
| # scripting and check_alias_analysis do not work with lambdas |
| # lambdas are typically used as a way to simulate methods without |
| # functional variants, so rely on the other variant for testing |
| # for now |
| if is_lambda(variant): |
| continue |
| |
| tested = True |
| try: |
| self.indiv_variant_test_jit( |
| device, dtype, op, sample, func_type, variant, has_fake_function |
| ) |
| except Exception as e: |
| variant_error_info = dedent( |
| f""" |
| Error testing {op.name} {func_type} variant |
| with dtype: {dtype} |
| with inputs {sample}: |
| """ |
| ) |
| raise Exception(variant_error_info) from e # noqa: TRY002 |
| |
| assert tested, "JIT Test does not execute any logic" |
| |
| def indiv_variant_test_jit( |
| self, device, dtype, op, sample, func_type, variant, has_fake_function |
| ): |
| _requires_grad = dtype in op.supported_backward_dtypes( |
| torch.device(device).type |
| ) |
| support_script = op.supports_scripting |
| # Create accessor for script function variant |
| name = op.name + "_" if func_type == "inplace" else op.name |
| |
| # run with disable_autodiff_subgraph_inlining(True) to test |
| # autodiff support. Context manager forces the graph to contain |
| # DifferentiableGraph nodes if they are present |
| with disable_autodiff_subgraph_inlining(): |
| # Check scripted forward, grad, and grad grad |
| if support_script: |
| script_fn = create_script_fn(self, name, func_type) |
| |
| def out_fn(output): |
| # Processes the output for autograd |
| if sample.output_process_fn_grad is not None: |
| return sample.output_process_fn_grad(output) |
| return output |
| |
| def get_sample(): |
| return ( |
| clone_input_helper(sample.input) |
| if op.name[-1] == "_" |
| else sample.input |
| ) |
| |
| if support_script: |
| check_against_reference( |
| self, |
| script_fn, |
| op.get_op(), |
| out_fn, |
| (get_sample(),) + sample.args, |
| sample.kwargs, |
| no_grad=not _requires_grad, |
| no_gradgrad=not op.supports_gradgrad, |
| ) |
| |
| # Check traced forward, grad, and grad grad |
| # TODO: fix tracing here |
| supports_tracing = op.supports_tracing and not has_fake_function |
| if op.assert_jit_shape_analysis: |
| self.assertTrue(supports_tracing) |
| |
| if supports_tracing: |
| traced_fn = create_traced_fn(self, variant) |
| check_against_reference( |
| self, |
| traced_fn, |
| op.get_op(), |
| out_fn, |
| (get_sample(),) + sample.args, |
| sample.kwargs, |
| no_grad=not _requires_grad, |
| no_gradgrad=not op.supports_gradgrad, |
| ) |
| |
| # Check alias annotation schema for correctness (make |
| # sure inputs that aren't supposed to be modified aren't) |
| # Note: only runs in float32 because schema isn't affected by dtype, |
| # so running it on all dtypes is would be excessive |
| if dtype == torch.float32: |
| # TODO: no reason why we cant run this with tracing graph |
| if support_script and op.name != "rsub": |
| check_alias_annotation( |
| name, |
| (get_sample(),) + sample.args, |
| sample.kwargs, |
| func_type=func_type, |
| aten_name=op.aten_name, |
| ) |
| |
| # TODO: use script graph as well |
| checked_shape_analysis = False |
| if supports_tracing: |
| out = variant(get_sample(), *sample.args, **sample.kwargs) |
| |
| # right now, tuple of outputs and tensor output supported |
| # TODO: list of tensor outputs |
| tuple_of_tensors = isinstance(out, tuple) and all( |
| isinstance(elem, torch.Tensor) for elem in out |
| ) |
| |
| if isinstance(out, torch.Tensor) or tuple_of_tensors: |
| if tuple_of_tensors: |
| sizes = [elem.size() for elem in out] |
| else: |
| sizes = out.size() |
| self.checkShapeAnalysis( |
| sizes, traced_fn.graph, op.assert_jit_shape_analysis |
| ) |
| checked_shape_analysis = True |
| if op.assert_jit_shape_analysis: |
| self.assertTrue(checked_shape_analysis) |
| |
| # Check autodifferentiation of nodes for traced and scripted graphs, only need to check once per sample |
| if dtype is torch.float32: |
| # Sandcastle doesn't fuse nodes |
| if IS_SANDCASTLE: |
| # fusible nodes are expected to be found in FusionGroups in the DifferentiableGraphs |
| nonfusible_nodes = ( |
| op.autodiff_nonfusible_nodes + op.autodiff_fusible_nodes |
| ) |
| fusible_nodes = [] |
| else: |
| nonfusible_nodes = op.autodiff_nonfusible_nodes |
| fusible_nodes = op.autodiff_fusible_nodes |
| |
| if supports_tracing: |
| self.assertAutodiffNode( |
| traced_fn.last_graph, |
| op.assert_autodiffed, |
| nonfusible_nodes, |
| fusible_nodes, |
| ) |
| if support_script: |
| self.assertAutodiffNode( |
| script_fn.last_graph, |
| op.assert_autodiffed, |
| nonfusible_nodes, |
| fusible_nodes, |
| ) |
| |
| # alias testing is only done with torch.float for the same reason |
| _alias_ops = partial(ops, dtypes=OpDTypes.supported, allowed_dtypes=(torch.float,)) |
| |
| @_alias_ops(op for op in op_db if op.aliases) |
| def test_jit_alias_remapping(self, device, dtype, op): |
| # NOTE: only tests on first sample |
| samples = op.sample_inputs(device, dtype, requires_grad=True) |
| sample = first_sample(self, samples) |
| |
| # [Scripting Data Preparation] |
| # Prepare data for test scripting |
| # Below we prepare strings of args/kwargs with and without type annotations. |
| # These strings are inserted into function template strings which is then torch scripted. |
| # - args string is ["t0"] corresponding to the "input" tensor required by the op |
| # - args_kw is the value of args and strings of kwargs used to call the op (without type annotations), for example, |
| # ["to", "1.0", "(1,)", "True", "tensor(1.0)"] -> def fn(t0): return variant(t0, 1.0, (1,), True, tensor(1.0)) |
| args = ["t0"] |
| |
| def quote_strs(v): |
| if isinstance(v, str): |
| return f"'{v}'" |
| |
| return str(v) |
| |
| args_kw = ( |
| args |
| + [f"{v}" for v in sample.args] |
| + [f"{k}={quote_strs(v)}" for k, v in sample.kwargs.items()] |
| ) |
| |
| # Prepare data for test tracing |
| sample_args_kwargs = () |
| if len(sample.args) > 0: |
| sample_args_kwargs += (sample.args,) |
| if len(sample.kwargs) > 0: |
| sample_args_kwargs += (sample.kwargs,) |
| |
| original_name = op.aten_name |
| original_name_inplace = original_name + "_" |
| expected_dtype = op(sample.input, *sample.args, **sample.kwargs).dtype |
| |
| for a_op in op.aliases: |
| inplace = a_op.inplace_variant |
| method_or_inplace = [a_op.inplace_variant, a_op.method_variant] |
| variants = ( |
| v |
| for v in (a_op.op, a_op.method_variant, a_op.inplace_variant) |
| if v is not None |
| ) |
| |
| # Test scripting: |
| for variant in variants: |
| variant_name = variant.__name__ |
| op_name = original_name_inplace if variant is inplace else original_name |
| |
| if variant in method_or_inplace: |
| fn_template = """ |
| def _fn(t0{c}): |
| return t0.{alias_name}({args_kw}) |
| """ |
| # remove the first input tensor |
| script = fn_template.format( |
| c=", " if len(args_kw[1:]) > 1 else "", |
| args_kw=", ".join(args_kw[1:]), |
| alias_name=variant_name, |
| ) |
| else: |
| fn_template = """ |
| def _fn({args}): |
| return variant({args_kw}) |
| """ |
| script = fn_template.format( |
| args=", ".join(args), |
| args_kw=", ".join(args_kw), |
| ) |
| |
| # Required to avoid undefined value: tensor error in JIT |
| # compilation of the function template |
| script = script.replace("tensor(", "torch.tensor(") |
| |
| scripted = torch.jit.CompilationUnit(script)._fn |
| |
| if variant is inplace and not torch.can_cast(expected_dtype, dtype): |
| try: |
| inp = clone_input_helper(sample.input) |
| scripted(inp) |
| except Exception as e: |
| continue |
| self.fail( |
| "Inplace operation on integer tensor that should be promoted to float didn't fail!" |
| ) |
| |
| inp = clone_input_helper(sample.input) |
| scripted(inp) |
| inp = clone_input_helper(sample.input) |
| graph = scripted.graph_for(inp) |
| FileCheck().check(op.aten_name).check_not(variant_name).run(graph) |
| |
| # Test tracing: |
| for variant in variants: |
| variant_name = variant.__name__ |
| op_name = original_name_inplace if variant is inplace else original_name |
| |
| def _fn(*sample_args, **sample_kwargs): |
| return variant(*sample_args, **sample_kwargs) |
| |
| inp = (clone_input_helper(sample.input),) + sample_args_kwargs |
| traced = torch.jit.trace(_fn, *inp) |
| inp = (clone_input_helper(sample.input),) + sample_args_kwargs |
| traced(*inp) |
| inp = (clone_input_helper(sample.input),) + sample_args_kwargs |
| graph = traced.graph_for(*inp) |
| FileCheck().check(op_name).check_not(variant_name).run(graph) |
| |
| |
| instantiate_device_type_tests(TestJit, globals()) |
| |
| if __name__ == "__main__": |
| TestCase._default_dtype_check_enabled = True |
| run_tests() |