// Metal helper functions #pragma once #include namespace c10 { namespace metal { namespace detail { template struct vectypes {}; template <> struct vectypes { using type4 = float4; using type3 = float3; using type2 = float2; }; template <> struct vectypes { using type4 = half4; using type3 = half3; using type2 = half2; }; #if __METAL_VERSION__ >= 310 template <> struct vectypes { using type4 = bfloat4; using type3 = bfloat3; using type2 = bfloat2; }; #endif template <> struct vectypes { using type4 = short4; using type3 = short3; using type2 = short2; }; template <> struct vectypes { using type4 = int4; using type3 = int3; using type2 = int2; }; template <> struct vectypes { using type4 = short4; using type3 = short3; using type2 = short2; }; template struct OpMathType { using type = T; }; template <> struct OpMathType { using type = float; }; template <> struct OpMathType { using type = int; }; template <> struct OpMathType { using type = int; }; template <> struct OpMathType { using type = int; }; #if __METAL_VERSION__ >= 310 template <> struct OpMathType { using type = float; }; #endif } // namespace detail template ::metal::enable_if_t<::metal::is_floating_point_v, T> max(T a, T b) { return ::metal::isunordered(a, b) ? NAN : ::metal::max(a, b); } template ::metal::enable_if_t<::metal::is_integral_v, T> max(T a, T b) { return ::metal::max(a, b); } template ::metal::enable_if_t<::metal::is_floating_point_v, T> min(T a, T b) { return ::metal::isunordered(a, b) ? NAN : ::metal::min(a, b); } template ::metal::enable_if_t<::metal::is_integral_v, T> min(T a, T b) { return ::metal::min(a, b); } #if __METAL_VERSION__ >= 310 template <> inline bfloat min(bfloat a, bfloat b) { return bfloat( ::metal::isunordered(a, b) ? NAN : ::metal::min(float(a), float(b))); } template <> inline bfloat max(bfloat a, bfloat b) { return bfloat( ::metal::isunordered(a, b) ? NAN : ::metal::max(float(a), float(b))); } #endif template using vec2type_t = typename detail::vectypes::type2; template using vec4type_t = typename detail::vectypes::type4; template using opmath_t = typename detail::OpMathType::type; // TODO: Move it to type_traits header may be template using result_of = decltype(::metal::declval()(::metal::declval()...)); template constexpr constant bool is_complex_v = ::metal::is_same_v || ::metal::is_same_v; template constexpr constant bool is_scalar_floating_point_v = ::metal::is_floating_point_v && ::metal::is_scalar_v; template constexpr constant bool is_scalar_integral_v = ::metal::is_integral_v && ::metal::is_scalar_v; } // namespace metal } // namespace c10