diff --git a/reference/solver/batch_cg_kernels.hpp b/reference/solver/batch_cg_kernels.hpp index 2f8e5990931..6208a048972 100644 --- a/reference/solver/batch_cg_kernels.hpp +++ b/reference/solver/batch_cg_kernels.hpp @@ -32,8 +32,8 @@ inline void initialize( const BatchMatrixType_entry& A_entry, const gko::batch::multi_vector::batch_item& b_entry, const gko::batch::multi_vector::batch_item& x_entry, - const gko::batch::multi_vector::batch_item& rho_new_entry, const gko::batch::multi_vector::batch_item& rho_old_entry, + const gko::batch::multi_vector::batch_item& rho_new_entry, const gko::batch::multi_vector::batch_item& r_entry, const gko::batch::multi_vector::batch_item& p_entry, const gko::batch::multi_vector::batch_item& z_entry, @@ -86,7 +86,7 @@ inline void update_p( template inline void update_x_and_r( - const gko::batch::multi_vector::batch_item& rho_old_entry, + const gko::batch::multi_vector::batch_item& rho_new_entry, const gko::batch::multi_vector::batch_item& p_entry, const gko::batch::multi_vector::batch_item& Ap_entry, const gko::batch::multi_vector::batch_item& alpha_entry, @@ -96,7 +96,7 @@ inline void update_x_and_r( batch_single_kernels::compute_conj_dot_product_kernel( p_entry, Ap_entry, alpha_entry); - const ValueType temp = rho_old_entry.values[0] / alpha_entry.values[0]; + const ValueType temp = rho_new_entry.values[0] / alpha_entry.values[0]; for (int row = 0; row < r_entry.num_rows; row++) { x_entry.values[row * x_entry.stride] += temp * p_entry.values[row * p_entry.stride];