| import inspect |
| import textwrap |
| |
| import torch.jit |
| from torch.jit._builtins import _find_builtin |
| |
| # this file is for generating documentation using sphinx autodoc |
| # > help(torch.jit.supported_ops) will also give a nice listed of the |
| # supported ops programmatically |
| |
| |
| def _hidden(name): |
| return name.startswith("_") and not name.startswith("__") |
| |
| |
| def _emit_type(type): |
| return str(type) |
| |
| |
| def _emit_arg(indent, i, arg): |
| v = f"{arg.name} : {_emit_type(arg.type)}" |
| default = arg.default_value |
| if default is not None: |
| v = f"{v}={str(default)}" |
| if i > 0: |
| v = f"\n{' ' * indent}{v}" |
| return v |
| |
| |
| def _emit_args(indent, arguments): |
| return ",".join(_emit_arg(indent, i, arg) for i, arg in enumerate(arguments)) |
| |
| |
| def _emit_ret(ret): |
| return _emit_type(ret.type) |
| |
| |
| def _emit_rets(returns): |
| if len(returns) == 1: |
| return _emit_ret(returns[0]) |
| return f"Tuple[{', '.join(_emit_ret(r) for r in returns)}]" |
| |
| |
| def _emit_schema(mod, name, schema, arg_start=0, padding=4): |
| if mod is None: |
| qualified_name = name |
| else: |
| qualified_name = f"{mod}.{name}" |
| schema_str = "{}({}) -> {}".format( |
| qualified_name, |
| _emit_args(len(qualified_name) + 1 + padding, schema.arguments[arg_start:]), |
| _emit_rets(schema.returns), |
| ) |
| return schema_str |
| |
| |
| def _get_tensor_ops(): |
| def is_tensor_method(schema): |
| if len(schema.arguments) == 0: |
| return False |
| self = schema.arguments[0] |
| if self.name != "self": |
| return False |
| if not self.type.isSubtypeOf(torch._C.TensorType.get()): |
| return False |
| return True |
| |
| methods = [] |
| # discover methods |
| for elem in dir(torch.Tensor): |
| if not _hidden(elem): |
| schemas = torch._C._jit_get_schemas_for_operator("aten::" + elem) |
| for schema in schemas: |
| if is_tensor_method(schema): |
| methods.append(_emit_schema("Tensor", elem, schema, arg_start=1)) |
| |
| return "Supported Tensor Methods", methods |
| |
| |
| def _get_nn_functional_ops(): |
| functions = [] |
| |
| # Iterate over torch.nn.functional |
| mod = torch.nn.functional |
| name = mod.__name__ |
| for elem in dir(torch.nn.functional): |
| attr = getattr(mod, elem) |
| if not inspect.isfunction(attr) or _hidden(elem[0]): |
| # Ignore non-functions and internal methods |
| continue |
| |
| attr_module = inspect.getmodule(attr) |
| if not attr_module: |
| raise RuntimeError(f"Module for {attr} not found") |
| |
| if "torch.nn.functional" not in attr_module.__name__: |
| # Ignore functions from outside torch.nn.functional |
| continue |
| |
| try: |
| # compile fn, get schema |
| scripted = torch.jit.script(attr) |
| scripted_schema = scripted.schema |
| functions.append(_emit_schema(name, elem, scripted_schema)) |
| except: # noqa: B001,E722 |
| # Skip interpolate / boolean dispatched things |
| pass |
| |
| # Iterate over modules that we know contain a lot of builtins |
| for mod in torch.jit._builtins._modules_containing_builtins: |
| name = mod.__name__ |
| for elem in dir(mod): |
| builtin = _find_builtin(getattr(mod, elem)) |
| if builtin is not None: |
| schemas = torch._C._jit_get_schemas_for_operator(builtin) |
| for schema in schemas: |
| # remove _tan but not __and__ |
| if not _hidden(elem): |
| functions.append(_emit_schema(name, elem, schema)) |
| return "Supported PyTorch Functions", functions |
| |
| |
| def _get_builtins_helper(): |
| builtins = [] |
| for fn, _builtin_name in torch.jit._builtins._builtin_ops: |
| mod = inspect.getmodule(fn) |
| |
| if not hasattr(fn, "__name__"): |
| # typing classes |
| continue |
| if not mod: |
| continue |
| if _hidden(fn.__name__) or _hidden(fn.__qualname__) or _hidden(mod.__name__): |
| # skip internal-only methods |
| continue |
| |
| if "torch._C" in mod.__name__: |
| continue |
| |
| builtins.append((fn, _builtin_name)) |
| |
| return builtins |
| |
| |
| def _is_math_fn(fn): |
| mod = inspect.getmodule(fn) |
| if not mod: |
| raise RuntimeError(f"Module for {fn} not found") |
| |
| return mod.__name__ == "math" |
| |
| |
| def _get_torchscript_builtins(): |
| functions = [] |
| builtins = filter(lambda fn: not _is_math_fn(fn[0]), _get_builtins_helper()) |
| builtins_list = list(builtins) |
| # Iterate over the specially added builtins |
| for fn, _builtin_name in builtins_list: |
| mod = inspect.getmodule(fn) |
| if not mod: |
| raise RuntimeError(f"Module for {fn} not found") |
| builtin = _find_builtin(fn) |
| if builtin is not None: |
| schemas = torch._C._jit_get_schemas_for_operator(builtin) |
| for schema in schemas: |
| functions.append(_emit_schema(mod.__name__, fn.__name__, schema)) |
| pass |
| |
| return "TorchScript Builtin Functions", functions |
| |
| |
| def _get_math_builtins(): |
| functions = [] |
| builtins = filter(lambda fn: _is_math_fn(fn[0]), _get_builtins_helper()) |
| builtins_list = list(builtins) |
| # Iterate over the specially added builtins |
| for fn, _builtin_name in builtins_list: |
| mod = inspect.getmodule(fn) |
| if not mod: |
| raise RuntimeError(f"Module for {fn} not found") |
| builtin = _find_builtin(fn) |
| if builtin is not None: |
| schemas = torch._C._jit_get_schemas_for_operator(builtin) |
| for schema in schemas: |
| schema_str = _emit_schema(mod.__name__, fn.__name__, schema) |
| if "Tensor" in schema_str: |
| # Skip Tensor ops that have the same name as math functions |
| # (they will show up in the tensor methods section) |
| continue |
| functions.append(schema) |
| pass |
| |
| return "``math`` Module", functions |
| |
| |
| def _get_global_builtins(): |
| # Taken from the 'globals' map in torch/csrc/jit/frontend/ir_emitter.cpp |
| supported_builtins = [ |
| "print", |
| "tuple", |
| "float", |
| "complex", |
| "int", |
| "bool", |
| "str", |
| "getattr", |
| "hasattr", |
| "isinstance", |
| "len", |
| "hex", |
| "oct", |
| "round", |
| "hash", |
| "min", |
| "max", |
| "abs", |
| "all", |
| "divmod", |
| "list", |
| "ord", |
| "chr", |
| "bin", |
| "range", |
| "zip", |
| "enumerate", |
| "sorted", |
| ] |
| |
| op_renames = { |
| "bool": "aten::Bool", |
| "int": "aten::Int", |
| "float": "aten::Float", |
| "complex": "aten::Complex", |
| "abs": "prim::abs", |
| "max": "prim::max", |
| "min": "prim::min", |
| "range": "fake::does_not_exist", |
| } |
| |
| schemaless_op_explanations = { |
| "print": "Print any value", |
| "tuple": "Lists cannot be converted to tuples with this method since their size is not statically known", |
| "getattr": "Attribute name must be a literal string", |
| "hasattr": "Attribute name must be a literal string", |
| "isinstance": "Result is static", |
| "zip": "Arguments must be iterable. See :ref:`Iterables <jit_iterables>` for details.", |
| "enumerate": "Arguments must be iterable. See :ref:`Iterables <jit_iterables>` for details.", |
| "range": "Can only be used as an iterator in a for loop", |
| } |
| |
| magic_methods = [ |
| ("complex", "__complex__"), |
| ("float", "__float__"), |
| ("int", "__int__"), |
| ("bool", "__bool__"), |
| ("str", "__str__"), |
| ("len", "__len__"), |
| ("hex", "__hex__"), |
| ("oct", "__oct__"), |
| ] |
| |
| magic_methods_rows = [] |
| for fn, magic_method in magic_methods: |
| magic_methods_rows.append(f'"{fn}", "``{magic_method}``"') |
| |
| schematized_ops = [] |
| schemaless_ops = [] |
| |
| for fn in supported_builtins: |
| op_name = f"aten::{fn}" |
| if fn in op_renames: |
| op_name = op_renames[fn] |
| schemas = torch._C._jit_get_schemas_for_operator(op_name) |
| for s in schemas: |
| schematized_ops.append(_emit_schema(None, fn, s, padding=0)) |
| if len(schemas) > 0: |
| schematized_ops.append("") |
| else: |
| table_row = f'":any:`{fn}`", "{schemaless_op_explanations[fn]}"' |
| schemaless_ops.append(table_row) |
| |
| schematized_ops_str = "\n".join(schematized_ops) |
| schemaless_ops_str = "\n".join(schemaless_ops) |
| magic_methods_rows_str = "\n".join(magic_methods_rows) |
| schematized_ops_str = textwrap.indent(schematized_ops_str, "\t") |
| schemaless_ops_str = textwrap.indent(schemaless_ops_str, "\t") |
| magic_methods_rows_str = textwrap.indent(magic_methods_rows_str, "\t") |
| section = f""" |
| The functions in the following table are supported but do not have a static schema |
| |
| .. csv-table:: |
| :header: "Function", "Note" |
| |
| {schemaless_ops_str} |
| |
| The following functions will use the corresponding magic method on :any:`TorchScript classes` |
| |
| .. csv-table:: |
| :header: "Function", "Magic Method" |
| |
| {magic_methods_rows_str} |
| |
| These built-in functions use the schema |
| |
| .. rst-class:: codeblock-height-limiter |
| |
| :: |
| |
| {schematized_ops_str} |
| """ |
| |
| return "Python Built-in Functions", section |
| |
| |
| def _list_supported_ops(): |
| def emit_block(decls): |
| return "\n.. rst-class:: codeblock-height-limiter\n\n::\n\n{}\n".format( |
| "".join(f" {d}\n\n" for d in decls) |
| ) |
| |
| body = "" |
| op_gathering_fns = ( |
| _get_tensor_ops, |
| _get_nn_functional_ops, |
| _get_torchscript_builtins, |
| _get_global_builtins, |
| _get_math_builtins, |
| ) |
| for fn in op_gathering_fns: |
| header, items = fn() |
| link_target = header.replace("`", "").replace("-", "").lower().replace(" ", "-") |
| if isinstance(items, str): |
| section = f"{header}\n{'~' * len(header)}\n{items}\n" |
| else: |
| section = f"{header}\n{'~' * len(header)}\n{emit_block(items)}" |
| section = f".. _{link_target}:" + "\n\n" + section |
| body += section |
| |
| return body |
| |
| |
| __doc__ = _list_supported_ops() |