Skip to content

Commit

Permalink
use float as the bridge between bfloat16 and half
Browse files Browse the repository at this point in the history
  • Loading branch information
yhmtsai committed Oct 4, 2023
1 parent 50cf5b2 commit 51ab0b0
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 8 deletions.
11 changes: 6 additions & 5 deletions common/unified/components/precision_conversion_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,12 @@ void convert_precision(std::shared_ptr<const DefaultExecutor> exec,
run_kernel(
exec,
[] GKO_KERNEL(auto idx, auto in, auto out) {
using in_type = typename std::remove_cv<
typename std::remove_reference<decltype(*in)>::type>::type;
using out_type = typename std::remove_cv<
typename std::remove_reference<decltype(*out)>::type>::type;
out[idx] = static_cast<out_type>(in[idx]);
using target_type = device_type<TargetType>;
using arithmetic_type =
highest_precision<target_type, device_type<SourceType>>;
// use float as the bridge between bfloat16 and half on device
out[idx] =
static_cast<target_type>(static_cast<arithmetic_type>(in[idx]));
},
size, in, out);
}
Expand Down
12 changes: 9 additions & 3 deletions common/unified/matrix/dense_kernels.template.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,10 @@ void copy(std::shared_ptr<const DefaultExecutor> exec,
exec,
[] GKO_KERNEL(auto row, auto col, auto input, auto output) {
using type = device_type<OutValueType>;
output(row, col) = static_cast<type>(input(row, col));
using arithmetic_type =
highest_precision<type, device_type<InValueType>>;
output(row, col) = static_cast<type>(
static_cast<arithmetic_type>(input(row, col)));
},
input->get_size(), input, output);
}
Expand Down Expand Up @@ -405,8 +408,11 @@ void row_gather(std::shared_ptr<const DefaultExecutor> exec,
run_kernel(
exec,
[] GKO_KERNEL(auto row, auto col, auto orig, auto rows, auto gathered) {
gathered(row, col) =
static_cast<device_type<OutputType>>(orig(rows[row], col));
using output_type = device_type<OutputType>;
using arithmetic_type =
highest_precision<output_type, device_type<ValueType>>;
gathered(row, col) = static_cast<output_type>(
static_cast<arithmetic_type>(orig(rows[row], col)));
},
dim<2>{row_idxs->get_num_elems(), orig->get_size()[1]}, orig, *row_idxs,
row_collection);
Expand Down

0 comments on commit 51ab0b0

Please sign in to comment.