team-10/env/Lib/site-packages/torch/include/c10/metal/utils.h
2025-08-02 07:34:44 +02:00

149 lines
3 KiB
C++

// Metal helper functions
#pragma once
#include <metal_stdlib>
namespace c10 {
namespace metal {
namespace detail {
template <typename T>
struct vectypes {};
template <>
struct vectypes<float> {
using type4 = float4;
using type3 = float3;
using type2 = float2;
};
template <>
struct vectypes<half> {
using type4 = half4;
using type3 = half3;
using type2 = half2;
};
#if __METAL_VERSION__ >= 310
template <>
struct vectypes<bfloat> {
using type4 = bfloat4;
using type3 = bfloat3;
using type2 = bfloat2;
};
#endif
template <>
struct vectypes<short> {
using type4 = short4;
using type3 = short3;
using type2 = short2;
};
template <>
struct vectypes<int> {
using type4 = int4;
using type3 = int3;
using type2 = int2;
};
template <>
struct vectypes<long> {
using type4 = short4;
using type3 = short3;
using type2 = short2;
};
template <typename T>
struct OpMathType {
using type = T;
};
template <>
struct OpMathType<half> {
using type = float;
};
template <>
struct OpMathType<short> {
using type = int;
};
template <>
struct OpMathType<char> {
using type = int;
};
template <>
struct OpMathType<uchar> {
using type = int;
};
#if __METAL_VERSION__ >= 310
template <>
struct OpMathType<bfloat> {
using type = float;
};
#endif
} // namespace detail
template <typename T>
::metal::enable_if_t<::metal::is_floating_point_v<T>, T> max(T a, T b) {
return ::metal::isunordered(a, b) ? NAN : ::metal::max(a, b);
}
template <typename T>
::metal::enable_if_t<::metal::is_integral_v<T>, T> max(T a, T b) {
return ::metal::max(a, b);
}
template <typename T>
::metal::enable_if_t<::metal::is_floating_point_v<T>, T> min(T a, T b) {
return ::metal::isunordered(a, b) ? NAN : ::metal::min(a, b);
}
template <typename T>
::metal::enable_if_t<::metal::is_integral_v<T>, T> min(T a, T b) {
return ::metal::min(a, b);
}
#if __METAL_VERSION__ >= 310
template <>
inline bfloat min(bfloat a, bfloat b) {
return bfloat(
::metal::isunordered(a, b) ? NAN : ::metal::min(float(a), float(b)));
}
template <>
inline bfloat max(bfloat a, bfloat b) {
return bfloat(
::metal::isunordered(a, b) ? NAN : ::metal::max(float(a), float(b)));
}
#endif
template <typename T>
using vec2type_t = typename detail::vectypes<T>::type2;
template <typename T>
using vec4type_t = typename detail::vectypes<T>::type4;
template <typename T>
using opmath_t = typename detail::OpMathType<T>::type;
// TODO: Move it to type_traits header may be
template <typename F, typename... Args>
using result_of = decltype(::metal::declval<F>()(::metal::declval<Args>()...));
template <typename T>
constexpr constant bool is_complex_v =
::metal::is_same_v<T, float2> || ::metal::is_same_v<T, half2>;
template <typename T>
constexpr constant bool is_scalar_floating_point_v =
::metal::is_floating_point_v<T> && ::metal::is_scalar_v<T>;
template <typename T>
constexpr constant bool is_scalar_integral_v =
::metal::is_integral_v<T> && ::metal::is_scalar_v<T>;
} // namespace metal
} // namespace c10