98 lines
3.4 KiB
C++
98 lines
3.4 KiB
C++
#pragma once
|
|
#include <ATen/core/TensorAccessor.h>
|
|
#include <ATen/NumericUtils.h>
|
|
|
|
namespace at::native {
|
|
|
|
#ifdef CPU_CAPABILITY
|
|
inline namespace CPU_CAPABILITY {
|
|
#else
|
|
inline namespace DEFAULT {
|
|
#endif
|
|
|
|
// Core topk loop, shared between CPU and QuantizedCPU
|
|
template <typename scalar_t, typename accscalar_t>
|
|
void topk_impl_loop(
|
|
const int64_t mode_values_stride,
|
|
const int64_t mode_indices_stride,
|
|
const int64_t tmp_values_stride,
|
|
const int64_t k,
|
|
const int64_t dim_size,
|
|
const bool largest,
|
|
const bool sorted,
|
|
char** data, const int64_t* strides, const int64_t n) {
|
|
|
|
// If k is zero, then output values and indices are empty tensors
|
|
// So iterating over other dims is pointless
|
|
if (k == 0) {
|
|
return;
|
|
}
|
|
using elem_t = std::pair<accscalar_t, int64_t>;
|
|
std::vector<elem_t> queue(dim_size);
|
|
for (const auto i : c10::irange(n)) {
|
|
TensorAccessor<scalar_t, 1> mode_values(
|
|
reinterpret_cast<scalar_t*>(data[0] + i * strides[0]),
|
|
&k, &mode_values_stride);
|
|
TensorAccessor<int64_t, 1> mode_indices(
|
|
reinterpret_cast<int64_t*>(data[1] + i * strides[1]),
|
|
&k, &mode_indices_stride);
|
|
TensorAccessor<const scalar_t, 1> tmp_values(
|
|
reinterpret_cast<scalar_t*>(data[2] + i * strides[2]),
|
|
&dim_size, &tmp_values_stride);
|
|
|
|
auto n_2 = dim_size;
|
|
auto use_partial_sort = k * 64 <= n_2;
|
|
|
|
for (const auto j : c10::irange(n_2)) {
|
|
queue[j].first = tmp_values[j];
|
|
queue[j].second = j;
|
|
}
|
|
|
|
// we want nan to be sorted as top for numpy compatibility
|
|
if (use_partial_sort) {
|
|
if (largest) {
|
|
std::partial_sort(queue.begin(), queue.begin() + k, queue.end(),
|
|
[](const elem_t& x, const elem_t& y) -> bool {
|
|
return ((_isnan<accscalar_t>(x.first) && !_isnan<accscalar_t>(y.first)) || (x.first > y.first));
|
|
});
|
|
} else {
|
|
std::partial_sort(queue.begin(), queue.begin() + k, queue.end(),
|
|
[](const elem_t& x, const elem_t& y) -> bool {
|
|
return ((!_isnan<accscalar_t>(x.first) && _isnan<accscalar_t>(y.first)) || (x.first < y.first));
|
|
});
|
|
}
|
|
} else {
|
|
if (largest) {
|
|
std::nth_element(queue.begin(), queue.begin() + k - 1, queue.end(),
|
|
[](const elem_t& x, const elem_t& y) -> bool {
|
|
return ((_isnan<accscalar_t>(x.first) && !_isnan<accscalar_t>(y.first)) || (x.first > y.first));
|
|
});
|
|
if (sorted) {
|
|
std::sort(queue.begin(), queue.begin() + k - 1,
|
|
[](const elem_t& x, const elem_t& y) -> bool {
|
|
return ((_isnan<accscalar_t>(x.first) && !_isnan<accscalar_t>(y.first)) || (x.first > y.first));
|
|
});
|
|
}
|
|
} else {
|
|
std::nth_element(queue.begin(), queue.begin() + k -1, queue.end(),
|
|
[](const elem_t& x, const elem_t& y) -> bool {
|
|
return ((!_isnan<accscalar_t>(x.first) && _isnan<accscalar_t>(y.first)) || (x.first < y.first));
|
|
});
|
|
if (sorted) {
|
|
std::sort(queue.begin(), queue.begin() + k -1,
|
|
[](const elem_t& x, const elem_t& y) -> bool {
|
|
return ((!_isnan<accscalar_t>(x.first) && _isnan<accscalar_t>(y.first)) || (x.first < y.first));
|
|
});
|
|
}
|
|
}
|
|
}
|
|
|
|
for (const auto j : c10::irange(k)) {
|
|
mode_values[j] = queue[j].first;
|
|
mode_indices[j] = queue[j].second;
|
|
}
|
|
}
|
|
}
|
|
|
|
} // namespace CPU_CAPABILITY
|
|
} // namespace at::native
|