team-10/env/Lib/site-packages/torch/include/ATen/autocast_mode.h

946 lines
39 KiB
C
Raw Normal View History

2025-08-02 07:34:44 +02:00
#pragma once
#include <ATen/ATen.h>
#include <ATen/NativeFunctions.h>
#include <ATen/Operators.h>
#include <torch/library.h>
#include <c10/core/impl/LocalDispatchKeySet.h>
#include <c10/util/intrusive_ptr.h>
namespace at::autocast {
TORCH_API bool is_autocast_enabled(at::DeviceType device_type);
TORCH_API void set_autocast_enabled(at::DeviceType device_type, bool enabled);
TORCH_API at::ScalarType get_autocast_dtype(at::DeviceType device_type);
TORCH_API void set_autocast_dtype(
at::DeviceType device_type,
at::ScalarType dtype);
TORCH_API void clear_cache();
TORCH_API int increment_nesting();
TORCH_API int decrement_nesting();
TORCH_API bool is_autocast_cache_enabled();
TORCH_API void set_autocast_cache_enabled(bool enabled);
// deprecated CUDA-specific autocast APIs
C10_DEPRECATED_MESSAGE(
"at::autocast::is_enabled() is deprecated. Please use at::autocast::is_autocast_enabled(at::kCUDA) instead.")
TORCH_API inline bool is_enabled() {
TORCH_WARN_DEPRECATION(
"at::autocast::",
__func__,
"() is deprecated. Please use at::autocast::is_autocast_enabled(at::kCUDA) instead.")
return is_autocast_enabled(at::kCUDA);
}
C10_DEPRECATED_MESSAGE(
"at::autocast::set_enabled(enabled) is deprecated. Please use at::autocast::set_autocast_enabled(at::kCUDA, enabled) instead.")
TORCH_API inline void set_enabled(bool enabled) {
TORCH_WARN_DEPRECATION(
"at::autocast::",
__func__,
"(enabled) is deprecated. Please use at::autocast::set_autocast_enabled(at::kCUDA, enabled) instead.")
set_autocast_enabled(at::kCUDA, enabled);
}
C10_DEPRECATED_MESSAGE(
"at::autocast::get_autocast_gpu_dtype() is deprecated. Please use at::autocast::get_autocast_dtype(at::kCUDA) instead.")
TORCH_API inline at::ScalarType get_autocast_gpu_dtype() {
TORCH_WARN_DEPRECATION(
"at::autocast::",
__func__,
"() is deprecated. Please use at::autocast::get_autocast_dtype(at::kCUDA) instead.")
return get_autocast_dtype(at::kCUDA);
}
C10_DEPRECATED_MESSAGE(
"at::autocast::set_autocast_gpu_dtype(dtype) is deprecated. Please use at::autocast::set_autocast_dtype(at::kCUDA, dtype) instead.")
TORCH_API inline void set_autocast_gpu_dtype(at::ScalarType dtype) {
TORCH_WARN_DEPRECATION(
"at::autocast::",
__func__,
"(dtype) is deprecated. Please use at::autocast::set_autocast_dtype(at::kCUDA, dtype) instead.")
set_autocast_dtype(at::kCUDA, dtype);
}
#define DECLARE_DEPRECATED_AUTOCAST_APIS(name, device_type) \
C10_DEPRECATED_MESSAGE( \
"at::autocast::is_" #name \
"_enabled() is deprecated. Please use at::autocast::is_autocast_enabled(" #device_type \
") instead.") \
TORCH_API inline bool is_##name##_enabled() { \
TORCH_WARN_DEPRECATION( \
"at::autocast::", \
__func__, \
"() is deprecated. Please use at::autocast::is_autocast_enabled(" #device_type \
") instead.") \
return is_autocast_enabled(device_type); \
} \
\
C10_DEPRECATED_MESSAGE( \
"at::autocast::set_" #name \
"_enabled(enabled) is deprecated. Please use at::autocast::set_autocast_enabled(" #device_type \
", enabled) instead.") \
TORCH_API inline void set_##name##_enabled(bool enabled) { \
TORCH_WARN_DEPRECATION( \
"at::autocast::", \
__func__, \
"(enabled) is deprecated. Please use at::autocast::set_autocast_enabled(" #device_type \
", enabled) instead.") \
set_autocast_enabled(device_type, enabled); \
} \
\
C10_DEPRECATED_MESSAGE( \
"at::autocast::get_autocast_" #name \
"_dtype() is deprecated. Please use at::autocast::get_autocast_dtype(" #device_type \
") instead.") \
TORCH_API inline at::ScalarType get_autocast_##name##_dtype() { \
TORCH_WARN_DEPRECATION( \
"at::autocast::", \
__func__, \
"() is deprecated. Please at::autocast::get_autocast_dtype(" #device_type \
") instead.") \
return get_autocast_dtype(device_type); \
} \
\
C10_DEPRECATED_MESSAGE( \
"at::autocast::set_autocast_" #name \
"_dtype(dtype) is deprecated. Please use at::autocast::set_autocast_dtype(" #device_type \
", dtype) instead.") \
TORCH_API inline void set_autocast_##name##_dtype(at::ScalarType dtype) { \
TORCH_WARN_DEPRECATION( \
"at::autocast::", \
__func__, \
"(dtype) is deprecated. Please use at::autocast::set_autocast_dtype(" #device_type \
", dtype) instead.") \
set_autocast_dtype(device_type, dtype); \
}
#define AT_FORALL_DEPRECATED_AUTOCAST_BACKENDS(_) \
_(cpu, at::kCPU) \
_(mtia, at::kMTIA) \
_(xpu, at::kXPU) \
_(xla, at::kXLA) \
_(hpu, at::kHPU) \
_(ipu, at::kIPU) \
_(privateuseone, at::kPrivateUse1)
// deprecated other backend specific autocast APIs
AT_FORALL_DEPRECATED_AUTOCAST_BACKENDS(DECLARE_DEPRECATED_AUTOCAST_APIS)
const std::array<at::DeviceType, 9> _AUTOCAST_SUPPORTED_DEVICES{
at::kCPU,
at::kCUDA,
at::kMTIA,
at::kXPU,
at::kIPU,
at::kHPU,
at::kXLA,
at::kPrivateUse1,
at::kMPS};
namespace {
inline bool is_autocast_eligible(
const Tensor& tensor,
c10::DeviceType device_type) {
switch (device_type) {
case c10::DeviceType::CUDA:
return (tensor.is_cuda() || tensor.is_xla()) &&
tensor.is_floating_point();
case c10::DeviceType::CPU:
return (tensor.is_cpu() || tensor.is_mkldnn()) &&
tensor.is_floating_point();
case c10::DeviceType::MTIA:
return tensor.is_mtia() && tensor.is_floating_point();
case c10::DeviceType::XPU:
return tensor.is_xpu() && tensor.is_floating_point();
case c10::DeviceType::IPU:
return tensor.is_ipu() && tensor.is_floating_point();
case c10::DeviceType::HPU:
return tensor.is_hpu() && tensor.is_floating_point();
case c10::DeviceType::XLA:
return tensor.is_xla() && tensor.is_floating_point();
case c10::DeviceType::PrivateUse1:
return tensor.is_privateuseone() && tensor.is_floating_point();
case c10::DeviceType::MPS:
return tensor.is_mps() && tensor.is_floating_point();
default:
return false;
}
}
} // namespace
inline DispatchKey get_autocast_dispatch_key_from_device_type(
c10::DeviceType device_type) {
switch (device_type) {
case c10::DeviceType::CUDA:
return DispatchKey::Autocast;
case c10::DeviceType::CPU:
return DispatchKey::AutocastCPU;
case c10::DeviceType::MTIA:
return DispatchKey::AutocastMTIA;
case c10::DeviceType::XPU:
return DispatchKey::AutocastXPU;
case c10::DeviceType::IPU:
return DispatchKey::AutocastIPU;
case c10::DeviceType::HPU:
return DispatchKey::AutocastHPU;
case c10::DeviceType::XLA:
return DispatchKey::AutocastXLA;
case c10::DeviceType::PrivateUse1:
return DispatchKey::AutocastPrivateUse1;
case c10::DeviceType::MPS:
return DispatchKey::AutocastMPS;
default:
throw std::runtime_error(
"unknown device type for autocast in get_autocast_dispatch_key_from_device_type");
}
}
inline bool is_autocast_available(c10::DeviceType device_type) {
if (std::find(
_AUTOCAST_SUPPORTED_DEVICES.begin(),
_AUTOCAST_SUPPORTED_DEVICES.end(),
device_type) != _AUTOCAST_SUPPORTED_DEVICES.end()) {
return true;
} else {
return false;
}
}
inline at::ScalarType get_lower_precision_fp_from_device_type(
c10::DeviceType device_type) {
if (is_autocast_available(device_type)) {
return get_autocast_dtype(device_type);
} else {
throw std::runtime_error(
"unknown device type for autocast in get_lower_precision_fp_from_device_type");
}
}
/********************************************************************
Logic to extract the promote type from any Tensor or TensorList args.
********************************************************************/
// Overload to catch Tensor args.
// If nextArg is floating-point, compare its scalar_type with our
// current best guess for the promote type, and update if necessary.
inline at::ScalarType prioritize(
at::ScalarType current,
const Tensor& nextArg,
c10::DeviceType device_type = c10::DeviceType::CUDA) {
if (current == at::kDouble) {
TORCH_CHECK(false, "promote type is double in at::autocast::prioritize");
return current;
}
at::ScalarType lower_precision_fp =
get_lower_precision_fp_from_device_type(device_type);
if (is_autocast_eligible(nextArg, device_type)) {
auto next = nextArg.scalar_type();
if (next == at::kDouble) {
return current; // ignores double tensors
} else if (current == at::kFloat || next == at::kFloat) {
return at::kFloat; // prioritizes float over lower_precision_fp
} else if (current == lower_precision_fp && next == lower_precision_fp) {
return lower_precision_fp;
} else {
TORCH_CHECK(
false, "Unexpected floating ScalarType in at::autocast::prioritize");
return current;
}
} else {
return current;
}
}
// Overload to catch TensorList args (for e.g. cat, stack).
// Reuses the overload above to process each Tensor in the list.
inline at::ScalarType prioritize(
at::ScalarType current,
const TensorList& list,
c10::DeviceType device_type = c10::DeviceType::CUDA) {
for (const auto& tensor : list) {
current = prioritize(current, tensor, device_type);
}
return current;
}
inline at::ScalarType prioritize(
at::ScalarType current,
const ITensorListRef& list,
c10::DeviceType device_type = c10::DeviceType::CUDA) {
for (const auto& tensor : list) {
current = prioritize(current, tensor, device_type);
}
return current;
}
// Template to catch non-Tensor args (no-op that returns current best guess)
template <typename T>
inline at::ScalarType prioritize(
at::ScalarType current,
T nextArg,
c10::DeviceType device_type = c10::DeviceType::CUDA) {
return current;
}
// Overload for the tail case.
inline at::ScalarType promote_type(
at::ScalarType current,
c10::DeviceType device_type) {
return current;
}
// Unpack args and determine if incoming lower_precision_fp tensors need to be
// promoted to float32. Non-Tensor arguments are ignored.
template <typename Arg0, typename... Args>
inline at::ScalarType promote_type(
at::ScalarType current,
c10::DeviceType device_type,
Arg0 arg0,
Args... args) {
auto new_current = prioritize(current, arg0, device_type);
return promote_type(new_current, device_type, args...);
}
/****************************************************
Logic to apply cached casting to any Tensor argument.
****************************************************/
inline bool is_eligible(
const Tensor& arg,
c10::DeviceType device_type = c10::DeviceType::CUDA) {
return (
arg.defined() && is_autocast_eligible(arg, device_type) &&
(arg.scalar_type() != at::kDouble));
}
// Overload to catch Tensor args
TORCH_API Tensor cached_cast(
at::ScalarType to_type,
const Tensor& arg,
c10::DeviceType device_type = c10::DeviceType::CUDA);
// Overload to process std::optional<Tensor>
inline std::optional<Tensor> cached_cast(
at::ScalarType to_type,
const std::optional<Tensor>& arg,
c10::DeviceType device_type = c10::DeviceType::CUDA) {
if (arg.has_value()) {
return cached_cast(to_type, *arg, device_type);
} else {
return std::nullopt;
}
}
// Overload to process TensorLists
inline std::vector<Tensor> cached_cast(
at::ScalarType to_type,
const TensorList& arg,
c10::DeviceType device_type = c10::DeviceType::CUDA) {
std::vector<Tensor> vec;
vec.reserve(arg.size());
for (const auto& t : arg) {
vec.emplace_back(cached_cast(to_type, t, device_type));
}
return vec;
}
inline std::vector<Tensor> cached_cast(
at::ScalarType to_type,
const ITensorListRef& arg,
c10::DeviceType device_type = c10::DeviceType::CUDA) {
std::vector<Tensor> vec;
vec.reserve(arg.size());
for (const auto& t : arg) {
vec.emplace_back(cached_cast(to_type, t, device_type));
}
return vec;
}
// Template to catch non-Tensor args.
template <typename T>
inline T cached_cast(
at::ScalarType to_type,
T arg,
c10::DeviceType device_type = c10::DeviceType::CUDA) {
return arg;
}
/*******************************************************
Logic to flip an output dtype flag.
Keep it simple for now by assuming only one such flag is
present in the argument list. If I ever need a function
with more than flag I'll figure out something else.
The policy is:
If the user has explicity specified a dtype, respect it.
Otherwise, set it to the autocast type.
********************************************************/
// Overload to catch dtype flags
std::optional<ScalarType> inline set_opt_dtype(
at::ScalarType to_type,
const std::optional<ScalarType>& dtype) {
return dtype.has_value() ? dtype : to_type;
}
// Template to catch other args
template <typename T>
inline T set_opt_dtype(at::ScalarType to_type, T arg) {
return arg;
}
template <typename... Args>
inline bool firstarg_is_eligible(
c10::DeviceType device_type,
const Tensor& arg,
Args... args) {
return is_eligible(arg, device_type);
}
template <typename... Args>
inline at::ScalarType type_from_firstarg(
c10::DeviceType device_type,
at::ScalarType to_type,
const Tensor& arg,
Args... args) {
return (is_eligible(arg, device_type) ? to_type : arg.scalar_type());
}
// Policies correspond to op categories that need code-divergent handling.
// Wrapper templates below are specialized based on a policy template parameter.
enum class CastPolicy : uint8_t {
lower_precision_fp = 0, // Cast all inputs to lower_precision_fp before
// running the op. Currently, lower_precision_fp is
// fp16 for AutocastCUDA, and is defined by user
// (default bf16) for AutocastCPU or other device.
fp32, // Cast all inputs to at::kFloat before running the op.
fp32_set_opt_dtype, // Treats functions (like softmax) that
// 1. we'd like to run in fp32 and
// 2. have a std::optional<ScalarType> arg that controls
// the output type.
// fp32_set_opt_dtype wrappers' policy is: if the output
// type is already set, don't touch it, otherwise, set
// it to at::kFloat.
fp32_append_dtype, // Treats functions (like norm) that
// 1. we'd like to run in fp32 and
// 2. have some overloads that accept an output type and
// other overloads that don't.
// fp32_append_dtype wrappers wrap the overloads that don't
// have an output dtype.
// The wrapper policy is: append at::kFloat to the args,
// and redispatch to the type-aware overload.
promote, // Run in the widest dtype among several args.
};
/********************************************************************************************************
Templates to provide wrapper functions
I'm copying the pattern used in core/boxing/impl/WrapFunctionIntoFunctor.h to
extract args and return type. (see also
https://stackoverflow.com/questions/46533698/how-to-deduce-argument-list-from-function-pointer)
This strategy uses an exterior "WrapFunction" that extracts arguments on behalf
of (in my case several specializations of) an interior "WrapFunction_".
Interior WrapFunction_ specializations are defined for each CastPolicy.
********************************************************************************************************/
// Base template for WrapFunction_, which is specialized to contain a "call"
// method each CastPolicy
template <
CastPolicy policy,
c10::DeviceType device_type,
class Redispatch,
Redispatch* F,
class Ret,
class ArgList>
struct WrapFunction_ {};
// CastPolicy::lower_precision_fp General_DeviceType
template <
c10::DeviceType device_type,
class Redispatch,
Redispatch* F,
class Ret,
class... Args>
struct WrapFunction_<
CastPolicy::lower_precision_fp,
device_type,
Redispatch,
F,
Ret,
guts::typelist::typelist<Args...>> {
static Ret call(Args... args) {
c10::impl::ExcludeDispatchKeyGuard no_autocast(
get_autocast_dispatch_key_from_device_type(device_type));
return (*F)(cached_cast(
get_lower_precision_fp_from_device_type(device_type),
args,
device_type)...);
}
};
// CastPolicy::fp32 General_DeviceType
template <
c10::DeviceType device_type,
class Redispatch,
Redispatch* F,
class Ret,
class... Args>
struct WrapFunction_<
CastPolicy::fp32,
device_type,
Redispatch,
F,
Ret,
guts::typelist::typelist<Args...>> {
static Ret call(Args... args) {
c10::impl::ExcludeDispatchKeyGuard no_autocast(
get_autocast_dispatch_key_from_device_type(device_type));
return (*F)(cached_cast(at::kFloat, args, device_type)...);
}
};
// CastPolicy::fp32_set_opt_dtype General_DeviceType
template <
c10::DeviceType device_type,
class Redispatch,
Redispatch* F,
class Ret,
class... Args>
struct WrapFunction_<
CastPolicy::fp32_set_opt_dtype,
device_type,
Redispatch,
F,
Ret,
guts::typelist::typelist<Args...>> {
static Ret call(Args... args) {
c10::impl::ExcludeDispatchKeyGuard no_autocast(
get_autocast_dispatch_key_from_device_type(device_type));
if (firstarg_is_eligible(device_type, args...)) {
return (*F)(set_opt_dtype(at::kFloat, args)...);
} else {
// If ineligible, calls F with unaltered args. Does not set opt dtype,
// because setting opt dtype explicitly may interfere with internal
// implicit promotion decisions.
return (*F)(args...);
}
}
};
// CastPolicy::fp32_append_dtype General_DeviceType
template <
c10::DeviceType device_type,
class Redispatch,
Redispatch* F,
class Ret,
class... Args>
struct WrapFunction_<
CastPolicy::fp32_append_dtype,
device_type,
Redispatch,
F,
Ret,
guts::typelist::typelist<Args...>> {
static Ret call(Args... args) {
c10::impl::ExcludeDispatchKeyGuard no_autocast(
get_autocast_dispatch_key_from_device_type(device_type));
at::ScalarType out_type =
type_from_firstarg(device_type, at::kFloat, args...);
return (*F)(args..., out_type);
}
};
// CastPolicy::promote General_DeviceType
template <
c10::DeviceType device_type,
class Redispatch,
Redispatch* F,
class Ret,
class... Args>
struct WrapFunction_<
CastPolicy::promote,
device_type,
Redispatch,
F,
Ret,
guts::typelist::typelist<Args...>> {
static Ret call(Args... args) {
c10::impl::ExcludeDispatchKeyGuard no_autocast(
get_autocast_dispatch_key_from_device_type(device_type));
auto to_type = promote_type(
get_lower_precision_fp_from_device_type(device_type),
device_type,
args...);
return (*F)(cached_cast(to_type, args, device_type)...);
}
};
// Wrapper to infer return_type and parameter_types for WrapFunction_ (imitating
// core/boxing/impl/WrapFunctionIntoFunctor.h)
template <
CastPolicy policy,
c10::DeviceType device_type,
class Registered, // The signature for which we're registering. The
// dispatcher's calling code invokes our registered
// functions with arguments matching Registered, so we
// register WrapFunction_::call methods with a matching
// signature to properly field those arguments.
// guts::function_traits below extracts return_type and
// parameter_types from Registered, which WrapFunction_
// templates above use to declare their call methods.
class Redispatch, // The signature for the function we're redispatching to.
// In most cases this is the same as Registered, but for
// some ops (for example, ops where we append a dtype)
// it's useful to redispatch to a function with a
// different signature.
Redispatch* F> // The actual function we're redispatching to.
struct WrapFunction final {
using type = WrapFunction_<
policy,
device_type,
Redispatch,
F,
typename guts::function_traits<Registered>::return_type,
typename guts::function_traits<Registered>::parameter_types>;
};
/*****************************************************************************************************************
This section performs load-time registration for autocast wrappers.
It's debatable at what level operations should be patched. We'd like casts to
be autograd-exposed and precede autograd history recording, so that for
lower_precision_fp ops, input tensors are saved for backward in
lower_precision_fp rather than fp32. Saving inputs in lower_precision_fp
can significantly reduce a model's memory footprint.
Option 1 (strawman): Patch only at the level of explicit calls into
cudnn/cublas (cudnn_convolution, etc), because those are the code paths that are
guaranteed to use Tensor Cores, therefore they're the ones that will benefit
most from lower_precision_fp. Potential pitfall: convolutions (and other ops)
are wrapped in several layers of at::* calls. If one of those happens to record
autograd history, then we've lost the opportunity to save inputs in
lower_precision_fp.
Option 2: Patch the Python-exposed surface of calls, to make 100% sure autograd
history recording can't sneak in ahead of autocast. This mirrors Apex most
closely.
I think Option 2 is the right answer for all ops, not just convolutions. Option
2 is what I implement here.
*****************************************************************************************************************/
/********************************************************************************************************************
Explicit registration for out-of-place ops
The stuff below could be codegenned. Ed said
> you are going to have to write the function definition at some point, I
wouldn't try to get clever about it Therefore, for the moment, this is all
copy pasted in from VariableTypeEverything.cpp with appropriate substitutions.
********************************************************************************************************************/
} // namespace at::autocast
#define ADD_NS(RAW_OP) at::RAW_OP
#define _KERNEL_OVERLOAD_NARG_IMPL(_0, _1, _2, N, ...) N
#define _KERNEL_OVERLOAD_NARG(...) \
C10_EXPAND_MSVC_WORKAROUND(_KERNEL_OVERLOAD_NARG_IMPL(__VA_ARGS__, 2, 1))
// Common cases where registration signature matches redispatch signature
// (that's why SIGNATURE is repeated in the WrapFunction instantiation)
#define KERNEL1(DISPATCHKEY, OP, POLICY) \
m.impl( \
TORCH_SELECTIVE_NAME("aten::" #OP), \
&::at::autocast::WrapFunction< \
::at::autocast::CastPolicy::POLICY, \
DISPATCHKEY, \
decltype(ATEN_FN(OP)), \
decltype(ATEN_FN(OP)), \
&ATEN_FN(OP)>::type::call);
#define KERNEL2(DISPATCHKEY, OP, OVERLOAD, POLICY) \
m.impl( \
TORCH_SELECTIVE_NAME("aten::" #OP "." #OVERLOAD), \
&::at::autocast::WrapFunction< \
::at::autocast::CastPolicy::POLICY, \
DISPATCHKEY, \
decltype(ATEN_FN2(OP, OVERLOAD)), \
decltype(ATEN_FN2(OP, OVERLOAD)), \
&ATEN_FN2(OP, OVERLOAD)>::type::call);
#define _KERNEL_DISPATCH(DISPATCHKEY, NARG, ...) \
C10_CONCATENATE(KERNEL, NARG)(DISPATCHKEY, __VA_ARGS__)
#define _KERNEL_IMPL(DISPATCHKEY, ...) \
_KERNEL_DISPATCH(DISPATCHKEY, _KERNEL_OVERLOAD_NARG(__VA_ARGS__), __VA_ARGS__)
// It will dispatch to KERNEL1 or KERNEL2 based on its inputs.
#define KERNEL(DISPATCHKEY, ...) _KERNEL_IMPL(DISPATCHKEY, __VA_ARGS__)
// Less-common but still useful case: redispatching to a function
// with a new signature (e.g. appending a dtype)
#define KERNEL_DIFFERENT_REDISPATCH_SIGNATURE( \
DISPATCHKEY, \
REDISPATCH_FUNC, \
REGISTER_NAME, \
REGISTER_SIGNATURE, \
REDISPATCH_SIGNATURE, \
POLICY) \
m.impl( \
TORCH_SELECTIVE_NAME("aten::" REGISTER_NAME), \
&::at::autocast::WrapFunction< \
::at::autocast::CastPolicy::POLICY, \
DISPATCHKEY, \
REGISTER_SIGNATURE, \
REDISPATCH_SIGNATURE, \
&REDISPATCH_FUNC>::type::call);
// KERNEL_CPU/KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_CPU
// registration (OP, POLICY) or (OP, OVERLOAD, POLICY) for AutocastCPU
#define KERNEL_CPU(...) KERNEL(c10::DeviceType::CPU, __VA_ARGS__)
#define KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_CPU( \
REDISPATCH_FUNC, \
REGISTER_NAME, \
REGISTER_SIGNATURE, \
REDISPATCH_SIGNATURE, \
POLICY) \
KERNEL_DIFFERENT_REDISPATCH_SIGNATURE( \
c10::DeviceType::CPU, \
REDISPATCH_FUNC, \
REGISTER_NAME, \
REGISTER_SIGNATURE, \
REDISPATCH_SIGNATURE, \
POLICY)
// KERNEL_CUDA/KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_CUDA
// registration (OP, POLICY) or (OP, OVERLOAD, POLICY) for AutocastCUDA
#define KERNEL_CUDA(...) KERNEL(c10::DeviceType::CUDA, __VA_ARGS__)
#define KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_CUDA( \
REDISPATCH_FUNC, \
REGISTER_NAME, \
REGISTER_SIGNATURE, \
REDISPATCH_SIGNATURE, \
POLICY) \
KERNEL_DIFFERENT_REDISPATCH_SIGNATURE( \
c10::DeviceType::CUDA, \
REDISPATCH_FUNC, \
REGISTER_NAME, \
REGISTER_SIGNATURE, \
REDISPATCH_SIGNATURE, \
POLICY)
// KERNEL_MTIA/KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_MTIA
// registration (OP, POLICY) or (OP, OVERLOAD, POLICY) for AutocastMTIA
#define KERNEL_MTIA(...) KERNEL(c10::DeviceType::MTIA, __VA_ARGS__)
#define KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_MTIA( \
REDISPATCH_FUNC, \
REGISTER_NAME, \
REGISTER_SIGNATURE, \
REDISPATCH_SIGNATURE, \
POLICY) \
KERNEL_DIFFERENT_REDISPATCH_SIGNATURE( \
c10::DeviceType::MTIA, \
REDISPATCH_FUNC, \
REGISTER_NAME, \
REGISTER_SIGNATURE, \
REDISPATCH_SIGNATURE, \
POLICY)
// KERNEL_XPU/KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_XPU
// registration (OP, POLICY) or (OP, OVERLOAD, POLICY) for AutocastXPU
#define KERNEL_XPU(...) KERNEL(c10::DeviceType::XPU, __VA_ARGS__)
#define KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_XPU( \
REDISPATCH_FUNC, \
REGISTER_NAME, \
REGISTER_SIGNATURE, \
REDISPATCH_SIGNATURE, \
POLICY) \
KERNEL_DIFFERENT_REDISPATCH_SIGNATURE( \
c10::DeviceType::XPU, \
REDISPATCH_FUNC, \
REGISTER_NAME, \
REGISTER_SIGNATURE, \
REDISPATCH_SIGNATURE, \
POLICY)
// KERNEL_PRIVATEUSEONE/KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_PRIVATEUSEONE
// registration (OP, POLICY) or (OP, OVERLOAD, POLICY) for AutocastPrivateUse1
#define KERNEL_PRIVATEUSEONE(...) \
KERNEL(c10::DeviceType::PrivateUse1, __VA_ARGS__)
#define KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_PRIVATEUSEONE( \
REDISPATCH_FUNC, \
REGISTER_NAME, \
REGISTER_SIGNATURE, \
REDISPATCH_SIGNATURE, \
POLICY) \
KERNEL_DIFFERENT_REDISPATCH_SIGNATURE( \
c10::DeviceType::PrivateUse1, \
REDISPATCH_FUNC, \
REGISTER_NAME, \
REGISTER_SIGNATURE, \
REDISPATCH_SIGNATURE, \
POLICY)
// KERNEL_MPS
// registration (OP, POLICY) or (OP, OVERLOAD, POLICY) for AutocastMPS
#define KERNEL_MPS(...) KERNEL(c10::DeviceType::MPS, __VA_ARGS__)
// Op lists for different policies.
// To make sure other backends can reuse the policy op list.
#define AT_FORALL_LOWER_PRECISION_FP(_) \
_(_convolution, deprecated) \
_(_convolution) \
_(conv1d) \
_(conv2d) \
_(conv3d) \
_(conv_tbc) \
_(conv_transpose1d) \
_(conv_transpose2d, input) \
_(conv_transpose3d, input) \
_(convolution) \
_(prelu) \
_(addmm) \
_(addmv) \
_(addr) \
_(matmul) \
_(einsum) \
_(mm) \
_(mv) \
_(linalg_vecdot) \
_(linear) \
_(addbmm) \
_(baddbmm) \
_(bmm) \
_(chain_matmul) \
_(linalg_multi_dot) \
_(_thnn_fused_lstm_cell) \
_(_thnn_fused_gru_cell) \
_(lstm_cell) \
_(gru_cell) \
_(rnn_tanh_cell) \
_(rnn_relu_cell) \
_(_scaled_dot_product_flash_attention) \
_(scaled_dot_product_attention)
#define AT_FORALL_FP32(_) \
_(acos) \
_(asin) \
_(cosh) \
_(erfinv) \
_(exp) \
_(expm1) \
_(log) \
_(log10) \
_(log2) \
_(log1p) \
_(reciprocal) \
_(rsqrt) \
_(sinh) \
_(tan) \
_(pow, Tensor_Scalar) \
_(pow, Tensor_Tensor) \
_(pow, Scalar) \
_(softplus) \
_(layer_norm) \
_(native_layer_norm) \
_(group_norm) \
_(frobenius_norm, dim) \
_(nuclear_norm) \
_(nuclear_norm, dim) \
_(cosine_similarity) \
_(poisson_nll_loss) \
_(cosine_embedding_loss) \
_(nll_loss) \
_(nll_loss2d) \
_(hinge_embedding_loss) \
_(kl_div) \
_(l1_loss) \
_(smooth_l1_loss) \
_(huber_loss) \
_(mse_loss) \
_(margin_ranking_loss) \
_(multilabel_margin_loss) \
_(soft_margin_loss) \
_(triplet_margin_loss) \
_(multi_margin_loss) \
_(binary_cross_entropy_with_logits) \
_(dist) \
_(pdist) \
_(cdist) \
_(renorm) \
_(logsumexp) \
_(upsample_nearest1d) \
_(_upsample_nearest_exact1d) \
_(upsample_nearest2d) \
_(_upsample_nearest_exact2d) \
_(upsample_nearest3d) \
_(_upsample_nearest_exact3d) \
_(upsample_linear1d) \
_(upsample_bilinear2d) \
_(_upsample_bilinear2d_aa) \
_(upsample_trilinear3d) \
_(upsample_bicubic2d) \
_(_upsample_bicubic2d_aa)
#define AT_FORALL_FP32_SET_OPT_DTYPE(_) \
_(prod) \
_(prod, dim_int) \
_(prod, dim_Dimname) \
_(softmax, int) \
_(softmax, Dimname) \
_(log_softmax, int) \
_(log_softmax, Dimname) \
_(cumprod) \
_(cumprod, dimname) \
_(cumsum) \
_(cumsum, dimname) \
_(linalg_vector_norm) \
_(linalg_matrix_norm) \
_(linalg_matrix_norm, str_ord) \
_(sum) \
_(sum, dim_IntList) \
_(sum, dim_DimnameList)
#define AT_FORALL_DIFFERENT_REDISPATCH_SIGNATURE(_) \
_(ADD_NS(norm), \
"norm.Scalar", \
Tensor(const Tensor&, const Scalar&), \
Tensor(const Tensor&, const std::optional<Scalar>&, ScalarType), \
fp32_append_dtype) \
_(ADD_NS(norm), \
"norm.ScalarOpt_dim", \
Tensor(const Tensor&, const std::optional<Scalar>&, IntArrayRef, bool), \
Tensor( \
const Tensor&, \
const std::optional<Scalar>&, \
IntArrayRef, \
bool, \
ScalarType), \
fp32_append_dtype) \
_(ADD_NS(norm), \
"norm.names_ScalarOpt_dim", \
Tensor(const Tensor&, const std::optional<Scalar>&, DimnameList, bool), \
Tensor( \
const Tensor&, \
const std::optional<Scalar>&, \
DimnameList, \
bool, \
ScalarType), \
fp32_append_dtype)
#define AT_FORALL_PROMOTE(_) \
_(addcdiv) \
_(addcmul) \
_(atan2) \
_(bilinear) \
_(cross) \
_(dot) \
_(vdot) \
_(grid_sampler) \
_(index_put) \
_(tensordot) \
_(scatter_add)