| # mypy: allow-untyped-defs |
| import contextlib |
| |
| import torch |
| |
| |
| # Common testing utilities for use in public testing APIs. |
| # NB: these should all be importable without optional dependencies |
| # (like numpy and expecttest). |
| |
| |
| def wrapper_set_seed(op, *args, **kwargs): |
| """Wrapper to set seed manually for some functions like dropout |
| See: https://github.com/pytorch/pytorch/pull/62315#issuecomment-896143189 for more details. |
| """ |
| with freeze_rng_state(): |
| torch.manual_seed(42) |
| output = op(*args, **kwargs) |
| |
| if isinstance(output, torch.Tensor) and output.device.type == "lazy": |
| # We need to call mark step inside freeze_rng_state so that numerics |
| # match eager execution |
| torch._lazy.mark_step() # type: ignore[attr-defined] |
| |
| return output |
| |
| |
| @contextlib.contextmanager |
| def freeze_rng_state(): |
| # no_dispatch needed for test_composite_compliance |
| # Some OpInfos use freeze_rng_state for rng determinism, but |
| # test_composite_compliance overrides dispatch for all torch functions |
| # which we need to disable to get and set rng state |
| with torch.utils._mode_utils.no_dispatch(), torch._C._DisableFuncTorch(): |
| rng_state = torch.get_rng_state() |
| if torch.cuda.is_available(): |
| cuda_rng_state = torch.cuda.get_rng_state() |
| try: |
| yield |
| finally: |
| # Modes are not happy with torch.cuda.set_rng_state |
| # because it clones the state (which could produce a Tensor Subclass) |
| # and then grabs the new tensor's data pointer in generator.set_state. |
| # |
| # In the long run torch.cuda.set_rng_state should probably be |
| # an operator. |
| # |
| # NB: Mode disable is to avoid running cross-ref tests on thes seeding |
| with torch.utils._mode_utils.no_dispatch(), torch._C._DisableFuncTorch(): |
| if torch.cuda.is_available(): |
| torch.cuda.set_rng_state(cuda_rng_state) # type: ignore[possibly-undefined] |
| torch.set_rng_state(rng_state) |