| import contextlib |
| |
| import pickle |
| import unittest |
| from types import FunctionType, ModuleType |
| from typing import Any, Dict, Set |
| from unittest import mock |
| |
| # Types saved/loaded in configs |
| CONFIG_TYPES = (int, float, bool, type(None), str, list, set, tuple, dict) |
| |
| |
| def install_config_module(module): |
| """ |
| Converts a module-level config into a `ConfigModule()` |
| """ |
| |
| class ConfigModuleInstance(ConfigModule): |
| _bypass_keys = set() |
| |
| def visit(source, dest, prefix): |
| """Walk the module structure and move everything to module._config""" |
| for key, value in list(source.__dict__.items()): |
| if key.startswith("__") or isinstance(value, (ModuleType, FunctionType)): |
| continue |
| |
| name = f"{prefix}{key}" |
| if isinstance(value, property) and dest is module: |
| # make @property work at the module level |
| delattr(module, key) |
| setattr(ConfigModuleInstance, key, value) |
| ConfigModuleInstance._bypass_keys.add(key) |
| elif isinstance(value, CONFIG_TYPES): |
| config[name] = value |
| if dest is module: |
| delattr(module, key) |
| elif isinstance(value, type): |
| assert value.__module__ == module.__name__ |
| # a subconfig with `class Blah:` syntax |
| proxy = SubConfigProxy(module, f"{name}.") |
| visit(value, proxy, f"{name}.") |
| setattr(dest, key, proxy) |
| else: |
| raise AssertionError(f"Unhandled config {key}={value} ({type(value)})") |
| |
| config = dict() |
| visit(module, module, "") |
| module._config = config |
| module._allowed_keys = set(config.keys()) |
| module.__class__ = ConfigModuleInstance |
| |
| |
| class ConfigModule(ModuleType): |
| _config: Dict[str, Any] |
| _allowed_keys: Set[str] |
| _bypass_keys: Set[str] |
| |
| def __init__(self): |
| raise NotImplementedError( |
| f"use {__name__}.install_config_module(sys.modules[__name__])" |
| ) |
| |
| def __setattr__(self, name, value): |
| if name in self._bypass_keys: |
| super().__setattr__(name, value) |
| elif name not in self._allowed_keys: |
| raise AttributeError(f"{self.__name__}.{name} does not exist") |
| else: |
| self._config[name] = value |
| |
| def __getattr__(self, name): |
| try: |
| return self._config[name] |
| except KeyError: |
| # make hasattr() work properly |
| raise AttributeError(f"{self.__name__}.{name} does not exist") |
| |
| def __delattr__(self, name): |
| # must support delete because unittest.mock.patch deletes |
| # then recreate things |
| del self._config[name] |
| |
| def save_config(self): |
| """Convert config to a pickled blob""" |
| config = dict(self._config) |
| for key in config.get("_save_config_ignore", ()): |
| config.pop(key) |
| return pickle.dumps(config, protocol=2) |
| |
| def load_config(self, data): |
| """Restore from a prior call to save_config()""" |
| self.to_dict().update(pickle.loads(data)) |
| |
| def to_dict(self): |
| return self._config |
| |
| def patch(self, arg1=None, arg2=None, **kwargs): |
| """ |
| Decorator and/or context manager to make temporary changes to a config. |
| |
| As a decorator: |
| |
| @config.patch("name", val) |
| @config.patch(name1=val1, name2=val2): |
| @config.patch({"name1": val1, "name2", val2}) |
| def foo(...): |
| ... |
| |
| As a context manager: |
| |
| with config.patch("name", val): |
| ... |
| """ |
| if arg1 is not None: |
| if arg2 is not None: |
| # patch("key", True) syntax |
| changes = {arg1: arg2} |
| else: |
| # patch({"key": True}) syntax |
| changes = arg1 |
| assert not kwargs |
| else: |
| # patch(key=True) syntax |
| changes = kwargs |
| assert arg2 is None |
| assert isinstance(changes, dict), f"expected `dict` got {type(changes)}" |
| prior = {} |
| config = self |
| |
| class ConfigPatch(ContextDecorator): |
| def __enter__(self): |
| assert not prior |
| for key in changes.keys(): |
| # KeyError on invalid entry |
| prior[key] = config._config[key] |
| config._config.update(changes) |
| |
| def __exit__(self, exc_type, exc_val, exc_tb): |
| config._config.update(prior) |
| prior.clear() |
| |
| return ConfigPatch() |
| |
| |
| class ContextDecorator(contextlib.ContextDecorator): |
| """ |
| Same as contextlib.ContextDecorator, but with support for |
| `unittest.TestCase` |
| """ |
| |
| def __call__(self, func): |
| if isinstance(func, type) and issubclass(func, unittest.TestCase): |
| |
| class _TestCase(func): |
| @classmethod |
| def setUpClass(cls): |
| self.__enter__() |
| try: |
| super().setUpClass() |
| except Exception: |
| self.__exit__(None, None, None) |
| raise |
| |
| @classmethod |
| def tearDownClass(cls): |
| try: |
| super().tearDownClass() |
| finally: |
| self.__exit__(None, None, None) |
| |
| _TestCase.__name__ = func.__name__ |
| return _TestCase |
| |
| return super().__call__(func) |
| |
| |
| class SubConfigProxy: |
| """ |
| Shim to redirect to main config. |
| `config.triton.cudagraphs` maps to _config["triton.cudagraphs"] |
| """ |
| |
| def __init__(self, config, prefix): |
| # `super().__setattr__` to bypass custom `__setattr__` |
| super().__setattr__("_config", config) |
| super().__setattr__("_prefix", prefix) |
| |
| def __setattr__(self, name, value): |
| return self._config.__setattr__(self._prefix + name, value) |
| |
| def __getattr__(self, name): |
| return self._config.__getattr__(self._prefix + name) |
| |
| def __delattr__(self, name): |
| return self._config.__delattr__(self._prefix + name) |
| |
| |
| def patch_object(obj, name, value): |
| """ |
| Workaround `mock.patch.object` issue with ConfigModule |
| """ |
| if isinstance(obj, ConfigModule): |
| return obj.patch(name, value) |
| return mock.patch.object(obj, name, value) |