Skip to content

Commit

Permalink
Merge Fix batch ell const apply infinite loop
Browse files Browse the repository at this point in the history
This PR fixes batch ell const apply infinite loop

Related PR: #1437
  • Loading branch information
yhmtsai authored Oct 21, 2023
2 parents 612a732 + 59f099d commit 74d0744
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 13 deletions.
26 changes: 13 additions & 13 deletions core/matrix/batch_ell.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,7 @@ Ell<ValueType, IndexType>* Ell<ValueType, IndexType>::apply(
ptr_param<const MultiVector<ValueType>> b,
ptr_param<MultiVector<ValueType>> x)
{
this->validate_application_parameters(b.get(), x.get());
auto exec = this->get_executor();
this->apply_impl(make_temporary_clone(exec, b).get(),
make_temporary_clone(exec, x).get());
static_cast<const Ell*>(this)->apply(b, x);
return this;
}

Expand All @@ -147,7 +144,10 @@ const Ell<ValueType, IndexType>* Ell<ValueType, IndexType>::apply(
ptr_param<const MultiVector<ValueType>> b,
ptr_param<MultiVector<ValueType>> x) const
{
this->apply(b, x);
this->validate_application_parameters(b.get(), x.get());
auto exec = this->get_executor();
this->apply_impl(make_temporary_clone(exec, b).get(),
make_temporary_clone(exec, x).get());
return this;
}

Expand All @@ -159,13 +159,7 @@ Ell<ValueType, IndexType>* Ell<ValueType, IndexType>::apply(
ptr_param<const MultiVector<ValueType>> beta,
ptr_param<MultiVector<ValueType>> x)
{
this->validate_application_parameters(alpha.get(), b.get(), beta.get(),
x.get());
auto exec = this->get_executor();
this->apply_impl(make_temporary_clone(exec, alpha).get(),
make_temporary_clone(exec, b).get(),
make_temporary_clone(exec, beta).get(),
make_temporary_clone(exec, x).get());
static_cast<const Ell*>(this)->apply(alpha, b, beta, x);
return this;
}

Expand All @@ -177,7 +171,13 @@ const Ell<ValueType, IndexType>* Ell<ValueType, IndexType>::apply(
ptr_param<const MultiVector<ValueType>> beta,
ptr_param<MultiVector<ValueType>> x) const
{
this->apply(alpha, b, beta, x);
this->validate_application_parameters(alpha.get(), b.get(), beta.get(),
x.get());
auto exec = this->get_executor();
this->apply_impl(make_temporary_clone(exec, alpha).get(),
make_temporary_clone(exec, b).get(),
make_temporary_clone(exec, beta).get(),
make_temporary_clone(exec, x).get());
return this;
}

Expand Down
41 changes: 41 additions & 0 deletions reference/test/matrix/batch_ell_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,21 @@ TYPED_TEST(Ell, AppliesToBatchMultiVector)
}


TYPED_TEST(Ell, ConstAppliesToBatchMultiVector)
{
using T = typename TestFixture::value_type;
using BMtx = typename TestFixture::BMtx;

static_cast<const BMtx*>(this->mtx_0.get())->apply(this->b_0, this->x_0);

this->mtx_00->apply(this->b_00.get(), this->x_00.get());
this->mtx_01->apply(this->b_01.get(), this->x_01.get());
auto res = gko::batch::unbatch<gko::batch::MultiVector<T>>(this->x_0.get());
GKO_ASSERT_MTX_NEAR(res[0].get(), this->x_00.get(), r<T>::value);
GKO_ASSERT_MTX_NEAR(res[1].get(), this->x_01.get(), r<T>::value);
}


TYPED_TEST(Ell, AppliesLinearCombinationToBatchMultiVector)
{
using BMtx = typename TestFixture::BMtx;
Expand All @@ -154,6 +169,32 @@ TYPED_TEST(Ell, AppliesLinearCombinationToBatchMultiVector)
}


TYPED_TEST(Ell, ConstAppliesLinearCombinationToBatchMultiVector)
{
using BMtx = typename TestFixture::BMtx;
using BMVec = typename TestFixture::BMVec;
using DenseMtx = typename TestFixture::DenseMtx;
using T = typename TestFixture::value_type;
auto alpha = gko::batch::initialize<BMVec>({{1.5}, {-1.0}}, this->exec);
auto beta = gko::batch::initialize<BMVec>({{2.5}, {-4.0}}, this->exec);
auto alpha0 = gko::initialize<DenseMtx>({1.5}, this->exec);
auto alpha1 = gko::initialize<DenseMtx>({-1.0}, this->exec);
auto beta0 = gko::initialize<DenseMtx>({2.5}, this->exec);
auto beta1 = gko::initialize<DenseMtx>({-4.0}, this->exec);

static_cast<const BMtx*>(this->mtx_0.get())
->apply(alpha.get(), this->b_0.get(), beta.get(), this->x_0.get());

this->mtx_00->apply(alpha0.get(), this->b_00.get(), beta0.get(),
this->x_00.get());
this->mtx_01->apply(alpha1.get(), this->b_01.get(), beta1.get(),
this->x_01.get());
auto res = gko::batch::unbatch<gko::batch::MultiVector<T>>(this->x_0.get());
GKO_ASSERT_MTX_NEAR(res[0].get(), this->x_00.get(), r<T>::value);
GKO_ASSERT_MTX_NEAR(res[1].get(), this->x_01.get(), r<T>::value);
}


TYPED_TEST(Ell, ApplyFailsOnWrongNumberOfResultCols)
{
using BMVec = typename TestFixture::BMVec;
Expand Down

0 comments on commit 74d0744

Please sign in to comment.