| # mypy: allow-untyped-defs |
| import ast |
| import functools |
| import inspect |
| from textwrap import dedent |
| from typing import Any, List, NamedTuple, Optional, Tuple |
| |
| from torch._C import ErrorReport |
| from torch._C._jit_tree_views import SourceRangeFactory |
| |
| |
| def get_source_lines_and_file( |
| obj: Any, |
| error_msg: Optional[str] = None, |
| ) -> Tuple[List[str], int, Optional[str]]: |
| """ |
| Wrapper around inspect.getsourcelines and inspect.getsourcefile. |
| |
| Returns: (sourcelines, file_lino, filename) |
| """ |
| filename = None # in case getsourcefile throws |
| try: |
| filename = inspect.getsourcefile(obj) |
| sourcelines, file_lineno = inspect.getsourcelines(obj) |
| except OSError as e: |
| msg = ( |
| f"Can't get source for {obj}. TorchScript requires source access in " |
| "order to carry out compilation, make sure original .py files are " |
| "available." |
| ) |
| if error_msg: |
| msg += "\n" + error_msg |
| raise OSError(msg) from e |
| |
| return sourcelines, file_lineno, filename |
| |
| |
| def normalize_source_lines(sourcelines: List[str]) -> List[str]: |
| """ |
| This helper function accepts a list of source lines. It finds the |
| indentation level of the function definition (`def`), then it indents |
| all lines in the function body to a point at or greater than that |
| level. This allows for comments and continued string literals that |
| are at a lower indentation than the rest of the code. |
| Args: |
| sourcelines: function source code, separated into lines by |
| the '\n' character |
| Returns: |
| A list of source lines that have been correctly aligned |
| """ |
| |
| def remove_prefix(text, prefix): |
| return text[text.startswith(prefix) and len(prefix) :] |
| |
| # Find the line and line number containing the function definition |
| idx = None |
| for i, l in enumerate(sourcelines): |
| if l.lstrip().startswith("def"): |
| idx = i |
| break |
| |
| # This will happen when the function is a lambda- we won't find "def" anywhere in the source |
| # lines in that case. Currently trying to JIT compile a lambda will throw an error up in |
| # `parse_def()`, but we might want to handle this case in the future. |
| if idx is None: |
| return sourcelines |
| |
| # Get a string representing the amount of leading whitespace |
| fn_def = sourcelines[idx] |
| whitespace = fn_def.split("def")[0] |
| |
| # Add this leading whitespace to all lines before and after the `def` |
| aligned_prefix = [ |
| whitespace + remove_prefix(s, whitespace) for s in sourcelines[:idx] |
| ] |
| aligned_suffix = [ |
| whitespace + remove_prefix(s, whitespace) for s in sourcelines[idx + 1 :] |
| ] |
| |
| # Put it together again |
| aligned_prefix.append(fn_def) |
| return aligned_prefix + aligned_suffix |
| |
| |
| # Thin wrapper around SourceRangeFactory to store extra metadata |
| # about the function-to-be-compiled. |
| class SourceContext(SourceRangeFactory): |
| def __init__( |
| self, |
| source, |
| filename, |
| file_lineno, |
| leading_whitespace_len, |
| uses_true_division=True, |
| funcname=None, |
| ): |
| super().__init__(source, filename, file_lineno, leading_whitespace_len) |
| self.uses_true_division = uses_true_division |
| self.filename = filename |
| self.funcname = funcname |
| |
| |
| @functools.lru_cache(maxsize=None) |
| def make_source_context(*args): |
| return SourceContext(*args) |
| |
| |
| def fake_range(): |
| return SourceContext("", None, 0, 0).make_raw_range(0, 1) |
| |
| |
| class ParsedDef(NamedTuple): |
| ast: ast.Module |
| ctx: SourceContext |
| source: str |
| filename: Optional[str] |
| file_lineno: int |
| |
| |
| def parse_def(fn): |
| sourcelines, file_lineno, filename = get_source_lines_and_file( |
| fn, ErrorReport.call_stack() |
| ) |
| sourcelines = normalize_source_lines(sourcelines) |
| source = "".join(sourcelines) |
| dedent_src = dedent(source) |
| py_ast = ast.parse(dedent_src) |
| if len(py_ast.body) != 1 or not isinstance(py_ast.body[0], ast.FunctionDef): |
| raise RuntimeError( |
| f"Expected a single top-level function: {filename}:{file_lineno}" |
| ) |
| leading_whitespace_len = len(source.split("\n", 1)[0]) - len( |
| dedent_src.split("\n", 1)[0] |
| ) |
| ctx = make_source_context( |
| source, filename, file_lineno, leading_whitespace_len, True, fn.__name__ |
| ) |
| return ParsedDef(py_ast, ctx, source, filename, file_lineno) |