blob: 7cfbb803f9b8bf8fa07b96fdcf47b07dab9d2c6b [file] [log] [blame]
#!/usr/bin/env python3
import os
from pathlib import Path
from torch.jit._decompositions import decomposition_table
# from torchgen.code_template import CodeTemplate
DECOMP_HEADER = r"""
/**
* @generated
* This is an auto-generated file. Please do not modify it by hand.
* To re-generate, please run:
* cd ~/pytorch && python torchgen/decompositions/gen_jit_decompositions.py
*/
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/passes/inliner.h>
#include <torch/csrc/jit/runtime/operator.h>
#include <torch/csrc/jit/runtime/decomposition_registry_util.h>
namespace torch {
namespace jit {
const std::string decomp_funcs =
R"("""
DECOMP_CENTER = r"""
)";
const std::string& GetSerializedDecompositions() {
return decomp_funcs;
}
const OperatorMap<std::string>& GetDecompositionMapping() {
// clang-format off
static const OperatorMap<std::string> decomposition_mapping {
"""
DECOMP_END = r"""
};
// clang-format on
return decomposition_mapping;
}
} // namespace jit
} // namespace torch
"""
DECOMPOSITION_UTIL_FILE_NAME = "decomposition_registry_util.cpp"
def gen_serialized_decompisitions() -> str:
return "\n".join(
[scripted_func.code for scripted_func in decomposition_table.values()]
)
def gen_decomposition_mappings() -> str:
decomposition_mappings = []
for schema, scripted_func in decomposition_table.items():
decomposition_mappings.append(
' {"' + schema + '", "' + scripted_func.name + '"},'
)
return "\n".join(decomposition_mappings)
def write_decomposition_util_file(path: str) -> None:
decomposition_str = gen_serialized_decompisitions()
decomposition_mappings = gen_decomposition_mappings()
file_components = [
DECOMP_HEADER,
decomposition_str,
DECOMP_CENTER,
decomposition_mappings,
DECOMP_END,
]
print("writing file to : ", path + "/" + DECOMPOSITION_UTIL_FILE_NAME)
with open(os.path.join(path, DECOMPOSITION_UTIL_FILE_NAME), "wb") as out_file:
final_output = "".join(file_components)
out_file.write(final_output.encode("utf-8"))
def main() -> None:
pytorch_dir = Path(__file__).resolve().parents[3]
upgrader_path = pytorch_dir / "torch" / "csrc" / "jit" / "runtime"
write_decomposition_util_file(str(upgrader_path))
if __name__ == "__main__":
main()