Skip to content

Commit

Permalink
documentation improvements and other review changes
Browse files Browse the repository at this point in the history
Co-authored-by: Marcel Koch <[email protected]>
  • Loading branch information
upsj and MarcelKoch committed Jan 19, 2025
1 parent 877e0ad commit 4b6b8a2
Show file tree
Hide file tree
Showing 11 changed files with 369 additions and 249 deletions.
54 changes: 35 additions & 19 deletions common/unified/components/range_minimum_query_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,38 +22,56 @@ namespace range_minimum_query {
template <typename IndexType>
void compute_lookup_small(std::shared_ptr<const DefaultExecutor> exec,
const IndexType* values, IndexType size,
block_argmin_storage_type<IndexType>& block_argmin,
IndexType* block_min, uint16* block_type)
bit_packed_span<int, IndexType, uint32>& block_argmin,
IndexType* block_min, uint16* block_tree_index)
{
#ifdef GKO_COMPILING_DPCPP
// The Intel SYCL compiler doesn't support constexpr initialization of
// non-trivial objects on the device.
GKO_NOT_IMPLEMENTED;
#else
using tree_index_type = std::decay_t<decltype(*block_type)>;
using device_lut_type =
gko::device_block_range_minimum_query_lookup_table<small_block_size>;
static_assert(device_lut_type::view_type::num_trees <=
std::numeric_limits<tree_index_type>::max(),
"block type storage too small");
using device_type = device_range_minimum_query<IndexType>;
constexpr auto block_size = device_type::block_size;
using tree_index_type = std::decay_t<decltype(*block_tree_index)>;
using device_lut_type = typename device_type::block_lut_type;
using lut_type = typename device_type::block_lut_view_type;
static_assert(
lut_type::num_trees <= std::numeric_limits<tree_index_type>::max(),
"block type storage too small");
// block_argmin stores multiple values per memory word, so we need to make
// sure that no two different threads write to the same memory location.
// The easiest way to do that is to have every thread handle all elements
// that map to the same memory location.
// The argmin inside a block is in the range [0, block_size - 1], so
// it needs ceil_log2_constexpr(block_size) bits. For efficiency
// reasons, we round that up to the next power of two.
// This expression is essentially bits_per_word /
// round_up_pow2_constexpr(ceil_log2_constexpr(block_size)), i.e. how
// many values are stored per word.
constexpr auto collation_width =
1 << (std::decay_t<decltype(block_argmin)>::bits_per_word_log2 -
ceil_log2_constexpr(ceil_log2_constexpr(small_block_size)));
ceil_log2_constexpr(ceil_log2_constexpr(block_size)));
const device_lut_type lut{exec};
run_kernel(
exec,
[] GKO_KERNEL(auto collated_block_idx, auto values, auto block_argmin,
auto block_min, auto block_type, auto lut, auto size) {
auto block_min, auto block_tree_index, auto lut,
auto size) {
// we need to put this here because some compilers interpret capture
// rules around constexpr incorrectly
constexpr auto block_size = device_type::block_size;
constexpr auto infinity = std::numeric_limits<IndexType>::max();
const auto num_blocks = ceildiv(size, small_block_size);
const auto num_blocks = ceildiv(size, block_size);
for (auto block_idx = collated_block_idx * collation_width;
block_idx <
std::min<int64>((collated_block_idx + 1) * collation_width,
num_blocks);
block_idx++) {
const auto i = block_idx * small_block_size;
IndexType local_values[small_block_size];
const auto i = block_idx * block_size;
IndexType local_values[block_size];
int argmin = 0;
#pragma unroll
for (int local_i = 0; local_i < small_block_size; local_i++) {
for (int local_i = 0; local_i < block_size; local_i++) {
// use "infinity" as sentinel for minimum computations
local_values[local_i] =
local_i + i < size ? values[local_i + i] : infinity;
Expand All @@ -63,16 +81,14 @@ void compute_lookup_small(std::shared_ptr<const DefaultExecutor> exec,
}
const auto tree_number = lut->compute_tree_index(local_values);
const auto min = local_values[argmin];
// TODO collate these so a single thread handles the argmins for
// an entire memory word
block_argmin.set(block_idx, argmin);
block_min[block_idx] = min;
block_type[block_idx] =
block_tree_index[block_idx] =
static_cast<tree_index_type>(tree_number);
}
},
ceildiv(ceildiv(size, small_block_size), collation_width), values,
block_argmin, block_min, block_type, lut.get(), size);
ceildiv(ceildiv(size, block_size), collation_width), values,
block_argmin, block_min, block_tree_index, lut.get(), size);
#endif
}

Expand Down
7 changes: 1 addition & 6 deletions core/base/array.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,7 @@ template GKO_DECLARE_ARRAY_FILL(uint32);

// this is necessary because compilers use different types for uint64_t and
// size_t, namely unsigned long long and unsigned long
void array_fill_instantiation_helper(array<uint64>& a)
{
if constexpr (!std::is_same_v<uint64, size_type>) {
a.fill(0);
}
}
void array_fill_instantiation_helper(array<uint64>& a) { a.fill(0); }


#define GKO_DECLARE_ARRAY_REDUCE_ADD(_type) \
Expand Down
47 changes: 25 additions & 22 deletions core/components/bit_packed_storage.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,19 +86,22 @@ constexpr int round_up_pow2(T value)
* non-power-of-two number of bits would make hard otherwise.
* The cass is a non-owning view.
*
* @tparam ValueType the type used to represent values in the span
* @tparam IndexType the type used to represent indices in the span
* @tparam StorageType the type used to represent values in the span, and also
* used internally to access the underlying memory. It
* @tparam StorageType the type used to internally represent the values. It
* needs to be large enough to represent all values to be
* stored.
*/
template <typename IndexType, typename StorageType>
template <typename ValueType, typename IndexType, typename StorageType>
class bit_packed_span {
static_assert(std::is_unsigned_v<StorageType>);

public:
using value_type = ValueType;
using index_type = IndexType;
using storage_type = StorageType;
/** How many bits are available in StorageType */
constexpr static int bits_per_word = sizeof(StorageType) * CHAR_BIT;
constexpr static int bits_per_word = sizeof(storage_type) * CHAR_BIT;
/** Binary logarithm of bits_per_word */
constexpr static int bits_per_word_log2 =
ceil_log2_constexpr(bits_per_word);
Expand All @@ -118,7 +121,7 @@ class bit_packed_span {

/*
* Returns the binary logarithm of the number of values can be stored inside
* a single StorageType word. This gets used to avoid integer divisions in
* a single storage_type word. This gets used to avoid integer divisions in
* favor of faster bit shifts.
*/
constexpr static int values_per_word_log2(int num_bits)
Expand All @@ -127,18 +130,18 @@ class bit_packed_span {
}

/**
* Computes how many StorageType words will be necessary to store size
* Computes how many storage_type words will be necessary to store size
* values requiring num_bits bits.
*
* @param size The number of values to store
* @param num_bits The number of bits necessary to store values inside this
* span. This means that all values need to be in the range
* [0, 2^num_bits).
*/
constexpr static IndexType storage_size(IndexType size, int num_bits)
constexpr static index_type storage_size(index_type size, int num_bits)
{
const auto shift = values_per_word_log2(num_bits);
const auto div = StorageType{1} << shift;
const auto div = storage_type{1} << shift;
return (size + div - 1) >> shift;
}

Expand All @@ -149,12 +152,12 @@ class bit_packed_span {
* @param i The index to write to
* @param value The value to write. It needs to be in [0, 2^num_bits).
*/
constexpr void set_from_zero(IndexType i, StorageType value)
constexpr void set_from_zero(index_type i, value_type value)
{
assert(value >= 0);
assert(value <= mask_);
assert(value <= static_cast<value_type>(mask_));
const auto [block, shift] = get_block_and_shift(i);
data_[block] |= value << shift;
data_[block] |= static_cast<storage_type>(value) << shift;
}

/**
Expand All @@ -163,7 +166,7 @@ class bit_packed_span {
* @param i The index to clear
* @param value The value to write. It needs to be in [0, 2^num_bits).
*/
constexpr void clear(IndexType i)
constexpr void clear(index_type i)
{
const auto [block, shift] = get_block_and_shift(i);
data_[block] &= ~(mask_ << shift);
Expand All @@ -175,7 +178,7 @@ class bit_packed_span {
* @param i The index to write to
* @param value The value to write. It needs to be in [0, 2^num_bits).
*/
constexpr void set(IndexType i, StorageType value)
constexpr void set(index_type i, value_type value)
{
clear(i);
set_from_zero(i, value);
Expand All @@ -187,17 +190,17 @@ class bit_packed_span {
* @param i The index to write to
* @param value The value to write. It needs to be in [0, 2^num_bits).
*/
constexpr StorageType get(IndexType i) const
constexpr value_type get(index_type i) const
{
const auto [block, shift] = get_block_and_shift(i);
return (data_[block] >> shift) & mask_;
return static_cast<value_type>((data_[block] >> shift) & mask_);
}

explicit constexpr bit_packed_span(StorageType* data, int num_bits,
IndexType size)
explicit constexpr bit_packed_span(storage_type* data, int num_bits,
index_type size)
: data_{data},
size_{size},
mask_{(StorageType{1} << num_bits) - 1},
mask_{(storage_type{1} << num_bits) - 1},
bits_per_value_{round_up_pow2(num_bits)},
values_per_word_log2_{values_per_word_log2(num_bits)},
local_index_mask_{(1 << values_per_word_log2_) - 1}
Expand All @@ -206,17 +209,17 @@ class bit_packed_span {
}

private:
constexpr std::pair<int, int> get_block_and_shift(IndexType i) const
constexpr std::pair<int, int> get_block_and_shift(index_type i) const
{
assert(i >= 0);
assert(i < size_);
return std::make_pair(i >> values_per_word_log2_,
(i & local_index_mask_) * bits_per_value_);
}

StorageType* data_;
IndexType size_;
StorageType mask_;
storage_type* data_;
index_type size_;
storage_type mask_;
int bits_per_value_;
int values_per_word_log2_;
int local_index_mask_;
Expand Down
14 changes: 7 additions & 7 deletions core/components/range_minimum_query.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,24 +28,24 @@ device_range_minimum_query<IndexType>::device_range_minimum_query(
: num_blocks_{static_cast<index_type>(
ceildiv(static_cast<index_type>(data.get_size()), block_size))},
lut_{data.get_executor()},
block_types_{data.get_executor(), static_cast<size_type>(num_blocks_)},
block_tree_indexs_{data.get_executor(),

Check warning on line 31 in core/components/range_minimum_query.cpp

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"indexs" should be "indexes" or "indices".
static_cast<size_type>(num_blocks_)},
block_argmin_storage_{
data.get_executor(),
static_cast<size_type>(block_argmin_view_type::storage_size(
static_cast<size_type>(num_blocks_), block_argmin_num_bits))},
block_min_{data.get_executor(), static_cast<size_type>(num_blocks_)},
superblock_storage_{
data.get_executor(),
static_cast<size_type>(
superblock_view_type::compute_storage_size(num_blocks_))},
superblock_storage_{data.get_executor(),
static_cast<size_type>(
superblock_view_type::storage_size(num_blocks_))},
values_{std::move(data)}
{
const auto exec = values_.get_executor();
auto block_argmin = block_argmin_view_type{
block_argmin_storage_.get_data(), block_argmin_num_bits, num_blocks_};
exec->run(make_compute_lookup_small(
values_.get_const_data(), static_cast<index_type>(values_.get_size()),
block_argmin, block_min_.get_data(), block_types_.get_data()));
block_argmin, block_min_.get_data(), block_tree_indexs_.get_data()));

Check warning on line 48 in core/components/range_minimum_query.cpp

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"indexs" should be "indexes" or "indices".
auto superblocks =
superblock_view_type{block_min_.get_const_data(),
superblock_storage_.get_data(), num_blocks_};
Expand All @@ -61,7 +61,7 @@ device_range_minimum_query<IndexType>::get() const
return range_minimum_query{values_.get_const_data(),
block_min_.get_const_data(),
block_argmin_storage_.get_const_data(),
block_types_.get_const_data(),
block_tree_indexs_.get_const_data(),

Check warning on line 64 in core/components/range_minimum_query.cpp

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"indexs" should be "indexes" or "indices".
superblock_storage_.get_const_data(),
lut_.get(),
static_cast<index_type>(values_.get_size())};
Expand Down
Loading

0 comments on commit 4b6b8a2

Please sign in to comment.