Skip to content

Commit

Permalink
fix ell accessor type
Browse files Browse the repository at this point in the history
  • Loading branch information
yhmtsai committed Oct 4, 2023
1 parent 51ab0b0 commit b531ba3
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 31 deletions.
24 changes: 14 additions & 10 deletions common/cuda_hip/matrix/ell_kernels.hpp.inc
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,14 @@ __device__ void spmv_kernel(
acc::range<b_accessor> b, OutputValueType* __restrict__ c,
const size_type c_stride, Closure op)
{
using arithmetic_type = typename a_accessor::arithmetic_type;
const auto tidx = thread::get_thread_id_flat();
const decltype(tidx) column_id = blockIdx.y;
if (num_thread_per_worker == 1) {
// Specialize the num_thread_per_worker = 1. It doesn't need the shared
// memory, __syncthreads, and atomic_add
if (tidx < num_rows) {
auto temp = zero<OutputValueType>();
auto temp = zero<arithmetic_type>();
for (size_type idx = 0; idx < num_stored_elements_per_row; idx++) {
const auto ind = tidx + idx * stride;
const auto col_idx = col[ind];
Expand All @@ -69,13 +70,13 @@ __device__ void spmv_kernel(
const auto worker_id = tidx / num_rows;
const auto step_size = num_worker_per_row * num_thread_per_worker;
__shared__ uninitialized_array<
OutputValueType, default_block_size / num_thread_per_worker>
arithmetic_type, default_block_size / num_thread_per_worker>
storage;
if (idx_in_worker == 0) {
storage[threadIdx.x] = gko::zero<OutputValueType>();
}
__syncthreads();
auto temp = zero<OutputValueType>();
auto temp = zero<arithmetic_type>();
for (size_type idx =
worker_id * num_thread_per_worker + idx_in_worker;
idx < num_stored_elements_per_row; idx += step_size) {
Expand Down Expand Up @@ -114,7 +115,9 @@ __global__ __launch_bounds__(default_block_size) void spmv(
spmv_kernel<num_thread_per_worker, atomic>(
num_rows, num_worker_per_row, val, col, stride,
num_stored_elements_per_row, b, c, c_stride,
[](const OutputValueType& x, const OutputValueType& y) { return x; });
[](const auto& x, const OutputValueType& y) {
return static_cast<OutputValueType>(x);
});
}


Expand All @@ -128,7 +131,8 @@ __global__ __launch_bounds__(default_block_size) void spmv(
const OutputValueType* __restrict__ beta, OutputValueType* __restrict__ c,
const size_type c_stride)
{
const OutputValueType alpha_val = alpha(0);
using arithmetic_type = typename a_accessor::arithmetic_type;
const auto alpha_val = alpha(0);
const OutputValueType beta_val = beta[0];
if (atomic) {
// Because the atomic operation changes the values of c during
Expand All @@ -139,16 +143,16 @@ __global__ __launch_bounds__(default_block_size) void spmv(
spmv_kernel<num_thread_per_worker, atomic>(
num_rows, num_worker_per_row, val, col, stride,
num_stored_elements_per_row, b, c, c_stride,
[&alpha_val](const OutputValueType& x, const OutputValueType& y) {
return alpha_val * x;
[&alpha_val](const auto& x, const OutputValueType& y) {
return static_cast<OutputValueType>(alpha_val * x);
});
} else {
spmv_kernel<num_thread_per_worker, atomic>(
num_rows, num_worker_per_row, val, col, stride,
num_stored_elements_per_row, b, c, c_stride,
[&alpha_val, &beta_val](const OutputValueType& x,
const OutputValueType& y) {
return alpha_val * x + beta_val * y;
[&alpha_val, &beta_val](const auto& x, const OutputValueType& y) {
return static_cast<OutputValueType>(
alpha_val * x + static_cast<arithmetic_type>(beta_val * y));
});
}
}
Expand Down
6 changes: 4 additions & 2 deletions cuda/matrix/ell_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,12 @@ void abstract_spmv(syn::value_list<int, info>,
const matrix::Dense<MatrixValueType>* alpha = nullptr,
const matrix::Dense<OutputValueType>* beta = nullptr)
{
using arithmetic_type =
highest_precision<InputValueType, OutputValueType, MatrixValueType>;
using a_accessor =
gko::acc::reduced_row_major<1, OutputValueType, const MatrixValueType>;
gko::acc::reduced_row_major<1, arithmetic_type, const MatrixValueType>;
using b_accessor =
gko::acc::reduced_row_major<2, OutputValueType, const InputValueType>;
gko::acc::reduced_row_major<2, arithmetic_type, const InputValueType>;

const auto nrows = a->get_size()[0];
const auto stride = a->get_stride();
Expand Down
40 changes: 23 additions & 17 deletions dpcpp/matrix/ell_kernels.dp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,16 +120,17 @@ void spmv_kernel(
const size_type stride, const size_type num_stored_elements_per_row,
acc::range<b_accessor> b, OutputValueType* __restrict__ c,
const size_type c_stride, Closure op, sycl::nd_item<3> item_ct1,
uninitialized_array<OutputValueType,
uninitialized_array<typename a_accessor::arithmetic_type,
default_block_size / num_thread_per_worker>& storage)
{
using arithmetic_type = typename a_accessor::arithmetic_type;
const auto tidx = thread::get_thread_id_flat(item_ct1);
const decltype(tidx) column_id = item_ct1.get_group(1);
if (num_thread_per_worker == 1) {
// Specialize the num_thread_per_worker = 1. It doesn't need the shared
// memory, __syncthreads, and atomic_add
if (tidx < num_rows) {
auto temp = zero<OutputValueType>();
auto temp = zero<arithmetic_type>();
for (size_type idx = 0; idx < num_stored_elements_per_row; idx++) {
const auto ind = tidx + idx * stride;
const auto col_idx = col[ind];
Expand All @@ -150,11 +151,11 @@ void spmv_kernel(
const auto step_size = num_worker_per_row * num_thread_per_worker;

if (runnable && idx_in_worker == 0) {
storage[item_ct1.get_local_id(2)] = 0;
storage[item_ct1.get_local_id(2)] = zero<arithmetic_type>();
}

item_ct1.barrier(sycl::access::fence_space::local_space);
auto temp = zero<OutputValueType>();
auto temp = zero<arithmetic_type>();
if (runnable) {
for (size_type idx =
worker_id * num_thread_per_worker + idx_in_worker;
Expand Down Expand Up @@ -193,13 +194,15 @@ void spmv(
const size_type stride, const size_type num_stored_elements_per_row,
acc::range<b_accessor> b, OutputValueType* __restrict__ c,
const size_type c_stride, sycl::nd_item<3> item_ct1,
uninitialized_array<OutputValueType,
uninitialized_array<typename a_accessor::arithmetic_type,
default_block_size / num_thread_per_worker>& storage)
{
spmv_kernel<num_thread_per_worker, atomic>(
num_rows, num_worker_per_row, val, col, stride,
num_stored_elements_per_row, b, c, c_stride,
[](const OutputValueType& x, const OutputValueType& y) { return x; },
[](const auto& x, const OutputValueType& y) {
return static_cast<OutputValueType>(x);
},
item_ct1, storage);
}

Expand All @@ -214,7 +217,7 @@ void spmv(dim3 grid, dim3 block, size_type dynamic_shared_memory,
{
queue->submit([&](sycl::handler& cgh) {
sycl::accessor<
uninitialized_array<OutputValueType,
uninitialized_array<typename a_accessor::arithmetic_type,
default_block_size / num_thread_per_worker>,
0, sycl::access_mode::read_write, sycl::access::target::local>
storage_acc_ct1(cgh);
Expand All @@ -239,10 +242,11 @@ void spmv(
const size_type num_stored_elements_per_row, acc::range<b_accessor> b,
const OutputValueType* __restrict__ beta, OutputValueType* __restrict__ c,
const size_type c_stride, sycl::nd_item<3> item_ct1,
uninitialized_array<OutputValueType,
uninitialized_array<typename a_accessor::arithmetic_type,
default_block_size / num_thread_per_worker>& storage)
{
const OutputValueType alpha_val = alpha(0);
using arithmetic_type = typename a_accessor::arithmetic_type;
const auto alpha_val = alpha(0);
const OutputValueType beta_val = beta[0];
if (atomic) {
// Because the atomic operation changes the values of c during
Expand All @@ -253,17 +257,17 @@ void spmv(
spmv_kernel<num_thread_per_worker, atomic>(
num_rows, num_worker_per_row, val, col, stride,
num_stored_elements_per_row, b, c, c_stride,
[&alpha_val](const OutputValueType& x, const OutputValueType& y) {
return alpha_val * x;
[&alpha_val](const auto& x, const OutputValueType& y) {
return static_cast<OutputValueType>(alpha_val * x);
},
item_ct1, storage);
} else {
spmv_kernel<num_thread_per_worker, atomic>(
num_rows, num_worker_per_row, val, col, stride,
num_stored_elements_per_row, b, c, c_stride,
[&alpha_val, &beta_val](const OutputValueType& x,
const OutputValueType& y) {
return alpha_val * x + beta_val * y;
[&alpha_val, &beta_val](const auto& x, const OutputValueType& y) {
return static_cast<OutputValueType>(
alpha_val * x + static_cast<arithmetic_type>(beta_val * y));
},
item_ct1, storage);
}
Expand All @@ -281,7 +285,7 @@ void spmv(dim3 grid, dim3 block, size_type dynamic_shared_memory,
{
queue->submit([&](sycl::handler& cgh) {
sycl::accessor<
uninitialized_array<OutputValueType,
uninitialized_array<typename a_accessor::arithmetic_type,
default_block_size / num_thread_per_worker>,
0, sycl::access_mode::read_write, sycl::access::target::local>
storage_acc_ct1(cgh);
Expand Down Expand Up @@ -316,10 +320,12 @@ void abstract_spmv(syn::value_list<int, info>,
const matrix::Dense<MatrixValueType>* alpha = nullptr,
const matrix::Dense<OutputValueType>* beta = nullptr)
{
using arithmetic_type =
highest_precision<InputValueType, OutputValueType, MatrixValueType>;
using a_accessor =
gko::acc::reduced_row_major<1, OutputValueType, const MatrixValueType>;
gko::acc::reduced_row_major<1, arithmetic_type, const MatrixValueType>;
using b_accessor =
gko::acc::reduced_row_major<2, OutputValueType, const InputValueType>;
gko::acc::reduced_row_major<2, arithmetic_type, const InputValueType>;

const auto nrows = a->get_size()[0];
const auto stride = a->get_stride();
Expand Down
6 changes: 4 additions & 2 deletions hip/matrix/ell_kernels.hip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,12 @@ void abstract_spmv(syn::value_list<int, info>,
const matrix::Dense<MatrixValueType>* alpha = nullptr,
const matrix::Dense<OutputValueType>* beta = nullptr)
{
using arithmetic_type =
highest_precision<InputValueType, OutputValueType, MatrixValueType>;
using a_accessor =
acc::reduced_row_major<1, OutputValueType, const MatrixValueType>;
acc::reduced_row_major<1, arithmetic_type, const MatrixValueType>;
using b_accessor =
acc::reduced_row_major<2, OutputValueType, const InputValueType>;
acc::reduced_row_major<2, arithmetic_type, const InputValueType>;

const auto nrows = a->get_size()[0];
const auto stride = a->get_stride();
Expand Down

0 comments on commit b531ba3

Please sign in to comment.