| from .modules import * # noqa: F403 |
| from .parameter import ( |
| Parameter as Parameter, |
| UninitializedParameter as UninitializedParameter, |
| UninitializedBuffer as UninitializedBuffer, |
| ) |
| from .parallel import DataParallel as DataParallel |
| from . import init |
| from . import functional |
| from . import utils |
| |
| |
| def factory_kwargs(kwargs): |
| r""" |
| Given kwargs, returns a canonicalized dict of factory kwargs that can be directly passed |
| to factory functions like torch.empty, or errors if unrecognized kwargs are present. |
| |
| This function makes it simple to write code like this:: |
| |
| class MyModule(nn.Module): |
| def __init__(self, **kwargs): |
| factory_kwargs = torch.nn.factory_kwargs(kwargs) |
| self.weight = Parameter(torch.empty(10, **factory_kwargs)) |
| |
| Why should you use this function instead of just passing `kwargs` along directly? |
| |
| 1. This function does error validation, so if there are unexpected kwargs we will |
| immediately report an error, instead of deferring it to the factory call |
| 2. This function supports a special `factory_kwargs` argument, which can be used to |
| explicitly specify a kwarg to be used for factory functions, in the event one of the |
| factory kwargs conflicts with an already existing argument in the signature (e.g. |
| in the signature ``def f(dtype, **kwargs)``, you can specify ``dtype`` for factory |
| functions, as distinct from the dtype argument, by saying |
| ``f(dtype1, factory_kwargs={"dtype": dtype2})``) |
| """ |
| if kwargs is None: |
| return {} |
| simple_keys = {"device", "dtype", "memory_format"} |
| expected_keys = simple_keys | {"factory_kwargs"} |
| if not kwargs.keys() <= expected_keys: |
| raise TypeError(f"unexpected kwargs {kwargs.keys() - expected_keys}") |
| |
| # guarantee no input kwargs is untouched |
| r = dict(kwargs.get("factory_kwargs", {})) |
| for k in simple_keys: |
| if k in kwargs: |
| if k in r: |
| raise TypeError(f"{k} specified twice, in **kwargs and in factory_kwargs") |
| r[k] = kwargs[k] |
| |
| return r |