Skip to content

Commit

Permalink
matrix with half
Browse files Browse the repository at this point in the history
  • Loading branch information
yhmtsai committed Oct 24, 2024
1 parent c503078 commit 013584d
Show file tree
Hide file tree
Showing 78 changed files with 1,511 additions and 862 deletions.
10 changes: 6 additions & 4 deletions common/cuda_hip/matrix/coo_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,8 @@ void spmv(std::shared_ptr<const DefaultExecutor> exec,
spmv2(exec, a, b, c);
}

GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_COO_SPMV_KERNEL);
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF(
GKO_DECLARE_COO_SPMV_KERNEL);


template <typename ValueType, typename IndexType>
Expand All @@ -253,7 +254,7 @@ void advanced_spmv(std::shared_ptr<const DefaultExecutor> exec,
advanced_spmv2(exec, alpha, a, b, c);
}

GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF(
GKO_DECLARE_COO_ADVANCED_SPMV_KERNEL);


Expand Down Expand Up @@ -295,7 +296,8 @@ void spmv2(std::shared_ptr<const DefaultExecutor> exec,
}
}

GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_COO_SPMV2_KERNEL);
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF(
GKO_DECLARE_COO_SPMV2_KERNEL);


template <typename ValueType, typename IndexType>
Expand Down Expand Up @@ -338,7 +340,7 @@ void advanced_spmv2(std::shared_ptr<const DefaultExecutor> exec,
}
}

GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF(
GKO_DECLARE_COO_ADVANCED_SPMV2_KERNEL);


Expand Down
124 changes: 76 additions & 48 deletions common/cuda_hip/matrix/csr_kernels.instantiate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,108 +17,136 @@ namespace csr {


// begin
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF(
GKO_DECLARE_CSR_CONVERT_TO_FBCSR_KERNEL);


// split
GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT1(GKO_DECLARE_CSR_SPMV_KERNEL,
int32);
GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT1_WITH_HALF(
GKO_DECLARE_CSR_SPMV_KERNEL, int32);
// split
GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT2(GKO_DECLARE_CSR_SPMV_KERNEL,
int32);
GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT2_WITH_HALF(
GKO_DECLARE_CSR_SPMV_KERNEL, int32);
// split
GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT3(GKO_DECLARE_CSR_SPMV_KERNEL,
int32);
GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT3_WITH_HALF(
GKO_DECLARE_CSR_SPMV_KERNEL, int32);
// split
GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT4(GKO_DECLARE_CSR_SPMV_KERNEL,
int32);
GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT4_WITH_HALF(
GKO_DECLARE_CSR_SPMV_KERNEL, int32);
// split
GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT1(GKO_DECLARE_CSR_SPMV_KERNEL,
int64);
GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT5_WITH_HALF(
GKO_DECLARE_CSR_SPMV_KERNEL, int32);
// split
GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT2(GKO_DECLARE_CSR_SPMV_KERNEL,
int64);
GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT6_WITH_HALF(
GKO_DECLARE_CSR_SPMV_KERNEL, int32);
// split
GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT3(GKO_DECLARE_CSR_SPMV_KERNEL,
int64);
GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT1_WITH_HALF(
GKO_DECLARE_CSR_SPMV_KERNEL, int64);
// split
GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT4(GKO_DECLARE_CSR_SPMV_KERNEL,
int64);
GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT2_WITH_HALF(
GKO_DECLARE_CSR_SPMV_KERNEL, int64);
// split
GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT3_WITH_HALF(
GKO_DECLARE_CSR_SPMV_KERNEL, int64);
// split
GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT4_WITH_HALF(
GKO_DECLARE_CSR_SPMV_KERNEL, int64);
// split
GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT5_WITH_HALF(
GKO_DECLARE_CSR_SPMV_KERNEL, int64);
// split
GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT6_WITH_HALF(
GKO_DECLARE_CSR_SPMV_KERNEL, int64);


// split
GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT1(
GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT1_WITH_HALF(
GKO_DECLARE_CSR_ADVANCED_SPMV_KERNEL, int32);
// split
GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT2_WITH_HALF(
GKO_DECLARE_CSR_ADVANCED_SPMV_KERNEL, int32);
// split
GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT2(
GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT3_WITH_HALF(
GKO_DECLARE_CSR_ADVANCED_SPMV_KERNEL, int32);
// split
GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT3(
GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT4_WITH_HALF(
GKO_DECLARE_CSR_ADVANCED_SPMV_KERNEL, int32);
// split
GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT4(
GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT5_WITH_HALF(
GKO_DECLARE_CSR_ADVANCED_SPMV_KERNEL, int32);
// split
GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT1(
GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT6_WITH_HALF(
GKO_DECLARE_CSR_ADVANCED_SPMV_KERNEL, int32);
// split
GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT1_WITH_HALF(
GKO_DECLARE_CSR_ADVANCED_SPMV_KERNEL, int64);
// split
GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT2_WITH_HALF(
GKO_DECLARE_CSR_ADVANCED_SPMV_KERNEL, int64);
// split
GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT3_WITH_HALF(
GKO_DECLARE_CSR_ADVANCED_SPMV_KERNEL, int64);
// split
GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT2(
GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT4_WITH_HALF(
GKO_DECLARE_CSR_ADVANCED_SPMV_KERNEL, int64);
// split
GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT3(
GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT5_WITH_HALF(
GKO_DECLARE_CSR_ADVANCED_SPMV_KERNEL, int64);
// split
GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT4(
GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT6_WITH_HALF(
GKO_DECLARE_CSR_ADVANCED_SPMV_KERNEL, int64);


// split
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_CSR_TRANSPOSE_KERNEL);
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF(
GKO_DECLARE_CSR_TRANSPOSE_KERNEL);
// split
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF(
GKO_DECLARE_CSR_CONJ_TRANSPOSE_KERNEL);
// split
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF(
GKO_DECLARE_CSR_SORT_BY_COLUMN_INDEX);
// split
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_CSR_SPGEMM_KERNEL);
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF(
GKO_DECLARE_CSR_SPGEMM_KERNEL);
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF(
GKO_DECLARE_CSR_ADVANCED_SPGEMM_KERNEL);
GKO_INSTANTIATE_FOR_EACH_INDEX_TYPE(GKO_DECLARE_CSR_BUILD_LOOKUP_KERNEL);
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_CSR_SPGEAM_KERNEL);
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF(
GKO_DECLARE_CSR_SPGEAM_KERNEL);
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF(
GKO_DECLARE_CSR_FILL_IN_DENSE_KERNEL);
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF(
GKO_DECLARE_CSR_INV_NONSYMM_PERMUTE_KERNEL);
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF(
GKO_DECLARE_CSR_INV_SYMM_PERMUTE_KERNEL);
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF(
GKO_DECLARE_CSR_ROW_PERMUTE_KERNEL);
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF(
GKO_DECLARE_CSR_INV_ROW_PERMUTE_KERNEL);
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF(
GKO_DECLARE_CSR_INV_NONSYMM_SCALE_PERMUTE_KERNEL);
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF(
GKO_DECLARE_CSR_INV_SYMM_SCALE_PERMUTE_KERNEL);
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF(
GKO_DECLARE_CSR_ROW_SCALE_PERMUTE_KERNEL);
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF(
GKO_DECLARE_CSR_INV_ROW_SCALE_PERMUTE_KERNEL);
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF(
GKO_DECLARE_CSR_CALC_NNZ_PER_ROW_IN_SPAN_KERNEL);
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF(
GKO_DECLARE_CSR_COMPUTE_SUB_MATRIX_KERNEL);
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF(
GKO_DECLARE_CSR_CALC_NNZ_PER_ROW_IN_INDEX_SET_KERNEL);
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF(
GKO_DECLARE_CSR_COMPUTE_SUB_MATRIX_FROM_INDEX_SET_KERNEL);
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF(
GKO_DECLARE_CSR_IS_SORTED_BY_COLUMN_INDEX);
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_CSR_EXTRACT_DIAGONAL);
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF(
GKO_DECLARE_CSR_EXTRACT_DIAGONAL);
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF(
GKO_DECLARE_CSR_CHECK_DIAGONAL_ENTRIES_EXIST);
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF(
GKO_DECLARE_CSR_ADD_SCALED_IDENTITY_KERNEL);
// end

Expand Down
6 changes: 3 additions & 3 deletions common/cuda_hip/matrix/csr_kernels.template.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ __global__ __launch_bounds__(spmv_block_size) void abstract_spmv(
{
using arithmetic_type = typename output_accessor::arithmetic_type;
using output_type = typename output_accessor::storage_type;
const arithmetic_type scale_factor = alpha[0];
const auto scale_factor = static_cast<arithmetic_type>(alpha[0]);
spmv_kernel(nwarps, num_rows, val, col_idxs, row_ptrs, srow, b, c,
[&scale_factor](const arithmetic_type& x) {
return static_cast<output_type>(scale_factor * x);
Expand Down Expand Up @@ -486,7 +486,7 @@ __global__ __launch_bounds__(spmv_block_size) void abstract_reduce(
const IndexType* __restrict__ last_row,
const MatrixValueType* __restrict__ alpha, acc::range<output_accessor> c)
{
const arithmetic_type alpha_val = alpha[0];
const auto alpha_val = static_cast<arithmetic_type>(alpha[0]);
merge_path_reduce(
nwarps, last_val, last_row, c,
[&alpha_val](const arithmetic_type& x) { return alpha_val * x; });
Expand Down Expand Up @@ -1193,7 +1193,7 @@ __global__ __launch_bounds__(default_block_size) void build_csr_lookup(
const auto i = base_i + lane;
const auto col = i < row_len
? local_cols[i]
: device_numeric_limits<IndexType>::max;
: device_numeric_limits<IndexType>::max();
const auto rel_col = static_cast<int32>(col - min_col);
const auto block = rel_col / bitmap_block_size;
const auto col_in_block = rel_col % bitmap_block_size;
Expand Down
33 changes: 18 additions & 15 deletions common/cuda_hip/matrix/dense_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ void convert_to_coo(std::shared_ptr<const DefaultExecutor> exec,
}
}

GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF(
GKO_DECLARE_DENSE_CONVERT_TO_COO_KERNEL);


Expand Down Expand Up @@ -491,7 +491,7 @@ void convert_to_csr(std::shared_ptr<const DefaultExecutor> exec,
}
}

GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF(
GKO_DECLARE_DENSE_CONVERT_TO_CSR_KERNEL);


Expand Down Expand Up @@ -521,7 +521,7 @@ void convert_to_ell(std::shared_ptr<const DefaultExecutor> exec,
}
}

GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF(
GKO_DECLARE_DENSE_CONVERT_TO_ELL_KERNEL);


Expand All @@ -544,7 +544,7 @@ void convert_to_fbcsr(std::shared_ptr<const DefaultExecutor> exec,
}
}

GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF(
GKO_DECLARE_DENSE_CONVERT_TO_FBCSR_KERNEL);


Expand All @@ -565,7 +565,7 @@ void count_nonzero_blocks_per_row(std::shared_ptr<const DefaultExecutor> exec,
}
}

GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF(
GKO_DECLARE_DENSE_COUNT_NONZERO_BLOCKS_PER_ROW_KERNEL);


Expand Down Expand Up @@ -598,7 +598,7 @@ void convert_to_hybrid(std::shared_ptr<const DefaultExecutor> exec,
}
}

GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF(
GKO_DECLARE_DENSE_CONVERT_TO_HYBRID_KERNEL);


Expand Down Expand Up @@ -629,7 +629,7 @@ void convert_to_sellp(std::shared_ptr<const DefaultExecutor> exec,
}
}

GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF(
GKO_DECLARE_DENSE_CONVERT_TO_SELLP_KERNEL);


Expand Down Expand Up @@ -657,7 +657,7 @@ void convert_to_sparsity_csr(std::shared_ptr<const DefaultExecutor> exec,
}
}

GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF(
GKO_DECLARE_DENSE_CONVERT_TO_SPARSITY_CSR_KERNEL);


Expand All @@ -681,7 +681,7 @@ void compute_dot_dispatch(std::shared_ptr<const DefaultExecutor> exec,
}
}

GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(
GKO_DECLARE_DENSE_COMPUTE_DOT_DISPATCH_KERNEL);


Expand All @@ -706,7 +706,7 @@ void compute_conj_dot_dispatch(std::shared_ptr<const DefaultExecutor> exec,
}
}

GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(
GKO_DECLARE_DENSE_COMPUTE_CONJ_DOT_DISPATCH_KERNEL);


Expand All @@ -729,7 +729,7 @@ void compute_norm2_dispatch(std::shared_ptr<const DefaultExecutor> exec,
}
}

GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(
GKO_DECLARE_DENSE_COMPUTE_NORM2_DISPATCH_KERNEL);


Expand Down Expand Up @@ -760,7 +760,8 @@ void simple_apply(std::shared_ptr<const DefaultExecutor> exec,
}
}

GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_DENSE_SIMPLE_APPLY_KERNEL);
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(
GKO_DECLARE_DENSE_SIMPLE_APPLY_KERNEL);


template <typename ValueType>
Expand All @@ -787,7 +788,7 @@ void apply(std::shared_ptr<const DefaultExecutor> exec,
}
}

GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_DENSE_APPLY_KERNEL);
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(GKO_DECLARE_DENSE_APPLY_KERNEL);


template <typename ValueType>
Expand All @@ -812,7 +813,8 @@ void transpose(std::shared_ptr<const DefaultExecutor> exec,
}
};

GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_DENSE_TRANSPOSE_KERNEL);
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(
GKO_DECLARE_DENSE_TRANSPOSE_KERNEL);


template <typename ValueType>
Expand All @@ -837,7 +839,8 @@ void conj_transpose(std::shared_ptr<const DefaultExecutor> exec,
}
}

GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_DENSE_CONJ_TRANSPOSE_KERNEL);
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(
GKO_DECLARE_DENSE_CONJ_TRANSPOSE_KERNEL);


} // namespace dense
Expand Down
2 changes: 1 addition & 1 deletion common/cuda_hip/matrix/diagonal_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ void apply_to_csr(std::shared_ptr<const DefaultExecutor> exec,
}
}

GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF(
GKO_DECLARE_DIAGONAL_APPLY_TO_CSR_KERNEL);


Expand Down
Loading

0 comments on commit 013584d

Please sign in to comment.