diff --git a/src/impl_methods.rs b/src/impl_methods.rs index 4a00ea000..7b19af2a4 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -1493,6 +1493,19 @@ where /// ``` pub fn axis_windows(&self, axis: Axis, window_size: usize) -> AxisWindows<'_, A, D> where S: Data + { + self.axis_windows_with_stride(axis, window_size, 1) + } + + /// Returns a producer which traverses over windows of a given length and + /// stride along an axis. + /// + /// Note that a calling this method with a stride of 1 is equivalent to + /// calling [`ArrayBase::axis_windows()`]. + pub fn axis_windows_with_stride( + &self, axis: Axis, window_size: usize, stride_size: usize, + ) -> AxisWindows<'_, A, D> + where S: Data { let axis_index = axis.index(); @@ -1507,7 +1520,12 @@ where self.shape() ); - AxisWindows::new(self.view(), axis, window_size) + ndassert!( + stride_size >0, + "Stride size must be greater than zero" + ); + + AxisWindows::new_with_stride(self.view(), axis, window_size, stride_size) } // Return (length, stride) for diagonal diff --git a/src/iterators/windows.rs b/src/iterators/windows.rs index 1c2ab6a85..6451b8901 100644 --- a/src/iterators/windows.rs +++ b/src/iterators/windows.rs @@ -141,7 +141,7 @@ pub struct AxisWindows<'a, A, D> impl<'a, A, D: Dimension> AxisWindows<'a, A, D> { - pub(crate) fn new(a: ArrayView<'a, A, D>, axis: Axis, window_size: usize) -> Self + pub(crate) fn new_with_stride(a: ArrayView<'a, A, D>, axis: Axis, window_size: usize, stride_size: usize) -> Self { let window_strides = a.strides.clone(); let axis_idx = axis.index(); @@ -150,10 +150,11 @@ impl<'a, A, D: Dimension> AxisWindows<'a, A, D> window[axis_idx] = window_size; let ndim = window.ndim(); - let mut unit_stride = D::zeros(ndim); - unit_stride.slice_mut().fill(1); + let mut stride = D::zeros(ndim); + stride.slice_mut().fill(1); + stride[axis_idx] = stride_size; - let base = build_base(a, window.clone(), unit_stride); + let base = build_base(a, window.clone(), stride); AxisWindows { base, axis_idx, diff --git a/tests/windows.rs b/tests/windows.rs index 6506d8301..4d4d0d7d7 100644 --- a/tests/windows.rs +++ b/tests/windows.rs @@ -294,6 +294,148 @@ fn tests_axis_windows_3d_zips_with_1d() assert_eq!(b,arr1(&[207, 261])); } +/// Test verifies that non existent Axis results in panic +#[test] +#[should_panic] +fn axis_windows_with_stride_outofbound() +{ + let a = Array::from_iter(10..37) + .into_shape_with_order((3, 3, 3)) + .unwrap(); + a.axis_windows_with_stride(Axis(4), 2, 2); +} + +/// Test verifies that zero sizes results in panic +#[test] +#[should_panic] +fn axis_windows_with_stride_zero_size() +{ + let a = Array::from_iter(10..37) + .into_shape_with_order((3, 3, 3)) + .unwrap(); + a.axis_windows_with_stride(Axis(0), 0, 2); +} + +/// Test verifies that zero stride results in panic +#[test] +#[should_panic] +fn axis_windows_with_stride_zero_stride() +{ + let a = Array::from_iter(10..37) + .into_shape_with_order((3, 3, 3)) + .unwrap(); + a.axis_windows_with_stride(Axis(0), 2, 0); +} + +/// Test verifies that over sized windows yield nothing +#[test] +fn axis_windows_with_stride_oversized() +{ + let a = Array::from_iter(10..37) + .into_shape_with_order((3, 3, 3)) + .unwrap(); + let mut iter = a.axis_windows_with_stride(Axis(2), 4, 2).into_iter(); + assert_eq!(iter.next(), None); +} + +/// Simple test for iterating 1d-arrays via `Axis Windows`. +#[test] +fn test_axis_windows_with_stride_1d() +{ + let a = Array::from_iter(10..20).into_shape_with_order(10).unwrap(); + + itertools::assert_equal(a.axis_windows_with_stride(Axis(0), 5, 2), vec![ + arr1(&[10, 11, 12, 13, 14]), + arr1(&[12, 13, 14, 15, 16]), + arr1(&[14, 15, 16, 17, 18]), + ]); + + itertools::assert_equal(a.axis_windows_with_stride(Axis(0), 5, 3), vec![ + arr1(&[10, 11, 12, 13, 14]), + arr1(&[13, 14, 15, 16, 17]), + ]); +} + +/// Simple test for iterating 2d-arrays via `Axis Windows`. +#[test] +fn test_axis_windows_with_stride_2d() +{ + let a = Array::from_iter(10..30) + .into_shape_with_order((5, 4)) + .unwrap(); + + itertools::assert_equal(a.axis_windows_with_stride(Axis(0), 2, 1), vec![ + arr2(&[[10, 11, 12, 13], [14, 15, 16, 17]]), + arr2(&[[14, 15, 16, 17], [18, 19, 20, 21]]), + arr2(&[[18, 19, 20, 21], [22, 23, 24, 25]]), + arr2(&[[22, 23, 24, 25], [26, 27, 28, 29]]), + ]); + + itertools::assert_equal(a.axis_windows_with_stride(Axis(0), 2, 2), vec![ + arr2(&[[10, 11, 12, 13], [14, 15, 16, 17]]), + arr2(&[[18, 19, 20, 21], [22, 23, 24, 25]]), + ]); + + itertools::assert_equal(a.axis_windows_with_stride(Axis(0), 2, 3), vec![ + arr2(&[[10, 11, 12, 13], [14, 15, 16, 17]]), + arr2(&[[22, 23, 24, 25], [26, 27, 28, 29]]), + ]); +} + +/// Simple test for iterating 3d-arrays via `Axis Windows`. +#[test] +fn test_axis_windows_with_stride_3d() +{ + let a = Array::from_iter(0..27) + .into_shape_with_order((3, 3, 3)) + .unwrap(); + + itertools::assert_equal(a.axis_windows_with_stride(Axis(1), 2, 1), vec![ + arr3(&[ + [[0, 1, 2], [3, 4, 5]], + [[9, 10, 11], [12, 13, 14]], + [[18, 19, 20], [21, 22, 23]], + ]), + arr3(&[ + [[3, 4, 5], [6, 7, 8]], + [[12, 13, 14], [15, 16, 17]], + [[21, 22, 23], [24, 25, 26]], + ]), + ]); + + itertools::assert_equal(a.axis_windows_with_stride(Axis(1), 2, 2), vec![ + arr3(&[ + [[0, 1, 2], [3, 4, 5]], + [[9, 10, 11], [12, 13, 14]], + [[18, 19, 20], [21, 22, 23]], + ]), + ]); +} + +#[test] +fn tests_axis_windows_with_stride_3d_zips_with_1d() +{ + let a = Array::from_iter(0..27) + .into_shape_with_order((3, 3, 3)) + .unwrap(); + let mut b1 = Array::zeros(2); + let mut b2 = Array::zeros(1); + + Zip::from(b1.view_mut()) + .and(a.axis_windows_with_stride(Axis(1), 2, 1)) + .for_each(|b, a| { + *b = a.sum(); + }); + assert_eq!(b1,arr1(&[207, 261])); + + Zip::from(b2.view_mut()) + .and(a.axis_windows_with_stride(Axis(1), 2, 2)) + .for_each(|b, a| { + *b = a.sum(); + }); + assert_eq!(b2,arr1(&[207])); +} + #[test] fn test_window_neg_stride() {