| """Tracing |
| |
| This module contains functionality to support the JIT's tracing frontend, notably: |
| * torch.jit.trace |
| * torch.jit.trace_module |
| |
| This is not intended to be imported directly; please use the exposed |
| functionalities in `torch.jit`. |
| """ |
| import torch |
| |
| import copy |
| import os |
| import contextlib |
| import functools |
| import warnings |
| import inspect |
| import re |
| from typing import Any, Dict, List, Optional, Set |
| |
| from torch.jit._state import _python_cu, _enabled |
| from torch.jit._script import ScriptModule, _CachedForward, script |
| from torch._jit_internal import _qualified_name, is_scripting, get_callable_argument_names |
| from torch.autograd import function |
| from torch.nn import Module |
| |
| from torch.testing._comparison import default_tolerances |
| |
| _flatten = torch._C._jit_flatten |
| _unflatten = torch._C._jit_unflatten |
| |
| |
| def _create_interpreter_name_lookup_fn(frames_up=1): |
| def _get_interpreter_name_for_var(var): |
| frame = inspect.currentframe() |
| if not frame: |
| raise RuntimeError("failed to inspect frame") |
| |
| i = 0 |
| while i < frames_up + 1: |
| frame = frame.f_back |
| if not frame: |
| raise RuntimeError("failed to get frame") |
| i += 1 |
| |
| f_locals = frame.f_locals |
| f_globals = frame.f_globals |
| |
| for k, v in f_locals.items(): |
| if isinstance(v, torch.Tensor) and var is v: |
| return k if k != "self" else "" |
| return "" |
| |
| return _get_interpreter_name_for_var |
| |
| |
| def _unique_state_dict(module, keep_vars=False): |
| # since Parameter.detach() always creates a new torch.Tensor instance, |
| # id(v) doesn't work with it. So we always get the Parameter or Buffer |
| # as values, and deduplicate the params using Parameters and Buffers |
| state_dict = module.state_dict(keep_vars=True) |
| filtered_dict = type(state_dict)() |
| seen_ids: Set[int] = set() |
| for k, v in state_dict.items(): |
| if id(v) in seen_ids: |
| continue |
| seen_ids.add(id(v)) |
| if keep_vars: |
| filtered_dict[k] = v |
| else: |
| filtered_dict[k] = v.detach() |
| return filtered_dict |
| |
| |
| class ONNXTracedModule(torch.nn.Module): |
| def __init__( |
| self, |
| inner, |
| strict=True, |
| force_outplace=False, |
| return_inputs=False, |
| return_inputs_states=False, |
| ): |
| super(ONNXTracedModule, self).__init__() |
| # inner may be a Module, or it may be an arbitrary callable |
| # If it's a Module, we get its parameters automatically, which lets |
| # us avoid a special casing functions versus modules. |
| self.inner = inner |
| self.strict = strict |
| self._force_outplace = force_outplace |
| self._return_inputs = return_inputs |
| self._return_inputs_states = return_inputs_states |
| |
| def forward(self, *args: torch.Tensor): |
| in_vars, in_desc = _flatten(args) |
| # NOTE: use full state, because we need it for BatchNorm export |
| # This differs from the compiler path, which doesn't support it at the moment. |
| module_state = list(_unique_state_dict(self, keep_vars=True).values()) |
| |
| ret_inputs = [] |
| inputs_states = [] |
| outs = [] |
| |
| def wrapper(*args): |
| in_args: List[torch.Tensor] = [] |
| for i in range(len(in_vars)): |
| if not isinstance(args[i], torch.Tensor): |
| raise RuntimeError('Expected Tensor argument') |
| in_args.append(args[i]) |
| |
| trace_inputs = _unflatten(in_args, in_desc) |
| |
| ret_inputs.append( |
| tuple(x.clone(memory_format=torch.preserve_format) for x in args) |
| ) |
| if self._return_inputs_states: |
| inputs_states.append(_unflatten(in_args, in_desc)) |
| outs.append(self.inner(*trace_inputs)) |
| if self._return_inputs_states: |
| inputs_states[0] = (inputs_states[0], trace_inputs) |
| out_vars, _ = _flatten(outs) |
| if len(out_vars) == 1: |
| return out_vars[0] |
| else: |
| return tuple(out_vars) |
| |
| graph, out = torch._C._create_graph_by_tracing( |
| wrapper, |
| in_vars + module_state, |
| _create_interpreter_name_lookup_fn(), |
| self.strict, |
| self._force_outplace, |
| ) |
| |
| if self._return_inputs: |
| return graph, outs[0], ret_inputs[0] |
| if self._return_inputs_states: |
| return graph, outs[0], inputs_states[0] |
| else: |
| return graph, outs[0] |
| |
| |
| def _clone_inputs(args): |
| def clone_input(a): |
| if a is None: |
| return None |
| elif isinstance(a, torch.Tensor): |
| # TODO: figure out one liner to .clone() and set requires_grad |
| v = ( |
| a.detach() |
| .clone(memory_format=None if a.is_mkldnn else torch.preserve_format) |
| .requires_grad_(a.requires_grad) |
| ) |
| if a.grad is not None: |
| v.grad = clone_input(v.grad) |
| return v |
| else: |
| return a.clone(memory_format=torch.preserve_format) |
| |
| return function._nested_map( |
| lambda x: isinstance(x, torch.Tensor), clone_input, condition_msg="tensors" |
| )(args) |
| |
| |
| # This is purely for developer debugging. We are not going to advertise it. |
| _JIT_TIME = os.environ.get("PYTORCH_JIT_TIME", False) # CUDA-only timing |
| _JIT_DISABLE = os.environ.get("PYTORCH_JIT_DISABLE", False) |
| _JIT_STATS = os.environ.get("PYTORCH_JIT_STATS", False) |
| |
| |
| @contextlib.contextmanager |
| def _time(trace_name, name, time=True): |
| if (not _JIT_TIME and not time) or not torch.cuda.is_available(): |
| yield |
| return |
| stream = torch.cuda.current_stream() |
| start = torch.cuda.Event(enable_timing=True) |
| end = torch.cuda.Event(enable_timing=True) |
| stream.record_event(start) |
| try: |
| yield |
| finally: |
| stream.record_event(end) |
| end.synchronize() |
| print("{} {} time: {} ms".format(trace_name, name, start.elapsed_time(end))) |
| |
| |
| def verify(model, args, loss_fn=torch.sum, devices=None): |
| """ |
| Verify that a JIT compiled model has the same behavior as its uncompiled |
| version along with its backwards pass. If your model returns multiple |
| outputs, you must also specify a `loss_fn` to produce a loss for which |
| the backwards will be computed. |
| |
| This function has side-effects (e.g., it executes your model / saves and loads |
| parameters), so don't expect the model to come out exactly the same as what |
| you passed in. |
| |
| Args: |
| model (compiled torch.nn.Module or function): the module/function to be |
| verified. The module/function definition MUST have been decorated with |
| `@torch.jit.compile`. |
| args (tuple or Tensor): the positional arguments to pass to the |
| compiled function/module to be verified. A non-tuple is assumed to |
| be a single positional argument to be passed to the model. |
| loss_fn (function, optional): the loss function to be applied to |
| the output of the model, before backwards is invoked. By default, |
| we assume that a model returns a single result, and we :func:`torch.sum` |
| before calling backwards; if this is inappropriate, you can pass your |
| own loss function. Note that if a model returns a tuple of results, |
| these are passed as separate positional arguments to `loss_fn`. |
| devices (iterable of device IDs, optional): the GPU devices which the |
| compiled module will be run on. This determines the RNG state we |
| must save when running both compiled and uncompiled versions of the model. |
| """ |
| # TODO: In principle, we track device information in our trace, so it |
| # should be possible to check if our execution actually obeyed the 'devices' |
| # the user provided. |
| |
| # TODO: Consider adding a utility function to torch.jit to test |
| # for this case |
| if not isinstance(model, torch._C.CompiledFunction): # type: ignore[attr-defined] |
| raise TypeError( |
| "Cannot verify an uncompiled module. Add @torch.jit.compile to compile it" |
| ) |
| is_module = isinstance(model, Module) |
| |
| if not isinstance(args, tuple): |
| args = (args,) |
| |
| saved_args = _clone_inputs(args) |
| if is_module: |
| saved_state = copy.deepcopy(model.state_dict()) |
| |
| def run_fwd_bwd(args, force_trace=False, assert_compiled=False): |
| params = list(model.parameters()) if is_module else [] |
| in_vars, _ = _flatten((args, params)) |
| # We use a special API to reset the trace and compile it from scratch. |
| compiled_fn = model |
| if force_trace: |
| compiled_fn.clear_cache() |
| if assert_compiled: |
| hits = compiled_fn.hits |
| out = model(*args) |
| if assert_compiled and compiled_fn.hits == hits: |
| raise RuntimeError("failed to use the compiled function") |
| if not isinstance(out, tuple): |
| out = (out,) |
| if loss_fn == torch.sum and len(out) != 1: |
| raise ValueError( |
| ( |
| "Model returns {} outputs, but default loss function " |
| "(torch.sum) can only handle a single output" |
| ).format(len(out)) |
| ) |
| out_vars, _ = _flatten(out) |
| saved_outs = [ |
| v.detach().clone(memory_format=torch.preserve_format) for v in out_vars |
| ] |
| loss = loss_fn(*out) |
| grads = torch.autograd.grad([loss], in_vars) |
| # TODO: I'm not sure if the clone here is necessary but it is safer |
| saved_grads = [ |
| v.detach().clone(memory_format=torch.preserve_format) for v in grads |
| ] |
| return (saved_outs, saved_grads) |
| |
| with torch.random.fork_rng(devices, _caller="torch.jit.verify"): |
| uncompiled_outs, uncompiled_grads = run_fwd_bwd(args, force_trace=True) |
| assert model.has_trace_for(*args) |
| |
| if is_module: |
| model.load_state_dict(saved_state) |
| compiled_outs, compiled_grads = run_fwd_bwd(args, assert_compiled=True) |
| |
| _verify_equal(uncompiled_outs, compiled_outs) |
| _verify_equal(uncompiled_grads, compiled_grads) |
| |
| |
| def _verify_equal(xs, ys): |
| for x, y in zip(xs, ys): |
| if x.sub(y).abs().max() > 1e-6: |
| raise RuntimeError("JIT and real computation mismatch") |
| |
| |
| def indent(s): |
| return "\n".join(["\t" + line for line in s.splitlines()]) |
| |
| |
| class TracingCheckError(Exception): |
| def __init__(self, graph_diff_error, tensor_compare_error, extra_msg=None): |
| self.message = "Tracing failed sanity checks!\n" |
| if extra_msg is not None: |
| self.message += extra_msg + "\n" |
| if graph_diff_error is not None: |
| self.message += "ERROR: Graphs differed across invocations!\n" |
| self.message += indent(graph_diff_error) + "\n" |
| if tensor_compare_error is not None: |
| self.message += ( |
| "ERROR: Tensor-valued Constant nodes differed in value " |
| "across invocations. This often indicates that the tracer has" |
| " encountered untraceable code.\n" |
| ) |
| self.message += indent(tensor_compare_error) + "\n" |
| super(TracingCheckError, self).__init__(self.message) |
| |
| |
| # Check the traced module against a set of user-provided validation inputs |
| @torch.no_grad() |
| def _check_trace( |
| check_inputs, |
| func, |
| traced_func, |
| check_tolerance, |
| strict, |
| force_outplace, |
| is_trace_module, |
| _module_class, |
| ): |
| # Note: tracing is independent of optimizations, which consume the trace |
| for inputs in check_inputs: |
| |
| if isinstance(inputs, torch.Tensor): |
| inputs = (inputs,) |
| |
| if is_trace_module: |
| copied_dict = {} |
| for name, data in inputs.items(): |
| copied_dict[name] = _clone_inputs(data) |
| check_mod = torch.jit.trace_module( |
| func.__self__ if hasattr(func, "__self__") else func, |
| copied_dict, |
| check_trace=False, |
| strict=strict, |
| _force_outplace=force_outplace, |
| _module_class=_module_class, |
| _compilation_unit=torch._C.CompilationUnit(), |
| ) |
| check_mod_func = check_mod._c._get_method(traced_func.name) |
| inputs = inputs[traced_func.name] |
| if isinstance(inputs, (torch.Tensor, dict)): |
| inputs = (inputs,) |
| else: |
| check_mod = torch.jit.trace( |
| func, |
| _clone_inputs(inputs), |
| check_trace=False, |
| strict=strict, |
| _force_outplace=force_outplace, |
| _module_class=_module_class, |
| ) |
| check_mod_func = check_mod |
| |
| def graph_diagnostic_info(): |
| mod_canonicalized = torch._C._jit_pass_canonicalize(traced_func.graph) |
| torch._C._jit_pass_inline(mod_canonicalized) |
| torch._C._jit_pass_erase_shape_information(mod_canonicalized) |
| mod_str = str(mod_canonicalized) |
| mod_str = re.sub(r"___torch_mangle_[0-9]+\.", "", mod_str) |
| check_canonicalized = torch._C._jit_pass_canonicalize(check_mod_func.graph) |
| torch._C._jit_pass_inline(check_canonicalized) |
| torch._C._jit_pass_erase_shape_information(check_canonicalized) |
| check_str = str(check_canonicalized) |
| check_str = re.sub(r"___torch_mangle_[0-9]+\.", "", check_str) |
| |
| graph_diff_errors = None |
| if mod_str != check_str: |
| import difflib |
| |
| graph_diff = difflib.ndiff( |
| mod_str.splitlines(True), check_str.splitlines(True) |
| ) |
| graph_diff_errors = "Graph diff:\n" + indent("".join(graph_diff)) + "\n" |
| |
| for n_mod, n_check in zip( |
| mod_canonicalized.nodes(), check_canonicalized.nodes() |
| ): |
| if str(n_mod) != str(n_check): |
| graph_diff_errors += "First diverging operator:\n" |
| node_diff = difflib.ndiff( |
| str(n_mod).splitlines(True), str(n_check).splitlines(True) |
| ) |
| source_printout = ( |
| "Node diff:\n" + indent("".join(node_diff)) + "\n" |
| ) |
| mod_stack = n_mod.sourceRange() |
| if mod_stack: |
| source_printout += ( |
| "Trace source location:\n" + indent(mod_stack) + "\n" |
| ) |
| check_stack = n_check.sourceRange() |
| if check_stack: |
| source_printout += ( |
| "Check source location:\n" + indent(check_stack) + "\n" |
| ) |
| graph_diff_errors += source_printout |
| |
| break # For now, only print out the first pair of nodes that diverges |
| |
| tensor_compare_errors = None |
| # Check Tensor-valued constant nodes |
| for n_mod, n_check in zip( |
| mod_canonicalized.nodes(), check_canonicalized.nodes() |
| ): |
| if n_mod.kind() != n_check.kind(): |
| break # Graphs have already diverged |
| |
| if n_mod.kind() == "prim::Constant" and not ( |
| n_mod.mustBeNone() or n_check.mustBeNone() |
| ): |
| if not n_mod.hasAttribute("value"): |
| continue |
| if n_mod.kindOf("value") != "t" or n_check.kindOf("value") != "t": |
| continue |
| |
| mod_tensor_val = n_mod.t("value") |
| check_tensor_val = n_check.t("value") |
| |
| try: |
| torch.testing.assert_close(mod_tensor_val, check_tensor_val, equal_nan=True) |
| except (RuntimeError, AssertionError) as e: |
| if tensor_compare_errors is None: |
| tensor_compare_errors = "" |
| tensor_compare_errors += "Node:\n" + indent(str(n_mod)) + "\n" |
| compare_stack = n_mod.sourceRange() |
| if compare_stack: |
| tensor_compare_errors += ( |
| "Source Location:\n" + indent(compare_stack) + "\n" |
| ) |
| tensor_compare_errors += "Comparison exception: " + indent( |
| str(e) |
| ) |
| |
| break # For now, only print the first diverging pair |
| |
| return graph_diff_errors, tensor_compare_errors |
| |
| def wrap_retval(x): |
| return x if isinstance(x, tuple) else (x,) |
| |
| def run_mod_and_filter_tensor_outputs(mod, inputs, running_what): |
| try: |
| outs = wrap_retval(mod(*_clone_inputs(inputs))) |
| outs = [out for out in outs if isinstance(out, torch.Tensor)] |
| return outs |
| except Exception as e: |
| graph_diff_errors, tensor_compare_errors = graph_diagnostic_info() |
| msg = f"encountered an exception while running the {running_what} with test inputs.\nException:\n{indent(str(e))}" |
| raise TracingCheckError( |
| graph_diff_errors, |
| tensor_compare_errors, |
| extra_msg=msg, |
| ) from e |
| |
| has_warned = [False] |
| |
| def maybe_warn_nondeterministic(): |
| if has_warned[0]: |
| return |
| has_warned[0] = True |
| nondeterm_ops = [ |
| op for op in traced_func.graph.nodes() if op.isNondeterministic() |
| ] |
| if len(nondeterm_ops) > 0: |
| nondeterministic_ops_warning = "Trace had nondeterministic nodes. " |
| nondeterministic_ops_warning += ( |
| "Did you forget call .eval() on your model? Nodes:\n" |
| ) |
| nondeterministic_ops_warning += "\n".join( |
| [indent(str(op)) for op in nondeterm_ops][:20] |
| ) |
| nondeterministic_ops_warning += ( |
| "\nThis may cause errors in trace checking. To disable trace checking," |
| " pass check_trace=False to torch.jit.trace()" |
| ) |
| warnings.warn( |
| nondeterministic_ops_warning, category=TracerWarning, stacklevel=5 |
| ) |
| |
| def compare_outputs(original, reference, match_what): |
| all_ok = True |
| for i, (orig, ref) in enumerate(zip(original, reference)): |
| try: |
| if orig.is_quantized: |
| orig = orig.dequantize() |
| if ref.is_quantized: |
| ref = ref.dequantize() |
| if orig.is_mkldnn: |
| orig = orig.to_dense() |
| if ref.is_mkldnn: |
| ref = ref.to_dense() |
| if ref.is_complex() or orig.is_complex(): |
| torch.testing.assert_close( |
| orig.to(torch.cdouble), |
| ref.to(torch.cdouble), |
| rtol=check_tolerance, |
| atol=default_tolerances(orig, ref)[1], |
| equal_nan=True, |
| ) |
| else: |
| if orig.is_mps or ref.is_mps: |
| torch.testing.assert_close( |
| orig.float(), |
| ref.float(), |
| rtol=check_tolerance, |
| atol=default_tolerances(orig, ref)[1], |
| equal_nan=True, |
| ) |
| else: |
| torch.testing.assert_close( |
| orig.double(), |
| ref.double(), |
| rtol=check_tolerance, |
| atol=default_tolerances(orig, ref)[1], |
| equal_nan=True, |
| ) |
| except AssertionError as e: |
| maybe_warn_nondeterministic() |
| warnings.warn( |
| "Output nr " |
| + str(i + 1) |
| + ". of the traced function does not match " |
| "the corresponding output of the " |
| + match_what |
| + ". Detailed error:\n" |
| + str(e), |
| category=TracerWarning, |
| stacklevel=4, |
| ) |
| all_ok = False |
| |
| return all_ok |
| |
| traced_outs = run_mod_and_filter_tensor_outputs(traced_func, inputs, "trace") |
| fn_outs = run_mod_and_filter_tensor_outputs(func, inputs, "Python function") |
| if compare_outputs(traced_outs, fn_outs, "Python function"): |
| check_outs = run_mod_and_filter_tensor_outputs( |
| check_mod_func, inputs, "repeated trace" |
| ) |
| compare_outputs(traced_outs, check_outs, "repeated trace") |
| |
| diag_info = graph_diagnostic_info() |
| if any(info is not None for info in diag_info): |
| raise TracingCheckError(*diag_info) |
| |
| |
| class TracerWarning(Warning): |
| @staticmethod |
| def ignore_lib_warnings(): |
| # We ignore warnings from all submodules excluding the JIT, because we need them e.g. for _check_trace |
| warnings.filterwarnings( |
| "ignore", category=TracerWarning, module="torch.(?!jit)" |
| ) |
| |
| |
| # We ignore the tracer warnings coming form inside the library, because all our shape |
| # checks in nn will trigger them. |
| TracerWarning.ignore_lib_warnings() |
| torch._C._tracer_warn_use_python() |
| |
| |
| def make_tuple(example_inputs): |
| if isinstance(example_inputs, (torch.Tensor, dict)): |
| return (example_inputs,) |
| # done primarily so that weird iterables fail here and not pybind11 code |
| if not isinstance(example_inputs, tuple): |
| return tuple(example_inputs) |
| return example_inputs |
| |
| |
| def make_module(mod, _module_class, _compilation_unit): |
| if isinstance(mod, ScriptModule): |
| return mod |
| elif torch._jit_internal.module_has_exports(mod): |
| |
| infer_methods_stubs_fn = torch.jit._recursive.make_stubs_from_exported_methods |
| return torch.jit._recursive.create_script_module( |
| mod, |
| infer_methods_stubs_fn, |
| share_types=False, |
| is_tracing=True |
| ) |
| else: |
| if _module_class is None: |
| _module_class = TopLevelTracedModule |
| return _module_class(mod, _compilation_unit=_compilation_unit) |
| |
| |
| def wrap_check_inputs(check_inputs): |
| if check_inputs is None: |
| return None |
| |
| return [{"forward": c} for c in check_inputs] |
| |
| |
| def trace( |
| func, |
| example_inputs, |
| optimize=None, |
| check_trace=True, |
| check_inputs=None, |
| check_tolerance=1e-5, |
| strict=True, |
| _force_outplace=False, |
| _module_class=None, |
| _compilation_unit=_python_cu, |
| ): |
| """ |
| Trace a function and return an executable or :class:`ScriptFunction` |
| that will be optimized using just-in-time compilation. Tracing is ideal for |
| code that operates only on ``Tensor``\\s and lists, dictionaries, and |
| tuples of ``Tensor``\\s. |
| |
| Using `torch.jit.trace` and `torch.jit.trace_module`, you can turn an |
| existing module or Python function into a TorchScript |
| :class:`ScriptFunction` or :class:`ScriptModule`. You must provide example |
| inputs, and we run the function, recording the operations performed on all |
| the tensors. |
| |
| * The resulting recording of a standalone function produces `ScriptFunction`. |
| * The resulting recording of `nn.Module.forward` or `nn.Module` produces |
| `ScriptModule`. |
| |
| This module also contains any parameters that the original |
| module had as well. |
| |
| Warning: |
| Tracing only correctly records functions and modules which are not data |
| dependent (e.g., do not have conditionals on data in tensors) and do not have |
| any untracked external dependencies (e.g., perform input/output or |
| access global variables). Tracing only records operations done when the given |
| function is run on the given tensors. Therefore, the returned |
| `ScriptModule` will always run the same traced graph on any input. This |
| has some important implications when your module is expected to run |
| different sets of operations, depending on the input and/or the module |
| state. For example, |
| |
| * Tracing will not record any control-flow like if-statements or loops. |
| When this control-flow is constant across your module, this is fine |
| and it often inlines the control-flow decisions. But sometimes the |
| control-flow is actually part of the model itself. For instance, a |
| recurrent network is a loop over the (possibly dynamic) length of an |
| input sequence. |
| * In the returned :class:`ScriptModule`, operations that have different |
| behaviors in ``training`` and ``eval`` modes will always behave as if |
| it is in the mode it was in during tracing, no matter which mode the |
| `ScriptModule` is in. |
| |
| In cases like these, tracing would not be appropriate and |
| :func:`scripting <torch.jit.script>` is a better choice. If you trace |
| such models, you may silently get incorrect results on subsequent |
| invocations of the model. The tracer will try to emit warnings when |
| doing something that may cause an incorrect trace to be produced. |
| |
| Args: |
| func (callable or torch.nn.Module): A Python function or `torch.nn.Module` |
| that will be run with `example_inputs`. `func` arguments and return |
| values must be tensors or (possibly nested) tuples that contain |
| tensors. When a module is passed `torch.jit.trace`, only the |
| ``forward`` method is run and traced (see :func:`torch.jit.trace |
| <torch.jit.trace_module>` for details). |
| example_inputs (tuple or torch.Tensor): A tuple of example inputs that |
| will be passed to the function while tracing. The resulting trace |
| can be run with inputs of different types and shapes assuming the |
| traced operations support those types and shapes. `example_inputs` |
| may also be a single Tensor in which case it is automatically |
| wrapped in a tuple. |
| |
| Keyword arguments: |
| check_trace (``bool``, optional): Check if the same inputs run through |
| traced code produce the same outputs. Default: ``True``. You might want |
| to disable this if, for example, your network contains non- |
| deterministic ops or if you are sure that the network is correct despite |
| a checker failure. |
| |
| check_inputs (list of tuples, optional): A list of tuples of input |
| arguments that should be used to check the trace against what is |
| expected. Each tuple is equivalent to a set of input arguments that |
| would be specified in ``example_inputs``. For best results, pass in |
| a set of checking inputs representative of the space of shapes and |
| types of inputs you expect the network to see. If not specified, |
| the original ``example_inputs`` are used for checking |
| check_tolerance (float, optional): Floating-point comparison tolerance |
| to use in the checker procedure. This can be used to relax the |
| checker strictness in the event that results diverge numerically |
| for a known reason, such as operator fusion. |
| strict (``bool``, optional): run the tracer in a strict mode or not |
| (default: ``True``). Only turn this off when you want the tracer to |
| record your mutable container types (currently ``list``/``dict``) |
| and you are sure that the container you are using in your |
| problem is a ``constant`` structure and does not get used as |
| control flow (if, for) conditions. |
| |
| Returns: |
| If `func` is `nn.Module` or ``forward`` of `nn.Module`, `trace` returns |
| a :class:`ScriptModule` object with a single ``forward`` method |
| containing the traced code. The returned `ScriptModule` will |
| have the same set of sub-modules and parameters as the original |
| ``nn.Module``. If ``func`` is a standalone function, ``trace`` |
| returns `ScriptFunction`. |
| |
| Example (tracing a function): |
| |
| .. testcode:: |
| |
| import torch |
| |
| def foo(x, y): |
| return 2 * x + y |
| |
| # Run `foo` with the provided inputs and record the tensor operations |
| traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3))) |
| |
| # `traced_foo` can now be run with the TorchScript interpreter or saved |
| # and loaded in a Python-free environment |
| |
| Example (tracing an existing module):: |
| |
| import torch |
| import torch.nn as nn |
| |
| class Net(nn.Module): |
| def __init__(self): |
| super(Net, self).__init__() |
| self.conv = nn.Conv2d(1, 1, 3) |
| |
| def forward(self, x): |
| return self.conv(x) |
| |
| n = Net() |
| example_weight = torch.rand(1, 1, 3, 3) |
| example_forward_input = torch.rand(1, 1, 3, 3) |
| |
| # Trace a specific method and construct `ScriptModule` with |
| # a single `forward` method |
| module = torch.jit.trace(n.forward, example_forward_input) |
| |
| # Trace a module (implicitly traces `forward`) and construct a |
| # `ScriptModule` with a single `forward` method |
| module = torch.jit.trace(n, example_forward_input) |
| |
| """ |
| if not _enabled: |
| return func |
| if optimize is not None: |
| warnings.warn( |
| "`optimize` is deprecated and has no effect. Use `with torch.jit.optimized_execution() instead" |
| ) |
| |
| if isinstance(func, torch.jit.ScriptModule): |
| # it is hard to trace it because the forward method on ScriptModule is already defined, so it |
| # would result in an error. |
| warnings.warn( |
| "The input to trace is already a ScriptModule, tracing it is a no-op. Returning the object as is." |
| ) |
| return func |
| |
| if isinstance(func, torch.nn.Module): |
| return trace_module( |
| func, |
| {"forward": example_inputs}, |
| None, |
| check_trace, |
| wrap_check_inputs(check_inputs), |
| check_tolerance, |
| strict, |
| _force_outplace, |
| _module_class, |
| ) |
| |
| if ( |
| hasattr(func, "__self__") |
| and isinstance(func.__self__, torch.nn.Module) |
| and func.__name__ == "forward" |
| ): |
| return trace_module( |
| func.__self__, |
| {"forward": example_inputs}, |
| None, |
| check_trace, |
| wrap_check_inputs(check_inputs), |
| check_tolerance, |
| strict, |
| _force_outplace, |
| _module_class, |
| ) |
| |
| # Special case for common case of passing a single Tensor |
| if isinstance(example_inputs, (torch.Tensor, dict)): |
| example_inputs = (example_inputs,) |
| # done primarily so that weird iterables fail here and not pybind11 code |
| elif not isinstance(example_inputs, tuple): |
| example_inputs = tuple(example_inputs) |
| |
| var_lookup_fn = _create_interpreter_name_lookup_fn(0) |
| |
| if hasattr(func, "__self__") and isinstance(func.__self__, torch.nn.Module): |
| raise AttributeError( |
| "trace doesn't support compiling individual module's functions.\n" |
| "Please use trace_module" |
| ) |
| |
| name = _qualified_name(func) |
| traced = torch._C._create_function_from_trace( |
| name, |
| func, |
| example_inputs, |
| var_lookup_fn, |
| strict, |
| _force_outplace, |
| get_callable_argument_names(func) |
| ) |
| |
| # Check the trace against new traces created from user-specified inputs |
| if check_trace: |
| if check_inputs is not None: |
| _check_trace( |
| check_inputs, |
| func, |
| traced, |
| check_tolerance, |
| strict, |
| _force_outplace, |
| False, |
| _module_class, |
| ) |
| else: |
| _check_trace( |
| [example_inputs], |
| func, |
| traced, |
| check_tolerance, |
| strict, |
| _force_outplace, |
| False, |
| _module_class, |
| ) |
| |
| return traced |
| |
| |
| _trace_module_map: Optional[Dict[Any, Any]] = None |
| |
| |
| def trace_module( |
| mod, |
| inputs, |
| optimize=None, |
| check_trace=True, |
| check_inputs=None, |
| check_tolerance=1e-5, |
| strict=True, |
| _force_outplace=False, |
| _module_class=None, |
| _compilation_unit=_python_cu, |
| ): |
| """ |
| Trace a module and return an executable :class:`ScriptModule` that will be optimized |
| using just-in-time compilation. When a module is passed to :func:`torch.jit.trace <torch.jit.trace>`, only |
| the ``forward`` method is run and traced. With ``trace_module``, you can specify a dictionary of |
| method names to example inputs to trace (see the ``inputs``) argument below. |
| |
| See :func:`torch.jit.trace <torch.jit.trace>` for more information on tracing. |
| |
| Args: |
| mod (torch.nn.Module): A ``torch.nn.Module`` containing methods whose names are |
| specified in ``inputs``. The given methods will be compiled |
| as a part of a single `ScriptModule`. |
| inputs (dict): A dict containing sample inputs indexed by method names in ``mod``. |
| The inputs will be passed to methods whose names correspond to inputs' |
| keys while tracing. |
| ``{ 'forward' : example_forward_input, 'method2': example_method2_input}`` |
| Keyword arguments: |
| check_trace (``bool``, optional): Check if the same inputs run through |
| traced code produce the same outputs. Default: ``True``. You might want |
| to disable this if, for example, your network contains non- |
| deterministic ops or if you are sure that the network is correct despite |
| a checker failure. |
| |
| check_inputs (list of dicts, optional): A list of dicts of input arguments that should be used |
| to check the trace against what is expected. Each tuple |
| is equivalent to a set of input arguments that would |
| be specified in ``inputs``. For best results, pass in a |
| set of checking inputs representative of the space of |
| shapes and types of inputs you expect the network to see. |
| If not specified, the original ``inputs`` are used for checking |
| check_tolerance (float, optional): Floating-point comparison tolerance to use in the checker procedure. |
| This can be used to relax the checker strictness in the event that |
| results diverge numerically for a known reason, such as operator fusion. |
| |
| Returns: |
| A :class:`ScriptModule` object with a single ``forward`` method containing the traced code. |
| When ``func`` is a ``torch.nn.Module``, the returned :class:`ScriptModule` will have the same set of |
| sub-modules and parameters as ``func``. |
| |
| Example (tracing a module with multiple methods):: |
| |
| import torch |
| import torch.nn as nn |
| |
| class Net(nn.Module): |
| def __init__(self): |
| super(Net, self).__init__() |
| self.conv = nn.Conv2d(1, 1, 3) |
| |
| def forward(self, x): |
| return self.conv(x) |
| |
| def weighted_kernel_sum(self, weight): |
| return weight * self.conv.weight |
| |
| |
| n = Net() |
| example_weight = torch.rand(1, 1, 3, 3) |
| example_forward_input = torch.rand(1, 1, 3, 3) |
| |
| # Trace a specific method and construct `ScriptModule` with |
| # a single `forward` method |
| module = torch.jit.trace(n.forward, example_forward_input) |
| |
| # Trace a module (implicitly traces `forward`) and construct a |
| # `ScriptModule` with a single `forward` method |
| module = torch.jit.trace(n, example_forward_input) |
| |
| # Trace specific methods on a module (specified in `inputs`), constructs |
| # a `ScriptModule` with `forward` and `weighted_kernel_sum` methods |
| inputs = {'forward' : example_forward_input, 'weighted_kernel_sum' : example_weight} |
| module = torch.jit.trace_module(n, inputs) |
| |
| """ |
| if not _enabled: |
| return mod |
| if optimize is not None: |
| warnings.warn( |
| "`optimize` is deprecated and has no effect. Use `with torch.jit.optimized_execution() instead" |
| ) |
| |
| var_lookup_fn = _create_interpreter_name_lookup_fn(0) |
| |
| if not isinstance(mod, torch.nn.Module): |
| raise AttributeError("expected torch.nn.Module as the first argument") |
| |
| if not isinstance(inputs, dict): |
| raise AttributeError("expected a dictionary of (method_name, input) pairs") |
| |
| old_module_map = torch.jit._trace._trace_module_map |
| try: |
| trace_module_map: Dict[Any, Any] = {} |
| |
| def register_submods(mod, prefix): |
| for name, child in mod.named_children(): |
| submod_qualname = prefix + "." + name |
| trace_module_map[child] = submod_qualname |
| register_submods(child, submod_qualname) |
| |
| trace_module_map["__module"] = mod |
| torch.jit._trace._trace_module_map = trace_module_map |
| register_submods(mod, "__module") |
| |
| module = make_module(mod, _module_class, _compilation_unit) |
| |
| for method_name, example_inputs in inputs.items(): |
| if method_name == "forward": |
| # "forward" is a special case because we need to trace |
| # `Module.__call__`, which sets up some extra tracing, but uses |
| # argument names of the real `Module.forward` method. |
| func = mod |
| forward_method = getattr(mod, method_name) |
| argument_names = get_callable_argument_names(forward_method) |
| else: |
| func = getattr(mod, method_name) |
| argument_names = get_callable_argument_names(func) |
| |
| example_inputs = make_tuple(example_inputs) |
| |
| module._c._create_method_from_trace( |
| method_name, |
| func, |
| example_inputs, |
| var_lookup_fn, |
| strict, |
| _force_outplace, |
| argument_names, |
| ) |
| check_trace_method = module._c._get_method(method_name) |
| |
| # Check the trace against new traces created from user-specified inputs |
| if check_trace: |
| if check_inputs is not None: |
| _check_trace( |
| check_inputs, |
| func, |
| check_trace_method, |
| check_tolerance, |
| strict, |
| _force_outplace, |
| True, |
| _module_class, |
| ) |
| else: |
| _check_trace( |
| [inputs], |
| func, |
| check_trace_method, |
| check_tolerance, |
| strict, |
| _force_outplace, |
| True, |
| _module_class, |
| ) |
| finally: |
| torch.jit._trace._trace_module_map = old_module_map |
| |
| return module |
| |
| |
| def is_tracing(): |
| """ |
| Returns ``True`` in tracing (if a function is called during the tracing of |
| code with ``torch.jit.trace``) and ``False`` otherwise. |
| """ |
| if is_scripting(): |
| return False |
| return torch._C._is_tracing() |
| |
| |
| class TracedModule(ScriptModule): |
| _disable_script_meta = True |
| |
| def __init__(self, orig, id_set=None, _compilation_unit=None): |
| # XXX: orig can be a nn.Module or a function! |
| super(TracedModule, self).__init__() |
| assert isinstance(orig, torch.nn.Module) |
| |
| # Copy a subset of `orig` to a temporary nn.Module. |
| # This is a way to customize what will actually get compiled by create_script_module |
| id_set = set() |
| |
| # This allows us to preserve the original module's qualified name by defining a new |
| # type with the attribute _jit_override_qualname. In torch._jit_internal._qualified_name |
| # we have a special case that will look up this attribute to override whatever qualname |
| # we would get from the python type system |
| class QualnameWrapper(torch.nn.Module): |
| pass |
| |
| QualnameWrapper._jit_override_qualname = torch._jit_internal._qualified_name( # type: ignore[attr-defined] |
| type(orig) |
| ) |
| |
| tmp_module = QualnameWrapper() |
| |
| def check_unique(param): |
| if param in id_set: |
| raise ValueError( |
| "TracedModules don't support parameter sharing between modules" |
| ) |
| id_set.add(param) |
| tmp_module.training = orig.training |
| |
| for name, param in orig._parameters.items(): |
| if param is not None: |
| tmp_module._parameters[name] = param |
| check_unique(param) |
| for name, buf in orig._buffers.items(): |
| if buf is not None: |
| tmp_module._buffers[name] = buf |
| check_unique(buf) |
| for name, val in orig.__dict__.items(): |
| if ( |
| torch._C._jit_is_script_object(val) |
| and name not in orig._parameters |
| and name not in orig._buffers |
| ): |
| setattr(tmp_module, name, val) |
| |
| if orig._backward_hooks: |
| raise ValueError( |
| "Modules that have backward hooks assigned can't be compiled: " |
| + str(orig) |
| ) |
| |
| for name, submodule in orig._modules.items(): |
| if submodule is None: |
| continue |
| tmp_module._modules[name] = make_module( |
| submodule, TracedModule, _compilation_unit=None |
| ) |
| |
| script_module = torch.jit._recursive.create_script_module( |
| tmp_module, lambda module: (), share_types=False, is_tracing=True |
| ) |
| |
| self.__dict__["_name"] = type(orig).__name__ |
| self.__dict__["_actual_script_module"] = script_module |
| for name in ("_parameters", "_buffers", "_modules", "training"): |
| delattr(self, name) |
| |
| def forward(self, *args, **kwargs): |
| raise RuntimeError("Trace submodules cannot be called.") |
| |
| def __getattr__(self, attr): |
| if "_actual_script_module" not in self.__dict__: |
| return super(TracedModule, self).__getattr__(attr) |
| return getattr(self._actual_script_module, attr) |
| |
| def __setattr__(self, attr, value): |
| if "_actual_script_module" not in self.__dict__: |
| return super(TracedModule, self).__setattr__(attr, value) |
| setattr(self._actual_script_module, attr, value) |
| |
| def _get_name(self): |
| return self._name |
| |
| def extra_repr(self): |
| return "original_name={}".format(self._name) |
| |
| |
| class TopLevelTracedModule(TracedModule): |
| forward = _CachedForward() |
| |
| def _reconstruct(self, cpp_module): |
| """ |
| Re-construct an instance of TopLevelTracedModule using an instance of a C++ module. |
| |
| Args: |
| cpp_module: The C++ module that this TopLevelTracedModule will be rebuilt around. |
| """ |
| self.__dict__["_actual_script_module"]._reconstruct(cpp_module) |
| |
| |
| def _script_if_tracing(fn): |
| @functools.wraps(fn) |
| def wrapper(*args, **kwargs): |
| if not is_tracing(): |
| # Not tracing, don't do anything |
| return fn(*args, **kwargs) |
| |
| compiled_fn = script(wrapper.__original_fn) # type: ignore[attr-defined] |
| return compiled_fn(*args, **kwargs) |
| |
| wrapper.__original_fn = fn # type: ignore[attr-defined] |
| wrapper.__script_if_tracing_wrapper = True # type: ignore[attr-defined] |
| |
| return wrapper |
| |
| |
| def _get_trace_graph(f, args=(), kwargs=None, strict=True, _force_outplace=False, |
| return_inputs=False, _return_inputs_states=False): |
| """ |
| .. warning:: |
| This function is internal-only and should only be used by the ONNX |
| exporter. If you are trying to get a graph through tracing, please go |
| through the public API instead:: |
| |
| trace = torch.jit.trace(nn.LSTMCell(), (input, hidden)) |
| trace_graph = trace.graph |
| |
| Trace a function or model, returning a tuple consisting of the both the |
| *trace* of an execution, as well as the original return value. If return_inputs, |
| also returns the trace inputs as part of the tuple |
| |
| Tracing is guaranteed not to change the semantics of the function/module |
| that is traced. |
| |
| Args: |
| f (torch.nn.Module or function): the function or module |
| to be traced. |
| args (tuple or Tensor): the positional arguments to pass to the |
| function/module to be traced. A non-tuple is assumed to |
| be a single positional argument to be passed to the model. |
| kwargs (dict): the keyword arguments to pass to the function/module |
| to be traced. |
| |
| Example (trace a cell): |
| |
| .. testcode:: |
| |
| trace = torch.jit.trace(nn.LSTMCell(), (input, hidden)) |
| """ |
| if kwargs is None: |
| kwargs = {} |
| if not isinstance(args, tuple): |
| args = (args,) |
| outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs) |
| return outs |