diff --git a/src/main/scala/yunsuan/fpu/FloatFMA.scala b/src/main/scala/yunsuan/fpu/FloatFMA.scala index 8ed8ae0..00c100a 100644 --- a/src/main/scala/yunsuan/fpu/FloatFMA.scala +++ b/src/main/scala/yunsuan/fpu/FloatFMA.scala @@ -184,15 +184,34 @@ class FloatFMA() extends Module{ val Ec_is_medium_f32 = !Ec_is_too_big_f32 & !Ec_is_too_small_f32 val Ec_is_medium_f16 = !Ec_is_too_big_f16 & !Ec_is_too_small_f16 - val rshift_guard_f64 = RegEnable(Mux(Ec_is_medium_f64, rshift_result_with_grs_f64(2), 0.U), fire) - val rshift_guard_f32 = RegEnable(Mux(Ec_is_medium_f32, rshift_result_with_grs_f32(2), 0.U), fire) - val rshift_guard_f16 = RegEnable(Mux(Ec_is_medium_f16, rshift_result_with_grs_f16(2), 0.U), fire) - val rshift_round_f64 = RegEnable(Mux(Ec_is_medium_f64, rshift_result_with_grs_f64(1), 0.U), fire) - val rshift_round_f32 = RegEnable(Mux(Ec_is_medium_f32, rshift_result_with_grs_f32(1), 0.U), fire) - val rshift_round_f16 = RegEnable(Mux(Ec_is_medium_f16, rshift_result_with_grs_f16(1), 0.U), fire) - val rshift_sticky_f64 = RegEnable(Mux(Ec_is_medium_f64, rshift_result_with_grs_f64(0), Mux(Ec_is_too_big_f64, 0.U, fp_c_significand_f64.orR)), fire) - val rshift_sticky_f32 = RegEnable(Mux(Ec_is_medium_f32, rshift_result_with_grs_f32(0), Mux(Ec_is_too_big_f32, 0.U, fp_c_significand_f32.orR)), fire) - val rshift_sticky_f16 = RegEnable(Mux(Ec_is_medium_f16, rshift_result_with_grs_f16(0), Mux(Ec_is_too_big_f16, 0.U, fp_c_significand_f16.orR)), fire) + // save 6 bit reg + val rshift_guard_f64_reg_d = Mux(Ec_is_medium_f64, rshift_result_with_grs_f64(2), 0.U) + val rshift_guard_f32_reg_d = Mux(Ec_is_medium_f32, rshift_result_with_grs_f32(2), 0.U) + val rshift_guard_f16_reg_d = Mux(Ec_is_medium_f16, rshift_result_with_grs_f16(2), 0.U) + val rshift_guard_reg_d = Mux(is_fp64, rshift_guard_f64_reg_d, Mux(is_fp32, rshift_guard_f32_reg_d, rshift_guard_f16_reg_d)) + val rshift_guard_reg = RegEnable(rshift_guard_reg_d, fire) + + val rshift_round_f64_reg_d = Mux(Ec_is_medium_f64, rshift_result_with_grs_f64(1), 0.U) + val rshift_round_f32_reg_d = Mux(Ec_is_medium_f32, rshift_result_with_grs_f32(1), 0.U) + val rshift_round_f16_reg_d = Mux(Ec_is_medium_f16, rshift_result_with_grs_f16(1), 0.U) + val rshift_round_reg_d = Mux(is_fp64, rshift_round_f64_reg_d, Mux(is_fp32, rshift_round_f32_reg_d, rshift_round_f16_reg_d)) + val rshift_round_reg = RegEnable(rshift_round_reg_d, fire) + + val rshift_sticky_f64_reg_d = Mux(Ec_is_medium_f64, rshift_result_with_grs_f64(0), Mux(Ec_is_too_big_f64, 0.U, fp_c_significand_f64.orR)) + val rshift_sticky_f32_reg_d = Mux(Ec_is_medium_f32, rshift_result_with_grs_f32(0), Mux(Ec_is_too_big_f32, 0.U, fp_c_significand_f32.orR)) + val rshift_sticky_f16_reg_d = Mux(Ec_is_medium_f16, rshift_result_with_grs_f16(0), Mux(Ec_is_too_big_f16, 0.U, fp_c_significand_f16.orR)) + val rshift_sticky_reg_d = Mux(is_fp64, rshift_sticky_f64_reg_d, Mux(is_fp32, rshift_sticky_f32_reg_d, rshift_sticky_f16_reg_d)) + val rshift_sticky_reg = RegEnable(rshift_sticky_reg_d, fire) + + val rshift_guard_f64 = rshift_guard_reg + val rshift_guard_f32 = rshift_guard_reg + val rshift_guard_f16 = rshift_guard_reg + val rshift_round_f64 = rshift_round_reg + val rshift_round_f32 = rshift_round_reg + val rshift_round_f16 = rshift_round_reg + val rshift_sticky_f64 = rshift_sticky_reg + val rshift_sticky_f32 = rshift_sticky_reg + val rshift_sticky_f16 = rshift_sticky_reg val rshift_result_temp_f64 = rshift_result_with_grs_f64.head(rshiftMaxF64-2) val rshift_result_temp_f32 = rshift_result_with_grs_f32.head(rshiftMaxF32-2) @@ -210,10 +229,15 @@ class FloatFMA() extends Module{ rshift_result_temp_f16, Mux(Ec_is_too_big_f16, fp_c_significand_cat0_f16.head(rshiftMaxF16-2), 0.U((rshiftMaxF16-2).W)) ) - - val fp_c_rshiftValue_inv_f64_reg0 = RegEnable(Mux(is_sub_f64.asBool ,Cat(1.U,~rshift_result_f64),Cat(0.U,rshift_result_f64)), fire) - val fp_c_rshiftValue_inv_f32_reg0 = RegEnable(Mux(is_sub_f32.asBool ,Cat(1.U,~rshift_result_f32),Cat(0.U,rshift_result_f32)), fire) - val fp_c_rshiftValue_inv_f16_reg0 = RegEnable(Mux(is_sub_f16.asBool ,Cat(1.U,~rshift_result_f16),Cat(0.U,rshift_result_f16)), fire) + // save 111 bit reg + val fp_c_rshiftValue_inv_f64_reg_d = Mux(is_sub_f64.asBool ,Cat(1.U,~rshift_result_f64),Cat(0.U,rshift_result_f64)) + val fp_c_rshiftValue_inv_f32_reg_d = Mux(is_sub_f32.asBool ,Cat(1.U,~rshift_result_f32),Cat(0.U,rshift_result_f32)) + val fp_c_rshiftValue_inv_f16_reg_d = Mux(is_sub_f16.asBool ,Cat(1.U,~rshift_result_f16),Cat(0.U,rshift_result_f16)) + val fp_c_rshiftValue_inv_reg_d = Mux(is_fp64, fp_c_rshiftValue_inv_f64_reg_d, Mux(is_fp32, fp_c_rshiftValue_inv_f32_reg_d, fp_c_rshiftValue_inv_f16_reg_d)) + val fp_c_rshiftValue_inv_reg = RegEnable(fp_c_rshiftValue_inv_reg_d, fire) + val fp_c_rshiftValue_inv_f64_reg0 = fp_c_rshiftValue_inv_reg + val fp_c_rshiftValue_inv_f32_reg0 = fp_c_rshiftValue_inv_reg(74, 0) + val fp_c_rshiftValue_inv_f16_reg0 = fp_c_rshiftValue_inv_reg(35, 0) val booth_in_a = Mux( is_fp64, fp_a_significand_f64, @@ -318,9 +342,14 @@ class FloatFMA() extends Module{ val adder_is_negative_f32 = adder_f32.head(1).asBool val adder_is_negative_f16 = adder_f16.head(1).asBool - val adder_is_negative_f64_reg2 = RegEnable(RegEnable(adder_is_negative_f64, fire_reg0), fire_reg1) - val adder_is_negative_f32_reg2 = RegEnable(RegEnable(adder_is_negative_f32, fire_reg0), fire_reg1) - val adder_is_negative_f16_reg2 = RegEnable(RegEnable(adder_is_negative_f16, fire_reg0), fire_reg1) + // save 4 + 2 = 6 bit reg + val adder_is_negative_reg_d = Mux(is_fp64_reg0, adder_is_negative_f64, Mux(is_fp32_reg0, adder_is_negative_f32, adder_is_negative_f16)) + val adder_is_negative_reg1 = RegEnable(adder_is_negative_reg_d, fire_reg0) + val adder_is_negative_reg2 = RegEnable(adder_is_negative_reg1, fire_reg1) + + val adder_is_negative_f64_reg2 = adder_is_negative_reg2 + val adder_is_negative_f32_reg2 = adder_is_negative_reg2 + val adder_is_negative_f16_reg2 = adder_is_negative_reg2 val adder_inv_f64 = Mux(adder_is_negative_f64, (~adder_f64.tail(1)).asUInt, adder_f64.tail(1)) val adder_inv_f32 = Mux(adder_is_negative_f32, (~adder_f32.tail(1)).asUInt, adder_f32.tail(1)) @@ -330,13 +359,27 @@ class FloatFMA() extends Module{ val Eab_is_greater_f32 = rshift_value_f32 > 0.S val Eab_is_greater_f16 = rshift_value_f16 > 0.S - val E_greater_f64_reg2 = RegEnable(RegEnable(RegEnable(Mux(Eab_is_greater_f64, Eab_f64(exponentWidth,0).asUInt, Cat(0.U(1.W),Ec_fix_f64)), fire), fire_reg0), fire_reg1) - val E_greater_f32_reg2 = RegEnable(RegEnable(RegEnable(Mux(Eab_is_greater_f32, Eab_f32(8,0).asUInt, Cat(0.U(1.W),Ec_fix_f32)), fire), fire_reg0), fire_reg1) - val E_greater_f16_reg2 = RegEnable(RegEnable(RegEnable(Mux(Eab_is_greater_f16, Eab_f16(5,0).asUInt, Cat(0.U(1.W),Ec_fix_f16)), fire), fire_reg0), fire_reg1) + // save 45bit reg + val E_greater_f64_reg_d = Mux(Eab_is_greater_f64, Eab_f64(exponentWidth,0).asUInt, Cat(0.U(1.W),Ec_fix_f64)) + val E_greater_f32_reg_d = Mux(Eab_is_greater_f32, Eab_f32(8,0).asUInt, Cat(0.U(1.W),Ec_fix_f32)) + val E_greater_f16_reg_d = Mux(Eab_is_greater_f16, Eab_f16(5,0).asUInt, Cat(0.U(1.W),Ec_fix_f16)) + val E_greater_reg_d = Mux(is_fp64, E_greater_f64_reg_d, Mux(is_fp32, E_greater_f32_reg_d, E_greater_f16_reg_d)) + val E_greater_reg2 = RegEnable(RegEnable(RegEnable(E_greater_reg_d, fire), fire_reg0), fire_reg1) + + val E_greater_f64_reg2 = E_greater_reg2 + val E_greater_f32_reg2 = E_greater_reg2(8,0) + val E_greater_f16_reg2 = E_greater_reg2(5,0) + + // save 15 bit reg + val lshift_value_max_f64_reg_d = Mux(Eab_is_greater_f64, Eab_f64(exponentWidth,0).asUInt - 1.U, Cat(0.U,Ec_fix_f64 - 1.U)) + val lshift_value_max_f32_reg_d = Mux(Eab_is_greater_f32, Eab_f32(8,0).asUInt - 1.U, Cat(0.U,Ec_fix_f32 - 1.U)) + val lshift_value_max_f16_reg_d = Mux(Eab_is_greater_f16, Eab_f16(5,0).asUInt - 1.U, Cat(0.U,Ec_fix_f16 - 1.U)) + val lshift_value_max_reg_d = Mux(is_fp64, lshift_value_max_f64_reg_d, Mux(is_fp32, lshift_value_max_f32_reg_d, lshift_value_max_f16_reg_d)) + val lshift_value_max_reg0 = RegEnable(lshift_value_max_reg_d, fire) - val lshift_value_max_f64_reg0 = RegEnable(Mux(Eab_is_greater_f64, Eab_f64(exponentWidth,0).asUInt - 1.U, Cat(0.U,Ec_fix_f64 - 1.U)), fire) - val lshift_value_max_f32_reg0 = RegEnable(Mux(Eab_is_greater_f32, Eab_f32(8,0).asUInt - 1.U, Cat(0.U,Ec_fix_f32 - 1.U)), fire) - val lshift_value_max_f16_reg0 = RegEnable(Mux(Eab_is_greater_f16, Eab_f16(5,0).asUInt - 1.U, Cat(0.U,Ec_fix_f16 - 1.U)), fire) + val lshift_value_max_f64_reg0 = lshift_value_max_reg0 + val lshift_value_max_f32_reg0 = lshift_value_max_reg0(8,0) + val lshift_value_max_f16_reg0 = lshift_value_max_reg0(5,0) val LZDWidth_f64 = adder_inv_f64.getWidth.U.getWidth val LZDWidth_f32 = adder_inv_f32.getWidth.U.getWidth @@ -356,14 +399,28 @@ class FloatFMA() extends Module{ Fill(adder_inv_f16.getWidth, 1.U) >> lshift_value_max_f16_reg0.tail(lshift_value_max_f16_reg0.getWidth-LZDWidth_f16) ).asUInt + //save 115 bit reg + val tzd_adder_f64_reg_d = Reverse(adder_f64.asUInt) + val tzd_adder_f32_reg_d = Reverse(adder_f32.asUInt) + val tzd_adder_f16_reg_d = Reverse(adder_f16.asUInt) + val tzd_adder_reg_d = Mux(is_fp64_reg0, tzd_adder_f64_reg_d, Mux(is_fp32_reg0, tzd_adder_f32_reg_d, tzd_adder_f16_reg_d)) + val tzd_adder_reg1 = RegEnable(tzd_adder_reg_d, fire_reg0) + //tail - val tzd_adder_f64_reg1 = LZD(RegEnable(Reverse(adder_f64.asUInt), fire_reg0).asTypeOf(adder_f64)) - val tzd_adder_f32_reg1 = LZD(RegEnable(Reverse(adder_f32.asUInt), fire_reg0).asTypeOf(adder_f32)) - val tzd_adder_f16_reg1 = LZD(RegEnable(Reverse(adder_f16.asUInt), fire_reg0).asTypeOf(adder_f16)) + val tzd_adder_f64_reg1 = LZD(tzd_adder_reg1.asTypeOf(adder_f64)) + val tzd_adder_f32_reg1 = LZD(tzd_adder_reg1(76,0).asTypeOf(adder_f32)) + val tzd_adder_f16_reg1 = LZD(tzd_adder_reg1(37,0).asTypeOf(adder_f16)) + + //save 112 bit reg + val lzd_adder_inv_mask_f64_reg_d = adder_inv_f64 | lshift_value_mask_f64 + val lzd_adder_inv_mask_f32_reg_d = adder_inv_f32 | lshift_value_mask_f32 + val lzd_adder_inv_mask_f16_reg_d = adder_inv_f16 | lshift_value_mask_f16 + val lzd_adder_inv_mask_reg_d = Mux(is_fp64_reg0, lzd_adder_inv_mask_f64_reg_d, Mux(is_fp32_reg0, lzd_adder_inv_mask_f32_reg_d, lzd_adder_inv_mask_f16_reg_d)) + val lzd_adder_inv_mask_reg1 = RegEnable(lzd_adder_inv_mask_reg_d, fire_reg0) - val lzd_adder_inv_mask_f64 = LZD(RegEnable(adder_inv_f64 | lshift_value_mask_f64, fire_reg0).asTypeOf(adder_inv_f64)) - val lzd_adder_inv_mask_f32 = LZD(RegEnable(adder_inv_f32 | lshift_value_mask_f32, fire_reg0).asTypeOf(adder_inv_f32)) - val lzd_adder_inv_mask_f16 = LZD(RegEnable(adder_inv_f16 | lshift_value_mask_f16, fire_reg0).asTypeOf(adder_inv_f16)) + val lzd_adder_inv_mask_f64 = LZD(lzd_adder_inv_mask_reg1.asTypeOf(adder_inv_f64)) + val lzd_adder_inv_mask_f32 = LZD(lzd_adder_inv_mask_reg1(75,0).asTypeOf(adder_inv_f32)) + val lzd_adder_inv_mask_f16 = LZD(lzd_adder_inv_mask_reg1(36,0).asTypeOf(adder_inv_f16)) val lzd_adder_inv_mask_f64_reg1 = Wire(UInt(lzd_adder_inv_mask_f64.getWidth.W)) val lzd_adder_inv_mask_f32_reg1 = Wire(UInt(lzd_adder_inv_mask_f32.getWidth.W)) @@ -372,12 +429,23 @@ class FloatFMA() extends Module{ lzd_adder_inv_mask_f32_reg1 := lzd_adder_inv_mask_f32 lzd_adder_inv_mask_f16_reg1 := lzd_adder_inv_mask_f16 - val lshift_mask_valid_f64_reg1 = (RegEnable(adder_inv_f64, fire_reg0) | RegEnable(lshift_value_mask_f64 , fire_reg0)) === RegEnable(lshift_value_mask_f64 , fire_reg0) - val lshift_mask_valid_f32_reg1 = (RegEnable(adder_inv_f32, fire_reg0) | RegEnable(lshift_value_mask_f32, fire_reg0)) === RegEnable(lshift_value_mask_f32, fire_reg0) - val lshift_mask_valid_f16_reg1 = (RegEnable(adder_inv_f16, fire_reg0) | RegEnable(lshift_value_mask_f16, fire_reg0)) === RegEnable(lshift_value_mask_f16, fire_reg0) - val lshift_value_f64_reg1 = lzd_adder_inv_mask_f64_reg1 - val lshift_value_f32_reg1 = lzd_adder_inv_mask_f32_reg1 - val lshift_value_f16_reg1 = lzd_adder_inv_mask_f16_reg1 + // save 828 bit reg + val lshift_mask_valid_f64_reg_d = (adder_inv_f64 | lshift_value_mask_f64) === lshift_value_mask_f64 + val lshift_mask_valid_f32_reg_d = (adder_inv_f32 | lshift_value_mask_f32) === lshift_value_mask_f32 + val lshift_mask_valid_f16_reg_d = (adder_inv_f16 | lshift_value_mask_f16) === lshift_value_mask_f16 + val lshift_mask_valid_reg_d = Mux(is_fp64_reg0, lshift_mask_valid_f64_reg_d, Mux(is_fp32_reg0, lshift_mask_valid_f32_reg_d, lshift_mask_valid_f16_reg_d)) + val lshift_mask_valid_reg = RegEnable(lshift_mask_valid_reg_d, fire_reg0) + + val lshift_mask_valid_f64_reg1 = lshift_mask_valid_reg + val lshift_mask_valid_f32_reg1 = lshift_mask_valid_reg + val lshift_mask_valid_f16_reg1 = lshift_mask_valid_reg + val lshift_value_f64_reg1 = lzd_adder_inv_mask_f64_reg1 + val lshift_value_f32_reg1 = lzd_adder_inv_mask_f32_reg1 + val lshift_value_f16_reg1 = lzd_adder_inv_mask_f16_reg1 + + // save 112 bit reg + val adder_reg_d = Mux(is_fp64_reg0, adder_f64, Mux(is_fp32_reg0, adder_f32, adder_f16)) + val adder_reg1 = RegEnable(adder_reg_d, fire_reg0) val adder_f64_reg1 = RegEnable(adder_f64, fire_reg0) val adder_f32_reg1 = RegEnable(adder_f32, fire_reg0) @@ -388,9 +456,9 @@ class FloatFMA() extends Module{ val lshift_adder_f32 = shiftLeftWithMux(adder_f32_reg1, lshift_value_f32_reg1) val lshift_adder_f16 = shiftLeftWithMux(adder_f16_reg1, lshift_value_f16_reg1) - val lshift_adder_inv_f64 = Cat(Mux(RegEnable(adder_is_negative_f64, fire_reg0),~lshift_adder_f64.head(significandWidth+4),lshift_adder_f64.head(significandWidth+4)),lshift_adder_f64.tail(significandWidth+4)) - val lshift_adder_inv_f32 = Cat(Mux(RegEnable(adder_is_negative_f32, fire_reg0),~lshift_adder_f32.head(24+4),lshift_adder_f32.head(24+4)),lshift_adder_f32.tail(24+4)) - val lshift_adder_inv_f16 = Cat(Mux(RegEnable(adder_is_negative_f16, fire_reg0),~lshift_adder_f16.head(11+4),lshift_adder_f16.head(11+4)),lshift_adder_f16.tail(11+4)) + val lshift_adder_inv_f64 = Cat(Mux(adder_is_negative_reg1,~lshift_adder_f64.head(significandWidth+4),lshift_adder_f64.head(significandWidth+4)),lshift_adder_f64.tail(significandWidth+4)) + val lshift_adder_inv_f32 = Cat(Mux(adder_is_negative_reg1,~lshift_adder_f32.head(24+4),lshift_adder_f32.head(24+4)),lshift_adder_f32.tail(24+4)) + val lshift_adder_inv_f16 = Cat(Mux(adder_is_negative_reg1,~lshift_adder_f16.head(11+4),lshift_adder_f16.head(11+4)),lshift_adder_f16.tail(11+4)) val is_fix_f64 = (tzd_adder_f64_reg1 + lzd_adder_inv_mask_f64_reg1) === adder_inv_f64.getWidth.U val is_fix_f32 = (tzd_adder_f32_reg1 + lzd_adder_inv_mask_f32_reg1) === adder_inv_f32.getWidth.U @@ -400,14 +468,23 @@ class FloatFMA() extends Module{ val lshift_adder_inv_fix_f32 = Mux(is_fix_f32, lshift_adder_inv_f32.head(adder_inv_f32.getWidth), lshift_adder_inv_f32.tail(1)) val lshift_adder_inv_fix_f16 = Mux(is_fix_f16, lshift_adder_inv_f16.head(adder_inv_f16.getWidth), lshift_adder_inv_f16.tail(1)) - val fraction_result_no_round_f64_reg2 = RegEnable(lshift_adder_inv_fix_f64.tail(1).head(significandWidth-1), fire_reg1) - val fraction_result_no_round_f32_reg2 = RegEnable(lshift_adder_inv_fix_f32.tail(1).head(24-1), fire_reg1) - val fraction_result_no_round_f16_reg2 = RegEnable(lshift_adder_inv_fix_f16.tail(1).head(11-1), fire_reg1) + // save 33 bit reg + val fraction_result_no_round_f64_reg_d = lshift_adder_inv_fix_f64.tail(1).head(significandWidth-1) + val fraction_result_no_round_f32_reg_d = lshift_adder_inv_fix_f32.tail(1).head(24-1) + val fraction_result_no_round_f16_reg_d = lshift_adder_inv_fix_f16.tail(1).head(11-1) + val fraction_result_no_round_reg_d = Mux(is_fp64_reg1, fraction_result_no_round_f64_reg_d, + Mux(is_fp32_reg1, fraction_result_no_round_f32_reg_d, fraction_result_no_round_f16_reg_d)) + val fraction_result_no_round_reg = RegEnable(fraction_result_no_round_reg_d, fire_reg1) + + val fraction_result_no_round_f64_reg2 = fraction_result_no_round_reg_d + val fraction_result_no_round_f32_reg2 = fraction_result_no_round_reg_d(22,0) + val fraction_result_no_round_f16_reg2 = fraction_result_no_round_reg_d(9,0) val fraction_result_round_f64 = fraction_result_no_round_f64_reg2 +& 1.U val fraction_result_round_f32 = fraction_result_no_round_f32_reg2 +& 1.U val fraction_result_round_f16 = fraction_result_no_round_f16_reg2 +& 1.U + // todo save 4 bit reg val sign_result_temp_f64_reg2 = RegEnable(RegEnable(Mux(adder_is_negative_f64 , RegEnable(sign_c_f64 , fire), RegEnable(sign_a_b_f64 , fire)), fire_reg0), fire_reg1) val sign_result_temp_f32_reg2 = RegEnable(RegEnable(Mux(adder_is_negative_f32, RegEnable(sign_c_f32, fire), RegEnable(sign_a_b_f32, fire)), fire_reg0), fire_reg1) val sign_result_temp_f16_reg2 = RegEnable(RegEnable(Mux(adder_is_negative_f16, RegEnable(sign_c_f16, fire), RegEnable(sign_a_b_f16, fire)), fire_reg0), fire_reg1) @@ -423,6 +500,7 @@ class FloatFMA() extends Module{ val RUP_reg2 = RegEnable(RegEnable(RegEnable(RUP, fire), fire_reg0), fire_reg1) val RMM_reg2 = RegEnable(RegEnable(RegEnable(RMM, fire), fire_reg0), fire_reg1) + // todo save 8 bit reg val sticky_f64_reg2 = RegEnable(RegEnable(rshift_sticky_f64, fire_reg0) | (lzd_adder_inv_mask_f64_reg1 + tzd_adder_f64_reg1 < (adder_inv_f64.getWidth-significandWidth-2).U), fire_reg1) val sticky_f32_reg2 = RegEnable(RegEnable(rshift_sticky_f32, fire_reg0) | (lzd_adder_inv_mask_f32_reg1 + tzd_adder_f32_reg1 < (adder_inv_f32.getWidth-24-2).U), fire_reg1) val sticky_f16_reg2 = RegEnable(RegEnable(rshift_sticky_f16, fire_reg0) | (lzd_adder_inv_mask_f16_reg1 + tzd_adder_f16_reg1 < (adder_inv_f16.getWidth-11-2).U), fire_reg1) @@ -431,6 +509,7 @@ class FloatFMA() extends Module{ val sticky_uf_f32_reg2 = RegEnable(RegEnable(rshift_sticky_f32, fire_reg0) | (lzd_adder_inv_mask_f32_reg1 + tzd_adder_f32_reg1 < (adder_inv_f32.getWidth-24-3).U), fire_reg1) val sticky_uf_f16_reg2 = RegEnable(RegEnable(rshift_sticky_f16, fire_reg0) | (lzd_adder_inv_mask_f16_reg1 + tzd_adder_f16_reg1 < (adder_inv_f16.getWidth-11-3).U), fire_reg1) + // todo save 4 bit reg val round_lshift_f64_reg2 = RegEnable(lshift_adder_inv_fix_f64.tail(significandWidth+1).head(1), fire_reg1) val round_lshift_f32_reg2 = RegEnable(lshift_adder_inv_fix_f32.tail(24+1).head(1), fire_reg1) val round_lshift_f16_reg2 = RegEnable(lshift_adder_inv_fix_f16.tail(11+1).head(1), fire_reg1) @@ -451,6 +530,7 @@ class FloatFMA() extends Module{ val guard_uf_f32 = round_f32 val guard_uf_f16 = round_f16 + // todo save 2 bit reg val round_lshift_uf_f64_reg2 = RegEnable(lshift_adder_inv_fix_f64.tail(significandWidth+2).head(1), fire_reg1) val round_lshift_uf_f32_reg2 = RegEnable(lshift_adder_inv_fix_f32.tail(24+2).head(1), fire_reg1) val round_lshift_uf_f16_reg2 = RegEnable(lshift_adder_inv_fix_f16.tail(11+2).head(1), fire_reg1) @@ -494,24 +574,31 @@ class FloatFMA() extends Module{ val exponent_add_1_f32 = fraction_result_no_round_f32_reg2.andR & round_add1_f32.asBool val exponent_add_1_f16 = fraction_result_no_round_f16_reg2.andR & round_add1_f16.asBool + // save 2 + 26 = 28bit reg + val is_fix_reg_d = Mux(is_fp64_reg1, is_fix_f64, Mux(is_fp32_reg1, is_fix_f32, is_fix_f16)) + val is_fix_reg2 = RegEnable(is_fix_reg_d, fire_reg1) + + val lshift_value_reg_d = Mux(is_fp64_reg1, lshift_value_f64_reg1, Mux(is_fp32_reg1, lshift_value_f32_reg1, lshift_value_f16_reg1)) + val lshift_value_reg2 = RegEnable(lshift_value_reg_d, fire_reg1) - val exponent_result_add_value_f64 = Mux(exponent_add_1_f64 | RegEnable(is_fix_f64, fire_reg1), - E_greater_f64_reg2 - RegEnable(lshift_value_f64_reg1, fire_reg1) + 1.U, - E_greater_f64_reg2 - RegEnable(lshift_value_f64_reg1, fire_reg1) + val exponent_result_add_value_f64 = Mux(exponent_add_1_f64 | is_fix_reg2, + E_greater_f64_reg2 - lshift_value_reg2 + 1.U, + E_greater_f64_reg2 - lshift_value_reg2 ) - val exponent_result_add_value_f32 = Mux(exponent_add_1_f32 | RegEnable(is_fix_f32, fire_reg1), - E_greater_f32_reg2 - RegEnable(lshift_value_f32_reg1, fire_reg1) + 1.U, - E_greater_f32_reg2 - RegEnable(lshift_value_f32_reg1, fire_reg1) + val exponent_result_add_value_f32 = Mux(exponent_add_1_f32 | is_fix_reg2, + E_greater_f32_reg2 - lshift_value_reg2(6,0) + 1.U, + E_greater_f32_reg2 - lshift_value_reg2(6,0) ) - val exponent_result_add_value_f16 = Mux(exponent_add_1_f16 | RegEnable(is_fix_f16, fire_reg1), - E_greater_f16_reg2 - RegEnable(lshift_value_f16_reg1, fire_reg1) + 1.U, - E_greater_f16_reg2 - RegEnable(lshift_value_f16_reg1, fire_reg1) + val exponent_result_add_value_f16 = Mux(exponent_add_1_f16 | is_fix_reg2, + E_greater_f16_reg2 - lshift_value_reg2(5,0) + 1.U, + E_greater_f16_reg2 - lshift_value_reg2(5,0) ) val exponent_overflow_f64 = exponent_result_add_value_f64.head(1).asBool | exponent_result_add_value_f64.tail(1).andR val exponent_overflow_f32 = exponent_result_add_value_f32.head(1).asBool | exponent_result_add_value_f32.tail(1).andR val exponent_overflow_f16 = exponent_result_add_value_f16.head(1).asBool | exponent_result_add_value_f16.tail(1).andR + // todo save 2bit reg val exponent_is_min_f64 = RegEnable(!lshift_adder_inv_fix_f64.head(1).asBool & lshift_mask_valid_f64_reg1 & !is_fix_f64, fire_reg1) val exponent_is_min_f32 = RegEnable(!lshift_adder_inv_fix_f32.head(1).asBool & lshift_mask_valid_f32_reg1 & !is_fix_f32, fire_reg1) val exponent_is_min_f16 = RegEnable(!lshift_adder_inv_fix_f16.head(1).asBool & lshift_mask_valid_f16_reg1 & !is_fix_f16, fire_reg1) @@ -556,6 +643,7 @@ class FloatFMA() extends Module{ val fp_c_is_zero_f32 = !io.fp_cIsFpCanonicalNAN & !fp_c_significand_f32.orR val fp_c_is_zero_f16 = !io.fp_cIsFpCanonicalNAN & !fp_c_significand_f16.orR + // todo save 4bit reg val normal_result_is_zero_f64_reg2 = RegEnable(RegEnable(!adder_f64.orR , fire_reg0), fire_reg1) val normal_result_is_zero_f32_reg2 = RegEnable(RegEnable(!adder_f32.orR, fire_reg0), fire_reg1) val normal_result_is_zero_f16_reg2 = RegEnable(RegEnable(!adder_f16.orR, fire_reg0), fire_reg1) @@ -632,6 +720,36 @@ class FloatFMA() extends Module{ val fp_result_f32 = Wire(UInt(32.W)) val fp_result_f16 = Wire(UInt(16.W)) + // save 144 bit reg + val fp_result_f64_fp_a_or_b_is_zero = Cat( + Mux( + fp_c_is_zero_f64, + Mux(is_fmul, sign_a_b_f64, (sign_a_b_f64 & sign_c_f64) | (RDN & (sign_a_b_f64 ^ sign_c_f64))), + fp_c_f64.head(1) + ), + fp_c_f64.tail(1) + ) + val fp_result_f32_fp_a_or_b_is_zero = Cat( + Mux( + fp_c_is_zero_f32, + Mux(is_fmul, sign_a_b_f32, (sign_a_b_f32 & sign_c_f32) | (RDN & (sign_a_b_f32 ^ sign_c_f32)) ), + fp_c_f32.head(1) + ), + fp_c_f32.tail(1) + ) + val fp_result_f16_fp_a_or_b_is_zero = Cat( + Mux( + fp_c_is_zero_f16, + Mux(is_fmul, sign_a_b_f16, (sign_a_b_f16 & sign_c_f16) | (RDN & (sign_a_b_f16 ^ sign_c_f16)) ), + fp_c_f16.head(1) + ), + fp_c_f16.tail(1) + ) + val fp_result_fp_a_or_b_is_zero_reg_d = Mux(is_fp64, fp_result_f64_fp_a_or_b_is_zero, + Mux(is_fp32, fp_result_f32_fp_a_or_b_is_zero, fp_result_f16_fp_a_or_b_is_zero)) + val fp_result_fp_a_or_b_is_zero_reg = RegEnable(RegEnable(RegEnable(fp_result_fp_a_or_b_is_zero_reg_d, fire), fire_reg0), fire_reg1) + + // todo save 18*2 = 36bitreg val has_nan_f64_reg2 = RegEnable(RegEnable(RegEnable(has_nan_f64, fire), fire_reg0), fire_reg1) val has_nan_f64_is_NV_reg2 = RegEnable(RegEnable(RegEnable( has_snan_f64.asBool | (fp_a_is_inf_f64 & fp_b_is_zero_f64) | (fp_a_is_zero_f64 & fp_b_is_inf_f64), @@ -646,17 +764,7 @@ class FloatFMA() extends Module{ val is_overflow_f64_down_reg2 = RTZ_reg2 | (RDN_reg2 & !sign_result_temp_f64_reg2.asBool) | (RUP_reg2 & sign_result_temp_f64_reg2.asBool) val fp_a_or_b_is_zero_f64_reg2 = RegEnable(RegEnable(RegEnable(fp_a_is_zero_f64 | fp_b_is_zero_f64, fire), fire_reg0), fire_reg1) - val fp_result_f64_fp_a_or_b_is_zero_reg2 = RegEnable(RegEnable(RegEnable( - Cat( - Mux( - fp_c_is_zero_f64, - Mux(is_fmul, sign_a_b_f64, (sign_a_b_f64 & sign_c_f64) | (RDN & (sign_a_b_f64 ^ sign_c_f64))), - fp_c_f64.head(1) - ), - fp_c_f64.tail(1) - ), - fire), fire_reg0), fire_reg1 - ) + val fp_result_f64_fp_a_or_b_is_zero_reg2 = fp_result_fp_a_or_b_is_zero_reg when(has_nan_f64_reg2){ fp_result_f64 := result_nan_f64 @@ -690,17 +798,7 @@ class FloatFMA() extends Module{ fire), fire_reg0), fire_reg1) val is_overflow_f32_down_reg2 = RTZ_reg2 | (RDN_reg2 & !sign_result_temp_f32_reg2.asBool) | (RUP_reg2 & sign_result_temp_f32_reg2.asBool) val fp_a_or_b_is_zero_f32_reg2 = RegEnable(RegEnable(RegEnable(fp_a_is_zero_f32 | fp_b_is_zero_f32, fire), fire_reg0), fire_reg1) - val fp_result_f32_fp_a_or_b_is_zero_reg2 = RegEnable(RegEnable(RegEnable( - Cat( - Mux( - fp_c_is_zero_f32, - Mux(is_fmul, sign_a_b_f32, (sign_a_b_f32 & sign_c_f32) | (RDN & (sign_a_b_f32 ^ sign_c_f32)) ), - fp_c_f32.head(1) - ), - fp_c_f32.tail(1) - ), - fire), fire_reg0), fire_reg1 - ) + val fp_result_f32_fp_a_or_b_is_zero_reg2 = fp_result_fp_a_or_b_is_zero_reg(31,0) when(has_nan_f32_reg2){ fp_result_f32 := result_nan_f32 fflags_f32 := Mux(has_nan_f32_is_NV_reg2,"b10000".U,"b00000".U) @@ -733,17 +831,7 @@ class FloatFMA() extends Module{ fire), fire_reg0), fire_reg1) val is_overflow_f16_down_reg2 = RTZ_reg2 | (RDN_reg2 & !sign_result_temp_f16_reg2.asBool) | (RUP_reg2 & sign_result_temp_f16_reg2.asBool) val fp_a_or_b_is_zero_f16_reg2 = RegEnable(RegEnable(RegEnable(fp_a_is_zero_f16 | fp_b_is_zero_f16, fire), fire_reg0), fire_reg1) - val fp_result_f16_fp_a_or_b_is_zero_reg2 = RegEnable(RegEnable(RegEnable( - Cat( - Mux( - fp_c_is_zero_f16, - Mux(is_fmul, sign_a_b_f16, (sign_a_b_f16 & sign_c_f16) | (RDN & (sign_a_b_f16 ^ sign_c_f16)) ), - fp_c_f16.head(1) - ), - fp_c_f16.tail(1) - ), - fire), fire_reg0), fire_reg1 - ) + val fp_result_f16_fp_a_or_b_is_zero_reg2 = fp_result_fp_a_or_b_is_zero_reg(15,0) when(has_nan_f16_reg2){ fp_result_f16 := result_nan_f16 fflags_f16 := Mux(has_nan_f16_is_NV_reg2,"b10000".U,"b00000".U)