Adding all project files
This commit is contained in:
parent
6c9e127bdc
commit
cd4316ad0f
42289 changed files with 8009643 additions and 0 deletions
6
venv/Lib/site-packages/torch/_library/__init__.py
Normal file
6
venv/Lib/site-packages/torch/_library/__init__.py
Normal file
|
@ -0,0 +1,6 @@
|
|||
import torch._library.autograd
|
||||
import torch._library.fake_impl
|
||||
import torch._library.simple_registry
|
||||
import torch._library.utils
|
||||
from torch._library.fake_class_registry import register_fake_class
|
||||
from torch._library.triton import capture_triton, triton_op, wrap_triton
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
241
venv/Lib/site-packages/torch/_library/autograd.py
Normal file
241
venv/Lib/site-packages/torch/_library/autograd.py
Normal file
|
@ -0,0 +1,241 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import dataclasses
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Optional, Protocol
|
||||
|
||||
from torch import _C, _ops, autograd, Tensor
|
||||
from torch.utils import _pytree
|
||||
|
||||
from . import utils
|
||||
|
||||
|
||||
class InfoProtocol(Protocol):
|
||||
_backward_fn: Optional[Callable]
|
||||
_setup_context_fn: Optional[Callable]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Info:
|
||||
_backward_fn: Optional[Callable]
|
||||
_setup_context_fn: Optional[Callable]
|
||||
|
||||
|
||||
def make_autograd_impl(op: _ops.OpOverload, info: InfoProtocol) -> Callable:
|
||||
name: str = f"GeneratedBackwardFor_{op._namespace}_{op._opname}_{op._overloadname}"
|
||||
|
||||
has_kwarg_only_args = utils.has_kwarg_only_args(op._schema)
|
||||
|
||||
@dataclass
|
||||
class Metadata:
|
||||
keyset: _C.DispatchKeySet
|
||||
keyword_only_args: dict[str, Any]
|
||||
|
||||
def forward_no_grad(*args):
|
||||
metadata = args[-1]
|
||||
args = args[:-1]
|
||||
|
||||
with _C._AutoDispatchBelowAutograd():
|
||||
keyset = metadata.keyset
|
||||
kwargs = metadata.keyword_only_args
|
||||
result = op.redispatch(keyset & _C._after_autograd_keyset, *args, **kwargs)
|
||||
return result
|
||||
|
||||
def forward(ctx, *args):
|
||||
metadata = args[-1]
|
||||
args = args[:-1]
|
||||
|
||||
with _C._AutoDispatchBelowAutograd():
|
||||
keyset = metadata.keyset
|
||||
kwargs = metadata.keyword_only_args
|
||||
result = op.redispatch(keyset & _C._after_autograd_keyset, *args, **kwargs)
|
||||
if info._setup_context_fn:
|
||||
# The Dispatcher will remove args that are equal to their default
|
||||
# values from (args, kwargs). We're going to add it back so that
|
||||
# the user can access them.
|
||||
#
|
||||
# This is OK to do: The Dispatcher removed the args for serialization
|
||||
# FC/BC reasons (that is, a graph will not store args that are equal
|
||||
# to their default values), but that doesn't matter here. If the user
|
||||
# adds a new default arg, then they must update
|
||||
# their setup_context (along with the rest of their operator
|
||||
# registrations)
|
||||
args, kwargs = utils.fill_defaults(op._schema, args, kwargs)
|
||||
|
||||
if has_kwarg_only_args:
|
||||
info._setup_context_fn(
|
||||
ctx=ctx, inputs=args, keyword_only_inputs=kwargs, output=result
|
||||
)
|
||||
else:
|
||||
info._setup_context_fn(ctx=ctx, inputs=args, output=result)
|
||||
return result
|
||||
|
||||
def backward(ctx, *grads):
|
||||
if info._backward_fn:
|
||||
try:
|
||||
prev_needs_input_grad = ctx.needs_input_grad
|
||||
ctx.needs_input_grad = ctx.needs_input_grad[:-1]
|
||||
result = info._backward_fn(ctx, *grads)
|
||||
finally:
|
||||
ctx.needs_input_grad = prev_needs_input_grad
|
||||
if isinstance(result, tuple):
|
||||
return (*result, None)
|
||||
return result, None
|
||||
raise RuntimeError(
|
||||
f"Trying to backward through {op} but no autograd "
|
||||
f"formula was registered. "
|
||||
f"Please use register_autograd to add one."
|
||||
)
|
||||
|
||||
Generated = type(
|
||||
name,
|
||||
(autograd.Function,),
|
||||
{
|
||||
"forward": staticmethod(forward),
|
||||
"backward": staticmethod(backward),
|
||||
},
|
||||
)
|
||||
|
||||
schema = op._schema
|
||||
if any(
|
||||
utils.is_tensorlist_like_type(a.type)
|
||||
for a in (*schema.arguments, *schema.returns)
|
||||
):
|
||||
Generated = supports_tensorlist(Generated)
|
||||
|
||||
# The dispatcher passes any keyword-only-args as kwargs and the
|
||||
# rest of the args (even if specified as kwargs) as args.
|
||||
def autograd_impl(keyset, *args, **keyword_only_args):
|
||||
if _C.is_grad_enabled() and _pytree.tree_any_only(
|
||||
Tensor, lambda x: x.requires_grad, args, not_list_of_tensor
|
||||
):
|
||||
result = Generated.apply(*args, Metadata(keyset, keyword_only_args)) # type: ignore[attr-defined]
|
||||
else:
|
||||
result = forward_no_grad(*args, Metadata(keyset, keyword_only_args))
|
||||
return result
|
||||
|
||||
return autograd_impl
|
||||
|
||||
|
||||
def supports_tensorlist(cls: Any) -> Any:
|
||||
"""Allows a given autograd.Function class to support List[Tensor] inputs/outputs.
|
||||
|
||||
Regular autograd.Function has a constraint that it only directly supports autograd for
|
||||
Tensors. Applying @supports_tensorlist enables an autograd.Function to support
|
||||
autograd for List[Tensor] inputs and outputs.
|
||||
"""
|
||||
orig_forward = cls.forward
|
||||
orig_backward = cls.backward
|
||||
orig_apply = cls.apply
|
||||
|
||||
@dataclass
|
||||
class Metadata:
|
||||
input_spec: spec_t
|
||||
output_spec: Optional[spec_t] = None
|
||||
result_is_tuple: Optional[bool] = None
|
||||
|
||||
def new_forward(ctx, *args):
|
||||
metadata = args[-1]
|
||||
args = args[:-1]
|
||||
if not isinstance(metadata, Metadata):
|
||||
raise NotImplementedError(
|
||||
"NYI: calling supports_tensorlist autograd.Function.forward directly. "
|
||||
"You should probably be calling .apply instead. "
|
||||
"Please file an issue if not."
|
||||
)
|
||||
args = unflatten(list(args), metadata.input_spec)
|
||||
result = orig_forward(ctx, *args)
|
||||
metadata.result_is_tuple = isinstance(result, tuple)
|
||||
if not metadata.result_is_tuple:
|
||||
result = (result,)
|
||||
flat_result, output_spec = flatten(result, not_list_of_tensor)
|
||||
metadata.output_spec = output_spec
|
||||
|
||||
if hasattr(ctx, "_pt_metadata"):
|
||||
raise RuntimeError(
|
||||
"Please don't set ctx._pt_metadata; PyTorch uses it to store info"
|
||||
)
|
||||
ctx._pt_metadata = metadata
|
||||
|
||||
return tuple(flat_result)
|
||||
|
||||
def new_backward(ctx, *grads):
|
||||
if not hasattr(ctx, "_pt_metadata"):
|
||||
raise NotImplementedError(
|
||||
"NYI: calling supports_tensorlist autograd.Function.backward directly. "
|
||||
"This will automatically get called by PyTorch autograd. "
|
||||
"Please file an issue if you need this."
|
||||
)
|
||||
|
||||
metadata = ctx._pt_metadata
|
||||
grads = unflatten(list(grads), metadata.output_spec)
|
||||
|
||||
# If the user's input is ([x, y, z], w),
|
||||
# then needs_input_grad is (bool, bool, bool, bool, bool).
|
||||
# We need to
|
||||
# 1. get rid of the additional bool (which comes from the extra
|
||||
# `metadata input`)
|
||||
# 2. unflatten to get the right structure.
|
||||
prev_needs_input_grad = ctx.needs_input_grad
|
||||
try:
|
||||
ctx.needs_input_grad = unflatten(
|
||||
list(ctx.needs_input_grad[:-1]), metadata.input_spec
|
||||
)
|
||||
grad_inputs = orig_backward(ctx, *grads)
|
||||
finally:
|
||||
ctx.needs_input_grad = prev_needs_input_grad
|
||||
|
||||
if not isinstance(grad_inputs, tuple):
|
||||
grad_inputs = (grad_inputs,)
|
||||
# Assume that any Nones in the backward are Tensors.
|
||||
# If the forward has an arg that is [1, 2, 3], the backward should
|
||||
# return None as the grad.
|
||||
# If the forward has an arg that is [tensor, tensor], the backward
|
||||
# may return [None, None], [grad, None], [None, grad], or [grad, grad].
|
||||
flat_grad_inputs, grad_inputs_spec = flatten(
|
||||
grad_inputs, not_list_of_optional_tensor
|
||||
)
|
||||
if grad_inputs_spec != metadata.input_spec:
|
||||
raise RuntimeError(
|
||||
f"Expected the return from backward to be of the same structure "
|
||||
f"as the inputs. Got: {grad_inputs_spec} (return from backward), "
|
||||
f"{metadata.input_spec} (inputs)"
|
||||
)
|
||||
return tuple(flat_grad_inputs + [None])
|
||||
|
||||
def new_apply(*args):
|
||||
flat_args, input_spec = flatten(args, is_leaf=not_list_of_tensor)
|
||||
metadata = Metadata(input_spec)
|
||||
result = orig_apply(*flat_args, metadata) # type: ignore[misc]
|
||||
assert metadata.output_spec is not None
|
||||
result = unflatten(list(result), metadata.output_spec)
|
||||
if not metadata.result_is_tuple:
|
||||
assert isinstance(result, tuple)
|
||||
assert len(result) == 1
|
||||
return result[0]
|
||||
return result
|
||||
|
||||
cls.forward = new_forward
|
||||
cls.backward = new_backward
|
||||
cls.apply = new_apply
|
||||
return cls
|
||||
|
||||
|
||||
def not_list_of_tensor(tree):
|
||||
if isinstance(tree, tuple):
|
||||
return False
|
||||
if isinstance(tree, list):
|
||||
return any(not isinstance(l, Tensor) for l in tree)
|
||||
return True
|
||||
|
||||
|
||||
def not_list_of_optional_tensor(tree):
|
||||
if isinstance(tree, tuple):
|
||||
return False
|
||||
if isinstance(tree, list):
|
||||
return any(l is not None and not isinstance(l, Tensor) for l in tree)
|
||||
return True
|
||||
|
||||
|
||||
flatten = _pytree.tree_flatten
|
||||
unflatten = _pytree.tree_unflatten
|
||||
spec_t = _pytree.TreeSpec
|
921
venv/Lib/site-packages/torch/_library/custom_ops.py
Normal file
921
venv/Lib/site-packages/torch/_library/custom_ops.py
Normal file
|
@ -0,0 +1,921 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import collections
|
||||
import inspect
|
||||
import logging
|
||||
import weakref
|
||||
from collections.abc import Iterable, Sequence
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Callable, Literal, Optional, overload, Union
|
||||
|
||||
import torch
|
||||
from torch import _C, _ops, Tensor
|
||||
from torch.types import _dtype
|
||||
from torch.utils._exposed_in import exposed_in
|
||||
|
||||
from . import autograd, utils
|
||||
|
||||
|
||||
device_types_t = Optional[Union[str, Sequence[str]]]
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@overload
|
||||
def custom_op(
|
||||
name: str,
|
||||
fn: Literal[None] = None,
|
||||
/,
|
||||
*,
|
||||
mutates_args: Union[str, Iterable[str]],
|
||||
device_types: device_types_t = None,
|
||||
schema: Optional[str] = None,
|
||||
) -> Callable[[Callable[..., object]], "CustomOpDef"]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def custom_op(
|
||||
name: str,
|
||||
fn: Callable[..., object],
|
||||
/,
|
||||
*,
|
||||
mutates_args: Union[str, Iterable[str]],
|
||||
device_types: device_types_t = None,
|
||||
schema: Optional[str] = None,
|
||||
) -> "CustomOpDef":
|
||||
...
|
||||
|
||||
|
||||
@exposed_in("torch.library")
|
||||
def custom_op(
|
||||
name: str,
|
||||
fn: Optional[Callable] = None,
|
||||
/,
|
||||
*,
|
||||
mutates_args: Union[str, Iterable[str]],
|
||||
device_types: device_types_t = None,
|
||||
schema: Optional[str] = None,
|
||||
) -> Union[Callable[[Callable[..., object]], "CustomOpDef"], "CustomOpDef"]:
|
||||
"""Wraps a function into custom operator.
|
||||
|
||||
Reasons why you may want to create a custom op include:
|
||||
- Wrapping a third-party library or custom kernel to work with PyTorch
|
||||
subsystems like Autograd.
|
||||
- Preventing torch.compile/export/FX tracing from peeking inside your function.
|
||||
|
||||
This API is used as a decorator around a function (please see examples).
|
||||
The provided function must have type hints; these are needed to interface
|
||||
with PyTorch's various subsystems.
|
||||
|
||||
Args:
|
||||
name (str): A name for the custom op that looks like "{namespace}::{name}",
|
||||
e.g. "mylib::my_linear". The name is used as the op's stable identifier
|
||||
in PyTorch subsystems (e.g. torch.export, FX graphs).
|
||||
To avoid name collisions, please use your project name as the namespace;
|
||||
e.g. all custom ops in pytorch/fbgemm use "fbgemm" as the namespace.
|
||||
mutates_args (Iterable[str] or "unknown"): The names of args that the function mutates.
|
||||
This MUST be accurate, otherwise, the behavior is undefined. If "unknown",
|
||||
it pessimistically assumes that all inputs to the operator are being mutated.
|
||||
device_types (None | str | Sequence[str]): The device type(s) the function
|
||||
is valid for. If no device type is provided, then the function
|
||||
is used as the default implementation for all device types.
|
||||
Examples: "cpu", "cuda".
|
||||
When registering a device-specific implementation for an operator that accepts no Tensors,
|
||||
we require the operator to have a "device: torch.device argument".
|
||||
schema (None | str): A schema string for the operator. If None
|
||||
(recommended) we'll infer a schema for the operator from its type
|
||||
annotations. We recommend letting us infer a schema unless you
|
||||
have a specific reason not to.
|
||||
Example: "(Tensor x, int y) -> (Tensor, Tensor)".
|
||||
|
||||
.. note::
|
||||
We recommend not passing in a ``schema`` arg and instead letting us infer
|
||||
it from the type annotations. It is error-prone to write your own schema.
|
||||
You may wish to provide your own schema if our interpretation of
|
||||
the type annotation is not what you want.
|
||||
For more info on how to write a schema string, see
|
||||
`here <https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md#func>`_
|
||||
|
||||
Examples::
|
||||
>>> import torch
|
||||
>>> from torch import Tensor
|
||||
>>> from torch.library import custom_op
|
||||
>>> import numpy as np
|
||||
>>>
|
||||
>>> @custom_op("mylib::numpy_sin", mutates_args=())
|
||||
>>> def numpy_sin(x: Tensor) -> Tensor:
|
||||
>>> x_np = x.cpu().numpy()
|
||||
>>> y_np = np.sin(x_np)
|
||||
>>> return torch.from_numpy(y_np).to(device=x.device)
|
||||
>>>
|
||||
>>> x = torch.randn(3)
|
||||
>>> y = numpy_sin(x)
|
||||
>>> assert torch.allclose(y, x.sin())
|
||||
>>>
|
||||
>>> # Example of a custom op that only works for one device type.
|
||||
>>> @custom_op("mylib::numpy_sin_cpu", mutates_args=(), device_types="cpu")
|
||||
>>> def numpy_sin_cpu(x: Tensor) -> Tensor:
|
||||
>>> x_np = x.numpy()
|
||||
>>> y_np = np.sin(x_np)
|
||||
>>> return torch.from_numpy(y_np)
|
||||
>>>
|
||||
>>> x = torch.randn(3)
|
||||
>>> y = numpy_sin_cpu(x)
|
||||
>>> assert torch.allclose(y, x.sin())
|
||||
>>>
|
||||
>>> # Example of a custom op that mutates an input
|
||||
>>> @custom_op("mylib::numpy_sin_inplace", mutates_args={"x"}, device_types="cpu")
|
||||
>>> def numpy_sin_inplace(x: Tensor) -> None:
|
||||
>>> x_np = x.numpy()
|
||||
>>> np.sin(x_np, out=x_np)
|
||||
>>>
|
||||
>>> x = torch.randn(3)
|
||||
>>> expected = x.sin()
|
||||
>>> numpy_sin_inplace(x)
|
||||
>>> assert torch.allclose(x, expected)
|
||||
>>>
|
||||
>>> # Example of a factory function
|
||||
>>> @torch.library.custom_op("mylib::bar", mutates_args={}, device_types="cpu")
|
||||
>>> def bar(device: torch.device) -> Tensor:
|
||||
>>> return torch.ones(3)
|
||||
>>>
|
||||
>>> bar("cpu")
|
||||
|
||||
"""
|
||||
|
||||
def inner(fn: Callable[..., object]) -> CustomOpDef:
|
||||
import torch
|
||||
|
||||
if schema is None:
|
||||
schema_str = torch.library.infer_schema(fn, mutates_args=mutates_args)
|
||||
else:
|
||||
schema_str = schema
|
||||
|
||||
namespace, opname = name.split("::")
|
||||
result = CustomOpDef(namespace, opname, schema_str, fn)
|
||||
if schema is not None:
|
||||
# Check that schema's alias annotations match those of `mutates_args`.
|
||||
expected = set()
|
||||
for arg in result._opoverload._schema.arguments:
|
||||
if arg.alias_info is not None and arg.alias_info.is_write:
|
||||
expected.add(arg.name)
|
||||
if expected != set(mutates_args):
|
||||
raise ValueError(
|
||||
f"Attempted to create a custom op with `mutates_args={mutates_args}` "
|
||||
f"and `schema={schema}. The schema suggests that the op mutates {expected}"
|
||||
f"which is different from what was provided to us in `mutates_args`. "
|
||||
f"Please make these consistent."
|
||||
)
|
||||
result.register_kernel(device_types)(fn)
|
||||
return result
|
||||
|
||||
if fn is None:
|
||||
return inner
|
||||
return inner(fn)
|
||||
|
||||
|
||||
class CustomOpDef:
|
||||
"""CustomOpDef is a wrapper around a function that turns it into a custom op.
|
||||
|
||||
It has various methods for registering additional behavior for this
|
||||
custom op.
|
||||
|
||||
You should not instantiate CustomOpDef directly; instead, use the
|
||||
:func:`torch.library.custom_op` API.
|
||||
"""
|
||||
|
||||
def __init__(self, namespace: str, name: str, schema: str, fn: Callable) -> None:
|
||||
# Fields used to interface with the PyTorch dispatcher
|
||||
self._namespace = namespace
|
||||
self._name = name
|
||||
self._schema = schema
|
||||
|
||||
self._init_fn = fn
|
||||
|
||||
self._backend_fns: dict[Union[str, None], Callable] = {}
|
||||
self._abstract_fn: Optional[Callable] = None
|
||||
self._setup_context_fn: Optional[Callable] = None
|
||||
self._backward_fn: Optional[Callable] = None
|
||||
self._torch_dispatch_fns: dict[type, Callable] = {}
|
||||
self._vmap_fn: Optional[Callable] = None
|
||||
self._autocast_cuda_dtype: Optional[_dtype] = None
|
||||
self._autocast_cpu_dtype: Optional[_dtype] = None
|
||||
|
||||
self._lib = get_library_allowing_overwrite(self._namespace, self._name)
|
||||
self._register_to_dispatcher()
|
||||
self._disabled_kernel: set = set()
|
||||
OPDEFS[self._qualname] = self
|
||||
|
||||
@property
|
||||
def _qualname(self) -> str:
|
||||
return f"{self._namespace}::{self._name}"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<CustomOpDef({self._qualname})>"
|
||||
|
||||
@contextmanager
|
||||
def set_kernel_enabled(self, device_type: str, enabled: bool = True):
|
||||
"""
|
||||
Disable or re-enable an already registered kernel for this custom operator.
|
||||
|
||||
If the kernel is already disabled/enabled, this is a no-op.
|
||||
|
||||
Note:
|
||||
If a kernel is first disabled and then registered, it is disabled until enabled again.
|
||||
|
||||
Args:
|
||||
device_type (str): The device type to disable/enable the kernel for.
|
||||
disable (bool): Whether to disable or enable the kernel.
|
||||
|
||||
Example:
|
||||
>>> inp = torch.randn(1)
|
||||
>>>
|
||||
>>> # define custom op `f`.
|
||||
>>> @custom_op("mylib::f", mutates_args=())
|
||||
>>> def f(x: Tensor) -> Tensor:
|
||||
>>> return torch.zeros(1)
|
||||
>>>
|
||||
>>> print(f(inp)) # tensor([0.]), default kernel
|
||||
>>>
|
||||
>>> @f.register_kernel("cpu")
|
||||
>>> def _(x):
|
||||
>>> return torch.ones(1)
|
||||
>>>
|
||||
>>> print(f(inp)) # tensor([1.]), CPU kernel
|
||||
>>>
|
||||
>>> # temporarily disable the CPU kernel
|
||||
>>> with f.set_kernel_enabled("cpu", enabled = False):
|
||||
>>> print(f(inp)) # tensor([0.]) with CPU kernel disabled
|
||||
|
||||
"""
|
||||
action = "enable" if enabled else "disable"
|
||||
originally_disabled = device_type in self._disabled_kernel
|
||||
if device_type not in self._backend_fns:
|
||||
log.warning(
|
||||
"Attempted to %s kernel for %s but no kernel was registered for this device type.",
|
||||
action,
|
||||
device_type,
|
||||
)
|
||||
|
||||
if not enabled:
|
||||
if originally_disabled:
|
||||
log.warning(
|
||||
"Attempted to disable kernel for %s but it was already disabled.",
|
||||
device_type,
|
||||
)
|
||||
else:
|
||||
self._disabled_kernel.add(device_type)
|
||||
else: # enable the kernel
|
||||
if not originally_disabled:
|
||||
log.warning(
|
||||
"Attempted to enable kernel for %s but it was already enabled.",
|
||||
device_type,
|
||||
)
|
||||
else:
|
||||
self._disabled_kernel.remove(device_type)
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
# restore original state
|
||||
if originally_disabled:
|
||||
self._disabled_kernel.add(device_type)
|
||||
else:
|
||||
self._disabled_kernel.discard(device_type)
|
||||
|
||||
def register_kernel(
|
||||
self, device_types: device_types_t, fn: Optional[Callable] = None, /
|
||||
) -> Callable:
|
||||
"""Register an implementation for a device type for this operator.
|
||||
|
||||
Some valid device_types are: "cpu", "cuda", "xla", "mps", "ipu", "xpu".
|
||||
This API may be used as a decorator.
|
||||
|
||||
Args:
|
||||
fn (Callable): The function to register as the implementation for
|
||||
the given device types.
|
||||
device_types (str | Sequence[str]): The device device_types to register an impl to.
|
||||
|
||||
Examples::
|
||||
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
|
||||
>>> import torch
|
||||
>>> from torch import Tensor
|
||||
>>> from torch.library import custom_op
|
||||
>>> import numpy as np
|
||||
>>>
|
||||
>>> # Create a custom op that works on cpu
|
||||
>>> @custom_op("mylib::numpy_sin", mutates_args=(), device_types="cpu")
|
||||
>>> def numpy_sin(x: Tensor) -> Tensor:
|
||||
>>> x_np = x.numpy()
|
||||
>>> y_np = np.sin(x_np)
|
||||
>>> return torch.from_numpy(y_np)
|
||||
>>>
|
||||
>>> # Add implementations for the cuda device
|
||||
>>> @numpy_sin.register_kernel("cuda")
|
||||
>>> def _(x):
|
||||
>>> x_np = x.cpu().numpy()
|
||||
>>> y_np = np.sin(x_np)
|
||||
>>> return torch.from_numpy(y_np).to(device=x.device)
|
||||
>>>
|
||||
>>> x_cpu = torch.randn(3)
|
||||
>>> x_cuda = x_cpu.cuda()
|
||||
>>> assert torch.allclose(numpy_sin(x_cpu), x_cpu.sin())
|
||||
>>> assert torch.allclose(numpy_sin(x_cuda), x_cuda.sin())
|
||||
|
||||
"""
|
||||
|
||||
def inner(fn):
|
||||
if device_types is None or isinstance(device_types, str):
|
||||
dtypes: list[Union[str, None]] = [device_types]
|
||||
else:
|
||||
dtypes = list(device_types)
|
||||
for device_type in dtypes:
|
||||
if device_type not in self._backend_fns:
|
||||
|
||||
def backend_impl(*args, **kwargs):
|
||||
result = self._backend_fns[device_type](*args, **kwargs)
|
||||
|
||||
def get_module():
|
||||
fn = self._backend_fns[device_type]
|
||||
return inspect.getmodule(fn)
|
||||
|
||||
utils.check_aliasing_constraint(
|
||||
self._name,
|
||||
utils.iter_tensors(args, kwargs),
|
||||
result,
|
||||
get_module,
|
||||
)
|
||||
return result
|
||||
|
||||
if device_type is None:
|
||||
self._lib.impl(
|
||||
self._name, backend_impl, "CompositeExplicitAutograd"
|
||||
)
|
||||
else:
|
||||
self._lib.impl(
|
||||
self._name,
|
||||
backend_impl,
|
||||
_C._dispatch_key_for_device(device_type),
|
||||
)
|
||||
|
||||
# Wrap function to choose between the default implementation or the device-specific
|
||||
# implementation depending on if the kernel is disabled.
|
||||
@torch._disable_dynamo
|
||||
def wrapped_fn(*args, **kwargs):
|
||||
if device_type in self._disabled_kernel:
|
||||
return self._init_fn(*args, **kwargs)
|
||||
else:
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
self._backend_fns[device_type] = wrapped_fn
|
||||
return fn
|
||||
|
||||
if device_types is not None and not utils.has_tensor_arg(
|
||||
self._opoverload._schema
|
||||
):
|
||||
device_arg_index = utils.get_device_arg_index(self._opoverload._schema)
|
||||
if device_arg_index is None:
|
||||
raise ValueError(
|
||||
"Functions without tensor inputs are required to have a `device: torch.device` argument"
|
||||
)
|
||||
self._register_backend_select_dispatcher(device_arg_index)
|
||||
|
||||
# See NOTE: [Supporting decorator and non-decorator usage]
|
||||
if fn is None:
|
||||
return inner
|
||||
return inner(fn)
|
||||
|
||||
def register_fake(self, fn: Callable, /) -> Callable:
|
||||
r"""Register a FakeTensor implementation for this custom op.
|
||||
|
||||
This is necessary to get the operator to work efficiently with torch.compile.
|
||||
|
||||
The Fake impl (sometimes also known as a meta kernel or abstract impl)
|
||||
specifies the behavior of this operator on Tensors that carry no data.
|
||||
Given some input Tensors with certain properties
|
||||
(sizes/strides/storage_offset/device), it specifies what the properties of
|
||||
the output Tensors are.
|
||||
|
||||
Please see :func:`torch.library.impl_abstract` for more details.
|
||||
|
||||
Args:
|
||||
fn (Callable): The function to register as the FakeTensor
|
||||
implementation.
|
||||
|
||||
Examples:
|
||||
>>> import torch
|
||||
>>> import numpy as np
|
||||
>>> from torch import Tensor
|
||||
>>>
|
||||
>>> # Example 1: an operator without data-dependent output shape
|
||||
>>> @torch.library.custom_op("mylib::linear", mutates_args=())
|
||||
>>> def linear(x: Tensor, weight: Tensor, bias: Tensor) -> Tensor:
|
||||
>>> return (x @ weight.t()) + bias
|
||||
>>>
|
||||
>>> @linear.register_fake
|
||||
>>> def _(x, weight, bias):
|
||||
>>> assert x.dim() == 2
|
||||
>>> assert weight.dim() == 2
|
||||
>>> assert bias.dim() == 1
|
||||
>>> assert x.shape[1] == weight.shape[1]
|
||||
>>> assert weight.shape[0] == bias.shape[0]
|
||||
>>> assert x.device == weight.device
|
||||
>>> return x.new_empty(x.size(0), weight.size(0))
|
||||
>>>
|
||||
>>> x = torch.randn(2, 2)
|
||||
>>> weight = torch.randn(2, 2)
|
||||
>>> bias = torch.randn(2)
|
||||
>>> # xdoctest: +SKIP("Requires Python <= 3.11")
|
||||
>>> out = torch.compile(linear, fullgraph=True)(x, weight, bias)
|
||||
>>> # xdoctest: +SKIP("Requires Python <= 3.11")
|
||||
>>> assert torch.allclose(out, torch.nn.functional.linear(x, weight, bias))
|
||||
>>>
|
||||
>>> # Example 2: an operator with data-dependent output shape
|
||||
>>> @torch.library.custom_op("mylib::nonzero", mutates_args=())
|
||||
>>> def nonzero(x: Tensor) -> Tensor:
|
||||
>>> x_np = x.cpu().numpy()
|
||||
>>> res = np.stack(np.nonzero(x_np), axis=1)
|
||||
>>> return torch.tensor(res, device=x.device)
|
||||
>>>
|
||||
>>> @nonzero.register_fake
|
||||
>>> def _(x):
|
||||
>>> # Number of nonzero-elements is data-dependent.
|
||||
>>> # Since we cannot peek at the data in an abstract impl,
|
||||
>>> # we use the ctx object to construct a new symint that
|
||||
>>> # represents the data-dependent size.
|
||||
>>> ctx = torch.library.get_ctx()
|
||||
>>> nnz = ctx.new_dynamic_size()
|
||||
>>> shape = [nnz, x.dim()]
|
||||
>>> result = x.new_empty(shape, dtype=torch.int64)
|
||||
>>> return result
|
||||
>>>
|
||||
>>> x = torch.tensor([0, 1, 2, 0, 0, 1])
|
||||
>>> # xdoctest: +SKIP("Requires Python <= 3.11")
|
||||
>>> out = torch.compile(nonzero, fullgraph=True)(x)
|
||||
>>> # xdoctest: +SKIP("Requires Python <= 3.11")
|
||||
>>> assert torch.allclose(out, x.nonzero())
|
||||
|
||||
"""
|
||||
self._abstract_fn = fn
|
||||
return fn
|
||||
|
||||
def register_torch_dispatch(
|
||||
self, torch_dispatch_class: Any, fn: Optional[Callable] = None, /
|
||||
) -> Callable:
|
||||
r"""Registers a torch_dispatch rule for the given operator and ``torch_dispatch_class``.
|
||||
|
||||
This allows for open registration to specify the behavior between the operator
|
||||
and the ``torch_dispatch_class`` without needing to modify the ``torch_dispatch_class``
|
||||
or the operator directly.
|
||||
|
||||
Please see :func:`torch.library.register_torch_dispatch` for examples and more details.
|
||||
"""
|
||||
|
||||
def register(fn):
|
||||
if torch_dispatch_class not in self._torch_dispatch_fns:
|
||||
|
||||
def inner(*args, **kwargs):
|
||||
return self._torch_dispatch_fns[torch_dispatch_class](
|
||||
*args, **kwargs
|
||||
)
|
||||
|
||||
self._lib._register_torch_dispatch_rule(
|
||||
self._name, torch_dispatch_class, inner
|
||||
)
|
||||
self._torch_dispatch_fns[torch_dispatch_class] = fn
|
||||
return fn
|
||||
|
||||
if fn is None:
|
||||
return register
|
||||
else:
|
||||
return register(fn)
|
||||
|
||||
def register_autograd(
|
||||
self,
|
||||
backward: Callable,
|
||||
/,
|
||||
*,
|
||||
setup_context: Optional[Callable] = None,
|
||||
) -> None:
|
||||
r"""Register a backward formula for this custom op.
|
||||
|
||||
In order for an operator to work with autograd, you need to register
|
||||
a backward formula:
|
||||
1. You must tell us how to compute gradients during the backward pass
|
||||
by providing us a "backward" function.
|
||||
2. If you need any values from the forward to compute gradients, you can
|
||||
use `setup_context` to save values for backward.
|
||||
|
||||
``backward_fn`` runs during the backward pass. It accepts ``(ctx, *grads)``:
|
||||
- ``grads`` is one or more gradients. The number of gradients matches
|
||||
the number of outputs of the operator.
|
||||
The ``ctx`` object is `the same ctx object <context_method_mixins>`_ used by
|
||||
:class:`torch.autograd.Function`. The semantics of ``backward_fn`` are the
|
||||
same as :meth:`torch.autograd.Function.backward`.
|
||||
|
||||
``setup_context(ctx, inputs, output)`` runs during the forward pass.
|
||||
Please save quantities needed for backward onto the ``ctx`` object via
|
||||
either :meth:`torch.autograd.function.FunctionCtx.save_for_backward`
|
||||
or assigning them as attributes of ``ctx``. If your custom op has
|
||||
kwarg-only arguments, we expect the signature of ``setup_context``
|
||||
to be ``setup_context(ctx, inputs, keyword_only_inputs, output)``.
|
||||
|
||||
Both ``setup_context_fn`` and ``backward_fn`` must be traceable. That is,
|
||||
they may not directly access :meth:`torch.Tensor.data_ptr` and they must
|
||||
not depend on or mutate global state. If you need a non-traceable backward,
|
||||
you can make it a separate custom_op that you call inside ``backward_fn``.
|
||||
|
||||
If you need different autograd behavior on different devices, then we
|
||||
recommend creating two different custom operators, one for each device
|
||||
that needs different behavior, and switching between them at runtime.
|
||||
|
||||
Examples:
|
||||
>>> import torch
|
||||
>>> import numpy as np
|
||||
>>> from torch import Tensor
|
||||
>>>
|
||||
>>> @torch.library.custom_op("mylib::numpy_sin", mutates_args=())
|
||||
>>> def numpy_sin(x: Tensor) -> Tensor:
|
||||
>>> x_np = x.cpu().numpy()
|
||||
>>> y_np = np.sin(x_np)
|
||||
>>> return torch.from_numpy(y_np).to(device=x.device)
|
||||
>>>
|
||||
>>> def setup_context(ctx, inputs, output) -> Tensor:
|
||||
>>> x, = inputs
|
||||
>>> ctx.save_for_backward(x)
|
||||
>>>
|
||||
>>> def backward(ctx, grad):
|
||||
>>> x, = ctx.saved_tensors
|
||||
>>> return grad * x.cos()
|
||||
>>>
|
||||
>>> numpy_sin.register_autograd(backward, setup_context=setup_context)
|
||||
>>>
|
||||
>>> x = torch.randn(3, requires_grad=True)
|
||||
>>> y = numpy_sin(x)
|
||||
>>> grad_x, = torch.autograd.grad(y, x, torch.ones_like(y))
|
||||
>>> assert torch.allclose(grad_x, x.cos())
|
||||
>>>
|
||||
>>> # Example with a keyword-only arg
|
||||
>>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=())
|
||||
>>> def numpy_mul(x: Tensor, *, val: float) -> Tensor:
|
||||
>>> x_np = x.cpu().numpy()
|
||||
>>> y_np = x_np * val
|
||||
>>> return torch.from_numpy(y_np).to(device=x.device)
|
||||
>>>
|
||||
>>> def setup_context(ctx, inputs, keyword_only_inputs, output) -> Tensor:
|
||||
>>> ctx.val = keyword_only_inputs["val"]
|
||||
>>>
|
||||
>>> def backward(ctx, grad):
|
||||
>>> return grad * ctx.val
|
||||
>>>
|
||||
>>> numpy_mul.register_autograd(backward, setup_context=setup_context)
|
||||
>>>
|
||||
>>> x = torch.randn(3, requires_grad=True)
|
||||
>>> y = numpy_mul(x, val=3.14)
|
||||
>>> grad_x, = torch.autograd.grad(y, x, torch.ones_like(y))
|
||||
>>> assert torch.allclose(grad_x, torch.full_like(x, 3.14))
|
||||
|
||||
"""
|
||||
schema = self._opoverload._schema
|
||||
if not utils.is_functional_schema(schema):
|
||||
raise RuntimeError(
|
||||
f"Cannot register autograd formula for non-functional operator "
|
||||
f"{self} with schema {schema}. Please create "
|
||||
f"a functional operator and register an autograd formula for that."
|
||||
)
|
||||
|
||||
self._backward_fn = backward
|
||||
self._setup_context_fn = setup_context
|
||||
|
||||
def _register_to_dispatcher(self) -> None:
|
||||
if torch._running_with_deploy():
|
||||
utils.warn_deploy(stacklevel=5)
|
||||
return
|
||||
|
||||
lib = self._lib
|
||||
schema_str = self._name + self._schema
|
||||
cpp_schema = _C.parse_schema(schema_str)
|
||||
if utils.has_kwarg_only_tensors(cpp_schema):
|
||||
# If you want to support this, the progression is:
|
||||
# - supporting kwarg-only Tensors that are non-differentiable
|
||||
# - supporting kwarg-only Tensors (regardless of differentiability)
|
||||
raise NotImplementedError(
|
||||
f"custom_op with kwarg-only Tensor args. Please make your "
|
||||
f"tensors not kwarg-only. Got: {schema_str}"
|
||||
)
|
||||
|
||||
lib.define(
|
||||
schema_str,
|
||||
tags=[_C.Tag.pt2_compliant_tag, _C.Tag.needs_fixed_stride_order],
|
||||
)
|
||||
self._opoverload = utils.lookup_op(self._qualname)
|
||||
|
||||
def fake_impl(*args, **kwargs):
|
||||
if self._abstract_fn is None:
|
||||
if utils.can_generate_trivial_fake_impl(self._opoverload):
|
||||
return None
|
||||
raise RuntimeError(
|
||||
f"There was no fake impl registered for {self}. "
|
||||
f"This is necessary for torch.compile/export/fx tracing to work. "
|
||||
f"Please use `{self._init_fn.__name__}.register_fake` to add an "
|
||||
f"fake impl."
|
||||
)
|
||||
return self._abstract_fn(*args, **kwargs)
|
||||
|
||||
lib._register_fake(self._name, fake_impl, _stacklevel=4)
|
||||
|
||||
autograd_impl = autograd.make_autograd_impl(self._opoverload, self)
|
||||
lib.impl(self._name, autograd_impl, "Autograd", with_keyset=True)
|
||||
|
||||
schema = self._opoverload._schema
|
||||
if schema.is_mutable:
|
||||
mutated_idxs, mutated_keys = utils.mutated_args_kwargs(schema)
|
||||
|
||||
def adinplaceorview_impl(keyset, *args, **kwargs):
|
||||
for idx in mutated_idxs:
|
||||
increment_version(args[idx])
|
||||
for key in mutated_keys:
|
||||
increment_version(kwargs[key])
|
||||
with _C._AutoDispatchBelowADInplaceOrView():
|
||||
return self._opoverload.redispatch(
|
||||
keyset & _C._after_ADInplaceOrView_keyset, *args, **kwargs
|
||||
)
|
||||
|
||||
lib.impl(
|
||||
self._name,
|
||||
adinplaceorview_impl,
|
||||
"ADInplaceOrView",
|
||||
with_keyset=True,
|
||||
)
|
||||
|
||||
def _register_backend_select_dispatcher(self, device_arg_index: int):
|
||||
"""
|
||||
Switch on the device argument to select the correct backend to dispatch to.
|
||||
"""
|
||||
|
||||
def backend_select(keyset, *args, **kwargs):
|
||||
device = args[device_arg_index].type
|
||||
if device not in self._backend_fns:
|
||||
raise RuntimeError(
|
||||
f"{self._name} does not have a kernel registered for {device}. "
|
||||
"Please use register_kernel to do so."
|
||||
)
|
||||
dispatch_key = _C._dispatch_key_for_device(device)
|
||||
dispatch_key = getattr(_C.DispatchKey, dispatch_key)
|
||||
return self._opoverload.redispatch(
|
||||
_C.DispatchKeySet(dispatch_key), *args, **kwargs
|
||||
)
|
||||
|
||||
self._lib.impl(self._name, backend_select, "BackendSelect", with_keyset=True)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self._opoverload(*args, **kwargs)
|
||||
|
||||
def register_vmap(
|
||||
self,
|
||||
func: Optional[Callable] = None,
|
||||
):
|
||||
r"""Register a vmap implementation to support :func:`torch.vmap` for this custom op.
|
||||
|
||||
This API may be used as a decorator.
|
||||
|
||||
In order for an operator to work with :func:`torch.vmap`, you may need to register a
|
||||
vmap implementation in the following signature:
|
||||
|
||||
``vmap_func(info, in_dims: Tuple[Optional[int]], *args, **kwargs)``,
|
||||
|
||||
where ``*args`` and ``**kwargs`` are the arguments and kwargs for ``op``.
|
||||
|
||||
It specifies how do we compute the batched version of ``op`` given inputs with an additional
|
||||
dimension (specified by ``in_dims``).
|
||||
|
||||
For each arg in ``args``, ``in_dims`` has a corresponding ``Optional[int]``. It is ``None``
|
||||
if the arg is not a Tensor or if the arg is not being vmapped over, otherwise, it is an integer
|
||||
specifying what dimension of the Tensor is being vmapped over.
|
||||
|
||||
``info`` is a collection of additional metadata that may be helpful:
|
||||
``info.batch_size`` specifies the size of the dimension being vmapped over, while
|
||||
``info.randomness`` is the ``randomness`` option that was passed to :func:`torch.vmap`.
|
||||
|
||||
The return of the function ``func`` is a tuple of ``(output, out_dims)``. Similar to ``in_dims``,
|
||||
``out_dims`` should be of the same structure as ``output`` and contain one ``out_dim``
|
||||
per output that specifies if the output has the vmapped dimension and what index it is in.
|
||||
|
||||
Examples:
|
||||
>>> import torch
|
||||
>>> import numpy as np
|
||||
>>> from torch import Tensor
|
||||
>>> from typing import Tuple
|
||||
>>>
|
||||
>>> def to_numpy(tensor):
|
||||
>>> return tensor.cpu().numpy()
|
||||
>>>
|
||||
>>> lib = torch.library.Library("mylib", "FRAGMENT")
|
||||
>>> @torch.library.custom_op("mylib::numpy_cube", mutates_args=())
|
||||
>>> def numpy_cube(x: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
>>> x_np = to_numpy(x)
|
||||
>>> dx = torch.tensor(3 * x_np ** 2, device=x.device)
|
||||
>>> return torch.tensor(x_np ** 3, device=x.device), dx
|
||||
>>>
|
||||
>>> def numpy_cube_vmap(info, in_dims, x):
|
||||
>>> result = numpy_cube(x)
|
||||
>>> return result, (in_dims[0], in_dims[0])
|
||||
>>>
|
||||
>>> numpy_cube.register_vmap(numpy_cube_vmap)
|
||||
>>>
|
||||
>>> x = torch.randn(3)
|
||||
>>> torch.vmap(numpy_cube)(x)
|
||||
>>>
|
||||
>>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=())
|
||||
>>> def numpy_mul(x: Tensor, y: Tensor) -> Tensor:
|
||||
>>> return torch.tensor(to_numpy(x) * to_numpy(y), device=x.device)
|
||||
>>>
|
||||
>>> @numpy_mul.register_vmap
|
||||
>>> def numpy_mul_vmap(info, in_dims, x, y):
|
||||
>>> x_bdim, y_bdim = in_dims
|
||||
>>> x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
|
||||
>>> y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1)
|
||||
>>> result = x * y
|
||||
>>> result = result.movedim(-1, 0)
|
||||
>>> return result, 0
|
||||
>>>
|
||||
>>>
|
||||
>>> x = torch.randn(3)
|
||||
>>> y = torch.randn(3)
|
||||
>>> torch.vmap(numpy_mul)(x, y)
|
||||
"""
|
||||
from torch._functorch.autograd_function import custom_function_call_vmap_helper
|
||||
from torch._functorch.pyfunctorch import retrieve_current_functorch_interpreter
|
||||
|
||||
def register(func):
|
||||
need_register = self._vmap_fn is None
|
||||
self._vmap_fn = func
|
||||
|
||||
if need_register:
|
||||
|
||||
def wrapped_func(keyset, *args, **kwargs):
|
||||
interpreter = retrieve_current_functorch_interpreter()
|
||||
return custom_function_call_vmap_helper(
|
||||
interpreter, self._vmap_fn, self._opoverload, *args, **kwargs
|
||||
)
|
||||
|
||||
self._lib.impl(
|
||||
self._name, wrapped_func, "FuncTorchBatched", with_keyset=True
|
||||
)
|
||||
|
||||
if func is None:
|
||||
return register
|
||||
else:
|
||||
return register(func)
|
||||
|
||||
def register_autocast(
|
||||
self,
|
||||
device_type: str,
|
||||
cast_inputs: _dtype,
|
||||
):
|
||||
r"""Register an autocast dispatch rule for this custom op.
|
||||
|
||||
Valid `device_type` include: "cpu" and "cuda".
|
||||
|
||||
Args:
|
||||
op (str | OpOverload): The operator to register an autocast dispatch rule to.
|
||||
device_type(str): Device type to use. 'cuda' or 'cpu'.
|
||||
The type is the same as the `type` attribute of a :class:`torch.device`.
|
||||
Thus, you may obtain the device type of a tensor using `Tensor.device.type`.
|
||||
cast_inputs (:class:`torch.dtype`): When custom op runs in an autocast-enabled region,
|
||||
casts incoming floating-point Tensors to the target dtype (non-floating-point Tensors
|
||||
are not affected), then executes custom op with autocast disabled.
|
||||
lib (Optional[Library]): If provided, the lifetime of this registration
|
||||
|
||||
Examples::
|
||||
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
|
||||
>>> import torch
|
||||
>>> from torch import Tensor
|
||||
>>> from torch.library import custom_op
|
||||
>>>
|
||||
>>> # Create a custom op that works on cuda
|
||||
>>> @torch.library.custom_op("mylib::my_sin", mutates_args=())
|
||||
>>> def my_sin(x: Tensor) -> Tensor:
|
||||
>>> return torch.sin(x)
|
||||
>>>
|
||||
>>> # Register autocast dispatch rule for the cuda device
|
||||
>>> torch.library.register_autocast("mylib::my_sin", "cuda", torch.float16)
|
||||
>>>
|
||||
>>> x = torch.randn(3, dtype=torch.float32, device="cuda")
|
||||
>>> with torch.autocast("cuda", dtype=torch.float16):
|
||||
>>> y = torch.ops.mylib.my_sin(x)
|
||||
>>> assert y.dtype == torch.float16
|
||||
|
||||
"""
|
||||
if not isinstance(device_type, str):
|
||||
raise ValueError(
|
||||
f"Expected `device_type` of type `str`, got: `{type(device_type)}`"
|
||||
)
|
||||
if device_type not in ["cpu", "cuda"]:
|
||||
raise ValueError(f"Unknown device type: {device_type}")
|
||||
|
||||
need_register_cuda = self._autocast_cuda_dtype is None
|
||||
need_register_cpu = self._autocast_cpu_dtype is None
|
||||
if device_type == "cuda":
|
||||
self._autocast_cuda_dtype = cast_inputs
|
||||
else:
|
||||
self._autocast_cpu_dtype = cast_inputs
|
||||
|
||||
def kernel(_, *args, **kwargs):
|
||||
assert len(kwargs) == 0, "Custom ops do not support kwargs yet."
|
||||
autocast_keyset = torch._C.DispatchKeySet(
|
||||
torch._C.DispatchKey.AutocastCPU
|
||||
) | torch._C.DispatchKeySet(torch._C.DispatchKey.AutocastCUDA)
|
||||
with torch._C._ExcludeDispatchKeyGuard(autocast_keyset):
|
||||
return self._opoverload(*_cast(args, device_type, cast_inputs))
|
||||
|
||||
if need_register_cuda and self._autocast_cuda_dtype:
|
||||
self._lib.impl(self._name, kernel, "AutocastCUDA", with_keyset=True)
|
||||
elif need_register_cpu and self._autocast_cpu_dtype:
|
||||
self._lib.impl(self._name, kernel, "AutocastCPU", with_keyset=True)
|
||||
|
||||
return kernel
|
||||
|
||||
|
||||
# TODO: Merge this function with torch.amp.autocast_mode._cast, and refactor it
|
||||
# into a utility function once custom ops support arbitrary input types.
|
||||
def _cast(value, device_type: str, dtype: _dtype):
|
||||
if isinstance(value, torch.Tensor):
|
||||
is_eligible = (
|
||||
value.is_floating_point()
|
||||
and value.device.type == device_type
|
||||
and (value.dtype is not torch.float64)
|
||||
)
|
||||
return value.to(dtype) if is_eligible else value
|
||||
elif isinstance(value, (str, bytes)):
|
||||
return value
|
||||
elif isinstance(value, collections.abc.Iterable):
|
||||
iterable = (_cast(v, device_type, dtype) for v in value)
|
||||
if isinstance(value, (list, tuple)):
|
||||
return type(value)(iterable)
|
||||
else:
|
||||
return iterable
|
||||
else:
|
||||
return value
|
||||
|
||||
|
||||
def increment_version(val: Any) -> None:
|
||||
if isinstance(val, Tensor):
|
||||
torch.autograd.graph.increment_version(val)
|
||||
elif isinstance(val, (tuple, list)):
|
||||
for v in val:
|
||||
if isinstance(v, Tensor):
|
||||
torch.autograd.graph.increment_version(v)
|
||||
|
||||
|
||||
# NOTE: [Supporting decorator and non-decorator usage]
|
||||
#
|
||||
# Some APIs may be both used as a decorator and not as a decorator.
|
||||
# For example:
|
||||
#
|
||||
# >>> def fn(x):
|
||||
# >>> return x.sin()
|
||||
# >>>
|
||||
# >>> # Usage 1: not as a decorator
|
||||
# >>> numpy_sin.register_kernel("cuda", fn)
|
||||
# >>>
|
||||
# >>> # Usage 2: as a decorator
|
||||
# >>> @numpy_sin.register_kernel("cuda")
|
||||
# >>> def fn2(x):
|
||||
# >>> return x.sin
|
||||
#
|
||||
# The way we support this is that `register_kernel` accepts an optional `fn`.
|
||||
# If `fn` is provided (Usage 1), then we know that the user is using it not
|
||||
# as a decorator.
|
||||
# If `fn` is not provided (Usage 2), then `register_kernel` needs to return a
|
||||
# decorator.
|
||||
|
||||
|
||||
OPDEF_TO_LIB: dict[str, "torch.library.Library"] = {}
|
||||
OPDEFS: weakref.WeakValueDictionary = weakref.WeakValueDictionary()
|
||||
|
||||
|
||||
def get_library_allowing_overwrite(
|
||||
namespace: str, name: str
|
||||
) -> "torch.library.Library":
|
||||
qualname = f"{namespace}::{name}"
|
||||
|
||||
if qualname in OPDEF_TO_LIB:
|
||||
OPDEF_TO_LIB[qualname]._destroy()
|
||||
del OPDEF_TO_LIB[qualname]
|
||||
|
||||
lib = torch.library.Library(namespace, "FRAGMENT") # noqa: TOR901
|
||||
OPDEF_TO_LIB[qualname] = lib
|
||||
return lib
|
||||
|
||||
|
||||
def _maybe_get_opdef(
|
||||
op: Union[CustomOpDef, _ops.OpOverload, str]
|
||||
) -> Optional[CustomOpDef]:
|
||||
if isinstance(op, CustomOpDef):
|
||||
return op
|
||||
if isinstance(op, _ops.OpOverload):
|
||||
op = op._name
|
||||
assert isinstance(op, str)
|
||||
if op in OPDEFS:
|
||||
return OPDEFS[op]
|
||||
return None
|
341
venv/Lib/site-packages/torch/_library/fake_class_registry.py
Normal file
341
venv/Lib/site-packages/torch/_library/fake_class_registry.py
Normal file
|
@ -0,0 +1,341 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import copy
|
||||
import logging
|
||||
from typing import Any, Optional, Protocol, Union
|
||||
|
||||
import torch
|
||||
from torch._library.utils import parse_namespace
|
||||
from torch.utils._python_dispatch import _disable_current_modes
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FakeScriptObject:
|
||||
def __init__(self, wrapped_obj: Any, script_class_name: str, x: torch.ScriptObject):
|
||||
self.wrapped_obj = wrapped_obj
|
||||
|
||||
# The fully qualified name of the class of original script object
|
||||
self.script_class_name = script_class_name
|
||||
try:
|
||||
with _disable_current_modes():
|
||||
self.real_obj = copy.deepcopy(x)
|
||||
except RuntimeError:
|
||||
log.warning(
|
||||
"Unable to deepcopy the custom object %s. "
|
||||
"Defaulting to the user given object. This might be "
|
||||
"dangerous as side effects may be directly applied "
|
||||
"to the object.",
|
||||
script_class_name,
|
||||
)
|
||||
self.real_obj = x
|
||||
|
||||
|
||||
class FakeScriptMethod:
|
||||
def __init__(
|
||||
self,
|
||||
self_fake_obj: FakeScriptObject,
|
||||
method_name: str,
|
||||
schema: Optional[torch.FunctionSchema],
|
||||
):
|
||||
self.self_fake_obj = self_fake_obj
|
||||
self.method_name = method_name
|
||||
self.schema = schema
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
from torch._higher_order_ops.torchbind import call_torchbind
|
||||
|
||||
return call_torchbind(self.self_fake_obj, self.method_name, *args, **kwargs)
|
||||
|
||||
|
||||
class HasStaticMethodFromReal(Protocol):
|
||||
@classmethod
|
||||
def from_real(cls, real_obj: torch.ScriptObject):
|
||||
pass
|
||||
|
||||
|
||||
class FakeClassRegistry:
|
||||
def __init__(self) -> None:
|
||||
self._registered_class: dict[str, Any] = {}
|
||||
|
||||
def has_impl(self, full_qualname: str) -> bool:
|
||||
return full_qualname in self._registered_class
|
||||
|
||||
def get_impl(self, full_qualname: str) -> Any:
|
||||
self._check_registered(full_qualname)
|
||||
return self._registered_class[full_qualname]
|
||||
|
||||
def register(self, full_qualname: str, fake_class=None) -> None:
|
||||
if self.has_impl(full_qualname):
|
||||
log.warning(
|
||||
"%s is already registered. Previous fake class is overridden with %s.",
|
||||
full_qualname,
|
||||
fake_class,
|
||||
)
|
||||
self._registered_class[full_qualname] = fake_class
|
||||
|
||||
def deregister(self, full_qualname: str) -> Any:
|
||||
if not self.has_impl(full_qualname):
|
||||
log.warning(
|
||||
"Cannot deregister %s. Please use register_fake_class to register it first."
|
||||
" Or do you dereigster it twice?",
|
||||
full_qualname,
|
||||
)
|
||||
else:
|
||||
return self._registered_class.pop(full_qualname)
|
||||
|
||||
def clear(self) -> None:
|
||||
self._registered_class.clear()
|
||||
|
||||
def _check_registered(self, full_qualname: str) -> None:
|
||||
if full_qualname not in self._registered_class:
|
||||
raise RuntimeError(
|
||||
f"{full_qualname} is not registered. Please use register_fake_class to register it first."
|
||||
)
|
||||
|
||||
|
||||
global_fake_class_registry = FakeClassRegistry()
|
||||
|
||||
|
||||
# TODO: add this check at compile time for __obj_flatten__.
|
||||
def _check_valid_flat_script_obj(flat_x):
|
||||
if not isinstance(flat_x, tuple):
|
||||
raise RuntimeError("Expect flat x to be a tuple.")
|
||||
|
||||
for tp in flat_x:
|
||||
if not isinstance(tp, tuple):
|
||||
raise RuntimeError("Expect flat x to be a tuple of tuples.")
|
||||
|
||||
if not len(tp) == 2 or not isinstance(tp[0], str):
|
||||
raise RuntimeError(
|
||||
"Expect element of flat x to be a tuple of two elements with first element being a string"
|
||||
)
|
||||
|
||||
|
||||
def tracing_with_real(x: torch.ScriptObject) -> bool:
|
||||
if not hasattr(x, "tracing_mode"):
|
||||
return False
|
||||
|
||||
assert x.tracing_mode() in [
|
||||
"real",
|
||||
"fake",
|
||||
], f"tracing_mode can be either real or fake but got {x.tracing_mode()}"
|
||||
return x.tracing_mode() == "real"
|
||||
|
||||
|
||||
def maybe_to_fake_obj(
|
||||
fake_mode, x: torch.ScriptObject
|
||||
) -> Union[FakeScriptObject, torch.ScriptObject]:
|
||||
import torch.utils._pytree as pytree
|
||||
from torch.utils._python_dispatch import _disable_current_modes
|
||||
|
||||
# When tracing with real mode, people should implement meta kernels that can
|
||||
# handle the case of real script object + fake tensor inputs.
|
||||
if tracing_with_real(x):
|
||||
return x
|
||||
|
||||
# x.__obj_flatten__() could be calling some tensor operations inside but we don't
|
||||
# want to call these ops in surrounding dispatch modes when executing it.
|
||||
# Otherwise, for example, the fake tensor modes will error out when the tensors inside
|
||||
# script obeject execute some operations like clone if allow_non_fake_input flag is set.
|
||||
with _disable_current_modes():
|
||||
flat_x = x.__obj_flatten__() # type: ignore[attr-defined]
|
||||
|
||||
_check_valid_flat_script_obj(flat_x)
|
||||
|
||||
fake_flattened = pytree.tree_map_only(
|
||||
torch.Tensor,
|
||||
lambda t: fake_mode.from_tensor(t),
|
||||
flat_x,
|
||||
)
|
||||
|
||||
fake_x = _find_fake_class_for_script_object(x).__obj_unflatten__(fake_flattened)
|
||||
|
||||
fake_x_wrapped = FakeScriptObject(fake_x, x._type().qualified_name(), x) # type: ignore[attr-defined]
|
||||
|
||||
for name in x._method_names(): # type: ignore[attr-defined]
|
||||
attr = getattr(fake_x, name, None)
|
||||
if attr:
|
||||
if not callable(attr):
|
||||
raise RuntimeError(f"Expect {name} to be a callable but got {attr}.")
|
||||
|
||||
real_attr = getattr(x, name) # type: ignore[attr-defined]
|
||||
|
||||
# real attr sometimes is not torch.ScriptMethod thus doesn't have schema e.g. __init___ or __eq__
|
||||
method_schema: Optional[torch.FunctionSchema] = None
|
||||
if isinstance(real_attr, torch.ScriptMethod):
|
||||
method_schema = real_attr.schema # type: ignore[attr-defined]
|
||||
|
||||
setattr(
|
||||
fake_x_wrapped,
|
||||
name,
|
||||
FakeScriptMethod(fake_x_wrapped, name, method_schema),
|
||||
)
|
||||
else:
|
||||
override_skip_list = {"__obj_flatten__", "__get_state__", "__set_state__"}
|
||||
if name not in override_skip_list:
|
||||
log.warning("fake object of %s doesn't implement method %s.", x, name)
|
||||
return fake_x_wrapped
|
||||
|
||||
|
||||
def register_fake_class(qualname, fake_class: Optional[HasStaticMethodFromReal] = None):
|
||||
r"""Register a fake implementation for this class.
|
||||
|
||||
It's in the same spirit of registering a fake implementation for
|
||||
an operator but with the difference that it
|
||||
associates a fake class with the original torch bind class (registered
|
||||
with torch::class_). In this way, torch.compile can handle them properly
|
||||
in components such as Dynamo and AOTAutograd.
|
||||
|
||||
This API may be used as a decorator (see example). For the fake class, users
|
||||
are required to provide a from_real classmethod that takes a real object and
|
||||
returns an instance of the fake class. All tensors in the fake object should also
|
||||
be properly fakified with to_fake_tensor() in from_real.
|
||||
|
||||
|
||||
Examples:
|
||||
# For a custom class Foo defined in test_custom_class_registration.cpp:
|
||||
|
||||
TORCH_LIBRARY(_TorchScriptTesting, m) {
|
||||
m.class_<TensorQueue>("_TensorQueue")
|
||||
.def(torch::init<at::Tensor>())
|
||||
.def("push", &TensorQueue::push)
|
||||
.def("pop", &TensorQueue::pop)
|
||||
.def("top", &TensorQueue::top)
|
||||
.def("size", &TensorQueue::size)
|
||||
.def("clone_queue", &TensorQueue::clone_queue)
|
||||
.def("__obj_flatten__", &TensorQueue::__obj_flatten__)
|
||||
.def_pickle(
|
||||
// __getstate__
|
||||
[](const c10::intrusive_ptr<TensorQueue>& self)
|
||||
-> c10::Dict<std::string, at::Tensor> {
|
||||
return self->serialize();
|
||||
},
|
||||
// __setstate__
|
||||
[](c10::Dict<std::string, at::Tensor> data)
|
||||
-> c10::intrusive_ptr<TensorQueue> {
|
||||
return c10::make_intrusive<TensorQueue>(std::move(data));
|
||||
});
|
||||
};
|
||||
# We could register a fake class FakeTensorQueue in Python as follows:
|
||||
import torch
|
||||
|
||||
@torch._library.register_fake_class("_TorchScriptTesting::_TensorQueue")
|
||||
class FakeTensorQueue:
|
||||
def __init__(self, queue):
|
||||
self.queue = queue
|
||||
|
||||
@classmethod
|
||||
def __obj_unflatten__(cls, flattened_ctx):
|
||||
return cls(**dict(ctx))
|
||||
|
||||
def push(self, x):
|
||||
self.queue.append(x)
|
||||
|
||||
def pop(self):
|
||||
return self.queue.pop(0)
|
||||
|
||||
def size(self):
|
||||
return len(self.queue)
|
||||
|
||||
In this example, the original TensorQeue need to addd a __obj_flatten__ method
|
||||
to the class TensorQueue and the flattend result is passed into FakeTensorQueue's
|
||||
__obj_unflatten__ as inputs to create a fake class. This protocol allows pytorch to look
|
||||
at the contents of the script object and properly handle them in the subsystems
|
||||
like dynamo, aot_aotugrad or more.
|
||||
"""
|
||||
|
||||
def inner(fake_class: HasStaticMethodFromReal):
|
||||
ns, name = parse_namespace(qualname)
|
||||
|
||||
# This also checks whether the refered torch::class_ exists.
|
||||
torch._C._get_custom_class_python_wrapper(ns, name)
|
||||
|
||||
from_method = getattr(fake_class, _CONVERT_FROM_REAL_NAME, None)
|
||||
if not from_method:
|
||||
raise RuntimeError(
|
||||
f"{fake_class} doesn't define a classmethod {_CONVERT_FROM_REAL_NAME}."
|
||||
)
|
||||
|
||||
if not isinstance(fake_class.__dict__[_CONVERT_FROM_REAL_NAME], classmethod):
|
||||
raise RuntimeError(
|
||||
f"{_CONVERT_FROM_REAL_NAME} method is not a classmethod."
|
||||
)
|
||||
|
||||
global_fake_class_registry.register(_full_qual_class_name(qualname), fake_class)
|
||||
return fake_class
|
||||
|
||||
if fake_class is None:
|
||||
return inner
|
||||
return inner(fake_class)
|
||||
|
||||
|
||||
def deregister_fake_class(qualname):
|
||||
return global_fake_class_registry.deregister(_full_qual_class_name(qualname))
|
||||
|
||||
|
||||
def has_fake_class(full_qualname) -> bool:
|
||||
return global_fake_class_registry.has_impl(full_qualname)
|
||||
|
||||
|
||||
def find_fake_class(full_qualname) -> Optional[Any]:
|
||||
if not has_fake_class(full_qualname):
|
||||
return None
|
||||
return global_fake_class_registry.get_impl(full_qualname)
|
||||
|
||||
|
||||
def _full_qual_class_name(qualname: str) -> str:
|
||||
ns, name = parse_namespace(qualname)
|
||||
return "__torch__.torch.classes." + ns + "." + name
|
||||
|
||||
|
||||
def _is_script_object(obj: Any) -> bool:
|
||||
return isinstance(
|
||||
obj, torch.ScriptObject
|
||||
) and obj._type().qualified_name().startswith( # type: ignore[attr-defined]
|
||||
"__torch__.torch.classes"
|
||||
)
|
||||
|
||||
|
||||
# Return the namespace and class name from fully qualified name.
|
||||
def _ns_and_class_name(full_qualname: str) -> tuple[str, str]:
|
||||
splits = full_qualname.split(".")
|
||||
assert len(splits) == 5, f"Could not split {full_qualname=}"
|
||||
_torch, _torch_ns, _classes, ns, class_name = splits
|
||||
return ns, class_name
|
||||
|
||||
|
||||
def _find_fake_class_for_script_object(x: torch.ScriptObject) -> Any:
|
||||
full_qualname = x._type().qualified_name() # type: ignore[attr-defined]
|
||||
ns, class_name = _ns_and_class_name(full_qualname)
|
||||
fake_class = find_fake_class(full_qualname)
|
||||
if fake_class is None:
|
||||
raise RuntimeError(
|
||||
f" ScriptObject's {full_qualname} haven't registered a fake class."
|
||||
f" Please use register_fake_class({ns}::{class_name}) to annotate a fake class for the script obj."
|
||||
f" Specifically, create a python class that implements a fake version for all the methods"
|
||||
f" that're used in the program and put annotated class in the program e.g. after loading the library."
|
||||
f" The fake methods can be written in the same way as a meta kernel for an operator but need to additionally"
|
||||
f" simulate the object's states. Be sure to add a {_CONVERT_FROM_REAL_NAME} classmethod"
|
||||
f" to enable creating a fake obj from a real one."
|
||||
)
|
||||
return fake_class
|
||||
|
||||
|
||||
_CONVERT_FROM_REAL_NAME = "__obj_unflatten__"
|
||||
|
||||
|
||||
def _fake_obj_from_real(fake_mode, x) -> Any:
|
||||
fake_class = _find_fake_class_for_script_object(x)
|
||||
|
||||
from_real_method = getattr(fake_class, _CONVERT_FROM_REAL_NAME, None)
|
||||
if not from_real_method:
|
||||
raise RuntimeError(
|
||||
f"{fake_class} must define a classmethod {_CONVERT_FROM_REAL_NAME}"
|
||||
f" that converts the real object to the fake object."
|
||||
)
|
||||
|
||||
# from_real defined by user need the ctx to fakify the tensor states.
|
||||
ctx = torch._library.fake_impl.FakeImplCtx(fake_mode, None)
|
||||
with torch._library.fake_impl.set_ctx_getter(lambda: ctx):
|
||||
return fake_class.from_real(x)
|
213
venv/Lib/site-packages/torch/_library/fake_impl.py
Normal file
213
venv/Lib/site-packages/torch/_library/fake_impl.py
Normal file
|
@ -0,0 +1,213 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import contextlib
|
||||
import functools
|
||||
from typing import Callable, Optional
|
||||
from typing_extensions import deprecated
|
||||
|
||||
import torch
|
||||
from torch._library.utils import Kernel, RegistrationHandle
|
||||
|
||||
|
||||
class FakeImplHolder:
|
||||
"""A holder where one can register an fake impl to."""
|
||||
|
||||
def __init__(self, qualname: str):
|
||||
self.qualname: str = qualname
|
||||
self.kernel: Optional[Kernel] = None
|
||||
self.lib: Optional[torch.library.Library] = None
|
||||
|
||||
def register(self, func: Callable, source: str) -> RegistrationHandle:
|
||||
"""Register an fake impl.
|
||||
|
||||
Returns a RegistrationHandle that one can use to de-register this
|
||||
fake impl.
|
||||
"""
|
||||
if self.kernel is not None:
|
||||
raise RuntimeError(
|
||||
f"register_fake(...): the operator {self.qualname} "
|
||||
f"already has an fake impl registered at "
|
||||
f"{self.kernel.source}."
|
||||
)
|
||||
if torch._C._dispatch_has_kernel_for_dispatch_key(self.qualname, "Meta"):
|
||||
raise RuntimeError(
|
||||
f"register_fake(...): the operator {self.qualname} "
|
||||
f"already has an DispatchKey::Meta implementation via a "
|
||||
f"pre-existing torch.library or TORCH_LIBRARY registration. "
|
||||
f"Please either remove that registration or don't call "
|
||||
f"register_fake."
|
||||
)
|
||||
|
||||
if torch._C._dispatch_has_kernel_for_dispatch_key(
|
||||
self.qualname, "CompositeImplicitAutograd"
|
||||
):
|
||||
raise RuntimeError(
|
||||
f"register_fake(...): the operator {self.qualname} "
|
||||
f"already has an implementation for this device type via a "
|
||||
f"pre-existing registration to "
|
||||
f"DispatchKey::CompositeImplicitAutograd."
|
||||
f"CompositeImplicitAutograd operators do not need an fake "
|
||||
f"impl; "
|
||||
f"instead, the operator will decompose into its constituents "
|
||||
f"and those "
|
||||
f"can have fake impls defined on them."
|
||||
)
|
||||
|
||||
# Store the kernel in this holder
|
||||
self.kernel = Kernel(func, source)
|
||||
|
||||
# Also register the fake impl to Meta key
|
||||
if self.lib is None:
|
||||
ns = self.qualname.split("::")[0]
|
||||
self.lib = torch.library.Library(ns, "FRAGMENT") # noqa: TOR901
|
||||
meta_kernel = construct_meta_kernel(self.qualname, self)
|
||||
self.lib.impl(self.qualname, meta_kernel, "Meta")
|
||||
|
||||
def deregister_fake_class():
|
||||
if self.lib:
|
||||
self.lib._destroy()
|
||||
self.lib = None
|
||||
self.kernel = None
|
||||
|
||||
return RegistrationHandle(deregister_fake_class)
|
||||
|
||||
|
||||
def construct_meta_kernel(qualname: str, fake_impl_holder: FakeImplHolder) -> Callable:
|
||||
assert fake_impl_holder.kernel is not None
|
||||
|
||||
@functools.wraps(fake_impl_holder.kernel.func)
|
||||
def meta_kernel(*args, **kwargs):
|
||||
assert fake_impl_holder.kernel is not None
|
||||
source = fake_impl_holder.kernel.source
|
||||
|
||||
def error_on_ctx():
|
||||
raise RuntimeError(
|
||||
f"{qualname} ({source}): You're trying to run this operator "
|
||||
f"with meta Tensors (as opposed to FakeTensors), but this "
|
||||
f"operator may return an output Tensor with data-dependent shape. Meta "
|
||||
f"Tensors don't support operators with outputs that have data-dependent shapes "
|
||||
f"but FakeTensors do. "
|
||||
f"If your operator does not return an output with data-dependent shape, "
|
||||
f"make sure the FakeTensor and/or meta kernel does not call "
|
||||
f"torch.library.get_ctx(). Otherwise, please use FakeTensors."
|
||||
)
|
||||
|
||||
with set_ctx_getter(error_on_ctx):
|
||||
return fake_impl_holder.kernel(*args, **kwargs)
|
||||
|
||||
return meta_kernel
|
||||
|
||||
|
||||
def get_none():
|
||||
return None
|
||||
|
||||
|
||||
global_ctx_getter: Callable = get_none
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def set_ctx_getter(ctx_getter):
|
||||
global global_ctx_getter
|
||||
prev = global_ctx_getter
|
||||
try:
|
||||
global_ctx_getter = ctx_getter
|
||||
yield
|
||||
finally:
|
||||
global_ctx_getter = prev
|
||||
|
||||
|
||||
class FakeImplCtx:
|
||||
"""
|
||||
Context object for writing fake implementations for custom operators.
|
||||
"""
|
||||
|
||||
def __init__(self, _fake_mode, _op):
|
||||
self._fake_mode = _fake_mode
|
||||
self._shape_env = _fake_mode.shape_env
|
||||
self._op = _op
|
||||
|
||||
@deprecated(
|
||||
"`create_unbacked_symint` is deprecated, please use `new_dynamic_size` instead",
|
||||
category=FutureWarning,
|
||||
)
|
||||
def create_unbacked_symint(self, *, min=2, max=None) -> torch.SymInt:
|
||||
return self.new_dynamic_size(min=min, max=max)
|
||||
|
||||
def new_dynamic_size(self, *, min=0, max=None) -> torch.SymInt:
|
||||
"""Constructs a new symint (symbolic int) representing a data-dependent value.
|
||||
|
||||
This is useful for writing the fake implementation (which is necessary
|
||||
for torch.compile) for a CustomOp where an output Tensor has a size
|
||||
that depends on the data of the input Tensors.
|
||||
|
||||
Args:
|
||||
min (int): A statically known inclusive lower bound for this symint. Default: 0
|
||||
max (Optional[int]): A statically known inclusive upper bound for this
|
||||
symint. Default: None
|
||||
|
||||
.. warning:
|
||||
|
||||
It is important that the ``min`` and ``max`` (if not None) values are set
|
||||
correctly, otherwise, there will be undefined behavior under
|
||||
torch.compile. The default value of ``min`` is 2 due to torch.compile
|
||||
specializing on 0/1 sizes.
|
||||
|
||||
You must also verify that your implementation on concrete Tensors
|
||||
(e.g. CPU/CUDA) only returns Tensors where the size that corresponds
|
||||
to the symint also has respects these constraint.
|
||||
The easiest way to do this is to add an assertion in the CPU/CUDA/etc
|
||||
implementation that the size follows these bounds.
|
||||
|
||||
Example::
|
||||
|
||||
>>> # An operator with data-dependent output shape
|
||||
>>> lib = torch.library.Library("mymodule", "FRAGMENT")
|
||||
>>> lib.define("mymodule::custom_nonzero(Tensor x) -> Tensor")
|
||||
>>>
|
||||
>>> @torch.library.register_fake("mymodule::custom_nonzero")
|
||||
>>> def _(x):
|
||||
>>> # Number of nonzero-elements is data-dependent.
|
||||
>>> # Since we cannot peek at the data in an fake impl,
|
||||
>>> # we use the ctx object to construct a new symint that
|
||||
>>> # represents the data-dependent size.
|
||||
>>> ctx = torch.library.get_ctx()
|
||||
>>> nnz = ctx.new_dynamic_size()
|
||||
>>> shape = [nnz, x.dim()]
|
||||
>>> result = x.new_empty(shape, dtype=torch.int64)
|
||||
>>> return result
|
||||
>>>
|
||||
>>> @torch.library.impl(lib, "custom_nonzero", "CPU")
|
||||
>>> def _(x):
|
||||
>>> x_np = x.numpy()
|
||||
>>> res = np.stack(np.nonzero(x_np), axis=1)
|
||||
>>> return torch.tensor(res, device=x.device)
|
||||
|
||||
"""
|
||||
if (
|
||||
self._shape_env is None
|
||||
or not self._shape_env.allow_dynamic_output_shape_ops
|
||||
):
|
||||
raise torch._subclasses.fake_tensor.DynamicOutputShapeException(self._op)
|
||||
|
||||
if isinstance(min, torch.SymInt) or isinstance(max, torch.SymInt):
|
||||
raise ValueError(
|
||||
f"ctx.new_dynamic_size(min={min}, max={max}): expected "
|
||||
f"min and max to be statically known ints but got SymInt. "
|
||||
f"This is not supported."
|
||||
)
|
||||
|
||||
if min < 0:
|
||||
raise ValueError(
|
||||
f"ctx.new_dynamic_size(min={min}, ...): expected min to be "
|
||||
f"greater than or equal to 0: this API can only create "
|
||||
f"non-negative sizes."
|
||||
)
|
||||
|
||||
return allocate_size(self._shape_env, min, max)
|
||||
|
||||
|
||||
def allocate_size(shape_env, min_val=0, max_val=None):
|
||||
result = shape_env.create_unbacked_symint()
|
||||
torch.fx.experimental.symbolic_shapes._constrain_range_for_size(
|
||||
result, min=min_val, max=max_val
|
||||
)
|
||||
return result
|
320
venv/Lib/site-packages/torch/_library/infer_schema.py
Normal file
320
venv/Lib/site-packages/torch/_library/infer_schema.py
Normal file
|
@ -0,0 +1,320 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import collections
|
||||
import inspect
|
||||
import typing
|
||||
from types import GenericAlias
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import device, dtype, Tensor, types
|
||||
from torch.utils._exposed_in import exposed_in
|
||||
|
||||
|
||||
# This is used as a negative test for
|
||||
# test_custom_ops.py::TestTypeConversion::test_type_eval.
|
||||
_TestTensor = torch.Tensor
|
||||
|
||||
|
||||
@exposed_in("torch.library")
|
||||
def infer_schema(
|
||||
prototype_function: typing.Callable,
|
||||
/,
|
||||
*,
|
||||
mutates_args,
|
||||
op_name: Optional[str] = None,
|
||||
) -> str:
|
||||
r"""Parses the schema of a given function with type hints. The schema is inferred from the
|
||||
function's type hints, and can be used to define a new operator.
|
||||
|
||||
We make the following assumptions:
|
||||
|
||||
* None of the outputs alias any of the inputs or each other.
|
||||
* | String type annotations "device, dtype, Tensor, types" without library specification are
|
||||
| assumed to be torch.*. Similarly, string type annotations "Optional, List, Sequence, Union"
|
||||
| without library specification are assumed to be typing.*.
|
||||
* | Only the args listed in ``mutates_args`` are being mutated. If ``mutates_args`` is "unknown",
|
||||
| it assumes that all inputs to the operator are being mutates.
|
||||
|
||||
Callers (e.g. the custom ops API) are responsible for checking these assumptions.
|
||||
|
||||
Args:
|
||||
prototype_function: The function from which to infer a schema for from its type annotations.
|
||||
op_name (Optional[str]): The name of the operator in the schema. If ``name`` is None, then the
|
||||
name is not included in the inferred schema. Note that the input schema to
|
||||
``torch.library.Library.define`` requires a operator name.
|
||||
mutates_args ("unknown" | Iterable[str]): The arguments that are mutated in the function.
|
||||
|
||||
Returns:
|
||||
The inferred schema.
|
||||
|
||||
Example:
|
||||
>>> def foo_impl(x: torch.Tensor) -> torch.Tensor:
|
||||
>>> return x.sin()
|
||||
>>>
|
||||
>>> infer_schema(foo_impl, op_name="foo", mutates_args={})
|
||||
foo(Tensor x) -> Tensor
|
||||
>>>
|
||||
>>> infer_schema(foo_impl, mutates_args={})
|
||||
(Tensor x) -> Tensor
|
||||
"""
|
||||
UNKNOWN_MUTATES = "unknown"
|
||||
pf_globals = prototype_function.__globals__
|
||||
pf_locals = None
|
||||
# TODO: Once our minimum version is py3.10+ pass `eval_str=True` to
|
||||
# inspect.signature() and we no longer need to deal with stringified
|
||||
# annotations below.
|
||||
sig = inspect.signature(prototype_function)
|
||||
|
||||
def error_fn(what):
|
||||
raise ValueError(f"infer_schema(func): {what} Got func with signature {sig})")
|
||||
|
||||
def convert_type_string(annotation_type: str):
|
||||
try:
|
||||
return eval(annotation_type, pf_globals, pf_locals)
|
||||
except Exception:
|
||||
error_fn(
|
||||
f"Unsupported type annotation {annotation_type}. It is not a type."
|
||||
)
|
||||
|
||||
def unstringify_types(
|
||||
tys: tuple[Union[type[object], str], ...]
|
||||
) -> tuple[tuple[typing.Any, ...], bool]:
|
||||
res = []
|
||||
changed = False
|
||||
for ty in tys:
|
||||
ty, ty_changed = unstringify_type(ty)
|
||||
res.append(ty)
|
||||
changed |= ty_changed
|
||||
if changed:
|
||||
return tuple(res), True
|
||||
else:
|
||||
return tys, False # type: ignore[return-value]
|
||||
|
||||
def unstringify_type(ty: Union[type[object], str]) -> tuple[typing.Any, bool]:
|
||||
# Dig through a generic type and if it contains a stringified type
|
||||
# convert that to a real type. The second return value indicates if the
|
||||
# type contained a string or not.
|
||||
if isinstance(ty, str):
|
||||
return convert_type_string(ty), True
|
||||
elif origin := typing.get_origin(ty):
|
||||
args, args_changed = unstringify_types(typing.get_args(ty))
|
||||
if args_changed:
|
||||
return GenericAlias(origin, args), True
|
||||
|
||||
return ty, False
|
||||
|
||||
params = []
|
||||
seen_args = set()
|
||||
saw_kwarg_only_arg = False
|
||||
for idx, (name, param) in enumerate(sig.parameters.items()):
|
||||
if not supported_param(param):
|
||||
error_fn("We do not support positional-only args, varargs, or varkwargs.")
|
||||
|
||||
if param.kind == inspect.Parameter.KEYWORD_ONLY:
|
||||
# The first time we see a kwarg-only arg, add "*" to the schema.
|
||||
if not saw_kwarg_only_arg:
|
||||
params.append("*")
|
||||
saw_kwarg_only_arg = True
|
||||
|
||||
if param.annotation is inspect.Parameter.empty:
|
||||
error_fn(f"Parameter {name} must have a type annotation.")
|
||||
|
||||
# The annotation might be converted to a string by annotation,
|
||||
# we convert it to the actual type.
|
||||
annotation_type, _ = unstringify_type(param.annotation)
|
||||
|
||||
if annotation_type not in SUPPORTED_PARAM_TYPES:
|
||||
if annotation_type.__origin__ is tuple:
|
||||
list_type = tuple_to_list(annotation_type)
|
||||
example_type_str = "\n\n"
|
||||
# Only suggest the list type if this type is supported.
|
||||
if list_type in SUPPORTED_PARAM_TYPES.keys():
|
||||
example_type_str = f"For example, {list_type}.\n\n"
|
||||
error_fn(
|
||||
f"Parameter {name} has unsupported type {param.annotation}. "
|
||||
f"We do not support Tuple inputs in schema. As a workaround, please try to use List instead. "
|
||||
f"{example_type_str}"
|
||||
f"The valid types are: {SUPPORTED_PARAM_TYPES.keys()}."
|
||||
)
|
||||
else:
|
||||
error_fn(
|
||||
f"Parameter {name} has unsupported type {param.annotation}. "
|
||||
f"The valid types are: {SUPPORTED_PARAM_TYPES.keys()}."
|
||||
)
|
||||
|
||||
schema_type = SUPPORTED_PARAM_TYPES[annotation_type]
|
||||
if type(mutates_args) == str:
|
||||
if mutates_args != UNKNOWN_MUTATES:
|
||||
raise ValueError(
|
||||
"mutates_args must either be a sequence of the names of "
|
||||
"the arguments that are mutated or the string 'unknown'. "
|
||||
)
|
||||
if schema_type.startswith("Tensor"):
|
||||
schema_type = f"Tensor(a{idx}!){schema_type[len('Tensor'):]}"
|
||||
elif name in mutates_args:
|
||||
if not schema_type.startswith("Tensor"):
|
||||
error_fn(
|
||||
f"Parameter {name} is in mutable_args but only Tensors or collections of Tensors can be mutated"
|
||||
)
|
||||
schema_type = f"Tensor(a{idx}!){schema_type[len('Tensor'):]}"
|
||||
seen_args.add(name)
|
||||
if param.default is inspect.Parameter.empty:
|
||||
params.append(f"{schema_type} {name}")
|
||||
else:
|
||||
default_repr = None
|
||||
if param.default is None or isinstance(param.default, (int, float, bool)):
|
||||
default_repr = str(param.default)
|
||||
elif isinstance(param.default, (str, torch.device)):
|
||||
default_repr = f'"{param.default}"'
|
||||
elif isinstance(param.default, torch.dtype):
|
||||
dtype_repr = str(param.default)
|
||||
torch_dot = "torch."
|
||||
assert dtype_repr.startswith(torch_dot)
|
||||
default_repr = dtype_repr[len(torch_dot) :]
|
||||
else:
|
||||
error_fn(
|
||||
f"Parameter {name} has an unsupported default value type {type(param.default)}. "
|
||||
f"Please file an issue on GitHub so we can prioritize this."
|
||||
)
|
||||
params.append(f"{schema_type} {name}={default_repr}")
|
||||
if mutates_args != UNKNOWN_MUTATES:
|
||||
mutates_args_not_seen = set(mutates_args) - seen_args
|
||||
if len(mutates_args_not_seen) > 0:
|
||||
error_fn(
|
||||
f"{mutates_args_not_seen} in mutates_args were not found in "
|
||||
f"the custom op's signature. "
|
||||
f"mutates_args should contain the names of all args that the "
|
||||
f"custom op mutates, or just the string 'unknown' if you don't know."
|
||||
)
|
||||
return_annotation, _ = unstringify_type(sig.return_annotation)
|
||||
ret = parse_return(return_annotation, error_fn)
|
||||
if op_name is not None:
|
||||
return f"{op_name}({', '.join(params)}) -> {ret}"
|
||||
return f"({', '.join(params)}) -> {ret}"
|
||||
|
||||
|
||||
def derived_types(
|
||||
base_type: Union[type, typing._SpecialForm],
|
||||
cpp_type: str,
|
||||
list_base: bool,
|
||||
optional_base_list: bool,
|
||||
optional_list_base: bool,
|
||||
):
|
||||
result: list[tuple[Union[type, typing._SpecialForm, GenericAlias], str]] = [
|
||||
(base_type, cpp_type),
|
||||
(typing.Optional[base_type], f"{cpp_type}?"),
|
||||
]
|
||||
|
||||
def derived_seq_types(typ: Union[type, typing._SpecialForm]):
|
||||
return (
|
||||
typing.Sequence[typ], # type: ignore[valid-type] # noqa: UP006
|
||||
typing.List[typ], # type: ignore[valid-type] # noqa: UP006
|
||||
GenericAlias(collections.abc.Sequence, (typ,)),
|
||||
GenericAlias(list, (typ,)),
|
||||
)
|
||||
|
||||
if list_base:
|
||||
result.extend(
|
||||
(seq_typ, f"{cpp_type}[]") for seq_typ in derived_seq_types(base_type)
|
||||
)
|
||||
if optional_base_list:
|
||||
result.extend(
|
||||
(seq_typ, f"{cpp_type}?[]")
|
||||
for seq_typ in derived_seq_types(typing.Optional[base_type])
|
||||
)
|
||||
if optional_list_base:
|
||||
result.extend(
|
||||
(typing.Optional[seq_typ], f"{cpp_type}[]?")
|
||||
for seq_typ in derived_seq_types(base_type)
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def get_supported_param_types():
|
||||
data: list[tuple[Union[type, typing._SpecialForm], str, bool, bool, bool]] = [
|
||||
# (python type, schema type, type[] variant, type?[] variant, type[]? variant
|
||||
(Tensor, "Tensor", True, True, False),
|
||||
(int, "SymInt", True, False, True),
|
||||
(float, "float", True, False, True),
|
||||
(bool, "bool", True, False, True),
|
||||
(str, "str", False, False, False),
|
||||
(types.Number, "Scalar", True, False, False),
|
||||
(dtype, "ScalarType", False, False, False),
|
||||
(device, "Device", False, False, False),
|
||||
]
|
||||
result = []
|
||||
for line in data:
|
||||
result.extend(derived_types(*line))
|
||||
return dict(result)
|
||||
|
||||
|
||||
SUPPORTED_RETURN_TYPES = {
|
||||
Tensor: "Tensor",
|
||||
typing.List[Tensor]: "Tensor[]", # noqa: UP006
|
||||
list[Tensor]: "Tensor[]",
|
||||
int: "SymInt",
|
||||
float: "float",
|
||||
bool: "bool",
|
||||
types.Number: "Scalar",
|
||||
}
|
||||
|
||||
|
||||
def parse_return(annotation, error_fn):
|
||||
if annotation is None:
|
||||
return "()"
|
||||
|
||||
if annotation is inspect.Parameter.empty:
|
||||
error_fn("No return type annotation was provided. Please add one.")
|
||||
|
||||
origin = typing.get_origin(annotation)
|
||||
if origin is not tuple:
|
||||
if annotation not in SUPPORTED_RETURN_TYPES.keys():
|
||||
error_fn(
|
||||
f"Return has unsupported type {annotation}. "
|
||||
f"The valid types are: {SUPPORTED_RETURN_TYPES}."
|
||||
)
|
||||
return SUPPORTED_RETURN_TYPES[annotation]
|
||||
|
||||
args = typing.get_args(annotation)
|
||||
for arg in args:
|
||||
if arg not in SUPPORTED_RETURN_TYPES:
|
||||
error_fn(
|
||||
f"Return has unsupported type {annotation}. "
|
||||
f"The valid types are: {SUPPORTED_RETURN_TYPES}."
|
||||
)
|
||||
|
||||
return "(" + ", ".join([SUPPORTED_RETURN_TYPES[arg] for arg in args]) + ")"
|
||||
|
||||
|
||||
SUPPORTED_PARAM_TYPES = get_supported_param_types()
|
||||
|
||||
|
||||
def supported_param(param: inspect.Parameter) -> bool:
|
||||
return param.kind in (
|
||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
inspect.Parameter.KEYWORD_ONLY,
|
||||
)
|
||||
|
||||
|
||||
def tuple_to_list(tuple_type: type[tuple]) -> type[list]:
|
||||
"""
|
||||
Convert `tuple_type` into a list type with the same type arguments. Assumes that `tuple_type` is typing.Tuple type.
|
||||
"""
|
||||
type_args = getattr(tuple_type, "__args__", None)
|
||||
# Account for different python versions, e.g. python 3.8 would give ()
|
||||
# but python 3.12 would give None.
|
||||
if (
|
||||
tuple_type is typing.Tuple # noqa: UP006
|
||||
or tuple_type is tuple
|
||||
or type_args == ()
|
||||
or type_args is None
|
||||
):
|
||||
# Handle the case of an empty tuple type
|
||||
return list
|
||||
elif len(type_args) == 1:
|
||||
# General case: create a List with the same type arguments
|
||||
return list[type_args[0]] # type: ignore[valid-type]
|
||||
elif len(type_args) == 2 and type_args[1] is Ellipsis:
|
||||
return list[type_args[0]] # type: ignore[valid-type]
|
||||
else:
|
||||
return list[typing.Union[tuple(type_args)]] # type: ignore[misc, return-value]
|
85
venv/Lib/site-packages/torch/_library/simple_registry.py
Normal file
85
venv/Lib/site-packages/torch/_library/simple_registry.py
Normal file
|
@ -0,0 +1,85 @@
|
|||
# mypy: allow-untyped-defs
|
||||
from typing import Callable, Optional
|
||||
|
||||
from .fake_impl import FakeImplHolder
|
||||
from .utils import RegistrationHandle
|
||||
|
||||
|
||||
__all__ = ["SimpleLibraryRegistry", "SimpleOperatorEntry", "singleton"]
|
||||
|
||||
|
||||
class SimpleLibraryRegistry:
|
||||
"""Registry for the "simple" torch.library APIs
|
||||
|
||||
The "simple" torch.library APIs are a higher-level API on top of the
|
||||
raw PyTorch DispatchKey registration APIs that includes:
|
||||
- fake impl
|
||||
|
||||
Registrations for these APIs do not go into the PyTorch dispatcher's
|
||||
table because they may not directly involve a DispatchKey. For example,
|
||||
the fake impl is a Python function that gets invoked by FakeTensor.
|
||||
Instead, we manage them here.
|
||||
|
||||
SimpleLibraryRegistry is a mapping from a fully qualified operator name
|
||||
(including the overload) to SimpleOperatorEntry.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._data = {}
|
||||
|
||||
def find(self, qualname: str) -> "SimpleOperatorEntry":
|
||||
if qualname not in self._data:
|
||||
self._data[qualname] = SimpleOperatorEntry(qualname)
|
||||
return self._data[qualname]
|
||||
|
||||
|
||||
singleton: SimpleLibraryRegistry = SimpleLibraryRegistry()
|
||||
|
||||
|
||||
class SimpleOperatorEntry:
|
||||
"""This is 1:1 to an operator overload.
|
||||
|
||||
The fields of SimpleOperatorEntry are Holders where kernels can be
|
||||
registered to.
|
||||
"""
|
||||
|
||||
def __init__(self, qualname: str):
|
||||
self.qualname: str = qualname
|
||||
self.fake_impl: FakeImplHolder = FakeImplHolder(qualname)
|
||||
self.torch_dispatch_rules: GenericTorchDispatchRuleHolder = (
|
||||
GenericTorchDispatchRuleHolder(qualname)
|
||||
)
|
||||
|
||||
# For compatibility reasons. We can delete this soon.
|
||||
@property
|
||||
def abstract_impl(self):
|
||||
return self.fake_impl
|
||||
|
||||
|
||||
class GenericTorchDispatchRuleHolder:
|
||||
def __init__(self, qualname):
|
||||
self._data = {}
|
||||
self.qualname = qualname
|
||||
|
||||
def register(
|
||||
self, torch_dispatch_class: type, func: Callable
|
||||
) -> RegistrationHandle:
|
||||
if self.find(torch_dispatch_class):
|
||||
raise RuntimeError(
|
||||
f"{torch_dispatch_class} already has a `__torch_dispatch__` rule registered for {self.qualname}"
|
||||
)
|
||||
self._data[torch_dispatch_class] = func
|
||||
|
||||
def deregister():
|
||||
del self._data[torch_dispatch_class]
|
||||
|
||||
return RegistrationHandle(deregister)
|
||||
|
||||
def find(self, torch_dispatch_class):
|
||||
return self._data.get(torch_dispatch_class, None)
|
||||
|
||||
|
||||
def find_torch_dispatch_rule(op, torch_dispatch_class: type) -> Optional[Callable]:
|
||||
return singleton.find(op.__qualname__).torch_dispatch_rules.find(
|
||||
torch_dispatch_class
|
||||
)
|
274
venv/Lib/site-packages/torch/_library/triton.py
Normal file
274
venv/Lib/site-packages/torch/_library/triton.py
Normal file
|
@ -0,0 +1,274 @@
|
|||
import contextlib
|
||||
import threading
|
||||
from collections.abc import Generator, Iterable
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
from torch.utils._exposed_in import exposed_in
|
||||
|
||||
from .custom_ops import custom_op, CustomOpDef
|
||||
from .infer_schema import infer_schema
|
||||
|
||||
|
||||
@exposed_in("torch.library")
|
||||
def triton_op(
|
||||
name: str,
|
||||
fn: Optional[Callable] = None,
|
||||
/,
|
||||
*,
|
||||
mutates_args: Union[str, Iterable[str]],
|
||||
schema: Optional[str] = None,
|
||||
) -> Callable:
|
||||
"""Create a custom operator whose implementation is backed by 1+ triton kernels.
|
||||
|
||||
This is a more structured way of using triton kernels with PyTorch.
|
||||
Prefer using triton kernels with no ``torch.library`` custom operator wrappers
|
||||
(like :func:`torch.library.custom_op`, :func:`torch.library.triton_op`) because
|
||||
that is simpler;
|
||||
only use :func:`torch.library.custom_op`/:func:`torch.library.triton_op` if you
|
||||
want to create an operator that behaves like PyTorch built-in operators.
|
||||
For example, you may use a ``torch.library`` wrapper API to define the
|
||||
behavior of the triton kernel when passed a tensor subclass or under
|
||||
a TorchDispatchMode.
|
||||
|
||||
Use :func:`torch.library.triton_op` instead of :func:`torch.library.custom_op`
|
||||
when the implementation
|
||||
consists of 1+ triton kernels. :func:`torch.library.custom_op` treats
|
||||
custom operators as opaque (:func:`torch.compile` and
|
||||
:func:`torch.export.export` will never trace into them), but ``triton_op``
|
||||
makes the implementation visible to these subsystems, allowing them
|
||||
to optimize the triton kernel(s).
|
||||
|
||||
Note that ``fn`` must only consist of calls to PyTorch-understood
|
||||
operators and triton kernels. Any triton kernels called inside ``fn``
|
||||
must be wrapped in a call to :func:`torch.library.wrap_triton`.
|
||||
|
||||
Args:
|
||||
name (str): A name for the custom op that looks like "{namespace}::{name}",
|
||||
e.g. "mylib::my_linear". The name is used as the op's stable identifier
|
||||
in PyTorch subsystems (e.g. torch.export, FX graphs).
|
||||
To avoid name collisions, please use your project name as the namespace;
|
||||
e.g. all custom ops in pytorch/fbgemm use "fbgemm" as the namespace.
|
||||
mutates_args (Iterable[str] or "unknown"): The names of args that the function mutates.
|
||||
This MUST be accurate, otherwise, the behavior is undefined. If "unknown",
|
||||
it pessimistically assumes that all inputs to the operator are being mutated.
|
||||
schema (None | str): A schema string for the operator. If None
|
||||
(recommended) we'll infer a schema for the operator from its type
|
||||
annotations. We recommend letting us infer a schema unless you
|
||||
have a specific reason not to.
|
||||
Example: "(Tensor x, int y) -> (Tensor, Tensor)".
|
||||
|
||||
Example::
|
||||
|
||||
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
|
||||
>>> import torch
|
||||
>>> from torch.library import triton_op, wrap_triton
|
||||
>>>
|
||||
>>> import triton
|
||||
>>> from triton import language as tl
|
||||
>>>
|
||||
>>> @triton.jit
|
||||
>>> def add_kernel(
|
||||
>>> in_ptr0,
|
||||
>>> in_ptr1,
|
||||
>>> out_ptr,
|
||||
>>> n_elements,
|
||||
>>> BLOCK_SIZE: "tl.constexpr",
|
||||
>>> ):
|
||||
>>> pid = tl.program_id(axis=0)
|
||||
>>> block_start = pid * BLOCK_SIZE
|
||||
>>> offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
>>> mask = offsets < n_elements
|
||||
>>> x = tl.load(in_ptr0 + offsets, mask=mask)
|
||||
>>> y = tl.load(in_ptr1 + offsets, mask=mask)
|
||||
>>> output = x + y
|
||||
>>> tl.store(out_ptr + offsets, output, mask=mask)
|
||||
>>>
|
||||
>>> @triton_op("mylib::add", mutates_args={})
|
||||
>>> def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
>>> output = torch.empty_like(x)
|
||||
>>> n_elements = output.numel()
|
||||
>>>
|
||||
>>> def grid(meta):
|
||||
>>> return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
|
||||
>>>
|
||||
>>> # NB: we need to wrap the triton kernel in a call to wrap_triton
|
||||
>>> wrap_triton(add_kernel)[grid](x, y, output, n_elements, 16)
|
||||
>>> return output
|
||||
>>>
|
||||
>>> @torch.compile
|
||||
>>> def f(x, y):
|
||||
>>> return add(x, y)
|
||||
>>>
|
||||
>>> x = torch.randn(3, device="cuda")
|
||||
>>> y = torch.randn(3, device="cuda")
|
||||
>>>
|
||||
>>> z = f(x, y)
|
||||
>>> assert torch.allclose(z, x + y)
|
||||
|
||||
"""
|
||||
|
||||
def dec(fn: Callable[..., object]) -> CustomOpDef:
|
||||
def backend_fn(*args, **kwargs): # type: ignore[no-untyped-def]
|
||||
# Optimization: we're passing regular Tensors into the triton kernel, so
|
||||
# no need to go through HOP dispatch
|
||||
with set_wrap_triton_enabled(False):
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
result = custom_op(
|
||||
name,
|
||||
backend_fn,
|
||||
mutates_args=mutates_args,
|
||||
schema=infer_schema(fn, mutates_args=mutates_args),
|
||||
)
|
||||
from .._subclasses.functional_tensor import FunctionalTensorMode
|
||||
|
||||
# We require that the user pass us a function that is make_fx traceable,
|
||||
# so we can just register it as the Fake/meta kernel.
|
||||
result.register_fake(fn)
|
||||
|
||||
# We decompose the operator when FunctionalTensorMode is active.
|
||||
# The goal is to decompose the operator in AOTDispatcher.
|
||||
# - With torch.compile, this means that the backend (usually Inductor)
|
||||
# can see a call to the triton kernel(s) and so it can directly optimize
|
||||
# them by inlining them into the lowering process.
|
||||
def functional_decomp( # type: ignore[no-untyped-def]
|
||||
mode, op, types, args, kwargs
|
||||
):
|
||||
# NOTE [Export custom triton op]
|
||||
# For torch.export (strict and non-strict), we don't do functional decomposition.
|
||||
# Instead, we preserve the custom triton ops as custom ops. This is because we want
|
||||
# the exported program to be high-level and serializable. If we decompose
|
||||
# the custom op to a functional hop and make it a node in exported program,
|
||||
# we need to figure out ways of serializing the hop and its arguments, which can be triton.jited
|
||||
# functions and triton dtypes. This is undesireble because:
|
||||
# - it can be tedious to maintain a layer that serializes the jited function (e.g. with a string) and dtypes.
|
||||
# - exported program will contain the implementation detail (e.g. triton source code) for a specific
|
||||
# backend (GPU), which is probably at a wrong level of abstraction.
|
||||
# - changes to triton or the serialization logic for triton arguments can be BC breaking
|
||||
#
|
||||
# In the short term, we expect users to have a separate aot_compile stage that compiles the exported program
|
||||
# into a Cubin file on the same machine that users call export, which does autotuning and removes triton
|
||||
# dependency and serve the model with Cubin. This guarantees that triton changes won't break BC.
|
||||
# In the long term, we may export multiple cubins for the triton op directly
|
||||
from torch.export._trace import custom_triton_ops_decomposition_disabled
|
||||
|
||||
if custom_triton_ops_decomposition_disabled():
|
||||
return mode.__torch_dispatch__(op, types, args, kwargs)
|
||||
else:
|
||||
with mode:
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
result.register_torch_dispatch(FunctionalTensorMode, functional_decomp)
|
||||
return result
|
||||
|
||||
if fn is None:
|
||||
return dec
|
||||
else:
|
||||
return dec(fn)
|
||||
|
||||
|
||||
wrap_triton_enabled = threading.local()
|
||||
wrap_triton_enabled_default = True
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def set_wrap_triton_enabled(enabled: bool) -> Generator[None, None, None]:
|
||||
"""If triton kernels annotated with @wrap_triton should dispatch via HOP
|
||||
or go straight to the triton kernel execution.
|
||||
|
||||
We have this switch because eager-mode performance of HOP dispatch is slow
|
||||
enough to matter (~1ms) and we know that wrap_triton isn't necessary in
|
||||
some situations (eager-mode with regular Tensors)
|
||||
"""
|
||||
try:
|
||||
prev = is_wrap_triton_enabled()
|
||||
wrap_triton_enabled.value = enabled
|
||||
yield
|
||||
finally:
|
||||
wrap_triton_enabled.value = prev
|
||||
|
||||
|
||||
def is_wrap_triton_enabled() -> bool:
|
||||
return getattr(wrap_triton_enabled, "value", wrap_triton_enabled_default)
|
||||
|
||||
|
||||
def capture_triton(triton_kernel: Callable, /) -> Any:
|
||||
"""This API has been renamed to wrap_triton"""
|
||||
return wrap_triton(triton_kernel)
|
||||
|
||||
|
||||
@exposed_in("torch.library")
|
||||
def wrap_triton(triton_kernel: Callable, /) -> Any:
|
||||
"""Allows capture of a triton kernel into a graph via make_fx or
|
||||
non-strict ``torch.export``.
|
||||
|
||||
These technologies perform Dispatcher-based tracing (via
|
||||
``__torch_dispatch__``) and cannot see calls to raw triton kernels.
|
||||
The ``wrap_triton`` API wraps a triton kernel into a callable that
|
||||
can actually be traced into a graph.
|
||||
|
||||
Please use this API together with :func:`torch.library.triton_op`.
|
||||
|
||||
Examples:
|
||||
|
||||
>>> # xdoctest: +SKIP
|
||||
>>> import torch
|
||||
>>> import triton
|
||||
>>> from triton import language as tl
|
||||
>>> from torch.fx.experimental.proxy_tensor import make_fx
|
||||
>>> from torch.library import wrap_triton
|
||||
>>>
|
||||
>>> @triton.jit
|
||||
>>> def add_kernel(
|
||||
>>> in_ptr0,
|
||||
>>> in_ptr1,
|
||||
>>> out_ptr,
|
||||
>>> n_elements,
|
||||
>>> BLOCK_SIZE: "tl.constexpr",
|
||||
>>> ):
|
||||
>>> pid = tl.program_id(axis=0)
|
||||
>>> block_start = pid * BLOCK_SIZE
|
||||
>>> offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
>>> mask = offsets < n_elements
|
||||
>>> x = tl.load(in_ptr0 + offsets, mask=mask)
|
||||
>>> y = tl.load(in_ptr1 + offsets, mask=mask)
|
||||
>>> output = x + y
|
||||
>>> tl.store(out_ptr + offsets, output, mask=mask)
|
||||
>>>
|
||||
>>> def add(x, y):
|
||||
>>> output = torch.empty_like(x)
|
||||
>>> n_elements = output.numel()
|
||||
>>>
|
||||
>>> def grid_fn(meta):
|
||||
>>> return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
|
||||
>>>
|
||||
>>> wrap_triton(add_kernel)[grid_fn](x, y, output, n_elements, 16)
|
||||
>>> return output
|
||||
>>>
|
||||
>>> x = torch.randn(3, device="cuda")
|
||||
>>> y = torch.randn(3, device="cuda")
|
||||
>>> gm = make_fx(add)(x, y)
|
||||
>>> print(gm.code)
|
||||
>>> # def forward(self, x_1, y_1):
|
||||
>>> # empty_like = torch.ops.aten.empty_like.default(x_1, pin_memory = False)
|
||||
>>> # triton_kernel_wrapper_mutation_proxy = triton_kernel_wrapper_mutation(
|
||||
>>> # kernel_idx = 0, constant_args_idx = 0,
|
||||
>>> # grid = [(1, 1, 1)], kwargs = {
|
||||
>>> # 'in_ptr0': x_1, 'in_ptr1': y_1, 'out_ptr': empty_like,
|
||||
>>> # 'n_elements': 3, 'BLOCK_SIZE': 16
|
||||
>>> # })
|
||||
>>> # return empty_like
|
||||
|
||||
"""
|
||||
from triton.runtime.autotuner import Autotuner
|
||||
from triton.runtime.jit import JITFunction
|
||||
|
||||
from torch._higher_order_ops.triton_kernel_wrap import TraceableTritonKernelWrapper
|
||||
|
||||
if not isinstance(triton_kernel, (JITFunction, Autotuner)):
|
||||
raise RuntimeError(
|
||||
"wrap_triton only works on functions annotated with triton.jit or triton.autotune"
|
||||
)
|
||||
if not is_wrap_triton_enabled():
|
||||
return triton_kernel
|
||||
return TraceableTritonKernelWrapper(triton_kernel, None, None)
|
478
venv/Lib/site-packages/torch/_library/utils.py
Normal file
478
venv/Lib/site-packages/torch/_library/utils.py
Normal file
|
@ -0,0 +1,478 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import dataclasses
|
||||
import inspect
|
||||
import sys
|
||||
import warnings
|
||||
from collections.abc import Iterable, Iterator
|
||||
from typing import Any, Callable, Union
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
from torch import _C, _utils_internal
|
||||
from torch._ops import OpOverload
|
||||
|
||||
|
||||
def warn_deploy(stacklevel=3):
|
||||
warnings.warn(
|
||||
"Python torch.library APIs do nothing under torch::deploy (multipy). "
|
||||
"Please instead use C++ custom operator registration APIs.",
|
||||
RuntimeWarning,
|
||||
stacklevel=stacklevel,
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Kernel:
|
||||
"""Models a (function, source location)"""
|
||||
|
||||
func: Callable
|
||||
source: str
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.func(*args, **kwargs)
|
||||
|
||||
|
||||
class RegistrationHandle:
|
||||
"""Does something when someone calls .destroy() on it"""
|
||||
|
||||
def __init__(self, on_destroy: Callable):
|
||||
self._on_destroy = on_destroy
|
||||
|
||||
def destroy(self) -> None:
|
||||
self._on_destroy()
|
||||
|
||||
|
||||
def get_source(stacklevel: int) -> str:
|
||||
"""Get a string that represents the caller.
|
||||
|
||||
Example: "/path/to/foo.py:42"
|
||||
|
||||
Use stacklevel=1 to get the caller's source
|
||||
Use stacklevel=2 to get the caller's caller's source
|
||||
etc.
|
||||
"""
|
||||
frame = inspect.getframeinfo(sys._getframe(stacklevel))
|
||||
source = f"{frame.filename}:{frame.lineno}"
|
||||
return source
|
||||
|
||||
|
||||
def parse_namespace(qualname: str) -> tuple[str, str]:
|
||||
splits = qualname.split("::")
|
||||
if len(splits) != 2:
|
||||
raise ValueError(
|
||||
f"Expected `qualname` to be of the form "
|
||||
f'"namespace::name", but got {qualname}. '
|
||||
f"The qualname passed to the torch.library APIs must consist "
|
||||
f"of a namespace and a name, e.g. aten::sin"
|
||||
)
|
||||
return splits[0], splits[1]
|
||||
|
||||
|
||||
def lookup_op(qualname: str) -> OpOverload:
|
||||
namespace, name = parse_namespace(qualname)
|
||||
if "." in name:
|
||||
name, overload = name.split(".")
|
||||
else:
|
||||
overload = "default"
|
||||
ns = getattr(torch.ops, namespace)
|
||||
packet = getattr(ns, name)
|
||||
return getattr(packet, overload)
|
||||
|
||||
|
||||
def is_builtin(op: OpOverload) -> bool:
|
||||
assert isinstance(op, OpOverload)
|
||||
return op.namespace in {"aten", "prim", "prims"}
|
||||
|
||||
|
||||
def is_functional_schema(schema: Any) -> bool:
|
||||
"""Check if the schema is functional.
|
||||
|
||||
An operator is functional if:
|
||||
- it does not mutate any of its inputs
|
||||
- it does not return a view on any of its inputs
|
||||
- it has at least one return
|
||||
"""
|
||||
|
||||
def is_functional(schema):
|
||||
if schema.is_mutable:
|
||||
return False
|
||||
rets = schema.returns
|
||||
is_non_mutating_view = len(rets) > 0 and any(
|
||||
r.alias_info is not None and not r.alias_info.is_write for r in rets
|
||||
)
|
||||
if is_non_mutating_view:
|
||||
return False
|
||||
if not schema.returns:
|
||||
return False
|
||||
return True
|
||||
|
||||
if isinstance(schema, torch._C.FunctionSchema):
|
||||
return is_functional(schema)
|
||||
|
||||
# Lazy import because not all PyTorch builds have torchgen
|
||||
from torchgen.model import FunctionSchema
|
||||
|
||||
if isinstance(schema, str):
|
||||
schema = FunctionSchema.parse(schema)
|
||||
assert isinstance(schema, FunctionSchema)
|
||||
return is_functional(schema)
|
||||
|
||||
|
||||
# should be torch._C.JitType but that annotation is busted
|
||||
def is_tensorlist_like_type(typ: Any) -> bool:
|
||||
return (
|
||||
typ == _C.ListType(_C.TensorType.get())
|
||||
or typ == _C.ListType(_C.OptionalType(_C.TensorType.get()))
|
||||
or typ == _C.OptionalType(_C.ListType(_C.TensorType.get()))
|
||||
or typ == _C.OptionalType(_C.ListType(_C.OptionalType(_C.TensorType.get())))
|
||||
)
|
||||
|
||||
|
||||
# should be torch._C.JitType but that annotation is busted
|
||||
def is_tensor_like_type(typ: Any) -> bool:
|
||||
return typ == _C.TensorType.get() or typ == _C.OptionalType(_C.TensorType.get())
|
||||
|
||||
|
||||
def mutates_and_returns_first_arg(op: OpOverload):
|
||||
"""Check if an op is an inplace aten op, i.e. it mutates and returns the first arg.
|
||||
|
||||
TODO: torchgen/model.py's FunctionSchema.parse is the source of truth for this,
|
||||
but not all PyTorch builds have torchgen (due to the yaml dependency being weird).
|
||||
Figure this out.
|
||||
|
||||
Example: add_(Tensor(a!) x, Tensor y) -> Tensor(a)
|
||||
"""
|
||||
if op.namespace != "aten":
|
||||
return False
|
||||
schema = op._schema
|
||||
if not len(schema.returns) == 1:
|
||||
return False
|
||||
if schema.returns[0].alias_info is None:
|
||||
return False
|
||||
alias_set = schema.returns[0].alias_info.after_set
|
||||
if len(alias_set) != 1:
|
||||
return False
|
||||
loc = next(iter(alias_set))
|
||||
if len(schema.arguments) < 1:
|
||||
return False
|
||||
first_arg = schema.arguments[0]
|
||||
if first_arg.alias_info is None:
|
||||
return False
|
||||
if not first_arg.alias_info.is_write:
|
||||
return False
|
||||
alias_set = first_arg.alias_info.after_set
|
||||
if len(alias_set) != 1:
|
||||
return False
|
||||
if loc != next(iter(alias_set)):
|
||||
return False
|
||||
for arg in schema.arguments[1:]:
|
||||
if arg.alias_info is not None:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def fill_defaults(schema, args, kwargs):
|
||||
new_args = []
|
||||
new_kwargs = {}
|
||||
for i in range(len(schema.arguments)):
|
||||
info = schema.arguments[i]
|
||||
if info.kwarg_only:
|
||||
if info.name in kwargs:
|
||||
new_kwargs[info.name] = kwargs[info.name]
|
||||
else:
|
||||
new_kwargs[info.name] = info.default_value
|
||||
else:
|
||||
if i < len(args):
|
||||
new_args.append(args[i])
|
||||
else:
|
||||
new_args.append(info.default_value)
|
||||
return tuple(new_args), new_kwargs
|
||||
|
||||
|
||||
def zip_schema(
|
||||
schema: _C.FunctionSchema, args: tuple[Any, ...], kwargs: dict[str, Any]
|
||||
) -> Iterable[tuple[_C.Argument, Any]]:
|
||||
"""zips schema.arguments and (args, kwargs) together.
|
||||
|
||||
Assumes that (args, kwargs) were the inputs to some torch._ops.OpOverload:
|
||||
that is, (args, kwargs) must be bindable to the schema (args, kwargs).
|
||||
"""
|
||||
assert len(schema.arguments) >= len(args) + len(kwargs)
|
||||
for i in range(len(schema.arguments)):
|
||||
info = schema.arguments[i]
|
||||
if info.kwarg_only:
|
||||
if info.name in kwargs:
|
||||
yield info, kwargs[info.name]
|
||||
continue
|
||||
if i >= len(args):
|
||||
if not info.kwarg_only and info.name in kwargs:
|
||||
yield info, kwargs[info.name]
|
||||
# args that are equal to their default values are not populated
|
||||
# if they are followed by args that are equal to their defaults.
|
||||
# Skip these.
|
||||
continue
|
||||
yield info, args[i]
|
||||
return
|
||||
|
||||
|
||||
def hop_schema_from_fx_node(node):
|
||||
from torchgen.gen_schema_utils import FunctionSchemaGen
|
||||
|
||||
hop = node.target
|
||||
if not isinstance(hop, torch._ops.HigherOrderOperator):
|
||||
raise RuntimeError("fx_node's target must be a hop.")
|
||||
|
||||
def _collect_example_val(node):
|
||||
meta_val = node.meta.get("val", None)
|
||||
if meta_val is None:
|
||||
assert node.op == "get_attr"
|
||||
meta_val = getattr(node.graph.owning_module, node.target)
|
||||
return meta_val
|
||||
|
||||
example_inputs = []
|
||||
for arg in node.args:
|
||||
if isinstance(arg, (torch.fx.Node, torch.fx.node.Node)):
|
||||
example_inputs.append(_collect_example_val(arg))
|
||||
elif isinstance(
|
||||
arg, (torch.fx.immutable_collections.immutable_list, list, tuple)
|
||||
):
|
||||
example_inputs.append([_collect_example_val(x) for x in arg])
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported arg type {type(arg)}")
|
||||
|
||||
# Bound the arguments to make sure number of inputs are correct
|
||||
bound_args: inspect.BoundArguments = inspect.signature(hop.__call__).bind(
|
||||
*example_inputs
|
||||
)
|
||||
|
||||
# We treat example_output as a single value in return. This is to differentiate 1. return a single val
|
||||
# vs 2. return a tuple with one element.
|
||||
example_output = _collect_example_val(node)
|
||||
return FunctionSchemaGen.from_example(
|
||||
hop._name, tuple(bound_args.arguments.items()), (list(example_output),)
|
||||
)
|
||||
|
||||
|
||||
def can_generate_trivial_fake_impl(op: OpOverload) -> bool:
|
||||
assert isinstance(op, OpOverload)
|
||||
if is_builtin(op):
|
||||
# We control the built-ins. These may (in rare cases)
|
||||
# do input metadata mutation (which we have banned on custom ops)
|
||||
return False
|
||||
schema = op._schema
|
||||
# It's suspicious if the op is not mutable but returns nothing, so we return False out of an abundance of caution
|
||||
if not schema.is_mutable:
|
||||
return False
|
||||
if len(schema.returns) > 0:
|
||||
return False
|
||||
# If the op returns nothing, then it has a trivial fake impl.
|
||||
return True
|
||||
|
||||
|
||||
def requires_set_python_module() -> bool:
|
||||
"""If an op was defined in C++ and extended from Python using the
|
||||
torch.library APIs, returns if we require that there have been a
|
||||
m.set_python_module("mylib.ops") call from C++ that associates
|
||||
the C++ op with a python module.
|
||||
"""
|
||||
return getattr(_utils_internal, "REQUIRES_SET_PYTHON_MODULE", True)
|
||||
|
||||
|
||||
def handle_dispatch_mode(curr_mode, op_overload, *args, **kwargs):
|
||||
assert isinstance(curr_mode, torch.utils._python_dispatch.TorchDispatchMode)
|
||||
args_flattened, _ = torch.utils._pytree.tree_flatten((args, kwargs.values()))
|
||||
# TODO: need to double check the semantics of the "types" argument to torch_dispatch.
|
||||
# It's generated in PyInterpreter.cpp, but seems to be generated in two places,
|
||||
# where in one case we only include tensors with the python key, and in another
|
||||
# we include **all** tensors.
|
||||
overload_types = [
|
||||
type(a)
|
||||
for a in args_flattened
|
||||
if isinstance(a, torch.Tensor)
|
||||
and torch._C._dispatch_keys(a).has(torch._C.DispatchKey.Python)
|
||||
]
|
||||
# TODO: check that I got these args correct (in C++, we pass in "0000"??)
|
||||
|
||||
return curr_mode.__torch_dispatch__(op_overload, overload_types, args, kwargs)
|
||||
|
||||
|
||||
def has_kwarg_only_args(schema: _C.FunctionSchema):
|
||||
return any(a.kwarg_only for a in schema.arguments)
|
||||
|
||||
|
||||
def has_kwarg_only_tensors(schema: _C.FunctionSchema):
|
||||
for a in schema.arguments:
|
||||
if not (is_tensor_like_type(a.type) or is_tensorlist_like_type(a.type)):
|
||||
continue
|
||||
if not a.kwarg_only:
|
||||
continue
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def has_tensor_arg(schema: _C.FunctionSchema) -> bool:
|
||||
"""
|
||||
Given a schema, returns True if the schema has a Tensor arg.
|
||||
A Tensor arg is any arg with a type annotation that might involve Tensor.
|
||||
"""
|
||||
return any(
|
||||
(is_tensor_like_type(a.type) or is_tensorlist_like_type(a.type))
|
||||
for a in schema.arguments
|
||||
)
|
||||
|
||||
|
||||
def get_device_arg_index(schema: _C.FunctionSchema) -> Union[int, None]:
|
||||
"""
|
||||
Given a schema, returns the id of the `device: torch.device` argument.
|
||||
If it does not exist, returns None.
|
||||
"""
|
||||
for index, arg in enumerate(schema.arguments):
|
||||
if arg.type is _C.DeviceObjType.get() and arg.name == "device":
|
||||
return index
|
||||
return None
|
||||
|
||||
|
||||
def iter_tensors(
|
||||
args: tuple[Any], kwargs: dict[str, Any], allowed_nesting: int = 1
|
||||
) -> Iterator[torch.Tensor]:
|
||||
def check(arg):
|
||||
if isinstance(arg, torch.Tensor):
|
||||
yield arg
|
||||
elif allowed_nesting > 0 and isinstance(arg, (tuple, list)):
|
||||
yield from iter_tensors(tuple(arg), {}, allowed_nesting - 1)
|
||||
|
||||
for arg in args:
|
||||
yield from check(arg)
|
||||
for kwarg in kwargs.values():
|
||||
yield from check(kwarg)
|
||||
|
||||
|
||||
def check_aliasing_constraint(name, prev, result, get_module=lambda: "???"):
|
||||
"""
|
||||
custom operators' outputs must not alias any inputs or other outputs.
|
||||
"""
|
||||
storages = {id(t.untyped_storage()) for t in prev if isinstance(t, torch.Tensor)}
|
||||
tuple_result = result
|
||||
if not isinstance(result, tuple):
|
||||
tuple_result = (result,)
|
||||
for tensor in iter_tensors(tuple_result, {}):
|
||||
key = id(tensor.untyped_storage())
|
||||
if id(tensor.untyped_storage()) in storages:
|
||||
raise RuntimeError(
|
||||
f"{name} (with implementation in {get_module()}): "
|
||||
f"The output of this custom operator (1) must not "
|
||||
f"also be an input to this custom operator and "
|
||||
f"(2) may not alias any inputs to this custom operator "
|
||||
f"or other returns. "
|
||||
f"The most common way to trigger this error is if "
|
||||
f"we have y = custom_op(x) and y and x are the same Tensor. "
|
||||
f"Please instead return a clone of the offending output "
|
||||
f"tensor(s) (e.g. return x.clone()) or refactor the custom "
|
||||
f"operator to not return y."
|
||||
)
|
||||
storages.add(key)
|
||||
|
||||
|
||||
class MutationChecker:
|
||||
"""
|
||||
Check if an operator mutated its arguments.
|
||||
Usage:
|
||||
|
||||
checker = MutationChecker(op, flat_args, args_spec)
|
||||
op(*args, **kwargs)
|
||||
checker.check()
|
||||
"""
|
||||
|
||||
def __init__(self, op, flat_args, args_spec):
|
||||
self.op = op
|
||||
self.args_spec = args_spec
|
||||
self.flat_args = flat_args
|
||||
self.real_pre_hashes = [
|
||||
hash_tensor(a) if isinstance(a, torch.Tensor) else None for a in flat_args
|
||||
]
|
||||
|
||||
def check(self):
|
||||
real_post_hashes = [
|
||||
hash_tensor(a) if isinstance(a, torch.Tensor) else None
|
||||
for a in self.flat_args
|
||||
]
|
||||
was_mutated = [
|
||||
not torch.equal(pre, post)
|
||||
and not (pre.isnan().all() and post.isnan().all())
|
||||
if isinstance(pre, torch.Tensor) and isinstance(post, torch.Tensor)
|
||||
else None
|
||||
for pre, post in zip(self.real_pre_hashes, real_post_hashes)
|
||||
]
|
||||
was_mutated_args, was_mutated_kwargs = pytree.tree_unflatten(
|
||||
was_mutated, self.args_spec
|
||||
)
|
||||
for info, was_mutated in zip_schema(
|
||||
self.op._schema, was_mutated_args, was_mutated_kwargs
|
||||
):
|
||||
|
||||
def check_one(info, was_mutated):
|
||||
if info.is_write == was_mutated:
|
||||
return
|
||||
raise RuntimeError(
|
||||
f"{self.op._name}: for argument '{info.name}': the operator's schema "
|
||||
f"{self.op._schema} specified that "
|
||||
f"the operator {'mutates' if info.is_write else 'does not mutate'} "
|
||||
f"the argument, but this seems to be emperically wrong. "
|
||||
f"Please make the schema and operator behavior consistent. "
|
||||
f"You can specify that an operator mutates a Tensor by "
|
||||
f"e.g. changing its schema type from 'Tensor name' to 'Tensor(a!) name'"
|
||||
f"(use different identifiers (a, b, c, ...) for different Tensors)"
|
||||
)
|
||||
|
||||
if is_tensor_like_type(info.type):
|
||||
check_one(info, was_mutated)
|
||||
elif is_tensorlist_like_type(info.type):
|
||||
was_any_mutated = False if was_mutated is None else any(was_mutated)
|
||||
check_one(info, was_any_mutated)
|
||||
|
||||
|
||||
def hash_tensor(t: torch.Tensor) -> torch.Tensor:
|
||||
"""Some inexpensive hash. Used as a quick and dirty indicator for tensor mutation"""
|
||||
return t.detach().float().mean()
|
||||
|
||||
|
||||
def has_fake_kernel(op: torch._ops.OpOverload) -> bool:
|
||||
"""If an operator (that stays alive until FakeTensorMode) has a Fake kernel.
|
||||
Don't use this if the operator decomposes before FakeTensorMode.
|
||||
"""
|
||||
if can_generate_trivial_fake_impl(op):
|
||||
return True
|
||||
name = op._name
|
||||
if torch._C._dispatch_has_kernel_for_dispatch_key(
|
||||
name, "CompositeImplicitAutograd"
|
||||
):
|
||||
return True
|
||||
opdef = torch._library.custom_ops._maybe_get_opdef(name)
|
||||
if opdef is None:
|
||||
# the non-torch.library.custom_op path
|
||||
if torch._C._dispatch_has_kernel_for_dispatch_key(
|
||||
name, "CompositeExplicitAutograd"
|
||||
):
|
||||
return True
|
||||
entry = torch._library.simple_registry.singleton.find(name)
|
||||
if entry.fake_impl.kernel is not None:
|
||||
return True
|
||||
if torch._C._dispatch_has_kernel_for_dispatch_key(name, "Meta"):
|
||||
return True
|
||||
else:
|
||||
# the torch.library.custom_op path
|
||||
if opdef._abstract_fn is not None:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def mutated_args_kwargs(schema: _C.FunctionSchema) -> tuple[list[int], list[str]]:
|
||||
idxs = []
|
||||
keys = []
|
||||
for i, info in enumerate(schema.arguments):
|
||||
if info.alias_info is not None and info.alias_info.is_write:
|
||||
if info.kwarg_only:
|
||||
keys.append(info.name)
|
||||
else:
|
||||
idxs.append(i)
|
||||
return idxs, keys
|
Loading…
Add table
Add a link
Reference in a new issue