84 lines
3.1 KiB
Python
84 lines
3.1 KiB
Python
from __future__ import annotations
|
|
|
|
import torchgen.api.meta as meta
|
|
import torchgen.api.structured as structured
|
|
from torchgen.api.types import kernel_signature
|
|
from torchgen.context import with_native_function_and_index
|
|
from torchgen.model import BackendIndex, NativeFunction, NativeFunctionsGroup
|
|
from torchgen.utils import mapMaybe
|
|
|
|
|
|
def torch_api_key_word_prefix(bankend_index: BackendIndex) -> str:
|
|
if bankend_index.external:
|
|
return ""
|
|
|
|
# Although Intel GPU ATen library is out-of-tree, it still utilizes torchgen to produce structrued
|
|
# kernels. Regarding these produced structured kernels, they should be visible for the Intel GPU ATen
|
|
# library. Therefore, we need to add "TORCH_XPU_API" prefix to these structured kernels,
|
|
# rather than "TORCH_API". Because the semantic of "TORCH_API" is "hidden" for out-of-tree backends.
|
|
# For other in-tree backends like cpu and cuda, they still use "TORCH_API" prefix with "visible" semantic.
|
|
device_torch_api_key_word_mapping = {
|
|
"XPU": "TORCH_XPU_API",
|
|
}
|
|
|
|
return (
|
|
device_torch_api_key_word_mapping.get(
|
|
bankend_index.dispatch_key.name, "TORCH_API"
|
|
)
|
|
+ " "
|
|
)
|
|
|
|
|
|
@with_native_function_and_index
|
|
def gen_unstructured(f: NativeFunction, backend_index: BackendIndex) -> str | None:
|
|
sig = kernel_signature(f, backend_index)
|
|
metadata = backend_index.get_kernel(f)
|
|
if metadata is None:
|
|
return None
|
|
if "legacy::" in metadata.kernel:
|
|
return None
|
|
else:
|
|
prefix = "static" if backend_index.external else "TORCH_API"
|
|
return f"{prefix} {sig.decl(name=metadata.kernel)};"
|
|
|
|
|
|
@with_native_function_and_index
|
|
def gen_structured(g: NativeFunctionsGroup, backend_index: BackendIndex) -> list[str]:
|
|
meta_name = meta.name(g)
|
|
out_args = structured.impl_arguments(g)
|
|
metadata = backend_index.get_kernel(g)
|
|
if metadata is None:
|
|
return []
|
|
prefix = torch_api_key_word_prefix(backend_index)
|
|
return [
|
|
f"""\
|
|
struct {prefix}structured_{metadata.kernel} : public at::meta::structured_{meta_name} {{
|
|
void impl({", ".join(a.decl() for a in out_args)});
|
|
}};
|
|
"""
|
|
]
|
|
|
|
|
|
# Generates NativeFunctions.h, a list of forward declarations of all
|
|
# actual kernel definitions we keep in aten/src/ATen/native/
|
|
@with_native_function_and_index
|
|
def compute_native_function_declaration(
|
|
g: NativeFunctionsGroup | NativeFunction, backend_index: BackendIndex
|
|
) -> list[str]:
|
|
metadata = backend_index.get_kernel(g)
|
|
if isinstance(g, NativeFunctionsGroup):
|
|
if metadata is not None and metadata.structured:
|
|
if backend_index.external:
|
|
# Structured hasn't been tested with external backends yet.
|
|
raise AssertionError(
|
|
"Structured external backend functions are not implemented yet."
|
|
)
|
|
else:
|
|
return gen_structured(g, backend_index)
|
|
else:
|
|
return list(
|
|
mapMaybe(lambda f: gen_unstructured(f, backend_index), g.functions())
|
|
)
|
|
else:
|
|
x = gen_unstructured(g, backend_index)
|
|
return [] if x is None else [x]
|