| r""" |
| This package enables an interface for accessing MPS backend in python |
| """ |
| import torch |
| from .. import Tensor |
| |
| _is_in_bad_fork = getattr(torch._C, "_mps_is_in_bad_fork", lambda: False) |
| _default_mps_generator: torch._C.Generator = None # type: ignore[assignment] |
| |
| # local helper function (not public or exported) |
| def _get_default_mps_generator() -> torch._C.Generator: |
| global _default_mps_generator |
| if _default_mps_generator is None: |
| _default_mps_generator = torch._C._mps_get_default_generator() |
| return _default_mps_generator |
| |
| def synchronize() -> None: |
| r"""Waits for all kernels in all streams on a MPS device to complete.""" |
| return torch._C._mps_synchronize() |
| |
| def get_rng_state() -> Tensor: |
| r"""Returns the random number generator state as a ByteTensor.""" |
| return _get_default_mps_generator().get_state() |
| |
| def set_rng_state(new_state: Tensor) -> None: |
| r"""Sets the random number generator state. |
| |
| Args: |
| new_state (torch.ByteTensor): The desired state |
| """ |
| new_state_copy = new_state.clone(memory_format=torch.contiguous_format) |
| _get_default_mps_generator().set_state(new_state_copy) |
| |
| def manual_seed(seed: int) -> None: |
| r"""Sets the seed for generating random numbers. |
| |
| Args: |
| seed (int): The desired seed. |
| """ |
| # the torch.mps.manual_seed() can be called from the global |
| # torch.manual_seed() in torch/random.py. So we need to make |
| # sure mps is available (otherwise we just return without |
| # erroring out) |
| if not torch.has_mps: |
| return |
| seed = int(seed) |
| _get_default_mps_generator().manual_seed(seed) |
| |
| def seed() -> None: |
| r"""Sets the seed for generating random numbers to a random number.""" |
| _get_default_mps_generator().seed() |
| |
| def empty_cache() -> None: |
| r"""Releases all unoccupied cached memory currently held by the caching |
| allocator so that those can be used in other GPU applications. |
| """ |
| torch._C._mps_emptyCache() |
| |
| def set_per_process_memory_fraction(fraction) -> None: |
| r"""Set memory fraction for limiting process's memory allocation on MPS device. |
| The allowed value equals the fraction multiplied by recommended maximum device memory |
| (obtained from Metal API device.recommendedMaxWorkingSetSize). |
| If trying to allocate more than the allowed value in a process, it will raise an out of |
| memory error in allocator. |
| |
| Args: |
| fraction(float): Range: 0~2. Allowed memory equals total_memory * fraction. |
| |
| .. note:: |
| Passing 0 to fraction means unlimited allocations |
| (may cause system failure if out of memory). |
| Passing fraction greater than 1.0 allows limits beyond the value |
| returned from device.recommendedMaxWorkingSetSize. |
| """ |
| |
| if not isinstance(fraction, float): |
| raise TypeError('Invalid type for fraction argument, must be `float`') |
| if fraction < 0 or fraction > 2: |
| raise ValueError('Invalid fraction value: {}. Allowed range: 0~2'.format(fraction)) |
| |
| torch._C._mps_setMemoryFraction(fraction) |
| |
| def current_allocated_memory() -> int: |
| r"""Returns the current GPU memory occupied by tensors in bytes. |
| |
| .. note:: |
| The returned size does not include cached allocations in |
| memory pools of MPSAllocator. |
| """ |
| return torch._C._mps_currentAllocatedMemory() |
| |
| def driver_allocated_memory() -> int: |
| r"""Returns total GPU memory allocated by Metal driver for the process in bytes. |
| |
| .. note:: |
| The returned size includes cached allocations in MPSAllocator pools |
| as well as allocations from MPS/MPSGraph frameworks. |
| """ |
| return torch._C._mps_driverAllocatedMemory() |
| |
| __all__ = [ |
| 'get_rng_state', 'manual_seed', 'seed', 'set_rng_state', 'synchronize', |
| 'empty_cache', 'set_per_process_memory_fraction', 'current_allocated_memory', |
| 'driver_allocated_memory'] |