Skip to content

Commit

Permalink
make half not rely on type
Browse files Browse the repository at this point in the history
  • Loading branch information
yhmtsai committed Oct 24, 2024
1 parent 391bf03 commit 409a92e
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 27 deletions.
2 changes: 1 addition & 1 deletion core/test/base/half.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class ExtendedFloatTestBase : public ::testing::Test {
protected:
using half = gko::half;

static constexpr auto byte_size = gko::byte_size;
static constexpr auto byte_size = gko::detail::byte_size;

template <std::size_t N>
static floating<N - 1> create_from_bits(const char (&s)[N])
Expand Down
51 changes: 28 additions & 23 deletions include/ginkgo/core/base/half.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,43 +6,47 @@
#define GKO_PUBLIC_CORE_BASE_HALF_HPP_


#include <climits>
#include <complex>
#include <cstdint>
#include <cstring>
#include <type_traits>

#include <ginkgo/core/base/std_extensions.hpp>
#include <ginkgo/core/base/types.hpp>


class __half;


namespace gko {


template <typename, size_type, size_type>
template <typename, std::size_t, std::size_t>
class truncated;


class half;


namespace detail {


constexpr std::size_t byte_size = CHAR_BIT;

template <std::size_t, typename = void>
struct uint_of_impl {};

template <std::size_t Bits>
struct uint_of_impl<Bits, std::enable_if_t<(Bits <= 16)>> {
using type = uint16;
using type = std::uint16_t;
};

template <std::size_t Bits>
struct uint_of_impl<Bits, std::enable_if_t<(16 < Bits && Bits <= 32)>> {
using type = uint32;
using type = std::uint32_t;
};

template <std::size_t Bits>
struct uint_of_impl<Bits, std::enable_if_t<(32 < Bits)>> {
using type = uint64;
using type = std::uint64_t;
};

template <std::size_t Bits>
Expand All @@ -53,8 +57,8 @@ template <typename T>
struct basic_float_traits {};

template <>
struct basic_float_traits<float16> {
using type = float16;
struct basic_float_traits<half> {
using type = half;
static constexpr int sign_bits = 1;
static constexpr int significand_bits = 10;
static constexpr int exponent_bits = 5;
Expand All @@ -71,24 +75,25 @@ struct basic_float_traits<__half> {
};

template <>
struct basic_float_traits<float32> {
using type = float32;
struct basic_float_traits<float> {
using type = float;
static constexpr int sign_bits = 1;
static constexpr int significand_bits = 23;
static constexpr int exponent_bits = 8;
static constexpr bool rounds_to_nearest = true;
};

template <>
struct basic_float_traits<float64> {
using type = float64;
struct basic_float_traits<double> {
using type = double;
static constexpr int sign_bits = 1;
static constexpr int significand_bits = 52;
static constexpr int exponent_bits = 11;
static constexpr bool rounds_to_nearest = true;
};

template <typename FloatType, size_type NumComponents, size_type ComponentId>
template <typename FloatType, std::size_t NumComponents,
std::size_t ComponentId>
struct basic_float_traits<truncated<FloatType, NumComponents, ComponentId>> {
using type = truncated<FloatType, NumComponents, ComponentId>;
static constexpr int sign_bits = ComponentId == 0 ? 1 : 0;
Expand Down Expand Up @@ -281,7 +286,7 @@ struct precision_converter<SourceType, ResultType, false> {
class half {
public:
// create half value from the bits directly.
static constexpr half create_from_bits(uint16 bits) noexcept
static constexpr half create_from_bits(std::uint16_t bits) noexcept
{
half result;
result.data_ = bits;
Expand Down Expand Up @@ -378,19 +383,19 @@ class half {
}

private:
using f16_traits = detail::float_traits<float16>;
using f32_traits = detail::float_traits<float32>;
using f16_traits = detail::float_traits<half>;
using f32_traits = detail::float_traits<float>;

void float2half(float val) noexcept
{
uint32 bit_val(0);
std::uint32_t bit_val(0);
std::memcpy(&bit_val, &val, sizeof(float));
data_ = float2half(bit_val);
}

static constexpr uint16 float2half(uint32 data_) noexcept
static constexpr std::uint16_t float2half(std::uint32_t data_) noexcept
{
using conv = detail::precision_converter<float32, float16>;
using conv = detail::precision_converter<float, half>;
if (f32_traits::is_inf(data_)) {
return conv::shift_sign(data_) | f16_traits::exponent_mask;
} else if (f32_traits::is_nan(data_)) {
Expand Down Expand Up @@ -419,9 +424,9 @@ class half {
}
}

static constexpr uint32 half2float(uint16 data_) noexcept
static constexpr std::uint32_t half2float(std::uint16_t data_) noexcept
{
using conv = detail::precision_converter<float16, float32>;
using conv = detail::precision_converter<half, float>;
if (f16_traits::is_inf(data_)) {
return conv::shift_sign(data_) | f32_traits::exponent_mask;
} else if (f16_traits::is_nan(data_)) {
Expand All @@ -436,7 +441,7 @@ class half {
}
}

uint16 data_;
std::uint16_t data_;
};


Expand Down
5 changes: 2 additions & 3 deletions include/ginkgo/core/base/types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
#include <string>
#include <type_traits>

#include <ginkgo/core/base/half.hpp>


#ifdef __HIPCC__
#include <hip/hip_runtime.h>
Expand Down Expand Up @@ -138,9 +140,6 @@ using uint64 = std::uint64_t;
using uintptr = std::uintptr_t;


class half;


/**
* Half precision floating point type.
*/
Expand Down

0 comments on commit 409a92e

Please sign in to comment.