Skip to content

Commit

Permalink
WIP: not good approach yet
Browse files Browse the repository at this point in the history
  • Loading branch information
yhmtsai committed Jan 20, 2025
1 parent 81d7814 commit ce6043c
Show file tree
Hide file tree
Showing 11 changed files with 270 additions and 108 deletions.
29 changes: 20 additions & 9 deletions common/unified/solver/chebyshev_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@

#include "core/solver/chebyshev_kernels.hpp"

#include <ginkgo/core/base/math.hpp>
#include <ginkgo/core/matrix/dense.hpp>

#include "common/unified/base/kernel_launch.hpp"
#include "core/base/mixed_precision_types.hpp"


namespace gko {
Expand All @@ -22,18 +24,22 @@ void init_update(std::shared_ptr<const DefaultExecutor> exec,
matrix::Dense<ValueType>* update_sol,
matrix::Dense<ValueType>* output)
{
using type = device_type<highest_precision<ValueType, ScalarType>>;
run_kernel(
exec,
[] GKO_KERNEL(auto row, auto col, auto alpha, auto inner_sol,
auto update_sol, auto output) {
const auto inner_val = inner_sol(row, col);
update_sol(row, col) = inner_val;
output(row, col) += alpha * inner_val;
const auto inner_val = static_cast<type>(inner_sol(row, col));
update_sol(row, col) =
static_cast<device_type<ValueType>>(inner_val);
output(row, col) = static_cast<device_type<ValueType>>(
static_cast<type>(output(row, col)) +
static_cast<type>(alpha) * inner_val);
},
output->get_size(), alpha, inner_sol, update_sol, output);
}

GKO_INSTANTIATE_FOR_EACH_VALUE_AND_SCALAR_TYPE(
GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_2(
GKO_DECLARE_CHEBYSHEV_INIT_UPDATE_KERNEL);


Expand All @@ -43,19 +49,24 @@ void update(std::shared_ptr<const DefaultExecutor> exec, const ScalarType alpha,
matrix::Dense<ValueType>* update_sol,
matrix::Dense<ValueType>* output)
{
using type = device_type<highest_precision<ValueType, ScalarType>>;
run_kernel(
exec,
[] GKO_KERNEL(auto row, auto col, auto alpha, auto beta, auto inner_sol,
auto update_sol, auto output) {
const auto val = inner_sol(row, col) + beta * update_sol(row, col);
inner_sol(row, col) = val;
update_sol(row, col) = val;
output(row, col) += alpha * val;
const auto val = static_cast<type>(inner_sol(row, col)) +
static_cast<type>(beta) *
static_cast<type>(update_sol(row, col));
inner_sol(row, col) = static_cast<device_type<ValueType>>(val);
update_sol(row, col) = static_cast<device_type<ValueType>>(val);
output(row, col) = static_cast<device_type<ValueType>>(
static_cast<type>(output(row, col)) +
static_cast<type>(alpha) * val);
},
output->get_size(), alpha, beta, inner_sol, update_sol, output);
}

GKO_INSTANTIATE_FOR_EACH_VALUE_AND_SCALAR_TYPE(
GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_2(
GKO_DECLARE_CHEBYSHEV_UPDATE_KERNEL);


Expand Down
9 changes: 7 additions & 2 deletions core/device_hooks/common_kernels.inc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,11 @@
GKO_NOT_COMPILED(GKO_HOOK_MODULE); \
GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_AND_INDEX_TYPE(_macro)

#define GKO_STUB_FOR_EACH_MIXED_VALUE_TYPE_2(_macro) \
template <typename ValueType, typename ScalarType> \
_macro(ValueType, ScalarType) GKO_NOT_COMPILED(GKO_HOOK_MODULE); \
GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_2(_macro)

#define GKO_STUB_MIXED_VALUE_AND_INDEX_TYPE_2_BASE(_macro) \
template <typename InputValueType, typename OutputValueType, \
typename IndexType> \
Expand Down Expand Up @@ -657,8 +662,8 @@ GKO_STUB_CB_GMRES_CONST(GKO_DECLARE_CB_GMRES_SOLVE_KRYLOV_KERNEL);
namespace chebyshev {


GKO_STUB_VALUE_AND_SCALAR_TYPE(GKO_DECLARE_CHEBYSHEV_INIT_UPDATE_KERNEL);
GKO_STUB_VALUE_AND_SCALAR_TYPE(GKO_DECLARE_CHEBYSHEV_UPDATE_KERNEL);
GKO_STUB_FOR_EACH_MIXED_VALUE_TYPE_2(GKO_DECLARE_CHEBYSHEV_INIT_UPDATE_KERNEL);
GKO_STUB_FOR_EACH_MIXED_VALUE_TYPE_2(GKO_DECLARE_CHEBYSHEV_UPDATE_KERNEL);


} // namespace chebyshev
Expand Down
124 changes: 88 additions & 36 deletions core/distributed/matrix.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

Expand All @@ -10,6 +10,7 @@
#include <ginkgo/core/distributed/vector.hpp>
#include <ginkgo/core/matrix/coo.hpp>
#include <ginkgo/core/matrix/csr.hpp>
#include <ginkgo/core/matrix/dense.hpp>
#include <ginkgo/core/matrix/diagonal.hpp>

#include "core/components/prefix_sum_kernels.hpp"
Expand Down Expand Up @@ -55,7 +56,8 @@ Matrix<ValueType, LocalIndexType, GlobalIndexType>::Matrix(
non_local_to_global_{exec},
one_scalar_{},
local_mtx_{local_matrix_template->clone(exec)},
non_local_mtx_{non_local_matrix_template->clone(exec)}
non_local_mtx_{non_local_matrix_template->clone(exec)},
local_only_{false}
{
GKO_ASSERT(
(dynamic_cast<ReadableFromMatrixData<ValueType, LocalIndexType>*>(
Expand All @@ -81,7 +83,8 @@ Matrix<ValueType, LocalIndexType, GlobalIndexType>::Matrix(
non_local_to_global_{exec},
one_scalar_{},
non_local_mtx_(::gko::matrix::Coo<ValueType, LocalIndexType>::create(
exec, dim<2>{local_linop->get_size()[0], 0}))
exec, dim<2>{local_linop->get_size()[0], 0})),
local_only_{true}
{
this->set_size(size);
one_scalar_.init(exec, dim<2>{1, 1});
Expand All @@ -104,7 +107,8 @@ Matrix<ValueType, LocalIndexType, GlobalIndexType>::Matrix(
recv_sizes_(comm.size()),
gather_idxs_{exec},
non_local_to_global_{exec},
one_scalar_{}
one_scalar_{},
local_only_{false}
{
this->set_size(size);
local_mtx_ = std::move(local_linop);
Expand Down Expand Up @@ -445,8 +449,9 @@ void Matrix<ValueType, LocalIndexType, GlobalIndexType>::read_distributed(


template <typename ValueType, typename LocalIndexType, typename GlobalIndexType>
template <typename VectorValueType>
mpi::request Matrix<ValueType, LocalIndexType, GlobalIndexType>::communicate(
const local_vector_type* local_b) const
const gko::matrix::Dense<VectorValueType>* local_b) const
{
// This function can never return early!
// Even if the non-local part is empty, i.e. this process doesn't need
Expand All @@ -460,23 +465,26 @@ mpi::request Matrix<ValueType, LocalIndexType, GlobalIndexType>::communicate(
auto recv_size = recv_offsets_.back();
auto send_dim = dim<2>{static_cast<size_type>(send_size), num_cols};
auto recv_dim = dim<2>{static_cast<size_type>(recv_size), num_cols};
recv_buffer_.init(exec, recv_dim);
send_buffer_.init(exec, send_dim);

local_b->row_gather(&gather_idxs_, send_buffer_.get());
auto recv_buffer =
recv_buffer_.template get<VectorValueType>(exec, recv_dim);
auto send_buffer =
recv_buffer_.template get<VectorValueType>(exec, send_dim);
local_b->row_gather(&gather_idxs_, send_buffer);

auto use_host_buffer = mpi::requires_host_buffer(exec, comm);
auto send_ptr = send_buffer->get_const_values();
auto recv_ptr = recv_buffer->get_values();
if (use_host_buffer) {
host_recv_buffer_.init(exec->get_master(), recv_dim);
host_send_buffer_.init(exec->get_master(), send_dim);
host_send_buffer_->copy_from(send_buffer_.get());
auto host_recv_buffer = host_recv_buffer_.template get<VectorValueType>(
exec->get_master(), recv_dim);
auto host_send_buffer = host_send_buffer_.template get<VectorValueType>(
exec->get_master(), send_dim);
host_send_buffer->copy_from(send_buffer);
send_ptr = host_send_buffer->get_const_values();
recv_ptr = host_recv_buffer->get_values();
}

mpi::contiguous_type type(num_cols, mpi::type_impl<ValueType>::get_type());
auto send_ptr = use_host_buffer ? host_send_buffer_->get_const_values()
: send_buffer_->get_const_values();
auto recv_ptr = use_host_buffer ? host_recv_buffer_->get_values()
: recv_buffer_->get_values();
exec->synchronize();
#ifdef GINKGO_FORCE_SPMV_BLOCKING_COMM
comm.all_to_all_v(use_host_buffer ? exec->get_master() : exec, send_ptr,
Expand All @@ -497,10 +505,14 @@ template <typename ValueType, typename LocalIndexType, typename GlobalIndexType>
void Matrix<ValueType, LocalIndexType, GlobalIndexType>::apply_impl(
const LinOp* b, LinOp* x) const
{
distributed::precision_dispatch_real_complex<ValueType>(
distributed::mixed_precision_dispatch_real_complex<ValueType>(
[this](const auto dense_b, auto dense_x) {
auto x_exec = dense_x->get_executor();
auto local_x = gko::matrix::Dense<ValueType>::create(
using x_value_type =
typename std::decay_t<decltype(*dense_x)>::value_type;
using b_value_type =
typename std::decay_t<decltype(*dense_b)>::value_type;
auto local_x = gko::matrix::Dense<x_value_type>::create(
x_exec, dense_x->get_local_vector()->get_size(),
gko::make_array_view(
x_exec,
Expand All @@ -509,16 +521,31 @@ void Matrix<ValueType, LocalIndexType, GlobalIndexType>::apply_impl(
dense_x->get_local_vector()->get_stride());

auto comm = this->get_communicator();
auto req = this->communicate(dense_b->get_local_vector());
mpi::request req;
if (!local_only_) {
req = this->communicate(dense_b->get_local_vector());
}
local_mtx_->apply(dense_b->get_local_vector(), local_x);
if (local_only_) {
return;
}
req.wait();

auto exec = this->get_executor();
auto use_host_buffer = mpi::requires_host_buffer(exec, comm);

auto recv_size = recv_offsets_.back();
auto recv_dim = dim<2>{static_cast<size_type>(recv_size),
dense_b->get_size()[1]};
auto recv_buffer =
recv_buffer_.template get<b_value_type>(exec, recv_dim);
if (use_host_buffer) {
recv_buffer_->copy_from(host_recv_buffer_.get());
auto host_recv_buffer =
host_recv_buffer_.template get<b_value_type>(exec,
recv_dim);
recv_buffer->copy_from(host_recv_buffer);
}
non_local_mtx_->apply(one_scalar_.get(), recv_buffer_.get(),
non_local_mtx_->apply(one_scalar_.get(), recv_buffer,
one_scalar_.get(), local_x);
},
b, x);
Expand All @@ -529,11 +556,17 @@ template <typename ValueType, typename LocalIndexType, typename GlobalIndexType>
void Matrix<ValueType, LocalIndexType, GlobalIndexType>::apply_impl(
const LinOp* alpha, const LinOp* b, const LinOp* beta, LinOp* x) const
{
distributed::precision_dispatch_real_complex<ValueType>(
[this](const auto local_alpha, const auto dense_b,
const auto local_beta, auto dense_x) {
distributed::mixed_precision_dispatch_real_complex<ValueType>(
[this, alpha, beta](const auto dense_b, auto dense_x) {
const auto x_exec = dense_x->get_executor();
auto local_x = gko::matrix::Dense<ValueType>::create(
using x_value_type =
typename std::decay_t<decltype(*dense_x)>::value_type;
using b_value_type =
typename std::decay_t<decltype(*dense_b)>::value_type;
auto local_alpha = gko::make_temporary_conversion<ValueType>(alpha);
auto local_beta =
gko::make_temporary_conversion<x_value_type>(beta);
auto local_x = gko::matrix::Dense<x_value_type>::create(
x_exec, dense_x->get_local_vector()->get_size(),
gko::make_array_view(
x_exec,
Expand All @@ -542,20 +575,34 @@ void Matrix<ValueType, LocalIndexType, GlobalIndexType>::apply_impl(
dense_x->get_local_vector()->get_stride());

auto comm = this->get_communicator();
auto req = this->communicate(dense_b->get_local_vector());
local_mtx_->apply(local_alpha, dense_b->get_local_vector(),
local_beta, local_x);
mpi::request req;
if (!local_only_) {
req = this->communicate(dense_b->get_local_vector());
}
local_mtx_->apply(local_alpha.get(), dense_b->get_local_vector(),
local_beta.get(), local_x);
if (local_only_) {
return;
}
req.wait();

auto exec = this->get_executor();
auto use_host_buffer = mpi::requires_host_buffer(exec, comm);
auto recv_size = recv_offsets_.back();
auto recv_dim = dim<2>{static_cast<size_type>(recv_size),
dense_b->get_size()[1]};
auto recv_buffer =
recv_buffer_.template get<b_value_type>(exec, recv_dim);
if (use_host_buffer) {
recv_buffer_->copy_from(host_recv_buffer_.get());
auto host_recv_buffer =
host_recv_buffer_.template get<b_value_type>(exec,
recv_dim);
recv_buffer->copy_from(host_recv_buffer);
}
non_local_mtx_->apply(local_alpha, recv_buffer_.get(),
non_local_mtx_->apply(local_alpha.get(), recv_buffer,
one_scalar_.get(), local_x);
},
alpha, b, beta, x);
b, x);
}


Expand All @@ -582,21 +629,26 @@ void Matrix<ValueType, LocalIndexType, GlobalIndexType>::col_scale(
exec, n_local_cols,
make_const_array_view(exec, n_local_cols, scale_values));

auto req = this->communicate(
stride == 1 ? scaling_factors->get_local_vector()
: scaling_factors_single_stride->get_local_vector());
auto factors = stride == 1
? scaling_factors->get_local_vector()
: scaling_factors_single_stride->get_local_vector();
auto req = this->communicate(factors);
scale_diag->rapply(local_mtx_, local_mtx_);
req.wait();
if (n_non_local_cols > 0) {
auto use_host_buffer = mpi::requires_host_buffer(exec, comm);
auto recv_buffer = recv_buffer_.template get<ValueType>(
exec, gko::dim<2>(n_non_local_cols, 1));
if (use_host_buffer) {
recv_buffer_->copy_from(host_recv_buffer_.get());
auto host_recv_buffer = host_recv_buffer_.template get<ValueType>(
exec->get_master(), gko::dim<2>(n_non_local_cols, 1));
recv_buffer->copy_from(host_recv_buffer);
}
const auto non_local_scale_diag =
gko::matrix::Diagonal<ValueType>::create_const(
exec, n_non_local_cols,
make_const_array_view(exec, n_non_local_cols,
recv_buffer_->get_const_values()));
recv_buffer->get_const_values()));
non_local_scale_diag->rapply(non_local_mtx_, non_local_mtx_);
}
}
Expand Down
20 changes: 10 additions & 10 deletions core/solver/chebyshev_kernels.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,18 @@ namespace kernels {
namespace chebyshev {


#define GKO_DECLARE_CHEBYSHEV_INIT_UPDATE_KERNEL(ValueType, ScalarType) \
void init_update(std::shared_ptr<const DefaultExecutor> exec, \
const ScalarType alpha, \
const matrix::Dense<ValueType>* inner_sol, \
matrix::Dense<ValueType>* update_sol, \
#define GKO_DECLARE_CHEBYSHEV_INIT_UPDATE_KERNEL(ValueType, ScalarType, ...) \
void init_update(std::shared_ptr<const DefaultExecutor> exec, \
const ScalarType alpha, \
const matrix::Dense<ValueType>* inner_sol, \
matrix::Dense<ValueType>* update_sol, \
matrix::Dense<ValueType>* output)

#define GKO_DECLARE_CHEBYSHEV_UPDATE_KERNEL(ValueType, ScalarType) \
void update(std::shared_ptr<const DefaultExecutor> exec, \
const ScalarType alpha, const ScalarType beta, \
matrix::Dense<ValueType>* inner_sol, \
matrix::Dense<ValueType>* update_sol, \
#define GKO_DECLARE_CHEBYSHEV_UPDATE_KERNEL(ValueType, ScalarType, ...) \
void update(std::shared_ptr<const DefaultExecutor> exec, \
const ScalarType alpha, const ScalarType beta, \
matrix::Dense<ValueType>* inner_sol, \
matrix::Dense<ValueType>* update_sol, \
matrix::Dense<ValueType>* output)

#define GKO_DECLARE_ALL_AS_TEMPLATES \
Expand Down
Loading

0 comments on commit ce6043c

Please sign in to comment.