Skip to content

Commit

Permalink
Merge Fixing parameter ordering/naming for reference (omp) batch CG
Browse files Browse the repository at this point in the history
This merge fixes the parameter ordering of the `initialize` function for the reference+omp implementation of the batch CG. Additionally, a parameter name in `update_x_and_r` is changed to better reflect the actually used variable.

Related PR: #1701
  • Loading branch information
MarcelKoch authored Oct 25, 2024
2 parents 9e7bfa7 + 30c2d86 commit 568a759
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions reference/solver/batch_cg_kernels.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ inline void initialize(
const BatchMatrixType_entry& A_entry,
const gko::batch::multi_vector::batch_item<const ValueType>& b_entry,
const gko::batch::multi_vector::batch_item<const ValueType>& x_entry,
const gko::batch::multi_vector::batch_item<ValueType>& rho_new_entry,
const gko::batch::multi_vector::batch_item<ValueType>& rho_old_entry,
const gko::batch::multi_vector::batch_item<ValueType>& rho_new_entry,
const gko::batch::multi_vector::batch_item<ValueType>& r_entry,
const gko::batch::multi_vector::batch_item<ValueType>& p_entry,
const gko::batch::multi_vector::batch_item<ValueType>& z_entry,
Expand Down Expand Up @@ -86,7 +86,7 @@ inline void update_p(

template <typename ValueType>
inline void update_x_and_r(
const gko::batch::multi_vector::batch_item<const ValueType>& rho_old_entry,
const gko::batch::multi_vector::batch_item<const ValueType>& rho_new_entry,
const gko::batch::multi_vector::batch_item<const ValueType>& p_entry,
const gko::batch::multi_vector::batch_item<const ValueType>& Ap_entry,
const gko::batch::multi_vector::batch_item<ValueType>& alpha_entry,
Expand All @@ -96,7 +96,7 @@ inline void update_x_and_r(
batch_single_kernels::compute_conj_dot_product_kernel<ValueType>(
p_entry, Ap_entry, alpha_entry);

const ValueType temp = rho_old_entry.values[0] / alpha_entry.values[0];
const ValueType temp = rho_new_entry.values[0] / alpha_entry.values[0];
for (int row = 0; row < r_entry.num_rows; row++) {
x_entry.values[row * x_entry.stride] +=
temp * p_entry.values[row * p_entry.stride];
Expand Down

0 comments on commit 568a759

Please sign in to comment.