| |
| |
| |
| |
| |
| from importlib import import_module |
| import pkgutil |
| import sys |
| from . import layers |
| |
| |
| def import_recursive(package): |
| """ |
| Takes a package and imports all modules underneath it |
| """ |
| |
| pkg_dir = package.__path__ |
| module_location = package.__name__ |
| for (_module_loader, name, ispkg) in pkgutil.iter_modules(pkg_dir): |
| module_name = "{}.{}".format(module_location, name) # Module/package |
| module = import_module(module_name) |
| if ispkg: |
| import_recursive(module) |
| |
| |
| def find_subclasses_recursively(base_cls, sub_cls): |
| cur_sub_cls = base_cls.__subclasses__() |
| sub_cls.update(cur_sub_cls) |
| for cls in cur_sub_cls: |
| find_subclasses_recursively(cls, sub_cls) |
| |
| |
| import_recursive(sys.modules[__name__]) |
| |
| model_layer_subcls = set() |
| find_subclasses_recursively(layers.ModelLayer, model_layer_subcls) |
| |
| for cls in list(model_layer_subcls): |
| layers.register_layer(cls.__name__, cls) |