diff --git a/Cargo.toml b/Cargo.toml index 16c9ccb..7f345be 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,7 +6,7 @@ members = [ "harmonium-fft", "harmonium-resample", "harmonium-window", - #"harmonium-stft", + "harmonium-stft", ] [workspace.dependencies] @@ -24,7 +24,7 @@ harmonium-io = { path = "harmonium-io", default-features = false } harmonium-fft = { path = "harmonium-fft", default-features = false } harmonium-resample = { path = "harmonium-resample", default-features = false } harmonium-window = { path = "harmonium-window", default-features = false } -#harmonium-stft = { path = "harmonium-stft", default-features = false } +harmonium-stft = { path = "harmonium-stft", default-features = false } [profile.release] opt-level = 3 diff --git a/harmonium-core/src/array.rs b/harmonium-core/src/array.rs index 007f900..c260105 100644 --- a/harmonium-core/src/array.rs +++ b/harmonium-core/src/array.rs @@ -59,6 +59,11 @@ where pub fn as_slice_mut(&mut self) -> Option<&mut [T]> { self.0.as_slice_mut() } + + /// Returns `true` if the `HArray` shares the inner arc with another one. + pub fn is_shared(&self) -> bool { + todo!() + } } #[cfg(test)] diff --git a/harmonium-fft/src/fft.rs b/harmonium-fft/src/fft.rs index d6302a8..380605b 100644 --- a/harmonium-fft/src/fft.rs +++ b/harmonium-fft/src/fft.rs @@ -1,13 +1,12 @@ -use std::sync::Arc; - use harmonium_core::{array::HArray, errors::HError, errors::HResult}; use ndarray::{ArcArray1, ArcArray2, Axis, Dimension, Ix1, Ix2, IxDyn, Zip}; use realfft::{ComplexToReal, RealFftPlanner, RealToComplex}; use rustfft::{ num_complex::Complex, - num_traits::{Float, FloatConst}, + num_traits::{ConstZero, Float, FloatConst}, FftNum, FftPlanner, }; +use std::sync::Arc; #[derive(Clone)] pub struct Fft { @@ -17,23 +16,22 @@ pub struct Fft { #[derive(Clone)] pub struct RealFftForward { - fft: Arc>, - scratch_buffer: Arc<[Complex]>, + pub fft: Arc>, + pub scratch_buffer: Arc<[Complex]>, } #[derive(Clone)] pub struct RealFftInverse { - fft: Arc>, - scratch_buffer: Arc<[Complex]>, + pub fft: Arc>, + pub scratch_buffer: Arc<[Complex]>, } -impl Fft { +impl Fft { pub fn new_fft_forward(length: usize) -> Self { let mut planner = FftPlanner::new(); let fft = planner.plan_fft_forward(length); let scratch_len = fft.get_inplace_scratch_len(); - let zero = T::zero(); - let scratch_buffer = vec![Complex::new(zero, zero); scratch_len]; + let scratch_buffer = vec![Complex::::ZERO; scratch_len]; let scratch_buffer: Arc<[Complex]> = Arc::from(scratch_buffer); Self { @@ -46,8 +44,7 @@ impl Fft { let mut planner = FftPlanner::new(); let fft = planner.plan_fft_inverse(length); let scratch_len = fft.get_inplace_scratch_len(); - let zero = T::zero(); - let scratch_buffer = vec![Complex::new(zero, zero); scratch_len]; + let scratch_buffer = vec![Complex::::ZERO; scratch_len]; let scratch_buffer: Arc<[Complex]> = Arc::from(scratch_buffer); Self { @@ -57,13 +54,12 @@ impl Fft { } } -impl RealFftForward { +impl RealFftForward { pub fn new_real_fft_forward(length: usize) -> Self { let mut planner = RealFftPlanner::new(); let fft = planner.plan_fft_forward(length); - let zero = T::zero(); let scratch_len = fft.get_scratch_len(); - let scratch_buffer = vec![Complex::new(zero, zero); scratch_len]; + let scratch_buffer = vec![Complex::::ZERO; scratch_len]; let scratch_buffer: Arc<[Complex]> = Arc::from(scratch_buffer); Self { @@ -73,13 +69,12 @@ impl RealFftForward { } } -impl RealFftInverse { +impl RealFftInverse { pub fn new_real_fft_inverse(length: usize) -> Self { let mut planner = RealFftPlanner::new(); let fft = planner.plan_fft_inverse(length); - let zero = T::zero(); let scratch_len = fft.get_scratch_len(); - let scratch_buffer = vec![Complex::new(zero, zero); scratch_len]; + let scratch_buffer = vec![Complex::::ZERO; scratch_len]; let scratch_buffer: Arc<[Complex]> = Arc::from(scratch_buffer); Self { @@ -127,12 +122,11 @@ where impl ProcessRealFftForward for RealFftForward where - T: FftNum + Float + FloatConst, + T: FftNum + Float + FloatConst + ConstZero, { fn process(&mut self, harray: &mut HArray) -> HResult, Ix1>> { - let zero = T::zero(); let length = harray.len(); - let mut ndarray = ArcArray1::from_elem(length / 2 + 1, Complex::new(zero, zero)); + let mut ndarray = ArcArray1::from_elem(length / 2 + 1, Complex::::ZERO); let scratch_buffer = make_mut_slice(&mut self.scratch_buffer); self.fft .process_with_scratch( @@ -147,13 +141,12 @@ where impl ProcessRealFftInverse for RealFftInverse where - T: FftNum + Float + FloatConst, + T: FftNum + Float + FloatConst + ConstZero, { fn process(&mut self, harray: &mut HArray, Ix1>) -> HResult> { - let zero = T::zero(); let length = self.fft.len(); let scratch_buffer = make_mut_slice(&mut self.scratch_buffer); - let mut ndarray = ArcArray1::from_elem(length, zero); + let mut ndarray = ArcArray1::from_elem(length, T::ZERO); self.fft .process_with_scratch( harray.as_slice_mut().unwrap(), @@ -172,24 +165,23 @@ where fn process(&mut self, harray: &mut HArray, Ix2>) -> HResult<()> { let scratch_buffer = make_mut_slice(&mut self.scratch_buffer); - Zip::from(harray.0.lanes_mut(Axis(1))).for_each(|mut row| { + for mut row in harray.0.lanes_mut(Axis(1)) { self.fft .process_with_scratch(row.as_slice_mut().unwrap(), scratch_buffer); - }); + } Ok(()) } } impl ProcessRealFftForward for RealFftForward where - T: FftNum + Float + FloatConst, + T: FftNum + Float + FloatConst + ConstZero, { fn process(&mut self, harray: &mut HArray) -> HResult, Ix2>> { - let zero = T::zero(); let nrows = harray.0.nrows(); let ncols = harray.0.ncols(); let scratch_buffer = make_mut_slice(&mut self.scratch_buffer); - let mut ndarray = ArcArray2::from_elem((nrows, ncols / 2 + 1), Complex::new(zero, zero)); + let mut ndarray = ArcArray2::from_elem((nrows, ncols / 2 + 1), Complex::::ZERO); Zip::from(ndarray.lanes_mut(Axis(1))) .and(harray.0.lanes_mut(Axis(1))) @@ -209,14 +201,13 @@ where impl ProcessRealFftInverse for RealFftInverse where - T: FftNum + Float + FloatConst, + T: FftNum + Float + FloatConst + ConstZero, { fn process(&mut self, harray: &mut HArray, Ix2>) -> HResult> { - let zero = T::zero(); let length = self.fft.len(); let nrows = harray.0.nrows(); let scratch_buffer = make_mut_slice(&mut self.scratch_buffer); - let mut ndarray = ArcArray2::from_elem((nrows, length), zero); + let mut ndarray = ArcArray2::from_elem((nrows, length), T::ZERO); Zip::from(ndarray.lanes_mut(Axis(1))) .and(harray.0.lanes_mut(Axis(1))) @@ -262,15 +253,14 @@ where impl ProcessRealFftForward for RealFftForward where - T: FftNum + Float + FloatConst, + T: FftNum + Float + FloatConst + ConstZero, { fn process(&mut self, harray: &mut HArray) -> HResult, IxDyn>> { - let zero = T::zero(); let scratch_buffer = make_mut_slice(&mut self.scratch_buffer); match harray.ndim() { 1 => { let length = harray.len(); - let mut ndarray = ArcArray1::from_elem(length / 2 + 1, Complex::new(zero, zero)); + let mut ndarray = ArcArray1::from_elem(length / 2 + 1, Complex::::ZERO); self.fft .process_with_scratch( harray.as_slice_mut().unwrap(), @@ -284,8 +274,7 @@ where let nrows = harray.0.len_of(Axis(0)); let ncols = harray.0.len_of(Axis(1)); let mut ndarray = - ArcArray2::from_elem((nrows, ncols / 2 + 1), Complex::new(zero, zero)) - .into_dyn(); + ArcArray2::from_elem((nrows, ncols / 2 + 1), Complex::::ZERO).into_dyn(); Zip::from(ndarray.lanes_mut(Axis(1))) .and(harray.0.lanes_mut(Axis(1))) @@ -310,15 +299,14 @@ where impl ProcessRealFftInverse for RealFftInverse where - T: FftNum + Float + FloatConst, + T: FftNum + Float + FloatConst + ConstZero, { fn process(&mut self, harray: &mut HArray, IxDyn>) -> HResult> { - let zero = T::zero(); let length = self.fft.len(); let scratch_buffer = make_mut_slice(&mut self.scratch_buffer); match harray.ndim() { 1 => { - let mut ndarray = ArcArray1::from_elem(length, zero); + let mut ndarray = ArcArray1::from_elem(length, T::ZERO); self.fft .process_with_scratch( harray.as_slice_mut().unwrap(), @@ -330,7 +318,7 @@ where } 2 => { let nrows = harray.0.len_of(Axis(0)); - let mut ndarray = ArcArray2::from_elem((nrows, length), zero).into_dyn(); + let mut ndarray = ArcArray2::from_elem((nrows, length), T::ZERO).into_dyn(); Zip::from(ndarray.lanes_mut(Axis(1))) .and(harray.0.lanes_mut(Axis(1))) @@ -354,11 +342,11 @@ where } // replace this function by make_mut when in stable (it is currently, but doesn't work for slices.) -fn make_mut_slice(arc: &mut Arc<[T]>) -> &mut [T] { +pub fn make_mut_slice(arc: &mut Arc<[T]>) -> &mut [T] { if Arc::get_mut(arc).is_none() { *arc = Arc::from(&arc[..]); } - // Replace by get_mut_unchecked when available in stable. This can't failed since get_mut was + // Replace by get_mut_unchecked when available in stable. This can't fail since get_mut was // checked above. unsafe { Arc::get_mut(arc).unwrap_unchecked() } } diff --git a/harmonium-stft/Cargo.toml b/harmonium-stft/Cargo.toml index 9f1e3bd..f6f5d67 100644 --- a/harmonium-stft/Cargo.toml +++ b/harmonium-stft/Cargo.toml @@ -5,5 +5,8 @@ edition = "2021" [dependencies] harmonium-core = { workspace = true } +harmonium-fft = { workspace = true } rustfft = { workspace = true } +realfft = { workspace = true } ndarray = { workspace = true } + diff --git a/harmonium-stft/src/lib.rs b/harmonium-stft/src/lib.rs index 0a6ba6c..fc8194a 100644 --- a/harmonium-stft/src/lib.rs +++ b/harmonium-stft/src/lib.rs @@ -1 +1 @@ -mod stft; +pub mod stft; diff --git a/harmonium-stft/src/stft.rs b/harmonium-stft/src/stft.rs index 1b65093..863ef51 100644 --- a/harmonium-stft/src/stft.rs +++ b/harmonium-stft/src/stft.rs @@ -1,71 +1,603 @@ -use harmonium_core::array::HArray; -use ndarray::{ArcArray2, Dimension, Ix1, Ix2, IxDyn}; +use harmonium_core::{ + array::HArray, + errors::{HError, HResult}, +}; +use harmonium_fft::fft::{make_mut_slice, Fft, ProcessFft}; +use ndarray::{s, ArcArray, ArcArray2, Axis, Dimension, Ix1, Ix2, Ix3, IxDyn}; use rustfft::{ num_complex::{Complex, ComplexFloat}, - num_traits::Float, - Fft, FftNum, FftPlanner, + num_traits::{ConstZero, Float, FloatConst}, + FftNum, }; -use std::sync::Arc; +use std::num::NonZero; + +pub struct Stft(Fft); + +//pub struct RealStftForward { +// inner: RealFftForward, +// scratch_real_buffer: Arc<[T]>, +//} -struct StftPlanner +#[allow(clippy::len_without_is_empty)] +/// An `Stft` is used to process stft. It caches results internally, so when making more than one stft it is advisable to reuse the same `Stft` instance. +impl Stft where - T: FftNum, + T: FftNum + Float + FloatConst + ConstZero, { - fft: Arc>, - buffer: Vec, + pub fn new_stft_forward(length: usize) -> Self { + Stft(Fft::new_fft_forward(length)) + } + + pub fn len(&self) -> usize { + self.0.fft.len() + } } -impl StftPlanner +/// An `RealStftForward` is used to process real stft. It caches results internally, so when making more than one stft it is advisable to reuse the same `RealdStftForward` instance. +//impl RealStftForward +//where +// T: FftNum + Float + FloatConst + ConstZero, +//{ +// pub fn new_real_stft_forward(length: usize) -> Self { +// let scratch_real_buffer = vec![T::ZERO; length]; +// let scratch_real_buffer: Arc<[T]> = Arc::from(scratch_real_buffer); +// +// RealStftForward { +// inner: RealFftForward::new_real_fft_forward(length), +// scratch_real_buffer, +// } +// } +// +// pub fn len(&self) -> usize { +// self.inner.fft.len() +// } +//} + +pub trait ProcessStft where - T: FftNum + ComplexFloat, + T: FftNum + Float + FloatConst, + D: Dimension, { - fn new(fft_length: usize) -> Self { - let mut fft_planner = FftPlanner::new(); - let fft = fft_planner.plan_fft_forward(fft_length); - let buffer = vec![T::zero(); fft_length]; - StftPlanner { fft, buffer } - } + /// Computes the Fourier transform of short overlapping windows of the input. + /// The function does not normalize outputs. + /// + /// # Arguments + /// `harray` - A complex 1D or 2D HArray to be used as input. + /// `hop_length` - The distance between neighboring sliding window frames. + /// `window_length` - Size of window frame. Must be greater than the fft length. + /// `window` - An optional window function. `window.len()` must be equal to `window_length`. + fn process( + &mut self, + harray: &HArray, D>, + hop_length: NonZero, + window_length: NonZero, + window: Option<&[T]>, + ) -> HResult, D::Larger>>; +} + +pub trait ProcessRealStftForward +where + T: FftNum + Float + FloatConst, + D: Dimension, +{ + /// Computes the Fourier transform of short overlapping windows of the input. + /// For each forward FFT, transforms a real signal of length `N` to a complex-valued spectrum of length `N/2+1` (with `N/2` rounded down). + /// The function does not normalize outputs. + /// + /// # Arguments + /// `harray` - A real-valued 1D or 2D HArray to be used as input. + /// `hop_length` - The distance between neighboring sliding window frames. + /// `window_length` - Size of window frame. Must be greater than the fft length. + /// `window` - An optional window function. `window.len()` must be equal to `window_length`. + fn process( + &mut self, + harray: &HArray, + hop_length: NonZero, + window_length: NonZero, + window: Option<&[T]>, + ) -> HResult, D::Larger>>; +} + +impl ProcessStft for Stft +where + T: FftNum + Float + FloatConst + ConstZero, +{ + fn process( + &mut self, + harray: &HArray, Ix1>, + hop_length: NonZero, + window_length: NonZero, + window: Option<&[T]>, + ) -> HResult, Ix2>> { + let fft_length = self.len(); // Since fft_length is checked to be >= window_length and window_length is NonZero, we can be sure fft_length > 0. + let window_length = window_length.get(); + let hop_length = hop_length.get(); + let length = harray.len(); + + if fft_length < window_length || fft_length > length { + return Err(HError::OutOfSpecError( + "Expected harray.len() >= fft_length >= window_length.".to_string(), + )); + } + if let Some(slice) = window { + if slice.len() != window_length { + return Err(HError::OutOfSpecError( + "Expected window.len() == window_length.".to_string(), + )); + } + } + + let n_fft = 1 + (length - fft_length) / hop_length; + let mut stft_ndarray = ArcArray2::>::zeros((n_fft, fft_length)); + + // Center PAD the window if fft_length > window_length. + let left = (fft_length - window_length) / 2; + let right = left + window_length; + let slice_info = s![.., left..right]; + let slice_info_1d = s![left..right]; + + for (mut row, win) in stft_ndarray + .slice_mut(slice_info) + .lanes_mut(Axis(1)) + .into_iter() + .zip(harray.0.windows(fft_length).into_iter().step_by(hop_length)) + { + row.assign(&win.slice(slice_info_1d)); + if let Some(w) = window { + row.as_slice_mut().unwrap().apply_window(w); + } + } + + let mut output = HArray(stft_ndarray); + self.0.process(&mut output)?; - fn len(&self) -> usize { - self.fft.len() + Ok(output) } } -// The input buffer is used as scratch space, so the contents of input should be considered garbage after calling. -pub trait Stft +impl ProcessStft for Stft where - T: FftNum + ComplexFloat, - D: Dimension, + T: FftNum + Float + FloatConst + ConstZero, { fn process( &mut self, - harray: HArray, - hop_length: usize, - window_length: usize, + harray: &HArray, Ix2>, + hop_length: NonZero, + window_length: NonZero, window: Option<&[T]>, - ) -> HArray; + ) -> HResult, Ix3>> { + let fft_length = self.len(); // Since fft_length is checked to be >= window_length and window_length is NonZero, we can be sure fft_length > 0. + let window_length = window_length.get(); + let hop_length = hop_length.get(); + let nrows = harray.0.len_of(Axis(0)); + let ncols = harray.0.len_of(Axis(1)); + + if fft_length < window_length || fft_length > ncols { + return Err(HError::OutOfSpecError( + "Expected ncols >= fft_length >= window_length.".to_string(), + )); + } + if let Some(slice) = window { + if slice.len() != window_length { + return Err(HError::OutOfSpecError( + "Expected window.len() == window_length.".to_string(), + )); + } + } + + let n_fft = 1 + (ncols - fft_length) / hop_length; + let mut stft_ndarray = ArcArray::, Ix3>::zeros((nrows, n_fft, fft_length)); + + // Center PAD the window if fft_length > window_length. + let left = (fft_length - window_length) / 2; + let right = left + window_length; + let slice_info = s![.., left..right]; + let slice_info_1d = s![left..right]; + let scratch_buffer = make_mut_slice(&mut self.0.scratch_buffer); + + for (mut matrix, win) in stft_ndarray.axis_iter_mut(Axis(1)).zip( + harray + .0 + .windows((nrows, fft_length)) + .into_iter() + .step_by(hop_length), + ) { + matrix.slice_mut(slice_info).assign(&win.slice(slice_info)); + + for mut col in matrix.lanes_mut(Axis(1)) { + if let Some(w) = window { + col.slice_mut(slice_info_1d) + .as_slice_mut() + .unwrap() + .apply_window(w); + } + self.0 + .fft + .process_with_scratch(col.as_slice_mut().unwrap(), scratch_buffer); + } + } + + let output = HArray(stft_ndarray); + + Ok(output) + } } -impl Stft for StftPlanner +impl ProcessStft for Stft where - T: FftNum + ComplexFloat, + T: FftNum + Float + FloatConst + ConstZero, { fn process( &mut self, - harray: HArray, - hop_length: usize, - window_length: usize, + harray: &HArray, IxDyn>, + hop_length: NonZero, + window_length: NonZero, window: Option<&[T]>, - ) -> HArray { - assert!(hop_length > 0); + ) -> HResult, IxDyn>> { + let fft_length = self.len(); // Since fft_length is checked to be >= window_length and window_length is NonZero, we can be sure fft_length > 0. + let window_length = window_length.get(); + let hop_length = hop_length.get(); + + // Center PAD the window if fft_length > window_length. + let left = (fft_length - window_length) / 2; + let right = left + window_length; + + match harray.ndim() { + 1 => { + let length = harray.len(); + + if fft_length < window_length || fft_length > length { + return Err(HError::OutOfSpecError( + "Expected harray.len() >= fft_length >= window_length.".to_string(), + )); + } + if let Some(slice) = window { + if slice.len() != window_length { + return Err(HError::OutOfSpecError( + "Expected window.len() == window_length.".to_string(), + )); + } + } + + let n_fft = 1 + (length - fft_length) / hop_length; + let mut stft_ndarray = ArcArray2::>::zeros((n_fft, fft_length)); + + let slice_info = s![.., left..right]; + let slice_info_1d = s![left..right]; + + for (mut row, win) in stft_ndarray + .slice_mut(slice_info) + .lanes_mut(Axis(1)) + .into_iter() + .zip( + harray + .0 + .windows(IxDyn(&[fft_length])) + .into_iter() + .step_by(hop_length), + ) + { + row.assign(&win.slice(slice_info_1d)); + if let Some(w) = window { + row.as_slice_mut().unwrap().apply_window(w); + } + } + + let mut output = HArray(stft_ndarray.into_dyn()); + self.0.process(&mut output)?; + + Ok(output) + } + 2 => { + let nrows = harray.0.len_of(Axis(0)); + let ncols = harray.0.len_of(Axis(1)); - let fft_length = self.len(); - let n_fft = (harray.len() - window_length) / hop_length + 1; - //let stft_ndarray = ArcArray2::zeros((fft_length, n_fft)); - todo!() + if fft_length < window_length || fft_length > ncols { + return Err(HError::OutOfSpecError( + "Expected ncols >= fft_length >= window_length.".to_string(), + )); + } + if let Some(slice) = window { + if slice.len() != window_length { + return Err(HError::OutOfSpecError( + "Expected window.len() == window_length.".to_string(), + )); + } + } + + let n_fft = 1 + (ncols - fft_length) / hop_length; + let mut stft_ndarray = + ArcArray::, Ix3>::zeros((nrows, n_fft, fft_length)); + + let slice_info = s![.., left..right]; + let slice_info_1d = s![left..right]; + let scratch_buffer = make_mut_slice(&mut self.0.scratch_buffer); + + for (mut matrix, win) in stft_ndarray.axis_iter_mut(Axis(1)).zip( + harray + .0 + .windows(IxDyn(&[nrows, fft_length])) + .into_iter() + .step_by(hop_length), + ) { + matrix.slice_mut(slice_info).assign(&win.slice(slice_info)); + + for mut col in matrix.lanes_mut(Axis(1)) { + if let Some(w) = window { + col.slice_mut(slice_info_1d) + .as_slice_mut() + .unwrap() + .apply_window(w); + } + self.0 + .fft + .process_with_scratch(col.as_slice_mut().unwrap(), scratch_buffer); + } + } + + let output = HArray(stft_ndarray.into_dyn()); + + Ok(output) + } + _ => Err(HError::OutOfSpecError( + "The HArray's ndim should be 1 or 2.".into(), + )), + } } } +//impl ProcessRealStftForward for RealStftForward +//where +// T: FftNum + Float + FloatConst + ConstZero, +//{ +// fn process( +// &mut self, +// harray: &HArray, +// hop_length: NonZero, +// window_length: NonZero, +// window: Option<&[T]>, +// ) -> HResult, Ix2>> { +// let fft_length = self.len(); // Since fft_length is checked to be >= window_length and window_length is NonZero, we can be sure fft_length > 0. +// let real_fft_length = fft_length / 2 + 1; +// let window_length = window_length.get(); +// let hop_length = hop_length.get(); +// let length = harray.len(); +// let scratch_real_buffer = make_mut_slice(&mut self.scratch_real_buffer); +// let scratch_buffer = make_mut_slice(&mut self.inner.scratch_buffer); +// +// if fft_length < window_length || fft_length > length { +// return Err(HError::OutOfSpecError( +// "Expected harray.len() >= fft_length >= window_length.".to_string(), +// )); +// } +// if let Some(slice) = window { +// if slice.len() != window_length { +// return Err(HError::OutOfSpecError( +// "Expected window.len() == window_length.".to_string(), +// )); +// } +// } +// +// let n_fft = 1 + (length - fft_length) / hop_length; +// let mut stft_ndarray = ArcArray2::>::zeros((n_fft, real_fft_length)); +// +// // Center PAD the window if fft_length > window_length. +// let left = (fft_length - window_length) / 2; +// let right = left + window_length; +// let slice_info = s![.., left..right]; +// let slice_info_1d = s![left..right]; +// +// for (mut row, win) in stft_ndarray +// .slice_mut(slice_info) +// .lanes_mut(Axis(1)) +// .into_iter() +// .zip(harray.0.windows(fft_length).into_iter().step_by(hop_length)) +// { +// let scratch_real_buffer_slice = &mut scratch_real_buffer[left..right]; +// scratch_real_buffer_slice.copy_from_slice(win.slice(slice_info_1d).as_slice().unwrap()); +// if let Some(w) = window { +// scratch_real_buffer_slice.apply_window(w); +// } +// self.inner.fft.process_with_scratch(scratch_real_buffer, row.as_slice_mut().unwrap(), scratch_buffer).unwrap(); +// } +// +// let output = HArray(stft_ndarray); +// +// Ok(output) +// } +//} + +//impl ProcessRealStftForward for RealStftForward +//where +// T: FftNum + Float + FloatConst + ConstZero, +//{ +// fn process( +// &mut self, +// harray: &HArray, +// hop_length: NonZero, +// window_length: NonZero, +// window: Option<&[T]>, +// ) -> HResult, Ix3>> { +// let fft_length = self.len(); // Since fft_length is checked to be >= window_length and window_length is NonZero, we can be sure fft_length > 0. +// let window_length = window_length.get(); +// let hop_length = hop_length.get(); +// let nrows = harray.0.len_of(Axis(0)); +// let ncols = harray.0.len_of(Axis(1)); +// +// if fft_length < window_length || fft_length > ncols { +// return Err(HError::OutOfSpecError( +// "Expected ncols >= fft_length >= window_length.".to_string(), +// )); +// } +// if let Some(slice) = window { +// if slice.len() != window_length { +// return Err(HError::OutOfSpecError( +// "Expected window.len() == window_length.".to_string(), +// )); +// } +// } +// +// let n_fft = 1 + (ncols - fft_length) / hop_length; +// let mut stft_ndarray = ArcArray::, Ix3>::zeros((nrows, n_fft, fft_length)); +// +// // Center PAD the window if fft_length > window_length. +// let left = (fft_length - window_length) / 2; +// let right = left + window_length; +// let slice_info = s![.., left..right]; +// let slice_info_1d = s![left..right]; +// let scratch_buffer = make_mut_slice(&mut self.0.scratch_buffer); +// +// for (mut matrix, win) in stft_ndarray.axis_iter_mut(Axis(1)).zip( +// harray +// .0 +// .windows((nrows, fft_length)) +// .into_iter() +// .step_by(hop_length), +// ) { +// matrix.slice_mut(slice_info).assign(&win.slice(slice_info)); +// +// for mut col in matrix.lanes_mut(Axis(1)) { +// if let Some(w) = window { +// col.slice_mut(slice_info_1d) +// .as_slice_mut() +// .unwrap() +// .apply_window(w); +// } +// self.0 +// .fft +// .process_with_scratch(col.as_slice_mut().unwrap(), scratch_buffer); +// } +// } +// +// let output = HArray(stft_ndarray); +// +// Ok(output) +// } +//} +// +//impl ProcessRealStftForward for RealStftForward +//where +// T: FftNum + Float + FloatConst + ConstZero, +//{ +// fn process( +// &mut self, +// harray: &HArray, +// hop_length: NonZero, +// window_length: NonZero, +// window: Option<&[T]>, +// ) -> HResult, IxDyn>> { +// let fft_length = self.len(); // Since fft_length is checked to be >= window_length and window_length is NonZero, we can be sure fft_length > 0. +// let window_length = window_length.get(); +// let hop_length = hop_length.get(); +// +// // Center PAD the window if fft_length > window_length. +// let left = (fft_length - window_length) / 2; +// let right = left + window_length; +// +// match harray.ndim() { +// 1 => { +// let length = harray.len(); +// +// if fft_length < window_length || fft_length > length { +// return Err(HError::OutOfSpecError( +// "Expected harray.len() >= fft_length >= window_length.".to_string(), +// )); +// } +// if let Some(slice) = window { +// if slice.len() != window_length { +// return Err(HError::OutOfSpecError( +// "Expected window.len() == window_length.".to_string(), +// )); +// } +// } +// +// let n_fft = 1 + (length - fft_length) / hop_length; +// let mut stft_ndarray = ArcArray2::>::zeros((n_fft, fft_length)); +// +// let slice_info = s![.., left..right]; +// let slice_info_1d = s![left..right]; +// +// for (mut row, win) in stft_ndarray +// .slice_mut(slice_info) +// .lanes_mut(Axis(1)) +// .into_iter() +// .zip( +// harray +// .0 +// .windows(IxDyn(&[fft_length])) +// .into_iter() +// .step_by(hop_length), +// ) +// { +// row.assign(&win.slice(slice_info_1d)); +// if let Some(w) = window { +// row.as_slice_mut().unwrap().apply_window(w); +// } +// } +// +// let mut output = HArray(stft_ndarray.into_dyn()); +// self.0.process(&mut output)?; +// +// Ok(output) +// } +// 2 => { +// let nrows = harray.0.len_of(Axis(0)); +// let ncols = harray.0.len_of(Axis(1)); +// +// if fft_length < window_length || fft_length > ncols { +// return Err(HError::OutOfSpecError( +// "Expected ncols >= fft_length >= window_length.".to_string(), +// )); +// } +// if let Some(slice) = window { +// if slice.len() != window_length { +// return Err(HError::OutOfSpecError( +// "Expected window.len() == window_length.".to_string(), +// )); +// } +// } +// +// let n_fft = 1 + (ncols - fft_length) / hop_length; +// let mut stft_ndarray = +// ArcArray::, Ix3>::zeros((nrows, n_fft, fft_length)); +// +// let slice_info = s![.., left..right]; +// let slice_info_1d = s![left..right]; +// let scratch_buffer = make_mut_slice(&mut self.0.scratch_buffer); +// +// for (mut matrix, win) in stft_ndarray.axis_iter_mut(Axis(1)).zip( +// harray +// .0 +// .windows(IxDyn(&[nrows, fft_length])) +// .into_iter() +// .step_by(hop_length), +// ) { +// matrix.slice_mut(slice_info).assign(&win.slice(slice_info)); +// +// for mut col in matrix.lanes_mut(Axis(1)) { +// if let Some(w) = window { +// col.slice_mut(slice_info_1d) +// .as_slice_mut() +// .unwrap() +// .apply_window(w); +// } +// self.0 +// .fft +// .process_with_scratch(col.as_slice_mut().unwrap(), scratch_buffer); +// } +// } +// +// let output = HArray(stft_ndarray.into_dyn()); +// +// Ok(output) +// } +// _ => Err(HError::OutOfSpecError( +// "The HArray's ndim should be 1 or 2.".into(), +// )), +// } +// } +//} + trait ApplyWindow { fn apply_window(&mut self, window: &[T]); } @@ -83,7 +615,7 @@ where impl ApplyWindow for [T] where - T: Float, + T: ComplexFloat, { fn apply_window(&mut self, window: &[T]) { for (i, w) in self.iter_mut().zip(window.iter()) { @@ -95,7 +627,350 @@ where #[cfg(test)] mod tests { use super::*; + use harmonium_core::{comparison::compare_harray_complex, conversions::IntoDynamic}; #[test] - fn it_works() {} + fn stft_1d_test() { + let fft_length = [3_usize, 5, 5, 5]; + let one_hop_length = NonZero::::new(1).unwrap(); + let two_hop_length = NonZero::::new(2).unwrap(); + let hop_length = [ + one_hop_length, + one_hop_length, + two_hop_length, + two_hop_length, + ]; + let result_no_pad = vec![ + Complex::new(9.0, 12.0), + Complex::new(-4.732051, -1.2679492), + Complex::new(-1.2679492, -4.732051), + Complex::new(15.0, 18.0), + Complex::new(-4.732051, -1.2679492), + Complex::new(-1.2679492, -4.732051), + Complex::new(21.0, 24.0), + Complex::new(-4.732051, -1.2679492), + Complex::new(-1.2679492, -4.732051), + Complex::new(27.0, 30.0), + Complex::new(-4.732051, -1.2679492), + Complex::new(-1.2679492, -4.732051), + ]; + let result_pad = vec![ + Complex::new(15.0, 18.0), + Complex::new(-6.15250, -11.76777), + Complex::new(5.534407, -2.575299), + Complex::new(-2.972101, 4.755639), + Complex::new(-11.40981, -8.41257), + Complex::new(21.0, 24.0), + Complex::new(-6.86842, -16.28792), + Complex::new(6.328012, -4.132835), + Complex::new(-4.529637, 5.549243), + Complex::new(-15.92996, -9.12849), + ]; + let result_pad_hop_length = vec![ + Complex::new(15.0, 18.0), + Complex::new(-6.1525, -11.76777), + Complex::new(5.534407, -2.575299), + Complex::new(-2.972101, 4.755639), + Complex::new(-11.40981, -8.41257), + ]; + let result_pad_hop_length_window = vec![ + Complex::new(34.0, 40.0), + Complex::new(-27.40167, -24.27608), + Complex::new(20.9163, -4.33643), + Complex::new(-6.61134, 20.11352), + Complex::new(-20.90328, -31.50101), + ]; + + let result = [ + result_no_pad, + result_pad, + result_pad_hop_length, + result_pad_hop_length_window, + ]; + + let window = [None, None, None, Some([1., 2., 3.].as_slice())]; + + let input = vec![ + Complex::new(1_f32, 2_f32), + Complex::new(3_f32, 4_f32), + Complex::new(5_f32, 6_f32), + Complex::new(7_f32, 8_f32), + Complex::new(9_f32, 10_f32), + Complex::new(11_f32, 12_f32), + ]; + let length = input.len(); + let window_length = NonZero::new(3).unwrap(); + + for (((fft_length, hop_length), result), window) in fft_length + .into_iter() + .zip(hop_length.into_iter()) + .zip(result.iter()) + .zip(window.into_iter()) + { + // Ix1 test. + let harray = HArray::new_from_shape_vec(length, input.clone()).unwrap(); + let mut stft = Stft::::new_stft_forward(fft_length); + let stft_harray = stft + .process(&harray, hop_length, window_length, window) + .unwrap(); + let n_fft = 1 + (harray.len() - fft_length) / hop_length; + let rhs = HArray::new_from_shape_vec((n_fft, fft_length), result.clone()).unwrap(); + assert!(compare_harray_complex(&stft_harray, &rhs)); + + // IxDyn test. + let harray = HArray::new_from_shape_vec(length, input.clone()) + .unwrap() + .into_dynamic(); + let mut stft = Stft::::new_stft_forward(fft_length); + let stft_harray = stft + .process(&harray, hop_length, window_length, window) + .unwrap(); + let n_fft = 1 + (harray.len() - fft_length) / hop_length; + let rhs = HArray::new_from_shape_vec((n_fft, fft_length), result.clone()) + .unwrap() + .into_dynamic(); + assert!(compare_harray_complex(&stft_harray, &rhs)); + } + } + + #[test] + fn stft_2d_test() { + let fft_length = [3_usize, 5, 5, 5]; + let one_hop_length = NonZero::::new(1).unwrap(); + let two_hop_length = NonZero::::new(2).unwrap(); + let hop_length = [ + one_hop_length, + one_hop_length, + two_hop_length, + two_hop_length, + ]; + let result_no_pad = vec![ + Complex::new(9.0, 12.0), + Complex::new(-4.732051, -1.2679492), + Complex::new(-1.2679492, -4.732051), + Complex::new(15.0, 18.0), + Complex::new(-4.732051, -1.2679492), + Complex::new(-1.2679492, -4.732051), + Complex::new(21.0, 24.0), + Complex::new(-4.732051, -1.2679492), + Complex::new(-1.2679492, -4.732051), + Complex::new(27.0, 30.0), + Complex::new(-4.732051, -1.2679492), + Complex::new(-1.2679492, -4.732051), + Complex::new(9.0, 12.0), + Complex::new(-4.732051, -1.2679492), + Complex::new(-1.2679492, -4.732051), + Complex::new(15.0, 18.0), + Complex::new(-4.732051, -1.2679492), + Complex::new(-1.2679492, -4.732051), + Complex::new(21.0, 24.0), + Complex::new(-4.732051, -1.2679492), + Complex::new(-1.2679492, -4.732051), + Complex::new(27.0, 30.0), + Complex::new(-4.732051, -1.2679492), + Complex::new(-1.2679492, -4.732051), + ]; + let result_pad = vec![ + Complex::new(15.0, 18.0), + Complex::new(-6.15250, -11.76777), + Complex::new(5.534407, -2.575299), + Complex::new(-2.972101, 4.755639), + Complex::new(-11.40981, -8.41257), + Complex::new(21.0, 24.0), + Complex::new(-6.86842, -16.28792), + Complex::new(6.328012, -4.132835), + Complex::new(-4.529637, 5.549243), + Complex::new(-15.92996, -9.12849), + Complex::new(15.0, 18.0), + Complex::new(-6.15250, -11.76777), + Complex::new(5.534407, -2.575299), + Complex::new(-2.972101, 4.755639), + Complex::new(-11.40981, -8.41257), + Complex::new(21.0, 24.0), + Complex::new(-6.86842, -16.28792), + Complex::new(6.328012, -4.132835), + Complex::new(-4.529637, 5.549243), + Complex::new(-15.92996, -9.12849), + ]; + let result_pad_hop_length = vec![ + Complex::new(15.0, 18.0), + Complex::new(-6.1525, -11.76777), + Complex::new(5.534407, -2.575299), + Complex::new(-2.972101, 4.755639), + Complex::new(-11.40981, -8.41257), + Complex::new(15.0, 18.0), + Complex::new(-6.1525, -11.76777), + Complex::new(5.534407, -2.575299), + Complex::new(-2.972101, 4.755639), + Complex::new(-11.40981, -8.41257), + ]; + let result_pad_hop_length_window = vec![ + Complex::new(34.0, 40.0), + Complex::new(-27.40167, -24.27608), + Complex::new(20.9163, -4.33643), + Complex::new(-6.61134, 20.11352), + Complex::new(-20.90328, -31.50101), + Complex::new(34.0, 40.0), + Complex::new(-27.40167, -24.27608), + Complex::new(20.9163, -4.33643), + Complex::new(-6.61134, 20.11352), + Complex::new(-20.90328, -31.50101), + ]; + + let result = [ + result_no_pad, + result_pad, + result_pad_hop_length, + result_pad_hop_length_window, + ]; + + let window = [None, None, None, Some([1., 2., 3.].as_slice())]; + + let input = vec![ + Complex::new(1_f32, 2_f32), + Complex::new(3_f32, 4_f32), + Complex::new(5_f32, 6_f32), + Complex::new(7_f32, 8_f32), + Complex::new(9_f32, 10_f32), + Complex::new(11_f32, 12_f32), + Complex::new(1_f32, 2_f32), + Complex::new(3_f32, 4_f32), + Complex::new(5_f32, 6_f32), + Complex::new(7_f32, 8_f32), + Complex::new(9_f32, 10_f32), + Complex::new(11_f32, 12_f32), + ]; + let length = input.len(); + let window_length = NonZero::new(3).unwrap(); + + for (((fft_length, hop_length), result), window) in fft_length + .into_iter() + .zip(hop_length.into_iter()) + .zip(result.iter()) + .zip(window.into_iter()) + { + // Ix2 test. + let harray = HArray::new_from_shape_vec((2, length / 2), input.clone()).unwrap(); + let mut stft = Stft::::new_stft_forward(fft_length); + let stft_harray = stft + .process(&harray, hop_length, window_length, window) + .unwrap(); + let ncols = harray.0.len_of(Axis(1)); + let n_fft = 1 + (ncols - fft_length) / hop_length; + let lhs = HArray::new_from_shape_vec((2, n_fft, fft_length), result.clone()).unwrap(); + assert!(compare_harray_complex(&stft_harray, &lhs)); + + // IxDyn test. + let harray = HArray::new_from_shape_vec((2, length / 2), input.clone()) + .unwrap() + .into_dynamic(); + let mut stft = Stft::::new_stft_forward(fft_length); + let stft_harray = stft + .process(&harray, hop_length, window_length, window) + .unwrap(); + let ncols = harray.0.len_of(Axis(1)); + let n_fft = 1 + (ncols - fft_length) / hop_length; + let lhs = HArray::new_from_shape_vec((2, n_fft, fft_length), result.clone()) + .unwrap() + .into_dynamic(); + assert!(compare_harray_complex(&stft_harray, &lhs)); + } + } + + //#[test] + //fn real_stft_1d_test() { + // let fft_length = [3_usize, 5, 5, 5]; + // let one_hop_length = NonZero::::new(1).unwrap(); + // let two_hop_length = NonZero::::new(2).unwrap(); + // let hop_length = [ + // one_hop_length, + // one_hop_length, + // two_hop_length, + // two_hop_length, + // ]; + // let result_no_pad = vec![ + // Complex::new(9.0, 12.0), + // Complex::new(-4.732051, -1.2679492), + // Complex::new(-1.2679492, -4.732051), + // Complex::new(15.0, 18.0), + // Complex::new(-4.732051, -1.2679492), + // Complex::new(-1.2679492, -4.732051), + // Complex::new(21.0, 24.0), + // Complex::new(-4.732051, -1.2679492), + // Complex::new(-1.2679492, -4.732051), + // Complex::new(27.0, 30.0), + // Complex::new(-4.732051, -1.2679492), + // Complex::new(-1.2679492, -4.732051), + // ]; + // let result_pad = vec![ + // Complex::new(15.0, 18.0), + // Complex::new(-6.15250, -11.76777), + // Complex::new(5.534407, -2.575299), + // Complex::new(-2.972101, 4.755639), + // Complex::new(-11.40981, -8.41257), + // Complex::new(21.0, 24.0), + // Complex::new(-6.86842, -16.28792), + // Complex::new(6.328012, -4.132835), + // Complex::new(-4.529637, 5.549243), + // Complex::new(-15.92996, -9.12849), + // ]; + // let result_pad_hop_length = vec![ + // Complex::new(15.0, 18.0), + // Complex::new(-6.1525, -11.76777), + // Complex::new(5.534407, -2.575299), + // Complex::new(-2.972101, 4.755639), + // Complex::new(-11.40981, -8.41257), + // ]; + // let result_pad_hop_length_window = vec![ + // Complex::new(34.0, 40.0), + // Complex::new(-27.40167, -24.27608), + // Complex::new(20.9163, -4.33643), + // Complex::new(-6.61134, 20.11352), + // Complex::new(-20.90328, -31.50101), + // ]; + + // let result = [ + // result_no_pad, + // result_pad, + // result_pad_hop_length, + // result_pad_hop_length_window, + // ]; + + // let window = [None, None, None, Some([1., 2., 3.].as_slice())]; + + // let input = vec![1.,2.,3.,4.,5.,6.]; + // let length = input.len(); + // let window_length = NonZero::new(3).unwrap(); + + // for (((fft_length, hop_length), result), window) in fft_length + // .into_iter() + // .zip(hop_length.into_iter()) + // .zip(result.iter()) + // .zip(window.into_iter()) + // { + // // Ix1 test. + // let harray = HArray::new_from_shape_vec(length, input.clone()).unwrap(); + // let mut stft = RealStftForward::::new_real_stft_forward(fft_length); + // let stft_harray = stft + // .process(&harray, hop_length, window_length, window) + // .unwrap(); + // let n_fft = 1 + (harray.len() - fft_length) / hop_length; + // let rhs = HArray::new_from_shape_vec((n_fft, fft_length / 2 + 1), result.clone()).unwrap(); + // assert!(compare_harray_complex(&stft_harray, &rhs)); + + // //// IxDyn test. + // //let harray = HArray::new_from_shape_vec(length, input.clone()) + // // .unwrap() + // // .into_dynamic(); + // //let mut stft = RealStftForward::::new_real_stft_forward(fft_length); + // //let stft_harray = stft + // // .process(&harray, hop_length, window_length, window) + // // .unwrap(); + // //let n_fft = 1 + (harray.len() - fft_length) / hop_length; + // //let rhs = HArray::new_from_shape_vec((n_fft, fft_length / 2 + 1), result.clone()) + // // .unwrap() + // // .into_dynamic(); + // //assert!(compare_harray_complex(&stft_harray, &rhs)); + // } + //} } diff --git a/harmonium-window/src/windows.rs b/harmonium-window/src/windows.rs index 48e7608..3e05aeb 100644 --- a/harmonium-window/src/windows.rs +++ b/harmonium-window/src/windows.rs @@ -3,7 +3,7 @@ use harmonium_core::{ errors::{HError, HResult}, }; use ndarray::Ix1; -use num_traits::{Float, FloatConst}; +use num_traits::{ConstOne, ConstZero, Float, FloatConst}; use realfft::{num_complex::Complex, FftNum, RealFftPlanner}; use rustfft::FftPlanner; @@ -216,7 +216,7 @@ where /// When `WindowType::Periodic`, generates a periodic window, for use in spectral analysis. pub fn bohman(npoints: usize, window_type: WindowType) -> HArray where - T: Float + FloatConst, + T: Float + FloatConst + ConstZero + ConstOne, { let np_f = match window_type { WindowType::Symmetric => T::from(npoints - 1).unwrap(), @@ -224,9 +224,9 @@ where }; let pi = T::PI(); - let zero = T::zero(); - let one = T::one(); - let two = T::from(2.0).unwrap(); + let zero = T::ZERO; + let one = T::ONE; + let two = one + one; let step = two / (np_f); let mut fac = -one; @@ -253,9 +253,9 @@ where /// `npoints` - Number of points in the output window. pub fn boxcar(npoints: usize) -> HArray where - T: Float + FloatConst, + T: Float + FloatConst + ConstOne, { - let one = T::one(); + let one = T::ONE; let window: Vec = (0..npoints).map(|_| one).collect(); HArray::new_from_shape_vec(npoints, window).unwrap() @@ -326,7 +326,7 @@ where /// When `WindowType::Periodic`, generates a periodic window, for use in spectral analysis. pub fn chebwin(npoints: usize, at: T, window_type: WindowType) -> HArray where - T: Float + FloatConst + FftNum, + T: Float + FloatConst + FftNum + ConstZero + ConstOne, { let np_f = match window_type { WindowType::Symmetric => T::from(npoints).unwrap(), @@ -335,9 +335,9 @@ where let np = np_f.to_usize().unwrap(); let pi = T::PI(); - let zero = T::zero(); - let one = T::one(); - let two = T::from(2.0).unwrap(); + let zero = T::ZERO; + let one = T::ONE; + let two = one + one; let ten = T::from(10.0).unwrap(); let twenty = ten + ten; let expr = T::from(two * T::from(np % 2).unwrap() - one).unwrap(); @@ -445,7 +445,7 @@ where WindowType::Symmetric if center.is_none() => T::from(npoints - 1).unwrap(), WindowType::Symmetric => { return Err(HError::OutOfSpecError( - "center must be none for symmetric windows".into(), + "center must be None for symmetric windows".into(), )); } WindowType::Periodic => T::from(npoints).unwrap(), @@ -509,11 +509,11 @@ where /// When `WindowType::Periodic`, generates a periodic window, for use in spectral analysis. pub fn triangle(npoints: usize, window_type: WindowType) -> HArray where - T: Float + FloatConst, + T: Float + FloatConst + ConstOne, { let np_f = T::from(npoints).unwrap(); - let one = T::one(); - let two = T::from(2.0).unwrap(); + let one = T::ONE; + let two = one + one; let mut window: Vec = Vec::with_capacity(npoints); diff --git a/r-harmonium/.testtest.R b/r-harmonium/.testtest.R index 5fb4814..084e557 100644 --- a/r-harmonium/.testtest.R +++ b/r-harmonium/.testtest.R @@ -1,4 +1,27 @@ -pkgbuild::clean_dll() +#pkgbuild::clean_dll() devtools::load_all(".", export_all = FALSE) devtools::test() -devtools::check(document = FALSE, cran = FALSE, args = c("--no-manual", "--no-build-vignettes", "--no-codoc", "--no-examples", "--no-tests")) +#devtools::check(document = FALSE, cran = FALSE, args = c("--no-manual", "--no-build-vignettes", "--no-codoc", "--no-examples", "--no-tests")) + +library(torch) +v <- c( + complex(real = 1, imaginary = 2), + complex(real = 3, imaginary = 4), + complex(real = 5, imaginary = 6), + complex(real = 7, imaginary = 8), + complex(real = 9, imaginary = 10), + complex(real = 11, imaginary = 12) +) + +complex_tensor = torch_tensor(v) +a = torch_stft( + input = complex_tensor, + n_fft = 5, + hop_length = 2, + win_length = 3, + window = torch_tensor(c(1,2,3)), + center = FALSE, + onesided = FALSE, + return_complex = TRUE +) +t(as_array(a)) \ No newline at end of file diff --git a/r-harmonium/src/rust/src/harray.rs b/r-harmonium/src/rust/src/harray.rs index 5d649f6..cbce5f9 100644 --- a/r-harmonium/src/rust/src/harray.rs +++ b/r-harmonium/src/rust/src/harray.rs @@ -5,7 +5,8 @@ use crate::{ use ndarray::{IxDyn, ShapeError, SliceInfo, SliceInfoElem}; use num_complex::Complex; use savvy::{ - savvy, ListSexp, OwnedIntegerSexp, OwnedLogicalSexp, OwnedStringSexp, Sexp, TypedSexp, + savvy, IntegerSexp, ListSexp, NotAvailableValue, OwnedIntegerSexp, OwnedLogicalSexp, + OwnedStringSexp, Sexp, TypedSexp, }; use std::sync::Arc; @@ -207,7 +208,13 @@ impl HArray { /// /// The number of vectors in the list must be equal to the number of dimensions in the original HArray as they represent the slice information for each axis. /// - /// Each vector must be composed of 3 elements: [start, end, step]. All 3 values can be positive or negative, although step can't be 0. + /// Each vector must be composed of 1 or 3 elements + /// + /// For 1 element: A single index. An index to use for taking a subview with respect to that axis. The index is selected, then the axis is removed. + /// + /// For 3 elements: [start, end, step]. All 3 values can be positive or negative, although step can't be 0. + /// Negative start or end indexes are counted from the back of the axis. If end is None, the slice extends to the end of the axis. + /// A `c(NA_integer_, NA_integer_, NA_integer_)` value for start will mean start = 0, end = axis_length, step = 1. /// /// #### Returns /// @@ -217,10 +224,15 @@ impl HArray { /// /// ```r /// library(harmonium) - /// arr = array(c(1,2,3,4,5,6,7,8,9,10,11,12), c(3,4)) + /// arr = array(c(1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20), c(4,5)) /// dtype = HDataType$Float32 /// harray = HArray$new_from_values(arr, dtype) /// harray$slice(list(c(0L, 2L, 1L), c(1L, 3L, 1L))) + /// harray$slice(list(c(0L, 4L, 1L), c(1L, NA, 1L))) + /// harray$slice(list(c(0L, NA, 1L), c(1L, 3L, 1L))) + /// harray$slice(list(0L, c(NA_integer_, NA, NA))) # using index + /// x = c(NA_integer_, NA_integer_, NA_integer_) + /// harray$slice(list(x, x)) == harray # TRUE /// ``` /// /// _________ @@ -236,20 +248,33 @@ impl HArray { let mut vec_ranges: Vec = Vec::with_capacity(list_len); for obj in range.values_iter() { - match obj.into_typed() { - TypedSexp::Integer(integer_sexp) => { - if integer_sexp.len() != 3 { - return Err("Each element must have a length of 3.".into()); - } - let slice: &[i32] = integer_sexp.as_slice(); - let slice_info_elem = SliceInfoElem::Slice { - start: slice[0] as isize, - end: Some(slice[1] as isize), - step: slice[2] as isize, - }; - vec_ranges.push(slice_info_elem); - } - _ => return Err("Each element in the list must be a vector of integers.".into()), + let integer_sexp = IntegerSexp::try_from(obj)?; + let slice: &[i32] = integer_sexp.as_slice(); + if slice.len() == 1 { + // Safety: the vector is checked to be length 1. + let index = unsafe { *slice.get_unchecked(0) as isize }; + let slice_info_elem = SliceInfoElem::Index(index); + vec_ranges.push(slice_info_elem); + } else if slice.len() == 3 { + // Safety: the vector is checked to be length 3. + let (start, end, step) = ( + unsafe { *slice.get_unchecked(0) }, + unsafe { *slice.get_unchecked(1) }, + unsafe { *slice.get_unchecked(2) }, + ); + + let start = if start.is_na() { 0 } else { start as isize }; + let end = if end.is_na() { + None + } else { + Some(end as isize) + }; + let step = if step.is_na() { 1 } else { step as isize }; + + let slice_info_elem = SliceInfoElem::Slice { start, end, step }; + vec_ranges.push(slice_info_elem); + } else { + return Err("Each element must have a length of 1 or 3.".into()); } } @@ -589,9 +614,12 @@ impl HArray { impl HArray { #[doc(hidden)] pub fn get_inner_mut(&mut self) -> &mut dyn HArrayR { - if Arc::weak_count(&self.0) + Arc::strong_count(&self.0) != 1 { + // Weak references are never used. + if Arc::strong_count(&self.0) != 1 { self.0 = self.0.clone_inner(); } - Arc::get_mut(&mut self.0).expect("implementation error") + // Safety: reference count was checked. + // Use get_mut_unchecked when stable. + unsafe { Arc::get_mut(&mut self.0).unwrap_unchecked() } } } diff --git a/r-harmonium/src/rust/src/hfft.rs b/r-harmonium/src/rust/src/hfft.rs index 5ee77f4..eadfe3e 100644 --- a/r-harmonium/src/rust/src/hfft.rs +++ b/r-harmonium/src/rust/src/hfft.rs @@ -752,10 +752,12 @@ impl_hrealfftinverse!( impl HFft { #[doc(hidden)] pub fn get_inner_mut(&mut self) -> &mut dyn HFftR { - if Arc::get_mut(&mut self.0).is_none() { + // Weak references are never used. + if Arc::strong_count(&self.0) != 1 { self.0 = self.0.clone_inner(); } - // Safe to unwrap_unchecked since get_mut was checked above. + // Safety: reference count was checked. + // Use get_mut_unchecked when stable. unsafe { Arc::get_mut(&mut self.0).unwrap_unchecked() } } } @@ -763,10 +765,12 @@ impl HFft { impl HRealFft { #[doc(hidden)] pub fn get_inner_mut(&mut self) -> &mut dyn HRealFftR { - if Arc::get_mut(&mut self.0).is_none() { + // Weak references are never used. + if Arc::strong_count(&self.0) != 1 { self.0 = self.0.clone_inner(); } - // Safe to unwrap_unchecked since get_mut was checked above. + // Safety: reference count was checked. + // Use get_mut_unchecked when stable. unsafe { Arc::get_mut(&mut self.0).unwrap_unchecked() } } }