193 lines
6.1 KiB
Python
193 lines
6.1 KiB
Python
# mypy: allow-untyped-defs
|
|
r"""
|
|
This package enables an interface for accessing MPS (Metal Performance Shaders) backend in Python.
|
|
Metal is Apple's API for programming metal GPU (graphics processor unit). Using MPS means that increased
|
|
performance can be achieved, by running work on the metal GPU(s).
|
|
See https://developer.apple.com/documentation/metalperformanceshaders for more details.
|
|
"""
|
|
from typing import Union
|
|
|
|
import torch
|
|
from torch import Tensor
|
|
|
|
|
|
_is_in_bad_fork = getattr(torch._C, "_mps_is_in_bad_fork", lambda: False)
|
|
_default_mps_generator: torch._C.Generator = None # type: ignore[assignment]
|
|
|
|
|
|
# local helper function (not public or exported)
|
|
def _get_default_mps_generator() -> torch._C.Generator:
|
|
global _default_mps_generator
|
|
if _default_mps_generator is None:
|
|
_default_mps_generator = torch._C._mps_get_default_generator()
|
|
return _default_mps_generator
|
|
|
|
|
|
def device_count() -> int:
|
|
r"""Returns the number of available MPS devices."""
|
|
return int(torch._C._has_mps and torch._C._mps_is_available())
|
|
|
|
|
|
def synchronize() -> None:
|
|
r"""Waits for all kernels in all streams on a MPS device to complete."""
|
|
return torch._C._mps_deviceSynchronize()
|
|
|
|
|
|
def get_rng_state(device: Union[int, str, torch.device] = "mps") -> Tensor:
|
|
r"""Returns the random number generator state as a ByteTensor.
|
|
|
|
Args:
|
|
device (torch.device or int, optional): The device to return the RNG state of.
|
|
Default: ``'mps'`` (i.e., ``torch.device('mps')``, the current MPS device).
|
|
"""
|
|
return _get_default_mps_generator().get_state()
|
|
|
|
|
|
def set_rng_state(
|
|
new_state: Tensor, device: Union[int, str, torch.device] = "mps"
|
|
) -> None:
|
|
r"""Sets the random number generator state.
|
|
|
|
Args:
|
|
new_state (torch.ByteTensor): The desired state
|
|
device (torch.device or int, optional): The device to set the RNG state.
|
|
Default: ``'mps'`` (i.e., ``torch.device('mps')``, the current MPS device).
|
|
"""
|
|
new_state_copy = new_state.clone(memory_format=torch.contiguous_format)
|
|
_get_default_mps_generator().set_state(new_state_copy)
|
|
|
|
|
|
def manual_seed(seed: int) -> None:
|
|
r"""Sets the seed for generating random numbers.
|
|
|
|
Args:
|
|
seed (int): The desired seed.
|
|
"""
|
|
# the torch.mps.manual_seed() can be called from the global
|
|
# torch.manual_seed() in torch/random.py. So we need to make
|
|
# sure mps is available (otherwise we just return without
|
|
# erroring out)
|
|
if not torch._C._has_mps:
|
|
return
|
|
seed = int(seed)
|
|
_get_default_mps_generator().manual_seed(seed)
|
|
|
|
|
|
def seed() -> None:
|
|
r"""Sets the seed for generating random numbers to a random number."""
|
|
_get_default_mps_generator().seed()
|
|
|
|
|
|
def empty_cache() -> None:
|
|
r"""Releases all unoccupied cached memory currently held by the caching
|
|
allocator so that those can be used in other GPU applications.
|
|
"""
|
|
torch._C._mps_emptyCache()
|
|
|
|
|
|
def set_per_process_memory_fraction(fraction) -> None:
|
|
r"""Set memory fraction for limiting process's memory allocation on MPS device.
|
|
The allowed value equals the fraction multiplied by recommended maximum device memory
|
|
(obtained from Metal API device.recommendedMaxWorkingSetSize).
|
|
If trying to allocate more than the allowed value in a process, it will raise an out of
|
|
memory error in allocator.
|
|
|
|
Args:
|
|
fraction(float): Range: 0~2. Allowed memory equals total_memory * fraction.
|
|
|
|
.. note::
|
|
Passing 0 to fraction means unlimited allocations
|
|
(may cause system failure if out of memory).
|
|
Passing fraction greater than 1.0 allows limits beyond the value
|
|
returned from device.recommendedMaxWorkingSetSize.
|
|
"""
|
|
|
|
if not isinstance(fraction, float):
|
|
raise TypeError("Invalid type for fraction argument, must be `float`")
|
|
if fraction < 0 or fraction > 2:
|
|
raise ValueError(f"Invalid fraction value: {fraction}. Allowed range: 0~2")
|
|
|
|
torch._C._mps_setMemoryFraction(fraction)
|
|
|
|
|
|
def current_allocated_memory() -> int:
|
|
r"""Returns the current GPU memory occupied by tensors in bytes.
|
|
|
|
.. note::
|
|
The returned size does not include cached allocations in
|
|
memory pools of MPSAllocator.
|
|
"""
|
|
return torch._C._mps_currentAllocatedMemory()
|
|
|
|
|
|
def driver_allocated_memory() -> int:
|
|
r"""Returns total GPU memory allocated by Metal driver for the process in bytes.
|
|
|
|
.. note::
|
|
The returned size includes cached allocations in MPSAllocator pools
|
|
as well as allocations from MPS/MPSGraph frameworks.
|
|
"""
|
|
return torch._C._mps_driverAllocatedMemory()
|
|
|
|
|
|
def recommended_max_memory() -> int:
|
|
r"""Returns recommended max Working set size for GPU memory in bytes.
|
|
|
|
.. note::
|
|
Recommended max working set size for Metal.
|
|
returned from device.recommendedMaxWorkingSetSize.
|
|
"""
|
|
return torch._C._mps_recommendedMaxMemory()
|
|
|
|
|
|
def compile_shader(source: str):
|
|
r"""Compiles compute shader from source and allows one to invoke kernels
|
|
defined there from the comfort of Python runtime
|
|
Example::
|
|
|
|
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_MPS)
|
|
>>> lib = torch.mps.compile_shader(
|
|
... "kernel void full(device float* out, constant float& val, uint idx [[thread_position_in_grid]]) { out[idx] = val; }"
|
|
... )
|
|
>>> x = torch.zeros(16, device="mps")
|
|
>>> lib.full(x, 3.14)
|
|
"""
|
|
from pathlib import Path
|
|
|
|
from torch.utils._cpp_embed_headers import _embed_headers
|
|
|
|
if not hasattr(torch._C, "_mps_compileShader"):
|
|
raise RuntimeError("MPS is not available")
|
|
source = _embed_headers(
|
|
[l + "\n" for l in source.split("\n")],
|
|
[Path(__file__).parent.parent / "include"],
|
|
set(),
|
|
)
|
|
return torch._C._mps_compileShader(source)
|
|
|
|
|
|
def is_available() -> bool:
|
|
return device_count() > 0
|
|
|
|
|
|
from . import profiler
|
|
from .event import Event
|
|
|
|
|
|
__all__ = [
|
|
"compile_shader",
|
|
"device_count",
|
|
"get_rng_state",
|
|
"manual_seed",
|
|
"seed",
|
|
"set_rng_state",
|
|
"synchronize",
|
|
"empty_cache",
|
|
"set_per_process_memory_fraction",
|
|
"current_allocated_memory",
|
|
"driver_allocated_memory",
|
|
"Event",
|
|
"profiler",
|
|
"recommended_max_memory",
|
|
"is_available",
|
|
]
|