Skip to content

Commit

Permalink
feat: rsa 1-2 ot
Browse files Browse the repository at this point in the history
  • Loading branch information
lonerapier committed Jul 19, 2024
1 parent 18bc1ad commit cf34f16
Show file tree
Hide file tree
Showing 6 changed files with 218 additions and 55 deletions.
86 changes: 59 additions & 27 deletions src/encryption/asymmetric/rsa/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,56 +17,88 @@ const fn mod_inverse(e: u64, totient: u64) -> u64 {
}
d
}
/// RSAKey struct
pub struct RSA {
/// pub key (e,n)
pub private_key: PrivateKey,
/// priv key (d,n)
pub public_key: PublicKey,
}

/// private key
pub struct PrivateKey {
/// gcd(e, totient) = 1
e: usize,
/// d x e mod totient = 1
d: usize,
/// modulus
n: usize,
pub n: usize,
}

/// public key
pub struct PublicKey {
/// d x e mod totient = 1
d: usize,
/// gcd(e, totient) = 1
e: usize,
/// modulus
n: usize,
pub n: usize,
}

impl RSA {
/// Encrypts a message using the RSA algorithm
#[allow(dead_code)]
/// RSA encryption
pub struct RSAEncryption {
/// RSA public key
pub public_key: PublicKey,
}

/// RSA decryption
pub struct RSADecryption {
/// RSA private key
pub private_key: PrivateKey,
}

impl RSAEncryption {
/// Encrypts a message using the RSA algorithm
/// C = P^e mod n
const fn encrypt(&self, message: u32) -> u32 {
message.pow(self.private_key.e as u32) % self.private_key.n as u32
pub const fn encrypt(&self, plaintext: u32) -> u32 {
let mut plaintext = plaintext;
let mut res = 1;
let mut exp = self.public_key.e as u32;

while exp > 0 {
if exp % 2 == 1 {
res = ((res as u64 * plaintext as u64) % self.public_key.n as u64) as u32;
}
plaintext = ((plaintext as u64).pow(2) % self.public_key.n as u64) as u32;
exp >>= 1;
}

res
}
}

#[allow(dead_code)]
impl RSADecryption {
/// Decrypts a cipher using the RSA algorithm
/// P = C^d mod n
const fn decrypt(&self, cipher: u32) -> u32 {
cipher.pow(self.public_key.d as u32) % self.public_key.n as u32
pub const fn decrypt(&self, ciphertext: u32) -> u32 {
let mut res = 1;
let mut ciphertext = ciphertext;
let mut exp = self.private_key.d as u32;

while exp > 0 {
if exp % 2 == 1 {
res = ((res as u64 * ciphertext as u64) % self.private_key.n as u64) as u32;
}
ciphertext = ((ciphertext as u64).pow(2) % self.private_key.n as u64) as u32;
exp >>= 1;
}

res
// ((ciphertext as u64).pow(self.private_key.d as u32) % self.private_key.n as u64) as u32
}
}

/// Key generation for the RSA algorithm
/// TODO: Implement a secure key generation algorithm using miller rabin primality test
pub fn rsa_key_gen(p: usize, q: usize) -> RSA {
pub fn rsa_key_gen(p: usize, q: usize) -> (RSAEncryption, RSADecryption) {
assert!(is_prime(p));
assert!(is_prime(q));
let n = p * q;
let e = generate_e(p, q);
let totient = euler_totient(p as u64, q as u64);
let d = mod_inverse(e, totient);
RSA { private_key: PrivateKey { e: e as usize, n }, public_key: PublicKey { d: d as usize, n } }
(RSAEncryption { public_key: PublicKey { e: e as usize, n } }, RSADecryption {
private_key: PrivateKey { d: d as usize, n },
})
}

/// Generates e value for the RSA algorithm
Expand All @@ -86,10 +118,10 @@ const fn generate_e(p: usize, q: usize) -> u64 {
panic!("Failed to find coprime e; totient should be greater than 1")
}

/// Generates a random prime number bigger than 1_000_000
pub fn random_prime(first_prime: usize) -> usize {
let mut n = 1_000_000;
while !is_prime(n) && n != first_prime {
/// Generates a random prime number bigger than `begin`
pub fn random_prime(begin: usize) -> usize {
let mut n = begin;
while !is_prime(n) {
n += 1;
}
n
Expand Down
64 changes: 36 additions & 28 deletions src/encryption/asymmetric/rsa/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,26 @@ fn test_euler_totient() {

#[test]
fn key_gen() {
let key = rsa_key_gen(PRIME_1, PRIME_2);
assert_eq!(key.public_key.n, PRIME_1 * PRIME_2);
assert_eq!(gcd(key.private_key.e as u64, euler_totient(PRIME_1 as u64, PRIME_2 as u64)), 1);

let key = rsa_key_gen(PRIME_2, PRIME_3);
assert_eq!(key.public_key.n, PRIME_2 * PRIME_3);
assert_eq!(gcd(key.private_key.e as u64, euler_totient(PRIME_2 as u64, PRIME_3 as u64)), 1);

let key = rsa_key_gen(PRIME_3, PRIME_1);
assert_eq!(key.public_key.n, PRIME_3 * PRIME_1);
assert_eq!(gcd(key.private_key.e as u64, euler_totient(PRIME_3 as u64, PRIME_1 as u64)), 1);
let (rsa_encrypt, rsa_decrypt) = rsa_key_gen(PRIME_1, PRIME_2);
assert_eq!(rsa_encrypt.public_key.n, PRIME_1 * PRIME_2);
assert_eq!(
gcd(rsa_decrypt.private_key.d as u64, euler_totient(PRIME_1 as u64, PRIME_2 as u64)),
1
);

let (rsa_encrypt, rsa_decrypt) = rsa_key_gen(PRIME_2, PRIME_3);
assert_eq!(rsa_encrypt.public_key.n, PRIME_2 * PRIME_3);
assert_eq!(
gcd(rsa_decrypt.private_key.d as u64, euler_totient(PRIME_2 as u64, PRIME_3 as u64)),
1
);

let (rsa_encrypt, rsa_decrypt) = rsa_key_gen(PRIME_3, PRIME_1);
assert_eq!(rsa_encrypt.public_key.n, PRIME_3 * PRIME_1);
assert_eq!(
gcd(rsa_decrypt.private_key.d as u64, euler_totient(PRIME_3 as u64, PRIME_1 as u64)),
1
);
}

#[test]
Expand Down Expand Up @@ -58,33 +67,32 @@ fn test_mod_inverse() {
#[test]
fn test_encrypt_decrypt() {
let message = 10;
let key = rsa_key_gen(PRIME_1, PRIME_2);
let cipher = key.encrypt(message);
let decrypted = key.decrypt(cipher);
let (rsa_encrypt, rsa_decrypt) = rsa_key_gen(PRIME_1, PRIME_2);
let cipher = rsa_encrypt.encrypt(message);
let decrypted = rsa_decrypt.decrypt(cipher);
assert_eq!(decrypted, message);

let key = rsa_key_gen(PRIME_2, PRIME_3);
let cipher = key.encrypt(message);
let decrypted = key.decrypt(cipher);
let (rsa_encrypt, rsa_decrypt) = rsa_key_gen(PRIME_2, PRIME_3);
let cipher = rsa_encrypt.encrypt(message);
let decrypted = rsa_decrypt.decrypt(cipher);
assert_eq!(decrypted, message);

let message = 10;
let key = rsa_key_gen(PRIME_3, PRIME_1);
let cipher = key.encrypt(message);
let decrypted = key.decrypt(cipher);
let (rsa_encrypt, rsa_decrypt) = rsa_key_gen(PRIME_3, PRIME_1);
let cipher = rsa_encrypt.encrypt(message);
let decrypted = rsa_decrypt.decrypt(cipher);
assert_eq!(decrypted, message);
}

#[test]
fn test_random_prime() {
let prime = random_prime(2);
assert!(is_prime(prime));
assert!(prime >= 1_000_000);
let message = u16::MAX as u32;
let (rsa_encrypt, rsa_decrypt) = rsa_key_gen(10007, 49999);
let cipher = rsa_encrypt.encrypt(message);
let decrypted = rsa_decrypt.decrypt(cipher);
assert_eq!(decrypted, message);
}

#[test]
fn test_random_prime_generation() {
let prime = random_prime(2);
fn test_random_prime() {
let prime = random_prime(1_000_000);
assert!(is_prime(prime));
assert!(prime >= 1_000_000);
}
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ pub mod encryption;
pub mod field;
pub mod hashes;
pub mod kzg;
pub mod ot;
pub mod polynomial;
pub mod tree;

Expand Down
Empty file added src/ot/README.md
Empty file.
2 changes: 2 additions & 0 deletions src/ot/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
//! Contains implementation of oblivious transfer and various extensions.
pub mod ot_rsa;
120 changes: 120 additions & 0 deletions src/ot/ot_rsa.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
//! Contains implementation of 1-out-of-2 OT using RSA encryption.
use rand::{thread_rng, Rng};

use crate::encryption::asymmetric::rsa::{rsa_key_gen, RSADecryption, RSAEncryption};

/// Sender that has two messages and wants to send one of it to [`OTReceiver`] without knowledge of
/// which one.
pub struct OTSender {
messages: [usize; 2],
random_messages: [usize; 2],
rsa_decrypt: RSADecryption,
}

/// Receiver wants to get access to one of the message that [`OTSender`] has without knowledge of
/// the other.
pub struct OTReceiver {
choice: bool,
key: usize,
}

impl OTSender {
/// create a new [`OTSender`] object.
/// ## Arguments
/// - `messages`: message that sender has access to
/// - `primes`: [`RSAEncryption`] primes
pub fn new(messages: [usize; 2], primes: [usize; 2]) -> (Self, RSAEncryption, [usize; 2]) {
let (rsa_encrypt, rsa_decrypt) = rsa_key_gen(primes[0], primes[1]);

let random_messages: [usize; 2] = rand::random();
(OTSender { messages, rsa_decrypt, random_messages }, rsa_encrypt, random_messages)
}

/// Encrypt messages with receiver's choice
pub fn encrypt(&self, v: usize) -> [usize; 2] {
let k0 = if v < self.random_messages[0] {
v + self.rsa_decrypt.private_key.n
- (self.random_messages[0] % self.rsa_decrypt.private_key.n)
} else {
v - self.random_messages[0]
};
let k1 = if v < self.random_messages[1] {
v + self.rsa_decrypt.private_key.n
- (self.random_messages[1] % self.rsa_decrypt.private_key.n)
} else {
v - self.random_messages[1]
};

let k0 = self.rsa_decrypt.decrypt((k0) as u32);
let k1 = self.rsa_decrypt.decrypt((k1) as u32);

println!("k0: {}, k1: {}", k0, k1);

let m0 = (self.messages[0] + k0 as usize) % self.rsa_decrypt.private_key.n;
let m1 = (self.messages[1] + k1 as usize) % self.rsa_decrypt.private_key.n;

[m0, m1]
}
}

impl OTReceiver {
/// create new [`OTReceiver`] object
/// ## Arguments
/// - `choice`: receiver message choice
pub fn new(choice: bool) -> Self {
let mut rng = thread_rng();
Self { choice, key: rng.gen::<u32>() as usize }
}

/// Encrypts receiver's choice out of sender's messages.
///
/// v = (x_b + k^e) mod N
pub fn encrypt(&self, rsa_encrypt: RSAEncryption, sender_messages: [usize; 2]) -> usize {
println!("key: {}", self.key % rsa_encrypt.public_key.n);
(rsa_encrypt.encrypt(self.key as u32) as usize + sender_messages[self.choice as usize])
% rsa_encrypt.public_key.n
}

/// Decrypts sender's encrypted message
///
/// m_b = (m'_b - k) mod N
/// ## Arguments:
/// - `messages`: sender's encrypted messages: m'_0, m'_1
/// - `modulus`: RSA modulus
pub fn decrypt(&self, messages: [usize; 2], modulus: usize) -> usize {
if messages[self.choice as usize] < self.key {
(messages[self.choice as usize] + modulus - (self.key % modulus)) % modulus
} else {
(messages[self.choice as usize] - self.key) % modulus
}
}
}

#[cfg(test)]
mod tests {

use super::*;

#[test]
fn ot_rsa() {
let mut rng = thread_rng();
let messages = [10, 2];
let random_primes = [19, 13];

let (ot_sender, rsa_encrypt, random_messages) = OTSender::new(messages, random_primes);

let modulus = rsa_encrypt.public_key.n;

let bit = rng.gen::<bool>();
let ot_receiver = OTReceiver::new(bit);

let v = ot_receiver.encrypt(rsa_encrypt, random_messages);

let encrypted_messages = ot_sender.encrypt(v);

let message = ot_receiver.decrypt(encrypted_messages, modulus);

assert_eq!(message, messages[bit as usize]);
}
}

0 comments on commit cf34f16

Please sign in to comment.