| import itertools |
| import logging |
| |
| from torch.hub import _Faketqdm, tqdm |
| |
| # Disable progress bar by default, not in dynamo config because otherwise get a circular import |
| disable_progress = True |
| |
| |
| # Return all loggers that torchdynamo/torchinductor is responsible for |
| def get_loggers(): |
| return [ |
| logging.getLogger("torch.fx.experimental.symbolic_shapes"), |
| logging.getLogger("torch._dynamo"), |
| logging.getLogger("torch._inductor"), |
| ] |
| |
| |
| # Creates a logging function that logs a message with a step # prepended. |
| # get_step_logger should be lazily called (i.e. at runtime, not at module-load time) |
| # so that step numbers are initialized properly. e.g.: |
| |
| # @functools.lru_cache(None) |
| # def _step_logger(): |
| # return get_step_logger(logging.getLogger(...)) |
| |
| # def fn(): |
| # _step_logger()(logging.INFO, "msg") |
| |
| _step_counter = itertools.count(1) |
| |
| # Update num_steps if more phases are added: Dynamo, AOT, Backend |
| # This is very inductor centric |
| # _inductor.utils.has_triton() gives a circular import error here |
| |
| if not disable_progress: |
| try: |
| import triton # noqa: F401 |
| |
| num_steps = 3 |
| except ImportError: |
| num_steps = 2 |
| pbar = tqdm(total=num_steps, desc="torch.compile()", delay=0) |
| |
| |
| def get_step_logger(logger): |
| if not disable_progress: |
| pbar.update(1) |
| if not isinstance(pbar, _Faketqdm): |
| pbar.set_postfix_str(f"{logger.name}") |
| |
| step = next(_step_counter) |
| |
| def log(level, msg): |
| logger.log(level, "Step %s: %s", step, msg) |
| |
| return log |