diff --git a/src/linalg/impl_linalg.rs b/src/linalg/impl_linalg.rs index 61b91eaed..2199c48fd 100644 --- a/src/linalg/impl_linalg.rs +++ b/src/linalg/impl_linalg.rs @@ -691,6 +691,30 @@ unsafe fn general_mat_vec_mul_impl( } } +/// Kronecker product of 2D matrices. +/// +/// The kronecker product of a LxN matrix A and a MxR matrix B is a (L*M)x(N*R) +/// matrix K formed by the block multiplication A_ij * B. +pub fn kron(a: &Array2, b: &Array2) -> Array2 +where + T: LinalgScalar, +{ + let dimar = a.shape()[0]; + let dimac = a.shape()[1]; + let dimbr = b.shape()[0]; + let dimbc = b.shape()[1]; + let mut out = Array2::zeros((dimar * dimbr, dimac * dimbc)); + for (mut chunk, elem) in out + .exact_chunks_mut((dimbr, dimbc)) + .into_iter() + .zip(a.iter()) + { + let v: Array2 = Array2::from_elem((dimbr, dimbc), *(elem)) * b; + chunk.assign(&v); + } + out +} + #[inline(always)] /// Return `true` if `A` and `B` are the same type fn same_type() -> bool { diff --git a/src/linalg/mod.rs b/src/linalg/mod.rs index 8575905cd..dc6964f9b 100644 --- a/src/linalg/mod.rs +++ b/src/linalg/mod.rs @@ -10,6 +10,7 @@ pub use self::impl_linalg::general_mat_mul; pub use self::impl_linalg::general_mat_vec_mul; +pub use self::impl_linalg::kron; pub use self::impl_linalg::Dot; mod impl_linalg; diff --git a/tests/oper.rs b/tests/oper.rs index ed612bad2..051728680 100644 --- a/tests/oper.rs +++ b/tests/oper.rs @@ -6,6 +6,7 @@ )] #![cfg(feature = "std")] use ndarray::linalg::general_mat_mul; +use ndarray::linalg::kron; use ndarray::prelude::*; use ndarray::{rcarr1, rcarr2}; use ndarray::{Data, LinalgScalar}; @@ -820,3 +821,65 @@ fn vec_mat_mul() { } } } + +#[test] +fn kron_square_f64() { + let a = arr2(&[[1.0, 0.0], [0.0, 1.0]]); + let b = arr2(&[[0.0, 1.0], [1.0, 0.0]]); + + assert_eq!( + kron(&a, &b), + arr2(&[ + [0.0, 1.0, 0.0, 0.0], + [1.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0], + [0.0, 0.0, 1.0, 0.0] + ]), + ); + + assert_eq!( + kron(&b, &a), + arr2(&[ + [0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 1.0], + [1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0] + ]), + ) +} + +#[test] +fn kron_square_i64() { + let a = arr2(&[[1, 0], [0, 1]]); + let b = arr2(&[[0, 1], [1, 0]]); + + assert_eq!( + kron(&a, &b), + arr2(&[[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 0, 1], [0, 0, 1, 0]]), + ); + + assert_eq!( + kron(&b, &a), + arr2(&[[0, 0, 1, 0], [0, 0, 0, 1], [1, 0, 0, 0], [0, 1, 0, 0]]), + ) +} + +#[test] +fn kron_i64() { + let a = arr2(&[[1, 0]]); + let b = arr2(&[[0, 1], [1, 0]]); + let r = arr2(&[[0, 1, 0, 0], [1, 0, 0, 0]]); + assert_eq!(kron(&a, &b), r); + + let a = arr2(&[[1, 0], [0, 0], [0, 1]]); + let b = arr2(&[[0, 1], [1, 0]]); + let r = arr2(&[ + [0, 1, 0, 0], + [1, 0, 0, 0], + [0, 0, 0, 0], + [0, 0, 0, 0], + [0, 0, 0, 1], + [0, 0, 1, 0], + ]); + assert_eq!(kron(&a, &b), r); +}