| import threading |
| from contextlib import contextmanager |
| from typing import Optional, Iterator |
| |
| # Simple dynamic scoping implementation. The name "parametrize" comes |
| # from Racket. |
| # |
| # WARNING WARNING: LOOKING TO EDIT THIS FILE? Think carefully about |
| # why you need to add a toggle to the global behavior of code |
| # generation. The parameters here should really only be used |
| # for "temporary" situations, where we need to temporarily change |
| # the codegen in some cases because we cannot conveniently update |
| # all call sites, and are slated to be eliminated once all call |
| # sites are eliminated. If you don't have a plan for how to get there, |
| # DON'T add a new entry here. |
| |
| |
| class Locals(threading.local): |
| use_const_ref_for_mutable_tensors: Optional[bool] = None |
| |
| |
| _locals = Locals() |
| |
| |
| def use_const_ref_for_mutable_tensors() -> bool: |
| assert _locals.use_const_ref_for_mutable_tensors is not None, ( |
| "need to initialize local.use_const_ref_for_mutable_tensors with " |
| "local.parametrize" |
| ) |
| return _locals.use_const_ref_for_mutable_tensors |
| |
| |
| @contextmanager |
| def parametrize(*, use_const_ref_for_mutable_tensors: bool) -> Iterator[None]: |
| old_use_const_ref_for_mutable_tensors = _locals.use_const_ref_for_mutable_tensors |
| try: |
| _locals.use_const_ref_for_mutable_tensors = use_const_ref_for_mutable_tensors |
| yield |
| finally: |
| _locals.use_const_ref_for_mutable_tensors = ( |
| old_use_const_ref_for_mutable_tensors |
| ) |