From 31563467172c88bc5bd07a7f83be9707b0d0893e Mon Sep 17 00:00:00 2001 From: Adam Kern Date: Mon, 14 Oct 2024 16:56:08 -0400 Subject: [PATCH] Adds some aliases to avoid breaking changes, and adds an example of how to write functionality with the new types --- .gitignore | 3 + examples/functions_and_traits.rs | 178 ++++++++++++++++++++++ src/{alias_slicing.rs => alias_asref.rs} | 142 +++++++++++++++++- src/impl_2d.rs | 6 +- src/impl_methods.rs | 33 ++-- src/impl_owned_array.rs | 16 +- src/impl_raw_views.rs | 6 +- src/impl_ref_types.rs | 183 ++++++++++++++++++----- src/iterators/chunks.rs | 16 +- src/lib.rs | 2 +- src/linalg/impl_linalg.rs | 4 +- src/zip/mod.rs | 1 - src/zip/ndproducer.rs | 16 +- tests/raw_views.rs | 4 +- tests/test_ref_structure.rs | 39 ----- 15 files changed, 513 insertions(+), 136 deletions(-) create mode 100644 examples/functions_and_traits.rs rename src/{alias_slicing.rs => alias_asref.rs} (56%) delete mode 100644 tests/test_ref_structure.rs diff --git a/.gitignore b/.gitignore index dd9ffd9fe..0745da3b8 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,6 @@ target/ # Editor settings .vscode + +# Apple details +**/.DS_Store diff --git a/examples/functions_and_traits.rs b/examples/functions_and_traits.rs new file mode 100644 index 000000000..dc8f73da4 --- /dev/null +++ b/examples/functions_and_traits.rs @@ -0,0 +1,178 @@ +//! Examples of how to write functions and traits that operate on `ndarray` types. +//! +//! `ndarray` has four kinds of array types that users may interact with: +//! 1. [`ArrayBase`], the owner of the layout that describes an array in memory; +//! this includes [`ndarray::Array`], [`ndarray::ArcArray`], [`ndarray::ArrayView`], +//! [`ndarray::RawArrayView`], and other variants. +//! 2. [`ArrayRef`], which represents a read-safe, uniquely-owned look at an array. +//! 3. [`RawRef`], which represents a read-unsafe, possibly-shared look at an array. +//! 4. [`LayoutRef`], which represents a look at an array's underlying structure, +//! but does not allow data reading of any kind. +//! +//! Below, we illustrate how to write functions and traits for most variants of these types. + +use ndarray::{ArrayBase, ArrayRef, Data, DataMut, Dimension, LayoutRef, RawData, RawDataMut, RawRef}; + +/// Take an array with the most basic requirements. +/// +/// This function takes its data as owning. It is very rare that a user will need to specifically +/// take a reference to an `ArrayBase`, rather than to one of the other four types. +#[rustfmt::skip] +fn takes_base_raw(arr: ArrayBase) -> ArrayBase +{ + // These skip from a possibly-raw array to `RawRef` and `LayoutRef`, and so must go through `AsRef` + takes_rawref(arr.as_ref()); // Caller uses `.as_ref` + takes_rawref_asref(&arr); // Implementor uses `.as_ref` + takes_layout(arr.as_ref()); // Caller uses `.as_ref` + takes_layout_asref(&arr); // Implementor uses `.as_ref` + + arr +} + +/// Similar to above, but allow us to read the underlying data. +#[rustfmt::skip] +fn takes_base_raw_mut(mut arr: ArrayBase) -> ArrayBase +{ + // These skip from a possibly-raw array to `RawRef` and `LayoutRef`, and so must go through `AsMut` + takes_rawref_mut(arr.as_mut()); // Caller uses `.as_mut` + takes_rawref_asmut(&mut arr); // Implementor uses `.as_mut` + takes_layout_mut(arr.as_mut()); // Caller uses `.as_mut` + takes_layout_asmut(&mut arr); // Implementor uses `.as_mut` + + arr +} + +/// Now take an array whose data is safe to read. +#[allow(dead_code)] +fn takes_base(mut arr: ArrayBase) -> ArrayBase +{ + // Raw call + arr = takes_base_raw(arr); + + // No need for AsRef, since data is safe + takes_arrref(&arr); + takes_rawref(&arr); + takes_rawref_asref(&arr); + takes_layout(&arr); + takes_layout_asref(&arr); + + arr +} + +/// Now, an array whose data is safe to read and that we can mutate. +/// +/// Notice that we include now a trait bound on `D: Dimension`; this is necessary in order +/// for the `ArrayBase` to dereference to an `ArrayRef` (or to any of the other types). +#[allow(dead_code)] +fn takes_base_mut(mut arr: ArrayBase) -> ArrayBase +{ + // Raw call + arr = takes_base_raw_mut(arr); + + // No need for AsMut, since data is safe + takes_arrref_mut(&mut arr); + takes_rawref_mut(&mut arr); + takes_rawref_asmut(&mut arr); + takes_layout_mut(&mut arr); + takes_layout_asmut(&mut arr); + + arr +} + +/// Now for new stuff: we want to read (but not alter) any array whose data is safe to read. +/// +/// This is probably the most common functionality that one would want to write. +/// As we'll see below, calling this function is very simple for `ArrayBase`. +fn takes_arrref(arr: &ArrayRef) +{ + // No need for AsRef, since data is safe + takes_rawref(arr); + takes_rawref_asref(arr); + takes_layout(arr); + takes_layout_asref(arr); +} + +/// Now we want any array whose data is safe to mutate. +/// +/// **Importantly**, any array passed to this function is guaranteed to uniquely point to its data. +/// As a result, passing a shared array to this function will **silently** un-share the array. +#[allow(dead_code)] +fn takes_arrref_mut(arr: &mut ArrayRef) +{ + // Immutable call + takes_arrref(arr); + + // No need for AsMut, since data is safe + takes_rawref_mut(arr); + takes_rawref_asmut(arr); + takes_layout_mut(arr); + takes_rawref_asmut(arr); +} + +/// Now, we no longer care about whether we can safely read data. +/// +/// This is probably the rarest type to deal with, since `LayoutRef` can access and modify an array's +/// shape and strides, and even do in-place slicing. As a result, `RawRef` is only for functionality +/// that requires unsafe data access, something that `LayoutRef` can't do. +/// +/// Writing functions and traits that deal with `RawRef`s and `LayoutRef`s can be done two ways: +/// 1. Directly on the types; calling these functions on arrays whose data are not known to be safe +/// to dereference (i.e., raw array views or `ArrayBase`) must explicitly call `.as_ref()`. +/// 2. Via a generic with `: AsRef>`; doing this will allow direct calling for all `ArrayBase` and +/// `ArrayRef` instances. +/// We'll demonstrate #1 here for both immutable and mutable references, then #2 directly below. +#[allow(dead_code)] +fn takes_rawref(arr: &RawRef) +{ + takes_layout(arr); + takes_layout_asref(arr); +} + +/// Mutable, directly take `RawRef` +#[allow(dead_code)] +fn takes_rawref_mut(arr: &mut RawRef) +{ + takes_layout(arr); + takes_layout_asmut(arr); +} + +/// Immutable, take a generic that implements `AsRef` to `RawRef` +#[allow(dead_code)] +fn takes_rawref_asref(_arr: &T) +where T: AsRef> +{ + takes_layout(_arr.as_ref()); + takes_layout_asref(_arr.as_ref()); +} + +/// Mutable, take a generic that implements `AsMut` to `RawRef` +#[allow(dead_code)] +fn takes_rawref_asmut(_arr: &mut T) +where T: AsMut> +{ + takes_layout_mut(_arr.as_mut()); + takes_layout_asmut(_arr.as_mut()); +} + +/// Finally, there's `LayoutRef`: this type provides read and write access to an array's *structure*, but not its *data*. +/// +/// Practically, this means that functions that only read/modify an array's shape or strides, +/// such as checking dimensionality or slicing, should take `LayoutRef`. +/// +/// Like `RawRef`, functions can be written either directly on `LayoutRef` or as generics with `: AsRef>>`. +#[allow(dead_code)] +fn takes_layout(_arr: &LayoutRef) {} + +/// Mutable, directly take `LayoutRef` +#[allow(dead_code)] +fn takes_layout_mut(_arr: &mut LayoutRef) {} + +/// Immutable, take a generic that implements `AsRef` to `LayoutRef` +#[allow(dead_code)] +fn takes_layout_asref>, A, D>(_arr: &T) {} + +/// Mutable, take a generic that implements `AsMut` to `LayoutRef` +#[allow(dead_code)] +fn takes_layout_asmut>, A, D>(_arr: &mut T) {} + +fn main() {} diff --git a/src/alias_slicing.rs b/src/alias_asref.rs similarity index 56% rename from src/alias_slicing.rs rename to src/alias_asref.rs index 7c9b3d26b..60435177c 100644 --- a/src/alias_slicing.rs +++ b/src/alias_asref.rs @@ -1,4 +1,16 @@ -use crate::{iter::Axes, ArrayBase, Axis, AxisDescription, Dimension, LayoutRef, RawData, Slice, SliceArg}; +use crate::{ + iter::Axes, + ArrayBase, + Axis, + AxisDescription, + Dimension, + LayoutRef, + RawArrayView, + RawData, + RawRef, + Slice, + SliceArg, +}; impl ArrayBase { @@ -69,19 +81,19 @@ impl ArrayBase /// contiguous in memory, it has custom strides, etc. pub fn is_standard_layout(&self) -> bool { - self.as_ref().is_standard_layout() + >>::as_ref(self).is_standard_layout() } /// Return true if the array is known to be contiguous. pub(crate) fn is_contiguous(&self) -> bool { - self.as_ref().is_contiguous() + >>::as_ref(self).is_contiguous() } /// Return an iterator over the length and stride of each axis. pub fn axes(&self) -> Axes<'_, D> { - self.as_ref().axes() + >>::as_ref(self).axes() } /* @@ -170,9 +182,129 @@ impl ArrayBase self.as_mut().merge_axes(take, into) } + /// Return a raw view of the array. + #[inline] + pub fn raw_view(&self) -> RawArrayView + { + >>::as_ref(self).raw_view() + } + + /// Return a pointer to the first element in the array. + /// + /// Raw access to array elements needs to follow the strided indexing + /// scheme: an element at multi-index *I* in an array with strides *S* is + /// located at offset + /// + /// *Σ0 ≤ k < d Ik × Sk* + /// + /// where *d* is `self.ndim()`. + #[inline(always)] + pub fn as_ptr(&self) -> *const S::Elem + { + >>::as_ref(self).as_ptr() + } + + /// Return the total number of elements in the array. + pub fn len(&self) -> usize + { + >>::as_ref(self).len() + } + + /// Return the length of `axis`. + /// + /// The axis should be in the range `Axis(` 0 .. *n* `)` where *n* is the + /// number of dimensions (axes) of the array. + /// + /// ***Panics*** if the axis is out of bounds. + #[track_caller] + pub fn len_of(&self, axis: Axis) -> usize + { + >>::as_ref(self).len_of(axis) + } + + /// Return whether the array has any elements + pub fn is_empty(&self) -> bool + { + >>::as_ref(self).is_empty() + } + + /// Return the number of dimensions (axes) in the array + pub fn ndim(&self) -> usize + { + >>::as_ref(self).ndim() + } + + /// Return the shape of the array in its “pattern” form, + /// an integer in the one-dimensional case, tuple in the n-dimensional cases + /// and so on. + pub fn dim(&self) -> D::Pattern + { + >>::as_ref(self).dim() + } + + /// Return the shape of the array as it's stored in the array. + /// + /// This is primarily useful for passing to other `ArrayBase` + /// functions, such as when creating another array of the same + /// shape and dimensionality. + /// + /// ``` + /// use ndarray::Array; + /// + /// let a = Array::from_elem((2, 3), 5.); + /// + /// // Create an array of zeros that's the same shape and dimensionality as `a`. + /// let b = Array::::zeros(a.raw_dim()); + /// ``` + pub fn raw_dim(&self) -> D + { + >>::as_ref(self).raw_dim() + } + + /// Return the shape of the array as a slice. + /// + /// Note that you probably don't want to use this to create an array of the + /// same shape as another array because creating an array with e.g. + /// [`Array::zeros()`](ArrayBase::zeros) using a shape of type `&[usize]` + /// results in a dynamic-dimensional array. If you want to create an array + /// that has the same shape and dimensionality as another array, use + /// [`.raw_dim()`](ArrayBase::raw_dim) instead: + /// + /// ```rust + /// use ndarray::{Array, Array2}; + /// + /// let a = Array2::::zeros((3, 4)); + /// let shape = a.shape(); + /// assert_eq!(shape, &[3, 4]); + /// + /// // Since `a.shape()` returned `&[usize]`, we get an `ArrayD` instance: + /// let b = Array::zeros(shape); + /// assert_eq!(a.clone().into_dyn(), b); + /// + /// // To get the same dimension type, use `.raw_dim()` instead: + /// let c = Array::zeros(a.raw_dim()); + /// assert_eq!(a, c); + /// ``` + pub fn shape(&self) -> &[usize] + { + >>::as_ref(self).shape() + } + /// Return the strides of the array as a slice. pub fn strides(&self) -> &[isize] { - self.as_ref().strides() + >>::as_ref(self).strides() + } + + /// Return the stride of `axis`. + /// + /// The axis should be in the range `Axis(` 0 .. *n* `)` where *n* is the + /// number of dimensions (axes) of the array. + /// + /// ***Panics*** if the axis is out of bounds. + #[track_caller] + pub fn stride_of(&self, axis: Axis) -> isize + { + >>::as_ref(self).stride_of(axis) } } diff --git a/src/impl_2d.rs b/src/impl_2d.rs index 061bf45fd..27358dca9 100644 --- a/src/impl_2d.rs +++ b/src/impl_2d.rs @@ -65,7 +65,7 @@ impl LayoutRef /// ``` pub fn nrows(&self) -> usize { - self.as_ref().len_of(Axis(0)) + self.len_of(Axis(0)) } } @@ -124,7 +124,7 @@ impl LayoutRef /// ``` pub fn ncols(&self) -> usize { - self.as_ref().len_of(Axis(1)) + self.len_of(Axis(1)) } /// Return true if the array is square, false otherwise. @@ -144,7 +144,7 @@ impl LayoutRef /// ``` pub fn is_square(&self) -> bool { - let (m, n) = self.as_ref().dim(); + let (m, n) = self.dim(); m == n } } diff --git a/src/impl_methods.rs b/src/impl_methods.rs index 2c87c072e..4bf86e96a 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -578,7 +578,7 @@ where { assert_eq!( info.in_ndim(), - self.as_ref().ndim(), + self.ndim(), "The input dimension of `info` must match the array to be sliced.", ); let out_ndim = info.out_ndim(); @@ -599,7 +599,7 @@ where } SliceInfoElem::Index(index) => { // Collapse the axis in-place to update the `ptr`. - let i_usize = abs_index(self.as_ref().len_of(Axis(old_axis)), index); + let i_usize = abs_index(self.len_of(Axis(old_axis)), index); self.collapse_axis(Axis(old_axis), i_usize); // Skip copying the axis since it should be removed. Note that // removing this axis is safe because `.collapse_axis()` panics @@ -614,7 +614,7 @@ where new_axis += 1; } }); - debug_assert_eq!(old_axis, self.as_ref().ndim()); + debug_assert_eq!(old_axis, self.ndim()); debug_assert_eq!(new_axis, out_ndim); // safe because new dimension, strides allow access to a subset of old data @@ -1568,7 +1568,7 @@ where { /* empty shape has len 1 */ let len = self.layout.dim.slice().iter().cloned().min().unwrap_or(1); - let stride = LayoutRef::strides(self.as_ref()).iter().sum(); + let stride = self.strides().iter().sum(); (len, stride) } @@ -2039,12 +2039,7 @@ where match order { Order::RowMajor if self.is_standard_layout() => Ok(self.with_strides_dim(shape.default_strides(), shape)), - Order::ColumnMajor - if self - .as_ref() - .raw_view() - .reversed_axes() - .is_standard_layout() => + Order::ColumnMajor if self.raw_view().reversed_axes().is_standard_layout() => Ok(self.with_strides_dim(shape.fortran_strides(), shape)), _otherwise => Err(error::from_kind(error::ErrorKind::IncompatibleLayout)), } @@ -2087,13 +2082,7 @@ where // safe because arrays are contiguous and len is unchanged if self.is_standard_layout() { Ok(self.with_strides_dim(shape.default_strides(), shape)) - } else if self.as_ref().ndim() > 1 - && self - .as_ref() - .raw_view() - .reversed_axes() - .is_standard_layout() - { + } else if self.ndim() > 1 && self.raw_view().reversed_axes().is_standard_layout() { Ok(self.with_strides_dim(shape.fortran_strides(), shape)) } else { Err(error::from_kind(error::ErrorKind::IncompatibleLayout)) @@ -2530,7 +2519,7 @@ where { let axes = axes.into_dimension(); // Ensure that each axis is used exactly once. - let mut usage_counts = D::zeros(self.as_ref().ndim()); + let mut usage_counts = D::zeros(self.ndim()); for axis in axes.slice() { usage_counts[*axis] += 1; } @@ -2539,7 +2528,7 @@ where } // Determine the new shape and strides. let mut new_dim = usage_counts; // reuse to avoid an allocation - let mut new_strides = D::zeros(self.as_ref().ndim()); + let mut new_strides = D::zeros(self.ndim()); { let dim = self.layout.dim.slice(); let strides = self.layout.strides.slice(); @@ -2686,7 +2675,7 @@ where #[track_caller] pub fn insert_axis(self, axis: Axis) -> ArrayBase { - assert!(axis.index() <= self.as_ref().ndim()); + assert!(axis.index() <= self.ndim()); // safe because a new axis of length one does not affect memory layout unsafe { let strides = self.layout.strides.insert_axis(axis); @@ -2710,7 +2699,7 @@ where pub(crate) fn pointer_is_inbounds(&self) -> bool { - self.data._is_pointer_inbounds(self.as_ref().as_ptr()) + self.data._is_pointer_inbounds(self.as_ptr()) } } @@ -3172,7 +3161,7 @@ impl ArrayRef return; } let mut curr = self.raw_view_mut(); // mut borrow of the array here - let mut prev = curr.as_ref().raw_view(); // derive further raw views from the same borrow + let mut prev = curr.raw_view(); // derive further raw views from the same borrow prev.slice_axis_inplace(axis, Slice::from(..-1)); curr.slice_axis_inplace(axis, Slice::from(1..)); // This implementation relies on `Zip` iterating along `axis` in order. diff --git a/src/impl_owned_array.rs b/src/impl_owned_array.rs index 3b0ef02be..dc79ecda0 100644 --- a/src/impl_owned_array.rs +++ b/src/impl_owned_array.rs @@ -743,11 +743,11 @@ where D: Dimension let tail_ptr = self.data.as_end_nonnull(); let mut tail_view = RawArrayViewMut::new(tail_ptr, array_dim, tail_strides); - if tail_view.as_ref().ndim() > 1 { + if tail_view.ndim() > 1 { sort_axes_in_default_order_tandem(&mut tail_view, &mut array); debug_assert!(tail_view.is_standard_layout(), "not std layout dim: {:?}, strides: {:?}", - tail_view.as_ref().shape(), LayoutRef::strides(tail_view.as_ref())); + tail_view.shape(), LayoutRef::strides(tail_view.as_ref())); } // Keep track of currently filled length of `self.data` and update it @@ -872,10 +872,10 @@ pub(crate) unsafe fn drop_unreachable_raw( mut self_: RawArrayViewMut, data_ptr: NonNull, data_len: usize, ) where D: Dimension { - let self_len = self_.as_ref().len(); + let self_len = self_.len(); - for i in 0..self_.as_ref().ndim() { - if self_.as_ref().stride_of(Axis(i)) < 0 { + for i in 0..self_.ndim() { + if self_.stride_of(Axis(i)) < 0 { self_.invert_axis(Axis(i)); } } @@ -898,7 +898,7 @@ pub(crate) unsafe fn drop_unreachable_raw( // As an optimization, the innermost axis is removed if it has stride 1, because // we then have a long stretch of contiguous elements we can skip as one. let inner_lane_len; - if self_.as_ref().ndim() > 1 && self_.layout.strides.last_elem() == 1 { + if self_.ndim() > 1 && self_.layout.strides.last_elem() == 1 { self_.layout.dim.slice_mut().rotate_right(1); self_.layout.strides.slice_mut().rotate_right(1); inner_lane_len = self_.layout.dim[0]; @@ -946,7 +946,7 @@ where S: RawData, D: Dimension, { - if a.as_ref().ndim() <= 1 { + if a.ndim() <= 1 { return; } sort_axes1_impl(&mut a.layout.dim, &mut a.layout.strides); @@ -986,7 +986,7 @@ where S2: RawData, D: Dimension, { - if a.as_ref().ndim() <= 1 { + if a.ndim() <= 1 { return; } sort_axes2_impl(&mut a.layout.dim, &mut a.layout.strides, &mut b.layout.dim, &mut b.layout.strides); diff --git a/src/impl_raw_views.rs b/src/impl_raw_views.rs index 049bdc536..5bb2a0e42 100644 --- a/src/impl_raw_views.rs +++ b/src/impl_raw_views.rs @@ -112,9 +112,9 @@ where D: Dimension #[inline] pub fn split_at(self, axis: Axis, index: Ix) -> (Self, Self) { - assert!(index <= self.as_ref().len_of(axis)); + assert!(index <= self.len_of(axis)); let left_ptr = self.layout.ptr.as_ptr(); - let right_ptr = if index == self.as_ref().len_of(axis) { + let right_ptr = if index == self.len_of(axis) { self.layout.ptr.as_ptr() } else { let offset = stride_offset(index, self.layout.strides.axis(axis)); @@ -186,7 +186,7 @@ where D: Dimension } let ptr_re: *mut T = self.layout.ptr.as_ptr().cast(); - let ptr_im: *mut T = if self.as_ref().is_empty() { + let ptr_im: *mut T = if self.is_empty() { // In the empty case, we can just reuse the existing pointer since // it won't be dereferenced anyway. It is not safe to offset by // one, since the allocation may be empty. diff --git a/src/impl_ref_types.rs b/src/impl_ref_types.rs index 48cb027fc..92e916887 100644 --- a/src/impl_ref_types.rs +++ b/src/impl_ref_types.rs @@ -1,9 +1,39 @@ -//! Code for the array reference type +//! Implementations that connect arrays to their reference types. +//! +//! `ndarray` has four kinds of array types that users may interact with: +//! 1. [`ArrayBase`], which represents arrays which own their layout (shape and strides) +//! 2. [`ArrayRef`], which represents a read-safe, uniquely-owned look at an array +//! 3. [`RawRef`], which represents a read-unsafe, possibly-shared look at an array +//! 4. [`LayoutRef`], which represents a look at an array's underlying structure, +//! but does not allow data reading of any kind +//! +//! These types are connected through a number of `Deref` and `AsRef` implementations. +//! 1. `ArrayBase` dereferences to `ArrayRef` when `S: Data` +//! 2. `ArrayBase` mutably dereferences to `ArrayRef` when `S: DataMut`, and ensures uniqueness +//! 3. `ArrayRef` mutably dereferences to `RawRef` +//! 4. `RawRef` mutably dereferences to `LayoutRef` +//! This chain works very well for arrays whose data is safe to read and is uniquely held. +//! Because raw views do not meet `S: Data`, they cannot dereference to `ArrayRef`; furthermore, +//! technical limitations of Rust's compiler means that `ArrayBase` cannot have multiple `Deref` implementations. +//! In addition, shared-data arrays do not want to go down the `Deref` path to get to methods on `RawRef` +//! or `LayoutRef`, since that would unecessarily ensure their uniqueness. +//! +//! To mitigate these problems, `ndarray` also provides `AsRef` and `AsMut` implementations as follows: +//! 1. `ArrayBase` implements `AsRef` to `RawRef` and `LayoutRef` when `S: RawData` +//! 2. `ArrayBase` implements `AsMut` to `RawRef` when `S: RawDataMut` +//! 3. `ArrayBase` implements `AsMut` to `LayoutRef` unconditionally +//! 4. `ArrayRef` implements `AsMut` to `RawRef` and `LayoutRef` unconditionally +//! 5. `RawRef` implements `AsMut` to `LayoutRef` +//! 6. `RawRef` and `LayoutRef` implement `AsMut` to themselves +//! +//! This allows users to write a single method or trait implementation that takes `T: AsRef>` +//! or `T: AsRef>` and have that functionality work on any of the relevant array types. use core::ops::{Deref, DerefMut}; -use crate::{ArrayBase, ArrayRef, Data, DataMut, Dimension, LayoutRef, RawData, RawRef}; +use crate::{ArrayBase, ArrayRef, Data, DataMut, Dimension, LayoutRef, RawData, RawDataMut, RawRef}; +// D1: &ArrayBase -> &ArrayRef when data is safe to read impl Deref for ArrayBase where S: Data { @@ -24,6 +54,7 @@ where S: Data } } +// D2: &mut ArrayBase -> &mut ArrayRef when data is safe to read; ensure uniqueness impl DerefMut for ArrayBase where S: DataMut, @@ -45,13 +76,15 @@ where } } -impl AsRef> for ArrayBase -where S: RawData +// D3: &ArrayRef -> &RawRef +impl Deref for ArrayRef { - fn as_ref(&self) -> &RawRef + type Target = RawRef; + + fn deref(&self) -> &Self::Target { unsafe { - (&self.layout as *const LayoutRef) + (self as *const ArrayRef) .cast::>() .as_ref() } @@ -59,13 +92,13 @@ where S: RawData } } -impl AsMut> for ArrayBase -where S: RawData +// D4: &mut ArrayRef -> &mut RawRef +impl DerefMut for ArrayRef { - fn as_mut(&mut self) -> &mut RawRef + fn deref_mut(&mut self) -> &mut Self::Target { unsafe { - (&mut self.layout as *mut LayoutRef) + (self as *mut ArrayRef) .cast::>() .as_mut() } @@ -73,30 +106,34 @@ where S: RawData } } -impl AsRef> for RawRef +// D5: &RawRef -> &LayoutRef +impl Deref for RawRef { - fn as_ref(&self) -> &RawRef + type Target = LayoutRef; + + fn deref(&self) -> &Self::Target { - self + &self.0 } } -impl AsMut> for RawRef +// D5: &mut RawRef -> &mut LayoutRef +impl DerefMut for RawRef { - fn as_mut(&mut self) -> &mut RawRef + fn deref_mut(&mut self) -> &mut Self::Target { - self + &mut self.0 } } -impl Deref for ArrayRef +// A1: &ArrayBase -AR-> &RawRef +impl AsRef> for ArrayBase +where S: RawData { - type Target = RawRef; - - fn deref(&self) -> &Self::Target + fn as_ref(&self) -> &RawRef { unsafe { - (self as *const ArrayRef) + (&self.layout as *const LayoutRef) .cast::>() .as_ref() } @@ -104,12 +141,14 @@ impl Deref for ArrayRef } } -impl DerefMut for ArrayRef +// A2: &mut ArrayBase -AM-> &mut RawRef +impl AsMut> for ArrayBase +where S: RawDataMut { - fn deref_mut(&mut self) -> &mut Self::Target + fn as_mut(&mut self) -> &mut RawRef { unsafe { - (self as *mut ArrayRef) + (&mut self.layout as *mut LayoutRef) .cast::>() .as_mut() } @@ -117,37 +156,113 @@ impl DerefMut for ArrayRef } } -impl AsRef> for LayoutRef +// A3: &ArrayBase -AR-> &LayoutRef +impl AsRef> for ArrayBase +where S: RawData { fn as_ref(&self) -> &LayoutRef { - self + &self.layout } } -impl AsMut> for LayoutRef +// A3: &mut ArrayBase -AM-> &mut LayoutRef +impl AsMut> for ArrayBase +where S: RawData { fn as_mut(&mut self) -> &mut LayoutRef + { + &mut self.layout + } +} + +// A4: &ArrayRef -AR-> &RawRef +impl AsRef> for ArrayRef +{ + fn as_ref(&self) -> &RawRef + { + &**self + } +} + +// A4: &mut ArrayRef -AM-> &mut RawRef +impl AsMut> for ArrayRef +{ + fn as_mut(&mut self) -> &mut RawRef + { + &mut **self + } +} + +// A4: &ArrayRef -AR-> &LayoutRef +impl AsRef> for ArrayRef +{ + fn as_ref(&self) -> &LayoutRef + { + &***self + } +} + +// A4: &mut ArrayRef -AM-> &mut LayoutRef +impl AsMut> for ArrayRef +{ + fn as_mut(&mut self) -> &mut LayoutRef + { + &mut ***self + } +} + +// A5: &RawRef -AR-> &LayoutRef +impl AsRef> for RawRef +{ + fn as_ref(&self) -> &LayoutRef + { + &**self + } +} + +// A5: &mut RawRef -AM-> &mut LayoutRef +impl AsMut> for RawRef +{ + fn as_mut(&mut self) -> &mut LayoutRef + { + &mut **self + } +} + +// A6: &RawRef -AR-> &RawRef +impl AsRef> for RawRef +{ + fn as_ref(&self) -> &RawRef { self } } -impl Deref for RawRef +// A6: &mut RawRef -AM-> &mut RawRef +impl AsMut> for RawRef { - type Target = LayoutRef; + fn as_mut(&mut self) -> &mut RawRef + { + self + } +} - fn deref(&self) -> &Self::Target +// A6: &LayoutRef -AR-> &LayoutRef +impl AsRef> for LayoutRef +{ + fn as_ref(&self) -> &LayoutRef { - &self.0 + self } } -impl DerefMut for RawRef +// A6: &mut LayoutRef -AR-> &mut LayoutRef +impl AsMut> for LayoutRef { - fn deref_mut(&mut self) -> &mut Self::Target + fn as_mut(&mut self) -> &mut LayoutRef { - &mut self.0 + self } } diff --git a/src/iterators/chunks.rs b/src/iterators/chunks.rs index 909377d5e..be6a763a7 100644 --- a/src/iterators/chunks.rs +++ b/src/iterators/chunks.rs @@ -49,16 +49,16 @@ impl<'a, A, D: Dimension> ExactChunks<'a, A, D> let mut a = a.into_raw_view(); let chunk = chunk.into_dimension(); ndassert!( - AsRef::as_ref(&a).ndim() == chunk.ndim(), + a.ndim() == chunk.ndim(), concat!( "Chunk dimension {} does not match array dimension {} ", "(with array of shape {:?})" ), chunk.ndim(), - AsRef::as_ref(&a).ndim(), - AsRef::as_ref(&a).shape() + a.ndim(), + a.shape() ); - for i in 0..AsRef::as_ref(&a).ndim() { + for i in 0..a.ndim() { a.layout.dim[i] /= chunk[i]; } let inner_strides = a.layout.strides.clone(); @@ -148,16 +148,16 @@ impl<'a, A, D: Dimension> ExactChunksMut<'a, A, D> let mut a = a.into_raw_view_mut(); let chunk = chunk.into_dimension(); ndassert!( - AsRef::as_ref(&a).ndim() == chunk.ndim(), + a.ndim() == chunk.ndim(), concat!( "Chunk dimension {} does not match array dimension {} ", "(with array of shape {:?})" ), chunk.ndim(), - AsRef::as_ref(&a).ndim(), - AsRef::as_ref(&a).shape() + a.ndim(), + a.shape() ); - for i in 0..AsRef::as_ref(&a).ndim() { + for i in 0..a.ndim() { a.layout.dim[i] /= chunk[i]; } let inner_strides = a.layout.strides.clone(); diff --git a/src/lib.rs b/src/lib.rs index 7183c096f..bbf88bf2c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1557,7 +1557,7 @@ mod impl_internal_constructors; mod impl_constructors; mod impl_methods; -mod alias_slicing; +mod alias_asref; mod impl_owned_array; mod impl_special_element_types; diff --git a/src/linalg/impl_linalg.rs b/src/linalg/impl_linalg.rs index 32112c896..04584ed0f 100644 --- a/src/linalg/impl_linalg.rs +++ b/src/linalg/impl_linalg.rs @@ -668,7 +668,7 @@ unsafe fn general_mat_vec_mul_impl( ) where A: LinalgScalar { let ((m, k), k2) = (a.dim(), x.dim()); - let m2 = y.as_ref().dim(); + let m2 = y.dim(); if k != k2 || m != m2 { general_dot_shape_error(m, k, k2, 1, m2, 1); } else { @@ -790,7 +790,7 @@ fn complex_array(z: Complex) -> [A; 2] } #[cfg(feature = "blas")] -fn blas_compat_1d(a: &LayoutRef) -> bool +fn blas_compat_1d(a: &RawRef) -> bool where A: 'static, B: 'static, diff --git a/src/zip/mod.rs b/src/zip/mod.rs index df043b3ec..640a74d1b 100644 --- a/src/zip/mod.rs +++ b/src/zip/mod.rs @@ -18,7 +18,6 @@ use crate::partial::Partial; use crate::AssignElem; use crate::IntoDimension; use crate::Layout; -use crate::LayoutRef; use crate::dimension; use crate::indexes::{indices, Indices}; diff --git a/src/zip/ndproducer.rs b/src/zip/ndproducer.rs index 91fb8602a..82f3f43a7 100644 --- a/src/zip/ndproducer.rs +++ b/src/zip/ndproducer.rs @@ -380,7 +380,7 @@ impl NdProducer for RawArrayView fn raw_dim(&self) -> Self::Dim { - AsRef::as_ref(self).raw_dim() + self.raw_dim() } fn equal_dim(&self, dim: &Self::Dim) -> bool @@ -390,12 +390,12 @@ impl NdProducer for RawArrayView fn as_ptr(&self) -> *const A { - AsRef::as_ref(self).as_ptr() as _ + self.as_ptr() as _ } fn layout(&self) -> Layout { - AsRef::as_ref(self).layout_impl() + AsRef::>::as_ref(self).layout_impl() } unsafe fn as_ref(&self, ptr: *const A) -> *const A @@ -413,7 +413,7 @@ impl NdProducer for RawArrayView fn stride_of(&self, axis: Axis) -> isize { - AsRef::as_ref(self).stride_of(axis) + self.stride_of(axis) } #[inline(always)] @@ -439,7 +439,7 @@ impl NdProducer for RawArrayViewMut fn raw_dim(&self) -> Self::Dim { - AsRef::as_ref(self).raw_dim() + self.raw_dim() } fn equal_dim(&self, dim: &Self::Dim) -> bool @@ -449,12 +449,12 @@ impl NdProducer for RawArrayViewMut fn as_ptr(&self) -> *mut A { - AsRef::as_ref(self).as_ptr() as _ + self.as_ptr() as _ } fn layout(&self) -> Layout { - AsRef::as_ref(self).layout_impl() + AsRef::>::as_ref(self).layout_impl() } unsafe fn as_ref(&self, ptr: *mut A) -> *mut A @@ -472,7 +472,7 @@ impl NdProducer for RawArrayViewMut fn stride_of(&self, axis: Axis) -> isize { - AsRef::as_ref(self).stride_of(axis) + self.stride_of(axis) } #[inline(always)] diff --git a/tests/raw_views.rs b/tests/raw_views.rs index be20aff52..929e969d7 100644 --- a/tests/raw_views.rs +++ b/tests/raw_views.rs @@ -39,8 +39,8 @@ fn raw_view_cast_zst() let a = Array::<(), _>::default((250, 250)); let b: RawArrayView = a.raw_view().cast::(); - assert_eq!(a.shape(), b.as_ref().shape()); - assert_eq!(a.as_ptr() as *const u8, b.as_ref().as_ptr() as *const u8); + assert_eq!(a.shape(), b.shape()); + assert_eq!(a.as_ptr() as *const u8, b.as_ptr() as *const u8); } #[test] diff --git a/tests/test_ref_structure.rs b/tests/test_ref_structure.rs deleted file mode 100644 index 6778097a9..000000000 --- a/tests/test_ref_structure.rs +++ /dev/null @@ -1,39 +0,0 @@ -use ndarray::{array, ArrayBase, ArrayRef, Data, LayoutRef, RawData, RawRef}; - -fn takes_base_raw(arr: &ArrayBase) -{ - takes_rawref(arr.as_ref()); // Doesn't work - takes_layout(arr.as_ref()); -} - -#[allow(dead_code)] -fn takes_base(arr: &ArrayBase) -{ - takes_base_raw(arr); - takes_arrref(arr); // Doesn't work - takes_rawref(arr); // Doesn't work - takes_layout(arr); -} - -fn takes_arrref(_arr: &ArrayRef) -{ - takes_rawref(_arr); - takes_layout(_arr); -} - -fn takes_rawref(_arr: &RawRef) -{ - takes_layout(_arr); -} - -fn takes_layout(_arr: &LayoutRef) {} - -#[test] -fn tester() -{ - let arr = array![1, 2, 3]; - takes_base_raw(&arr); - takes_arrref(&arr); - takes_rawref(&arr); // Doesn't work - takes_layout(&arr); -}