Consolidate all python targets in the tools folder (#80408)
Summary:
All buck targets that points to caffe2/tools folder are now moved to tools/BUCK.
This also eliminates all python library/binary import in pt_defs.bzl, which caused T124308913.
Test Plan: CI
Differential Revision: D37468313
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80408
Approved by: https://github.com/seemethere, https://github.com/malfet
diff --git a/.github/workflows/_buck-build-test.yml b/.github/workflows/_buck-build-test.yml
index 2d1e563..ae7f751 100644
--- a/.github/workflows/_buck-build-test.yml
+++ b/.github/workflows/_buck-build-test.yml
@@ -62,7 +62,15 @@
command: |
sh scripts/buck_setup.sh
- - name: Build C10
+ - name: Build tools
+ run: |
+ buck build tools: --keep-going
+
+ - name: Run tools tests
+ run: |
+ buck test tools:selective_build_test tools:gen_oplist_test tools:gen_operators_yaml_test
+
+ - name: Build c10
run: |
buck build c10:c10
diff --git a/.lintrunner.toml b/.lintrunner.toml
index 1b98226..302ff02 100644
--- a/.lintrunner.toml
+++ b/.lintrunner.toml
@@ -16,6 +16,7 @@
'torch/lib/**',
'venv/**',
'**/*.pyi',
+ 'tools/test/test_selective_build.py',
]
command = [
'python3',
@@ -145,6 +146,10 @@
exclude_patterns = [
# (linbinyu) copied from internal repo
'tools/code_analyzer/gen_operators_yaml.py',
+ 'tools/gen_vulkan_spv.py',
+ 'tools/test/gen_operators_yaml_test.py',
+ 'tools/test/gen_oplist_test.py',
+ 'tools/test/test_selective_build.py',
]
command = [
'python3',
@@ -334,6 +339,7 @@
command = [
'python3',
'tools/linter/adapters/grep_linter.py',
+ # @lint-ignore TXT2
'--pattern= ',
'--linter-name=TABS',
'--error-name=saw some tabs',
@@ -565,6 +571,9 @@
'torch/_decomp/**/*.py',
'test/onnx/**/*.py',
]
+exclude_patterns = [
+ 'tools/gen_vulkan_spv.py',
+]
command = [
'python3',
'tools/linter/adapters/black_linter.py',
diff --git a/buckbuild.bzl b/buckbuild.bzl
index 42abc49..ecacc34 100644
--- a/buckbuild.bzl
+++ b/buckbuild.bzl
@@ -3,8 +3,6 @@
load("@bazel_skylib//lib:paths.bzl", "paths")
load("//tools/build_defs:fb_native_wrapper.bzl", "fb_native")
-load("//tools/build_defs:fb_python_binary.bzl", "fb_python_binary")
-load("//tools/build_defs:fb_python_library.bzl", "fb_python_library")
load("//tools/build_defs:fb_xplat_cxx_library.bzl", "fb_xplat_cxx_library")
load("//tools/build_defs:fb_xplat_genrule.bzl", "fb_xplat_genrule")
load("//tools/build_defs:fbsource_utils.bzl", "is_arvr_mode")
@@ -416,7 +414,7 @@
name = name,
default_outs = ["."],
outs = get_aten_generated_files(backends),
- cmd = "$(exe {}:gen_aten_bin) ".format(ROOT) + " ".join([
+ cmd = "$(exe {}torchgen:gen) ".format(ROOT_PATH) + " ".join([
"--source-path $(location {}:aten_src_path)/aten/src/ATen".format(ROOT),
"--install_dir $OUT",
] + extra_params),
@@ -442,7 +440,7 @@
name = genrule_name,
default_outs = ["."],
outs = get_unboxing_generated_files(),
- cmd = "$(exe {}:gen_unboxing_bin) ".format(ROOT) + " ".join([
+ cmd = "$(exe {}tools:gen_unboxing_bin) ".format(ROOT_PATH) + " ".join([
"--source-path $(location {}:aten_src_path)/aten/src/ATen".format(ROOT),
"--install_dir $OUT",
] + extra_params),
@@ -515,7 +513,7 @@
# @lint-ignore BUCKLINT
fb_native.genrule(
name = oplist_dir_name,
- cmd = ("$(exe {}:gen_oplist) ".format(ROOT) +
+ cmd = ("$(exe {}tools:gen_oplist) ".format(ROOT_PATH) +
"--model_file_list_path $(@query_outputs 'attrfilter(labels, pt_operator_library, deps(set({deps})))') " +
("" if enforce_traced_op_list else "--allow_include_all_overloads ") +
"--output_dir $OUT ").format(deps = " ".join(["\"{}\"".format(d) for d in deps])),
@@ -620,7 +618,7 @@
outs = get_generate_code_bin_outs(),
default_outs = ["."],
bash = "mkdir -p tools && " +
- "$(exe {}tools/setup_helpers:generate_code_bin) ".format(ROOT_PATH) + " ".join(
+ "$(exe {}tools:generate_code_bin) ".format(ROOT_PATH) + " ".join(
# Mobile build only needs libtorch - skip python bindings for now, except
# for ovrsource, which needs Python bindings.
(["--subset libtorch"] if not is_arvr_mode() else []) + [
@@ -630,7 +628,7 @@
] + extra_params,
),
cmd_exe = "@powershell -Command New-Item -Path tools -ItemType Directory -Force; " +
- "$(exe {}tools/setup_helpers:generate_code_bin) ".format(ROOT_PATH) + " ".join(
+ "$(exe {}tools:generate_code_bin) ".format(ROOT_PATH) + " ".join(
# Mobile build only needs libtorch - skip python bindings for now, except
# for ovrsource, which needs Python bindings.
(["--subset libtorch"] if not is_arvr_mode() else []) + [
@@ -950,7 +948,7 @@
"torch/csrc/api/include/torch/version.h.in",
"version.txt",
],
- cmd = "$(exe {}tools/setup_helpers:gen-version-header) ".format(ROOT_PATH) + " ".join([
+ cmd = "$(exe {}tools:gen-version-header) ".format(ROOT_PATH) + " ".join([
"--template-path",
"torch/csrc/api/include/torch/version.h.in",
"--version-path",
@@ -995,28 +993,13 @@
],
)
- fb_python_library(
- name = "substitutelib",
- srcs = ["tools/substitute.py"],
- base_module = "",
- )
-
- fb_python_binary(
- name = "substitute",
- main_module = "tools.substitute",
- visibility = ["PUBLIC"],
- deps = [
- ":substitutelib",
- ],
- )
-
# @lint-ignore BUCKLINT
fb_native.genrule(
name = "generate_aten_config",
srcs = [
"aten/src/ATen/Config.h.in",
],
- cmd = "$(exe :substitute) " + " ".join([
+ cmd = "$(exe {}tools:substitute) ".format(ROOT_PATH) + " ".join([
"--install_dir",
"$OUT",
"--input-file",
@@ -1072,79 +1055,6 @@
default_outs = ["."],
)
- fb_python_binary(
- name = "gen_aten_bin",
- main_module = "torchgen.gen",
- visibility = [
- "PUBLIC",
- ],
- deps = [
- ROOT_PATH + "torchgen:torchgen",
- ],
- )
-
- fb_python_binary(
- name = "gen_unboxing_bin",
- main_module = "tools.jit.gen_unboxing",
- visibility = [
- "PUBLIC",
- ],
- deps = [
- ROOT_PATH + "tools/jit:jit",
- ],
- )
-
- fb_python_library(
- name = "gen_oplist_lib",
- srcs = subdir_glob([
- ("tools/code_analyzer", "gen_oplist.py"),
- ("tools/code_analyzer", "gen_op_registration_allowlist.py"),
- ]),
- base_module = "",
- tests = [
- ":gen_oplist_test",
- ],
- deps = [
- third_party("pyyaml"),
- ROOT_PATH + "tools/lite_interpreter:gen_selected_mobile_ops_header",
- ROOT_PATH + "torchgen:torchgen",
- ],
- )
-
- fb_python_library(
- name = "gen_operators_yaml_lib",
- srcs = subdir_glob([
- ("tools/code_analyzer", "gen_operators_yaml.py"),
- ("tools/code_analyzer", "gen_op_registration_allowlist.py"),
- ]),
- base_module = "",
- tests = [
- ":gen_operators_yaml_test",
- ],
- deps = [
- third_party("pyyaml"),
- ROOT_PATH + "torchgen:torchgen",
- ],
- )
-
- fb_python_binary(
- name = "gen_oplist",
- main_module = "gen_oplist",
- visibility = ["PUBLIC"],
- deps = [
- ":gen_oplist_lib",
- ],
- )
-
- fb_python_binary(
- name = "gen_operators_yaml",
- main_module = "gen_operators_yaml",
- visibility = ["PUBLIC"],
- deps = [
- ":gen_operators_yaml_lib",
- ],
- )
-
gen_aten_files(
name = "gen_aten",
extra_flags = get_aten_codegen_extra_params(USED_PT_BACKENDS),
diff --git a/cmake/VulkanCodegen.cmake b/cmake/VulkanCodegen.cmake
index c39b54d..075f2b3 100644
--- a/cmake/VulkanCodegen.cmake
+++ b/cmake/VulkanCodegen.cmake
@@ -62,7 +62,7 @@
execute_process(
COMMAND
"${PYTHON_EXECUTABLE}"
- ${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/gen_vulkan_spv.py
+ ${CMAKE_CURRENT_LIST_DIR}/../tools/gen_vulkan_spv.py
--glsl-path ${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/vulkan/glsl
--output-path ${VULKAN_GEN_OUTPUT_PATH}
--glslc-path=${GLSLC_PATH}
diff --git a/pt_ops.bzl b/pt_ops.bzl
index 2dd4ce3..a8089a9 100644
--- a/pt_ops.bzl
+++ b/pt_ops.bzl
@@ -39,7 +39,7 @@
name = name,
out = "model_operators.yaml",
cmd = (
- "$(exe {root}:gen_operators_yaml) " +
+ "$(exe {exe}) " +
"{optionally_root_ops} " +
"{optionally_training_root_ops} " +
"--rule_name {rule_name} " +
@@ -52,7 +52,7 @@
"{optionally_model_traced_backends} " +
"{optionally_include_all_operators}"
).format(
- root = "//" if IS_OSS else "//xplat/caffe2",
+ exe = "//tools:gen_operators_yaml" if IS_OSS else "//xplat/caffe2/tools:gen_operators_yaml",
rule_name = name,
model_name = model_name,
dep_graph_yaml = "none" if IS_OSS else "$(location //xplat/caffe2:pytorch_op_deps)/fb/pytorch_op_deps.yaml ",
diff --git a/tools/BUCK.bzl b/tools/BUCK.bzl
new file mode 100644
index 0000000..959a73d
--- /dev/null
+++ b/tools/BUCK.bzl
@@ -0,0 +1,263 @@
+# @lint-ignore-every FBCODEBZLADDLOADS
+load("//tools/build_defs:glob_defs.bzl", "subdir_glob")
+
+# shared by internal and OSS BUCK
+def define_tools_targets(
+ python_binary,
+ python_library,
+ python_test,
+ third_party,
+ torchgen_deps,
+ contacts = []):
+ python_library(
+ name = "substitutelib",
+ srcs = ["substitute.py"],
+ base_module = "",
+ )
+
+ python_binary(
+ name = "substitute",
+ main_module = "substitute",
+ visibility = ["PUBLIC"],
+ deps = [
+ ":substitutelib",
+ ],
+ )
+
+ python_library(
+ name = "jit",
+ # @lint-ignore BUCKRESTRICTEDSYNTAX
+ srcs = glob([
+ "jit/*.py",
+ "jit/templates/*",
+ ]),
+ base_module = "tools",
+ visibility = ["PUBLIC"],
+ deps = [
+ torchgen_deps,
+ ],
+ )
+
+ python_binary(
+ name = "gen_unboxing_bin",
+ main_module = "tools.jit.gen_unboxing",
+ visibility = [
+ "PUBLIC",
+ ],
+ deps = [
+ ":jit",
+ ],
+ )
+
+ python_library(
+ name = "gen_selected_mobile_ops_header",
+ srcs = ["lite_interpreter/gen_selected_mobile_ops_header.py"],
+ base_module = "tools",
+ visibility = ["PUBLIC"],
+ )
+
+ python_library(
+ name = "gen_oplist_lib",
+ srcs = subdir_glob([
+ ("code_analyzer", "gen_oplist.py"),
+ ("code_analyzer", "gen_op_registration_allowlist.py"),
+ ]),
+ base_module = "",
+ tests = [
+ ":gen_oplist_test",
+ ],
+ deps = [
+ ":gen_selected_mobile_ops_header",
+ torchgen_deps,
+ third_party("pyyaml"),
+ ],
+ )
+
+ python_binary(
+ name = "gen_oplist",
+ main_module = "gen_oplist",
+ visibility = ["PUBLIC"],
+ deps = [
+ ":gen_oplist_lib",
+ ],
+ )
+
+ python_library(
+ name = "gen_operators_yaml_lib",
+ srcs = subdir_glob([
+ ("code_analyzer", "gen_operators_yaml.py"),
+ ("code_analyzer", "gen_op_registration_allowlist.py"),
+ ]),
+ base_module = "",
+ tests = [
+ ":gen_operators_yaml_test",
+ ],
+ deps = [
+ third_party("pyyaml"),
+ torchgen_deps,
+ ],
+ )
+
+ python_binary(
+ name = "gen_operators_yaml",
+ main_module = "gen_operators_yaml",
+ visibility = ["PUBLIC"],
+ deps = [
+ ":gen_operators_yaml_lib",
+ ],
+ )
+
+ python_library(
+ name = "autograd",
+ # @lint-ignore BUCKRESTRICTEDSYNTAX
+ srcs = glob(
+ ["autograd/*.py"],
+ ),
+ base_module = "tools",
+ resources = [
+ "autograd/deprecated.yaml",
+ "autograd/derivatives.yaml",
+ "autograd/templates/ADInplaceOrViewType.cpp",
+ "autograd/templates/Functions.cpp",
+ "autograd/templates/Functions.h",
+ "autograd/templates/TraceType.cpp",
+ "autograd/templates/VariableType.cpp",
+ "autograd/templates/VariableType.h",
+ "autograd/templates/annotated_fn_args.py.in",
+ "autograd/templates/python_enum_tag.cpp",
+ "autograd/templates/python_fft_functions.cpp",
+ "autograd/templates/python_functions.cpp",
+ "autograd/templates/python_functions.h",
+ "autograd/templates/python_linalg_functions.cpp",
+ "autograd/templates/python_nn_functions.cpp",
+ "autograd/templates/python_return_types.cpp",
+ "autograd/templates/python_sparse_functions.cpp",
+ "autograd/templates/python_special_functions.cpp",
+ "autograd/templates/python_torch_functions.cpp",
+ "autograd/templates/python_variable_methods.cpp",
+ "autograd/templates/variable_factories.h",
+ ],
+ visibility = ["PUBLIC"],
+ deps = [
+ third_party("pyyaml"),
+ torchgen_deps,
+ ],
+ )
+
+ python_library(
+ name = "generate_code",
+ srcs = [
+ "setup_helpers/generate_code.py",
+ ],
+ base_module = "tools",
+ deps = [
+ ":autograd",
+ ":jit",
+ torchgen_deps,
+ ],
+ )
+
+ python_binary(
+ name = "generate_code_bin",
+ main_module = "tools.setup_helpers.generate_code",
+ # Windows does not support inplace:
+ # https://github.com/facebook/buck/issues/2161.
+ #
+ # Note that //arvr/mode/embedded/win/clang-aarch64-release sets
+ # its target platform to
+ # ovr_config//platform/embedded:clang-aarch64-linux-release, hence
+ # that is why we are selecting that OS to trigger this behavior.
+ package_style = select({
+ "DEFAULT": "inplace",
+ "ovr_config//os:linux-arm64": "standalone",
+ }),
+ visibility = ["PUBLIC"],
+ # Because Windows does not support inplace packaging, we need to
+ # ensure it is unzipped before executing it, otherwise it will not
+ # be able to find any resources using path manipulation.
+ #
+ # See note above about why the OS is Linux here and not Windows.
+ zip_safe = select({
+ "DEFAULT": True,
+ "ovr_config//os:linux-arm64": False,
+ }),
+ deps = [
+ ":generate_code",
+ ],
+ )
+
+ python_library(
+ name = "gen-version-header-lib",
+ srcs = [
+ "setup_helpers/gen_version_header.py",
+ ],
+ base_module = "",
+ deps = [],
+ )
+
+ python_binary(
+ name = "gen-version-header",
+ main_module = "setup_helpers.gen_version_header",
+ visibility = ["PUBLIC"],
+ deps = [
+ ":gen-version-header-lib",
+ ],
+ )
+
+ python_library(
+ name = "gen_aten_vulkan_spv_lib",
+ srcs = [
+ "gen_vulkan_spv.py",
+ ],
+ base_module = "",
+ deps = [
+ torchgen_deps,
+ ],
+ )
+
+ python_binary(
+ name = "gen_aten_vulkan_spv_bin",
+ main_module = "gen_vulkan_spv",
+ visibility = [
+ "PUBLIC",
+ ],
+ deps = [
+ ":gen_aten_vulkan_spv_lib",
+ ],
+ )
+
+ python_test(
+ name = "selective_build_test",
+ srcs = [
+ "test/test_selective_build.py",
+ ],
+ contacts = contacts,
+ visibility = ["PUBLIC"],
+ deps = [
+ torchgen_deps,
+ ],
+ )
+
+ python_test(
+ name = "gen_oplist_test",
+ srcs = [
+ "test/gen_oplist_test.py",
+ ],
+ contacts = contacts,
+ visibility = ["PUBLIC"],
+ deps = [
+ ":gen_oplist_lib",
+ ],
+ )
+
+ python_test(
+ name = "gen_operators_yaml_test",
+ srcs = [
+ "test/gen_operators_yaml_test.py",
+ ],
+ visibility = ["PUBLIC"],
+ contacts = contacts,
+ deps = [
+ ":gen_operators_yaml_lib",
+ ],
+ )
diff --git a/tools/BUCK.oss b/tools/BUCK.oss
new file mode 100644
index 0000000..97f6794
--- /dev/null
+++ b/tools/BUCK.oss
@@ -0,0 +1,10 @@
+load("//:buckbuild.bzl", "third_party")
+load(":BUCK.bzl", "define_tools_targets")
+
+define_tools_targets(
+ python_binary = python_binary,
+ python_library = python_library,
+ python_test = python_test,
+ third_party = third_party,
+ torchgen_deps = "//torchgen:torchgen",
+)
diff --git a/tools/autograd/BUCK.oss b/tools/autograd/BUCK.oss
deleted file mode 100644
index 04403f4..0000000
--- a/tools/autograd/BUCK.oss
+++ /dev/null
@@ -1,35 +0,0 @@
-python_library(
- name = "autograd",
- srcs = glob(
- ["*.py"],
- ),
- base_module = "tools.autograd",
- resources = [
- "deprecated.yaml",
- "derivatives.yaml",
- "templates/ADInplaceOrViewType.cpp",
- "templates/Functions.cpp",
- "templates/Functions.h",
- "templates/TraceType.cpp",
- "templates/VariableType.cpp",
- "templates/VariableType.h",
- "templates/annotated_fn_args.py.in",
- "templates/python_fft_functions.cpp",
- "templates/python_functions.cpp",
- "templates/python_functions.h",
- "templates/python_linalg_functions.cpp",
- "templates/python_nn_functions.cpp",
- "templates/python_return_types.cpp",
- "templates/python_sparse_functions.cpp",
- "templates/python_special_functions.cpp",
- "templates/python_torch_functions.cpp",
- "templates/python_variable_methods.cpp",
- "templates/variable_factories.h",
- "templates/python_enum_tag.cpp",
- ],
- visibility = ["PUBLIC"],
- deps = [
- "//third_party:pyyaml",
- "//torchgen:torchgen",
- ],
-)
diff --git a/tools/build_defs/fb_python_binary.bzl b/tools/build_defs/fb_python_binary.bzl
deleted file mode 100644
index 5e69f32..0000000
--- a/tools/build_defs/fb_python_binary.bzl
+++ /dev/null
@@ -1,9 +0,0 @@
-# Only used for PyTorch open source BUCK build
-# @lint-ignore-every BUCKRESTRICTEDSYNTAX
-# @lint-ignore-every FBCODEBZLADDLOADS
-
-def fb_python_binary(**kwgs):
- if read_config("pt", "is_oss", "0") == "0":
- fail("This file is for open source pytorch build. Do not use it in fbsource!")
-
- python_binary(**kwgs)
diff --git a/tools/build_defs/fb_python_library.bzl b/tools/build_defs/fb_python_library.bzl
deleted file mode 100644
index e0ab86f..0000000
--- a/tools/build_defs/fb_python_library.bzl
+++ /dev/null
@@ -1,9 +0,0 @@
-# Only used for PyTorch open source BUCK build
-# @lint-ignore-every BUCKRESTRICTEDSYNTAX
-# @lint-ignore-every FBCODEBZLADDLOADS
-
-def fb_python_library(**kwgs):
- if read_config("pt", "is_oss", "0") == "0":
- fail("This file is for open source pytorch build. Do not use it in fbsource!")
-
- python_library(**kwgs)
diff --git a/aten/src/ATen/gen_vulkan_spv.py b/tools/gen_vulkan_spv.py
similarity index 100%
rename from aten/src/ATen/gen_vulkan_spv.py
rename to tools/gen_vulkan_spv.py
diff --git a/tools/jit/BUCK.oss b/tools/jit/BUCK.oss
deleted file mode 100644
index 8c0105f..0000000
--- a/tools/jit/BUCK.oss
+++ /dev/null
@@ -1,12 +0,0 @@
-python_library(
- name = "jit",
- srcs = glob([
- "*.py",
- "templates/*",
- ]),
- base_module = "tools.jit",
- visibility = ["PUBLIC"],
- deps = [
- "//torchgen:torchgen",
- ],
-)
diff --git a/tools/lite_interpreter/BUCK.oss b/tools/lite_interpreter/BUCK.oss
deleted file mode 100644
index 10415c2..0000000
--- a/tools/lite_interpreter/BUCK.oss
+++ /dev/null
@@ -1,6 +0,0 @@
-python_library(
- name = "gen_selected_mobile_ops_header",
- srcs = ["gen_selected_mobile_ops_header.py"],
- base_module = "tools.lite_interpreter",
- visibility = ["PUBLIC"],
-)
diff --git a/tools/setup_helpers/BUCK.oss b/tools/setup_helpers/BUCK.oss
deleted file mode 100644
index afcd31f..0000000
--- a/tools/setup_helpers/BUCK.oss
+++ /dev/null
@@ -1,41 +0,0 @@
-python_library(
- name = "generate_code",
- srcs = [
- "generate_code.py",
- ],
- base_module = "tools.setup_helpers",
- deps = [
- "//tools/autograd:autograd",
- "//tools/jit:jit",
- "//torchgen:torchgen",
- ],
-)
-
-python_binary(
- name = "generate_code_bin",
- main_module = "tools.setup_helpers.generate_code",
- visibility = ["PUBLIC"],
- # package_style = "inplace",
- zip_safe = False,
- deps = [
- ":generate_code",
- ],
-)
-
-python_library(
- name = "gen-version-header-lib",
- srcs = [
- "gen_version_header.py",
- ],
- base_module = "tools.setup_helpers",
- deps = [],
-)
-
-python_binary(
- name = "gen-version-header",
- main_module = "tools.setup_helpers.gen_version_header",
- visibility = ["PUBLIC"],
- deps = [
- ":gen-version-header-lib",
- ],
-)
diff --git a/tools/test/gen_operators_yaml_test.py b/tools/test/gen_operators_yaml_test.py
new file mode 100644
index 0000000..87455d3
--- /dev/null
+++ b/tools/test/gen_operators_yaml_test.py
@@ -0,0 +1,190 @@
+#!/usr/bin/env python3
+# Copyright 2004-present Facebook. All Rights Reserved.
+
+import unittest
+
+from gen_operators_yaml import make_filter_from_options, verify_all_specified_present
+
+
+class GenOperatorsYAMLTest(unittest.TestCase):
+ def setUp(self):
+ pass
+
+ def test_filter_creation(self):
+ filter_func = make_filter_from_options(
+ model_name="abc",
+ model_versions=["100", "101"],
+ model_assets=None,
+ model_backends=None,
+ )
+ config = [
+ {
+ "model": {
+ "name": "abc",
+ "version": 100,
+ "asset": "asset-1",
+ "backend": "CPU",
+ },
+ "root_operators": [],
+ "traced_operators": [],
+ },
+ {
+ "model": {
+ "name": "abc",
+ "version": 102,
+ "asset": "asset-1",
+ "backend": "CPU",
+ },
+ "root_operators": [],
+ },
+ {
+ "model": {
+ "name": "abcd",
+ "version": 100,
+ "asset": "asset-1",
+ "backend": "CPU",
+ },
+ "root_operators": [],
+ "traced_operators": [],
+ },
+ {
+ "model": {
+ "name": "abc",
+ "version": 101,
+ "asset": "asset-2",
+ "backend": "CPU",
+ },
+ "root_operators": [],
+ },
+ ]
+
+ filtered_configs = list(filter(filter_func, config))
+ assert (
+ len(filtered_configs) == 2
+ ), "Expected 2 elements in filtered_configs, but got {}".format(
+ len(filtered_configs)
+ )
+
+ def test_verification_success(self):
+ filter_func = make_filter_from_options(
+ model_name="abc",
+ model_versions=["100", "101"],
+ model_assets=["asset-1", "asset-2"],
+ model_backends=None,
+ )
+ config = [
+ {
+ "model": {
+ "name": "abc",
+ "version": 100,
+ "asset": "asset-1",
+ "backend": "CPU",
+ },
+ "root_operators": [],
+ "traced_operators": [],
+ },
+ {
+ "model": {
+ "name": "abc",
+ "version": 101,
+ "asset": "asset-2",
+ "backend": "CPU",
+ },
+ "root_operators": [],
+ },
+ ]
+ filtered_configs = list(filter(filter_func, config))
+ try:
+ verify_all_specified_present(
+ model_assets=["asset-1", "asset-2"],
+ model_versions=["100", "101"],
+ selected_models_yaml=filtered_configs,
+ rule_name="test",
+ model_name="abc",
+ new_style_rule=True,
+ )
+ except Exception:
+ self.fail(
+ "expected verify_all_specified_present to succeed instead it raised an exception"
+ )
+
+ def test_verification_fail(self):
+ config = [
+ {
+ "model": {
+ "name": "abc",
+ "version": 100,
+ "asset": "asset-1",
+ "backend": "CPU",
+ },
+ "root_operators": [],
+ "traced_operators": [],
+ },
+ {
+ "model": {
+ "name": "abc",
+ "version": 101,
+ "asset": "asset-2",
+ "backend": "CPU",
+ },
+ "root_operators": [],
+ },
+ ]
+
+ good_assets = ["asset-1", "asset-2"]
+ good_versions = ["100", "101"]
+ good_name = "abc"
+
+ # Test bad asset
+ filter_func_bad_asset = make_filter_from_options(
+ model_name=good_name,
+ model_versions=good_versions,
+ model_assets=["asset-1", "asset-2", "asset-3"],
+ model_backends=None,
+ )
+ filtered_configs_asset = list(filter(filter_func_bad_asset, config))
+ with self.assertRaises(RuntimeError):
+ verify_all_specified_present(
+ model_assets=["asset-1", "asset-2", "asset-3"],
+ model_versions=good_versions,
+ selected_models_yaml=filtered_configs_asset,
+ rule_name="test",
+ model_name=good_name,
+ new_style_rule=True,
+ )
+
+ # Test bad version
+ filter_func_bad_version = make_filter_from_options(
+ model_name=good_name,
+ model_versions=["100", "101", "102"],
+ model_assets=good_assets,
+ model_backends=None,
+ )
+ filtered_configs_version = list(filter(filter_func_bad_version, config))
+ with self.assertRaises(RuntimeError):
+ verify_all_specified_present(
+ model_assets=good_assets,
+ model_versions=["100", "101", "102"],
+ selected_models_yaml=filtered_configs_version,
+ rule_name="test",
+ model_name=good_name,
+ new_style_rule=True,
+ )
+
+ # Test bad name
+ filter_func_bad_name = make_filter_from_options(
+ model_name="abcd",
+ model_versions=good_versions,
+ model_assets=good_assets,
+ model_backends=None,
+ )
+ filtered_configs_name = list(filter(filter_func_bad_name, config))
+ with self.assertRaises(RuntimeError):
+ verify_all_specified_present(
+ model_assets=good_assets,
+ model_versions=good_versions,
+ selected_models_yaml=filtered_configs_name,
+ rule_name="test",
+ model_name="abcd",
+ new_style_rule=True,
+ )
diff --git a/tools/test/gen_oplist_test.py b/tools/test/gen_oplist_test.py
new file mode 100644
index 0000000..d58e2cc
--- /dev/null
+++ b/tools/test/gen_oplist_test.py
@@ -0,0 +1,35 @@
+#!/usr/bin/env python3
+# Copyright 2004-present Facebook. All Rights Reserved.
+
+import unittest
+from unittest.mock import MagicMock
+
+from gen_oplist import throw_if_any_op_includes_overloads
+
+
+class GenOplistTest(unittest.TestCase):
+ def setUp(self):
+ pass
+
+ def test_throw_if_any_op_includes_overloads(self):
+ selective_builder = MagicMock()
+ selective_builder.operators = MagicMock()
+ selective_builder.operators.items.return_value = [
+ ("op1", MagicMock(include_all_overloads=True)),
+ ("op2", MagicMock(include_all_overloads=False)),
+ ("op3", MagicMock(include_all_overloads=True)),
+ ]
+
+ self.assertRaises(
+ Exception, throw_if_any_op_includes_overloads, selective_builder
+ )
+
+ selective_builder.operators.items.return_value = [
+ ("op1", MagicMock(include_all_overloads=False)),
+ ("op2", MagicMock(include_all_overloads=False)),
+ ("op3", MagicMock(include_all_overloads=False)),
+ ]
+
+ # Here we do not expect it to throw an exception since none of the ops
+ # include all overloads.
+ throw_if_any_op_includes_overloads(selective_builder)
diff --git a/tools/test/test_selective_build.py b/tools/test/test_selective_build.py
new file mode 100644
index 0000000..50a3ba5
--- /dev/null
+++ b/tools/test/test_selective_build.py
@@ -0,0 +1,281 @@
+from __future__ import absolute_import, division, print_function, unicode_literals
+
+import unittest
+
+from torchgen.selective_build.operator import *
+from torchgen.selective_build.selector import (
+ combine_selective_builders,
+ SelectiveBuilder,
+)
+
+
+class TestSelectiveBuild(unittest.TestCase):
+ def test_selective_build_operator(self):
+ op = SelectiveBuildOperator(
+ "aten::add.int",
+ is_root_operator=True,
+ is_used_for_training=False,
+ include_all_overloads=False,
+ _debug_info=None,
+ )
+ self.assertTrue(op.is_root_operator)
+ self.assertFalse(op.is_used_for_training)
+ self.assertFalse(op.include_all_overloads)
+
+ def test_selector_factory(self):
+ yaml_config_v1 = """
+debug_info:
+ - model1@v100
+ - model2@v51
+operators:
+ aten::add:
+ is_used_for_training: No
+ is_root_operator: Yes
+ include_all_overloads: Yes
+ aten::add.int:
+ is_used_for_training: Yes
+ is_root_operator: No
+ include_all_overloads: No
+ aten::mul.int:
+ is_used_for_training: Yes
+ is_root_operator: No
+ include_all_overloads: No
+"""
+
+ yaml_config_v2 = """
+debug_info:
+ - model1@v100
+ - model2@v51
+operators:
+ aten::sub:
+ is_used_for_training: No
+ is_root_operator: Yes
+ include_all_overloads: No
+ debug_info:
+ - model1@v100
+ aten::sub.int:
+ is_used_for_training: Yes
+ is_root_operator: No
+ include_all_overloads: No
+"""
+
+ yaml_config_all = "include_all_operators: Yes"
+
+ yaml_config_invalid = "invalid:"
+
+ selector1 = SelectiveBuilder.from_yaml_str(yaml_config_v1)
+
+ self.assertTrue(selector1.is_operator_selected("aten::add"))
+ self.assertTrue(selector1.is_operator_selected("aten::add.int"))
+ # Overload name is not used for checking in v1.
+ self.assertTrue(selector1.is_operator_selected("aten::add.float"))
+
+ def gen():
+ return SelectiveBuilder.from_yaml_str(yaml_config_invalid)
+
+ self.assertRaises(Exception, gen)
+
+ selector_all = SelectiveBuilder.from_yaml_str(yaml_config_all)
+
+ self.assertTrue(selector_all.is_operator_selected("aten::add"))
+ self.assertTrue(selector_all.is_operator_selected("aten::sub"))
+ self.assertTrue(selector_all.is_operator_selected("aten::sub.int"))
+ self.assertTrue(selector_all.is_kernel_dtype_selected("add_kernel", "int32"))
+
+ selector2 = SelectiveBuilder.from_yaml_str(yaml_config_v2)
+
+ self.assertFalse(selector2.is_operator_selected("aten::add"))
+ self.assertTrue(selector2.is_operator_selected("aten::sub"))
+ self.assertTrue(selector2.is_operator_selected("aten::sub.int"))
+
+ selector_legacy_v1 = SelectiveBuilder.from_legacy_op_registration_allow_list(
+ ["aten::add", "aten::add.int", "aten::mul.int"],
+ False,
+ False,
+ )
+ self.assertTrue(selector_legacy_v1.is_operator_selected("aten::add.float"))
+ self.assertTrue(selector_legacy_v1.is_operator_selected("aten::add"))
+ self.assertTrue(selector_legacy_v1.is_operator_selected("aten::add.int"))
+ self.assertFalse(selector_legacy_v1.is_operator_selected("aten::sub"))
+
+ self.assertFalse(selector_legacy_v1.is_root_operator("aten::add"))
+ self.assertFalse(
+ selector_legacy_v1.is_operator_selected_for_training("aten::add")
+ )
+
+ selector_legacy_v1 = SelectiveBuilder.from_legacy_op_registration_allow_list(
+ ["aten::add", "aten::add.int", "aten::mul.int"],
+ True,
+ False,
+ )
+
+ self.assertTrue(selector_legacy_v1.is_root_operator("aten::add"))
+ self.assertFalse(
+ selector_legacy_v1.is_operator_selected_for_training("aten::add")
+ )
+ self.assertTrue(selector_legacy_v1.is_root_operator("aten::add.float"))
+ self.assertFalse(
+ selector_legacy_v1.is_operator_selected_for_training("aten::add.float")
+ )
+
+ selector_legacy_v1 = SelectiveBuilder.from_legacy_op_registration_allow_list(
+ ["aten::add", "aten::add.int", "aten::mul.int"],
+ False,
+ True,
+ )
+
+ self.assertFalse(selector_legacy_v1.is_root_operator("aten::add"))
+ self.assertTrue(
+ selector_legacy_v1.is_operator_selected_for_training("aten::add")
+ )
+ self.assertFalse(selector_legacy_v1.is_root_operator("aten::add.float"))
+ self.assertTrue(
+ selector_legacy_v1.is_operator_selected_for_training("aten::add.float")
+ )
+
+ def test_operator_combine(self):
+ op1 = SelectiveBuildOperator(
+ "aten::add.int",
+ is_root_operator=True,
+ is_used_for_training=False,
+ include_all_overloads=False,
+ _debug_info=None,
+ )
+ op2 = SelectiveBuildOperator(
+ "aten::add.int",
+ is_root_operator=False,
+ is_used_for_training=False,
+ include_all_overloads=False,
+ _debug_info=None,
+ )
+ op3 = SelectiveBuildOperator(
+ "aten::add",
+ is_root_operator=True,
+ is_used_for_training=False,
+ include_all_overloads=False,
+ _debug_info=None,
+ )
+ op4 = SelectiveBuildOperator(
+ "aten::add.int",
+ is_root_operator=True,
+ is_used_for_training=True,
+ include_all_overloads=False,
+ _debug_info=None,
+ )
+
+ op5 = combine_operators(op1, op2)
+
+ self.assertTrue(op5.is_root_operator)
+ self.assertFalse(op5.is_used_for_training)
+
+ op6 = combine_operators(op1, op4)
+
+ self.assertTrue(op6.is_root_operator)
+ self.assertTrue(op6.is_used_for_training)
+
+ def gen_new_op():
+ return combine_operators(op1, op3)
+
+ self.assertRaises(Exception, gen_new_op)
+
+ def test_training_op_fetch(self):
+ yaml_config = """
+operators:
+ aten::add.int:
+ is_used_for_training: No
+ is_root_operator: Yes
+ include_all_overloads: No
+ aten::add:
+ is_used_for_training: Yes
+ is_root_operator: No
+ include_all_overloads: Yes
+"""
+
+ selector = SelectiveBuilder.from_yaml_str(yaml_config)
+ self.assertTrue(selector.is_operator_selected_for_training("aten::add.int"))
+ self.assertTrue(selector.is_operator_selected_for_training("aten::add"))
+
+ def test_kernel_dtypes(self):
+ yaml_config = """
+kernel_metadata:
+ add_kernel:
+ - int8
+ - int32
+ sub_kernel:
+ - int16
+ - int32
+ add/sub_kernel:
+ - float
+ - complex
+"""
+
+ selector = SelectiveBuilder.from_yaml_str(yaml_config)
+
+ self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "int32"))
+ self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "int8"))
+ self.assertFalse(selector.is_kernel_dtype_selected("add_kernel", "int16"))
+ self.assertFalse(selector.is_kernel_dtype_selected("add1_kernel", "int32"))
+ self.assertFalse(selector.is_kernel_dtype_selected("add_kernel", "float"))
+
+ self.assertTrue(selector.is_kernel_dtype_selected("add/sub_kernel", "float"))
+ self.assertTrue(selector.is_kernel_dtype_selected("add/sub_kernel", "complex"))
+ self.assertFalse(selector.is_kernel_dtype_selected("add/sub_kernel", "int16"))
+ self.assertFalse(selector.is_kernel_dtype_selected("add/sub_kernel", "int32"))
+
+ def test_merge_kernel_dtypes(self):
+ yaml_config1 = """
+kernel_metadata:
+ add_kernel:
+ - int8
+ add/sub_kernel:
+ - float
+ - complex
+ - none
+ mul_kernel:
+ - int8
+"""
+
+ yaml_config2 = """
+kernel_metadata:
+ add_kernel:
+ - int32
+ sub_kernel:
+ - int16
+ - int32
+ add/sub_kernel:
+ - float
+ - complex
+"""
+
+ selector1 = SelectiveBuilder.from_yaml_str(yaml_config1)
+ selector2 = SelectiveBuilder.from_yaml_str(yaml_config2)
+
+ selector = combine_selective_builders(selector1, selector2)
+
+ self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "int32"))
+ self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "int8"))
+ self.assertFalse(selector.is_kernel_dtype_selected("add_kernel", "int16"))
+ self.assertFalse(selector.is_kernel_dtype_selected("add1_kernel", "int32"))
+ self.assertFalse(selector.is_kernel_dtype_selected("add_kernel", "float"))
+
+ self.assertTrue(selector.is_kernel_dtype_selected("add/sub_kernel", "float"))
+ self.assertTrue(selector.is_kernel_dtype_selected("add/sub_kernel", "complex"))
+ self.assertTrue(selector.is_kernel_dtype_selected("add/sub_kernel", "none"))
+ self.assertFalse(selector.is_kernel_dtype_selected("add/sub_kernel", "int16"))
+ self.assertFalse(selector.is_kernel_dtype_selected("add/sub_kernel", "int32"))
+
+ self.assertTrue(selector.is_kernel_dtype_selected("mul_kernel", "int8"))
+ self.assertFalse(selector.is_kernel_dtype_selected("mul_kernel", "int32"))
+
+ def test_all_kernel_dtypes_selected(self):
+ yaml_config = """
+include_all_non_op_selectives: True
+"""
+
+ selector = SelectiveBuilder.from_yaml_str(yaml_config)
+
+ self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "int32"))
+ self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "int8"))
+ self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "int16"))
+ self.assertTrue(selector.is_kernel_dtype_selected("add1_kernel", "int32"))
+ self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "float"))