From 1b7e524c6119c64adaad20ea2c6ae5a79ae470cb Mon Sep 17 00:00:00 2001 From: Brian Smith Date: Thu, 12 Oct 2023 17:25:24 -0700 Subject: [PATCH] Eliminate gathering during table construction. When `elem_exp_consttime` replaced `BN_mod_exp_mont_consttime` I did not fully understand the way the table was constructed in the original function. Recent BoringSSL changes clarify the table construction. Do it the same way, to restore performance to what it was previously. This addresses the `// TODO: Optimize this to avoid gathering`. --- src/arithmetic/bigint.rs | 58 +++++++++++++++++++++++----------------- 1 file changed, 33 insertions(+), 25 deletions(-) diff --git a/src/arithmetic/bigint.rs b/src/arithmetic/bigint.rs index 9e2548bf5d..12fb864737 100644 --- a/src/arithmetic/bigint.rs +++ b/src/arithmetic/bigint.rs @@ -576,20 +576,7 @@ pub fn elem_exp_consttime( unsafe { bn_gather5(acc.as_mut_ptr(), num_limbs, table.as_ptr(), i) } } - fn gather_square( - table: &[Limb], - acc: &mut [Limb], - m: &[Limb], - n0: &N0, - i: Window, - num_limbs: usize, - cpu_features: cpu::Features, - ) { - gather(table, acc, i, num_limbs); - limbs_mont_square(acc, m, n0, cpu_features); - } - - fn gather_mul_base_amm( + fn limbs_mul_mont_gather5_amm( table: &[Limb], acc: &mut [Limb], base: &[Limb], @@ -670,6 +657,29 @@ pub fn elem_exp_consttime( (acc, base_cached, m_cached) }; + let n0 = m.n0(); + + // Fill in all the powers of 2 of `acc` into the table using only squaring and without any + // gathering, storing the last calculated power into `acc`. + fn scatter_powers_of_2( + table: &mut [Limb], + acc: &mut [Limb], + m_cached: &[Limb], + n0: &N0, + mut i: Window, + num_limbs: usize, + cpu_features: cpu::Features, + ) { + loop { + scatter(table, acc, i, num_limbs); + i *= 2; + if i >= (TABLE_ENTRIES as Window) { + break; + } + limbs_mont_square(acc, m_cached, n0, cpu_features); + } + } + // All entries in `table` will be Montgomery encoded. // acc = table[0] = base**0 (i.e. 1). @@ -677,21 +687,19 @@ pub fn elem_exp_consttime( // encode it. debug_assert!(acc.iter().all(|&value| value == 0)); acc[0] = 1; - limbs_mont_mul(acc, &m.oneRR().0.limbs, m_cached, m.n0(), cpu_features); + limbs_mont_mul(acc, &m.oneRR().0.limbs, m_cached, n0, cpu_features); scatter(table, acc, 0, num_limbs); // acc = table[1] = base**1 (i.e. base). acc.copy_from_slice(base_cached); - scatter(table, acc, 1, num_limbs); - for i in 2..(TABLE_ENTRIES as Window) { - if i % 2 == 0 { - // TODO: Optimize this to avoid gathering - gather_square(table, acc, m_cached, m.n0(), i / 2, num_limbs, cpu_features); - } else { - gather_mul_base_amm(table, acc, base_cached, m_cached, m.n0(), i - 1, num_limbs) - }; - scatter(table, acc, i, num_limbs); + // Fill in entries 1, 2, 4, 8, 16. + scatter_powers_of_2(table, acc, m_cached, n0, 1, num_limbs, cpu_features); + // Fill in entries 3, 6, 12, 24; 5, 10, 20, 30; 7, 14, 28; 9, 18; 11, 22; 13, 26; 15, 30; + // 17; 19; 21; 23; 25; 27; 29; 31. + for i in (3..(TABLE_ENTRIES as Window)).step_by(2) { + limbs_mul_mont_gather5_amm(table, acc, base_cached, m_cached, n0, i - 1, num_limbs); + scatter_powers_of_2(table, acc, m_cached, n0, i, num_limbs, cpu_features); } let acc = limb::fold_5_bit_windows( @@ -701,7 +709,7 @@ pub fn elem_exp_consttime( acc }, |acc, window| { - power_amm(table, acc, m_cached, m.n0(), window, num_limbs); + power_amm(table, acc, m_cached, n0, window, num_limbs); acc }, );