diff --git a/common/unified/components/precision_conversion_kernels.cpp b/common/unified/components/precision_conversion_kernels.cpp index 310189c64c0..df1cd9fa062 100644 --- a/common/unified/components/precision_conversion_kernels.cpp +++ b/common/unified/components/precision_conversion_kernels.cpp @@ -52,11 +52,12 @@ void convert_precision(std::shared_ptr exec, run_kernel( exec, [] GKO_KERNEL(auto idx, auto in, auto out) { - using in_type = typename std::remove_cv< - typename std::remove_reference::type>::type; - using out_type = typename std::remove_cv< - typename std::remove_reference::type>::type; - out[idx] = static_cast(in[idx]); + using target_type = device_type; + using arithmetic_type = + highest_precision>; + // use float as the bridge between bfloat16 and half on device + out[idx] = + static_cast(static_cast(in[idx])); }, size, in, out); } diff --git a/common/unified/matrix/dense_kernels.template.cpp b/common/unified/matrix/dense_kernels.template.cpp index d6cda937fdf..81d7543d79c 100644 --- a/common/unified/matrix/dense_kernels.template.cpp +++ b/common/unified/matrix/dense_kernels.template.cpp @@ -63,7 +63,10 @@ void copy(std::shared_ptr exec, exec, [] GKO_KERNEL(auto row, auto col, auto input, auto output) { using type = device_type; - output(row, col) = static_cast(input(row, col)); + using arithmetic_type = + highest_precision>; + output(row, col) = static_cast( + static_cast(input(row, col))); }, input->get_size(), input, output); } @@ -405,8 +408,11 @@ void row_gather(std::shared_ptr exec, run_kernel( exec, [] GKO_KERNEL(auto row, auto col, auto orig, auto rows, auto gathered) { - gathered(row, col) = - static_cast>(orig(rows[row], col)); + using output_type = device_type; + using arithmetic_type = + highest_precision>; + gathered(row, col) = static_cast( + static_cast(orig(rows[row], col))); }, dim<2>{row_idxs->get_num_elems(), orig->get_size()[1]}, orig, *row_idxs, row_collection);