team-10/venv/Lib/site-packages/torch/include/ATen/mps/IndexKernels.h
2025-08-02 02:00:33 +02:00

220 lines
9.3 KiB
C++

#pragma once
namespace at::mps {
static const char* SCATTER_OPS_TEMPLATE = R"METAL_SCATTER(
template<typename Y, typename X>
Y cast(const X x);
template<>
{1} cast<{1}, {0}>(const {0} x) {{
return {2};
}}
kernel void scatter_kernel_n(uint linear_index [[thread_position_in_grid]],
constant void * src_ [[buffer(0)]],
device void * dst_ [[buffer(1)]],
constant uint32_t * size [[buffer(2)]],
constant uint32_t * stride [[buffer(3)]],
constant uint32_t & numel [[buffer(4)]],
constant int32_t & ndim [[buffer(5)]]) {{
if (linear_index >= numel) return;
constant {0} * src = (constant {0} *)src_;
device {1} * dst = (device {1} *)dst_;
uint64_t dst_offs = 0;
auto dst_idx = linear_index;
for(int dim = ndim - 1; dim >= 0; --dim) {{
dst_offs += stride[dim] * (dst_idx % size[dim]);
dst_idx /= size[dim];
}}
dst[dst_offs] = cast<{1}>(src[linear_index]);
}}
kernel void scatter_kernel_4(uint linear_index [[thread_position_in_grid]],
constant void * src_ [[buffer(0)]],
device void * dst_ [[buffer(1)]],
constant packed_uint4 & size [[buffer(2)]],
constant packed_uint4 & stride [[buffer(3)]],
constant uint32_t & numel [[buffer(4)]]) {{
if (linear_index >= numel) return;
constant {0} * src = (constant {0} *)src_;
device {1} * dst = (device {1} *)dst_;
packed_uint4 local_index;
local_index.x = linear_index / (size[3] * size[2] * size[1]) % size[0];
local_index.y = linear_index / (size[3] * size[2]) % size[1];
local_index.z = linear_index / size[3] % size[2];
local_index.w = linear_index % size[3];
const packed_uint4 strided_index = local_index * stride;
dst[strided_index.x + strided_index.y + strided_index.z + strided_index.w] = cast<{1}>(src[linear_index]);
}}
kernel void scatter_kernel_3(uint linear_index [[thread_position_in_grid]],
constant void * src_ [[buffer(0)]],
device void * dst_ [[buffer(1)]],
constant packed_uint3 & size [[buffer(2)]],
constant packed_uint3 & stride [[buffer(3)]],
constant uint32_t & numel [[buffer(4)]]) {{
if (linear_index >= numel) return;
constant {0} * src = (constant {0} *)src_;
device {1} * dst = (device {1} *)dst_;
packed_uint3 local_index;
local_index.x = linear_index / (size[2] * size[1]) % size[0];
local_index.y = linear_index / size[2] % size[1];
local_index.z = linear_index % size[2];
const packed_uint3 strided_index = local_index * stride;
dst[strided_index.x + strided_index.y + strided_index.z] = cast<{1}>(src[linear_index]);
}}
kernel void scatter_kernel_2(uint linear_index [[thread_position_in_grid]],
constant void * src_ [[buffer(0)]],
device void * dst_ [[buffer(1)]],
constant packed_uint2 & size [[buffer(2)]],
constant packed_uint2 & stride [[buffer(3)]],
constant uint32_t & numel [[buffer(4)]]) {{
if (linear_index >= numel) return;
constant {0} * src = (constant {0} *)src_;
device {1} * dst = (device {1} *)dst_;
packed_uint2 local_index;
local_index.x = linear_index / size[1] % size[0];
local_index.y = linear_index % size[1];
const packed_uint2 strided_index = local_index * stride;
dst[strided_index.x + strided_index.y] = cast<{1}>(src[linear_index]);
}}
kernel void scatter_kernel_1(uint linear_index [[thread_position_in_grid]],
constant void * src_ [[buffer(0)]],
device void * dst_ [[buffer(1)]],
constant int & size [[buffer(2)]],
constant int & stride [[buffer(3)]],
constant uint32_t & numel [[buffer(4)]]) {{
if (linear_index >= numel) return;
constant {0} * src = (constant {0} *)src_;
device {1} * dst = (device {1} *)dst_;
const int local_index = linear_index % size;
const int strided_index = local_index * stride;
dst[strided_index] = cast<{1}>(src[linear_index]);
}}
)METAL_SCATTER";
static const char* GATHER_OPS_TEMPLATE = R"METAL_GATHER(
template<typename Y, typename X>
Y cast(const X x);
template<>
{1} cast<{1}, {0}>(const {0} x) {{
return {2};
}}
kernel void gather_kernel_n(uint linear_index [[thread_position_in_grid]],
constant void * src_ [[buffer(0)]],
device void * dst_ [[buffer(1)]],
constant uint32_t * size [[buffer(2)]],
constant uint32_t * stride [[buffer(3)]],
constant uint32_t & numel [[buffer(4)]],
constant int32_t & ndim [[buffer(5)]]) {{
if (linear_index >= numel) return;
constant {0} * src = (constant {0} *)src_;
device {1} * dst = (device {1} *)dst_;
uint64_t src_offs = 0;
auto src_idx = linear_index;
for(int dim = ndim - 1; dim >= 0; --dim) {{
src_offs += stride[dim] * (src_idx % size[dim]);
src_idx /= size[dim];
}}
dst[linear_index] = cast<{1}>(src[src_offs]);
}}
kernel void gather_kernel_4(uint linear_index [[thread_position_in_grid]],
constant void * src_ [[buffer(0)]],
device void * dst_ [[buffer(1)]],
constant packed_uint4 & size [[buffer(2)]],
constant packed_uint4 & stride [[buffer(3)]],
constant uint32_t & numel [[buffer(4)]]) {{
if (linear_index >= numel) return;
constant {0} * src = (constant {0} *)src_;
device {1} * dst = (device {1} *)dst_;
packed_uint4 local_index;
local_index.x = linear_index / (size[3] * size[2] * size[1]) % size[0];
local_index.y = linear_index / (size[3] * size[2]) % size[1];
local_index.z = linear_index / size[3] % size[2];
local_index.w = linear_index % size[3];
const packed_uint4 strided_index = local_index * stride;
dst[linear_index] = cast<{1}>(src[strided_index.x + strided_index.y + strided_index.z + strided_index.w]);
}}
kernel void gather_kernel_3(uint linear_index [[thread_position_in_grid]],
constant void * src_ [[buffer(0)]],
device void * dst_ [[buffer(1)]],
constant packed_uint3 & size [[buffer(2)]],
constant packed_uint3 & stride [[buffer(3)]],
constant uint32_t & numel [[buffer(4)]]) {{
if (linear_index >= numel) return;
constant {0} * src = (constant {0} *)src_;
device {1} * dst = (device {1} *)dst_;
packed_uint3 local_index;
local_index.x = linear_index / (size[2] * size[1]) % size[0];
local_index.y = linear_index / size[2] % size[1];
local_index.z = linear_index % size[2];
const packed_uint3 strided_index = local_index * stride;
dst[linear_index] = cast<{1}>(src[strided_index.x + strided_index.y + strided_index.z]);
}}
kernel void gather_kernel_2(uint linear_index [[thread_position_in_grid]],
constant void * src_ [[buffer(0)]],
device void * dst_ [[buffer(1)]],
constant packed_uint2 & size [[buffer(2)]],
constant packed_uint2 & stride [[buffer(3)]],
constant uint32_t & numel [[buffer(4)]]) {{
if (linear_index >= numel) return;
constant {0} * src = (constant {0} *)src_;
device {1} * dst = (device {1} *)dst_;
packed_uint2 local_index;
local_index.x = linear_index / size[1] % size[0];
local_index.y = linear_index % size[1];
const packed_uint2 strided_index = local_index * stride;
dst[linear_index] = cast<{1}>(src[strided_index.x + strided_index.y]);
}}
kernel void gather_kernel_1(uint linear_index [[thread_position_in_grid]],
constant void * src_ [[buffer(0)]],
device void * dst_ [[buffer(1)]],
constant int & size [[buffer(2)]],
constant int & stride [[buffer(3)]],
constant uint32_t & numel [[buffer(4)]]) {{
if (linear_index >= numel) return;
constant {0} * src = (constant {0} *)src_;
device {1} * dst = (device {1} *)dst_;
const int local_index = linear_index % size;
const int strided_index = local_index * stride;
dst[linear_index] = cast<{1}>(src[strided_index]);
}}
)METAL_GATHER";
} // namespace at::mps