| import os |
| from collections import OrderedDict |
| from pathlib import Path |
| |
| import torch |
| import torch._prims as prims |
| from torchgen.gen import parse_native_yaml |
| |
| |
| ROOT = Path(__file__).absolute().parent.parent.parent.parent |
| NATIVE_FUNCTION_YAML_PATH = ROOT / Path("aten/src/ATen/native/native_functions.yaml") |
| TAGS_YAML_PATH = ROOT / Path("aten/src/ATen/native/tags.yaml") |
| |
| BUILD_DIR = "build/ir" |
| ATEN_OPS_CSV_FILE = "aten_ops.csv" |
| PRIMS_OPS_CSV_FILE = "prims_ops.csv" |
| |
| |
| def get_aten(): |
| parsed_yaml = parse_native_yaml(NATIVE_FUNCTION_YAML_PATH, TAGS_YAML_PATH) |
| native_functions = parsed_yaml.native_functions |
| |
| aten_ops = OrderedDict() |
| for function in native_functions: |
| if "core" in function.tags: |
| op_name = str(function.func.name) |
| aten_ops[op_name] = function |
| |
| op_schema_pairs = [] |
| for key, op in sorted(aten_ops.items()): |
| op_name = f"aten.{key}" |
| schema = str(op.func).replace("*", r"\*") |
| |
| op_schema_pairs.append((op_name, schema)) |
| |
| return op_schema_pairs |
| |
| |
| def get_prims(): |
| op_schema_pairs = [] |
| for op_name in prims.__all__: |
| op_overload = getattr(prims, op_name, None) |
| |
| if not isinstance(op_overload, torch._ops.OpOverload): |
| continue |
| |
| op_overloadpacket = op_overload.overloadpacket |
| |
| op_name = str(op_overload).replace(".default", "") |
| schema = op_overloadpacket.schema.replace("*", r"\*") |
| |
| op_schema_pairs.append((op_name, schema)) |
| |
| return op_schema_pairs |
| |
| |
| def main(): |
| aten_ops_list = get_aten() |
| prims_ops_list = get_prims() |
| |
| os.makedirs(BUILD_DIR, exist_ok=True) |
| |
| with open(os.path.join(BUILD_DIR, ATEN_OPS_CSV_FILE), "w") as f: |
| f.write("Operator,Schema\n") |
| for name, schema in aten_ops_list: |
| f.write(f'"``{name}``","{schema}"\n') |
| |
| with open(os.path.join(BUILD_DIR, PRIMS_OPS_CSV_FILE), "w") as f: |
| f.write("Operator,Schema\n") |
| for name, schema in prims_ops_list: |
| f.write(f'"``{name}``","{schema}"\n') |
| |
| |
| if __name__ == "__main__": |
| main() |