From 373511094eb7a1e53e3035ea942347b23c8a3b0c Mon Sep 17 00:00:00 2001 From: "Yu-Hsiang M. Tsai" Date: Tue, 29 Oct 2024 01:04:27 +0100 Subject: [PATCH] batch with half --- .../base/batch_multi_vector_kernels.cpp | 13 +++-- common/cuda_hip/matrix/batch_csr_kernels.cpp | 8 +-- .../cuda_hip/matrix/batch_dense_kernels.cpp | 12 ++-- common/cuda_hip/matrix/batch_ell_kernels.cpp | 8 +-- core/base/batch_multi_vector.cpp | 27 ++++++++- core/device_hooks/common_kernels.inc.cpp | 57 +++++++++++-------- core/log/batch_logger.cpp | 4 +- core/matrix/batch_csr.cpp | 29 +++++++++- core/matrix/batch_dense.cpp | 28 ++++++++- core/matrix/batch_ell.cpp | 29 +++++++++- core/matrix/batch_identity.cpp | 3 +- core/preconditioner/batch_jacobi.cpp | 2 +- core/solver/batch_bicgstab.cpp | 2 +- core/solver/batch_cg.cpp | 2 +- core/solver/batch_dispatch.hpp | 40 ++++++++++++- cuda/preconditioner/batch_jacobi_kernels.cu | 4 +- cuda/solver/batch_bicgstab_kernels.cu | 50 ++++++++-------- cuda/solver/batch_cg_kernels.cu | 51 +++++++++-------- dpcpp/base/batch_multi_vector_kernels.dp.cpp | 13 +++-- dpcpp/base/batch_multi_vector_kernels.hpp | 41 ------------- dpcpp/matrix/batch_csr_kernels.dp.cpp | 8 +-- dpcpp/matrix/batch_dense_kernels.dp.cpp | 12 ++-- dpcpp/matrix/batch_ell_kernels.dp.cpp | 8 +-- dpcpp/preconditioner/batch_block_jacobi.hpp | 7 ++- .../batch_jacobi_kernels.dp.cpp | 4 +- dpcpp/solver/batch_bicgstab_kernels.dp.cpp | 3 +- dpcpp/solver/batch_cg_kernels.dp.cpp | 3 +- .../batch_jacobi_kernels.hip.cpp | 4 +- hip/solver/batch_bicgstab_kernels.hip.cpp | 51 ++++++++--------- hip/solver/batch_cg_kernels.hip.cpp | 45 ++++++++------- .../ginkgo/core/base/batch_multi_vector.hpp | 36 ++++++++++-- include/ginkgo/core/log/logger.hpp | 12 ++++ include/ginkgo/core/matrix/batch_csr.hpp | 35 ++++++++++-- include/ginkgo/core/matrix/batch_dense.hpp | 33 +++++++++-- include/ginkgo/core/matrix/batch_ell.hpp | 35 ++++++++++-- omp/base/batch_multi_vector_kernels.cpp | 13 +++-- omp/matrix/batch_csr_kernels.cpp | 8 +-- omp/matrix/batch_dense_kernels.cpp | 12 ++-- omp/matrix/batch_ell_kernels.cpp | 8 +-- omp/preconditioner/batch_jacobi_kernels.cpp | 4 +- omp/solver/batch_bicgstab_kernels.cpp | 3 +- omp/solver/batch_cg_kernels.cpp | 3 +- reference/base/batch_multi_vector_kernels.cpp | 13 +++-- reference/matrix/batch_csr_kernels.cpp | 8 +-- reference/matrix/batch_dense_kernels.cpp | 12 ++-- reference/matrix/batch_ell_kernels.cpp | 8 +-- .../preconditioner/batch_jacobi_kernels.cpp | 4 +- reference/solver/batch_bicgstab_kernels.cpp | 3 +- reference/solver/batch_cg_kernels.cpp | 3 +- 49 files changed, 526 insertions(+), 295 deletions(-) diff --git a/common/cuda_hip/base/batch_multi_vector_kernels.cpp b/common/cuda_hip/base/batch_multi_vector_kernels.cpp index 17f65487464..b71fa14f4b3 100644 --- a/common/cuda_hip/base/batch_multi_vector_kernels.cpp +++ b/common/cuda_hip/base/batch_multi_vector_kernels.cpp @@ -55,7 +55,7 @@ void scale(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( GKO_DECLARE_BATCH_MULTI_VECTOR_SCALE_KERNEL); @@ -81,7 +81,7 @@ void add_scaled(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( GKO_DECLARE_BATCH_MULTI_VECTOR_ADD_SCALED_KERNEL); @@ -101,7 +101,7 @@ void compute_dot(std::shared_ptr exec, x_ub, y_ub, res_ub, [] __device__(auto val) { return val; }); } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( GKO_DECLARE_BATCH_MULTI_VECTOR_COMPUTE_DOT_KERNEL); @@ -121,7 +121,7 @@ void compute_conj_dot(std::shared_ptr exec, x_ub, y_ub, res_ub, [] __device__(auto val) { return conj(val); }); } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( GKO_DECLARE_BATCH_MULTI_VECTOR_COMPUTE_CONJ_DOT_KERNEL); @@ -139,7 +139,7 @@ void compute_norm2(std::shared_ptr exec, x_ub, res_ub); } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( GKO_DECLARE_BATCH_MULTI_VECTOR_COMPUTE_NORM2_KERNEL); @@ -156,7 +156,8 @@ void copy(std::shared_ptr exec, x_ub, result_ub); } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_BATCH_MULTI_VECTOR_COPY_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( + GKO_DECLARE_BATCH_MULTI_VECTOR_COPY_KERNEL); } // namespace batch_multi_vector diff --git a/common/cuda_hip/matrix/batch_csr_kernels.cpp b/common/cuda_hip/matrix/batch_csr_kernels.cpp index d48cdbaf32a..0db100363b8 100644 --- a/common/cuda_hip/matrix/batch_csr_kernels.cpp +++ b/common/cuda_hip/matrix/batch_csr_kernels.cpp @@ -46,7 +46,7 @@ void simple_apply(std::shared_ptr exec, } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE_WITH_HALF( GKO_DECLARE_BATCH_CSR_SIMPLE_APPLY_KERNEL); @@ -72,7 +72,7 @@ void advanced_apply(std::shared_ptr exec, alpha_ub, mat_ub, b_ub, beta_ub, x_ub); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE_WITH_HALF( GKO_DECLARE_BATCH_CSR_ADVANCED_APPLY_KERNEL); @@ -91,7 +91,7 @@ void scale(std::shared_ptr exec, mat_ub); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE_WITH_HALF( GKO_DECLARE_BATCH_CSR_SCALE_KERNEL); @@ -110,7 +110,7 @@ void add_scaled_identity(std::shared_ptr exec, alpha_ub, beta_ub, mat_ub); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE_WITH_HALF( GKO_DECLARE_BATCH_CSR_ADD_SCALED_IDENTITY_KERNEL); diff --git a/common/cuda_hip/matrix/batch_dense_kernels.cpp b/common/cuda_hip/matrix/batch_dense_kernels.cpp index ee4d87abaa3..e0f1fc5e8dc 100644 --- a/common/cuda_hip/matrix/batch_dense_kernels.cpp +++ b/common/cuda_hip/matrix/batch_dense_kernels.cpp @@ -45,7 +45,7 @@ void simple_apply(std::shared_ptr exec, mat_ub, b_ub, x_ub); } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( GKO_DECLARE_BATCH_DENSE_SIMPLE_APPLY_KERNEL); @@ -71,7 +71,7 @@ void advanced_apply(std::shared_ptr exec, alpha_ub, mat_ub, b_ub, beta_ub, x_ub); } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( GKO_DECLARE_BATCH_DENSE_ADVANCED_APPLY_KERNEL); @@ -90,7 +90,8 @@ void scale(std::shared_ptr exec, mat_ub); } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_BATCH_DENSE_SCALE_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( + GKO_DECLARE_BATCH_DENSE_SCALE_KERNEL); template @@ -108,7 +109,8 @@ void scale_add(std::shared_ptr exec, alpha_ub, mat_ub, in_out_ub); } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_BATCH_DENSE_SCALE_ADD_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( + GKO_DECLARE_BATCH_DENSE_SCALE_ADD_KERNEL); template @@ -126,7 +128,7 @@ void add_scaled_identity(std::shared_ptr exec, alpha_ub, beta_ub, mat_ub); } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( GKO_DECLARE_BATCH_DENSE_ADD_SCALED_IDENTITY_KERNEL); diff --git a/common/cuda_hip/matrix/batch_ell_kernels.cpp b/common/cuda_hip/matrix/batch_ell_kernels.cpp index 38d34707d45..dddb53e34ff 100644 --- a/common/cuda_hip/matrix/batch_ell_kernels.cpp +++ b/common/cuda_hip/matrix/batch_ell_kernels.cpp @@ -46,7 +46,7 @@ void simple_apply(std::shared_ptr exec, } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE_WITH_HALF( GKO_DECLARE_BATCH_ELL_SIMPLE_APPLY_KERNEL); @@ -72,7 +72,7 @@ void advanced_apply(std::shared_ptr exec, alpha_ub, mat_ub, b_ub, beta_ub, x_ub); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE_WITH_HALF( GKO_DECLARE_BATCH_ELL_ADVANCED_APPLY_KERNEL); @@ -91,7 +91,7 @@ void scale(std::shared_ptr exec, mat_ub); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE_WITH_HALF( GKO_DECLARE_BATCH_ELL_SCALE_KERNEL); @@ -110,7 +110,7 @@ void add_scaled_identity(std::shared_ptr exec, alpha_ub, beta_ub, mat_ub); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE_WITH_HALF( GKO_DECLARE_BATCH_ELL_ADD_SCALED_IDENTITY_KERNEL); diff --git a/core/base/batch_multi_vector.cpp b/core/base/batch_multi_vector.cpp index f4485377f25..1eb3cd8f60d 100644 --- a/core/base/batch_multi_vector.cpp +++ b/core/base/batch_multi_vector.cpp @@ -281,7 +281,7 @@ void MultiVector::compute_norm2( template void MultiVector::convert_to( - MultiVector>* result) const + MultiVector>* result) const { result->values_ = this->values_; result->set_size(this->get_size()); @@ -290,14 +290,35 @@ void MultiVector::convert_to( template void MultiVector::move_to( - MultiVector>* result) + MultiVector>* result) { this->convert_to(result); } +#if GINKGO_ENABLE_HALF +template +void MultiVector::convert_to( + MultiVector>>* + result) const +{ + result->values_ = this->values_; + result->set_size(this->get_size()); +} + + +template +void MultiVector::move_to( + MultiVector>>* + result) +{ + this->convert_to(result); +} +#endif + + #define GKO_DECLARE_BATCH_MULTI_VECTOR(_type) class MultiVector<_type> -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_BATCH_MULTI_VECTOR); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(GKO_DECLARE_BATCH_MULTI_VECTOR); } // namespace batch diff --git a/core/device_hooks/common_kernels.inc.cpp b/core/device_hooks/common_kernels.inc.cpp index b26697cd6a9..6cdbe1348ac 100644 --- a/core/device_hooks/common_kernels.inc.cpp +++ b/core/device_hooks/common_kernels.inc.cpp @@ -341,12 +341,15 @@ GKO_STUB_VALUE_AND_LOCAL_GLOBAL_INDEX_TYPE(GKO_DECLARE_SEPARATE_LOCAL_NONLOCAL); namespace batch_multi_vector { -GKO_STUB_VALUE_TYPE(GKO_DECLARE_BATCH_MULTI_VECTOR_SCALE_KERNEL); -GKO_STUB_VALUE_TYPE(GKO_DECLARE_BATCH_MULTI_VECTOR_ADD_SCALED_KERNEL); -GKO_STUB_VALUE_TYPE(GKO_DECLARE_BATCH_MULTI_VECTOR_COMPUTE_DOT_KERNEL); -GKO_STUB_VALUE_TYPE(GKO_DECLARE_BATCH_MULTI_VECTOR_COMPUTE_CONJ_DOT_KERNEL); -GKO_STUB_VALUE_TYPE(GKO_DECLARE_BATCH_MULTI_VECTOR_COMPUTE_NORM2_KERNEL); -GKO_STUB_VALUE_TYPE(GKO_DECLARE_BATCH_MULTI_VECTOR_COPY_KERNEL); +GKO_STUB_VALUE_TYPE_WITH_HALF(GKO_DECLARE_BATCH_MULTI_VECTOR_SCALE_KERNEL); +GKO_STUB_VALUE_TYPE_WITH_HALF(GKO_DECLARE_BATCH_MULTI_VECTOR_ADD_SCALED_KERNEL); +GKO_STUB_VALUE_TYPE_WITH_HALF( + GKO_DECLARE_BATCH_MULTI_VECTOR_COMPUTE_DOT_KERNEL); +GKO_STUB_VALUE_TYPE_WITH_HALF( + GKO_DECLARE_BATCH_MULTI_VECTOR_COMPUTE_CONJ_DOT_KERNEL); +GKO_STUB_VALUE_TYPE_WITH_HALF( + GKO_DECLARE_BATCH_MULTI_VECTOR_COMPUTE_NORM2_KERNEL); +GKO_STUB_VALUE_TYPE_WITH_HALF(GKO_DECLARE_BATCH_MULTI_VECTOR_COPY_KERNEL); } // namespace batch_multi_vector @@ -355,10 +358,13 @@ GKO_STUB_VALUE_TYPE(GKO_DECLARE_BATCH_MULTI_VECTOR_COPY_KERNEL); namespace batch_csr { -GKO_STUB_VALUE_AND_INT32_TYPE(GKO_DECLARE_BATCH_CSR_SIMPLE_APPLY_KERNEL); -GKO_STUB_VALUE_AND_INT32_TYPE(GKO_DECLARE_BATCH_CSR_ADVANCED_APPLY_KERNEL); -GKO_STUB_VALUE_AND_INT32_TYPE(GKO_DECLARE_BATCH_CSR_SCALE_KERNEL); -GKO_STUB_VALUE_AND_INT32_TYPE(GKO_DECLARE_BATCH_CSR_ADD_SCALED_IDENTITY_KERNEL); +GKO_STUB_VALUE_AND_INT32_TYPE_WITH_HALF( + GKO_DECLARE_BATCH_CSR_SIMPLE_APPLY_KERNEL); +GKO_STUB_VALUE_AND_INT32_TYPE_WITH_HALF( + GKO_DECLARE_BATCH_CSR_ADVANCED_APPLY_KERNEL); +GKO_STUB_VALUE_AND_INT32_TYPE_WITH_HALF(GKO_DECLARE_BATCH_CSR_SCALE_KERNEL); +GKO_STUB_VALUE_AND_INT32_TYPE_WITH_HALF( + GKO_DECLARE_BATCH_CSR_ADD_SCALED_IDENTITY_KERNEL); } // namespace batch_csr @@ -367,11 +373,12 @@ GKO_STUB_VALUE_AND_INT32_TYPE(GKO_DECLARE_BATCH_CSR_ADD_SCALED_IDENTITY_KERNEL); namespace batch_dense { -GKO_STUB_VALUE_TYPE(GKO_DECLARE_BATCH_DENSE_SIMPLE_APPLY_KERNEL); -GKO_STUB_VALUE_TYPE(GKO_DECLARE_BATCH_DENSE_ADVANCED_APPLY_KERNEL); -GKO_STUB_VALUE_TYPE(GKO_DECLARE_BATCH_DENSE_SCALE_KERNEL); -GKO_STUB_VALUE_TYPE(GKO_DECLARE_BATCH_DENSE_SCALE_ADD_KERNEL); -GKO_STUB_VALUE_TYPE(GKO_DECLARE_BATCH_DENSE_ADD_SCALED_IDENTITY_KERNEL); +GKO_STUB_VALUE_TYPE_WITH_HALF(GKO_DECLARE_BATCH_DENSE_SIMPLE_APPLY_KERNEL); +GKO_STUB_VALUE_TYPE_WITH_HALF(GKO_DECLARE_BATCH_DENSE_ADVANCED_APPLY_KERNEL); +GKO_STUB_VALUE_TYPE_WITH_HALF(GKO_DECLARE_BATCH_DENSE_SCALE_KERNEL); +GKO_STUB_VALUE_TYPE_WITH_HALF(GKO_DECLARE_BATCH_DENSE_SCALE_ADD_KERNEL); +GKO_STUB_VALUE_TYPE_WITH_HALF( + GKO_DECLARE_BATCH_DENSE_ADD_SCALED_IDENTITY_KERNEL); } // namespace batch_dense @@ -380,10 +387,13 @@ GKO_STUB_VALUE_TYPE(GKO_DECLARE_BATCH_DENSE_ADD_SCALED_IDENTITY_KERNEL); namespace batch_ell { -GKO_STUB_VALUE_AND_INT32_TYPE(GKO_DECLARE_BATCH_ELL_SIMPLE_APPLY_KERNEL); -GKO_STUB_VALUE_AND_INT32_TYPE(GKO_DECLARE_BATCH_ELL_ADVANCED_APPLY_KERNEL); -GKO_STUB_VALUE_AND_INT32_TYPE(GKO_DECLARE_BATCH_ELL_SCALE_KERNEL); -GKO_STUB_VALUE_AND_INT32_TYPE(GKO_DECLARE_BATCH_ELL_ADD_SCALED_IDENTITY_KERNEL); +GKO_STUB_VALUE_AND_INT32_TYPE_WITH_HALF( + GKO_DECLARE_BATCH_ELL_SIMPLE_APPLY_KERNEL); +GKO_STUB_VALUE_AND_INT32_TYPE_WITH_HALF( + GKO_DECLARE_BATCH_ELL_ADVANCED_APPLY_KERNEL); +GKO_STUB_VALUE_AND_INT32_TYPE_WITH_HALF(GKO_DECLARE_BATCH_ELL_SCALE_KERNEL); +GKO_STUB_VALUE_AND_INT32_TYPE_WITH_HALF( + GKO_DECLARE_BATCH_ELL_ADD_SCALED_IDENTITY_KERNEL); } // namespace batch_ell @@ -506,7 +516,7 @@ GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF( namespace batch_bicgstab { -GKO_STUB_VALUE_TYPE(GKO_DECLARE_BATCH_BICGSTAB_APPLY_KERNEL); +GKO_STUB_VALUE_TYPE_WITH_HALF(GKO_DECLARE_BATCH_BICGSTAB_APPLY_KERNEL); } // namespace batch_bicgstab @@ -515,7 +525,7 @@ GKO_STUB_VALUE_TYPE(GKO_DECLARE_BATCH_BICGSTAB_APPLY_KERNEL); namespace batch_cg { -GKO_STUB_VALUE_TYPE(GKO_DECLARE_BATCH_CG_APPLY_KERNEL); +GKO_STUB_VALUE_TYPE_WITH_HALF(GKO_DECLARE_BATCH_CG_APPLY_KERNEL); } // namespace batch_cg @@ -916,9 +926,10 @@ namespace batch_jacobi { GKO_STUB_INDEX_TYPE( GKO_DECLARE_BATCH_BLOCK_JACOBI_COMPUTE_CUMULATIVE_BLOCK_STORAGE); GKO_STUB_INDEX_TYPE(GKO_DECLARE_BATCH_BLOCK_JACOBI_FIND_ROW_BLOCK_MAP); -GKO_STUB_VALUE_AND_INT32_TYPE( +GKO_STUB_VALUE_AND_INT32_TYPE_WITH_HALF( GKO_DECLARE_BATCH_BLOCK_JACOBI_EXTRACT_PATTERN_KERNEL); -GKO_STUB_VALUE_AND_INT32_TYPE(GKO_DECLARE_BATCH_BLOCK_JACOBI_COMPUTE_KERNEL); +GKO_STUB_VALUE_AND_INT32_TYPE_WITH_HALF( + GKO_DECLARE_BATCH_BLOCK_JACOBI_COMPUTE_KERNEL); } // namespace batch_jacobi diff --git a/core/log/batch_logger.cpp b/core/log/batch_logger.cpp index 286803c0ae1..da4e715aa20 100644 --- a/core/log/batch_logger.cpp +++ b/core/log/batch_logger.cpp @@ -65,7 +65,7 @@ log_data::log_data(std::shared_ptr exec, #define GKO_DECLARE_LOG_DATA(_type) class log_data<_type> -GKO_INSTANTIATE_FOR_EACH_NON_COMPLEX_VALUE_TYPE(GKO_DECLARE_LOG_DATA); +GKO_INSTANTIATE_FOR_EACH_NON_COMPLEX_VALUE_TYPE_WITH_HALF(GKO_DECLARE_LOG_DATA); #undef GKO_DECLARE_LOG_DATA @@ -92,7 +92,7 @@ void BatchConvergence::on_batch_solver_completed( #define GKO_DECLARE_BATCH_CONVERGENCE(_type) class BatchConvergence<_type> -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_BATCH_CONVERGENCE); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(GKO_DECLARE_BATCH_CONVERGENCE); } // namespace log diff --git a/core/matrix/batch_csr.cpp b/core/matrix/batch_csr.cpp index 1b1dc22a6c4..141c5b86d02 100644 --- a/core/matrix/batch_csr.cpp +++ b/core/matrix/batch_csr.cpp @@ -246,7 +246,7 @@ void Csr::add_scaled_identity( template void Csr::convert_to( - Csr, IndexType>* result) const + Csr, IndexType>* result) const { result->values_ = this->values_; result->col_idxs_ = this->col_idxs_; @@ -257,14 +257,37 @@ void Csr::convert_to( template void Csr::move_to( - Csr, IndexType>* result) + Csr, IndexType>* result) { this->convert_to(result); } +#if GINKGO_ENABLE_HALF +template +void Csr::convert_to( + Csr>, + IndexType>* result) const +{ + result->values_ = this->values_; + result->col_idxs_ = this->col_idxs_; + result->row_ptrs_ = this->row_ptrs_; + result->set_size(this->get_size()); +} + + +template +void Csr::move_to( + Csr>, + IndexType>* result) +{ + this->convert_to(result); +} +#endif + + #define GKO_DECLARE_BATCH_CSR_MATRIX(ValueType) class Csr -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_BATCH_CSR_MATRIX); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(GKO_DECLARE_BATCH_CSR_MATRIX); } // namespace matrix diff --git a/core/matrix/batch_dense.cpp b/core/matrix/batch_dense.cpp index 6390a4c7ad0..0c1838abb56 100644 --- a/core/matrix/batch_dense.cpp +++ b/core/matrix/batch_dense.cpp @@ -245,7 +245,7 @@ void Dense::add_scaled_identity( template void Dense::convert_to( - Dense>* result) const + Dense>* result) const { result->values_ = this->values_; result->set_size(this->get_size()); @@ -253,14 +253,36 @@ void Dense::convert_to( template -void Dense::move_to(Dense>* result) +void Dense::move_to( + Dense>* result) { this->convert_to(result); } +#if GINKGO_ENABLE_HALF +template +void Dense::convert_to( + Dense>>* + result) const +{ + result->values_ = this->values_; + result->set_size(this->get_size()); +} + + +template +void Dense::move_to( + Dense>>* + result) +{ + this->convert_to(result); +} +#endif + + #define GKO_DECLARE_BATCH_DENSE_MATRIX(_type) class Dense<_type> -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_BATCH_DENSE_MATRIX); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(GKO_DECLARE_BATCH_DENSE_MATRIX); } // namespace matrix diff --git a/core/matrix/batch_ell.cpp b/core/matrix/batch_ell.cpp index 3722c41de60..3b829d3ba4c 100644 --- a/core/matrix/batch_ell.cpp +++ b/core/matrix/batch_ell.cpp @@ -266,7 +266,7 @@ void Ell::add_scaled_identity( template void Ell::convert_to( - Ell, IndexType>* result) const + Ell, IndexType>* result) const { result->values_ = this->values_; result->col_idxs_ = this->col_idxs_; @@ -277,14 +277,37 @@ void Ell::convert_to( template void Ell::move_to( - Ell, IndexType>* result) + Ell, IndexType>* result) { this->convert_to(result); } +#if GINKGO_ENABLE_HALF +template +void Ell::convert_to( + Ell>, + IndexType>* result) const +{ + result->values_ = this->values_; + result->col_idxs_ = this->col_idxs_; + result->num_elems_per_row_ = this->num_elems_per_row_; + result->set_size(this->get_size()); +} + + +template +void Ell::move_to( + Ell>, + IndexType>* result) +{ + this->convert_to(result); +} +#endif + + #define GKO_DECLARE_BATCH_ELL_MATRIX(ValueType) class Ell -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_BATCH_ELL_MATRIX); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(GKO_DECLARE_BATCH_ELL_MATRIX); } // namespace matrix diff --git a/core/matrix/batch_identity.cpp b/core/matrix/batch_identity.cpp index 2220120d00b..6ee2d55f6fe 100644 --- a/core/matrix/batch_identity.cpp +++ b/core/matrix/batch_identity.cpp @@ -113,7 +113,8 @@ void Identity::apply_impl(const MultiVector* alpha, #define GKO_DECLARE_BATCH_IDENTITY_MATRIX(ValueType) class Identity -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_BATCH_IDENTITY_MATRIX); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( + GKO_DECLARE_BATCH_IDENTITY_MATRIX); } // namespace matrix diff --git a/core/preconditioner/batch_jacobi.cpp b/core/preconditioner/batch_jacobi.cpp index f92ccd18cfc..a9d173ee9ac 100644 --- a/core/preconditioner/batch_jacobi.cpp +++ b/core/preconditioner/batch_jacobi.cpp @@ -175,7 +175,7 @@ void Jacobi::generate_precond( #define GKO_DECLARE_BATCH_JACOBI(_type) class Jacobi<_type, int32> -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_BATCH_JACOBI); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(GKO_DECLARE_BATCH_JACOBI); } // namespace preconditioner diff --git a/core/solver/batch_bicgstab.cpp b/core/solver/batch_bicgstab.cpp index c22c712b411..7e3e5330631 100644 --- a/core/solver/batch_bicgstab.cpp +++ b/core/solver/batch_bicgstab.cpp @@ -57,7 +57,7 @@ void Bicgstab::solver_apply( #define GKO_DECLARE_BATCH_BICGSTAB(_type) class Bicgstab<_type> -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_BATCH_BICGSTAB); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(GKO_DECLARE_BATCH_BICGSTAB); } // namespace solver diff --git a/core/solver/batch_cg.cpp b/core/solver/batch_cg.cpp index 0ab1ca8564f..607a6311e71 100644 --- a/core/solver/batch_cg.cpp +++ b/core/solver/batch_cg.cpp @@ -55,7 +55,7 @@ void Cg::solver_apply( #define GKO_DECLARE_BATCH_CG(_type) class Cg<_type> -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_BATCH_CG); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(GKO_DECLARE_BATCH_CG); } // namespace solver diff --git a/core/solver/batch_dispatch.hpp b/core/solver/batch_dispatch.hpp index 018a6674df5..72873fb91ec 100644 --- a/core/solver/batch_dispatch.hpp +++ b/core/solver/batch_dispatch.hpp @@ -85,6 +85,23 @@ using DeviceValueType = gko::kernels::hip::hip_type; #include "dpcpp/stop/batch_criteria.hpp" +namespace gko { +namespace kernels { +namespace dpcpp { + + +template +inline std::decay_t as_device_type(T val) +{ + return val; +} + + +} // namespace dpcpp +} // namespace kernels +} // namespace gko + + namespace gko { namespace batch { namespace solver { @@ -114,6 +131,23 @@ using DeviceValueType = ValueType; #include "reference/stop/batch_criteria.hpp" +namespace gko { +namespace kernels { +namespace host { + + +template +inline std::decay_t as_device_type(T val) +{ + return val; +} + + +} // namespace host +} // namespace kernels +} // namespace gko + + namespace gko { namespace batch { namespace solver { @@ -181,6 +215,7 @@ class batch_solver_dispatch { using value_type = ValueType; using device_value_type = DeviceValueType; using real_type = remove_complex; + using device_real_type = DeviceValueType; batch_solver_dispatch( const KernelCaller& kernel_caller, const SettingsType& settings, @@ -270,8 +305,9 @@ class batch_solver_dispatch { { if (logger_type_ == log::detail::log_type::simple_convergence_completion) { - device::batch_log::SimpleFinalLogger logger( - log_data.res_norms.get_data(), log_data.iter_counts.get_data()); + device::batch_log::SimpleFinalLogger logger( + device::as_device_type(log_data.res_norms.get_data()), + log_data.iter_counts.get_data()); dispatch_on_preconditioner(logger, amat, b_item, x_item); } else { GKO_NOT_IMPLEMENTED; diff --git a/cuda/preconditioner/batch_jacobi_kernels.cu b/cuda/preconditioner/batch_jacobi_kernels.cu index 8768937dc6d..41f861cc3af 100644 --- a/cuda/preconditioner/batch_jacobi_kernels.cu +++ b/cuda/preconditioner/batch_jacobi_kernels.cu @@ -100,7 +100,7 @@ void extract_common_blocks_pattern( blocks_pattern); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE_WITH_HALF( GKO_DECLARE_BATCH_BLOCK_JACOBI_EXTRACT_PATTERN_KERNEL); @@ -158,7 +158,7 @@ void compute_block_jacobi( cumulative_block_storage, block_pointers, blocks_pattern, blocks); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE_WITH_HALF( GKO_DECLARE_BATCH_BLOCK_JACOBI_COMPUTE_KERNEL); diff --git a/cuda/solver/batch_bicgstab_kernels.cu b/cuda/solver/batch_bicgstab_kernels.cu index 8a5eee6b196..991ba9c2dc4 100644 --- a/cuda/solver/batch_bicgstab_kernels.cu +++ b/cuda/solver/batch_bicgstab_kernels.cu @@ -75,13 +75,13 @@ template using settings = gko::kernels::batch_bicgstab::settings; -template +template class kernel_caller { public: - using value_type = CuValueType; + using cu_value_type = cuda_type; kernel_caller(std::shared_ptr exec, - const settings> settings) + const settings> settings) : exec_{std::move(exec)}, settings_{settings} {} @@ -91,16 +91,17 @@ public: void launch_apply_kernel( const gko::kernels::batch_bicgstab::storage_config& sconf, LogType& logger, PrecType& prec, const BatchMatrixType& mat, - const value_type* const __restrict__ b_values, - value_type* const __restrict__ x_values, - value_type* const __restrict__ workspace_data, const int& block_size, + const cu_value_type* const __restrict__ b_values, + cu_value_type* const __restrict__ x_values, + cu_value_type* const __restrict__ workspace_data, const int& block_size, const size_t& shared_size) const { batch_single_kernels::apply_kernel <<get_stream()>>>(sconf, settings_.max_iterations, - settings_.residual_tol, logger, prec, mat, - b_values, x_values, workspace_data); + as_cuda_type(settings_.residual_tol), + logger, prec, mat, b_values, x_values, + workspace_data); } @@ -108,21 +109,20 @@ public: typename LogType> void call_kernel( LogType logger, const BatchMatrixType& mat, PrecType prec, - const gko::batch::multi_vector::uniform_batch& b, - const gko::batch::multi_vector::uniform_batch& x) const + const gko::batch::multi_vector::uniform_batch& b, + const gko::batch::multi_vector::uniform_batch& x) const { - using real_type = gko::remove_complex; + using real_type = gko::remove_complex; const size_type num_batch_items = mat.num_batch_items; constexpr int align_multiple = 8; const int padded_num_rows = ceildiv(mat.num_rows, align_multiple) * align_multiple; - const int shmem_per_blk = - get_max_dynamic_shared_memory(exec_); + const int shmem_per_blk = get_max_dynamic_shared_memory< + StopType, PrecType, LogType, BatchMatrixType, cu_value_type>(exec_); // TODO const int block_size = 256; // get_num_threads_per_block( + // BatchMatrixType, cu_value_type>( // exec_, mat.num_rows); GKO_ASSERT(block_size >= 2 * config::warp_size); @@ -130,18 +130,18 @@ public: padded_num_rows, mat.get_single_item_num_nnz()); const auto sconf = gko::kernels::batch_bicgstab::compute_shared_storage( + cu_value_type>( shmem_per_blk, padded_num_rows, mat.get_single_item_num_nnz(), b.num_rhs); const size_t shared_size = - sconf.n_shared * padded_num_rows * sizeof(value_type) + + sconf.n_shared * padded_num_rows * sizeof(cu_value_type) + (sconf.prec_shared ? prec_size : 0); - auto workspace = gko::array( + auto workspace = gko::array( exec_, - sconf.gmem_stride_bytes * num_batch_items / sizeof(value_type)); - GKO_ASSERT(sconf.gmem_stride_bytes % sizeof(value_type) == 0); + sconf.gmem_stride_bytes * num_batch_items / sizeof(cu_value_type)); + GKO_ASSERT(sconf.gmem_stride_bytes % sizeof(cu_value_type) == 0); - value_type* const workspace_data = workspace.get_data(); + cu_value_type* const workspace_data = workspace.get_data(); // TODO: split compilation // Template parameters launch_apply_kernel exec_; - const settings> settings_; + const settings> settings_; }; @@ -223,13 +223,13 @@ void apply(std::shared_ptr exec, batch::MultiVector* const x, batch::log::detail::log_data>& logdata) { - using cu_value_type = cuda_type; auto dispatcher = batch::solver::create_dispatcher( - kernel_caller(exec, settings), settings, mat, precon); + kernel_caller(exec, settings), settings, mat, precon); dispatcher.apply(b, x, logdata); } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_BATCH_BICGSTAB_APPLY_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( + GKO_DECLARE_BATCH_BICGSTAB_APPLY_KERNEL); } // namespace batch_bicgstab diff --git a/cuda/solver/batch_cg_kernels.cu b/cuda/solver/batch_cg_kernels.cu index 32e66d7ee54..e9656a3b8a8 100644 --- a/cuda/solver/batch_cg_kernels.cu +++ b/cuda/solver/batch_cg_kernels.cu @@ -75,13 +75,14 @@ template using settings = gko::kernels::batch_cg::settings; -template +template class kernel_caller { public: - using value_type = CuValueType; + using cu_value_type = cuda_type; + ; kernel_caller(std::shared_ptr exec, - const settings> settings) + const settings> settings) : exec_{std::move(exec)}, settings_{settings} {} @@ -91,36 +92,36 @@ public: void launch_apply_kernel( const gko::kernels::batch_cg::storage_config& sconf, LogType& logger, PrecType& prec, const BatchMatrixType& mat, - const value_type* const __restrict__ b_values, - value_type* const __restrict__ x_values, - value_type* const __restrict__ workspace_data, const int& block_size, + const cu_value_type* const __restrict__ b_values, + cu_value_type* const __restrict__ x_values, + cu_value_type* const __restrict__ workspace_data, const int& block_size, const size_t& shared_size) const { batch_single_kernels::apply_kernel <<get_stream()>>>(sconf, settings_.max_iterations, - settings_.residual_tol, logger, prec, mat, - b_values, x_values, workspace_data); + as_cuda_type(settings_.residual_tol), + logger, prec, mat, b_values, x_values, + workspace_data); } template void call_kernel( LogType logger, const BatchMatrixType& mat, PrecType prec, - const gko::batch::multi_vector::uniform_batch& b, - const gko::batch::multi_vector::uniform_batch& x) const + const gko::batch::multi_vector::uniform_batch& b, + const gko::batch::multi_vector::uniform_batch& x) const { - using real_type = gko::remove_complex; + using real_type = gko::remove_complex; const size_type num_batch_items = mat.num_batch_items; constexpr int align_multiple = 8; const int padded_num_rows = ceildiv(mat.num_rows, align_multiple) * align_multiple; - const int shmem_per_blk = - get_max_dynamic_shared_memory(exec_); + const int shmem_per_blk = get_max_dynamic_shared_memory< + StopType, PrecType, LogType, BatchMatrixType, cu_value_type>(exec_); const int block_size = get_num_threads_per_block( + BatchMatrixType, cu_value_type>( exec_, mat.num_rows); GKO_ASSERT(block_size >= 2 * config::warp_size); @@ -128,18 +129,18 @@ public: padded_num_rows, mat.get_single_item_num_nnz()); const auto sconf = gko::kernels::batch_cg::compute_shared_storage( + cu_value_type>( shmem_per_blk, padded_num_rows, mat.get_single_item_num_nnz(), b.num_rhs); const size_t shared_size = - sconf.n_shared * padded_num_rows * sizeof(value_type) + + sconf.n_shared * padded_num_rows * sizeof(cu_value_type) + (sconf.prec_shared ? prec_size : 0); - auto workspace = gko::array( + auto workspace = gko::array( exec_, - sconf.gmem_stride_bytes * num_batch_items / sizeof(value_type)); - GKO_ASSERT(sconf.gmem_stride_bytes % sizeof(value_type) == 0); + sconf.gmem_stride_bytes * num_batch_items / sizeof(cu_value_type)); + GKO_ASSERT(sconf.gmem_stride_bytes % sizeof(cu_value_type) == 0); - value_type* const workspace_data = workspace.get_data(); + cu_value_type* const workspace_data = workspace.get_data(); // TODO: split compilation // Only instantiate when full optimizations has been enabled. Otherwise, @@ -190,7 +191,7 @@ public: private: std::shared_ptr exec_; - const settings> settings_; + const settings> settings_; }; @@ -203,13 +204,13 @@ void apply(std::shared_ptr exec, batch::MultiVector* const x, batch::log::detail::log_data>& logdata) { - using cu_value_type = cuda_type; auto dispatcher = batch::solver::create_dispatcher( - kernel_caller(exec, settings), settings, mat, precon); + kernel_caller(exec, settings), settings, mat, precon); dispatcher.apply(b, x, logdata); } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_BATCH_CG_APPLY_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( + GKO_DECLARE_BATCH_CG_APPLY_KERNEL); } // namespace batch_cg diff --git a/dpcpp/base/batch_multi_vector_kernels.dp.cpp b/dpcpp/base/batch_multi_vector_kernels.dp.cpp index 0d2662bdccd..db226d08eee 100644 --- a/dpcpp/base/batch_multi_vector_kernels.dp.cpp +++ b/dpcpp/base/batch_multi_vector_kernels.dp.cpp @@ -102,7 +102,7 @@ void scale(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( GKO_DECLARE_BATCH_MULTI_VECTOR_SCALE_KERNEL); @@ -161,7 +161,7 @@ void add_scaled(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( GKO_DECLARE_BATCH_MULTI_VECTOR_ADD_SCALED_KERNEL); @@ -230,7 +230,7 @@ void compute_dot(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( GKO_DECLARE_BATCH_MULTI_VECTOR_COMPUTE_DOT_KERNEL); @@ -275,7 +275,7 @@ void compute_conj_dot(std::shared_ptr exec, }); } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( GKO_DECLARE_BATCH_MULTI_VECTOR_COMPUTE_CONJ_DOT_KERNEL); @@ -334,7 +334,7 @@ void compute_norm2(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( GKO_DECLARE_BATCH_MULTI_VECTOR_COMPUTE_NORM2_KERNEL); @@ -372,7 +372,8 @@ void copy(std::shared_ptr exec, }); } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_BATCH_MULTI_VECTOR_COPY_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( + GKO_DECLARE_BATCH_MULTI_VECTOR_COPY_KERNEL); } // namespace batch_multi_vector diff --git a/dpcpp/base/batch_multi_vector_kernels.hpp b/dpcpp/base/batch_multi_vector_kernels.hpp index 74abaeda86f..96ada23f42c 100644 --- a/dpcpp/base/batch_multi_vector_kernels.hpp +++ b/dpcpp/base/batch_multi_vector_kernels.hpp @@ -65,25 +65,6 @@ __dpct_inline__ void add_scaled_kernel( } -template -__dpct_inline__ void single_rhs_compute_conj_dot( - const int num_rows, const ValueType* const __restrict__ x, - const ValueType* const __restrict__ y, ValueType& result, - sycl::nd_item<3> item_ct1) -{ - const auto group = item_ct1.get_group(); - const auto group_size = item_ct1.get_local_range().size(); - const auto tid = item_ct1.get_local_linear_id(); - - ValueType val = zero(); - - for (int r = tid; r < num_rows; r += group_size) { - val += conj(x[r]) * y[r]; - } - result = sycl::reduce_over_group(group, val, sycl::plus<>()); -} - - template __dpct_inline__ void single_rhs_compute_conj_dot_sg( const int num_rows, const ValueType* const __restrict__ x, @@ -174,28 +155,6 @@ __dpct_inline__ void single_rhs_compute_norm2_sg( } -template -__dpct_inline__ void single_rhs_compute_norm2( - const int num_rows, const ValueType* const __restrict__ x, - gko::remove_complex& result, sycl::nd_item<3> item_ct1) -{ - const auto group = item_ct1.get_group(); - const auto group_size = item_ct1.get_local_range().size(); - const auto tid = item_ct1.get_local_linear_id(); - - using real_type = typename gko::remove_complex; - real_type val = zero(); - - for (int r = tid; r < num_rows; r += group_size) { - val += squared_norm(x[r]); - } - - val = sycl::reduce_over_group(group, val, sycl::plus<>()); - - result = sqrt(val); -} - - template __dpct_inline__ void compute_norm2_kernel( const gko::batch::multi_vector::batch_item& x, diff --git a/dpcpp/matrix/batch_csr_kernels.dp.cpp b/dpcpp/matrix/batch_csr_kernels.dp.cpp index 1759a959299..ae5122ec7f9 100644 --- a/dpcpp/matrix/batch_csr_kernels.dp.cpp +++ b/dpcpp/matrix/batch_csr_kernels.dp.cpp @@ -73,7 +73,7 @@ void simple_apply(std::shared_ptr exec, }); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE_WITH_HALF( GKO_DECLARE_BATCH_CSR_SIMPLE_APPLY_KERNEL); @@ -127,7 +127,7 @@ void advanced_apply(std::shared_ptr exec, }); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE_WITH_HALF( GKO_DECLARE_BATCH_CSR_ADVANCED_APPLY_KERNEL); @@ -173,7 +173,7 @@ void scale(std::shared_ptr exec, }); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE_WITH_HALF( GKO_DECLARE_BATCH_CSR_SCALE_KERNEL); @@ -215,7 +215,7 @@ void add_scaled_identity(std::shared_ptr exec, }); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE_WITH_HALF( GKO_DECLARE_BATCH_CSR_ADD_SCALED_IDENTITY_KERNEL); diff --git a/dpcpp/matrix/batch_dense_kernels.dp.cpp b/dpcpp/matrix/batch_dense_kernels.dp.cpp index 43974589abb..6c0e4b4eb44 100644 --- a/dpcpp/matrix/batch_dense_kernels.dp.cpp +++ b/dpcpp/matrix/batch_dense_kernels.dp.cpp @@ -76,7 +76,7 @@ void simple_apply(std::shared_ptr exec, }); } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( GKO_DECLARE_BATCH_DENSE_SIMPLE_APPLY_KERNEL); @@ -129,7 +129,7 @@ void advanced_apply(std::shared_ptr exec, }); } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( GKO_DECLARE_BATCH_DENSE_ADVANCED_APPLY_KERNEL); @@ -173,7 +173,8 @@ void scale(std::shared_ptr exec, }); } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_BATCH_DENSE_SCALE_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( + GKO_DECLARE_BATCH_DENSE_SCALE_KERNEL); template @@ -215,7 +216,8 @@ void scale_add(std::shared_ptr exec, }); } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_BATCH_DENSE_SCALE_ADD_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( + GKO_DECLARE_BATCH_DENSE_SCALE_ADD_KERNEL); template @@ -256,7 +258,7 @@ void add_scaled_identity(std::shared_ptr exec, }); } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( GKO_DECLARE_BATCH_DENSE_ADD_SCALED_IDENTITY_KERNEL); diff --git a/dpcpp/matrix/batch_ell_kernels.dp.cpp b/dpcpp/matrix/batch_ell_kernels.dp.cpp index d9b819b101e..b4e2627a494 100644 --- a/dpcpp/matrix/batch_ell_kernels.dp.cpp +++ b/dpcpp/matrix/batch_ell_kernels.dp.cpp @@ -73,7 +73,7 @@ void simple_apply(std::shared_ptr exec, }); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE_WITH_HALF( GKO_DECLARE_BATCH_ELL_SIMPLE_APPLY_KERNEL); @@ -127,7 +127,7 @@ void advanced_apply(std::shared_ptr exec, }); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE_WITH_HALF( GKO_DECLARE_BATCH_ELL_ADVANCED_APPLY_KERNEL); @@ -170,7 +170,7 @@ void scale(std::shared_ptr exec, }); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE_WITH_HALF( GKO_DECLARE_BATCH_ELL_SCALE_KERNEL); @@ -212,7 +212,7 @@ void add_scaled_identity(std::shared_ptr exec, }); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE_WITH_HALF( GKO_DECLARE_BATCH_ELL_ADD_SCALED_IDENTITY_KERNEL); diff --git a/dpcpp/preconditioner/batch_block_jacobi.hpp b/dpcpp/preconditioner/batch_block_jacobi.hpp index a7431f919a5..04c21f97991 100644 --- a/dpcpp/preconditioner/batch_block_jacobi.hpp +++ b/dpcpp/preconditioner/batch_block_jacobi.hpp @@ -129,8 +129,11 @@ class BlockJacobi final { sum += block_val * r[dense_block_col + idx_start]; } - // reduction - sum = sycl::reduce_over_group(sg, sum, sycl::plus<>()); + // reduction (it does not support half) + // sum = sycl::reduce_over_group(sg, sum, sycl::plus<>()); + for (int i = sg_size / 2; i > 0; i /= 2) { + sum += sg.shuffle_down(sum, i); + } if (sg_tid == 0) { z[row_idx] = sum; diff --git a/dpcpp/preconditioner/batch_jacobi_kernels.dp.cpp b/dpcpp/preconditioner/batch_jacobi_kernels.dp.cpp index d85f93e74f2..bdf1502492c 100644 --- a/dpcpp/preconditioner/batch_jacobi_kernels.dp.cpp +++ b/dpcpp/preconditioner/batch_jacobi_kernels.dp.cpp @@ -105,7 +105,7 @@ void extract_common_blocks_pattern( }); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE_WITH_HALF( GKO_DECLARE_BATCH_BLOCK_JACOBI_EXTRACT_PATTERN_KERNEL); @@ -175,7 +175,7 @@ void compute_block_jacobi( cumulative_block_storage, block_pointers, blocks_pattern, blocks, exec); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE_WITH_HALF( GKO_DECLARE_BATCH_BLOCK_JACOBI_COMPUTE_KERNEL); diff --git a/dpcpp/solver/batch_bicgstab_kernels.dp.cpp b/dpcpp/solver/batch_bicgstab_kernels.dp.cpp index 7036b770f1b..9a335103786 100644 --- a/dpcpp/solver/batch_bicgstab_kernels.dp.cpp +++ b/dpcpp/solver/batch_bicgstab_kernels.dp.cpp @@ -250,7 +250,8 @@ void apply(std::shared_ptr exec, } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_BATCH_BICGSTAB_APPLY_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( + GKO_DECLARE_BATCH_BICGSTAB_APPLY_KERNEL); } // namespace batch_bicgstab diff --git a/dpcpp/solver/batch_cg_kernels.dp.cpp b/dpcpp/solver/batch_cg_kernels.dp.cpp index 9d3aa14ab2c..e587c1d4143 100644 --- a/dpcpp/solver/batch_cg_kernels.dp.cpp +++ b/dpcpp/solver/batch_cg_kernels.dp.cpp @@ -224,7 +224,8 @@ void apply(std::shared_ptr exec, } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_BATCH_CG_APPLY_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( + GKO_DECLARE_BATCH_CG_APPLY_KERNEL); } // namespace batch_cg diff --git a/hip/preconditioner/batch_jacobi_kernels.hip.cpp b/hip/preconditioner/batch_jacobi_kernels.hip.cpp index 2380bc6a0bd..df2cdb1ed75 100644 --- a/hip/preconditioner/batch_jacobi_kernels.hip.cpp +++ b/hip/preconditioner/batch_jacobi_kernels.hip.cpp @@ -102,7 +102,7 @@ void extract_common_blocks_pattern( blocks_pattern); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE_WITH_HALF( GKO_DECLARE_BATCH_BLOCK_JACOBI_EXTRACT_PATTERN_KERNEL); @@ -161,7 +161,7 @@ void compute_block_jacobi( cumulative_block_storage, block_pointers, blocks_pattern, blocks); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE_WITH_HALF( GKO_DECLARE_BATCH_BLOCK_JACOBI_COMPUTE_KERNEL); diff --git a/hip/solver/batch_bicgstab_kernels.hip.cpp b/hip/solver/batch_bicgstab_kernels.hip.cpp index 17199d2cd19..f53228bad79 100644 --- a/hip/solver/batch_bicgstab_kernels.hip.cpp +++ b/hip/solver/batch_bicgstab_kernels.hip.cpp @@ -55,13 +55,13 @@ template using settings = gko::kernels::batch_bicgstab::settings; -template +template class kernel_caller { public: - using value_type = HipValueType; + using hip_value_type = hip_type; kernel_caller(std::shared_ptr exec, - const settings> settings) + const settings> settings) : exec_{exec}, settings_{settings} {} @@ -71,16 +71,17 @@ class kernel_caller { void launch_apply_kernel( const gko::kernels::batch_bicgstab::storage_config& sconf, LogType& logger, PrecType& prec, const BatchMatrixType& mat, - const value_type* const __restrict__ b_values, - value_type* const __restrict__ x_values, - value_type* const __restrict__ workspace_data, const int& block_size, - const size_t& shared_size) const + const hip_value_type* const __restrict__ b_values, + hip_value_type* const __restrict__ x_values, + hip_value_type* const __restrict__ workspace_data, + const int& block_size, const size_t& shared_size) const { batch_single_kernels::apply_kernel <<get_stream()>>>(sconf, settings_.max_iterations, - settings_.residual_tol, logger, prec, mat, - b_values, x_values, workspace_data); + as_hip_type(settings_.residual_tol), + logger, prec, mat, b_values, x_values, + workspace_data); } @@ -88,10 +89,10 @@ class kernel_caller { typename LogType> void call_kernel( LogType logger, const BatchMatrixType& mat, PrecType prec, - const gko::batch::multi_vector::uniform_batch& b, - const gko::batch::multi_vector::uniform_batch& x) const + const gko::batch::multi_vector::uniform_batch& b, + const gko::batch::multi_vector::uniform_batch& x) const { - using real_type = gko::remove_complex; + using real_type = gko::remove_complex; const size_type num_batch_items = mat.num_batch_items; constexpr int align_multiple = 8; const int padded_num_rows = @@ -109,22 +110,20 @@ class kernel_caller { // Returns amount required in bytes const size_t prec_size = PrecType::dynamic_work_size( padded_num_rows, mat.get_single_item_num_nnz()); - const auto sconf = - gko::kernels::batch_bicgstab::compute_shared_storage( - shmem_per_blk, padded_num_rows, mat.get_single_item_num_nnz(), - b.num_rhs); + const auto sconf = gko::kernels::batch_bicgstab::compute_shared_storage< + PrecType, hip_value_type>(shmem_per_blk, padded_num_rows, + mat.get_single_item_num_nnz(), b.num_rhs); const size_t shared_size = - sconf.n_shared * padded_num_rows * sizeof(value_type) + + sconf.n_shared * padded_num_rows * sizeof(hip_value_type) + (sconf.prec_shared ? prec_size : 0); - auto workspace = gko::array( + auto workspace = gko::array( exec_, - sconf.gmem_stride_bytes * num_batch_items / sizeof(value_type)); + sconf.gmem_stride_bytes * num_batch_items / sizeof(hip_value_type)); bool is_stride_aligned = - sconf.gmem_stride_bytes % sizeof(value_type) == 0; + sconf.gmem_stride_bytes % sizeof(hip_value_type) == 0; GKO_ASSERT(is_stride_aligned); - value_type* const workspace_data = workspace.get_data(); + hip_value_type* const workspace_data = workspace.get_data(); // Only instantiate when full optimizations has been enabled. Otherwise, // just use the default one with no shared memory. @@ -194,7 +193,7 @@ class kernel_caller { private: std::shared_ptr exec_; - const settings> settings_; + const settings> settings_; }; @@ -207,13 +206,13 @@ void apply(std::shared_ptr exec, batch::MultiVector* const x, batch::log::detail::log_data>& logdata) { - using hip_value_type = hip_type; auto dispatcher = batch::solver::create_dispatcher( - kernel_caller(exec, settings), settings, mat, precon); + kernel_caller(exec, settings), settings, mat, precon); dispatcher.apply(b, x, logdata); } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_BATCH_BICGSTAB_APPLY_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( + GKO_DECLARE_BATCH_BICGSTAB_APPLY_KERNEL); } // namespace batch_bicgstab diff --git a/hip/solver/batch_cg_kernels.hip.cpp b/hip/solver/batch_cg_kernels.hip.cpp index 6d5e3bff3b3..4583582a216 100644 --- a/hip/solver/batch_cg_kernels.hip.cpp +++ b/hip/solver/batch_cg_kernels.hip.cpp @@ -54,13 +54,13 @@ template using settings = gko::kernels::batch_cg::settings; -template +template class kernel_caller { public: - using value_type = HipValueType; + using hip_value_type = hip_type; kernel_caller(std::shared_ptr exec, - const settings> settings) + const settings> settings) : exec_{exec}, settings_{settings} {} @@ -70,16 +70,17 @@ class kernel_caller { void launch_apply_kernel( const gko::kernels::batch_cg::storage_config& sconf, LogType& logger, PrecType& prec, const BatchMatrixType& mat, - const value_type* const __restrict__ b_values, - value_type* const __restrict__ x_values, - value_type* const __restrict__ workspace_data, const int& block_size, - const size_t& shared_size) const + const hip_value_type* const __restrict__ b_values, + hip_value_type* const __restrict__ x_values, + hip_value_type* const __restrict__ workspace_data, + const int& block_size, const size_t& shared_size) const { batch_single_kernels::apply_kernel <<get_stream()>>>(sconf, settings_.max_iterations, - settings_.residual_tol, logger, prec, mat, - b_values, x_values, workspace_data); + as_hip_type(settings_.residual_tol), + logger, prec, mat, b_values, x_values, + workspace_data); } @@ -87,10 +88,10 @@ class kernel_caller { typename LogType> void call_kernel( LogType logger, const BatchMatrixType& mat, PrecType prec, - const gko::batch::multi_vector::uniform_batch& b, - const gko::batch::multi_vector::uniform_batch& x) const + const gko::batch::multi_vector::uniform_batch& b, + const gko::batch::multi_vector::uniform_batch& x) const { - using real_type = gko::remove_complex; + using real_type = gko::remove_complex; const size_type num_batch_items = mat.num_batch_items; constexpr int align_multiple = 8; const int padded_num_rows = @@ -110,20 +111,20 @@ class kernel_caller { padded_num_rows, mat.get_single_item_num_nnz()); const auto sconf = gko::kernels::batch_cg::compute_shared_storage( + hip_value_type>( shmem_per_blk, padded_num_rows, mat.get_single_item_num_nnz(), b.num_rhs); const size_t shared_size = - sconf.n_shared * padded_num_rows * sizeof(value_type) + + sconf.n_shared * padded_num_rows * sizeof(hip_value_type) + (sconf.prec_shared ? prec_size : 0); - auto workspace = gko::array( + auto workspace = gko::array( exec_, - sconf.gmem_stride_bytes * num_batch_items / sizeof(value_type)); + sconf.gmem_stride_bytes * num_batch_items / sizeof(hip_value_type)); bool is_stride_aligned = - sconf.gmem_stride_bytes % sizeof(value_type) == 0; + sconf.gmem_stride_bytes % sizeof(hip_value_type) == 0; GKO_ASSERT(is_stride_aligned); - value_type* const workspace_data = workspace.get_data(); + hip_value_type* const workspace_data = workspace.get_data(); // Only instantiate when full optimizations has been enabled. Otherwise, // just use the default one with no shared memory. @@ -173,7 +174,7 @@ class kernel_caller { private: std::shared_ptr exec_; - const settings> settings_; + const settings> settings_; }; @@ -186,13 +187,13 @@ void apply(std::shared_ptr exec, batch::MultiVector* const x, batch::log::detail::log_data>& logdata) { - using hip_value_type = hip_type; auto dispatcher = batch::solver::create_dispatcher( - kernel_caller(exec, settings), settings, mat, precon); + kernel_caller(exec, settings), settings, mat, precon); dispatcher.apply(b, x, logdata); } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_BATCH_CG_APPLY_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( + GKO_DECLARE_BATCH_CG_APPLY_KERNEL); } // namespace batch_cg diff --git a/include/ginkgo/core/base/batch_multi_vector.hpp b/include/ginkgo/core/base/batch_multi_vector.hpp index d04e9562fce..bd641f057a1 100644 --- a/include/ginkgo/core/base/batch_multi_vector.hpp +++ b/include/ginkgo/core/base/batch_multi_vector.hpp @@ -52,16 +52,22 @@ template class MultiVector : public EnablePolymorphicObject>, public EnablePolymorphicAssignment>, - public ConvertibleTo>> { +#if GINKGO_ENABLE_HALF + public ConvertibleTo>>>, +#endif + public ConvertibleTo>> { friend class EnablePolymorphicObject; friend class MultiVector>; - friend class MultiVector>; + friend class MultiVector>; public: using EnablePolymorphicAssignment::convert_to; using EnablePolymorphicAssignment::move_to; - using ConvertibleTo>>::convert_to; - using ConvertibleTo>>::move_to; + using ConvertibleTo< + MultiVector>>::convert_to; + using ConvertibleTo< + MultiVector>>::move_to; using value_type = ValueType; using index_type = int32; @@ -78,10 +84,28 @@ class MultiVector static std::unique_ptr create_with_config_of( ptr_param other); + void convert_to(MultiVector>* result) + const override; + + void move_to( + MultiVector>* result) override; + +#if GINKGO_ENABLE_HALF + friend class MultiVector< + previous_precision_with_half>>; + using ConvertibleTo>>>::convert_to; + using ConvertibleTo>>>::move_to; + void convert_to( - MultiVector>* result) const override; + MultiVector< + next_precision_with_half>>* + result) const override; - void move_to(MultiVector>* result) override; + void move_to(MultiVector>>* result) override; +#endif /** * Creates a mutable view (of matrix::Dense type) of one item of the Batch diff --git a/include/ginkgo/core/log/logger.hpp b/include/ginkgo/core/log/logger.hpp index dd9d30249e9..b05b15fcc0c 100644 --- a/include/ginkgo/core/log/logger.hpp +++ b/include/ginkgo/core/log/logger.hpp @@ -18,6 +18,7 @@ namespace gko { +class half; /* Eliminate circular dependencies the hard way */ template @@ -579,6 +580,17 @@ public: \ const array& iters, const array& residual_norms) const {} + /** + * Batch solver's event that records the iteration count and the residual + * norm. + * + * @param iters the array of iteration counts. + * @param residual_norms the array storing the residual norms. + */ + virtual void on_batch_solver_completed( + const array& iters, const array& residual_norms) const + {} + public: #undef GKO_LOGGER_REGISTER_EVENT diff --git a/include/ginkgo/core/matrix/batch_csr.hpp b/include/ginkgo/core/matrix/batch_csr.hpp index e431454063d..49eb5e4d7cd 100644 --- a/include/ginkgo/core/matrix/batch_csr.hpp +++ b/include/ginkgo/core/matrix/batch_csr.hpp @@ -46,10 +46,16 @@ namespace matrix { template class Csr final : public EnableBatchLinOp>, - public ConvertibleTo, IndexType>> { +#if GINKGO_ENABLE_HALF + public ConvertibleTo< + Csr>, + IndexType>>, +#endif + public ConvertibleTo< + Csr, IndexType>> { friend class EnablePolymorphicObject; friend class Csr, IndexType>; - friend class Csr, IndexType>; + friend class Csr, IndexType>; static_assert(std::is_same::value, "IndexType must be a 32 bit integer"); @@ -63,10 +69,31 @@ class Csr final using absolute_type = remove_complex; using complex_type = to_complex; + void convert_to(Csr, IndexType>* result) + const override; + + void move_to( + Csr, IndexType>* result) override; + +#if GINKGO_ENABLE_HALF + friend class Csr< + previous_precision_with_half>, + IndexType>; + using ConvertibleTo< + Csr>, + IndexType>>::convert_to; + using ConvertibleTo< + Csr>, + IndexType>>::move_to; + void convert_to( - Csr, IndexType>* result) const override; + Csr>, + IndexType>* result) const override; - void move_to(Csr, IndexType>* result) override; + void move_to( + Csr>, + IndexType>* result) override; +#endif /** * Creates a mutable view (of matrix::Csr type) of one item of the diff --git a/include/ginkgo/core/matrix/batch_dense.hpp b/include/ginkgo/core/matrix/batch_dense.hpp index 5ea7c3ee128..c1340e482f4 100644 --- a/include/ginkgo/core/matrix/batch_dense.hpp +++ b/include/ginkgo/core/matrix/batch_dense.hpp @@ -45,11 +45,16 @@ namespace matrix { * @ingroup BatchLinOp */ template -class Dense final : public EnableBatchLinOp>, - public ConvertibleTo>> { +class Dense final + : public EnableBatchLinOp>, +#if GINKGO_ENABLE_HALF + public ConvertibleTo< + Dense>>>, +#endif + public ConvertibleTo>> { friend class EnablePolymorphicObject; friend class Dense>; - friend class Dense>; + friend class Dense>; public: using EnableBatchLinOp::convert_to; @@ -62,9 +67,27 @@ class Dense final : public EnableBatchLinOp>, using absolute_type = remove_complex; using complex_type = to_complex; - void convert_to(Dense>* result) const override; + void convert_to( + Dense>* result) const override; - void move_to(Dense>* result) override; + void move_to(Dense>* result) override; + +#if GINKGO_ENABLE_HALF + friend class Dense< + previous_precision_with_half>>; + using ConvertibleTo>>>::convert_to; + using ConvertibleTo>>>::move_to; + + void convert_to( + Dense>>* + result) const override; + + void move_to( + Dense>>* + result) override; +#endif /** * Creates a mutable view (of gko::matrix::Dense type) of one item of the diff --git a/include/ginkgo/core/matrix/batch_ell.hpp b/include/ginkgo/core/matrix/batch_ell.hpp index b760cee795a..872b8ce2db9 100644 --- a/include/ginkgo/core/matrix/batch_ell.hpp +++ b/include/ginkgo/core/matrix/batch_ell.hpp @@ -51,10 +51,16 @@ namespace matrix { template class Ell final : public EnableBatchLinOp>, - public ConvertibleTo, IndexType>> { +#if GINKGO_ENABLE_HALF + public ConvertibleTo< + Ell>, + IndexType>>, +#endif + public ConvertibleTo< + Ell, IndexType>> { friend class EnablePolymorphicObject; friend class Ell, IndexType>; - friend class Ell, IndexType>; + friend class Ell, IndexType>; static_assert(std::is_same::value, "IndexType must be a 32 bit integer"); @@ -68,10 +74,31 @@ class Ell final using absolute_type = remove_complex; using complex_type = to_complex; + void convert_to(Ell, IndexType>* result) + const override; + + void move_to( + Ell, IndexType>* result) override; + +#if GINKGO_ENABLE_HALF + friend class Ell< + previous_precision_with_half>, + IndexType>; + using ConvertibleTo< + Ell>, + IndexType>>::convert_to; + using ConvertibleTo< + Ell>, + IndexType>>::move_to; + void convert_to( - Ell, IndexType>* result) const override; + Ell>, + IndexType>* result) const override; - void move_to(Ell, IndexType>* result) override; + void move_to( + Ell>, + IndexType>* result) override; +#endif /** * Creates a mutable view (of matrix::Ell type) of one item of the diff --git a/omp/base/batch_multi_vector_kernels.cpp b/omp/base/batch_multi_vector_kernels.cpp index f740e3c32f0..17c0b81e1dc 100644 --- a/omp/base/batch_multi_vector_kernels.cpp +++ b/omp/base/batch_multi_vector_kernels.cpp @@ -37,7 +37,7 @@ void scale(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( GKO_DECLARE_BATCH_MULTI_VECTOR_SCALE_KERNEL); @@ -59,7 +59,7 @@ void add_scaled(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( GKO_DECLARE_BATCH_MULTI_VECTOR_ADD_SCALED_KERNEL); @@ -81,7 +81,7 @@ void compute_dot(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( GKO_DECLARE_BATCH_MULTI_VECTOR_COMPUTE_DOT_KERNEL); @@ -103,7 +103,7 @@ void compute_conj_dot(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( GKO_DECLARE_BATCH_MULTI_VECTOR_COMPUTE_CONJ_DOT_KERNEL); @@ -122,7 +122,7 @@ void compute_norm2(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( GKO_DECLARE_BATCH_MULTI_VECTOR_COMPUTE_NORM2_KERNEL); @@ -141,7 +141,8 @@ void copy(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_BATCH_MULTI_VECTOR_COPY_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( + GKO_DECLARE_BATCH_MULTI_VECTOR_COPY_KERNEL); } // namespace batch_multi_vector diff --git a/omp/matrix/batch_csr_kernels.cpp b/omp/matrix/batch_csr_kernels.cpp index d4ea6cbd642..b55253e9d4e 100644 --- a/omp/matrix/batch_csr_kernels.cpp +++ b/omp/matrix/batch_csr_kernels.cpp @@ -41,7 +41,7 @@ void simple_apply(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE_WITH_HALF( GKO_DECLARE_BATCH_CSR_SIMPLE_APPLY_KERNEL); @@ -71,7 +71,7 @@ void advanced_apply(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE_WITH_HALF( GKO_DECLARE_BATCH_CSR_ADVANCED_APPLY_KERNEL); @@ -98,7 +98,7 @@ void scale(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE_WITH_HALF( GKO_DECLARE_BATCH_CSR_SCALE_KERNEL); @@ -122,7 +122,7 @@ void add_scaled_identity(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE_WITH_HALF( GKO_DECLARE_BATCH_CSR_ADD_SCALED_IDENTITY_KERNEL); diff --git a/omp/matrix/batch_dense_kernels.cpp b/omp/matrix/batch_dense_kernels.cpp index cd4a7f05b4a..ea7da295bb4 100644 --- a/omp/matrix/batch_dense_kernels.cpp +++ b/omp/matrix/batch_dense_kernels.cpp @@ -41,7 +41,7 @@ void simple_apply(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( GKO_DECLARE_BATCH_DENSE_SIMPLE_APPLY_KERNEL); @@ -71,7 +71,7 @@ void advanced_apply(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( GKO_DECLARE_BATCH_DENSE_ADVANCED_APPLY_KERNEL); @@ -98,7 +98,8 @@ void scale(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_BATCH_DENSE_SCALE_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( + GKO_DECLARE_BATCH_DENSE_SCALE_KERNEL); template @@ -121,7 +122,8 @@ void scale_add(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_BATCH_DENSE_SCALE_ADD_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( + GKO_DECLARE_BATCH_DENSE_SCALE_ADD_KERNEL); template @@ -144,7 +146,7 @@ void add_scaled_identity(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( GKO_DECLARE_BATCH_DENSE_ADD_SCALED_IDENTITY_KERNEL); diff --git a/omp/matrix/batch_ell_kernels.cpp b/omp/matrix/batch_ell_kernels.cpp index 8b1239565a1..74b8d94cfc8 100644 --- a/omp/matrix/batch_ell_kernels.cpp +++ b/omp/matrix/batch_ell_kernels.cpp @@ -41,7 +41,7 @@ void simple_apply(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE_WITH_HALF( GKO_DECLARE_BATCH_ELL_SIMPLE_APPLY_KERNEL); @@ -71,7 +71,7 @@ void advanced_apply(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE_WITH_HALF( GKO_DECLARE_BATCH_ELL_ADVANCED_APPLY_KERNEL); @@ -98,7 +98,7 @@ void scale(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE_WITH_HALF( GKO_DECLARE_BATCH_ELL_SCALE_KERNEL); @@ -122,7 +122,7 @@ void add_scaled_identity(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE_WITH_HALF( GKO_DECLARE_BATCH_ELL_ADD_SCALED_IDENTITY_KERNEL); diff --git a/omp/preconditioner/batch_jacobi_kernels.cpp b/omp/preconditioner/batch_jacobi_kernels.cpp index 90c8f0c1865..aa081150f29 100644 --- a/omp/preconditioner/batch_jacobi_kernels.cpp +++ b/omp/preconditioner/batch_jacobi_kernels.cpp @@ -75,7 +75,7 @@ void extract_common_blocks_pattern( } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE_WITH_HALF( GKO_DECLARE_BATCH_BLOCK_JACOBI_EXTRACT_PATTERN_KERNEL); @@ -103,7 +103,7 @@ void compute_block_jacobi( } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE_WITH_HALF( GKO_DECLARE_BATCH_BLOCK_JACOBI_COMPUTE_KERNEL); diff --git a/omp/solver/batch_bicgstab_kernels.cpp b/omp/solver/batch_bicgstab_kernels.cpp index ed880507116..f1732f7d129 100644 --- a/omp/solver/batch_bicgstab_kernels.cpp +++ b/omp/solver/batch_bicgstab_kernels.cpp @@ -91,7 +91,8 @@ void apply(std::shared_ptr exec, dispatcher.apply(b, x, logdata); } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_BATCH_BICGSTAB_APPLY_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( + GKO_DECLARE_BATCH_BICGSTAB_APPLY_KERNEL); } // namespace batch_bicgstab diff --git a/omp/solver/batch_cg_kernels.cpp b/omp/solver/batch_cg_kernels.cpp index 89d4441db64..3d8a8fa8b3c 100644 --- a/omp/solver/batch_cg_kernels.cpp +++ b/omp/solver/batch_cg_kernels.cpp @@ -97,7 +97,8 @@ void apply(std::shared_ptr exec, dispatcher.apply(b, x, logdata); } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_BATCH_CG_APPLY_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( + GKO_DECLARE_BATCH_CG_APPLY_KERNEL); } // namespace batch_cg diff --git a/reference/base/batch_multi_vector_kernels.cpp b/reference/base/batch_multi_vector_kernels.cpp index d7fbf3ce214..4f48a0b6f94 100644 --- a/reference/base/batch_multi_vector_kernels.cpp +++ b/reference/base/batch_multi_vector_kernels.cpp @@ -35,7 +35,7 @@ void scale(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( GKO_DECLARE_BATCH_MULTI_VECTOR_SCALE_KERNEL); @@ -56,7 +56,7 @@ void add_scaled(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( GKO_DECLARE_BATCH_MULTI_VECTOR_ADD_SCALED_KERNEL); @@ -77,7 +77,7 @@ void compute_dot(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( GKO_DECLARE_BATCH_MULTI_VECTOR_COMPUTE_DOT_KERNEL); @@ -98,7 +98,7 @@ void compute_conj_dot(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( GKO_DECLARE_BATCH_MULTI_VECTOR_COMPUTE_CONJ_DOT_KERNEL); @@ -116,7 +116,7 @@ void compute_norm2(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( GKO_DECLARE_BATCH_MULTI_VECTOR_COMPUTE_NORM2_KERNEL); @@ -134,7 +134,8 @@ void copy(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_BATCH_MULTI_VECTOR_COPY_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( + GKO_DECLARE_BATCH_MULTI_VECTOR_COPY_KERNEL); } // namespace batch_multi_vector diff --git a/reference/matrix/batch_csr_kernels.cpp b/reference/matrix/batch_csr_kernels.cpp index d3304ab9795..c277d4f0738 100644 --- a/reference/matrix/batch_csr_kernels.cpp +++ b/reference/matrix/batch_csr_kernels.cpp @@ -39,7 +39,7 @@ void simple_apply(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE_WITH_HALF( GKO_DECLARE_BATCH_CSR_SIMPLE_APPLY_KERNEL); @@ -68,7 +68,7 @@ void advanced_apply(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE_WITH_HALF( GKO_DECLARE_BATCH_CSR_ADVANCED_APPLY_KERNEL); @@ -94,7 +94,7 @@ void scale(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE_WITH_HALF( GKO_DECLARE_BATCH_CSR_SCALE_KERNEL); @@ -117,7 +117,7 @@ void add_scaled_identity(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE_WITH_HALF( GKO_DECLARE_BATCH_CSR_ADD_SCALED_IDENTITY_KERNEL); diff --git a/reference/matrix/batch_dense_kernels.cpp b/reference/matrix/batch_dense_kernels.cpp index 599af30ecfb..9c92fb54056 100644 --- a/reference/matrix/batch_dense_kernels.cpp +++ b/reference/matrix/batch_dense_kernels.cpp @@ -39,7 +39,7 @@ void simple_apply(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( GKO_DECLARE_BATCH_DENSE_SIMPLE_APPLY_KERNEL); @@ -68,7 +68,7 @@ void advanced_apply(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( GKO_DECLARE_BATCH_DENSE_ADVANCED_APPLY_KERNEL); @@ -94,7 +94,8 @@ void scale(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_BATCH_DENSE_SCALE_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( + GKO_DECLARE_BATCH_DENSE_SCALE_KERNEL); template @@ -116,7 +117,8 @@ void scale_add(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_BATCH_DENSE_SCALE_ADD_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( + GKO_DECLARE_BATCH_DENSE_SCALE_ADD_KERNEL); template @@ -138,7 +140,7 @@ void add_scaled_identity(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( GKO_DECLARE_BATCH_DENSE_ADD_SCALED_IDENTITY_KERNEL); diff --git a/reference/matrix/batch_ell_kernels.cpp b/reference/matrix/batch_ell_kernels.cpp index 1a4855f389f..bc0eb61e30d 100644 --- a/reference/matrix/batch_ell_kernels.cpp +++ b/reference/matrix/batch_ell_kernels.cpp @@ -39,7 +39,7 @@ void simple_apply(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE_WITH_HALF( GKO_DECLARE_BATCH_ELL_SIMPLE_APPLY_KERNEL); @@ -68,7 +68,7 @@ void advanced_apply(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE_WITH_HALF( GKO_DECLARE_BATCH_ELL_ADVANCED_APPLY_KERNEL); @@ -94,7 +94,7 @@ void scale(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE_WITH_HALF( GKO_DECLARE_BATCH_ELL_SCALE_KERNEL); @@ -117,7 +117,7 @@ void add_scaled_identity(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE_WITH_HALF( GKO_DECLARE_BATCH_ELL_ADD_SCALED_IDENTITY_KERNEL); diff --git a/reference/preconditioner/batch_jacobi_kernels.cpp b/reference/preconditioner/batch_jacobi_kernels.cpp index a012e019b41..7168ae9012e 100644 --- a/reference/preconditioner/batch_jacobi_kernels.cpp +++ b/reference/preconditioner/batch_jacobi_kernels.cpp @@ -71,7 +71,7 @@ void extract_common_blocks_pattern( } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE_WITH_HALF( GKO_DECLARE_BATCH_BLOCK_JACOBI_EXTRACT_PATTERN_KERNEL); @@ -97,7 +97,7 @@ void compute_block_jacobi( } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE_WITH_HALF( GKO_DECLARE_BATCH_BLOCK_JACOBI_COMPUTE_KERNEL); diff --git a/reference/solver/batch_bicgstab_kernels.cpp b/reference/solver/batch_bicgstab_kernels.cpp index 20883e24434..2a8458c105b 100644 --- a/reference/solver/batch_bicgstab_kernels.cpp +++ b/reference/solver/batch_bicgstab_kernels.cpp @@ -87,7 +87,8 @@ void apply(std::shared_ptr exec, dispatcher.apply(b, x, log_data); } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_BATCH_BICGSTAB_APPLY_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( + GKO_DECLARE_BATCH_BICGSTAB_APPLY_KERNEL); } // namespace batch_bicgstab diff --git a/reference/solver/batch_cg_kernels.cpp b/reference/solver/batch_cg_kernels.cpp index f2155f98719..3321a5b825b 100644 --- a/reference/solver/batch_cg_kernels.cpp +++ b/reference/solver/batch_cg_kernels.cpp @@ -87,7 +87,8 @@ void apply(std::shared_ptr exec, dispatcher.apply(b, x, log_data); } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_BATCH_CG_APPLY_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( + GKO_DECLARE_BATCH_CG_APPLY_KERNEL); } // namespace batch_cg