Skip to content

Commit

Permalink
Update DD matrix to support half
Browse files Browse the repository at this point in the history
  • Loading branch information
fritzgoebel committed Jan 7, 2025
1 parent 12f827f commit 762277e
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 28 deletions.
5 changes: 3 additions & 2 deletions core/device_hooks/common_kernels.inc.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

Expand Down Expand Up @@ -363,7 +363,8 @@ GKO_STUB_VALUE_AND_LOCAL_GLOBAL_INDEX_TYPE_BASE(
namespace distributed_dd_matrix {


GKO_STUB_VALUE_AND_LOCAL_GLOBAL_INDEX_TYPE(GKO_DECLARE_FILTER_NON_OWNING_IDXS);
GKO_STUB_VALUE_AND_LOCAL_GLOBAL_INDEX_TYPE_BASE(
GKO_DECLARE_FILTER_NON_OWNING_IDXS);


} // namespace distributed_dd_matrix
Expand Down
24 changes: 12 additions & 12 deletions core/distributed/dd_matrix.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

Expand Down Expand Up @@ -44,7 +44,7 @@ template <typename ValueType, typename LocalIndexType, typename GlobalIndexType>
DdMatrix<ValueType, LocalIndexType, GlobalIndexType>::DdMatrix(
std::shared_ptr<const Executor> exec, mpi::communicator comm,
ptr_param<const LinOp> matrix_template)
: EnableDistributedLinOp<
: EnableLinOp<
DdMatrix<value_type, local_index_type, global_index_type>>{exec},
DistributedBase{comm},
send_offsets_(comm.size() + 1),
Expand All @@ -68,7 +68,7 @@ template <typename ValueType, typename LocalIndexType, typename GlobalIndexType>
DdMatrix<ValueType, LocalIndexType, GlobalIndexType>::DdMatrix(
std::shared_ptr<const Executor> exec, mpi::communicator comm, dim<2> size,
std::shared_ptr<LinOp> local_linop)
: EnableDistributedLinOp<
: EnableLinOp<
DdMatrix<value_type, local_index_type, global_index_type>>{exec},
DistributedBase{comm},
send_offsets_(comm.size() + 1),
Expand Down Expand Up @@ -107,8 +107,8 @@ DdMatrix<ValueType, LocalIndexType, GlobalIndexType>::create(

template <typename ValueType, typename LocalIndexType, typename GlobalIndexType>
void DdMatrix<ValueType, LocalIndexType, GlobalIndexType>::convert_to(
DdMatrix<next_precision<value_type>, local_index_type, global_index_type>*
result) const
DdMatrix<next_precision_base<value_type>, local_index_type,
global_index_type>* result) const
{
GKO_ASSERT(this->get_communicator().size() ==
result->get_communicator().size());
Expand All @@ -125,8 +125,8 @@ void DdMatrix<ValueType, LocalIndexType, GlobalIndexType>::convert_to(

template <typename ValueType, typename LocalIndexType, typename GlobalIndexType>
void DdMatrix<ValueType, LocalIndexType, GlobalIndexType>::move_to(
DdMatrix<next_precision<value_type>, local_index_type, global_index_type>*
result)
DdMatrix<next_precision_base<value_type>, local_index_type,
global_index_type>* result)
{
GKO_ASSERT(this->get_communicator().size() ==
result->get_communicator().size());
Expand Down Expand Up @@ -430,8 +430,8 @@ void DdMatrix<ValueType, LocalIndexType, GlobalIndexType>::row_scale(
template <typename ValueType, typename LocalIndexType, typename GlobalIndexType>
DdMatrix<ValueType, LocalIndexType, GlobalIndexType>::DdMatrix(
const DdMatrix& other)
: EnableDistributedLinOp<DdMatrix<value_type, local_index_type,
global_index_type>>{other.get_executor()},
: EnableLinOp<DdMatrix<value_type, local_index_type,
global_index_type>>{other.get_executor()},
DistributedBase{other.get_communicator()}
{
*this = other;
Expand All @@ -441,8 +441,8 @@ DdMatrix<ValueType, LocalIndexType, GlobalIndexType>::DdMatrix(
template <typename ValueType, typename LocalIndexType, typename GlobalIndexType>
DdMatrix<ValueType, LocalIndexType, GlobalIndexType>::DdMatrix(
DdMatrix&& other) noexcept
: EnableDistributedLinOp<DdMatrix<value_type, local_index_type,
global_index_type>>{other.get_executor()},
: EnableLinOp<DdMatrix<value_type, local_index_type,
global_index_type>>{other.get_executor()},
DistributedBase{other.get_communicator()}
{
*this = std::move(other);
Expand Down Expand Up @@ -499,7 +499,7 @@ DdMatrix<ValueType, LocalIndexType, GlobalIndexType>::operator=(
#define GKO_DECLARE_DISTRIBUTED_DD_MATRIX(ValueType, LocalIndexType, \
GlobalIndexType) \
class DdMatrix<ValueType, LocalIndexType, GlobalIndexType>
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_LOCAL_GLOBAL_INDEX_TYPE(
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_LOCAL_GLOBAL_INDEX_TYPE_BASE(
GKO_DECLARE_DISTRIBUTED_DD_MATRIX);


Expand Down
27 changes: 13 additions & 14 deletions include/ginkgo/core/distributed/dd_matrix.hpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

Expand All @@ -13,11 +13,11 @@


#include <ginkgo/core/base/dense_cache.hpp>
#include <ginkgo/core/base/lin_op.hpp>
#include <ginkgo/core/base/mpi.hpp>
#include <ginkgo/core/base/std_extensions.hpp>
#include <ginkgo/core/distributed/base.hpp>
#include <ginkgo/core/distributed/index_map.hpp>
#include <ginkgo/core/distributed/lin_op.hpp>
#include <ginkgo/core/distributed/matrix.hpp>


Expand Down Expand Up @@ -169,13 +169,12 @@ class Vector;
template <typename ValueType = default_precision,
typename LocalIndexType = int32, typename GlobalIndexType = int64>
class DdMatrix
: public EnableDistributedLinOp<
DdMatrix<ValueType, LocalIndexType, GlobalIndexType>>,
public ConvertibleTo<
DdMatrix<next_precision<ValueType>, LocalIndexType, GlobalIndexType>>,
: public EnableLinOp<DdMatrix<ValueType, LocalIndexType, GlobalIndexType>>,
public ConvertibleTo<DdMatrix<next_precision_base<ValueType>,
LocalIndexType, GlobalIndexType>>,
public DistributedBase {
friend class EnableDistributedPolymorphicObject<DdMatrix, LinOp>;
friend class DdMatrix<next_precision<ValueType>, LocalIndexType,
friend class EnablePolymorphicObject<DdMatrix, LinOp>;
friend class DdMatrix<next_precision_base<ValueType>, LocalIndexType,
GlobalIndexType>;

public:
Expand All @@ -189,17 +188,17 @@ class DdMatrix
gko::experimental::distributed::Vector<ValueType>;
using local_vector_type = typename global_vector_type::local_vector_type;

using EnableDistributedLinOp<DdMatrix>::convert_to;
using EnableDistributedLinOp<DdMatrix>::move_to;
using ConvertibleTo<DdMatrix<next_precision<ValueType>, LocalIndexType,
using EnableLinOp<DdMatrix>::convert_to;
using EnableLinOp<DdMatrix>::move_to;
using ConvertibleTo<DdMatrix<next_precision_base<ValueType>, LocalIndexType,
GlobalIndexType>>::convert_to;
using ConvertibleTo<DdMatrix<next_precision<ValueType>, LocalIndexType,
using ConvertibleTo<DdMatrix<next_precision_base<ValueType>, LocalIndexType,
GlobalIndexType>>::move_to;

void convert_to(DdMatrix<next_precision<value_type>, local_index_type,
void convert_to(DdMatrix<next_precision_base<value_type>, local_index_type,
global_index_type>* result) const override;

void move_to(DdMatrix<next_precision<value_type>, local_index_type,
void move_to(DdMatrix<next_precision_base<value_type>, local_index_type,
global_index_type>* result) override;

/**
Expand Down

0 comments on commit 762277e

Please sign in to comment.