From 59a1f128ab58aab720d0a8d3470fc86eedc57d0d Mon Sep 17 00:00:00 2001 From: "Yu-Hsiang M. Tsai" Date: Wed, 18 Dec 2024 08:46:26 +0100 Subject: [PATCH] only go through dense mkl::gemm when all data pointers are valid. --- dpcpp/matrix/dense_kernels.dp.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/dpcpp/matrix/dense_kernels.dp.cpp b/dpcpp/matrix/dense_kernels.dp.cpp index f4654e7ca06..4f314ec92c8 100644 --- a/dpcpp/matrix/dense_kernels.dp.cpp +++ b/dpcpp/matrix/dense_kernels.dp.cpp @@ -223,7 +223,8 @@ void simple_apply(std::shared_ptr exec, using namespace oneapi::mkl; if constexpr (onemkl::is_supported::value) { if (b->get_stride() != 0 && c->get_stride() != 0) { - if (a->get_size()[1] > 0) { + if (a->get_const_values() && b->get_const_values() && + c->get_const_values()) { oneapi::mkl::blas::row_major::gemm( *exec->get_queue(), transpose::nontrans, transpose::nontrans, c->get_size()[0], c->get_size()[1], @@ -253,7 +254,8 @@ void apply(std::shared_ptr exec, using namespace oneapi::mkl; if constexpr (onemkl::is_supported::value) { if (b->get_stride() != 0 && c->get_stride() != 0) { - if (a->get_size()[1] > 0) { + if (a->get_const_values() && b->get_const_values() && + c->get_const_values()) { oneapi::mkl::blas::row_major::gemm( *exec->get_queue(), transpose::nontrans, transpose::nontrans, c->get_size()[0], c->get_size()[1],