Skip to content

Commit

Permalink
Rename PartialModulus to Modulus, Modulus to OwnedModulusWithOne.
Browse files Browse the repository at this point in the history
Originally we only had `Modulus`. Then we had a need for a
temporary `Modulus` without `oneRR` so we created `PartialModulus`.
However, there is really nothing "partial" about them. So, improve
the naming by renaming `PartialModulus` to `Modulus` and `Modulus`
to `OwnedModulusWithOne`. In the future we may refactor things
further to separate the ownership aspect from the "has oneRR"
aspect.

Instead of just doing a straightforward rename, take this
opportunity to refactor the code so that it uses the new `Modulus`
whenever `oneRR()` isn't used. This eliminates the duplication of
the APIs of the two modulus types, and the duplication of
`elem_mul` and `elem_mul_`.
  • Loading branch information
briansmith committed Nov 6, 2023
1 parent 69d1dd3 commit e51c88a
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 115 deletions.
93 changes: 47 additions & 46 deletions src/arithmetic/bigint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
use self::boxed_limbs::BoxedLimbs;
pub(crate) use self::{
modulus::{Modulus, PartialModulus, MODULUS_MAX_LIMBS},
modulus::{Modulus, OwnedModulusWithOne, MODULUS_MAX_LIMBS},
private_exponent::PrivateExponent,
};
use super::n0::N0;
Expand Down Expand Up @@ -186,20 +186,9 @@ impl<M> Elem<M, Unencoded> {
}

pub fn elem_mul<M, AF, BF>(
a: &Elem<M, AF>,
b: Elem<M, BF>,
m: &Modulus<M>,
) -> Elem<M, <(AF, BF) as ProductEncoding>::Output>
where
(AF, BF): ProductEncoding,
{
elem_mul_(a, b, &m.as_partial())
}

fn elem_mul_<M, AF, BF>(
a: &Elem<M, AF>,
mut b: Elem<M, BF>,
m: &PartialModulus<M>,
m: &Modulus<M>,
) -> Elem<M, <(AF, BF) as ProductEncoding>::Output>
where
(AF, BF): ProductEncoding,
Expand All @@ -211,7 +200,7 @@ where
}
}

fn elem_mul_by_2<M, AF>(a: &mut Elem<M, AF>, m: &PartialModulus<M>) {
fn elem_mul_by_2<M, AF>(a: &mut Elem<M, AF>, m: &Modulus<M>) {
prefixed_extern! {
fn LIMBS_shl_mod(r: *mut Limb, a: *const Limb, m: *const Limb, num_limbs: c::size_t);
}
Expand Down Expand Up @@ -254,7 +243,7 @@ pub fn elem_reduced<Larger, Smaller: NotMuchSmallerModulus<Larger>>(

fn elem_squared<M, E>(
mut a: Elem<M, E>,
m: &PartialModulus<M>,
m: &Modulus<M>,
) -> Elem<M, <(E, E) as ProductEncoding>::Output>
where
(E, E): ProductEncoding,
Expand Down Expand Up @@ -316,7 +305,7 @@ impl<M> One<M, RR> {
// values, using `LIMB_BITS` here, rather than `N0::LIMBS_USED * LIMB_BITS`,
// is correct because R**2 will still be a multiple of the latter as
// `N0::LIMBS_USED` is either one or two.
fn newRR(m: &PartialModulus<M>, m_bits: bits::BitLength) -> Self {
fn newRR(m: &Modulus<M>, m_bits: bits::BitLength) -> Self {
let m_bits = m_bits.as_usize_bits();
let r = (m_bits + (LIMB_BITS - 1)) / LIMB_BITS * LIMB_BITS;

Expand Down Expand Up @@ -390,7 +379,7 @@ impl<M: PublicModulus, E> Clone for One<M, E> {
pub(crate) fn elem_exp_vartime<M>(
base: Elem<M, R>,
exponent: NonZeroU64,
m: &PartialModulus<M>,
m: &Modulus<M>,
) -> Elem<M, R> {
// Use what [Knuth] calls the "S-and-X binary method", i.e. variable-time
// square-and-multiply that scans the exponent from the most significant
Expand All @@ -417,7 +406,7 @@ pub(crate) fn elem_exp_vartime<M>(
bit >>= 1;
acc = elem_squared(acc, m);
if (exponent & bit) != 0 {
acc = elem_mul_(&base, acc, m);
acc = elem_mul(&base, acc, m);
}
}
acc
Expand All @@ -426,17 +415,20 @@ pub(crate) fn elem_exp_vartime<M>(
/// Uses Fermat's Little Theorem to calculate modular inverse in constant time.
pub fn elem_inverse_consttime<M: Prime>(
a: Elem<M, R>,
m: &Modulus<M>,
m: &OwnedModulusWithOne<M>,
) -> Result<Elem<M, Unencoded>, error::Unspecified> {
elem_exp_consttime(a, &PrivateExponent::for_flt(m), m)
elem_exp_consttime(a, &PrivateExponent::for_flt(&m.modulus()), m)
}

#[cfg(not(target_arch = "x86_64"))]
pub fn elem_exp_consttime<M>(
base: Elem<M, R>,
exponent: &PrivateExponent,
m: &Modulus<M>,
m: &OwnedModulusWithOne<M>,
) -> Result<Elem<M, Unencoded>, error::Unspecified> {
let oneRR = m.oneRR();
let m = &m.modulus();

use crate::{bssl, limb::Window};

const WINDOW_BITS: usize = 5;
Expand Down Expand Up @@ -469,7 +461,7 @@ pub fn elem_exp_consttime<M>(
mut tmp: Elem<M, R>,
) -> (Elem<M, R>, Elem<M, R>) {
for _ in 0..WINDOW_BITS {
acc = elem_squared(acc, &m.as_partial());
acc = elem_squared(acc, m);
}
gather(table, &mut tmp, i);
let acc = elem_mul(&tmp, acc, m);
Expand All @@ -489,7 +481,7 @@ pub fn elem_exp_consttime<M>(
// `table` was initialized to zero and hasn't changed.
debug_assert!(acc.iter().all(|&value| value == 0));
acc[0] = 1;
limbs_mont_mul(acc, &m.oneRR().0.limbs, m.limbs(), m.n0(), m.cpu_features());
limbs_mont_mul(acc, &oneRR.0.limbs, m.limbs(), m.n0(), m.cpu_features());
}

entry_mut(&mut table, 1, num_limbs).copy_from_slice(&base.limbs);
Expand Down Expand Up @@ -527,10 +519,13 @@ pub fn elem_exp_consttime<M>(
pub fn elem_exp_consttime<M>(
base: Elem<M, R>,
exponent: &PrivateExponent,
m: &Modulus<M>,
m: &OwnedModulusWithOne<M>,
) -> Result<Elem<M, Unencoded>, error::Unspecified> {
use crate::limb::LIMB_BYTES;

let oneRR = m.oneRR();
let m = &m.modulus();

// Pretty much all the math here requires CPU feature detection to have
// been done. `cpu_features` isn't threaded through all the internal
// functions, so just make it clear that it has been done at this point.
Expand Down Expand Up @@ -685,7 +680,7 @@ pub fn elem_exp_consttime<M>(
// encode it.
debug_assert!(acc.iter().all(|&value| value == 0));
acc[0] = 1;
limbs_mont_mul(acc, &m.oneRR().0.limbs, m_cached, n0, cpu_features);
limbs_mont_mul(acc, &oneRR.0.limbs, m_cached, n0, cpu_features);
scatter(table, acc, 0, num_limbs);

// acc = base**1 (i.e. base).
Expand Down Expand Up @@ -853,16 +848,17 @@ mod tests {
|section, test_case| {
assert_eq!(section, "");

let m = consume_modulus::<M>(test_case, "M", cpu_features);
let m_ = consume_modulus::<M>(test_case, "M", cpu_features);
let m = m_.modulus();
let expected_result = consume_elem(test_case, "ModExp", &m);
let base = consume_elem(test_case, "A", &m);
let e = {
let bytes = test_case.consume_bytes("E");
PrivateExponent::from_be_bytes_for_test_only(untrusted::Input::from(&bytes), &m)
.expect("valid exponent")
};
let base = into_encoded(base, &m);
let actual_result = elem_exp_consttime(base, &e, &m).unwrap();
let base = into_encoded(base, &m_);
let actual_result = elem_exp_consttime(base, &e, &m_).unwrap();
assert_elem_eq(&actual_result, &expected_result);

Ok(())
Expand All @@ -882,13 +878,14 @@ mod tests {
|section, test_case| {
assert_eq!(section, "");

let m = consume_modulus::<M>(test_case, "M", cpu_features);
let m_ = consume_modulus::<M>(test_case, "M", cpu_features);
let m = m_.modulus();
let expected_result = consume_elem(test_case, "ModMul", &m);
let a = consume_elem(test_case, "A", &m);
let b = consume_elem(test_case, "B", &m);

let b = into_encoded(b, &m);
let a = into_encoded(a, &m);
let b = into_encoded(b, &m_);
let a = into_encoded(a, &m_);
let actual_result = elem_mul(&a, b, &m);
let actual_result = actual_result.into_unencoded(&m);
assert_elem_eq(&actual_result, &expected_result);
Expand All @@ -906,12 +903,13 @@ mod tests {
|section, test_case| {
assert_eq!(section, "");

let m = consume_modulus::<M>(test_case, "M", cpu_features);
let m_ = consume_modulus::<M>(test_case, "M", cpu_features);
let m = m_.modulus();
let expected_result = consume_elem(test_case, "ModSquare", &m);
let a = consume_elem(test_case, "A", &m);

let a = into_encoded(a, &m);
let actual_result = elem_squared(a, &m.as_partial());
let a = into_encoded(a, &m_);
let actual_result = elem_squared(a, &m);
let actual_result = actual_result.into_unencoded(&m);
assert_elem_eq(&actual_result, &expected_result);

Expand All @@ -932,13 +930,14 @@ mod tests {
unsafe impl SmallerModulus<MM> for M {}
unsafe impl NotMuchSmallerModulus<MM> for M {}

let m = consume_modulus::<M>(test_case, "M", cpu_features);
let m_ = consume_modulus::<M>(test_case, "M", cpu_features);
let m = m_.modulus();
let expected_result = consume_elem(test_case, "R", &m);
let a =
consume_elem_unchecked::<MM>(test_case, "A", expected_result.limbs.len() * 2);

let actual_result = elem_reduced(&a, &m);
let oneRR = m.oneRR();
let oneRR = m_.oneRR();
let actual_result = elem_mul(oneRR.as_ref(), actual_result, &m);
assert_elem_eq(&actual_result, &expected_result);

Expand All @@ -961,11 +960,11 @@ mod tests {
unsafe impl SlightlySmallerModulus<N> for QQ {}

let qq = consume_modulus::<QQ>(test_case, "QQ", cpu_features);
let expected_result = consume_elem::<QQ>(test_case, "R", &qq);
let expected_result = consume_elem::<QQ>(test_case, "R", &qq.modulus());
let n = consume_modulus::<N>(test_case, "N", cpu_features);
let a = consume_elem::<N>(test_case, "A", &n);
let a = consume_elem::<N>(test_case, "A", &n.modulus());

let actual_result = elem_reduced_once(&a, &qq);
let actual_result = elem_reduced_once(&a, &qq.modulus());
assert_elem_eq(&actual_result, &expected_result);

Ok(())
Expand All @@ -975,7 +974,7 @@ mod tests {

#[test]
fn test_modulus_debug() {
let (modulus, _) = Modulus::<M>::from_be_bytes_with_bit_length(
let (modulus, _) = OwnedModulusWithOne::<M>::from_be_bytes_with_bit_length(
untrusted::Input::from(&[0xff; LIMB_BYTES * MODULUS_MIN_LIMBS]),
cpu::features(),
)
Expand Down Expand Up @@ -1010,11 +1009,13 @@ mod tests {
test_case: &mut test::TestCase,
name: &str,
cpu_features: cpu::Features,
) -> Modulus<M> {
) -> OwnedModulusWithOne<M> {
let value = test_case.consume_bytes(name);
let (value, _) =
Modulus::from_be_bytes_with_bit_length(untrusted::Input::from(&value), cpu_features)
.unwrap();
let (value, _) = OwnedModulusWithOne::from_be_bytes_with_bit_length(
untrusted::Input::from(&value),
cpu_features,
)
.unwrap();
value
}

Expand All @@ -1031,7 +1032,7 @@ mod tests {
}
}

fn into_encoded<M>(a: Elem<M, Unencoded>, m: &Modulus<M>) -> Elem<M, R> {
elem_mul(m.oneRR().as_ref(), a, m)
fn into_encoded<M>(a: Elem<M, Unencoded>, m: &OwnedModulusWithOne<M>) -> Elem<M, R> {
elem_mul(m.oneRR().as_ref(), a, &m.modulus())
}
}
Loading

0 comments on commit e51c88a

Please sign in to comment.