blob: 701036789ffcb65c69417388ce1754d93e89792c [file] [log] [blame]
import logging
import os
import sys
from os.path import abspath, dirname
from types import ModuleType
import torch
try:
import torch._prims
import torch._refs
HAS_REFS_PRIMS = True
except ImportError:
HAS_REFS_PRIMS = False
# log level (levels print what it says + all levels listed below it)
# logging.DEBUG print full traces <-- lowest level + print tracing of every instruction
# torchdynamo.logging.CODE print compiled functions + graphs
# logging.INFO print the steps that dynamo is running
# logging.WARN print warnings (including graph breaks)
# logging.ERROR print exceptions (and what user code was being processed when it occurred)
# NOTE: changing log_level will automatically update the levels of all torchdynamo loggers
log_level = logging.WARNING
# the name of a file to write the logs to
log_file_name = None
# Verbose will print full stack traces on warnings and errors
verbose = False
# verify the correctness of optimized backend
verify_correctness = False
# need this many ops to create an FX graph
minimum_call_count = 1
# turn on/off DCE pass
dead_code_elimination = True
# disable (for a function) when cache reaches this size
cache_size_limit = 64
# specializing int/float by default
specialize_int_float = True
# Assume these functions return constants
constant_functions = {
torch.jit.is_scripting: False,
torch.jit.is_tracing: False,
torch._C._get_tracing_state: None,
torch.fx._symbolic_trace.is_fx_tracing: False,
torch.onnx.is_in_onnx_export: False,
}
# don't specialize on shapes and strides and put shape ops in graph
dynamic_shapes = os.environ.get("TORCHDYNAMO_DYNAMIC_SHAPES") == "1"
# Set this to False to assume nn.Modules() contents are immutable (similar assumption as freezing)
guard_nn_modules = False
# Run the FX graph as it is created to get better type information
dynamic_propagation = True
# Run the FX graph with FakeTensors
fake_tensor_propagation = True
# run FX normalization passes in optimizer
normalize_ir = False
# If a tensor subclass type is in this set, torchdynamo will inline the
# __torch_function__ logic of the subclass.
traceable_tensor_subclasses = set()
# Suppress errors in torchdynamo.optimize, instead forcing a fallback to eager.
# This is a good way to get your model to work one way or another, but you may
# lose optimization opportunities this way. Devs, if your benchmark model is failing
# this way, you should figure out why instead of suppressing it.
suppress_errors = bool(os.environ.get("TORCHDYNAMO_SUPPRESS_ERRORS", False))
# Record and write an execution record of the current frame to a file
# if an exception is encountered
replay_record_enabled = False
# Show a warning on every graph break
print_graph_breaks = False
# If a PyTorch module is in this allowlist, torchdynamo will be allowed
# to inline objects from it or its children.
skipfiles_inline_module_allowlist = {
torch.nn,
torch.distributions,
torch.testing,
}
if HAS_REFS_PRIMS:
skipfiles_inline_module_allowlist |= {
torch._refs,
torch._prims,
torch._decomp,
}
# If a string representing a PyTorch module is in this ignorelist,
# the `allowed_functions.is_allowed` function will not consider it
# when creating a list of PyTorch functions that will appear in
# FX IR.
allowed_functions_module_string_ignorelist = {
"torch.distributions",
"torch.testing",
"torch._refs",
"torch._prims",
"torch._decomp",
}
# Debug Flag to try minifier at different stages. Possible values are {None, "aot", "dynamo"}
# None - Minifier is switched off
# dynamo - Runs minifier on the TorchDynamo produced graphs, if compilation fails
# aot - Runs minifier on the Aot Autograd produced graphs, if compilation fails
repro_after = os.environ.get("TORCHDYNAMO_REPRO_AFTER", None)
# Compiler compilation debug info
# 1: Dumps the original graph out to repro.py if compilation fails
# 2: Dumps a minifier_launcher.py if compilation fails.
# 3: Always dumps a minifier_laucher.py. Good for segfaults.
# 4: Dumps a minifier_launcher.py if the accuracy fails.
repro_level = int(os.environ.get("TORCHDYNAMO_REPRO_LEVEL", 2))
# Not all backends support scalars. Some calls on torch.Tensor (like .item()) return a scalar type.
# When this flag is set to False, we introduce a graph break instead of capturing.
capture_scalar_outputs = False
# Should almost always be true in prod. This relaxes the requirement that cond's true_fn and
# false_fn produces code with identical guards.
enforce_cond_guards_match = True
# Automatically split model graph into pieces to match DDP bucket sizes
# to allow DDP comm/compute overlap
optimize_ddp = False
# If True, raises exception if TorchDynamo is called with a context manager
raise_on_ctx_manager_usage = True
# If True, raise when aot autograd is unsafe to use
raise_on_unsafe_aot_autograd = False
# How to import torchdynamo, either torchdynamo or torch.dynamo
dynamo_import = __name__.replace(".config", "")
# How to import torchinductor, either torchinductor or torch.inductor
inductor_import = dynamo_import.replace("dynamo", "inductor")
# root folder of the project
if "torch." in dynamo_import:
base_dir = dirname(dirname(dirname(abspath(__file__))))
else:
base_dir = dirname(dirname(abspath(__file__)))
debug_dir_root = os.path.join(os.getcwd(), "torchdynamo_debug")
class _AccessLimitingConfig(ModuleType):
def __setattr__(self, name, value):
if name not in _allowed_config_names:
raise AttributeError(f"{__name__}.{name} does not exist")
# automatically set logger level whenever config.log_level is modified
if name == "log_level":
from .logging import set_loggers_level
set_loggers_level(value)
return object.__setattr__(self, name, value)
_allowed_config_names = {*globals().keys()}
sys.modules[__name__].__class__ = _AccessLimitingConfig