-
Notifications
You must be signed in to change notification settings - Fork 91
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
This merge adds a class to store a segmented array. A flat ginkgo array is partitioned into multiple segments by an additional index offset array. The segment `i` starts within the flat buffer at index `offsets[i]`, and ends (exclusively) at index `offsets[i + 1]`. The class only provides access to the flat buffer and the offsets. Related PR: #1545
- Loading branch information
Showing
10 changed files
with
657 additions
and
50 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,187 @@ | ||
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors | ||
// | ||
// SPDX-License-Identifier: BSD-3-Clause | ||
|
||
#include <ginkgo/core/base/segmented_array.hpp> | ||
|
||
|
||
#include "core/base/array_access.hpp" | ||
#include "core/components/prefix_sum_kernels.hpp" | ||
|
||
|
||
namespace gko { | ||
namespace { | ||
|
||
|
||
GKO_REGISTER_OPERATION(prefix_sum, components::prefix_sum_nonnegative); | ||
|
||
|
||
} | ||
|
||
|
||
template <typename T> | ||
size_type segmented_array<T>::get_size() const | ||
{ | ||
return buffer_.get_size(); | ||
} | ||
|
||
|
||
template <typename T> | ||
size_type segmented_array<T>::get_segment_count() const | ||
{ | ||
return offsets_.get_size() ? offsets_.get_size() - 1 : 0; | ||
} | ||
|
||
|
||
template <typename T> | ||
T* segmented_array<T>::get_flat_data() | ||
{ | ||
return buffer_.get_data(); | ||
} | ||
|
||
|
||
template <typename T> | ||
const T* segmented_array<T>::get_const_flat_data() const | ||
{ | ||
return buffer_.get_const_data(); | ||
} | ||
|
||
|
||
template <typename T> | ||
const gko::array<int64>& segmented_array<T>::get_offsets() const | ||
{ | ||
return offsets_; | ||
} | ||
|
||
|
||
template <typename T> | ||
std::shared_ptr<const Executor> segmented_array<T>::get_executor() const | ||
{ | ||
return buffer_.get_executor(); | ||
} | ||
|
||
|
||
template <typename T> | ||
segmented_array<T>::segmented_array(std::shared_ptr<const Executor> exec) | ||
: buffer_(exec), offsets_(exec, 1) | ||
{ | ||
offsets_.fill(0); | ||
} | ||
|
||
|
||
array<int64> sizes_to_offsets(const gko::array<int64>& sizes) | ||
{ | ||
auto exec = sizes.get_executor(); | ||
array<int64> offsets(exec, sizes.get_size() + 1); | ||
exec->copy(sizes.get_size(), sizes.get_const_data(), offsets.get_data()); | ||
exec->run(make_prefix_sum(offsets.get_data(), offsets.get_size())); | ||
return offsets; | ||
} | ||
|
||
|
||
template <typename T> | ||
segmented_array<T> segmented_array<T>::create_from_sizes( | ||
const gko::array<int64>& sizes) | ||
{ | ||
return create_from_offsets(sizes_to_offsets(sizes)); | ||
} | ||
|
||
|
||
template <typename T> | ||
segmented_array<T> segmented_array<T>::create_from_sizes( | ||
gko::array<T> buffer, const gko::array<int64>& sizes) | ||
{ | ||
return create_from_offsets(std::move(buffer), sizes_to_offsets(sizes)); | ||
} | ||
|
||
|
||
template <typename T> | ||
segmented_array<T> segmented_array<T>::create_from_offsets( | ||
gko::array<int64> offsets) | ||
{ | ||
GKO_THROW_IF_INVALID(offsets.get_size() > 0, | ||
"The offsets for segmented_arrays require at least " | ||
"one element."); | ||
auto size = | ||
static_cast<size_type>(get_element(offsets, offsets.get_size() - 1)); | ||
return create_from_offsets(array<T>{offsets.get_executor(), size}, | ||
std::move(offsets)); | ||
} | ||
|
||
|
||
template <typename T> | ||
segmented_array<T> segmented_array<T>::create_from_offsets( | ||
gko::array<T> buffer, gko::array<int64> offsets) | ||
{ | ||
GKO_ASSERT_EQ(buffer.get_size(), | ||
get_element(offsets, offsets.get_size() - 1)); | ||
segmented_array<T> result(buffer.get_executor()); | ||
result.offsets_ = std::move(offsets); | ||
result.buffer_ = std::move(buffer); | ||
return result; | ||
} | ||
|
||
|
||
template <typename T> | ||
segmented_array<T>::segmented_array(std::shared_ptr<const Executor> exec, | ||
segmented_array&& other) | ||
: segmented_array(exec) | ||
{ | ||
*this = std::move(other); | ||
} | ||
|
||
|
||
template <typename T> | ||
segmented_array<T>::segmented_array(std::shared_ptr<const Executor> exec, | ||
const segmented_array& other) | ||
: segmented_array(exec) | ||
{ | ||
*this = other; | ||
} | ||
|
||
|
||
template <typename T> | ||
segmented_array<T>::segmented_array(const segmented_array& other) | ||
: segmented_array(other.get_executor()) | ||
{ | ||
*this = other; | ||
} | ||
|
||
|
||
template <typename T> | ||
segmented_array<T>::segmented_array(segmented_array&& other) | ||
: segmented_array(other.get_executor()) | ||
{ | ||
*this = std::move(other); | ||
} | ||
|
||
|
||
template <typename T> | ||
segmented_array<T>& segmented_array<T>::operator=(const segmented_array& other) | ||
{ | ||
if (this != &other) { | ||
buffer_ = other.buffer_; | ||
offsets_ = other.offsets_; | ||
} | ||
return *this; | ||
} | ||
|
||
|
||
template <typename T> | ||
segmented_array<T>& segmented_array<T>::operator=(segmented_array&& other) | ||
{ | ||
if (this != &other) { | ||
buffer_ = std::move(other.buffer_); | ||
offsets_ = std::exchange(other.offsets_, | ||
array<int64>{other.get_executor(), {0}}); | ||
} | ||
return *this; | ||
} | ||
|
||
|
||
#define GKO_DECLARE_SEGMENTED_ARRAY(_type) class segmented_array<_type> | ||
|
||
GKO_INSTANTIATE_FOR_EACH_POD_TYPE(GKO_DECLARE_SEGMENTED_ARRAY); | ||
|
||
|
||
} // namespace gko |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors | ||
// | ||
// SPDX-License-Identifier: BSD-3-Clause | ||
|
||
#ifndef GINKGO_SEGMENTED_ARRAY_HPP | ||
#define GINKGO_SEGMENTED_ARRAY_HPP | ||
|
||
|
||
#include <ginkgo/core/base/segmented_array.hpp> | ||
|
||
|
||
namespace gko { | ||
|
||
|
||
/** | ||
* Helper struct storing an array segment | ||
* | ||
* @tparam T The value type of the array | ||
*/ | ||
template <typename T> | ||
struct array_segment { | ||
T* begin; | ||
T* end; | ||
}; | ||
|
||
|
||
/** | ||
* Helper function to create a device-compatible view of an array segment. | ||
*/ | ||
template <typename T> | ||
constexpr array_segment<T> get_array_segment(segmented_array<T>& sarr, | ||
size_type segment_id) | ||
{ | ||
assert(segment_id < sarr.get_segment_count()); | ||
auto offsets = sarr.get_offsets().get_const_data(); | ||
auto data = sarr.get_flat_data(); | ||
return {data + offsets[segment_id], data + offsets[segment_id + 1]}; | ||
} | ||
|
||
|
||
/** | ||
* Helper function to create a device-compatible view of a const array segment. | ||
*/ | ||
template <typename T> | ||
constexpr array_segment<const T> get_array_segment( | ||
const segmented_array<T>& sarr, size_type segment_id) | ||
{ | ||
assert(segment_id < sarr.get_segment_count()); | ||
auto offsets = sarr.get_offsets().get_const_data(); | ||
auto data = sarr.get_const_flat_data(); | ||
return {data + offsets[segment_id], data + offsets[segment_id + 1]}; | ||
} | ||
|
||
|
||
} // namespace gko | ||
|
||
#endif // GINKGO_SEGMENTED_ARRAY_HPP |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.