blob: 50a3ba56eb795dee3cda6fbcf231299af52fe10f [file] [log] [blame]
from __future__ import absolute_import, division, print_function, unicode_literals
import unittest
from torchgen.selective_build.operator import *
from torchgen.selective_build.selector import (
combine_selective_builders,
SelectiveBuilder,
)
class TestSelectiveBuild(unittest.TestCase):
def test_selective_build_operator(self):
op = SelectiveBuildOperator(
"aten::add.int",
is_root_operator=True,
is_used_for_training=False,
include_all_overloads=False,
_debug_info=None,
)
self.assertTrue(op.is_root_operator)
self.assertFalse(op.is_used_for_training)
self.assertFalse(op.include_all_overloads)
def test_selector_factory(self):
yaml_config_v1 = """
debug_info:
- model1@v100
- model2@v51
operators:
aten::add:
is_used_for_training: No
is_root_operator: Yes
include_all_overloads: Yes
aten::add.int:
is_used_for_training: Yes
is_root_operator: No
include_all_overloads: No
aten::mul.int:
is_used_for_training: Yes
is_root_operator: No
include_all_overloads: No
"""
yaml_config_v2 = """
debug_info:
- model1@v100
- model2@v51
operators:
aten::sub:
is_used_for_training: No
is_root_operator: Yes
include_all_overloads: No
debug_info:
- model1@v100
aten::sub.int:
is_used_for_training: Yes
is_root_operator: No
include_all_overloads: No
"""
yaml_config_all = "include_all_operators: Yes"
yaml_config_invalid = "invalid:"
selector1 = SelectiveBuilder.from_yaml_str(yaml_config_v1)
self.assertTrue(selector1.is_operator_selected("aten::add"))
self.assertTrue(selector1.is_operator_selected("aten::add.int"))
# Overload name is not used for checking in v1.
self.assertTrue(selector1.is_operator_selected("aten::add.float"))
def gen():
return SelectiveBuilder.from_yaml_str(yaml_config_invalid)
self.assertRaises(Exception, gen)
selector_all = SelectiveBuilder.from_yaml_str(yaml_config_all)
self.assertTrue(selector_all.is_operator_selected("aten::add"))
self.assertTrue(selector_all.is_operator_selected("aten::sub"))
self.assertTrue(selector_all.is_operator_selected("aten::sub.int"))
self.assertTrue(selector_all.is_kernel_dtype_selected("add_kernel", "int32"))
selector2 = SelectiveBuilder.from_yaml_str(yaml_config_v2)
self.assertFalse(selector2.is_operator_selected("aten::add"))
self.assertTrue(selector2.is_operator_selected("aten::sub"))
self.assertTrue(selector2.is_operator_selected("aten::sub.int"))
selector_legacy_v1 = SelectiveBuilder.from_legacy_op_registration_allow_list(
["aten::add", "aten::add.int", "aten::mul.int"],
False,
False,
)
self.assertTrue(selector_legacy_v1.is_operator_selected("aten::add.float"))
self.assertTrue(selector_legacy_v1.is_operator_selected("aten::add"))
self.assertTrue(selector_legacy_v1.is_operator_selected("aten::add.int"))
self.assertFalse(selector_legacy_v1.is_operator_selected("aten::sub"))
self.assertFalse(selector_legacy_v1.is_root_operator("aten::add"))
self.assertFalse(
selector_legacy_v1.is_operator_selected_for_training("aten::add")
)
selector_legacy_v1 = SelectiveBuilder.from_legacy_op_registration_allow_list(
["aten::add", "aten::add.int", "aten::mul.int"],
True,
False,
)
self.assertTrue(selector_legacy_v1.is_root_operator("aten::add"))
self.assertFalse(
selector_legacy_v1.is_operator_selected_for_training("aten::add")
)
self.assertTrue(selector_legacy_v1.is_root_operator("aten::add.float"))
self.assertFalse(
selector_legacy_v1.is_operator_selected_for_training("aten::add.float")
)
selector_legacy_v1 = SelectiveBuilder.from_legacy_op_registration_allow_list(
["aten::add", "aten::add.int", "aten::mul.int"],
False,
True,
)
self.assertFalse(selector_legacy_v1.is_root_operator("aten::add"))
self.assertTrue(
selector_legacy_v1.is_operator_selected_for_training("aten::add")
)
self.assertFalse(selector_legacy_v1.is_root_operator("aten::add.float"))
self.assertTrue(
selector_legacy_v1.is_operator_selected_for_training("aten::add.float")
)
def test_operator_combine(self):
op1 = SelectiveBuildOperator(
"aten::add.int",
is_root_operator=True,
is_used_for_training=False,
include_all_overloads=False,
_debug_info=None,
)
op2 = SelectiveBuildOperator(
"aten::add.int",
is_root_operator=False,
is_used_for_training=False,
include_all_overloads=False,
_debug_info=None,
)
op3 = SelectiveBuildOperator(
"aten::add",
is_root_operator=True,
is_used_for_training=False,
include_all_overloads=False,
_debug_info=None,
)
op4 = SelectiveBuildOperator(
"aten::add.int",
is_root_operator=True,
is_used_for_training=True,
include_all_overloads=False,
_debug_info=None,
)
op5 = combine_operators(op1, op2)
self.assertTrue(op5.is_root_operator)
self.assertFalse(op5.is_used_for_training)
op6 = combine_operators(op1, op4)
self.assertTrue(op6.is_root_operator)
self.assertTrue(op6.is_used_for_training)
def gen_new_op():
return combine_operators(op1, op3)
self.assertRaises(Exception, gen_new_op)
def test_training_op_fetch(self):
yaml_config = """
operators:
aten::add.int:
is_used_for_training: No
is_root_operator: Yes
include_all_overloads: No
aten::add:
is_used_for_training: Yes
is_root_operator: No
include_all_overloads: Yes
"""
selector = SelectiveBuilder.from_yaml_str(yaml_config)
self.assertTrue(selector.is_operator_selected_for_training("aten::add.int"))
self.assertTrue(selector.is_operator_selected_for_training("aten::add"))
def test_kernel_dtypes(self):
yaml_config = """
kernel_metadata:
add_kernel:
- int8
- int32
sub_kernel:
- int16
- int32
add/sub_kernel:
- float
- complex
"""
selector = SelectiveBuilder.from_yaml_str(yaml_config)
self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "int32"))
self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "int8"))
self.assertFalse(selector.is_kernel_dtype_selected("add_kernel", "int16"))
self.assertFalse(selector.is_kernel_dtype_selected("add1_kernel", "int32"))
self.assertFalse(selector.is_kernel_dtype_selected("add_kernel", "float"))
self.assertTrue(selector.is_kernel_dtype_selected("add/sub_kernel", "float"))
self.assertTrue(selector.is_kernel_dtype_selected("add/sub_kernel", "complex"))
self.assertFalse(selector.is_kernel_dtype_selected("add/sub_kernel", "int16"))
self.assertFalse(selector.is_kernel_dtype_selected("add/sub_kernel", "int32"))
def test_merge_kernel_dtypes(self):
yaml_config1 = """
kernel_metadata:
add_kernel:
- int8
add/sub_kernel:
- float
- complex
- none
mul_kernel:
- int8
"""
yaml_config2 = """
kernel_metadata:
add_kernel:
- int32
sub_kernel:
- int16
- int32
add/sub_kernel:
- float
- complex
"""
selector1 = SelectiveBuilder.from_yaml_str(yaml_config1)
selector2 = SelectiveBuilder.from_yaml_str(yaml_config2)
selector = combine_selective_builders(selector1, selector2)
self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "int32"))
self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "int8"))
self.assertFalse(selector.is_kernel_dtype_selected("add_kernel", "int16"))
self.assertFalse(selector.is_kernel_dtype_selected("add1_kernel", "int32"))
self.assertFalse(selector.is_kernel_dtype_selected("add_kernel", "float"))
self.assertTrue(selector.is_kernel_dtype_selected("add/sub_kernel", "float"))
self.assertTrue(selector.is_kernel_dtype_selected("add/sub_kernel", "complex"))
self.assertTrue(selector.is_kernel_dtype_selected("add/sub_kernel", "none"))
self.assertFalse(selector.is_kernel_dtype_selected("add/sub_kernel", "int16"))
self.assertFalse(selector.is_kernel_dtype_selected("add/sub_kernel", "int32"))
self.assertTrue(selector.is_kernel_dtype_selected("mul_kernel", "int8"))
self.assertFalse(selector.is_kernel_dtype_selected("mul_kernel", "int32"))
def test_all_kernel_dtypes_selected(self):
yaml_config = """
include_all_non_op_selectives: True
"""
selector = SelectiveBuilder.from_yaml_str(yaml_config)
self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "int32"))
self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "int8"))
self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "int16"))
self.assertTrue(selector.is_kernel_dtype_selected("add1_kernel", "int32"))
self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "float"))