946 lines
39 KiB
C
946 lines
39 KiB
C
![]() |
#pragma once
|
||
|
|
||
|
#include <ATen/ATen.h>
|
||
|
#include <ATen/NativeFunctions.h>
|
||
|
#include <ATen/Operators.h>
|
||
|
#include <torch/library.h>
|
||
|
|
||
|
#include <c10/core/impl/LocalDispatchKeySet.h>
|
||
|
#include <c10/util/intrusive_ptr.h>
|
||
|
|
||
|
namespace at::autocast {
|
||
|
|
||
|
TORCH_API bool is_autocast_enabled(at::DeviceType device_type);
|
||
|
TORCH_API void set_autocast_enabled(at::DeviceType device_type, bool enabled);
|
||
|
TORCH_API at::ScalarType get_autocast_dtype(at::DeviceType device_type);
|
||
|
TORCH_API void set_autocast_dtype(
|
||
|
at::DeviceType device_type,
|
||
|
at::ScalarType dtype);
|
||
|
TORCH_API void clear_cache();
|
||
|
TORCH_API int increment_nesting();
|
||
|
TORCH_API int decrement_nesting();
|
||
|
TORCH_API bool is_autocast_cache_enabled();
|
||
|
TORCH_API void set_autocast_cache_enabled(bool enabled);
|
||
|
|
||
|
// deprecated CUDA-specific autocast APIs
|
||
|
C10_DEPRECATED_MESSAGE(
|
||
|
"at::autocast::is_enabled() is deprecated. Please use at::autocast::is_autocast_enabled(at::kCUDA) instead.")
|
||
|
TORCH_API inline bool is_enabled() {
|
||
|
TORCH_WARN_DEPRECATION(
|
||
|
"at::autocast::",
|
||
|
__func__,
|
||
|
"() is deprecated. Please use at::autocast::is_autocast_enabled(at::kCUDA) instead.")
|
||
|
return is_autocast_enabled(at::kCUDA);
|
||
|
}
|
||
|
C10_DEPRECATED_MESSAGE(
|
||
|
"at::autocast::set_enabled(enabled) is deprecated. Please use at::autocast::set_autocast_enabled(at::kCUDA, enabled) instead.")
|
||
|
TORCH_API inline void set_enabled(bool enabled) {
|
||
|
TORCH_WARN_DEPRECATION(
|
||
|
"at::autocast::",
|
||
|
__func__,
|
||
|
"(enabled) is deprecated. Please use at::autocast::set_autocast_enabled(at::kCUDA, enabled) instead.")
|
||
|
set_autocast_enabled(at::kCUDA, enabled);
|
||
|
}
|
||
|
C10_DEPRECATED_MESSAGE(
|
||
|
"at::autocast::get_autocast_gpu_dtype() is deprecated. Please use at::autocast::get_autocast_dtype(at::kCUDA) instead.")
|
||
|
TORCH_API inline at::ScalarType get_autocast_gpu_dtype() {
|
||
|
TORCH_WARN_DEPRECATION(
|
||
|
"at::autocast::",
|
||
|
__func__,
|
||
|
"() is deprecated. Please use at::autocast::get_autocast_dtype(at::kCUDA) instead.")
|
||
|
return get_autocast_dtype(at::kCUDA);
|
||
|
}
|
||
|
C10_DEPRECATED_MESSAGE(
|
||
|
"at::autocast::set_autocast_gpu_dtype(dtype) is deprecated. Please use at::autocast::set_autocast_dtype(at::kCUDA, dtype) instead.")
|
||
|
TORCH_API inline void set_autocast_gpu_dtype(at::ScalarType dtype) {
|
||
|
TORCH_WARN_DEPRECATION(
|
||
|
"at::autocast::",
|
||
|
__func__,
|
||
|
"(dtype) is deprecated. Please use at::autocast::set_autocast_dtype(at::kCUDA, dtype) instead.")
|
||
|
set_autocast_dtype(at::kCUDA, dtype);
|
||
|
}
|
||
|
|
||
|
#define DECLARE_DEPRECATED_AUTOCAST_APIS(name, device_type) \
|
||
|
C10_DEPRECATED_MESSAGE( \
|
||
|
"at::autocast::is_" #name \
|
||
|
"_enabled() is deprecated. Please use at::autocast::is_autocast_enabled(" #device_type \
|
||
|
") instead.") \
|
||
|
TORCH_API inline bool is_##name##_enabled() { \
|
||
|
TORCH_WARN_DEPRECATION( \
|
||
|
"at::autocast::", \
|
||
|
__func__, \
|
||
|
"() is deprecated. Please use at::autocast::is_autocast_enabled(" #device_type \
|
||
|
") instead.") \
|
||
|
return is_autocast_enabled(device_type); \
|
||
|
} \
|
||
|
\
|
||
|
C10_DEPRECATED_MESSAGE( \
|
||
|
"at::autocast::set_" #name \
|
||
|
"_enabled(enabled) is deprecated. Please use at::autocast::set_autocast_enabled(" #device_type \
|
||
|
", enabled) instead.") \
|
||
|
TORCH_API inline void set_##name##_enabled(bool enabled) { \
|
||
|
TORCH_WARN_DEPRECATION( \
|
||
|
"at::autocast::", \
|
||
|
__func__, \
|
||
|
"(enabled) is deprecated. Please use at::autocast::set_autocast_enabled(" #device_type \
|
||
|
", enabled) instead.") \
|
||
|
set_autocast_enabled(device_type, enabled); \
|
||
|
} \
|
||
|
\
|
||
|
C10_DEPRECATED_MESSAGE( \
|
||
|
"at::autocast::get_autocast_" #name \
|
||
|
"_dtype() is deprecated. Please use at::autocast::get_autocast_dtype(" #device_type \
|
||
|
") instead.") \
|
||
|
TORCH_API inline at::ScalarType get_autocast_##name##_dtype() { \
|
||
|
TORCH_WARN_DEPRECATION( \
|
||
|
"at::autocast::", \
|
||
|
__func__, \
|
||
|
"() is deprecated. Please at::autocast::get_autocast_dtype(" #device_type \
|
||
|
") instead.") \
|
||
|
return get_autocast_dtype(device_type); \
|
||
|
} \
|
||
|
\
|
||
|
C10_DEPRECATED_MESSAGE( \
|
||
|
"at::autocast::set_autocast_" #name \
|
||
|
"_dtype(dtype) is deprecated. Please use at::autocast::set_autocast_dtype(" #device_type \
|
||
|
", dtype) instead.") \
|
||
|
TORCH_API inline void set_autocast_##name##_dtype(at::ScalarType dtype) { \
|
||
|
TORCH_WARN_DEPRECATION( \
|
||
|
"at::autocast::", \
|
||
|
__func__, \
|
||
|
"(dtype) is deprecated. Please use at::autocast::set_autocast_dtype(" #device_type \
|
||
|
", dtype) instead.") \
|
||
|
set_autocast_dtype(device_type, dtype); \
|
||
|
}
|
||
|
|
||
|
#define AT_FORALL_DEPRECATED_AUTOCAST_BACKENDS(_) \
|
||
|
_(cpu, at::kCPU) \
|
||
|
_(mtia, at::kMTIA) \
|
||
|
_(xpu, at::kXPU) \
|
||
|
_(xla, at::kXLA) \
|
||
|
_(hpu, at::kHPU) \
|
||
|
_(ipu, at::kIPU) \
|
||
|
_(privateuseone, at::kPrivateUse1)
|
||
|
|
||
|
// deprecated other backend specific autocast APIs
|
||
|
AT_FORALL_DEPRECATED_AUTOCAST_BACKENDS(DECLARE_DEPRECATED_AUTOCAST_APIS)
|
||
|
|
||
|
const std::array<at::DeviceType, 9> _AUTOCAST_SUPPORTED_DEVICES{
|
||
|
at::kCPU,
|
||
|
at::kCUDA,
|
||
|
at::kMTIA,
|
||
|
at::kXPU,
|
||
|
at::kIPU,
|
||
|
at::kHPU,
|
||
|
at::kXLA,
|
||
|
at::kPrivateUse1,
|
||
|
at::kMPS};
|
||
|
|
||
|
namespace {
|
||
|
inline bool is_autocast_eligible(
|
||
|
const Tensor& tensor,
|
||
|
c10::DeviceType device_type) {
|
||
|
switch (device_type) {
|
||
|
case c10::DeviceType::CUDA:
|
||
|
return (tensor.is_cuda() || tensor.is_xla()) &&
|
||
|
tensor.is_floating_point();
|
||
|
case c10::DeviceType::CPU:
|
||
|
return (tensor.is_cpu() || tensor.is_mkldnn()) &&
|
||
|
tensor.is_floating_point();
|
||
|
case c10::DeviceType::MTIA:
|
||
|
return tensor.is_mtia() && tensor.is_floating_point();
|
||
|
case c10::DeviceType::XPU:
|
||
|
return tensor.is_xpu() && tensor.is_floating_point();
|
||
|
case c10::DeviceType::IPU:
|
||
|
return tensor.is_ipu() && tensor.is_floating_point();
|
||
|
case c10::DeviceType::HPU:
|
||
|
return tensor.is_hpu() && tensor.is_floating_point();
|
||
|
case c10::DeviceType::XLA:
|
||
|
return tensor.is_xla() && tensor.is_floating_point();
|
||
|
case c10::DeviceType::PrivateUse1:
|
||
|
return tensor.is_privateuseone() && tensor.is_floating_point();
|
||
|
case c10::DeviceType::MPS:
|
||
|
return tensor.is_mps() && tensor.is_floating_point();
|
||
|
default:
|
||
|
return false;
|
||
|
}
|
||
|
}
|
||
|
} // namespace
|
||
|
|
||
|
inline DispatchKey get_autocast_dispatch_key_from_device_type(
|
||
|
c10::DeviceType device_type) {
|
||
|
switch (device_type) {
|
||
|
case c10::DeviceType::CUDA:
|
||
|
return DispatchKey::Autocast;
|
||
|
case c10::DeviceType::CPU:
|
||
|
return DispatchKey::AutocastCPU;
|
||
|
case c10::DeviceType::MTIA:
|
||
|
return DispatchKey::AutocastMTIA;
|
||
|
case c10::DeviceType::XPU:
|
||
|
return DispatchKey::AutocastXPU;
|
||
|
case c10::DeviceType::IPU:
|
||
|
return DispatchKey::AutocastIPU;
|
||
|
case c10::DeviceType::HPU:
|
||
|
return DispatchKey::AutocastHPU;
|
||
|
case c10::DeviceType::XLA:
|
||
|
return DispatchKey::AutocastXLA;
|
||
|
case c10::DeviceType::PrivateUse1:
|
||
|
return DispatchKey::AutocastPrivateUse1;
|
||
|
case c10::DeviceType::MPS:
|
||
|
return DispatchKey::AutocastMPS;
|
||
|
default:
|
||
|
throw std::runtime_error(
|
||
|
"unknown device type for autocast in get_autocast_dispatch_key_from_device_type");
|
||
|
}
|
||
|
}
|
||
|
|
||
|
inline bool is_autocast_available(c10::DeviceType device_type) {
|
||
|
if (std::find(
|
||
|
_AUTOCAST_SUPPORTED_DEVICES.begin(),
|
||
|
_AUTOCAST_SUPPORTED_DEVICES.end(),
|
||
|
device_type) != _AUTOCAST_SUPPORTED_DEVICES.end()) {
|
||
|
return true;
|
||
|
} else {
|
||
|
return false;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
inline at::ScalarType get_lower_precision_fp_from_device_type(
|
||
|
c10::DeviceType device_type) {
|
||
|
if (is_autocast_available(device_type)) {
|
||
|
return get_autocast_dtype(device_type);
|
||
|
} else {
|
||
|
throw std::runtime_error(
|
||
|
"unknown device type for autocast in get_lower_precision_fp_from_device_type");
|
||
|
}
|
||
|
}
|
||
|
|
||
|
/********************************************************************
|
||
|
Logic to extract the promote type from any Tensor or TensorList args.
|
||
|
********************************************************************/
|
||
|
|
||
|
// Overload to catch Tensor args.
|
||
|
// If nextArg is floating-point, compare its scalar_type with our
|
||
|
// current best guess for the promote type, and update if necessary.
|
||
|
inline at::ScalarType prioritize(
|
||
|
at::ScalarType current,
|
||
|
const Tensor& nextArg,
|
||
|
c10::DeviceType device_type = c10::DeviceType::CUDA) {
|
||
|
if (current == at::kDouble) {
|
||
|
TORCH_CHECK(false, "promote type is double in at::autocast::prioritize");
|
||
|
return current;
|
||
|
}
|
||
|
at::ScalarType lower_precision_fp =
|
||
|
get_lower_precision_fp_from_device_type(device_type);
|
||
|
if (is_autocast_eligible(nextArg, device_type)) {
|
||
|
auto next = nextArg.scalar_type();
|
||
|
if (next == at::kDouble) {
|
||
|
return current; // ignores double tensors
|
||
|
} else if (current == at::kFloat || next == at::kFloat) {
|
||
|
return at::kFloat; // prioritizes float over lower_precision_fp
|
||
|
} else if (current == lower_precision_fp && next == lower_precision_fp) {
|
||
|
return lower_precision_fp;
|
||
|
} else {
|
||
|
TORCH_CHECK(
|
||
|
false, "Unexpected floating ScalarType in at::autocast::prioritize");
|
||
|
return current;
|
||
|
}
|
||
|
} else {
|
||
|
return current;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Overload to catch TensorList args (for e.g. cat, stack).
|
||
|
// Reuses the overload above to process each Tensor in the list.
|
||
|
inline at::ScalarType prioritize(
|
||
|
at::ScalarType current,
|
||
|
const TensorList& list,
|
||
|
c10::DeviceType device_type = c10::DeviceType::CUDA) {
|
||
|
for (const auto& tensor : list) {
|
||
|
current = prioritize(current, tensor, device_type);
|
||
|
}
|
||
|
return current;
|
||
|
}
|
||
|
|
||
|
inline at::ScalarType prioritize(
|
||
|
at::ScalarType current,
|
||
|
const ITensorListRef& list,
|
||
|
c10::DeviceType device_type = c10::DeviceType::CUDA) {
|
||
|
for (const auto& tensor : list) {
|
||
|
current = prioritize(current, tensor, device_type);
|
||
|
}
|
||
|
return current;
|
||
|
}
|
||
|
|
||
|
// Template to catch non-Tensor args (no-op that returns current best guess)
|
||
|
template <typename T>
|
||
|
inline at::ScalarType prioritize(
|
||
|
at::ScalarType current,
|
||
|
T nextArg,
|
||
|
c10::DeviceType device_type = c10::DeviceType::CUDA) {
|
||
|
return current;
|
||
|
}
|
||
|
|
||
|
// Overload for the tail case.
|
||
|
inline at::ScalarType promote_type(
|
||
|
at::ScalarType current,
|
||
|
c10::DeviceType device_type) {
|
||
|
return current;
|
||
|
}
|
||
|
|
||
|
// Unpack args and determine if incoming lower_precision_fp tensors need to be
|
||
|
// promoted to float32. Non-Tensor arguments are ignored.
|
||
|
template <typename Arg0, typename... Args>
|
||
|
inline at::ScalarType promote_type(
|
||
|
at::ScalarType current,
|
||
|
c10::DeviceType device_type,
|
||
|
Arg0 arg0,
|
||
|
Args... args) {
|
||
|
auto new_current = prioritize(current, arg0, device_type);
|
||
|
return promote_type(new_current, device_type, args...);
|
||
|
}
|
||
|
|
||
|
/****************************************************
|
||
|
Logic to apply cached casting to any Tensor argument.
|
||
|
****************************************************/
|
||
|
inline bool is_eligible(
|
||
|
const Tensor& arg,
|
||
|
c10::DeviceType device_type = c10::DeviceType::CUDA) {
|
||
|
return (
|
||
|
arg.defined() && is_autocast_eligible(arg, device_type) &&
|
||
|
(arg.scalar_type() != at::kDouble));
|
||
|
}
|
||
|
|
||
|
// Overload to catch Tensor args
|
||
|
TORCH_API Tensor cached_cast(
|
||
|
at::ScalarType to_type,
|
||
|
const Tensor& arg,
|
||
|
c10::DeviceType device_type = c10::DeviceType::CUDA);
|
||
|
|
||
|
// Overload to process std::optional<Tensor>
|
||
|
inline std::optional<Tensor> cached_cast(
|
||
|
at::ScalarType to_type,
|
||
|
const std::optional<Tensor>& arg,
|
||
|
c10::DeviceType device_type = c10::DeviceType::CUDA) {
|
||
|
if (arg.has_value()) {
|
||
|
return cached_cast(to_type, *arg, device_type);
|
||
|
} else {
|
||
|
return std::nullopt;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Overload to process TensorLists
|
||
|
inline std::vector<Tensor> cached_cast(
|
||
|
at::ScalarType to_type,
|
||
|
const TensorList& arg,
|
||
|
c10::DeviceType device_type = c10::DeviceType::CUDA) {
|
||
|
std::vector<Tensor> vec;
|
||
|
vec.reserve(arg.size());
|
||
|
for (const auto& t : arg) {
|
||
|
vec.emplace_back(cached_cast(to_type, t, device_type));
|
||
|
}
|
||
|
return vec;
|
||
|
}
|
||
|
|
||
|
inline std::vector<Tensor> cached_cast(
|
||
|
at::ScalarType to_type,
|
||
|
const ITensorListRef& arg,
|
||
|
c10::DeviceType device_type = c10::DeviceType::CUDA) {
|
||
|
std::vector<Tensor> vec;
|
||
|
vec.reserve(arg.size());
|
||
|
for (const auto& t : arg) {
|
||
|
vec.emplace_back(cached_cast(to_type, t, device_type));
|
||
|
}
|
||
|
return vec;
|
||
|
}
|
||
|
|
||
|
// Template to catch non-Tensor args.
|
||
|
template <typename T>
|
||
|
inline T cached_cast(
|
||
|
at::ScalarType to_type,
|
||
|
T arg,
|
||
|
c10::DeviceType device_type = c10::DeviceType::CUDA) {
|
||
|
return arg;
|
||
|
}
|
||
|
|
||
|
/*******************************************************
|
||
|
Logic to flip an output dtype flag.
|
||
|
Keep it simple for now by assuming only one such flag is
|
||
|
present in the argument list. If I ever need a function
|
||
|
with more than flag I'll figure out something else.
|
||
|
The policy is:
|
||
|
If the user has explicity specified a dtype, respect it.
|
||
|
Otherwise, set it to the autocast type.
|
||
|
********************************************************/
|
||
|
|
||
|
// Overload to catch dtype flags
|
||
|
std::optional<ScalarType> inline set_opt_dtype(
|
||
|
at::ScalarType to_type,
|
||
|
const std::optional<ScalarType>& dtype) {
|
||
|
return dtype.has_value() ? dtype : to_type;
|
||
|
}
|
||
|
|
||
|
// Template to catch other args
|
||
|
template <typename T>
|
||
|
inline T set_opt_dtype(at::ScalarType to_type, T arg) {
|
||
|
return arg;
|
||
|
}
|
||
|
|
||
|
template <typename... Args>
|
||
|
inline bool firstarg_is_eligible(
|
||
|
c10::DeviceType device_type,
|
||
|
const Tensor& arg,
|
||
|
Args... args) {
|
||
|
return is_eligible(arg, device_type);
|
||
|
}
|
||
|
|
||
|
template <typename... Args>
|
||
|
inline at::ScalarType type_from_firstarg(
|
||
|
c10::DeviceType device_type,
|
||
|
at::ScalarType to_type,
|
||
|
const Tensor& arg,
|
||
|
Args... args) {
|
||
|
return (is_eligible(arg, device_type) ? to_type : arg.scalar_type());
|
||
|
}
|
||
|
|
||
|
// Policies correspond to op categories that need code-divergent handling.
|
||
|
// Wrapper templates below are specialized based on a policy template parameter.
|
||
|
enum class CastPolicy : uint8_t {
|
||
|
lower_precision_fp = 0, // Cast all inputs to lower_precision_fp before
|
||
|
// running the op. Currently, lower_precision_fp is
|
||
|
// fp16 for AutocastCUDA, and is defined by user
|
||
|
// (default bf16) for AutocastCPU or other device.
|
||
|
fp32, // Cast all inputs to at::kFloat before running the op.
|
||
|
fp32_set_opt_dtype, // Treats functions (like softmax) that
|
||
|
// 1. we'd like to run in fp32 and
|
||
|
// 2. have a std::optional<ScalarType> arg that controls
|
||
|
// the output type.
|
||
|
// fp32_set_opt_dtype wrappers' policy is: if the output
|
||
|
// type is already set, don't touch it, otherwise, set
|
||
|
// it to at::kFloat.
|
||
|
fp32_append_dtype, // Treats functions (like norm) that
|
||
|
// 1. we'd like to run in fp32 and
|
||
|
// 2. have some overloads that accept an output type and
|
||
|
// other overloads that don't.
|
||
|
// fp32_append_dtype wrappers wrap the overloads that don't
|
||
|
// have an output dtype.
|
||
|
// The wrapper policy is: append at::kFloat to the args,
|
||
|
// and redispatch to the type-aware overload.
|
||
|
promote, // Run in the widest dtype among several args.
|
||
|
};
|
||
|
|
||
|
/********************************************************************************************************
|
||
|
Templates to provide wrapper functions
|
||
|
|
||
|
I'm copying the pattern used in core/boxing/impl/WrapFunctionIntoFunctor.h to
|
||
|
extract args and return type. (see also
|
||
|
https://stackoverflow.com/questions/46533698/how-to-deduce-argument-list-from-function-pointer)
|
||
|
|
||
|
This strategy uses an exterior "WrapFunction" that extracts arguments on behalf
|
||
|
of (in my case several specializations of) an interior "WrapFunction_".
|
||
|
Interior WrapFunction_ specializations are defined for each CastPolicy.
|
||
|
********************************************************************************************************/
|
||
|
|
||
|
// Base template for WrapFunction_, which is specialized to contain a "call"
|
||
|
// method each CastPolicy
|
||
|
template <
|
||
|
CastPolicy policy,
|
||
|
c10::DeviceType device_type,
|
||
|
class Redispatch,
|
||
|
Redispatch* F,
|
||
|
class Ret,
|
||
|
class ArgList>
|
||
|
struct WrapFunction_ {};
|
||
|
|
||
|
// CastPolicy::lower_precision_fp General_DeviceType
|
||
|
template <
|
||
|
c10::DeviceType device_type,
|
||
|
class Redispatch,
|
||
|
Redispatch* F,
|
||
|
class Ret,
|
||
|
class... Args>
|
||
|
struct WrapFunction_<
|
||
|
CastPolicy::lower_precision_fp,
|
||
|
device_type,
|
||
|
Redispatch,
|
||
|
F,
|
||
|
Ret,
|
||
|
guts::typelist::typelist<Args...>> {
|
||
|
static Ret call(Args... args) {
|
||
|
c10::impl::ExcludeDispatchKeyGuard no_autocast(
|
||
|
get_autocast_dispatch_key_from_device_type(device_type));
|
||
|
return (*F)(cached_cast(
|
||
|
get_lower_precision_fp_from_device_type(device_type),
|
||
|
args,
|
||
|
device_type)...);
|
||
|
}
|
||
|
};
|
||
|
|
||
|
// CastPolicy::fp32 General_DeviceType
|
||
|
template <
|
||
|
c10::DeviceType device_type,
|
||
|
class Redispatch,
|
||
|
Redispatch* F,
|
||
|
class Ret,
|
||
|
class... Args>
|
||
|
struct WrapFunction_<
|
||
|
CastPolicy::fp32,
|
||
|
device_type,
|
||
|
Redispatch,
|
||
|
F,
|
||
|
Ret,
|
||
|
guts::typelist::typelist<Args...>> {
|
||
|
static Ret call(Args... args) {
|
||
|
c10::impl::ExcludeDispatchKeyGuard no_autocast(
|
||
|
get_autocast_dispatch_key_from_device_type(device_type));
|
||
|
return (*F)(cached_cast(at::kFloat, args, device_type)...);
|
||
|
}
|
||
|
};
|
||
|
|
||
|
// CastPolicy::fp32_set_opt_dtype General_DeviceType
|
||
|
template <
|
||
|
c10::DeviceType device_type,
|
||
|
class Redispatch,
|
||
|
Redispatch* F,
|
||
|
class Ret,
|
||
|
class... Args>
|
||
|
struct WrapFunction_<
|
||
|
CastPolicy::fp32_set_opt_dtype,
|
||
|
device_type,
|
||
|
Redispatch,
|
||
|
F,
|
||
|
Ret,
|
||
|
guts::typelist::typelist<Args...>> {
|
||
|
static Ret call(Args... args) {
|
||
|
c10::impl::ExcludeDispatchKeyGuard no_autocast(
|
||
|
get_autocast_dispatch_key_from_device_type(device_type));
|
||
|
if (firstarg_is_eligible(device_type, args...)) {
|
||
|
return (*F)(set_opt_dtype(at::kFloat, args)...);
|
||
|
} else {
|
||
|
// If ineligible, calls F with unaltered args. Does not set opt dtype,
|
||
|
// because setting opt dtype explicitly may interfere with internal
|
||
|
// implicit promotion decisions.
|
||
|
return (*F)(args...);
|
||
|
}
|
||
|
}
|
||
|
};
|
||
|
|
||
|
// CastPolicy::fp32_append_dtype General_DeviceType
|
||
|
template <
|
||
|
c10::DeviceType device_type,
|
||
|
class Redispatch,
|
||
|
Redispatch* F,
|
||
|
class Ret,
|
||
|
class... Args>
|
||
|
struct WrapFunction_<
|
||
|
CastPolicy::fp32_append_dtype,
|
||
|
device_type,
|
||
|
Redispatch,
|
||
|
F,
|
||
|
Ret,
|
||
|
guts::typelist::typelist<Args...>> {
|
||
|
static Ret call(Args... args) {
|
||
|
c10::impl::ExcludeDispatchKeyGuard no_autocast(
|
||
|
get_autocast_dispatch_key_from_device_type(device_type));
|
||
|
at::ScalarType out_type =
|
||
|
type_from_firstarg(device_type, at::kFloat, args...);
|
||
|
return (*F)(args..., out_type);
|
||
|
}
|
||
|
};
|
||
|
|
||
|
// CastPolicy::promote General_DeviceType
|
||
|
template <
|
||
|
c10::DeviceType device_type,
|
||
|
class Redispatch,
|
||
|
Redispatch* F,
|
||
|
class Ret,
|
||
|
class... Args>
|
||
|
struct WrapFunction_<
|
||
|
CastPolicy::promote,
|
||
|
device_type,
|
||
|
Redispatch,
|
||
|
F,
|
||
|
Ret,
|
||
|
guts::typelist::typelist<Args...>> {
|
||
|
static Ret call(Args... args) {
|
||
|
c10::impl::ExcludeDispatchKeyGuard no_autocast(
|
||
|
get_autocast_dispatch_key_from_device_type(device_type));
|
||
|
auto to_type = promote_type(
|
||
|
get_lower_precision_fp_from_device_type(device_type),
|
||
|
device_type,
|
||
|
args...);
|
||
|
return (*F)(cached_cast(to_type, args, device_type)...);
|
||
|
}
|
||
|
};
|
||
|
|
||
|
// Wrapper to infer return_type and parameter_types for WrapFunction_ (imitating
|
||
|
// core/boxing/impl/WrapFunctionIntoFunctor.h)
|
||
|
template <
|
||
|
CastPolicy policy,
|
||
|
c10::DeviceType device_type,
|
||
|
class Registered, // The signature for which we're registering. The
|
||
|
// dispatcher's calling code invokes our registered
|
||
|
// functions with arguments matching Registered, so we
|
||
|
// register WrapFunction_::call methods with a matching
|
||
|
// signature to properly field those arguments.
|
||
|
// guts::function_traits below extracts return_type and
|
||
|
// parameter_types from Registered, which WrapFunction_
|
||
|
// templates above use to declare their call methods.
|
||
|
class Redispatch, // The signature for the function we're redispatching to.
|
||
|
// In most cases this is the same as Registered, but for
|
||
|
// some ops (for example, ops where we append a dtype)
|
||
|
// it's useful to redispatch to a function with a
|
||
|
// different signature.
|
||
|
Redispatch* F> // The actual function we're redispatching to.
|
||
|
struct WrapFunction final {
|
||
|
using type = WrapFunction_<
|
||
|
policy,
|
||
|
device_type,
|
||
|
Redispatch,
|
||
|
F,
|
||
|
typename guts::function_traits<Registered>::return_type,
|
||
|
typename guts::function_traits<Registered>::parameter_types>;
|
||
|
};
|
||
|
|
||
|
/*****************************************************************************************************************
|
||
|
This section performs load-time registration for autocast wrappers.
|
||
|
|
||
|
It's debatable at what level operations should be patched. We'd like casts to
|
||
|
be autograd-exposed and precede autograd history recording, so that for
|
||
|
lower_precision_fp ops, input tensors are saved for backward in
|
||
|
lower_precision_fp rather than fp32. Saving inputs in lower_precision_fp
|
||
|
can significantly reduce a model's memory footprint.
|
||
|
|
||
|
Option 1 (strawman): Patch only at the level of explicit calls into
|
||
|
cudnn/cublas (cudnn_convolution, etc), because those are the code paths that are
|
||
|
guaranteed to use Tensor Cores, therefore they're the ones that will benefit
|
||
|
most from lower_precision_fp. Potential pitfall: convolutions (and other ops)
|
||
|
are wrapped in several layers of at::* calls. If one of those happens to record
|
||
|
autograd history, then we've lost the opportunity to save inputs in
|
||
|
lower_precision_fp.
|
||
|
|
||
|
Option 2: Patch the Python-exposed surface of calls, to make 100% sure autograd
|
||
|
history recording can't sneak in ahead of autocast. This mirrors Apex most
|
||
|
closely.
|
||
|
|
||
|
I think Option 2 is the right answer for all ops, not just convolutions. Option
|
||
|
2 is what I implement here.
|
||
|
*****************************************************************************************************************/
|
||
|
|
||
|
/********************************************************************************************************************
|
||
|
Explicit registration for out-of-place ops
|
||
|
|
||
|
The stuff below could be codegenned. Ed said
|
||
|
> you are going to have to write the function definition at some point, I
|
||
|
wouldn't try to get clever about it Therefore, for the moment, this is all
|
||
|
copy pasted in from VariableTypeEverything.cpp with appropriate substitutions.
|
||
|
********************************************************************************************************************/
|
||
|
|
||
|
} // namespace at::autocast
|
||
|
|
||
|
#define ADD_NS(RAW_OP) at::RAW_OP
|
||
|
|
||
|
#define _KERNEL_OVERLOAD_NARG_IMPL(_0, _1, _2, N, ...) N
|
||
|
#define _KERNEL_OVERLOAD_NARG(...) \
|
||
|
C10_EXPAND_MSVC_WORKAROUND(_KERNEL_OVERLOAD_NARG_IMPL(__VA_ARGS__, 2, 1))
|
||
|
|
||
|
// Common cases where registration signature matches redispatch signature
|
||
|
// (that's why SIGNATURE is repeated in the WrapFunction instantiation)
|
||
|
#define KERNEL1(DISPATCHKEY, OP, POLICY) \
|
||
|
m.impl( \
|
||
|
TORCH_SELECTIVE_NAME("aten::" #OP), \
|
||
|
&::at::autocast::WrapFunction< \
|
||
|
::at::autocast::CastPolicy::POLICY, \
|
||
|
DISPATCHKEY, \
|
||
|
decltype(ATEN_FN(OP)), \
|
||
|
decltype(ATEN_FN(OP)), \
|
||
|
&ATEN_FN(OP)>::type::call);
|
||
|
|
||
|
#define KERNEL2(DISPATCHKEY, OP, OVERLOAD, POLICY) \
|
||
|
m.impl( \
|
||
|
TORCH_SELECTIVE_NAME("aten::" #OP "." #OVERLOAD), \
|
||
|
&::at::autocast::WrapFunction< \
|
||
|
::at::autocast::CastPolicy::POLICY, \
|
||
|
DISPATCHKEY, \
|
||
|
decltype(ATEN_FN2(OP, OVERLOAD)), \
|
||
|
decltype(ATEN_FN2(OP, OVERLOAD)), \
|
||
|
&ATEN_FN2(OP, OVERLOAD)>::type::call);
|
||
|
|
||
|
#define _KERNEL_DISPATCH(DISPATCHKEY, NARG, ...) \
|
||
|
C10_CONCATENATE(KERNEL, NARG)(DISPATCHKEY, __VA_ARGS__)
|
||
|
|
||
|
#define _KERNEL_IMPL(DISPATCHKEY, ...) \
|
||
|
_KERNEL_DISPATCH(DISPATCHKEY, _KERNEL_OVERLOAD_NARG(__VA_ARGS__), __VA_ARGS__)
|
||
|
|
||
|
// It will dispatch to KERNEL1 or KERNEL2 based on its inputs.
|
||
|
#define KERNEL(DISPATCHKEY, ...) _KERNEL_IMPL(DISPATCHKEY, __VA_ARGS__)
|
||
|
|
||
|
// Less-common but still useful case: redispatching to a function
|
||
|
// with a new signature (e.g. appending a dtype)
|
||
|
#define KERNEL_DIFFERENT_REDISPATCH_SIGNATURE( \
|
||
|
DISPATCHKEY, \
|
||
|
REDISPATCH_FUNC, \
|
||
|
REGISTER_NAME, \
|
||
|
REGISTER_SIGNATURE, \
|
||
|
REDISPATCH_SIGNATURE, \
|
||
|
POLICY) \
|
||
|
m.impl( \
|
||
|
TORCH_SELECTIVE_NAME("aten::" REGISTER_NAME), \
|
||
|
&::at::autocast::WrapFunction< \
|
||
|
::at::autocast::CastPolicy::POLICY, \
|
||
|
DISPATCHKEY, \
|
||
|
REGISTER_SIGNATURE, \
|
||
|
REDISPATCH_SIGNATURE, \
|
||
|
&REDISPATCH_FUNC>::type::call);
|
||
|
|
||
|
// KERNEL_CPU/KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_CPU
|
||
|
// registration (OP, POLICY) or (OP, OVERLOAD, POLICY) for AutocastCPU
|
||
|
#define KERNEL_CPU(...) KERNEL(c10::DeviceType::CPU, __VA_ARGS__)
|
||
|
|
||
|
#define KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_CPU( \
|
||
|
REDISPATCH_FUNC, \
|
||
|
REGISTER_NAME, \
|
||
|
REGISTER_SIGNATURE, \
|
||
|
REDISPATCH_SIGNATURE, \
|
||
|
POLICY) \
|
||
|
KERNEL_DIFFERENT_REDISPATCH_SIGNATURE( \
|
||
|
c10::DeviceType::CPU, \
|
||
|
REDISPATCH_FUNC, \
|
||
|
REGISTER_NAME, \
|
||
|
REGISTER_SIGNATURE, \
|
||
|
REDISPATCH_SIGNATURE, \
|
||
|
POLICY)
|
||
|
|
||
|
// KERNEL_CUDA/KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_CUDA
|
||
|
// registration (OP, POLICY) or (OP, OVERLOAD, POLICY) for AutocastCUDA
|
||
|
#define KERNEL_CUDA(...) KERNEL(c10::DeviceType::CUDA, __VA_ARGS__)
|
||
|
|
||
|
#define KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_CUDA( \
|
||
|
REDISPATCH_FUNC, \
|
||
|
REGISTER_NAME, \
|
||
|
REGISTER_SIGNATURE, \
|
||
|
REDISPATCH_SIGNATURE, \
|
||
|
POLICY) \
|
||
|
KERNEL_DIFFERENT_REDISPATCH_SIGNATURE( \
|
||
|
c10::DeviceType::CUDA, \
|
||
|
REDISPATCH_FUNC, \
|
||
|
REGISTER_NAME, \
|
||
|
REGISTER_SIGNATURE, \
|
||
|
REDISPATCH_SIGNATURE, \
|
||
|
POLICY)
|
||
|
|
||
|
// KERNEL_MTIA/KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_MTIA
|
||
|
// registration (OP, POLICY) or (OP, OVERLOAD, POLICY) for AutocastMTIA
|
||
|
#define KERNEL_MTIA(...) KERNEL(c10::DeviceType::MTIA, __VA_ARGS__)
|
||
|
|
||
|
#define KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_MTIA( \
|
||
|
REDISPATCH_FUNC, \
|
||
|
REGISTER_NAME, \
|
||
|
REGISTER_SIGNATURE, \
|
||
|
REDISPATCH_SIGNATURE, \
|
||
|
POLICY) \
|
||
|
KERNEL_DIFFERENT_REDISPATCH_SIGNATURE( \
|
||
|
c10::DeviceType::MTIA, \
|
||
|
REDISPATCH_FUNC, \
|
||
|
REGISTER_NAME, \
|
||
|
REGISTER_SIGNATURE, \
|
||
|
REDISPATCH_SIGNATURE, \
|
||
|
POLICY)
|
||
|
|
||
|
// KERNEL_XPU/KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_XPU
|
||
|
// registration (OP, POLICY) or (OP, OVERLOAD, POLICY) for AutocastXPU
|
||
|
#define KERNEL_XPU(...) KERNEL(c10::DeviceType::XPU, __VA_ARGS__)
|
||
|
|
||
|
#define KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_XPU( \
|
||
|
REDISPATCH_FUNC, \
|
||
|
REGISTER_NAME, \
|
||
|
REGISTER_SIGNATURE, \
|
||
|
REDISPATCH_SIGNATURE, \
|
||
|
POLICY) \
|
||
|
KERNEL_DIFFERENT_REDISPATCH_SIGNATURE( \
|
||
|
c10::DeviceType::XPU, \
|
||
|
REDISPATCH_FUNC, \
|
||
|
REGISTER_NAME, \
|
||
|
REGISTER_SIGNATURE, \
|
||
|
REDISPATCH_SIGNATURE, \
|
||
|
POLICY)
|
||
|
|
||
|
// KERNEL_PRIVATEUSEONE/KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_PRIVATEUSEONE
|
||
|
// registration (OP, POLICY) or (OP, OVERLOAD, POLICY) for AutocastPrivateUse1
|
||
|
#define KERNEL_PRIVATEUSEONE(...) \
|
||
|
KERNEL(c10::DeviceType::PrivateUse1, __VA_ARGS__)
|
||
|
|
||
|
#define KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_PRIVATEUSEONE( \
|
||
|
REDISPATCH_FUNC, \
|
||
|
REGISTER_NAME, \
|
||
|
REGISTER_SIGNATURE, \
|
||
|
REDISPATCH_SIGNATURE, \
|
||
|
POLICY) \
|
||
|
KERNEL_DIFFERENT_REDISPATCH_SIGNATURE( \
|
||
|
c10::DeviceType::PrivateUse1, \
|
||
|
REDISPATCH_FUNC, \
|
||
|
REGISTER_NAME, \
|
||
|
REGISTER_SIGNATURE, \
|
||
|
REDISPATCH_SIGNATURE, \
|
||
|
POLICY)
|
||
|
|
||
|
// KERNEL_MPS
|
||
|
// registration (OP, POLICY) or (OP, OVERLOAD, POLICY) for AutocastMPS
|
||
|
#define KERNEL_MPS(...) KERNEL(c10::DeviceType::MPS, __VA_ARGS__)
|
||
|
|
||
|
// Op lists for different policies.
|
||
|
// To make sure other backends can reuse the policy op list.
|
||
|
#define AT_FORALL_LOWER_PRECISION_FP(_) \
|
||
|
_(_convolution, deprecated) \
|
||
|
_(_convolution) \
|
||
|
_(conv1d) \
|
||
|
_(conv2d) \
|
||
|
_(conv3d) \
|
||
|
_(conv_tbc) \
|
||
|
_(conv_transpose1d) \
|
||
|
_(conv_transpose2d, input) \
|
||
|
_(conv_transpose3d, input) \
|
||
|
_(convolution) \
|
||
|
_(prelu) \
|
||
|
_(addmm) \
|
||
|
_(addmv) \
|
||
|
_(addr) \
|
||
|
_(matmul) \
|
||
|
_(einsum) \
|
||
|
_(mm) \
|
||
|
_(mv) \
|
||
|
_(linalg_vecdot) \
|
||
|
_(linear) \
|
||
|
_(addbmm) \
|
||
|
_(baddbmm) \
|
||
|
_(bmm) \
|
||
|
_(chain_matmul) \
|
||
|
_(linalg_multi_dot) \
|
||
|
_(_thnn_fused_lstm_cell) \
|
||
|
_(_thnn_fused_gru_cell) \
|
||
|
_(lstm_cell) \
|
||
|
_(gru_cell) \
|
||
|
_(rnn_tanh_cell) \
|
||
|
_(rnn_relu_cell) \
|
||
|
_(_scaled_dot_product_flash_attention) \
|
||
|
_(scaled_dot_product_attention)
|
||
|
|
||
|
#define AT_FORALL_FP32(_) \
|
||
|
_(acos) \
|
||
|
_(asin) \
|
||
|
_(cosh) \
|
||
|
_(erfinv) \
|
||
|
_(exp) \
|
||
|
_(expm1) \
|
||
|
_(log) \
|
||
|
_(log10) \
|
||
|
_(log2) \
|
||
|
_(log1p) \
|
||
|
_(reciprocal) \
|
||
|
_(rsqrt) \
|
||
|
_(sinh) \
|
||
|
_(tan) \
|
||
|
_(pow, Tensor_Scalar) \
|
||
|
_(pow, Tensor_Tensor) \
|
||
|
_(pow, Scalar) \
|
||
|
_(softplus) \
|
||
|
_(layer_norm) \
|
||
|
_(native_layer_norm) \
|
||
|
_(group_norm) \
|
||
|
_(frobenius_norm, dim) \
|
||
|
_(nuclear_norm) \
|
||
|
_(nuclear_norm, dim) \
|
||
|
_(cosine_similarity) \
|
||
|
_(poisson_nll_loss) \
|
||
|
_(cosine_embedding_loss) \
|
||
|
_(nll_loss) \
|
||
|
_(nll_loss2d) \
|
||
|
_(hinge_embedding_loss) \
|
||
|
_(kl_div) \
|
||
|
_(l1_loss) \
|
||
|
_(smooth_l1_loss) \
|
||
|
_(huber_loss) \
|
||
|
_(mse_loss) \
|
||
|
_(margin_ranking_loss) \
|
||
|
_(multilabel_margin_loss) \
|
||
|
_(soft_margin_loss) \
|
||
|
_(triplet_margin_loss) \
|
||
|
_(multi_margin_loss) \
|
||
|
_(binary_cross_entropy_with_logits) \
|
||
|
_(dist) \
|
||
|
_(pdist) \
|
||
|
_(cdist) \
|
||
|
_(renorm) \
|
||
|
_(logsumexp) \
|
||
|
_(upsample_nearest1d) \
|
||
|
_(_upsample_nearest_exact1d) \
|
||
|
_(upsample_nearest2d) \
|
||
|
_(_upsample_nearest_exact2d) \
|
||
|
_(upsample_nearest3d) \
|
||
|
_(_upsample_nearest_exact3d) \
|
||
|
_(upsample_linear1d) \
|
||
|
_(upsample_bilinear2d) \
|
||
|
_(_upsample_bilinear2d_aa) \
|
||
|
_(upsample_trilinear3d) \
|
||
|
_(upsample_bicubic2d) \
|
||
|
_(_upsample_bicubic2d_aa)
|
||
|
|
||
|
#define AT_FORALL_FP32_SET_OPT_DTYPE(_) \
|
||
|
_(prod) \
|
||
|
_(prod, dim_int) \
|
||
|
_(prod, dim_Dimname) \
|
||
|
_(softmax, int) \
|
||
|
_(softmax, Dimname) \
|
||
|
_(log_softmax, int) \
|
||
|
_(log_softmax, Dimname) \
|
||
|
_(cumprod) \
|
||
|
_(cumprod, dimname) \
|
||
|
_(cumsum) \
|
||
|
_(cumsum, dimname) \
|
||
|
_(linalg_vector_norm) \
|
||
|
_(linalg_matrix_norm) \
|
||
|
_(linalg_matrix_norm, str_ord) \
|
||
|
_(sum) \
|
||
|
_(sum, dim_IntList) \
|
||
|
_(sum, dim_DimnameList)
|
||
|
|
||
|
#define AT_FORALL_DIFFERENT_REDISPATCH_SIGNATURE(_) \
|
||
|
_(ADD_NS(norm), \
|
||
|
"norm.Scalar", \
|
||
|
Tensor(const Tensor&, const Scalar&), \
|
||
|
Tensor(const Tensor&, const std::optional<Scalar>&, ScalarType), \
|
||
|
fp32_append_dtype) \
|
||
|
_(ADD_NS(norm), \
|
||
|
"norm.ScalarOpt_dim", \
|
||
|
Tensor(const Tensor&, const std::optional<Scalar>&, IntArrayRef, bool), \
|
||
|
Tensor( \
|
||
|
const Tensor&, \
|
||
|
const std::optional<Scalar>&, \
|
||
|
IntArrayRef, \
|
||
|
bool, \
|
||
|
ScalarType), \
|
||
|
fp32_append_dtype) \
|
||
|
_(ADD_NS(norm), \
|
||
|
"norm.names_ScalarOpt_dim", \
|
||
|
Tensor(const Tensor&, const std::optional<Scalar>&, DimnameList, bool), \
|
||
|
Tensor( \
|
||
|
const Tensor&, \
|
||
|
const std::optional<Scalar>&, \
|
||
|
DimnameList, \
|
||
|
bool, \
|
||
|
ScalarType), \
|
||
|
fp32_append_dtype)
|
||
|
|
||
|
#define AT_FORALL_PROMOTE(_) \
|
||
|
_(addcdiv) \
|
||
|
_(addcmul) \
|
||
|
_(atan2) \
|
||
|
_(bilinear) \
|
||
|
_(cross) \
|
||
|
_(dot) \
|
||
|
_(vdot) \
|
||
|
_(grid_sampler) \
|
||
|
_(index_put) \
|
||
|
_(tensordot) \
|
||
|
_(scatter_add)
|