Skip to content

Commit

Permalink
RSA: Remove QQ from RsaKeyPair.
Browse files Browse the repository at this point in the history
QQ comprised almost 25% of the bulk of RsaKeyPair and is actually
completely unnecessary since `elem_reduced` can do the whole
reduction itself.

This has the nice and important side effect of eliminating some
conversion operations between `bigint` types.

This is also a step towards eliminating some of the `unsafe trait`
stuff that kinda-but-not-really modeled modulus relationships.
  • Loading branch information
briansmith committed Nov 9, 2023
1 parent cbcac26 commit 946ce87
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 99 deletions.
59 changes: 25 additions & 34 deletions src/arithmetic/bigint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ use super::n0::N0;
pub(crate) use super::nonnegative::Nonnegative;
use crate::{
arithmetic::montgomery::*,
bits::BitLength,
c, cpu, error,
limb::{self, Limb, LimbMask, LIMB_BITS},
polyfill::u64_from_usize,
Expand Down Expand Up @@ -80,31 +81,6 @@ pub unsafe trait Prime {}
/// trait preemptively.)
pub unsafe trait SmallerModulus<L> {}

/// A modulus *s* where s < l < 2*s for the given larger modulus *l*. This is
/// the precondition for reduction by conditional subtraction,
/// `elem_reduce_once()`.
///
/// # Safety
///
/// Some logic may assume that the invariant holds when accessing limbs within
/// a value, e.g. by assuming that the smaller modulus is at most one limb
/// smaller than the larger modulus. TODO: Any such logic should be
/// encapsulated here, or this trait should be made non-`unsafe`. (In retrospect,
/// this shouldn't have been made an `unsafe` trait preemptively.)
pub unsafe trait SlightlySmallerModulus<L>: SmallerModulus<L> {}

/// A modulus *s* where √l <= s < l for the given larger modulus *l*. This is
/// the precondition for the more general Montgomery reduction from ℤ/lℤ to
/// ℤ/sℤ.
///
/// # Safety
///
/// Some logic may assume that the invariant holds when accessing limbs within
/// a value. TODO: Any such logic should be encapsulated here, or this trait
/// should be made non-`unsafe`. (In retrospect, this shouldn't have been made
/// an `unsafe` trait preemptively.)
pub unsafe trait NotMuchSmallerModulus<L>: SmallerModulus<L> {}

pub trait PublicModulus {}

/// Elements of ℤ/mℤ for some modulus *m*.
Expand Down Expand Up @@ -214,12 +190,20 @@ fn elem_mul_by_2<M, AF>(a: &mut Elem<M, AF>, m: &Modulus<M>) {
}
}

pub fn elem_reduced_once<Larger, Smaller: SlightlySmallerModulus<Larger>>(
// TODO: This is currently unused, but we intend to eventually use this to
// reduce elements (x mod q) mod p in the RSA CRT. If/when we do so, we
// should update the testing so it is reflective of that usage, instead of
// the old usage.
#[cfg(test)]
pub fn elem_reduced_once<Larger, Smaller>(
a: &Elem<Larger, Unencoded>,
m: &Modulus<Smaller>,
) -> Elem<Smaller, Unencoded> {
// `limbs_reduce_once_constant_time` requires `r` and `m` to have the same
// number of limbs.
assert_eq!(a.limbs.len(), m.limbs().len());

let mut r = a.limbs.clone();
assert!(r.len() <= m.limbs().len());
limb::limbs_reduce_once_constant_time(&mut r, m.limbs());
Elem {
limbs: BoxedLimbs::new_unchecked(r.into_limbs()),
Expand All @@ -228,10 +212,19 @@ pub fn elem_reduced_once<Larger, Smaller: SlightlySmallerModulus<Larger>>(
}

#[inline]
pub fn elem_reduced<Larger, Smaller: NotMuchSmallerModulus<Larger>>(
pub fn elem_reduced<Larger, Smaller>(
a: &Elem<Larger, Unencoded>,
m: &Modulus<Smaller>,
other_prime_len_bits: BitLength,
) -> Elem<Smaller, RInverse> {
// This is stricter than required mathematically but this is what we
// guarantee and this is easier to check. The real requirement is that
// that `a < m*R` where `R` is the Montgomery `R` for `m`.
assert_eq!(other_prime_len_bits, m.len_bits());

// `limbs_from_mont_in_place` requires this.
assert_eq!(a.limbs.len(), m.limbs().len() * 2);

let mut tmp = [0; MODULUS_MAX_LIMBS];
let tmp = &mut tmp[..a.limbs.len()];
tmp.copy_from_slice(&a.limbs);
Expand Down Expand Up @@ -919,17 +912,16 @@ mod tests {
|section, test_case| {
assert_eq!(section, "");

struct MM {}
unsafe impl SmallerModulus<MM> for M {}
unsafe impl NotMuchSmallerModulus<MM> for M {}
struct M {}

let m_ = consume_modulus::<M>(test_case, "M", cpu_features);
let m = m_.modulus();
let expected_result = consume_elem(test_case, "R", &m);
let a =
consume_elem_unchecked::<MM>(test_case, "A", expected_result.limbs.len() * 2);
consume_elem_unchecked::<M>(test_case, "A", expected_result.limbs.len() * 2);
let other_modulus_len_bits = m_.len_bits();

let actual_result = elem_reduced(&a, &m);
let actual_result = elem_reduced(&a, &m, other_modulus_len_bits);
let oneRR = m_.oneRR();
let actual_result = elem_mul(oneRR.as_ref(), actual_result, &m);
assert_elem_eq(&actual_result, &expected_result);
Expand All @@ -950,7 +942,6 @@ mod tests {
struct N {}
struct QQ {}
unsafe impl SmallerModulus<N> for QQ {}
unsafe impl SlightlySmallerModulus<N> for QQ {}

let qq = consume_modulus::<QQ>(test_case, "QQ", cpu_features);
let expected_result = consume_elem::<QQ>(test_case, "R", &qq.modulus());
Expand Down
10 changes: 1 addition & 9 deletions src/arithmetic/bigint/boxed_limbs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use crate::{
error,
limb::{self, Limb, LimbMask, LIMB_BYTES},
};
use alloc::{borrow::ToOwned, boxed::Box, vec};
use alloc::{boxed::Box, vec};
use core::{
marker::PhantomData,
ops::{Deref, DerefMut},
Expand Down Expand Up @@ -82,14 +82,6 @@ impl<M> BoxedLimbs<M> {
Ok(r)
}

pub(super) fn minimal_width_from_unpadded(limbs: &[Limb]) -> Self {
debug_assert_ne!(limbs.last(), Some(&0));
Self {
limbs: limbs.to_owned().into_boxed_slice(),
m: PhantomData,
}
}

pub(super) fn from_be_bytes_padded_less_than(
input: untrusted::Input,
m: &Modulus<M>,
Expand Down
15 changes: 1 addition & 14 deletions src/arithmetic/bigint/modulus.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use super::{
montgomery::{Unencoded, RR},
n0::N0,
},
BoxedLimbs, Elem, Nonnegative, One, PublicModulus, SlightlySmallerModulus, SmallerModulus,
BoxedLimbs, Elem, Nonnegative, One, PublicModulus, SmallerModulus,
};
use crate::{
bits::BitLength,
Expand Down Expand Up @@ -124,19 +124,6 @@ impl<M> OwnedModulusWithOne<M> {
Self::from_boxed_limbs(limbs, cpu_features)
}

pub(crate) fn from_elem<L>(
elem: Elem<L, Unencoded>,
cpu_features: cpu::Features,
) -> Result<Self, error::KeyRejected>
where
M: SlightlySmallerModulus<L>,
{
Self::from_boxed_limbs(
BoxedLimbs::minimal_width_from_unpadded(&elem.limbs),
cpu_features,
)
}

fn from_boxed_limbs(
n: BoxedLimbs<M>,
cpu_features: cpu::Features,
Expand Down
15 changes: 15 additions & 0 deletions src/arithmetic/bigint_elem_reduced_tests.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,21 @@ R = fffffffdfffffd01000009000002f6fffdf403000312000402f3fff5f602fe080a0005fdfaff
A = fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffe00000000000001fffffffe00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000fffffffffffffe00000002000000fffffffe00000200fffffffffffdfffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffe00000000000003fffffffdfffffe00000002000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000fffffffffffffe000000000000010000000000000000
M = ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff00000000000000ffffffff00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000ffffffffffffff

# m = 2**1023 + 1 (the smallest 1024 bit odd value), a = (m * 2**1024) - 1 (m*R - 1)
R = 8000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000
A = 8000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff
M = 8000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001

# m = 2**1023 + 1 (the smallest 1024 bit odd value), a = 2**(2*1024) - 1 (the largest 2048-bit value).
R = 03
A = ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff
M = 8000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001

# m = 2**1024 - 1, a = 2**(2*1024) - 1
R = 00
A = ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff
M = ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff


# These vectors were adapted from BoringSSL's "Quotient" BIGNUM tests in the
# following ways:
Expand Down
50 changes: 8 additions & 42 deletions src/rsa/keypair.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,6 @@ pub struct KeyPair {
q: PrivatePrime<Q>,
qInv: bigint::Elem<P, R>,

// XXX: qq's `oneRR` isn't used and thus this is about twice as large as
// it needs to be, according to how it is used. Further, it appears to be
// completely unnecessary since `elem_reduced` seems to be able to reduce
// an `Elem<N>` directly to an `Elem<Q>`. TODO: Verify that is true and
// eliminate this.
qq: bigint::OwnedModulusWithOne<QQ>,

// TODO: Eliminate `q_mod_n` entirely since it is a bad space:time trade-off.
// Also, this is the only non-temporary `Elem` so if we eliminate this, we
// can make all `Elem`s temporary (borrowed) values.
Expand Down Expand Up @@ -336,8 +329,6 @@ impl KeyPair {
let p = PrivatePrime::new(p, dP, n_bits, cpu_features)?;
let q = PrivatePrime::new(q, dQ, n_bits, cpu_features)?;

let q_mod_n_decoded = q.modulus.to_elem(n);

// TODO: Step 5.i
//
// 3.b is unneeded since `n_bits` is derived here from `n`.
Expand All @@ -349,7 +340,8 @@ impl KeyPair {
// 0 < q < p < n. We check that q and p are close to sqrt(n) and then
// assume that these preconditions are enough to let us assume that
// checking p * q == 0 (mod n) is equivalent to checking p * q == n.
let q_mod_n = bigint::elem_mul(n_one.as_ref(), q_mod_n_decoded.clone(), n);
let q_mod_n = q.modulus.to_elem(n);
let q_mod_n = bigint::elem_mul(n_one.as_ref(), q_mod_n, n);
let p_mod_n = p.modulus.to_elem(n);
let pq_mod_n = bigint::elem_mul(&q_mod_n, p_mod_n, n);
if !pq_mod_n.is_zero() {
Expand Down Expand Up @@ -405,19 +397,13 @@ impl KeyPair {
bigint::verify_inverses_consttime(&qInv, q_mod_p, pm)
.map_err(|error::Unspecified| KeyRejected::inconsistent_components())?;

let qq = bigint::OwnedModulusWithOne::from_elem(
bigint::elem_mul(&q_mod_n, q_mod_n_decoded, n),
cpu_features,
)?;

// This should never fail since `n` and `e` were validated above.

Ok(Self {
p,
q,
qInv,
q_mod_n,
qq,
public: public_key,
})
}
Expand Down Expand Up @@ -501,16 +487,16 @@ impl<M: Prime> PrivatePrime<M> {
}
}

fn elem_exp_consttime<M, MM>(
c: &bigint::Elem<MM>,
fn elem_exp_consttime<M>(
c: &bigint::Elem<N>,
p: &PrivatePrime<M>,
other_prime_len_bits: BitLength,
) -> Result<bigint::Elem<M>, error::Unspecified>
where
M: bigint::NotMuchSmallerModulus<MM>,
M: Prime,
{
let m = &p.modulus.modulus();
let c_mod_m = bigint::elem_reduced(c, m);
let c_mod_m = bigint::elem_reduced(c, m, other_prime_len_bits);
// We could precompute `oneRRR = elem_squared(&p.oneRR`) as mentioned
// in the Smooth CRT-RSA paper.
let c_mod_m = bigint::elem_mul(p.modulus.oneRR().as_ref(), c_mod_m, m);
Expand All @@ -525,32 +511,13 @@ where
enum P {}
unsafe impl Prime for P {}
unsafe impl bigint::SmallerModulus<N> for P {}
unsafe impl bigint::NotMuchSmallerModulus<N> for P {}

#[derive(Copy, Clone)]
enum QQ {}
unsafe impl bigint::SmallerModulus<N> for QQ {}
unsafe impl bigint::NotMuchSmallerModulus<N> for QQ {}

// `q < p < 2*q` since `q` is slightly smaller than `p` (see below). Thus:
//
// q < p < 2*q
// q*q < p*q < 2*q*q.
// q**2 < n < 2*(q**2).
unsafe impl bigint::SlightlySmallerModulus<N> for QQ {}

#[derive(Copy, Clone)]
enum Q {}
unsafe impl Prime for Q {}
unsafe impl bigint::SmallerModulus<N> for Q {}
unsafe impl bigint::SmallerModulus<P> for Q {}

// q < p && `p.bit_length() == q.bit_length()` implies `q < p < 2*q`.
unsafe impl bigint::SlightlySmallerModulus<P> for Q {}

unsafe impl bigint::SmallerModulus<QQ> for Q {}
unsafe impl bigint::NotMuchSmallerModulus<QQ> for Q {}

impl KeyPair {
/// Computes the signature of `msg` and writes it into `signature`.
///
Expand Down Expand Up @@ -620,9 +587,8 @@ impl KeyPair {
let c = base;

// Step 2.b.i.
let m_1 = elem_exp_consttime(&c, &self.p)?;
let c_mod_qq = bigint::elem_reduced_once(&c, &self.qq.modulus());
let m_2 = elem_exp_consttime(&c_mod_qq, &self.q)?;
let m_1 = elem_exp_consttime(&c, &self.p, self.q.modulus.len_bits())?;
let m_2 = elem_exp_consttime(&c, &self.q, self.p.modulus.len_bits())?;

// Step 2.b.ii isn't needed since there are only two primes.

Expand Down

0 comments on commit 946ce87

Please sign in to comment.