blob: 1eeba84b6bafdd0714b7204ae77870abed6b28a3 [file] [log] [blame]
import dataclasses
import typing
import unittest
from collections import defaultdict
from typing import Dict, List
import torchgen.model
import yaml
from tools.autograd import gen_autograd_functions, load_derivatives
from torchgen import dest
from torchgen.api.types import CppSignatureGroup, DispatcherSignature
from torchgen.context import native_function_manager
from torchgen.gen import (
get_native_function_declarations,
get_native_function_schema_registrations,
LineLoader,
static_dispatch,
)
from torchgen.model import (
BackendIndex,
BackendMetadata,
DispatchKey,
Location,
NativeFunction,
OperatorName,
)
from torchgen.native_function_generation import add_generated_native_functions
from torchgen.selective_build.selector import SelectiveBuilder
class TestCreateDerivative(unittest.TestCase):
def test_named_grads(self) -> None:
schema = torchgen.model.FunctionSchema.parse(
"func(Tensor a, Tensor b) -> (Tensor x, Tensor y)"
)
native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema)
derivative = load_derivatives.create_derivative(
native_function,
formula="func_backward(grad_x, grad_y)",
var_names=(),
available_named_gradients=["grad_x", "grad_y"],
)
self.assertSetEqual(derivative.named_gradients, {"grad_x", "grad_y"})
def test_non_differentiable_output(self) -> None:
specification = "func(Tensor a, Tensor b) -> (Tensor x, bool y, Tensor z)"
schema = torchgen.model.FunctionSchema.parse(specification)
native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema)
_, differentiability_info = load_derivatives.create_differentiability_info(
defn_dict={
"name": specification,
"dispatch": {"Default": {"a": "grads[0]", "b": "grads[2]"}},
},
functions_by_signature={schema.signature(): [native_function]},
functions_by_schema={specification: native_function},
op_counter=typing.Counter[str](),
used_dispatch_keys=set(),
)
self.assertSequenceEqual(
differentiability_info["Default"].available_named_gradients,
# grad_y is not present because y is a
# bool and thus not differentiable.
["grad_x", "grad_z"],
)
def test_indexed_grads(self) -> None:
schema = torchgen.model.FunctionSchema.parse(
"func(Tensor a, Tensor b) -> (Tensor x, Tensor y)"
)
native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema)
derivative = load_derivatives.create_derivative(
native_function,
formula="func_backward(grads[0], grads[1])",
var_names=(),
available_named_gradients=["grad_x", "grad_y"],
)
self.assertSetEqual(derivative.named_gradients, set())
def test_named_grads_and_indexed_grads(self) -> None:
specification = "func(Tensor a, Tensor b) -> (Tensor x, Tensor y)"
schema = torchgen.model.FunctionSchema.parse(specification)
native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema)
with self.assertRaisesRegex(
RuntimeError, 'illegally mixes use of "grad_RETURN_NAME"'
):
load_derivatives.create_differentiability_info(
defn_dict={
"name": specification,
# Uh-oh, the derivatives reference gradients by
# name and by index.
"dispatch": {
"Default": {
"a": "grad_x",
"b": "grads[1]",
}
},
},
functions_by_signature={schema.signature(): [native_function]},
functions_by_schema={specification: native_function},
op_counter=typing.Counter[str](),
used_dispatch_keys=set(),
)
class TestGenAutogradFunctions(unittest.TestCase):
def test_non_differentiable_output_invalid_type(self) -> None:
specification = "func(Tensor a, Tensor b) -> (Tensor x, bool y, Tensor z)"
schema = torchgen.model.FunctionSchema.parse(specification)
native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema)
_, differentiability_info = load_derivatives.create_differentiability_info(
defn_dict={
"name": specification,
"dispatch": {
"Default": {
"a": "grad_x",
"b": "grad_z",
}
},
},
functions_by_signature={schema.signature(): [native_function]},
functions_by_schema={specification: native_function},
op_counter=typing.Counter[str](),
used_dispatch_keys=set(),
)
definition = gen_autograd_functions.process_function(
differentiability_info["Default"],
gen_autograd_functions.FUNCTION_DEFINITION,
)
# grad_z should map to grads[1], not grads[2] because output 1
# (y) is not differentiable.
assert "grad_z = grads[2]" not in definition
assert "grad_z = grads[1]" in definition
def test_non_differentiable_output_output_differentiability(self) -> None:
specification = "func(Tensor a, Tensor b) -> (Tensor x, Tensor y, Tensor z)"
schema = torchgen.model.FunctionSchema.parse(specification)
native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema)
_, differentiability_info = load_derivatives.create_differentiability_info(
defn_dict={
"name": specification,
"dispatch": {
"Default": {
"a": "grad_x",
"b": "grad_z",
},
"AutogradNestedTensor": {
"a": "grad_z",
"b": "grad_x",
},
},
"output_differentiability": [True, False, True],
},
functions_by_signature={schema.signature(): [native_function]},
functions_by_schema={specification: native_function},
op_counter=typing.Counter[str](),
used_dispatch_keys=set(),
)
default_definition = gen_autograd_functions.process_function(
differentiability_info["Default"],
gen_autograd_functions.FUNCTION_DEFINITION,
)
# grad_z should map to grads[1], not grads[2] because output 1
# (y) is not differentiable.
assert "grad_z = grads[2]" not in default_definition
assert "grad_z = grads[1]" in default_definition
nested_tensor_definition = gen_autograd_functions.process_function(
differentiability_info["AutogradNestedTensor"],
gen_autograd_functions.FUNCTION_DEFINITION,
)
assert "grad_z = grads[2]" not in nested_tensor_definition
assert "grad_z = grads[1]" in nested_tensor_definition
def test_register_bogus_dispatch_key(self) -> None:
specification = "func(Tensor a, Tensor b) -> (Tensor x, bool y, Tensor z)"
schema = torchgen.model.FunctionSchema.parse(specification)
native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema)
with self.assertRaisesRegex(
RuntimeError,
"Invalid dispatch key AutogradRandomTensor in derivatives.yaml for",
):
load_derivatives.create_differentiability_info(
defn_dict={
"name": specification,
"dispatch": {
"Default": {
"a": "grad_x",
"b": "grad_z",
},
"AutogradRandomTensor": {
"a": "grad_x",
"b": "grad_z",
},
},
},
functions_by_signature={schema.signature(): [native_function]},
functions_by_schema={specification: native_function},
op_counter=typing.Counter[str](),
used_dispatch_keys=set(),
)
class TestGenSchemaRegistration(unittest.TestCase):
def setUp(self) -> None:
self.selector = SelectiveBuilder.get_nop_selector()
self.custom_native_function, _ = torchgen.model.NativeFunction.from_yaml(
{"func": "custom::func() -> bool"},
loc=torchgen.model.Location(__file__, 1),
valid_tags=set(),
)
(
self.fragment_custom_native_function,
_,
) = torchgen.model.NativeFunction.from_yaml(
{"func": "quantized_decomposed::func() -> bool"},
loc=torchgen.model.Location(__file__, 1),
valid_tags=set(),
)
def test_default_namespace_schema_registration_code_valid(self) -> None:
native_functions = [DEFAULT_NATIVE_FUNCTION]
registrations, _ = get_native_function_schema_registrations(
native_functions=native_functions,
schema_selector=self.selector,
)
self.assertEqual(registrations, ['m.def("func() -> bool", {});\n'])
def test_custom_namespace_schema_registration_code_valid(self) -> None:
_, registrations = get_native_function_schema_registrations(
native_functions=[self.custom_native_function],
schema_selector=self.selector,
)
self.assertEqual(
registrations,
"""
TORCH_LIBRARY(custom, m) {
m.def("func() -> bool", {});
};""",
)
def test_fragment_custom_namespace_schema_registration_code_valid(self) -> None:
"""Sometimes we want to extend an existing namespace, for example quantized
namespace, which is already defined in native/quantized/library.cpp
"""
_, registrations = get_native_function_schema_registrations(
native_functions=[self.fragment_custom_native_function],
schema_selector=self.selector,
)
self.assertEqual(
registrations,
"""
TORCH_LIBRARY_FRAGMENT(quantized_decomposed, m) {
m.def("func() -> bool", {});
};""",
)
def test_mixed_namespace_schema_registration_code_valid(self) -> None:
(
aten_registrations,
custom_registrations,
) = get_native_function_schema_registrations(
native_functions=[DEFAULT_NATIVE_FUNCTION, self.custom_native_function],
schema_selector=self.selector,
)
self.assertEqual(aten_registrations, ['m.def("func() -> bool", {});\n'])
self.assertEqual(
custom_registrations,
"""
TORCH_LIBRARY(custom, m) {
m.def("func() -> bool", {});
};""",
)
def test_3_namespaces_schema_registration_code_valid(self) -> None:
custom2_native_function, _ = torchgen.model.NativeFunction.from_yaml(
{"func": "custom2::func() -> bool"},
loc=torchgen.model.Location(__file__, 1),
valid_tags=set(),
)
(
aten_registrations,
custom_registrations,
) = get_native_function_schema_registrations(
native_functions=[
DEFAULT_NATIVE_FUNCTION,
self.custom_native_function,
custom2_native_function,
],
schema_selector=self.selector,
)
self.assertEqual(aten_registrations, ['m.def("func() -> bool", {});\n'])
self.assertEqual(
custom_registrations,
"""
TORCH_LIBRARY(custom, m) {
m.def("func() -> bool", {});
};
TORCH_LIBRARY(custom2, m) {
m.def("func() -> bool", {});
};""",
)
class TestGenNativeFunctionDeclaration(unittest.TestCase):
def setUp(self) -> None:
self.op_1_native_function, op_1_backend_index = NativeFunction.from_yaml(
{"func": "op_1() -> bool", "dispatch": {"CPU": "kernel_1"}},
loc=torchgen.model.Location(__file__, 1),
valid_tags=set(),
)
self.op_2_native_function, op_2_backend_index = NativeFunction.from_yaml(
{
"func": "op_2() -> bool",
"dispatch": {"CPU": "kernel_2", "QuantizedCPU": "custom::kernel_3"},
},
loc=torchgen.model.Location(__file__, 1),
valid_tags=set(),
)
backend_indices: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]] = {
DispatchKey.CPU: {},
DispatchKey.QuantizedCPU: {},
}
BackendIndex.grow_index(backend_indices, op_1_backend_index)
BackendIndex.grow_index(backend_indices, op_2_backend_index)
self.backend_indices = {
k: BackendIndex(
dispatch_key=k,
use_out_as_primary=True,
external=False,
device_guard=False,
index=backend_indices[k],
)
for k in backend_indices
}
def test_native_function_declaration_1_op_2_ns_error(self) -> None:
with self.assertRaises(AssertionError):
get_native_function_declarations(
grouped_native_functions=[
self.op_1_native_function,
self.op_2_native_function,
],
backend_indices=self.backend_indices,
native_function_decl_gen=dest.compute_native_function_declaration,
)
def test_native_function_declaration_1_op_1_ns_valid(self) -> None:
self.assertIsInstance(self.op_1_native_function, NativeFunction)
declaration = get_native_function_declarations(
grouped_native_functions=[
self.op_1_native_function,
],
backend_indices=self.backend_indices,
native_function_decl_gen=dest.compute_native_function_declaration,
)
target = """
namespace at {
namespace native {
TORCH_API bool kernel_1();
} // namespace native
} // namespace at
"""
self.assertEqual("\n".join(declaration), target)
# Test for native_function_generation
class TestNativeFunctionGeneratrion(unittest.TestCase):
def setUp(self) -> None:
self.native_functions: List[NativeFunction] = []
self.backend_indices: Dict[
DispatchKey, Dict[OperatorName, BackendMetadata]
] = defaultdict(dict)
yaml_entry = """
- func: op(Tensor self) -> Tensor
dispatch:
CompositeExplicitAutograd: op
autogen: op.out
"""
es = yaml.load(yaml_entry, Loader=LineLoader)
self.one_return_func, m = NativeFunction.from_yaml(
es[0], loc=Location(__file__, 1), valid_tags=set()
)
BackendIndex.grow_index(self.backend_indices, m)
self.two_returns_func, two_returns_backend_index = NativeFunction.from_yaml(
{
"func": "op_2() -> (Tensor, Tensor)",
"dispatch": {"CPU": "kernel_1"},
"autogen": "op_2.out",
},
loc=torchgen.model.Location(__file__, 1),
valid_tags=set(),
)
BackendIndex.grow_index(self.backend_indices, two_returns_backend_index)
def test_functional_variant_autogen_out_variant(self) -> None:
native_functions = [self.one_return_func]
add_generated_native_functions(native_functions, self.backend_indices)
self.assertEqual(len(native_functions), 2)
self.assertEqual(
str(native_functions[1].func),
"op.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)",
)
op_name = native_functions[1].func.name
backend_metadata = self.backend_indices[DispatchKey.CompositeExplicitAutograd][
op_name
]
self.assertEqual(backend_metadata.kernel, "op_out")
def test_functional_variant_autogen_out_variant_two_returns(self) -> None:
native_functions = [self.two_returns_func]
add_generated_native_functions(native_functions, self.backend_indices)
self.assertEqual(len(native_functions), 2)
self.assertEqual(
str(native_functions[1].func),
"op_2.out(*, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))",
)
op_name = native_functions[1].func.name
backend_metadata = self.backend_indices[DispatchKey.CompositeExplicitAutograd][
op_name
]
self.assertEqual(backend_metadata.kernel, "op_2_out")
# Test for static_dispatch
class TestStaticDispatchGeneratrion(unittest.TestCase):
def setUp(self) -> None:
self.backend_indices: Dict[
DispatchKey, Dict[OperatorName, BackendMetadata]
] = defaultdict(dict)
yaml_entry = """
- func: op.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
CompositeExplicitAutograd: op
"""
es = yaml.load(yaml_entry, Loader=LineLoader)
self.one_return_func, m = NativeFunction.from_yaml(
es[0], loc=Location(__file__, 1), valid_tags=set()
)
BackendIndex.grow_index(self.backend_indices, m)
dispatch_key = DispatchKey.CompositeExplicitAutograd
self.assertTrue(dispatch_key in self.backend_indices)
self.indices = [
BackendIndex(
dispatch_key=dispatch_key,
use_out_as_primary=True,
external=False,
device_guard=False,
index=self.backend_indices[dispatch_key],
)
]
def test_op_with_1_backend_generates_static_dispatch(self) -> None:
disp_sig = DispatcherSignature.from_schema(self.one_return_func.func)
with native_function_manager(self.one_return_func):
out = static_dispatch(
sig=disp_sig,
f=self.one_return_func,
backend_indices=self.indices,
)
self.assertEqual(
out, "return at::compositeexplicitautograd::op_out(out, self);"
)
def test_op_with_cpp_sig_generates_static_dispatch(self) -> None:
sig_group = CppSignatureGroup.from_native_function(
self.one_return_func,
method=False,
fallback_binding=self.one_return_func.manual_cpp_binding,
)
# cpp signature puts out at the front
with native_function_manager(self.one_return_func):
out = static_dispatch(
sig=sig_group.signature,
f=self.one_return_func,
backend_indices=self.indices,
)
self.assertEqual(
out, "return at::compositeexplicitautograd::op_out(out, self);"
)
# Represents the most basic NativeFunction. Use dataclasses.replace()
# to edit for use.
DEFAULT_NATIVE_FUNCTION, _ = torchgen.model.NativeFunction.from_yaml(
{"func": "func() -> bool"},
loc=torchgen.model.Location(__file__, 1),
valid_tags=set(),
)
if __name__ == "__main__":
unittest.main()