| import dataclasses |
| import typing |
| import unittest |
| from collections import defaultdict |
| from typing import Dict, List |
| |
| import torchgen.model |
| |
| import yaml |
| |
| from tools.autograd import gen_autograd_functions, load_derivatives |
| from torchgen import dest |
| from torchgen.api.types import CppSignatureGroup, DispatcherSignature |
| from torchgen.context import native_function_manager |
| from torchgen.gen import ( |
| get_native_function_declarations, |
| get_native_function_schema_registrations, |
| LineLoader, |
| static_dispatch, |
| ) |
| from torchgen.model import ( |
| BackendIndex, |
| BackendMetadata, |
| DispatchKey, |
| Location, |
| NativeFunction, |
| OperatorName, |
| ) |
| from torchgen.native_function_generation import add_generated_native_functions |
| from torchgen.selective_build.selector import SelectiveBuilder |
| |
| |
| class TestCreateDerivative(unittest.TestCase): |
| def test_named_grads(self) -> None: |
| schema = torchgen.model.FunctionSchema.parse( |
| "func(Tensor a, Tensor b) -> (Tensor x, Tensor y)" |
| ) |
| native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema) |
| |
| derivative = load_derivatives.create_derivative( |
| native_function, |
| formula="func_backward(grad_x, grad_y)", |
| var_names=(), |
| available_named_gradients=["grad_x", "grad_y"], |
| ) |
| self.assertSetEqual(derivative.named_gradients, {"grad_x", "grad_y"}) |
| |
| def test_non_differentiable_output(self) -> None: |
| specification = "func(Tensor a, Tensor b) -> (Tensor x, bool y, Tensor z)" |
| schema = torchgen.model.FunctionSchema.parse(specification) |
| native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema) |
| |
| _, differentiability_info = load_derivatives.create_differentiability_info( |
| defn_dict={ |
| "name": specification, |
| "dispatch": {"Default": {"a": "grads[0]", "b": "grads[2]"}}, |
| }, |
| functions_by_signature={schema.signature(): [native_function]}, |
| functions_by_schema={specification: native_function}, |
| op_counter=typing.Counter[str](), |
| used_dispatch_keys=set(), |
| ) |
| |
| self.assertSequenceEqual( |
| differentiability_info["Default"].available_named_gradients, |
| # grad_y is not present because y is a |
| # bool and thus not differentiable. |
| ["grad_x", "grad_z"], |
| ) |
| |
| def test_indexed_grads(self) -> None: |
| schema = torchgen.model.FunctionSchema.parse( |
| "func(Tensor a, Tensor b) -> (Tensor x, Tensor y)" |
| ) |
| native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema) |
| |
| derivative = load_derivatives.create_derivative( |
| native_function, |
| formula="func_backward(grads[0], grads[1])", |
| var_names=(), |
| available_named_gradients=["grad_x", "grad_y"], |
| ) |
| self.assertSetEqual(derivative.named_gradients, set()) |
| |
| def test_named_grads_and_indexed_grads(self) -> None: |
| specification = "func(Tensor a, Tensor b) -> (Tensor x, Tensor y)" |
| schema = torchgen.model.FunctionSchema.parse(specification) |
| native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema) |
| |
| with self.assertRaisesRegex( |
| RuntimeError, 'illegally mixes use of "grad_RETURN_NAME"' |
| ): |
| load_derivatives.create_differentiability_info( |
| defn_dict={ |
| "name": specification, |
| # Uh-oh, the derivatives reference gradients by |
| # name and by index. |
| "dispatch": { |
| "Default": { |
| "a": "grad_x", |
| "b": "grads[1]", |
| } |
| }, |
| }, |
| functions_by_signature={schema.signature(): [native_function]}, |
| functions_by_schema={specification: native_function}, |
| op_counter=typing.Counter[str](), |
| used_dispatch_keys=set(), |
| ) |
| |
| |
| class TestGenAutogradFunctions(unittest.TestCase): |
| def test_non_differentiable_output_invalid_type(self) -> None: |
| specification = "func(Tensor a, Tensor b) -> (Tensor x, bool y, Tensor z)" |
| schema = torchgen.model.FunctionSchema.parse(specification) |
| native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema) |
| |
| _, differentiability_info = load_derivatives.create_differentiability_info( |
| defn_dict={ |
| "name": specification, |
| "dispatch": { |
| "Default": { |
| "a": "grad_x", |
| "b": "grad_z", |
| } |
| }, |
| }, |
| functions_by_signature={schema.signature(): [native_function]}, |
| functions_by_schema={specification: native_function}, |
| op_counter=typing.Counter[str](), |
| used_dispatch_keys=set(), |
| ) |
| definition = gen_autograd_functions.process_function( |
| differentiability_info["Default"], |
| gen_autograd_functions.FUNCTION_DEFINITION, |
| ) |
| # grad_z should map to grads[1], not grads[2] because output 1 |
| # (y) is not differentiable. |
| assert "grad_z = grads[2]" not in definition |
| assert "grad_z = grads[1]" in definition |
| |
| def test_non_differentiable_output_output_differentiability(self) -> None: |
| specification = "func(Tensor a, Tensor b) -> (Tensor x, Tensor y, Tensor z)" |
| schema = torchgen.model.FunctionSchema.parse(specification) |
| native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema) |
| |
| _, differentiability_info = load_derivatives.create_differentiability_info( |
| defn_dict={ |
| "name": specification, |
| "dispatch": { |
| "Default": { |
| "a": "grad_x", |
| "b": "grad_z", |
| }, |
| "AutogradNestedTensor": { |
| "a": "grad_z", |
| "b": "grad_x", |
| }, |
| }, |
| "output_differentiability": [True, False, True], |
| }, |
| functions_by_signature={schema.signature(): [native_function]}, |
| functions_by_schema={specification: native_function}, |
| op_counter=typing.Counter[str](), |
| used_dispatch_keys=set(), |
| ) |
| default_definition = gen_autograd_functions.process_function( |
| differentiability_info["Default"], |
| gen_autograd_functions.FUNCTION_DEFINITION, |
| ) |
| # grad_z should map to grads[1], not grads[2] because output 1 |
| # (y) is not differentiable. |
| assert "grad_z = grads[2]" not in default_definition |
| assert "grad_z = grads[1]" in default_definition |
| |
| nested_tensor_definition = gen_autograd_functions.process_function( |
| differentiability_info["AutogradNestedTensor"], |
| gen_autograd_functions.FUNCTION_DEFINITION, |
| ) |
| assert "grad_z = grads[2]" not in nested_tensor_definition |
| assert "grad_z = grads[1]" in nested_tensor_definition |
| |
| def test_register_bogus_dispatch_key(self) -> None: |
| specification = "func(Tensor a, Tensor b) -> (Tensor x, bool y, Tensor z)" |
| schema = torchgen.model.FunctionSchema.parse(specification) |
| native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema) |
| |
| with self.assertRaisesRegex( |
| RuntimeError, |
| "Invalid dispatch key AutogradRandomTensor in derivatives.yaml for", |
| ): |
| load_derivatives.create_differentiability_info( |
| defn_dict={ |
| "name": specification, |
| "dispatch": { |
| "Default": { |
| "a": "grad_x", |
| "b": "grad_z", |
| }, |
| "AutogradRandomTensor": { |
| "a": "grad_x", |
| "b": "grad_z", |
| }, |
| }, |
| }, |
| functions_by_signature={schema.signature(): [native_function]}, |
| functions_by_schema={specification: native_function}, |
| op_counter=typing.Counter[str](), |
| used_dispatch_keys=set(), |
| ) |
| |
| |
| class TestGenSchemaRegistration(unittest.TestCase): |
| def setUp(self) -> None: |
| self.selector = SelectiveBuilder.get_nop_selector() |
| self.custom_native_function, _ = torchgen.model.NativeFunction.from_yaml( |
| {"func": "custom::func() -> bool"}, |
| loc=torchgen.model.Location(__file__, 1), |
| valid_tags=set(), |
| ) |
| ( |
| self.fragment_custom_native_function, |
| _, |
| ) = torchgen.model.NativeFunction.from_yaml( |
| {"func": "quantized_decomposed::func() -> bool"}, |
| loc=torchgen.model.Location(__file__, 1), |
| valid_tags=set(), |
| ) |
| |
| def test_default_namespace_schema_registration_code_valid(self) -> None: |
| native_functions = [DEFAULT_NATIVE_FUNCTION] |
| registrations, _ = get_native_function_schema_registrations( |
| native_functions=native_functions, |
| schema_selector=self.selector, |
| ) |
| self.assertEqual(registrations, ['m.def("func() -> bool", {});\n']) |
| |
| def test_custom_namespace_schema_registration_code_valid(self) -> None: |
| _, registrations = get_native_function_schema_registrations( |
| native_functions=[self.custom_native_function], |
| schema_selector=self.selector, |
| ) |
| self.assertEqual( |
| registrations, |
| """ |
| TORCH_LIBRARY(custom, m) { |
| m.def("func() -> bool", {}); |
| |
| };""", |
| ) |
| |
| def test_fragment_custom_namespace_schema_registration_code_valid(self) -> None: |
| """Sometimes we want to extend an existing namespace, for example quantized |
| namespace, which is already defined in native/quantized/library.cpp |
| """ |
| _, registrations = get_native_function_schema_registrations( |
| native_functions=[self.fragment_custom_native_function], |
| schema_selector=self.selector, |
| ) |
| self.assertEqual( |
| registrations, |
| """ |
| TORCH_LIBRARY_FRAGMENT(quantized_decomposed, m) { |
| m.def("func() -> bool", {}); |
| |
| };""", |
| ) |
| |
| def test_mixed_namespace_schema_registration_code_valid(self) -> None: |
| ( |
| aten_registrations, |
| custom_registrations, |
| ) = get_native_function_schema_registrations( |
| native_functions=[DEFAULT_NATIVE_FUNCTION, self.custom_native_function], |
| schema_selector=self.selector, |
| ) |
| self.assertEqual(aten_registrations, ['m.def("func() -> bool", {});\n']) |
| self.assertEqual( |
| custom_registrations, |
| """ |
| TORCH_LIBRARY(custom, m) { |
| m.def("func() -> bool", {}); |
| |
| };""", |
| ) |
| |
| def test_3_namespaces_schema_registration_code_valid(self) -> None: |
| custom2_native_function, _ = torchgen.model.NativeFunction.from_yaml( |
| {"func": "custom2::func() -> bool"}, |
| loc=torchgen.model.Location(__file__, 1), |
| valid_tags=set(), |
| ) |
| ( |
| aten_registrations, |
| custom_registrations, |
| ) = get_native_function_schema_registrations( |
| native_functions=[ |
| DEFAULT_NATIVE_FUNCTION, |
| self.custom_native_function, |
| custom2_native_function, |
| ], |
| schema_selector=self.selector, |
| ) |
| self.assertEqual(aten_registrations, ['m.def("func() -> bool", {});\n']) |
| self.assertEqual( |
| custom_registrations, |
| """ |
| TORCH_LIBRARY(custom, m) { |
| m.def("func() -> bool", {}); |
| |
| }; |
| TORCH_LIBRARY(custom2, m) { |
| m.def("func() -> bool", {}); |
| |
| };""", |
| ) |
| |
| |
| class TestGenNativeFunctionDeclaration(unittest.TestCase): |
| def setUp(self) -> None: |
| self.op_1_native_function, op_1_backend_index = NativeFunction.from_yaml( |
| {"func": "op_1() -> bool", "dispatch": {"CPU": "kernel_1"}}, |
| loc=torchgen.model.Location(__file__, 1), |
| valid_tags=set(), |
| ) |
| self.op_2_native_function, op_2_backend_index = NativeFunction.from_yaml( |
| { |
| "func": "op_2() -> bool", |
| "dispatch": {"CPU": "kernel_2", "QuantizedCPU": "custom::kernel_3"}, |
| }, |
| loc=torchgen.model.Location(__file__, 1), |
| valid_tags=set(), |
| ) |
| |
| backend_indices: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]] = { |
| DispatchKey.CPU: {}, |
| DispatchKey.QuantizedCPU: {}, |
| } |
| BackendIndex.grow_index(backend_indices, op_1_backend_index) |
| BackendIndex.grow_index(backend_indices, op_2_backend_index) |
| self.backend_indices = { |
| k: BackendIndex( |
| dispatch_key=k, |
| use_out_as_primary=True, |
| external=False, |
| device_guard=False, |
| index=backend_indices[k], |
| ) |
| for k in backend_indices |
| } |
| |
| def test_native_function_declaration_1_op_2_ns_error(self) -> None: |
| with self.assertRaises(AssertionError): |
| get_native_function_declarations( |
| grouped_native_functions=[ |
| self.op_1_native_function, |
| self.op_2_native_function, |
| ], |
| backend_indices=self.backend_indices, |
| native_function_decl_gen=dest.compute_native_function_declaration, |
| ) |
| |
| def test_native_function_declaration_1_op_1_ns_valid(self) -> None: |
| self.assertIsInstance(self.op_1_native_function, NativeFunction) |
| declaration = get_native_function_declarations( |
| grouped_native_functions=[ |
| self.op_1_native_function, |
| ], |
| backend_indices=self.backend_indices, |
| native_function_decl_gen=dest.compute_native_function_declaration, |
| ) |
| target = """ |
| namespace at { |
| namespace native { |
| TORCH_API bool kernel_1(); |
| } // namespace native |
| } // namespace at |
| """ |
| self.assertEqual("\n".join(declaration), target) |
| |
| |
| # Test for native_function_generation |
| class TestNativeFunctionGeneratrion(unittest.TestCase): |
| def setUp(self) -> None: |
| self.native_functions: List[NativeFunction] = [] |
| self.backend_indices: Dict[ |
| DispatchKey, Dict[OperatorName, BackendMetadata] |
| ] = defaultdict(dict) |
| yaml_entry = """ |
| - func: op(Tensor self) -> Tensor |
| dispatch: |
| CompositeExplicitAutograd: op |
| autogen: op.out |
| """ |
| es = yaml.load(yaml_entry, Loader=LineLoader) |
| self.one_return_func, m = NativeFunction.from_yaml( |
| es[0], loc=Location(__file__, 1), valid_tags=set() |
| ) |
| |
| BackendIndex.grow_index(self.backend_indices, m) |
| |
| self.two_returns_func, two_returns_backend_index = NativeFunction.from_yaml( |
| { |
| "func": "op_2() -> (Tensor, Tensor)", |
| "dispatch": {"CPU": "kernel_1"}, |
| "autogen": "op_2.out", |
| }, |
| loc=torchgen.model.Location(__file__, 1), |
| valid_tags=set(), |
| ) |
| BackendIndex.grow_index(self.backend_indices, two_returns_backend_index) |
| |
| def test_functional_variant_autogen_out_variant(self) -> None: |
| native_functions = [self.one_return_func] |
| add_generated_native_functions(native_functions, self.backend_indices) |
| self.assertEqual(len(native_functions), 2) |
| self.assertEqual( |
| str(native_functions[1].func), |
| "op.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", |
| ) |
| op_name = native_functions[1].func.name |
| backend_metadata = self.backend_indices[DispatchKey.CompositeExplicitAutograd][ |
| op_name |
| ] |
| self.assertEqual(backend_metadata.kernel, "op_out") |
| |
| def test_functional_variant_autogen_out_variant_two_returns(self) -> None: |
| native_functions = [self.two_returns_func] |
| add_generated_native_functions(native_functions, self.backend_indices) |
| self.assertEqual(len(native_functions), 2) |
| self.assertEqual( |
| str(native_functions[1].func), |
| "op_2.out(*, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", |
| ) |
| op_name = native_functions[1].func.name |
| backend_metadata = self.backend_indices[DispatchKey.CompositeExplicitAutograd][ |
| op_name |
| ] |
| self.assertEqual(backend_metadata.kernel, "op_2_out") |
| |
| |
| # Test for static_dispatch |
| class TestStaticDispatchGeneratrion(unittest.TestCase): |
| def setUp(self) -> None: |
| self.backend_indices: Dict[ |
| DispatchKey, Dict[OperatorName, BackendMetadata] |
| ] = defaultdict(dict) |
| yaml_entry = """ |
| - func: op.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) |
| dispatch: |
| CompositeExplicitAutograd: op |
| """ |
| es = yaml.load(yaml_entry, Loader=LineLoader) |
| self.one_return_func, m = NativeFunction.from_yaml( |
| es[0], loc=Location(__file__, 1), valid_tags=set() |
| ) |
| |
| BackendIndex.grow_index(self.backend_indices, m) |
| dispatch_key = DispatchKey.CompositeExplicitAutograd |
| self.assertTrue(dispatch_key in self.backend_indices) |
| self.indices = [ |
| BackendIndex( |
| dispatch_key=dispatch_key, |
| use_out_as_primary=True, |
| external=False, |
| device_guard=False, |
| index=self.backend_indices[dispatch_key], |
| ) |
| ] |
| |
| def test_op_with_1_backend_generates_static_dispatch(self) -> None: |
| disp_sig = DispatcherSignature.from_schema(self.one_return_func.func) |
| with native_function_manager(self.one_return_func): |
| out = static_dispatch( |
| sig=disp_sig, |
| f=self.one_return_func, |
| backend_indices=self.indices, |
| ) |
| self.assertEqual( |
| out, "return at::compositeexplicitautograd::op_out(out, self);" |
| ) |
| |
| def test_op_with_cpp_sig_generates_static_dispatch(self) -> None: |
| sig_group = CppSignatureGroup.from_native_function( |
| self.one_return_func, |
| method=False, |
| fallback_binding=self.one_return_func.manual_cpp_binding, |
| ) |
| # cpp signature puts out at the front |
| with native_function_manager(self.one_return_func): |
| out = static_dispatch( |
| sig=sig_group.signature, |
| f=self.one_return_func, |
| backend_indices=self.indices, |
| ) |
| self.assertEqual( |
| out, "return at::compositeexplicitautograd::op_out(out, self);" |
| ) |
| |
| |
| # Represents the most basic NativeFunction. Use dataclasses.replace() |
| # to edit for use. |
| DEFAULT_NATIVE_FUNCTION, _ = torchgen.model.NativeFunction.from_yaml( |
| {"func": "func() -> bool"}, |
| loc=torchgen.model.Location(__file__, 1), |
| valid_tags=set(), |
| ) |
| |
| |
| if __name__ == "__main__": |
| unittest.main() |