import torch | |
from typing import TypeVar | |
from contextlib import contextmanager | |
T = TypeVar('T') | |
# returns if all are the same mode | |
def all_same_mode(modes): | |
return all(tuple(mode == modes[0] for mode in modes)) | |
@contextmanager | |
def no_dispatch(): | |
guard = torch._C._DisableTorchDispatch() # type: ignore[attr-defined] | |
try: | |
yield | |
finally: | |
del guard |