| # Copyright (c) Facebook, Inc. and its affiliates. |
| # All rights reserved. |
| # |
| # This source code is licensed under the BSD-style license found in the |
| # LICENSE file in the root directory of this source tree. |
| _vmap_levels = [] |
| @dataclass |
| class LevelInfo: |
| level: int |
| alive: bool = True |
| |
| class Dim: |
| def __init__(self, name: str, size: Union[None, int] = None): |
| self.name = name |
| self._size = None |
| self._vmap_level = None |
| if size is not None: |
| self.size = size |
| |
| def __del__(self): |
| if self._vmap_level is not None: |
| _vmap_active_levels[self._vmap_stack].alive = False |
| while not _vmap_levels[-1].alive and current_level() == _vmap_levels[-1].level: |
| _vmap_decrement_nesting() |
| _vmap_levels.pop() |
| |
| @property |
| def size(self): |
| assert self.is_bound |
| return self._size |
| |
| @size.setter |
| def size(self, size: int): |
| if self._size is None: |
| self._size = size |
| self._vmap_level = _vmap_increment_nesting(size, 'same') |
| self._vmap_stack = len(_vmap_levels) |
| _vmap_levels.append(LevelInfo(self._vmap_level)) |
| |
| elif self._size != size: |
| raise DimensionBindError( |
| f"Dim '{self}' previously bound to a dimension of size {self._size} cannot bind to a dimension of size {size}") |
| |
| @property |
| def is_bound(self): |
| return self._size is not None |
| |
| def __repr__(self): |
| return self.name |
| |
| |
| def extract_name(inst): |
| assert inst.opname == 'STORE_FAST' or inst.opname == 'STORE_NAME' |
| return inst.argval |
| |
| _cache = {} |
| def dims(lists=0): |
| frame = inspect.currentframe() |
| assert frame is not None |
| calling_frame = frame.f_back |
| assert calling_frame is not None |
| code, lasti = calling_frame.f_code, calling_frame.f_lasti |
| key = (code, lasti) |
| if key not in _cache: |
| first = lasti // 2 + 1 |
| instructions = list(dis.get_instructions(calling_frame.f_code)) |
| unpack = instructions[first] |
| |
| if unpack.opname == 'STORE_FAST' or unpack.opname == 'STORE_NAME': |
| # just a single dim, not a list |
| name = unpack.argval |
| ctor = Dim if lists == 0 else DimList |
| _cache[key] = lambda: ctor(name=name) |
| else: |
| assert unpack.opname == 'UNPACK_SEQUENCE' |
| ndims = unpack.argval |
| names = tuple(extract_name(instructions[first + 1 + i]) for i in range(ndims)) |
| first_list = len(names) - lists |
| _cache[key] = lambda: tuple(Dim(n) if i < first_list else DimList(name=n) for i, n in enumerate(names)) |
| return _cache[key]() |
| |
| |
| def _dim_set(positional, arg): |
| def convert(a): |
| if isinstance(a, Dim): |
| return a |
| else: |
| assert isinstance(a, int) |
| return positional[a] |
| if arg is None: |
| return positional |
| elif not isinstance(arg, (Dim, int)): |
| return tuple(convert(a) for a in arg) |
| else: |
| return (convert(arg),) |