78 lines
2.5 KiB
Python
78 lines
2.5 KiB
Python
# mypy: ignore-errors
|
|
from typing import Any, Optional
|
|
|
|
import torch
|
|
import torch.utils._pytree as pytree
|
|
from torch._subclasses.fake_tensor import is_fake
|
|
from torch.testing._internal.two_tensor import TwoTensor
|
|
from torch.utils._python_dispatch import return_and_correct_aliasing
|
|
|
|
|
|
class WrapperSubclass(torch.Tensor):
|
|
@staticmethod
|
|
def __new__(cls, a, outer_size=None, outer_stride=None):
|
|
if outer_size is None:
|
|
outer_size = a.size()
|
|
if outer_stride is None:
|
|
outer_stride = a.stride()
|
|
|
|
kwargs = {}
|
|
kwargs["strides"] = outer_stride
|
|
kwargs["storage_offset"] = a.storage_offset()
|
|
kwargs["device"] = a.device
|
|
kwargs["layout"] = a.layout
|
|
kwargs["requires_grad"] = a.requires_grad
|
|
kwargs["dtype"] = a.dtype
|
|
out = torch.Tensor._make_wrapper_subclass(cls, outer_size, **kwargs)
|
|
|
|
return out
|
|
|
|
def __init__(self, a, outer_size=None, outer_stride=None):
|
|
self.a = a
|
|
|
|
def __repr__(self):
|
|
return f"WrapperSubclass({repr(self.a)})"
|
|
|
|
def __tensor_flatten__(self):
|
|
return ["a"], None
|
|
|
|
@staticmethod
|
|
def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):
|
|
assert meta is None
|
|
a = inner_tensors["a"]
|
|
if is_fake(a):
|
|
assert outer_size is not None
|
|
assert outer_stride is not None
|
|
return WrapperSubclass(a, outer_size, outer_stride)
|
|
|
|
@classmethod
|
|
def __torch_dispatch__(cls, func, types, args, kwargs):
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
args_a = pytree.tree_map_only(WrapperSubclass, lambda x: x.a, args)
|
|
|
|
kwargs_a = pytree.tree_map_only(WrapperSubclass, lambda x: x.a, kwargs)
|
|
|
|
out_a = func(*args_a, **kwargs_a)
|
|
out_a_flat, spec = pytree.tree_flatten(out_a)
|
|
out_flat = [
|
|
WrapperSubclass(o_a) if isinstance(o_a, torch.Tensor) else o_a
|
|
for o_a in out_a_flat
|
|
]
|
|
out = pytree.tree_unflatten(out_flat, spec)
|
|
from torch._higher_order_ops.cond import cond_op
|
|
|
|
if func is cond_op:
|
|
return out
|
|
else:
|
|
return return_and_correct_aliasing(func, args, kwargs, out)
|
|
|
|
def __coerce_same_metadata_as_tangent__(
|
|
self, expected_metadata: Any, expected_type: Optional[type] = None
|
|
):
|
|
if expected_type == type(self.a):
|
|
return self.a
|
|
elif expected_type is TwoTensor:
|
|
return TwoTensor(self.a, self.a.clone())
|
|
|
|
return None
|