diff --git a/Cargo.lock b/Cargo.lock index 0fb3a08317..c4227fa262 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -523,6 +523,7 @@ dependencies = [ "burn-dataset", "burn-derive", "burn-ndarray", + "burn-sparse", "burn-tch", "burn-tensor", "burn-wgpu", @@ -687,6 +688,24 @@ dependencies = [ "serde", ] +[[package]] +name = "burn-sparse" +version = "0.14.0" +dependencies = [ + "burn-common", + "burn-tensor", + "derive-new", + "half", + "hashbrown 0.14.5", + "num-traits", + "proc-macro2", + "quote", + "rand", + "rand_distr", + "serde", + "syn 2.0.72", +] + [[package]] name = "burn-tch" version = "0.14.0" diff --git a/crates/burn-core/Cargo.toml b/crates/burn-core/Cargo.toml index a13700084a..982fcb75fe 100644 --- a/crates/burn-core/Cargo.toml +++ b/crates/burn-core/Cargo.toml @@ -71,6 +71,7 @@ vision = ["burn-dataset?/vision", "burn-common/network"] # Backend autodiff = ["burn-autodiff"] fusion = ["burn-wgpu?/fusion"] +sparse = ["burn-sparse"] ## Backend features metal = ["burn-candle?/metal"] @@ -116,6 +117,7 @@ burn-cuda = { path = "../burn-cuda", version = "0.14.0", optional = true, defaul burn-autodiff = { path = "../burn-autodiff", version = "0.14.0", optional = true } burn-tch = { path = "../burn-tch", version = "0.14.0", optional = true } burn-candle = { path = "../burn-candle", version = "0.14.0", optional = true } +burn-sparse = { path = "../burn-sparse", version = "0.14.0", optional = true } derive-new = { workspace = true } log = { workspace = true, optional = true } diff --git a/crates/burn-core/src/backend.rs b/crates/burn-core/src/backend.rs index 2203e1d282..bed5a09ddc 100644 --- a/crates/burn-core/src/backend.rs +++ b/crates/burn-core/src/backend.rs @@ -33,3 +33,6 @@ pub use burn_tch as libtorch; #[cfg(feature = "tch")] pub use burn_tch::LibTorch; + +#[cfg(feature = "sparse")] +pub use burn_sparse as sparse; diff --git a/crates/burn-sparse/Cargo.toml b/crates/burn-sparse/Cargo.toml new file mode 100644 index 0000000000..4b60b96254 --- /dev/null +++ b/crates/burn-sparse/Cargo.toml @@ -0,0 +1,43 @@ +[package] +authors = [] +categories = ["science", "no-std", "embedded", "wasm"] +description = "Sparse tensor crate that offers a default sparse backend wrapper around burn backends." +edition.workspace = true +keywords = ["deep-learning", "machine-learning", "tensor", "sparse"] +license.workspace = true +name = "burn-sparse" +readme.workspace = true +repository = "https://github.com/tracel-ai/burn/tree/main/burn-sparse" +version.workspace = true + +[features] +default = ["std"] +doc = ["default"] +experimental-named-tensor = [] +std = ["rand/std", "half/std", "num-traits/std"] +wasm-sync = [] + +[dependencies] +burn-common = { path = "../burn-common", version = "0.14.0", default-features = false } +burn-tensor = { path = "../burn-tensor", version = "0.14.0" } + +proc-macro2 = { workspace = true } +quote = { workspace = true } +syn = { workspace = true } +derive-new = { workspace = true } +half = { workspace = true } +num-traits = { workspace = true } +rand = { workspace = true } +rand_distr = { workspace = true } # use instead of statrs because it supports no_std + +# The same implementation of HashMap in std but with no_std support (only needs alloc crate) +hashbrown = { workspace = true } # no_std compatible + +# Serialization +serde = { workspace = true } + +[dev-dependencies] +rand = { workspace = true, features = ["std", "std_rng"] } # Default enables std + +[package.metadata.docs.rs] +features = ["doc"] diff --git a/crates/burn-sparse/src/coo.rs b/crates/burn-sparse/src/coo.rs new file mode 100644 index 0000000000..b195a323b7 --- /dev/null +++ b/crates/burn-sparse/src/coo.rs @@ -0,0 +1,76 @@ +use burn_tensor::backend::Backend; +use burn_tensor::ops::SparseBoolOps; +use burn_tensor::ops::SparseTensorOps; +use burn_tensor::Dense; +use burn_tensor::Device; +use burn_tensor::Float; +use burn_tensor::Int; +use burn_tensor::Shape; +use burn_tensor::Sparse; +use burn_tensor::SparseStorage; +use burn_tensor::Tensor; +use burn_tensor::TensorData; +use burn_tensor::TensorKind; + +#[derive(Clone, Debug)] +pub struct COO; + +#[derive(Clone, Debug)] +pub struct SparseCOOTensor, const D: usize> { + pub coordinates: Option>, + pub values: Option>, + pub shape: Shape, + pub device: Device, +} + +impl SparseStorage for COO { + type SparsePrimitive, const D: usize> = SparseCOOTensor; + + fn name() -> &'static str { + "SparseCOO" + } +} + +impl SparseTensorOps for COO {} + +pub(crate) fn flatten_coordinates( + coordinates: Tensor, + shape: Shape, + device: &Device, +) -> Tensor { + let mut strides_data = [[1]; D]; + for i in (0..D).rev() { + if D - 1 - i == S { + strides_data[i] = [1]; + } else if D - 1 - i < S { + strides_data[i] = [0]; + } else { + strides_data[i] = [strides_data[i + 1][0] * shape.dims[i + 1] as i64]; + } + } + let strides_data: TensorData = TensorData::from(strides_data); + let strides: Tensor = Tensor::from_data(strides_data, device); + let flat_coordinates: Tensor = strides.mul(coordinates).sum_dim(0).flatten(0, 1); + + flat_coordinates.unsqueeze_dim(0) +} + +pub(crate) fn unflatten_coordinates( + flat_coordinates: Tensor, + new_shape: Shape, +) -> Tensor { + let flat_coordinates = flat_coordinates.squeeze::<1>(0); + let mut remaining_flat_coordinates = flat_coordinates.clone(); + let mut new_coordinates = Vec::with_capacity(D); + + for &dim_size in new_shape.dims.iter().rev() { + let size = dim_size as i64; + let new_coord = remaining_flat_coordinates.clone().remainder_scalar(size); + new_coordinates.push(new_coord.clone()); + remaining_flat_coordinates = remaining_flat_coordinates.div_scalar(size); + } + + new_coordinates.reverse(); + + Tensor::stack(new_coordinates, 0) +} diff --git a/crates/burn-sparse/src/coo_bool.rs b/crates/burn-sparse/src/coo_bool.rs new file mode 100644 index 0000000000..1f63d837b8 --- /dev/null +++ b/crates/burn-sparse/src/coo_bool.rs @@ -0,0 +1,163 @@ +use super::coo::COO; +use crate::SparseCOOTensor; +use crate::{flatten_coordinates, unflatten_coordinates}; +use burn_tensor::Int; +use burn_tensor::ReprPrimitive; +use burn_tensor::Shape; +use burn_tensor::Tensor; +use burn_tensor::{ + backend::Backend, + ops::{SparseBoolOps, SparseTensorOps}, + SparseStorage, +}; +use burn_tensor::{Bool, Dense}; + +impl SparseBoolOps for COO { + fn bool_to_sparse( + dense: ::BoolTensorPrimitive, + ) -> >::SparsePrimitive { + todo!() + } + + fn bool_empty( + shape: burn_tensor::Shape, + device: &burn_tensor::Device, + ) -> >::SparsePrimitive { + todo!() + } + + fn bool_shape( + tensor: &>::SparsePrimitive, + ) -> burn_tensor::Shape { + todo!() + } + + fn bool_reshape( + tensor: >::SparsePrimitive, + shape: burn_tensor::Shape, + ) -> >::SparsePrimitive { + todo!() + } + + fn bool_transpose( + tensor: >::SparsePrimitive, + ) -> >::SparsePrimitive { + todo!() + } + + fn bool_swap_dims( + tensor: >::SparsePrimitive, + dim1: usize, + dim2: usize, + ) -> >::SparsePrimitive { + todo!() + } + + fn bool_permute( + tensor: >::SparsePrimitive, + axes: &[usize], + ) -> >::SparsePrimitive { + todo!() + } + + fn bool_flip( + tensor: >::SparsePrimitive, + axes: &[usize], + ) -> >::SparsePrimitive { + todo!() + } + + fn bool_slice( + tensor: >::SparsePrimitive, + indices: [std::ops::Range; D2], + ) -> >::SparsePrimitive { + todo!() + } + + fn bool_slice_assign( + tensor: >::SparsePrimitive, + ranges: [std::ops::Range; D2], + value: >::SparsePrimitive, + ) -> >::SparsePrimitive { + todo!() + } + + fn bool_device( + tensor: &>::SparsePrimitive, + ) -> burn_tensor::Device { + todo!() + } + + fn bool_to_device( + tensor: >::SparsePrimitive, + device: &burn_tensor::Device, + ) -> >::SparsePrimitive { + todo!() + } + + fn bool_repeat_dim( + tensor: >::SparsePrimitive, + dim: usize, + times: usize, + ) -> >::SparsePrimitive { + todo!() + } + + fn bool_cat( + tensors: Vec<>::SparsePrimitive>, + dim: usize, + ) -> >::SparsePrimitive { + todo!() + } + + fn bool_any( + tensor: >::SparsePrimitive, + ) -> >::SparsePrimitive { + todo!() + } + + fn bool_any_dim( + tensor: >::SparsePrimitive, + dim: usize, + ) -> >::SparsePrimitive { + todo!() + } + + fn bool_all( + tensor: >::SparsePrimitive, + ) -> >::SparsePrimitive { + todo!() + } + + fn bool_all_dim( + tensor: >::SparsePrimitive, + dim: usize, + ) -> >::SparsePrimitive { + todo!() + } + + fn bool_expand( + tensor: >::SparsePrimitive, + shape: burn_tensor::Shape, + ) -> >::SparsePrimitive { + todo!() + } + + fn bool_coordinates( + mut tensor: >::SparsePrimitive, + ) -> Option> { + tensor.coordinates.map(|c| c.into_primitive()) + } + + fn bool_to_dense( + sparse: >::SparsePrimitive, + ) -> B::BoolTensorPrimitive { + todo!() + } + + fn bool_values( + tensor: ReprPrimitive, D>, + ) -> Option> { + tensor.values.map(|v| v.into_primitive()) + } +} diff --git a/crates/burn-sparse/src/coo_float.rs b/crates/burn-sparse/src/coo_float.rs new file mode 100644 index 0000000000..fe34779812 --- /dev/null +++ b/crates/burn-sparse/src/coo_float.rs @@ -0,0 +1,1012 @@ +use super::coo::{flatten_coordinates, unflatten_coordinates, SparseCOOTensor, COO}; +use burn_tensor::cast::ToElement; +use burn_tensor::ops::{FloatElem, SparseBoolOps}; +use burn_tensor::{backend::Backend, ops::SparseFloatOps, Tensor}; +use burn_tensor::{ + Bool, Dense, ElementConversion, Float, ReprPrimitive, Shape, Sparse, SparseStorage, TensorData, + TensorKind, TensorPrimitive, +}; +use burn_tensor::{Device, Int}; + +impl SparseFloatOps for COO { + fn float_to_sparse( + dense: ::FloatTensorPrimitive, + ) -> >::SparsePrimitive { + let dense: Tensor = Tensor::from_primitive(TensorPrimitive::Float(dense)); + + let shape = dense.shape(); + let device = dense.device(); + + let significant = dense.clone().not_equal_elem(0.0); + if !significant.clone().any().into_scalar() { + return Self::float_empty(dense.shape(), &device); + }; + + let coordinates = significant + .clone() + .nonzero() + .into_iter() + .map(|tensor| { + let length = tensor.shape().dims[0]; + let shape = Shape::new([1, length]); + tensor.reshape(shape) + }) + .collect(); + + let coordinates = Tensor::cat(coordinates, 0); + + let dense = dense.flatten(0, D - 1); + + let dims = significant.dims(); + let values = dense.gather( + 0, + significant + .flatten::<1>(0, dims.len() - 1) + .nonzero() + .remove(0), + ); + + let coordinates = Some(coordinates); + let values = Some(values); + + SparseCOOTensor { + coordinates, + values, + shape, + device, + } + } + + fn float_empty( + shape: burn_tensor::Shape, + device: &burn_tensor::Device, + ) -> >::SparsePrimitive { + SparseCOOTensor { + coordinates: None, + values: None, + shape, + device: device.clone(), + } + } + + fn float_to_dense( + sparse: >::SparsePrimitive, + ) -> B::FloatTensorPrimitive { + let SparseCOOTensor { + coordinates, + values, + shape, + device, + } = sparse; + + let (Some(coordinates), Some(values)) = (coordinates, values) else { + return Tensor::::zeros(shape, &device) + .into_primitive() + .tensor(); + }; + + let dense: Tensor = Tensor::zeros(Shape::new([shape.num_elements()]), &device); + let flat_coordinates = + flatten_coordinates::(coordinates, shape.clone(), &device).squeeze(0); + let dense = dense.select_assign(0, flat_coordinates, values); + + dense.reshape(shape).into_primitive().tensor() + } + + fn float_spmm( + lhs: >::SparsePrimitive, + rhs: >::Primitive, + ) -> ::FloatTensorPrimitive { + let SparseCOOTensor { + coordinates, + values, + shape, + device, + } = lhs; + + let rhs: Tensor = Tensor::from_primitive(rhs); + let rhs_shape = rhs.shape(); + let mut out_shape = shape.clone(); + out_shape.dims[D - 1] = rhs_shape.dims[D - 1]; + + let (Some(coordinates), Some(values)) = (coordinates, values) else { + // All zeros, exit early + return Tensor::::zeros(out_shape, &device) + .into_primitive() + .tensor(); + }; + + let nnz = coordinates.shape().dims[1]; + + // Ensure they are of the correct shape to multiply + if shape.dims[D - 1] != rhs_shape.dims[D - 2] { + panic!("Invalid shape for matrix multiplication"); + } + + // Ensure batches are the same + if D > 2 && rhs_shape.dims[0..D - 2] != shape.dims[0..D - 2] { + panic!("Batches must be of the same shape"); + } + + // Compute strides for the dense tensor to match the flattened shape + let mut strides_data = [1; D]; + for i in (0..D - 1).rev() { + strides_data[i] = strides_data[i + 1] * shape.dims[i + 1] as i32; + } + let strides: Tensor = + Tensor::::from_ints(strides_data, &device).unsqueeze_dim(1); + + let column_index = coordinates.clone().slice([D - 1..D, 0..nnz]); + + // the indices into the flat row vector at which the containing matrix starts + let matrix_starts: Tensor = if D > 2 { + coordinates + .clone() + .slice([0..D - 2, 0..nnz]) + .mul(strides.clone().slice([0..D - 2])) + .div_scalar((shape.dims[D - 1]) as i32) + .sum_dim(0) + } else { + Tensor::::zeros(column_index.shape(), &device) + }; + + let row_index = coordinates.slice([D - 2..D - 1, 0..nnz]); + + let gather_index = matrix_starts.clone() + column_index; + let scatter_index = matrix_starts + row_index; + + let gather_index = gather_index + .transpose() + .repeat_dim(1, rhs_shape.dims[D - 1]); + let scatter_index = scatter_index + .transpose() + .repeat_dim(1, rhs_shape.dims[D - 1]); + let values = values.unsqueeze_dim(1).repeat_dim(1, rhs_shape.dims[D - 1]); + + // Flatten the rhs similarly into 2 dimensions + let rhs: Tensor = rhs.reshape([-1, rhs_shape.dims[D - 1] as i32]); + + // Do the matmul using gather/scatter + let output: Tensor = + Tensor::zeros([out_shape.dims[0], rhs.shape().dims[1]], &device); + let gathered = rhs.gather(0, gather_index); + + let multiplied = gathered.mul(values); + + let scattered = output.scatter(0, scatter_index, multiplied); + + scattered.reshape(out_shape).into_primitive().tensor() + } + + fn float_sddmm( + lhs: ::FloatTensorPrimitive, + rhs: ::FloatTensorPrimitive, + sparse: >::SparsePrimitive, + ) -> >::SparsePrimitive { + if sparse.coordinates.is_none() || sparse.values.is_none() { + return sparse; + } + + // Flatten the lhs and rhs into a tensor of rows and cols respectively + let lhs = Tensor::::new(burn_tensor::TensorPrimitive::Float(lhs)); + let rhs = Tensor::::new(burn_tensor::TensorPrimitive::Float(rhs)).transpose(); + let lhs_dims = lhs.shape().dims; + let rhs_dims = rhs.shape().dims; + + if lhs_dims[D - 1] != rhs_dims[D - 1] + || lhs_dims[D - 2] != sparse.shape.dims[D - 2] + || rhs_dims[D - 2] != sparse.shape.dims[D - 1] + { + panic!("invalid dimensions for sddmm. lhs and rhs must have compatible shapes for matmul, and sparse must have the correct shape for output of matmul between lhs and rhs."); + } + + let lhs = lhs.reshape([-1, lhs_dims[D - 1] as i32]); + let rhs = rhs.reshape([-1, rhs_dims[D - 1] as i32]); + + // Flatten the sparse tensor into + let device = sparse.device.clone(); + let mut shape = sparse.shape.clone(); + let lhs_coordinates = sparse + .coordinates + .clone() + .expect("Expected non-empty sparse tensor"); + + // swap the last two dims so its column-first + let swizzle = Tensor::::arange(0..D as i64, &device) + .slice_assign( + [D - 2..D], + Tensor::::from_ints([D - 1, D - 2], &device), + ) + .unsqueeze_dim(1) + .repeat_dim(1, lhs_coordinates.shape().dims[1]); + let rhs_coordinates = lhs_coordinates.clone().gather(0, swizzle); + + let row_indices = flatten_coordinates::(lhs_coordinates, shape.clone(), &device); + + shape.dims.swap(D - 1, D - 2); + let col_indices = flatten_coordinates::(rhs_coordinates, shape.clone(), &device); + + let row_indices = row_indices.transpose().repeat_dim(1, lhs_dims[D - 1]); + let col_indices = col_indices.transpose().repeat_dim(1, rhs_dims[D - 1]); + + let lhs = lhs.gather(0, row_indices); + let rhs = rhs.gather(0, col_indices); + + let dotted = lhs.mul(rhs).sum_dim(1).squeeze(1); + + SparseCOOTensor { + coordinates: sparse.coordinates, + values: Some(dotted), + shape: sparse.shape, + device, + } + } + + fn float_coalesce_sum( + tensor: >::SparsePrimitive, + ) -> >::SparsePrimitive { + if tensor.coordinates.as_ref().map(|c| c.shape().dims[1] <= 1) == Some(true) { + return tensor; + } + let original_shape = tensor.shape.clone(); + + if tensor.coordinates.is_none() && tensor.values.is_none() { + return SparseCOOTensor { + coordinates: None, + values: None, + shape: original_shape, + device: tensor.device, + }; + } + + let coordinates = tensor + .coordinates + .expect("Mismatch between coordinates and values"); + let values = tensor + .values + .expect("Mismatch between coordinates and values"); + let device = tensor.device; + let nnz = coordinates.shape().dims[1]; + + let coordinates = + flatten_coordinates::(coordinates, original_shape.clone(), &device); + let _flat_shape = Shape::new([original_shape.num_elements()]); + + let (coordinates, indices) = coordinates.sort_with_indices(1); + let values = values.select(0, indices.squeeze(0)); + let range = Tensor::::arange(0..nnz as i64, &device).unsqueeze::<2>(); + + // Get the diff of coordinates, diff[i] = coordinates[i]-coordinates[i-1] + let left_slice = coordinates.clone().slice([0..1, 0..nnz - 1]); + let right_slice = coordinates.clone().slice([0..1, 1..nnz]); + let diff = right_slice - left_slice; + let ones = Tensor::::ones(Shape::new([1, 1]), &device); + let diff = Tensor::cat(vec![ones, diff], 1); + + // TODO this all would be way cleaner with cumsum/max, but that is waiting on a pull request as of writing + // inspiration could be taken from pytorch_scatter for better implementations + let unique_mask = diff.not_equal_elem(0); + let unique_indices = unique_mask.clone().nonzero().remove(1); + let steps = Tensor::cat( + vec![unique_indices.clone(), Tensor::from_data([nnz], &device)], + 0, + ); + let unique = steps.shape().dims[0]; + let steps = steps + .clone() + .slice([1..unique]) + .sub(steps.slice([0..unique - 1])) + .max() + // .sub_scalar(1) + .into_scalar() + .elem::(); + + let mut scatter_indices = range.mul(unique_mask.int()); + + for _ in 0..steps { + scatter_indices = scatter_indices + .clone() + .slice([0..1, 1..nnz]) + .max_pair(scatter_indices.slice([0..1, 0..nnz - 1])); + scatter_indices = Tensor::cat( + vec![Tensor::zeros(Shape::new([1, 1]), &device), scatter_indices], + 1, + ); + } + + // Scatter/Gather everything into place + let zeroed = Tensor::::zeros(Shape::new([nnz]), &device); + let values = zeroed.scatter(0, scatter_indices.squeeze(0), values); + let values = values.gather(0, unique_indices.clone()); + let coordinates = coordinates.gather(1, unique_indices.unsqueeze::<2>()); + let coordinates = unflatten_coordinates(coordinates, original_shape.clone()); + + let coordinates = Some(coordinates); + let values = Some(values); + + // reshape back into the original shape and send it! + SparseCOOTensor { + coordinates, + values, + shape: original_shape, + device, + } + } + + fn float_remove_zeros( + tensor: >::SparsePrimitive, + ) -> >::SparsePrimitive { + if tensor.coordinates.is_none() && tensor.values.is_none() { + return tensor; + } + + let coordinates = tensor + .coordinates + .expect("Mismatch between coordinates and values"); + let values = tensor + .values + .expect("Mismatch between coordinates and values"); + let device = tensor.device; + let shape = tensor.shape; + + todo!() + } + + fn float_number_nonzero( + tensor: >::SparsePrimitive, + ) -> usize { + match tensor.coordinates { + Some(coordinates) => coordinates.shape().dims[1], + None => 0, + } + } + + fn float_density( + sparse: >::SparsePrimitive, + ) -> f32 { + match sparse.coordinates { + Some(coordinates) => { + coordinates.shape().dims[1] as f32 / sparse.shape.num_elements() as f32 + } + None => 0.0, + } + } + + fn float_slice( + tensor: >::SparsePrimitive, + indices: [std::ops::Range; D2], + ) -> >::SparsePrimitive { + todo!() + } + + fn float_device( + tensor: &>::SparsePrimitive, + ) -> burn_tensor::Device { + tensor.device.clone() + } + + fn float_to_device( + tensor: >::SparsePrimitive, + device: &burn_tensor::Device, + ) -> >::SparsePrimitive { + SparseCOOTensor { + coordinates: tensor.coordinates.map(|t| t.to_device(device)), + values: tensor.values.map(|t| t.to_device(device)), + shape: tensor.shape, + device: device.clone(), + } + } + + fn float_shape( + tensor: &>::SparsePrimitive, + ) -> burn_tensor::Shape { + tensor.shape.clone() + } + + fn float_reshape( + tensor: >::SparsePrimitive, + out_shape: burn_tensor::Shape, + ) -> >::SparsePrimitive { + if tensor.coordinates.is_none() && tensor.values.is_none() { + return SparseCOOTensor { + coordinates: None, + values: None, + shape: out_shape, + device: tensor.device, + }; + } + + let coordinates = tensor + .coordinates + .expect("Mismatch between coordinates and values"); + let values = tensor + .values + .expect("Mismatch between coordinates and values"); + let shape = tensor.shape; + let device = tensor.device; + + // Flatten the coordinates + let flat_coordinates = flatten_coordinates::(coordinates, shape, &device); + + // Unflatten the coordinates to the new shape + let new_coordinates = unflatten_coordinates(flat_coordinates, out_shape.clone()); + + SparseCOOTensor { + coordinates: Some(new_coordinates), + values: Some(values), + shape: out_shape, + device, + } + } + + fn float_transpose( + tensor: >::SparsePrimitive, + ) -> >::SparsePrimitive { + let d = tensor.shape.dims.len(); + let mut axes: Vec = (0..d).collect(); + axes.swap(d - 1, d - 2); + Self::float_permute(tensor, &axes) + } + + fn float_swap_dims( + tensor: >::SparsePrimitive, + dim1: usize, + dim2: usize, + ) -> >::SparsePrimitive { + let d = tensor.shape.dims.len(); + let mut axes: Vec = (0..d).collect(); + axes.swap(dim1, dim2); + Self::float_permute(tensor, &axes) + } + + fn float_permute( + tensor: >::SparsePrimitive, + axes: &[usize], + ) -> >::SparsePrimitive { + let SparseCOOTensor { + coordinates, + values, + mut shape, + device, + } = tensor; + + for (i, &j) in (0..D).zip(axes).filter(|(i, j)| i < j) { + shape.dims.swap(i, j); + } + + let axes = Tensor::from(axes); + let coordinates = coordinates.map(|coordinates| coordinates.select(0, axes)); + + SparseCOOTensor { + coordinates, + values, + shape, + device, + } + } + + fn float_flip( + tensor: >::SparsePrimitive, + axes: &[usize], + ) -> >::SparsePrimitive { + let SparseCOOTensor { + coordinates, + values, + shape, + device, + } = tensor; + + let (Some(coordinates), Some(values)) = (coordinates, values) else { + // All zeros, exit early + return SparseCOOTensor { + coordinates: None, + values: None, + shape, + device, + }; + }; + + let nnz = coordinates.shape().dims[1]; + + let mut mask = [0; D]; + for &axis in axes { + mask[axis] = 1; + } + let mask: Tensor = Tensor::<_, 1, _>::from_ints(mask, &device) + .unsqueeze_dim(1) + .repeat_dim(1, nnz) + .bool(); + + let flipped: Tensor = Tensor::<_, 1, _>::from_ints(shape.dims, &device) + .unsqueeze_dim(1) + .repeat_dim(1, nnz) + .sub(coordinates.clone()) + .sub_scalar(1); + + let coordinates = coordinates.mask_where(mask, flipped); + + let coordinates = Some(coordinates); + let values = Some(values); + + SparseCOOTensor { + coordinates, + values, + shape, + device, + } + } + + fn float_slice_assign( + tensor: >::SparsePrimitive, + ranges: [std::ops::Range; D2], + mut value: >::SparsePrimitive, + ) -> >::SparsePrimitive { + let value_nnz = value + .coordinates + .as_ref() + .map(|coords| coords.shape().dims[1]) + .unwrap_or(0); + + let mut ranges = Vec::from(ranges); + ranges.extend(tensor.shape.dims[ranges.len()..D1].iter().map(|&l| 0..l)); + let ranges: [core::ops::Range; D1] = ranges.try_into().expect("D2 must be <= D1"); + + let shape = tensor.shape.clone(); + let sliced = Self::float_reshape( + Self::float_slice(tensor.clone(), ranges.clone()), + shape.clone(), + ); + let tensor = Self::float_sub(tensor, sliced); + let offset = Tensor::::from_ints(ranges.map(|r| r.start), &tensor.device); + let offset = offset.unsqueeze_dim::<2>(1).repeat_dim(1, value_nnz); + + value.shape = shape; + value.coordinates = value.coordinates.map(|coords| coords + offset); + + Self::float_add(tensor, value) + } + + fn float_repeat_dim( + tensor: >::SparsePrimitive, + dim: usize, + times: usize, + ) -> >::SparsePrimitive { + let SparseCOOTensor { + coordinates, + values, + shape, + device, + } = tensor; + + let mut out_shape = shape.clone(); + out_shape.dims[dim] *= times; + + let (Some(coordinates), Some(values)) = (coordinates, values) else { + // All zeros, exit early + return SparseCOOTensor { + coordinates: None, + values: None, + shape, + device, + }; + }; + + let device = coordinates.device(); + let nnz = coordinates.shape().dims[1]; + + let values = values.repeat_dim(0, times); + + let coordinates_mask: Tensor = Tensor::zeros(coordinates.shape(), &device); + let ones: Tensor = Tensor::ones(Shape::new([1, nnz]), &device); + let coordinates_mask = coordinates_mask.slice_assign([dim..dim + 1, 0..nnz], ones); + let coordinates = Tensor::cat( + (0..times) + .map(|n| { + coordinates.clone() + + coordinates_mask.clone() * (n as i32) * (shape.dims[dim] as i32) + }) + .collect::>(), + 1, + ); + + let coordinates = Some(coordinates); + let values = Some(values); + + SparseCOOTensor { + coordinates, + values, + shape: out_shape, + device, + } + } + + fn float_cat( + tensors: Vec<>::SparsePrimitive>, + dim: usize, + ) -> >::SparsePrimitive { + todo!() + } + + fn float_any( + tensor: >::SparsePrimitive, + ) -> >::SparsePrimitive { + let SparseCOOTensor { + coordinates, + values: _, + shape: _, + device: _, + } = tensor; + let any = coordinates.is_some(); + let bool = Tensor::::from([any]).into_primitive(); + >::bool_to_sparse(bool) + } + + fn float_any_dim( + tensor: >::SparsePrimitive, + dim: usize, + ) -> >::SparsePrimitive { + panic!("any_dim is unsupported for COO until scatter supports any-based reduction"); + } + + fn float_all( + tensor: >::SparsePrimitive, + ) -> >::SparsePrimitive { + let SparseCOOTensor { + coordinates, + values: _, + shape, + device: _, + } = tensor; + let all = match coordinates { + Some(coordinates) => shape.num_elements() == coordinates.shape().dims[1], + None => false, + }; + let bool = Tensor::::from([all]).into_primitive(); + >::bool_to_sparse(bool) + } + + fn float_all_dim( + tensor: >::SparsePrimitive, + dim: usize, + ) -> >::SparsePrimitive { + panic!("all_dim is unsupported for COO until scatter supports all-based reduction"); + } + + fn float_expand( + tensor: >::SparsePrimitive, + shape: burn_tensor::Shape, + ) -> >::SparsePrimitive { + todo!() + } + + fn float_add( + lhs: >::SparsePrimitive, + rhs: >::SparsePrimitive, + ) -> >::SparsePrimitive { + let SparseCOOTensor { + coordinates: lhs_coordinates, + values: lhs_values, + shape: lhs_shape, + device: lhs_device, + } = lhs; + let (Some(lhs_coordinates), Some(lhs_values)) = (lhs_coordinates, lhs_values) else { + return rhs; + }; + + let SparseCOOTensor { + coordinates: rhs_coordinates, + values: rhs_values, + shape: rhs_shape, + device: rhs_device, + } = rhs; + let (Some(rhs_coordinates), Some(rhs_values)) = (rhs_coordinates, rhs_values) else { + return SparseCOOTensor { + coordinates: Some(lhs_coordinates), + values: Some(lhs_values), + shape: lhs_shape, + device: lhs_device, + }; + }; + + assert_eq!(lhs_shape, rhs_shape); + assert_eq!(lhs_device, rhs_device); + + let coordinates = Some(Tensor::cat(vec![lhs_coordinates, rhs_coordinates], 1)); + let values = Some(Tensor::cat(vec![lhs_values, rhs_values], 0)); + let shape = lhs_shape; + let device = lhs_device; + + let result = SparseCOOTensor { + coordinates, + values, + shape, + device, + }; + + Self::float_coalesce_sum(result) + } + + fn float_sub( + lhs: >::SparsePrimitive, + rhs: >::SparsePrimitive, + ) -> >::SparsePrimitive { + Self::float_add( + lhs, + Self::float_mul_scalar(rhs, FloatElem::::from_elem(-1.0)), + ) + } + + fn float_mul( + lhs: >::SparsePrimitive, + rhs: >::SparsePrimitive, + ) -> >::SparsePrimitive { + panic!("float_mul is unsupported until scatter supports multiplication based reduction"); + } + + fn float_mul_scalar( + mut lhs: >::SparsePrimitive, + rhs: burn_tensor::ops::FloatElem, + ) -> >::SparsePrimitive { + lhs.values = lhs.values.map(|values| values.mul_scalar(rhs)); + lhs + } + + fn float_div( + lhs: >::SparsePrimitive, + rhs: >::SparsePrimitive, + ) -> >::SparsePrimitive { + panic!("float_div is unsupported until scatter supports multiplication based reduction"); + } + + fn float_div_scalar( + mut lhs: >::SparsePrimitive, + rhs: burn_tensor::ops::FloatElem, + ) -> >::SparsePrimitive { + lhs.values = lhs.values.map(|values| values.div_scalar(rhs)); + lhs + } + + fn float_max( + tensor: >::SparsePrimitive, + ) -> >::SparsePrimitive { + panic!("max is unsupported for COO until scatter supports max reduction"); + } + + fn float_max_dim( + tensor: >::SparsePrimitive, + dim: usize, + ) -> >::SparsePrimitive { + panic!("max_dim is unsupported for COO until scatter supports max reduction"); + } + + fn float_min( + tensor: >::SparsePrimitive, + ) -> >::SparsePrimitive { + panic!("min is unsupported for COO until scatter supports min reduction"); + } + + fn float_min_dim( + tensor: >::SparsePrimitive, + dim: usize, + ) -> >::SparsePrimitive { + panic!("min_dim is unsupported for COO until scatter supports min reduction"); + } + + fn float_abs( + mut tensor: >::SparsePrimitive, + ) -> >::SparsePrimitive { + tensor.values = tensor.values.map(|values| values.abs()); + tensor + } + + fn float_sign( + mut tensor: >::SparsePrimitive, + ) -> >::SparsePrimitive { + tensor.values = tensor.values.map(|values| values.sign()); + tensor + } + + fn float_powf( + lhs: >::SparsePrimitive, + rhs: >::SparsePrimitive, + ) -> >::SparsePrimitive { + panic!("float_powf is unsupported for COO until scatter supports other reduction methods"); + } + + fn float_powi( + lhs: >::SparsePrimitive, + rhs: >::SparsePrimitive, + ) -> >::SparsePrimitive { + panic!("float_powi is unsupported for COO until scatter supports other reduction methods"); + } + + fn float_powf_scalar( + mut lhs: >::SparsePrimitive, + rhs: burn_tensor::ops::FloatElem, + ) -> >::SparsePrimitive { + lhs.values = lhs.values.map(|values| values.powf_scalar(rhs)); + lhs + } + + fn float_powi_scalar( + mut lhs: >::SparsePrimitive, + rhs: burn_tensor::ops::FloatElem, + ) -> >::SparsePrimitive { + lhs.values = lhs.values.map(|values| values.powi_scalar(rhs)); + lhs + } + + fn float_clamp( + mut tensor: >::SparsePrimitive, + min: burn_tensor::ops::FloatElem, + max: burn_tensor::ops::FloatElem, + ) -> >::SparsePrimitive { + tensor.values = tensor.values.map(|values| values.clamp(min, max)); + if min.to_f64() == 0f64 || max.to_f64() == 0f64 { + // Clamp can zero elements if a boundary is zero + Self::float_remove_zeros(tensor) + } else { + tensor + } + } + + fn float_clamp_min( + mut tensor: >::SparsePrimitive, + min: burn_tensor::ops::FloatElem, + ) -> >::SparsePrimitive { + tensor.values = tensor.values.map(|values| values.clamp_min(min)); + if min.to_f64() == 0f64 { + // min can zero elements if boundary is 0 + Self::float_remove_zeros(tensor) + } else { + tensor + } + } + + fn float_clamp_max( + mut tensor: >::SparsePrimitive, + max: burn_tensor::ops::FloatElem, + ) -> >::SparsePrimitive { + tensor.values = tensor.values.map(|values| values.clamp_max(max)); + if max.to_f64() == 0f64 { + // max can zero elements if boundary is 0 + Self::float_remove_zeros(tensor) + } else { + tensor + } + } + + fn float_select( + tensor: >::SparsePrimitive, + dim: usize, + indices: burn_tensor::ops::IntTensor, + ) -> >::SparsePrimitive { + if tensor.coordinates.is_none() && tensor.values.is_none() { + return tensor; + } + + let coordinates = tensor + .coordinates + .expect("Mismatch between coordinates and values"); + let values = tensor + .values + .expect("Mismatch between coordinates and values"); + let device = tensor.device; + let mut shape = tensor.shape; + let indices = Tensor::::new(indices); + + let nnz = coordinates.shape().dims[1]; + let dim_coords = coordinates + .clone() + .slice([dim..dim + 1, 0..nnz]) + .squeeze::<1>(0); + let indices = indices.select(0, dim_coords); + let indices_len = indices.shape().num_elements(); + let coordinates = coordinates.slice_assign( + [dim..dim + 1, 0..nnz], + indices.unsqueeze::<2>().repeat_dim(1, D), + ); + + shape.dims[dim] = indices_len; + + SparseCOOTensor { + coordinates: Some(coordinates), + values: Some(values), + shape, + device, + } + } + + fn float_select_assign( + tensor: >::SparsePrimitive, + dim: usize, + indices: burn_tensor::ops::IntTensor, + values: >::SparsePrimitive, + ) -> >::SparsePrimitive { + todo!() + } + + fn float_gather( + dim: usize, + tensor: >::SparsePrimitive, + indices: burn_tensor::ops::IntTensor, + ) -> >::SparsePrimitive { + todo!() + } + + fn float_scatter( + dim: usize, + tensor: >::SparsePrimitive, + indices: burn_tensor::ops::IntTensor, + values: >::SparsePrimitive, + ) -> >::SparsePrimitive { + todo!() + } + + fn float_sum( + tensor: >::SparsePrimitive, + ) -> >::SparsePrimitive { + tensor + .values + .map(|values| Self::float_to_sparse(values.sum().into_primitive().tensor())) + .unwrap_or(Self::float_empty(Shape::new([1]), &tensor.device)) + } + + fn float_sum_dim( + tensor: >::SparsePrimitive, + dim: usize, + ) -> >::SparsePrimitive { + panic!("float_sum_dim unsupported for COO"); + } + + fn float_prod_dim( + tensor: >::SparsePrimitive, + dim: usize, + ) -> >::SparsePrimitive { + panic!("float_prod_dim is not supported for COO until scatter supports product reduction") + } + + fn float_mean( + tensor: >::SparsePrimitive, + ) -> >::SparsePrimitive { + let num_elems = tensor.shape.num_elements(); + Self::float_div_scalar( + Self::float_sum(tensor), + ElementConversion::elem(num_elems as f32), + ) + } + + fn float_mean_dim( + tensor: >::SparsePrimitive, + dim: usize, + ) -> >::SparsePrimitive { + panic!("float_mean_dim is not supported for COO until scatter supports mean reduction"); + } + + fn float_remainder_scalar( + mut lhs: >::SparsePrimitive, + rhs: burn_tensor::ops::FloatElem, + ) -> >::SparsePrimitive { + lhs.values = lhs.values.map(|values| values.remainder_scalar(rhs)); + lhs + } + + fn float_neg( + mut tensor: >::SparsePrimitive, + ) -> >::SparsePrimitive { + tensor.values = tensor.values.map(|values| values.neg()); + tensor + } + + fn float_coordinates( + mut tensor: >::SparsePrimitive, + ) -> Option> { + tensor.coordinates.map(|c| c.into_primitive()) + } + + fn float_values( + mut tensor: >::SparsePrimitive, + ) -> Option> { + tensor.values.map(|v| v.into_primitive()) + } +} diff --git a/crates/burn-sparse/src/coo_int.rs b/crates/burn-sparse/src/coo_int.rs new file mode 100644 index 0000000000..67a9df94db --- /dev/null +++ b/crates/burn-sparse/src/coo_int.rs @@ -0,0 +1,175 @@ +use super::coo::COO; +use crate::SparseCOOTensor; +use crate::{flatten_coordinates, unflatten_coordinates}; +use burn_tensor::Dense; +use burn_tensor::Int; +use burn_tensor::ReprPrimitive; +use burn_tensor::Shape; +use burn_tensor::Tensor; +use burn_tensor::{backend::Backend, ops::SparseIntOps, SparseStorage}; + +impl SparseIntOps for COO { + fn int_empty( + shape: burn_tensor::Shape, + device: &burn_tensor::Device, + ) -> >::SparsePrimitive { + todo!() + } + + fn int_shape( + tensor: &>::SparsePrimitive, + ) -> burn_tensor::Shape { + todo!() + } + + fn int_reshape( + tensor: >::SparsePrimitive, + shape: burn_tensor::Shape, + ) -> >::SparsePrimitive { + todo!() + } + + fn int_transpose( + tensor: >::SparsePrimitive, + ) -> >::SparsePrimitive { + todo!() + } + + fn int_swap_dims( + tensor: >::SparsePrimitive, + dim1: usize, + dim2: usize, + ) -> >::SparsePrimitive { + todo!() + } + + fn int_permute( + tensor: >::SparsePrimitive, + axes: &[usize], + ) -> >::SparsePrimitive { + todo!() + } + + fn int_flip( + tensor: >::SparsePrimitive, + axes: &[usize], + ) -> >::SparsePrimitive { + todo!() + } + + fn int_slice( + tensor: >::SparsePrimitive, + indices: [std::ops::Range; D2], + ) -> >::SparsePrimitive { + todo!() + } + + fn int_slice_assign( + tensor: >::SparsePrimitive, + ranges: [std::ops::Range; D2], + value: >::SparsePrimitive, + ) -> >::SparsePrimitive { + todo!() + } + + fn int_device( + tensor: &>::SparsePrimitive, + ) -> burn_tensor::Device { + todo!() + } + + fn int_to_device( + tensor: >::SparsePrimitive, + device: &burn_tensor::Device, + ) -> >::SparsePrimitive { + todo!() + } + + fn int_repeat_dim( + tensor: >::SparsePrimitive, + dim: usize, + times: usize, + ) -> >::SparsePrimitive { + todo!() + } + + fn int_cat( + tensors: Vec<>::SparsePrimitive>, + dim: usize, + ) -> >::SparsePrimitive { + todo!() + } + + fn int_any( + tensor: >::SparsePrimitive, + ) -> >::SparsePrimitive { + todo!() + } + + fn int_any_dim( + tensor: >::SparsePrimitive, + dim: usize, + ) -> >::SparsePrimitive { + todo!() + } + + fn int_all( + tensor: >::SparsePrimitive, + ) -> >::SparsePrimitive { + todo!() + } + + fn int_all_dim( + tensor: >::SparsePrimitive, + dim: usize, + ) -> >::SparsePrimitive { + todo!() + } + + fn int_expand( + tensor: >::SparsePrimitive, + shape: burn_tensor::Shape, + ) -> >::SparsePrimitive { + todo!() + } + + fn int_coordinates( + mut tensor: >::SparsePrimitive, + ) -> Option> { + tensor.coordinates.map(|c| c.into_primitive()) + } + + fn int_to_dense( + sparse: >::SparsePrimitive, + ) -> B::IntTensorPrimitive { + let SparseCOOTensor { + coordinates, + values, + shape, + device, + } = sparse; + + let (Some(coordinates), Some(values)) = (coordinates, values) else { + return Tensor::::zeros(shape, &device).into_primitive(); + }; + + let dense: Tensor = Tensor::zeros(Shape::new([shape.num_elements()]), &device); + let flat_coordinates = + flatten_coordinates::(coordinates, shape.clone(), &device).squeeze(0); + let dense = dense.select_assign(0, flat_coordinates, values); + + dense.reshape(shape).into_primitive() + } + + fn int_to_sparse( + dense: ::IntTensorPrimitive, + ) -> >::SparsePrimitive { + todo!() + } + + fn int_values( + tensor: ReprPrimitive, D>, + ) -> Option> { + tensor.values.map(|v| v.into_primitive()) + } +} diff --git a/crates/burn-sparse/src/lib.rs b/crates/burn-sparse/src/lib.rs new file mode 100644 index 0000000000..63a079f460 --- /dev/null +++ b/crates/burn-sparse/src/lib.rs @@ -0,0 +1,9 @@ +mod coo; +mod coo_bool; +mod coo_float; +mod coo_int; + +pub use coo::*; +pub use coo_bool::*; +pub use coo_float::*; +pub use coo_int::*; diff --git a/crates/burn-tensor/src/tensor/api/base.rs b/crates/burn-tensor/src/tensor/api/base.rs index 1201fcda27..c57257c59a 100644 --- a/crates/burn-tensor/src/tensor/api/base.rs +++ b/crates/burn-tensor/src/tensor/api/base.rs @@ -18,41 +18,64 @@ use crate::check::TensorCheck; use crate::tensor::api::chunk::chunk; use crate::tensor::api::narrow::narrow; use crate::{backend::Backend, check, Bool, Float, Int, Shape, TensorData, TensorKind}; -use crate::{DType, Element, TensorPrimitive}; +use crate::{DType, Dense, Element, ReprPrimitive, TensorPrimitive, TensorRepr, TensorStorage}; /// A tensor with a given backend, shape and data type. #[derive(new, Clone, Debug)] -pub struct Tensor +pub struct Tensor where B: Backend, K: TensorKind, + SR: TensorStorage, + (B, K, SR): TensorRepr, { - pub(crate) primitive: K::Primitive, + pub(crate) primitive: <(B, K, SR) as TensorRepr>::Primitive, } -impl From for Tensor +impl From for Tensor where B: Backend, - K: BasicOps, + K: BasicOps, + SR: TensorStorage, T: Into, + (B, K, SR): TensorRepr, + (B, Bool, SR): TensorRepr, { fn from(value: T) -> Self { Tensor::from_data(value.into(), &Default::default()) } } -impl Tensor +// impl Tensor +// where +// B: Backend, +// R: TensorRepr, +// K: TensorKind, +// { +// fn change_repr, +// R: ChangeRepr, +// { +// R::change_repr(self) +// } +// } + +impl Tensor where B: Backend, - K: BasicOps, + K: BasicOps, + SR: TensorStorage, + (B, K, SR): TensorRepr, + (B, Bool, SR): TensorRepr, { /// Converts the tensor into a primitive tensor. - pub fn into_primitive(self) -> K::Primitive { + pub fn into_primitive(self) -> ReprPrimitive { self.primitive } /// Converts from a primitive tensor into a tensor. - pub fn from_primitive(tensor: K::Primitive) -> Self { + pub fn from_primitive(tensor: ReprPrimitive) -> Self { Self::new(tensor) } @@ -106,7 +129,7 @@ where /// println!("{:?}", reshaped_tensor.shape()); /// } /// ``` - pub fn reshape>(self, shape: S) -> Tensor { + pub fn reshape>(self, shape: S) -> Tensor { // Convert reshape args to shape let shape = shape.into_shape(&self); Tensor::new(K::reshape::(self.primitive, shape)) @@ -121,7 +144,7 @@ where /// # Returns /// /// The transposed tensor. - pub fn transpose(self) -> Tensor { + pub fn transpose(self) -> Tensor { Tensor::new(K::transpose(self.primitive)) } @@ -136,7 +159,7 @@ where /// # Returns /// /// The tensor with the dimensions swapped. - pub fn swap_dims(self, dim1: usize, dim2: usize) -> Tensor { + pub fn swap_dims(self, dim1: usize, dim2: usize) -> Tensor { Tensor::new(K::swap_dims(self.primitive, dim1, dim2)) } @@ -152,7 +175,7 @@ where /// # Returns /// /// The tensor with the dimensions permuted. - pub fn permute(self, axes: [isize; D]) -> Tensor { + pub fn permute(self, axes: [isize; D]) -> Tensor { // Convert the axes to usize and handle negative values without using vector let mut transformed_axes: [usize; D] = [0; D]; for (i, &x) in axes.iter().enumerate() { @@ -192,7 +215,11 @@ where /// The tensor with the dimensions moved. // This is a semantic sugar for `permute`. It is used widely enough, so we define a separate Op // for it - pub fn movedim(self, src: S1, dst: S2) -> Tensor { + pub fn movedim( + self, + src: S1, + dst: S2, + ) -> Tensor { let source_dims = src.into_dim_vec::(); let destination_dims = dst.into_dim_vec::(); @@ -233,7 +260,7 @@ where /// # Returns /// /// The tensor with the axes flipped. - pub fn flip(self, axes: [isize; N]) -> Tensor { + pub fn flip(self, axes: [isize; N]) -> Tensor { // Convert the axes to usize and handle negative values without using vector let mut transformed_axes: [usize; N] = [0; N]; for (i, &x) in axes.iter().enumerate() { @@ -287,7 +314,11 @@ where /// } /// /// ``` - pub fn flatten(self, start_dim: usize, end_dim: usize) -> Tensor { + pub fn flatten( + self, + start_dim: usize, + end_dim: usize, + ) -> Tensor { check!(TensorCheck::flatten::(start_dim, end_dim)); let current_dims = self.shape().dims; @@ -338,7 +369,7 @@ where /// println!("{:?}", squeezed_tensor.shape()); /// } /// ``` - pub fn squeeze(self, dim: usize) -> Tensor { + pub fn squeeze(self, dim: usize) -> Tensor { check!(TensorCheck::squeeze::(dim, &self.shape().dims)); let current_dims = self.shape().dims; @@ -387,7 +418,7 @@ where /// println!("{:?}", squeezed_tensor.shape()); /// } /// ``` - pub fn squeeze_dims(self, dims: &[isize]) -> Tensor { + pub fn squeeze_dims(self, dims: &[isize]) -> Tensor { let current_dims = self.shape().dims; let mut dim_indices: Vec; @@ -457,7 +488,7 @@ where /// // Shape { dims: [1, 1, 3, 3] } /// } /// ``` - pub fn unsqueeze(self) -> Tensor { + pub fn unsqueeze(self) -> Tensor { check!(TensorCheck::unsqueeze::()); let mut dims = [1; D2]; @@ -486,7 +517,7 @@ where /// // Shape { dims: [3, 1, 3] } /// } /// ``` - pub fn unsqueeze_dim(self, dim: usize) -> Tensor { + pub fn unsqueeze_dim(self, dim: usize) -> Tensor { check!(TensorCheck::unsqueeze_dim::<{ D }>(dim)); let mut dims = [1; D2]; @@ -523,7 +554,7 @@ where /// // Shape { dims: [1, 3, 4, 5, 1, 1] } /// } /// ``` - pub fn unsqueeze_dims(self, axes: &[isize]) -> Tensor { + pub fn unsqueeze_dims(self, axes: &[isize]) -> Tensor { let mut new_dims = [1; D2]; let old_dims = self.shape().dims; //for checking if the dimension is in the acceptable range @@ -635,7 +666,7 @@ where /// This function uses the `RangesArg` trait for flexible range specification. The trait /// handles the conversion of various range formats and applies clamping and negative /// index handling internally. - pub fn slice>(self, ranges: R) -> Self { + pub fn slice>(self, ranges: RA) -> Self { let ranges = ranges.into_ranges(self.shape()); check!(TensorCheck::slice(&self.shape(), &ranges)); @@ -744,8 +775,8 @@ where /// # Panics /// /// If the two tensors don't have the same shape. - pub fn equal(self, other: Self) -> Tensor { - check!(TensorCheck::binary_ops_ew("Equal", &self, &other)); + pub fn equal(self, other: Self) -> Tensor { + // check!(TensorCheck::binary_ops_ew("Equal", &self, &other)); K::equal(self.primitive, other.primitive) } @@ -754,8 +785,8 @@ where /// # Panics /// /// If the two tensors don't have the same shape. - pub fn not_equal(self, other: Self) -> Tensor { - check!(TensorCheck::binary_ops_ew("NotEqual", &self, &other)); + pub fn not_equal(self, other: Self) -> Tensor { + // check!(TensorCheck::binary_ops_ew("NotEqual", &self, &other)); K::not_equal(self.primitive, other.primitive) } @@ -765,7 +796,7 @@ where /// /// If all tensors don't have the same shape. pub fn cat(tensors: Vec, dim: usize) -> Self { - check!(TensorCheck::cat(&tensors, dim)); + // check!(TensorCheck::cat(&tensors, dim)); Self::new(K::cat( tensors.into_iter().map(|vector| vector.primitive).collect(), @@ -779,10 +810,13 @@ where /// /// If all tensors don't have the same shape. /// Given dimension is not with range of 0..D2 - pub fn stack(tensors: Vec>, dim: usize) -> Tensor { - check!(TensorCheck::stack(&tensors, dim)); + pub fn stack( + tensors: Vec>, + dim: usize, + ) -> Tensor { + // check!(TensorCheck::stack(&tensors, dim)); let tensors = tensors.into_iter().map(|t| t.unsqueeze_dim(dim)).collect(); - Tensor::::cat(tensors, dim) + Tensor::::cat(tensors, dim) } /// Iterate over slices of tensors alongside a given dimension. @@ -794,9 +828,9 @@ where /// # Returns /// /// A tensor iterator. - pub fn iter_dim(self, dim: usize) -> DimIter { + pub fn iter_dim(self, dim: usize) -> DimIter { check!(TensorCheck::dim_ops::("iter_dim", dim)); - DimIter::new(self, dim) + DimIter::::new(self, dim) } /// Returns a new tensor with the given dimension narrowed to the given range. @@ -811,8 +845,8 @@ where /// A new tensor with the given dimension narrowed to the given range. pub fn narrow(self, dim: usize, start: usize, length: usize) -> Self { check!(TensorCheck::dim_ops::("narrow", dim)); - check!(TensorCheck::narrow(&self, dim, start, length)); - Self::new(narrow::(self.primitive, dim, start, length)) + // check!(TensorCheck::narrow(&self, dim, start, length)); + Self::new(narrow::(self.primitive, dim, start, length)) } /// Attempts to split the tensor along the given dimension into chunks. @@ -829,7 +863,7 @@ where /// A vector of tensors. pub fn chunk(self, chunks: usize, dim: usize) -> Vec { check!(TensorCheck::dim_ops::("chunk", dim)); - chunk::(self.primitive, chunks, dim) + chunk::(self.primitive, chunks, dim) .into_iter() .map(|v| Self::new(v)) .collect() @@ -845,7 +879,7 @@ where /// /// A boolean tensor `Tensor` containing a single element, True if any element in the input tensor /// evaluates to True, False otherwise. - pub fn any(self) -> Tensor { + pub fn any(self) -> Tensor { K::any(self.primitive) } @@ -861,7 +895,7 @@ where /// A boolean tensor `Tensor` with the same size as input `tensor`, except in the `dim` axis /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the input /// evaluates to True, False otherwise. - pub fn any_dim(self, dim: usize) -> Tensor { + pub fn any_dim(self, dim: usize) -> Tensor { K::any_dim(self.primitive, dim) } @@ -875,7 +909,7 @@ where /// /// A boolean tensor `Tensor` with a single element, True if all elements in the input tensor /// evaluate to True, False otherwise. - pub fn all(self) -> Tensor { + pub fn all(self) -> Tensor { K::all(self.primitive) } @@ -891,7 +925,7 @@ where /// A boolean tensor `Tensor` with the same size as input `tensor`, except in the `dim` axis /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input /// evaluates to True, False otherwise. - pub fn all_dim(self, dim: usize) -> Tensor { + pub fn all_dim(self, dim: usize) -> Tensor { K::all_dim(self.primitive, dim) } @@ -935,29 +969,41 @@ where /// # Returns /// /// A new tensor with the given shape. - pub fn expand>(self, shape: S) -> Tensor { + pub fn expand>( + self, + shape: S, + ) -> Tensor { let shape = shape.into_shape(&self.shape()); check!(TensorCheck::expand("expand", &self.shape(), &shape,)); - Tensor::::new(K::expand(self.primitive, shape)) + Tensor::::new(K::expand(self.primitive, shape)) } } /// Iterator given by (Tensor::iter_dim). -pub struct DimIter +pub struct DimIter where B: Backend, - K: BasicOps, + K: BasicOps, + SR: TensorStorage, + Bool: TensorKind, + (B, K, SR): TensorRepr, + (B, Bool, SR): TensorRepr, { start: usize, end: usize, dim: usize, ranges: [Range; D], - tensor: Tensor, + tensor: Tensor, } -impl> Iterator for DimIter { - type Item = Tensor; +impl, SR: TensorStorage> Iterator + for DimIter +where + (B, K, SR): TensorRepr, + (B, Bool, SR): TensorRepr, +{ + type Item = Tensor; fn next(&mut self) -> Option { if self.start >= self.end { @@ -990,8 +1036,12 @@ impl> DoubleEndedIterator for DimIter } } -impl> DimIter { - fn new(tensor: Tensor, dim: usize) -> Self { +impl, SR: TensorStorage> DimIter +where + (B, K, SR): TensorRepr, + (B, Bool, SR): TensorRepr, +{ + fn new(tensor: Tensor, dim: usize) -> Self { let dims = tensor.dims(); let ranges = dims .iter() @@ -1278,7 +1328,11 @@ impl core::ops::BitXor for Tensor { /// # Warnings /// /// This is an internal trait, use the public API provided by [tensor struct](Tensor). -pub trait BasicOps: TensorKind { +pub trait BasicOps = Dense>: TensorKind +where + (B, Self, SR): TensorRepr, + (B, Bool, SR): TensorRepr, +{ /// The type of the tensor elements. type Elem: Element; @@ -1301,7 +1355,7 @@ pub trait BasicOps: TensorKind { /// /// For creating empty tensors, users should prefer the [Tensor::empty](Tensor::empty) function, /// which is more high-level and designed for public use. - fn empty(shape: Shape, device: &B::Device) -> Self::Primitive; + fn empty(shape: Shape, device: &B::Device) -> ReprPrimitive; /// Returns the shape of the tensor. /// @@ -1321,7 +1375,7 @@ pub trait BasicOps: TensorKind { /// /// For getting the shape of a tensor, users should prefer the [Tensor::shape](Tensor::shape) function, /// which is more high-level and designed for public use. - fn shape(tensor: &Self::Primitive) -> Shape; + fn shape(tensor: &ReprPrimitive) -> Shape; /// Reshapes the tensor. /// @@ -1343,9 +1397,9 @@ pub trait BasicOps: TensorKind { /// For reshaping a tensor, users should prefer the [Tensor::reshape](Tensor::reshape) function, /// which is more high-level and designed for public use. fn reshape( - tensor: Self::Primitive, + tensor: ReprPrimitive, shape: Shape, - ) -> Self::Primitive; + ) -> ReprPrimitive; /// Transposes a tensor. /// @@ -1356,7 +1410,9 @@ pub trait BasicOps: TensorKind { /// # Returns /// /// The transposed tensor. - fn transpose(tensor: Self::Primitive) -> Self::Primitive; + fn transpose( + tensor: ReprPrimitive, + ) -> ReprPrimitive; /// Swaps two dimensions of a tensor. /// @@ -1370,10 +1426,10 @@ pub trait BasicOps: TensorKind { /// /// The tensor with the dimensions swapped. fn swap_dims( - tensor: Self::Primitive, + tensor: ReprPrimitive, dim1: usize, dim2: usize, - ) -> Self::Primitive; + ) -> ReprPrimitive; /// Permutes the dimensions of a tensor. /// @@ -1385,7 +1441,10 @@ pub trait BasicOps: TensorKind { /// # Returns /// /// The tensor with the dimensions permuted. - fn permute(tensor: Self::Primitive, axes: [usize; D]) -> Self::Primitive; + fn permute( + tensor: ReprPrimitive, + axes: [usize; D], + ) -> ReprPrimitive; /// Flips the tensor along the given axes. /// @@ -1397,7 +1456,10 @@ pub trait BasicOps: TensorKind { /// # Returns /// /// The tensor with the axes flipped. - fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive; + fn flip( + tensor: ReprPrimitive, + axes: &[usize], + ) -> ReprPrimitive; /// Select tensor elements corresponding for the given ranges. /// @@ -1419,9 +1481,9 @@ pub trait BasicOps: TensorKind { /// For selecting elements of a tensor, users should prefer the [Tensor::slice](Tensor::slice) function, /// which is more high-level and designed for public use. fn slice( - tensor: Self::Primitive, + tensor: ReprPrimitive, range: [Range; D2], - ) -> Self::Primitive; + ) -> ReprPrimitive; /// Assigns the given value to the tensor elements corresponding for the given ranges. /// @@ -1444,10 +1506,10 @@ pub trait BasicOps: TensorKind { /// For assigning values to elements of a tensor, users should prefer the [Tensor::slice_assign](Tensor::slice_assign) function, /// which is more high-level and designed for public use. fn slice_assign( - tensor: Self::Primitive, + tensor: ReprPrimitive, ranges: [Range; D2], - value: Self::Primitive, - ) -> Self::Primitive; + value: ReprPrimitive, + ) -> ReprPrimitive; /// Returns the device on which the tensor is allocated. /// @@ -1467,7 +1529,7 @@ pub trait BasicOps: TensorKind { /// /// For getting the device of a tensor, users should prefer the [Tensor::device](Tensor::device) function, /// which is more high-level and designed for public use. - fn device(tensor: &Self::Primitive) -> B::Device; + fn device(tensor: &ReprPrimitive) -> B::Device; /// Moves the tensor to the given device. /// @@ -1489,9 +1551,9 @@ pub trait BasicOps: TensorKind { /// For moving a tensor to a device, users should prefer the [Tensor::to_device](Tensor::to_device) function, /// which is more high-level and designed for public use. fn to_device( - tensor: Self::Primitive, + tensor: ReprPrimitive, device: &B::Device, - ) -> Self::Primitive; + ) -> ReprPrimitive; /// Extracts the data from the tensor asynchronously. /// @@ -1512,7 +1574,7 @@ pub trait BasicOps: TensorKind { /// For extracting the data of a tensor, users should prefer the [Tensor::into_data](Tensor::into_data) function, /// which is more high-level and designed for public use. fn into_data_async( - tensor: Self::Primitive, + tensor: ReprPrimitive, ) -> impl Future + Send; /// Creates a tensor from the given data. @@ -1534,7 +1596,10 @@ pub trait BasicOps: TensorKind { /// /// For creating a tensor from data, users should prefer the [Tensor::from_data](Tensor::from_data) function, /// which is more high-level and designed for public use. - fn from_data(data: TensorData, device: &B::Device) -> Self::Primitive; + fn from_data( + data: TensorData, + device: &B::Device, + ) -> ReprPrimitive; /// Repeat the tensor along the given dimension. /// @@ -1557,10 +1622,10 @@ pub trait BasicOps: TensorKind { /// For repeating a tensor, users should prefer the [Tensor::repeat_dim](Tensor::repeat_dim) function, /// which is more high-level and designed for public use. fn repeat_dim( - tensor: Self::Primitive, + tensor: ReprPrimitive, dim: usize, times: usize, - ) -> Self::Primitive; + ) -> ReprPrimitive; /// Concatenates the given tensors along the given dimension. /// @@ -1581,7 +1646,10 @@ pub trait BasicOps: TensorKind { /// /// For concatenating tensors, users should prefer the [Tensor::cat](Tensor::cat) function, /// which is more high-level and designed for public use. - fn cat(vectors: Vec>, dim: usize) -> Self::Primitive; + fn cat( + vectors: Vec>, + dim: usize, + ) -> ReprPrimitive; /// Equates the given tensors. /// @@ -1603,9 +1671,9 @@ pub trait BasicOps: TensorKind { /// For equating tensors, users should prefer the [Tensor::equal](Tensor::equal) function, /// which is more high-level and designed for public use. fn equal( - lhs: Self::Primitive, - rhs: Self::Primitive, - ) -> Tensor; + lhs: ReprPrimitive, + rhs: ReprPrimitive, + ) -> Tensor; /// Applies element-wise non-equality comparison between the given tensors. /// @@ -1627,9 +1695,9 @@ pub trait BasicOps: TensorKind { /// For non-equality comparison of tensors, users should prefer the [Tensor::not_equal](Tensor::not_equal) /// function, which is more high-level and designed for public use. fn not_equal( - lhs: Self::Primitive, - rhs: Self::Primitive, - ) -> Tensor; + lhs: ReprPrimitive, + rhs: ReprPrimitive, + ) -> Tensor; /// Returns the name of the element type. fn elem_type_name() -> &'static str { @@ -1652,7 +1720,7 @@ pub trait BasicOps: TensorKind { /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. Users should prefer the [Tensor::any](Tensor::any) function /// which is more high-level and designed for public use. - fn any(tensor: Self::Primitive) -> Tensor; + fn any(tensor: ReprPrimitive) -> Tensor; /// Tests if any element in the tensor evaluates to True along a given dimension dim. /// @@ -1672,7 +1740,10 @@ pub trait BasicOps: TensorKind { /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. Users should prefer the [Tensor::any_dim](Tensor::any_dim) function, /// which is more high-level and designed for public use. - fn any_dim(tensor: Self::Primitive, dim: usize) -> Tensor; + fn any_dim( + tensor: ReprPrimitive, + dim: usize, + ) -> Tensor; /// Tests if all elements in the `tensor` evaluate to True. /// @@ -1690,7 +1761,7 @@ pub trait BasicOps: TensorKind { /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. Users should prefer the [Tensor::all](Tensor::all) function, /// which is more high-level and designed for public use. - fn all(tensor: Self::Primitive) -> Tensor; + fn all(tensor: ReprPrimitive) -> Tensor; /// Tests if all elements in the `tensor` evaluate to True along a given dimension `dim`. /// @@ -1709,7 +1780,10 @@ pub trait BasicOps: TensorKind { /// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. Users should prefer the [Tensor::all_dim](Tensor::all_dim) function, /// which is more high-level and designed for public use. - fn all_dim(tensor: Self::Primitive, dim: usize) -> Tensor; + fn all_dim( + tensor: ReprPrimitive, + dim: usize, + ) -> Tensor; /// Broadcasts the given tensor to the specified shape. /// @@ -1722,9 +1796,9 @@ pub trait BasicOps: TensorKind { /// /// The broadcasted tensor. fn expand( - tensor: Self::Primitive, + tensor: ReprPrimitive, shape: Shape, - ) -> Self::Primitive; + ) -> ReprPrimitive; } impl BasicOps for Float { @@ -2260,27 +2334,39 @@ impl RangesArg for [(i64, i64); D2] { /// Trait used for reshape arguments. pub trait ReshapeArgs { /// Converts to a shape. - fn into_shape>( + fn into_shape, SR: TensorStorage>( self, - tensor: &Tensor, - ) -> Shape; + tensor: &Tensor, + ) -> Shape + where + (B, K, SR): TensorRepr, + (B, Bool, SR): TensorRepr; } impl ReshapeArgs for Shape { - fn into_shape>( + fn into_shape, SR: TensorStorage>( self, - tensor: &Tensor, - ) -> Shape { + tensor: &Tensor, + ) -> Shape + where + (B, K, SR): TensorRepr, + (B, Bool, SR): TensorRepr, + { check!(TensorCheck::reshape_args_usize(&tensor.shape(), &self)); self } } impl ReshapeArgs for [usize; D2] { - fn into_shape>( + fn into_shape, SR: TensorStorage>( self, - tensor: &Tensor, - ) -> Shape { + tensor: &Tensor, + ) -> Shape + where + Bool: TensorKind, + (B, K, SR): TensorRepr, + (B, Bool, SR): TensorRepr, + { let shape = Shape::from(self); check!(TensorCheck::reshape_args_usize(&tensor.shape(), &shape)); @@ -2290,10 +2376,14 @@ impl ReshapeArgs for [usize; D2] { } impl ReshapeArgs for [i32; D2] { - fn into_shape>( + fn into_shape, SR: TensorStorage>( self, - tensor: &Tensor, - ) -> Shape { + tensor: &Tensor, + ) -> Shape + where + (B, K, SR): TensorRepr, + (B, Bool, SR): TensorRepr, + { // Validate the reshape arguments check!(TensorCheck::reshape_args_i32(&self)); diff --git a/crates/burn-tensor/src/tensor/api/check.rs b/crates/burn-tensor/src/tensor/api/check.rs index 754fab4c6c..a23b2c908d 100644 --- a/crates/burn-tensor/src/tensor/api/check.rs +++ b/crates/burn-tensor/src/tensor/api/check.rs @@ -1,4 +1,5 @@ use crate::{backend::Backend, BasicOps, Shape, Tensor}; +use crate::{Dense, Float, Sparse, TensorRepr}; use alloc::format; use alloc::string::{String, ToString}; use alloc::vec; @@ -548,6 +549,41 @@ impl TensorCheck { check } + // pub(crate) fn spmm, const D: usize>( + // lhs: &Tensor>, + // rhs: &Tensor, + // ) -> Self { + // let mut check = Self::Ok; + + // check = check.binary_ops_device("Matmul", &lhs.device(), &rhs.device()); + + // if D < 2 { + // return check; + // } + + // let shape_lhs = lhs.shape(); + // let shape_rhs = rhs.shape(); + + // let dim_lhs = shape_lhs.dims[D - 1]; + // let dim_rhs = shape_rhs.dims[D - 2]; + + // if dim_lhs != dim_rhs { + // check = check.register( + // "Matmul", + // TensorError::new(format!( + // "The inner dimension of matmul should be the same, but got {dim_lhs} and \ + // {dim_rhs}." + // )) + // .details(format!( + // "Lhs shape {:?}, rhs shape {:?}.", + // shape_lhs.dims, shape_rhs.dims + // )), + // ); + // } + + // check + // } + pub(crate) fn stack>( tensors: &[Tensor], dim: usize, diff --git a/crates/burn-tensor/src/tensor/api/chunk.rs b/crates/burn-tensor/src/tensor/api/chunk.rs index e247485c4a..07a266dd8d 100644 --- a/crates/burn-tensor/src/tensor/api/chunk.rs +++ b/crates/burn-tensor/src/tensor/api/chunk.rs @@ -1,5 +1,7 @@ use super::narrow::narrow; -use crate::{backend::Backend, BasicOps, TensorKind}; +use crate::{ + backend::Backend, BasicOps, Bool, Dense, ReprPrimitive, TensorKind, TensorRepr, TensorStorage, +}; use alloc::vec::Vec; /// Split the tensor along the given dimension into chunks. @@ -20,15 +22,19 @@ use alloc::vec::Vec; /// Ideally, it is supposed to be implemented by the backend and the backend implementation will be resolved /// by static dispatch. It is not designed for direct usage by users, and not recommended to import /// or use this function directly. -pub fn chunk + BasicOps>( - tensor: K::Primitive, +pub fn chunk + BasicOps, SR: TensorStorage>( + tensor: ReprPrimitive, chunks: usize, dim: usize, -) -> Vec> { +) -> Vec> +where + (B, K, SR): TensorRepr, + (B, Bool, SR): TensorRepr, +{ let size = K::shape(&tensor).dims[dim]; if size < chunks { return (0..size) - .map(|i| narrow::(tensor.clone(), dim, i, 1)) + .map(|i| narrow::(tensor.clone(), dim, i, 1)) .collect(); } @@ -37,7 +43,7 @@ pub fn chunk + BasicOps>( if size % chunks == 0 { let chunk_size = size / chunks; for _ in 0..chunks { - tensors.push(narrow::( + tensors.push(narrow::( tensor.clone(), dim, sum_chunk_size, @@ -48,7 +54,7 @@ pub fn chunk + BasicOps>( } else { let chunk_size = (size / chunks) + 1; // assumes not divisible for _ in 0..chunks - 1 { - tensors.push(narrow::( + tensors.push(narrow::( tensor.clone(), dim, sum_chunk_size, @@ -57,7 +63,7 @@ pub fn chunk + BasicOps>( sum_chunk_size += chunk_size; } let remainder = size % chunk_size; - tensors.push(narrow::( + tensors.push(narrow::( tensor.clone(), dim, sum_chunk_size, diff --git a/crates/burn-tensor/src/tensor/api/kind.rs b/crates/burn-tensor/src/tensor/api/kind.rs index 7afe1d2c36..697a1a7ae0 100644 --- a/crates/burn-tensor/src/tensor/api/kind.rs +++ b/crates/burn-tensor/src/tensor/api/kind.rs @@ -1,4 +1,6 @@ use crate::backend::Backend; +use crate::{Dense, Sparse, TensorRepr}; +use core::marker::PhantomData; /// A type-level representation of the kind of a float tensor #[derive(Clone, Debug)] diff --git a/crates/burn-tensor/src/tensor/api/mod.rs b/crates/burn-tensor/src/tensor/api/mod.rs index 60272d80bd..4e4f94bb1b 100644 --- a/crates/burn-tensor/src/tensor/api/mod.rs +++ b/crates/burn-tensor/src/tensor/api/mod.rs @@ -11,7 +11,13 @@ mod int; mod kind; mod narrow; mod numeric; +mod repr; mod sort; +mod sparse; +mod sparse_float; +mod sparse_numeric; +mod sparse_tensor; +mod storage; pub use argwhere::argwhere_data; pub use autodiff::*; @@ -21,4 +27,9 @@ pub use chunk::chunk; pub use kind::*; pub use narrow::narrow; pub use numeric::*; +pub use repr::*; pub use sort::{argsort, sort, sort_with_indices}; +pub use sparse::*; +pub use sparse_numeric::*; +pub use sparse_tensor::*; +pub use storage::*; diff --git a/crates/burn-tensor/src/tensor/api/narrow.rs b/crates/burn-tensor/src/tensor/api/narrow.rs index 88290bd388..be7299e385 100644 --- a/crates/burn-tensor/src/tensor/api/narrow.rs +++ b/crates/burn-tensor/src/tensor/api/narrow.rs @@ -1,4 +1,6 @@ -use crate::{backend::Backend, BasicOps, TensorKind}; +use crate::{ + backend::Backend, BasicOps, Bool, Dense, ReprPrimitive, TensorKind, TensorRepr, TensorStorage, +}; use alloc::vec::Vec; /// Returns a new tensor with the given dimension narrowed to the given range. @@ -17,12 +19,21 @@ use alloc::vec::Vec; /// # Returns /// /// A new tensor with the given dimension narrowed to the given range. -pub fn narrow + BasicOps>( - tensor: K::Primitive, +pub fn narrow< + B: Backend, + const D: usize, + K: TensorKind + BasicOps, + SR: TensorStorage, +>( + tensor: ReprPrimitive, dim: usize, start: usize, length: usize, -) -> K::Primitive { +) -> ReprPrimitive +where + (B, K, SR): TensorRepr, + (B, Bool, SR): TensorRepr, +{ let shape = K::shape(&tensor); let ranges: Vec<_> = (0..D) diff --git a/crates/burn-tensor/src/tensor/api/repr.rs b/crates/burn-tensor/src/tensor/api/repr.rs new file mode 100644 index 0000000000..68dfef6b34 --- /dev/null +++ b/crates/burn-tensor/src/tensor/api/repr.rs @@ -0,0 +1,17 @@ +use crate::{ + backend::Backend, Dense, Float, Sparse, SparseStorage, Tensor, TensorKind, TensorStorage, +}; + +pub type ReprPrimitive = <(B, K, S) as TensorRepr>::Primitive; + +pub trait TensorRepr { + type Primitive: Clone + core::fmt::Debug + Send; +} + +impl> TensorRepr for (B, K, Dense) { + type Primitive = K::Primitive; +} + +impl, SR: SparseStorage> TensorRepr for (B, K, Sparse) { + type Primitive = SR::SparsePrimitive; +} diff --git a/crates/burn-tensor/src/tensor/api/sparse.rs b/crates/burn-tensor/src/tensor/api/sparse.rs new file mode 100644 index 0000000000..457c701f80 --- /dev/null +++ b/crates/burn-tensor/src/tensor/api/sparse.rs @@ -0,0 +1,581 @@ +use crate::{ + backend::Backend, check::TensorCheck, BasicOps, Bool, DType, Dense, Device, Element, Float, + Int, ReprPrimitive, Shape, Sparse, SparseStorage, Tensor, TensorData, TensorKind, + TensorPrimitive, TensorRepr, TensorStorage, +}; +use core::{future::Future, ops::Range}; + +use crate::check; + +pub trait BasicSparseOps, SR: SparseStorage> +where + (B, K, Sparse): TensorRepr, +{ + fn into_dense( + tensor: ReprPrimitive, D>, + ) -> ReprPrimitive; + + fn into_sparse( + tensor: ReprPrimitive, + ) -> ReprPrimitive, D>; + + fn coordinates( + tensor: ReprPrimitive, D>, + ) -> Option>; + + fn values( + tensor: ReprPrimitive, D>, + ) -> Option>; +} + +impl> BasicSparseOps for SR +where + (B, Float, Sparse): TensorRepr, +{ + fn into_dense( + tensor: ReprPrimitive, D>, + ) -> ReprPrimitive { + TensorPrimitive::Float(SR::float_to_dense(tensor)) + } + + fn into_sparse( + tensor: ReprPrimitive, + ) -> ReprPrimitive, D> { + SR::float_to_sparse(tensor.tensor()) + } + + fn coordinates( + tensor: ReprPrimitive, D>, + ) -> Option> { + SR::float_coordinates(tensor) + } + + fn values( + tensor: ReprPrimitive, D>, + ) -> Option> { + SR::float_values(tensor) + } +} + +impl> BasicSparseOps for SR +where + (B, Int, Sparse): TensorRepr, +{ + fn into_dense( + tensor: ReprPrimitive, D>, + ) -> ReprPrimitive { + SR::int_to_dense(tensor) + } + + fn into_sparse( + tensor: ReprPrimitive, + ) -> ReprPrimitive, D> { + SR::int_to_sparse(tensor) + } + + fn coordinates( + tensor: ReprPrimitive, D>, + ) -> Option> { + SR::int_coordinates(tensor) + } + + fn values( + tensor: ReprPrimitive, D>, + ) -> Option> { + SR::int_values(tensor) + } +} + +impl> BasicSparseOps for SR +where + (B, Bool, Sparse): TensorRepr, +{ + fn into_dense( + tensor: ReprPrimitive, D>, + ) -> ReprPrimitive { + SR::bool_to_dense(tensor) + } + + fn into_sparse( + tensor: ReprPrimitive, + ) -> ReprPrimitive, D> { + SR::bool_to_sparse(tensor) + } + + fn coordinates( + tensor: ReprPrimitive, D>, + ) -> Option> { + SR::bool_coordinates(tensor) + } + + fn values( + tensor: ReprPrimitive, D>, + ) -> Option> { + SR::bool_values(tensor) + } +} + +impl> BasicOps> for Float { + type Elem = B::FloatElem; + + fn empty( + shape: Shape, + device: &::Device, + ) -> SR::SparsePrimitive { + SR::float_empty(shape, device) + } + + fn shape(tensor: &ReprPrimitive, D>) -> Shape { + SR::float_shape(tensor) + } + + fn reshape( + tensor: ReprPrimitive, D1>, + shape: Shape, + ) -> ReprPrimitive, D2> { + SR::float_reshape(tensor, shape) + } + + fn transpose( + tensor: ReprPrimitive, D>, + ) -> ReprPrimitive, D> { + SR::float_transpose(tensor) + } + + fn swap_dims( + tensor: ReprPrimitive, D>, + dim1: usize, + dim2: usize, + ) -> ReprPrimitive, D> { + SR::float_swap_dims(tensor, dim1, dim2) + } + + fn permute( + tensor: ReprPrimitive, D>, + axes: [usize; D], + ) -> ReprPrimitive, D> { + SR::float_permute(tensor, &axes) + } + + fn flip( + tensor: ReprPrimitive, D>, + axes: &[usize], + ) -> ReprPrimitive, D> { + SR::float_flip(tensor, axes) + } + + fn slice( + tensor: ReprPrimitive, D1>, + range: [Range; D2], + ) -> ReprPrimitive, D1> { + SR::float_slice(tensor, range) + } + + fn slice_assign( + tensor: ReprPrimitive, D1>, + ranges: [Range; D2], + value: ReprPrimitive, D1>, + ) -> ReprPrimitive, D1> { + SR::float_slice_assign(tensor, ranges, value) + } + + fn device( + tensor: &ReprPrimitive, D>, + ) -> ::Device { + SR::float_device(tensor) + } + + fn to_device( + tensor: ReprPrimitive, D>, + device: &::Device, + ) -> ReprPrimitive, D> { + SR::float_to_device(tensor, device) + } + + fn into_data_async( + tensor: ReprPrimitive, D>, + ) -> impl Future + Send { + async { + panic!("into_data not supported for sparse tensors, convert to dense first."); + } + } + + fn from_data( + data: TensorData, + device: &::Device, + ) -> ReprPrimitive, D> { + panic!("from_data not supported for sparse tensors, convert from dense.."); + } + + fn repeat_dim( + tensor: ReprPrimitive, D>, + dim: usize, + times: usize, + ) -> ReprPrimitive, D> { + SR::float_repeat_dim(tensor, dim, times) + } + + fn cat( + vectors: Vec, D>>, + dim: usize, + ) -> ReprPrimitive, D> { + SR::float_cat(vectors, dim) + } + + fn expand( + tensor: ReprPrimitive, D1>, + shape: Shape, + ) -> ReprPrimitive, D2> { + SR::float_expand(tensor, shape) + } + + fn equal( + lhs: ReprPrimitive, D>, + rhs: ReprPrimitive, D>, + ) -> Tensor> { + panic!("equal is unsupported for sparse tensors as it is non zero-preserving"); + } + + fn not_equal( + lhs: ReprPrimitive, D>, + rhs: ReprPrimitive, D>, + ) -> Tensor> { + panic!("not_equal is unsupported for sparse tensors as it is non zero-preserving"); + } + + fn any( + tensor: ReprPrimitive, D>, + ) -> Tensor> { + Tensor::new(SR::float_any(tensor)) + } + + fn any_dim( + tensor: ReprPrimitive, D>, + dim: usize, + ) -> Tensor> { + Tensor::new(SR::float_any_dim(tensor, dim)) + } + + fn all( + tensor: ReprPrimitive, D>, + ) -> Tensor> { + Tensor::new(SR::float_all(tensor)) + } + + fn all_dim( + tensor: ReprPrimitive, D>, + dim: usize, + ) -> Tensor> { + Tensor::new(SR::float_all_dim(tensor, dim)) + } +} + +impl> BasicOps> for Bool { + type Elem = bool; + + fn empty( + shape: Shape, + device: &::Device, + ) -> ReprPrimitive, D> { + SR::bool_empty(shape, device) + } + + fn shape(tensor: &ReprPrimitive, D>) -> Shape { + SR::bool_shape(tensor) + } + + fn reshape( + tensor: ReprPrimitive, D1>, + shape: Shape, + ) -> ReprPrimitive, D2> { + SR::bool_reshape(tensor, shape) + } + + fn transpose( + tensor: ReprPrimitive, D>, + ) -> ReprPrimitive, D> { + SR::bool_transpose(tensor) + } + + fn swap_dims( + tensor: ReprPrimitive, D>, + dim1: usize, + dim2: usize, + ) -> ReprPrimitive, D> { + SR::bool_swap_dims(tensor, dim1, dim2) + } + + fn permute( + tensor: ReprPrimitive, D>, + axes: [usize; D], + ) -> ReprPrimitive, D> { + SR::bool_permute(tensor, &axes) + } + + fn flip( + tensor: ReprPrimitive, D>, + axes: &[usize], + ) -> ReprPrimitive, D> { + SR::bool_flip(tensor, axes) + } + + fn slice( + tensor: ReprPrimitive, D1>, + range: [Range; D2], + ) -> ReprPrimitive, D1> { + SR::bool_slice(tensor, range) + } + + fn slice_assign( + tensor: ReprPrimitive, D1>, + ranges: [Range; D2], + value: ReprPrimitive, D1>, + ) -> ReprPrimitive, D1> { + SR::bool_slice_assign(tensor, ranges, value) + } + + fn device( + tensor: &ReprPrimitive, D>, + ) -> ::Device { + SR::bool_device(tensor) + } + + fn to_device( + tensor: ReprPrimitive, D>, + device: &::Device, + ) -> ReprPrimitive, D> { + SR::bool_to_device(tensor, device) + } + + fn into_data_async( + tensor: ReprPrimitive, D>, + ) -> impl Future + Send { + async { + panic!("into_data not supported for sparse tensors, convert to dense first."); + } + } + + fn from_data( + data: TensorData, + device: &::Device, + ) -> ReprPrimitive, D> { + panic!("from_data not supported for sparse tensors, convert from dense.."); + } + + fn repeat_dim( + tensor: ReprPrimitive, D>, + dim: usize, + times: usize, + ) -> ReprPrimitive, D> { + SR::bool_repeat_dim(tensor, dim, times) + } + + fn cat( + vectors: Vec, D>>, + dim: usize, + ) -> ReprPrimitive, D> { + SR::bool_cat(vectors, dim) + } + + fn equal( + lhs: ReprPrimitive, D>, + rhs: ReprPrimitive, D>, + ) -> Tensor> { + panic!("equal is unsupported for sparse tensors as it is non zero-preserving"); + } + + fn not_equal( + lhs: ReprPrimitive, D>, + rhs: ReprPrimitive, D>, + ) -> Tensor> { + panic!("not_equal is unsupported for sparse tensors as it is non zero-preserving"); + } + + fn any( + tensor: ReprPrimitive, D>, + ) -> Tensor> { + Tensor::new(SR::bool_any(tensor)) + } + + fn any_dim( + tensor: ReprPrimitive, D>, + dim: usize, + ) -> Tensor> { + Tensor::new(SR::bool_any_dim(tensor, dim)) + } + + fn all( + tensor: ReprPrimitive, D>, + ) -> Tensor> { + Tensor::new(SR::bool_all(tensor)) + } + + fn all_dim( + tensor: ReprPrimitive, D>, + dim: usize, + ) -> Tensor> { + Tensor::new(SR::bool_all_dim(tensor, dim)) + } + + fn expand( + tensor: ReprPrimitive, D1>, + shape: Shape, + ) -> ReprPrimitive, D2> { + SR::bool_expand(tensor, shape) + } +} + +impl> BasicOps> for Int { + type Elem = i32; + + fn empty( + shape: Shape, + device: &::Device, + ) -> ReprPrimitive, D> { + SR::int_empty(shape, device) + } + + fn shape(tensor: &ReprPrimitive, D>) -> Shape { + SR::int_shape(tensor) + } + + fn reshape( + tensor: ReprPrimitive, D1>, + shape: Shape, + ) -> ReprPrimitive, D2> { + SR::int_reshape(tensor, shape) + } + + fn transpose( + tensor: ReprPrimitive, D>, + ) -> ReprPrimitive, D> { + SR::int_transpose(tensor) + } + + fn swap_dims( + tensor: ReprPrimitive, D>, + dim1: usize, + dim2: usize, + ) -> ReprPrimitive, D> { + SR::int_swap_dims(tensor, dim1, dim2) + } + + fn permute( + tensor: ReprPrimitive, D>, + axes: [usize; D], + ) -> ReprPrimitive, D> { + SR::int_permute(tensor, &axes) + } + + fn flip( + tensor: ReprPrimitive, D>, + axes: &[usize], + ) -> ReprPrimitive, D> { + SR::int_flip(tensor, axes) + } + + fn slice( + tensor: ReprPrimitive, D1>, + range: [Range; D2], + ) -> ReprPrimitive, D1> { + SR::int_slice(tensor, range) + } + + fn slice_assign( + tensor: ReprPrimitive, D1>, + ranges: [Range; D2], + value: ReprPrimitive, D1>, + ) -> ReprPrimitive, D1> { + SR::int_slice_assign(tensor, ranges, value) + } + + fn device( + tensor: &ReprPrimitive, D>, + ) -> ::Device { + SR::int_device(tensor) + } + + fn to_device( + tensor: ReprPrimitive, D>, + device: &::Device, + ) -> ReprPrimitive, D> { + SR::int_to_device(tensor, device) + } + + fn into_data_async( + tensor: ReprPrimitive, D>, + ) -> impl Future + Send { + async { + panic!("into_data not supported for sparse tensors, convert to dense first."); + } + } + + fn from_data( + data: TensorData, + device: &::Device, + ) -> ReprPrimitive, D> { + panic!("from_data not supported for sparse tensors, convert from dense.."); + } + + fn repeat_dim( + tensor: ReprPrimitive, D>, + dim: usize, + times: usize, + ) -> ReprPrimitive, D> { + SR::int_repeat_dim(tensor, dim, times) + } + + fn cat( + vectors: Vec, D>>, + dim: usize, + ) -> ReprPrimitive, D> { + SR::int_cat(vectors, dim) + } + + fn equal( + lhs: ReprPrimitive, D>, + rhs: ReprPrimitive, D>, + ) -> Tensor> { + panic!("equal is unsupported for sparse tensors as it is non zero-preserving"); + } + + fn not_equal( + lhs: ReprPrimitive, D>, + rhs: ReprPrimitive, D>, + ) -> Tensor> { + panic!("not_equal is unsupported for sparse tensors as it is non zero-preserving"); + } + + fn any( + tensor: ReprPrimitive, D>, + ) -> Tensor> { + Tensor::new(SR::int_any(tensor)) + } + + fn any_dim( + tensor: ReprPrimitive, D>, + dim: usize, + ) -> Tensor> { + Tensor::new(SR::int_any_dim(tensor, dim)) + } + + fn all( + tensor: ReprPrimitive, D>, + ) -> Tensor> { + Tensor::new(SR::int_all(tensor)) + } + + fn all_dim( + tensor: ReprPrimitive, D>, + dim: usize, + ) -> Tensor> { + Tensor::new(SR::int_all_dim(tensor, dim)) + } + + fn expand( + tensor: ReprPrimitive, D1>, + shape: Shape, + ) -> ReprPrimitive, D2> { + SR::int_expand(tensor, shape) + } +} diff --git a/crates/burn-tensor/src/tensor/api/sparse_float.rs b/crates/burn-tensor/src/tensor/api/sparse_float.rs new file mode 100644 index 0000000000..ee35c5e398 --- /dev/null +++ b/crates/burn-tensor/src/tensor/api/sparse_float.rs @@ -0,0 +1,50 @@ +use crate::{backend::Backend, check::TensorCheck, Dense, Float, Sparse, Tensor, TensorKind}; +use crate::{check, Bool, SparseStorage, TensorPrimitive, TensorRepr}; + +impl Tensor> +where + B: Backend, + SR: SparseStorage, + (B, Float, Sparse): TensorRepr, +{ + /// Executes an operation on the tensor and modifies its value. + /// + /// # Notes + /// + /// This won't necessary reuse the same tensor data/buffer, but it should if there is + /// no other reference pointing to the same tensor. + /// + /// Wrapping operations with inplace is not an optimization, it's mainly there if you + /// want to mutate a tensor by using owned operations. A plausible usage would be to + /// update the weights of a mutable model reference. + pub fn inplace Self>(&mut self, func: F) { + let mut tensor_owned = Tensor::empty([0; D], &self.device()); + core::mem::swap(&mut tensor_owned, self); + + let mut tensor_new = func(tensor_owned); + core::mem::swap(&mut tensor_new, self); + } + + /// Applies the matrix multiplication operation. + /// + /// `C = AB` + /// + /// # Panics + /// + /// If the two tensors dont' have a compatible shape. + pub fn spmm(self, rhs: Tensor) -> Tensor { + // check!(TensorCheck::spmm(&self, &rhs)); + Tensor::::new(TensorPrimitive::Float(SR::float_spmm( + self.into_primitive(), + rhs.into_primitive(), + ))) + } + + pub fn sddmm(self, lhs: Tensor, rhs: Tensor) -> Self { + Tensor::new(SR::float_sddmm( + lhs.into_primitive().tensor(), + rhs.into_primitive().tensor(), + self.into_primitive(), + )) + } +} diff --git a/crates/burn-tensor/src/tensor/api/sparse_int.rs b/crates/burn-tensor/src/tensor/api/sparse_int.rs new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/crates/burn-tensor/src/tensor/api/sparse_int.rs @@ -0,0 +1 @@ + diff --git a/crates/burn-tensor/src/tensor/api/sparse_numeric.rs b/crates/burn-tensor/src/tensor/api/sparse_numeric.rs new file mode 100644 index 0000000000..100cd6b38a --- /dev/null +++ b/crates/burn-tensor/src/tensor/api/sparse_numeric.rs @@ -0,0 +1,30 @@ +use crate::check; + +use crate::{ + backend::Backend, check::TensorCheck, BasicOps, Bool, Element, ElementConversion, Int, Shape, + Sparse, SparseStorage, Tensor, TensorKind, TensorRepr, +}; + +/// Trait that list all operations that can be applied on all sparse numerical tensors. +/// +/// # Warnings +/// +/// This is an internal trait, use the public API provided by [tensor struct](Tensor). +pub trait SparseNumeric: TensorRepr +where + B: Backend, + K: TensorKind + BasicOps, + SR: SparseStorage, + K::Elem: Element, +{ +} + +impl Tensor> +where + B: Backend, + K: TensorKind + BasicOps, + SR: SparseStorage, + (B, K, SR): SparseNumeric, + K::Elem: Element, +{ +} diff --git a/crates/burn-tensor/src/tensor/api/sparse_tensor.rs b/crates/burn-tensor/src/tensor/api/sparse_tensor.rs new file mode 100644 index 0000000000..bf690f3e94 --- /dev/null +++ b/crates/burn-tensor/src/tensor/api/sparse_tensor.rs @@ -0,0 +1,44 @@ +use crate::{backend::Backend, check::TensorCheck, Dense, Float, Sparse, Tensor, TensorKind}; +use crate::{ + check, BasicOps, BasicSparseOps, Bool, Int, SparseStorage, TensorPrimitive, TensorRepr, +}; + +impl Tensor +where + B: Backend, + K: TensorKind, +{ + pub fn into_sparse + BasicSparseOps>( + self, + ) -> Tensor> + where + K: BasicOps>, + (B, K, Sparse): TensorRepr, + { + Tensor::>::from_primitive(SR::into_sparse(self.primitive)) + } +} + +impl Tensor> +where + B: Backend, + K: TensorKind + BasicOps> + BasicOps, + SR: SparseStorage + BasicSparseOps, + (B, K, Sparse): TensorRepr, +{ + pub fn into_dense(self) -> Tensor { + Tensor::::from_primitive(SR::into_dense(self.into_primitive())) + } + + pub fn coordinates(self) -> Option> { + Some(Tensor::::from_primitive(SR::coordinates( + self.into_primitive(), + )?)) + } + + pub fn values(self) -> Option> { + Some(Tensor::::from_primitive(SR::values( + self.into_primitive(), + )?)) + } +} diff --git a/crates/burn-tensor/src/tensor/api/storage.rs b/crates/burn-tensor/src/tensor/api/storage.rs new file mode 100644 index 0000000000..ec54cfc434 --- /dev/null +++ b/crates/burn-tensor/src/tensor/api/storage.rs @@ -0,0 +1,30 @@ +use crate::{backend::Backend, ops::SparseTensorOps, Bool, Float, Int, Tensor, TensorKind}; +use core::marker::PhantomData; + +pub trait TensorStorage: Clone + core::fmt::Debug { + fn name() -> &'static str; +} + +pub trait SparseStorage: Clone + core::fmt::Debug + SparseTensorOps { + type SparsePrimitive, const D: usize>: Clone + core::fmt::Debug + Send; + + fn name() -> &'static str; +} + +#[derive(Clone, Debug)] +pub struct Dense; + +#[derive(Clone, Debug)] +pub struct Sparse>(PhantomData<(B, SR)>); + +impl TensorStorage for Dense { + fn name() -> &'static str { + "Dense" + } +} + +impl> TensorStorage for Sparse { + fn name() -> &'static str { + SR::name() + } +} diff --git a/crates/burn-tensor/src/tensor/ops/bool_tensor.rs b/crates/burn-tensor/src/tensor/ops/bool_tensor.rs index b1718ad5c0..81560696f0 100644 --- a/crates/burn-tensor/src/tensor/ops/bool_tensor.rs +++ b/crates/burn-tensor/src/tensor/ops/bool_tensor.rs @@ -3,8 +3,8 @@ use super::{ FloatTensor, IntTensor, }; use crate::{ - argwhere_data, backend::Backend, chunk, narrow, tensor::Shape, Bool, ElementConversion, Tensor, - TensorData, + argwhere_data, backend::Backend, chunk, narrow, tensor::Shape, Bool, Dense, ElementConversion, + Tensor, TensorData, }; use alloc::vec::Vec; use core::{future::Future, ops::Range}; @@ -306,7 +306,7 @@ pub trait BoolTensorOps { start: usize, length: usize, ) -> BoolTensor { - narrow::(tensor, dim, start, length) + narrow::(tensor, dim, start, length) } /// Split the tensor along the given dimension into chunks. @@ -325,7 +325,7 @@ pub trait BoolTensorOps { chunks: usize, dim: usize, ) -> Vec> { - chunk::(tensor, chunks, dim) + chunk::(tensor, chunks, dim) } /// Tests if any element in the boolean `tensor` evaluates to True. diff --git a/crates/burn-tensor/src/tensor/ops/int_tensor.rs b/crates/burn-tensor/src/tensor/ops/int_tensor.rs index 28a2eb803b..2aa4d21dad 100644 --- a/crates/burn-tensor/src/tensor/ops/int_tensor.rs +++ b/crates/burn-tensor/src/tensor/ops/int_tensor.rs @@ -3,7 +3,7 @@ use super::repeat_dim::repeat_with_slice_assign; use super::{BoolTensor, Device, FloatTensor, IntElem, IntTensor}; use crate::cast::ToElement; use crate::{backend::Backend, tensor::Shape, Distribution, ElementConversion, Int, TensorData}; -use crate::{cartesian_grid, Tensor}; +use crate::{cartesian_grid, Dense, Tensor}; use crate::{tensor::api::chunk, tensor::api::narrow}; use alloc::vec::Vec; use core::future::Future; @@ -1011,7 +1011,7 @@ pub trait IntTensorOps { start: usize, length: usize, ) -> IntTensor { - narrow::(tensor, dim, start, length) + narrow::(tensor, dim, start, length) } /// Generates a cartesian grid for the given tensor shape on the specified device. @@ -1060,7 +1060,7 @@ pub trait IntTensorOps { chunks: usize, dim: usize, ) -> Vec> { - chunk::(tensor, chunks, dim) + chunk::(tensor, chunks, dim) } /// Creates a new int tensor with random values. diff --git a/crates/burn-tensor/src/tensor/ops/mod.rs b/crates/burn-tensor/src/tensor/ops/mod.rs index 1cce562586..af59004a81 100644 --- a/crates/burn-tensor/src/tensor/ops/mod.rs +++ b/crates/burn-tensor/src/tensor/ops/mod.rs @@ -4,6 +4,7 @@ mod bool_tensor; mod int_tensor; mod modules; mod qtensor; +mod sparse_tensor; mod tensor; pub use activation::*; @@ -12,4 +13,5 @@ pub use bool_tensor::*; pub use int_tensor::*; pub use modules::*; pub use qtensor::*; +pub use sparse_tensor::*; pub use tensor::*; diff --git a/crates/burn-tensor/src/tensor/ops/sparse_tensor.rs b/crates/burn-tensor/src/tensor/ops/sparse_tensor.rs new file mode 100644 index 0000000000..ec835f5e54 --- /dev/null +++ b/crates/burn-tensor/src/tensor/ops/sparse_tensor.rs @@ -0,0 +1,595 @@ +use super::{BoolTensor, FloatElem, FloatTensor, IntTensor, QuantizedTensor}; +use crate::{ + backend::Backend, Bool, Device, Float, Int, ReprPrimitive, Shape, Sparse, SparseStorage, + TensorData, TensorKind, +}; +use crate::{Dense, TensorRepr}; +use core::{future::Future, ops::Range}; + +pub trait SparseTensorOps, B: Backend>: + SparseFloatOps + SparseBoolOps + SparseIntOps +{ +} + +pub trait SparseFloatOps, B: Backend> +where + (B, Float, Sparse): TensorRepr, + (B, Bool, Sparse): TensorRepr, + (B, Int, Sparse): TensorRepr, +{ + fn float_values( + tensor: ReprPrimitive, D>, + ) -> Option>; + + fn float_coordinates( + sparse: ReprPrimitive, D>, + ) -> Option>; + + fn float_to_sparse( + dense: B::FloatTensorPrimitive, + ) -> ReprPrimitive, D>; + + fn float_empty( + shape: Shape, + device: &Device, + ) -> ReprPrimitive, D>; + + fn float_to_dense( + sparse: ReprPrimitive, D>, + ) -> B::FloatTensorPrimitive; + + fn float_spmm( + lhs: ReprPrimitive, D>, + rhs: >::Primitive, + ) -> B::FloatTensorPrimitive; + + fn float_sddmm( + lhs: B::FloatTensorPrimitive, + rhs: B::FloatTensorPrimitive, + sparse: ReprPrimitive, D>, + ) -> ReprPrimitive, D>; + + fn float_coalesce_sum( + tensor: ReprPrimitive, D>, + ) -> ReprPrimitive, D>; + + fn float_remove_zeros( + tensor: ReprPrimitive, D>, + ) -> ReprPrimitive, D>; + + fn float_number_nonzero( + tensor: ReprPrimitive, D>, + ) -> usize; + + fn float_density(sparse: ReprPrimitive, D>) -> f32; + + /// Gets the element at the given indices. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// * `indices` - The indices. + /// + /// # Returns + /// + /// The elements at the given indices. + fn float_slice( + tensor: SR::SparsePrimitive, + indices: [Range; D2], + ) -> SR::SparsePrimitive; + + /// Gets the device of the tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// + /// # Returns + /// + /// The device of the tensor. + fn float_device( + tensor: &ReprPrimitive, D>, + ) -> Device; + + /// Moves the tensor to the given device. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// * `device` - The device to move the tensor to. + /// + /// # Returns + /// + /// The tensor on the given device. + fn float_to_device( + tensor: ReprPrimitive, D>, + device: &Device, + ) -> ReprPrimitive, D>; + + /// Gets the shape of the tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor. + /// + /// # Returns + /// + /// The shape of the tensor. + fn float_shape(tensor: &ReprPrimitive, D>) -> Shape; + + fn float_reshape( + tensor: SR::SparsePrimitive, + shape: Shape, + ) -> SR::SparsePrimitive; + + fn float_transpose( + tensor: ReprPrimitive, D>, + ) -> ReprPrimitive, D>; + + fn float_swap_dims( + tensor: ReprPrimitive, D>, + dim1: usize, + dim2: usize, + ) -> ReprPrimitive, D>; + + fn float_permute( + tensor: ReprPrimitive, D>, + axes: &[usize], + ) -> ReprPrimitive, D>; + + fn float_flip( + tensor: ReprPrimitive, D>, + axes: &[usize], + ) -> ReprPrimitive, D>; + + fn float_slice_assign( + tensor: SR::SparsePrimitive, + ranges: [Range; D2], + value: SR::SparsePrimitive, + ) -> SR::SparsePrimitive; + + fn float_repeat_dim( + tensor: ReprPrimitive, D>, + dim: usize, + times: usize, + ) -> ReprPrimitive, D>; + + fn float_cat( + tensors: Vec, D>>, + dim: usize, + ) -> ReprPrimitive, D>; + + fn float_any( + tensor: ReprPrimitive, D>, + ) -> SR::SparsePrimitive; + + fn float_any_dim( + tensor: ReprPrimitive, D>, + dim: usize, + ) -> SR::SparsePrimitive; + + fn float_all( + tensor: ReprPrimitive, D>, + ) -> SR::SparsePrimitive; + + fn float_all_dim( + tensor: ReprPrimitive, D>, + dim: usize, + ) -> SR::SparsePrimitive; + + fn float_expand( + tensor: SR::SparsePrimitive, + shape: Shape, + ) -> SR::SparsePrimitive; + + /// Adds two sparse tensors together. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// The result of adding the two tensors together. + fn float_add( + lhs: ReprPrimitive, D>, + rhs: ReprPrimitive, D>, + ) -> ReprPrimitive, D>; + + /// Subtracts two tensors. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// The result of subtracting the two tensors. + fn float_sub( + lhs: ReprPrimitive, D>, + rhs: ReprPrimitive, D>, + ) -> ReprPrimitive, D>; + + /// Multiplies two sparse tensors together. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// The result of multiplying the two tensors together. + fn float_mul( + lhs: ReprPrimitive, D>, + rhs: ReprPrimitive, D>, + ) -> ReprPrimitive, D>; + + /// Multiplies a scalar to a tensor. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side scalar. + /// + /// # Returns + /// + /// The result of multiplying the scalar with the tensor. + fn float_mul_scalar( + lhs: ReprPrimitive, D>, + rhs: FloatElem, + ) -> ReprPrimitive, D>; + + /// Divides two sparse tensors. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side tensor. + /// + /// # Returns + /// + /// The result of dividing the two tensors. + fn float_div( + lhs: ReprPrimitive, D>, + rhs: ReprPrimitive, D>, + ) -> ReprPrimitive, D>; + + /// Divides a tensor by a scalar. + /// + /// # Arguments + /// + /// * `lhs` - The left hand side tensor. + /// * `rhs` - The right hand side scalar. + /// + /// # Returns + /// + /// The result of dividing the tensor by the scalar. + fn float_div_scalar( + lhs: ReprPrimitive, D>, + rhs: FloatElem, + ) -> ReprPrimitive, D>; + + fn float_max( + tensor: ReprPrimitive, D>, + ) -> ReprPrimitive, D>; + + fn float_max_dim( + tensor: ReprPrimitive, D>, + dim: usize, + ) -> ReprPrimitive, D>; + + fn float_min( + tensor: ReprPrimitive, D>, + ) -> ReprPrimitive, D>; + + fn float_min_dim( + tensor: ReprPrimitive, D>, + dim: usize, + ) -> ReprPrimitive, D>; + + fn float_abs( + tensor: ReprPrimitive, D>, + ) -> ReprPrimitive, D>; + fn float_sign( + tensor: ReprPrimitive, D>, + ) -> ReprPrimitive, D>; + + fn float_powf( + lhs: ReprPrimitive, D>, + rhs: ReprPrimitive, D>, + ) -> ReprPrimitive, D>; + + fn float_powi( + lhs: ReprPrimitive, D>, + rhs: ReprPrimitive, D>, + ) -> ReprPrimitive, D>; + + fn float_powf_scalar( + lhs: ReprPrimitive, D>, + rhs: FloatElem, + ) -> ReprPrimitive, D>; + + fn float_powi_scalar( + lhs: ReprPrimitive, D>, + rhs: FloatElem, + ) -> ReprPrimitive, D>; + + fn float_clamp( + tensor: ReprPrimitive, D>, + min: FloatElem, + max: FloatElem, + ) -> ReprPrimitive, D>; + + fn float_clamp_min( + tensor: ReprPrimitive, D>, + min: FloatElem, + ) -> ReprPrimitive, D>; + + fn float_clamp_max( + tensor: ReprPrimitive, D>, + max: FloatElem, + ) -> ReprPrimitive, D>; + + fn float_select( + tensor: ReprPrimitive, D>, + dim: usize, + indices: IntTensor, + ) -> ReprPrimitive, D>; + + fn float_select_assign( + tensor: ReprPrimitive, D>, + dim: usize, + indices: IntTensor, + values: ReprPrimitive, D>, + ) -> ReprPrimitive, D>; + + fn float_gather( + dim: usize, + tensor: ReprPrimitive, D>, + indices: IntTensor, + ) -> ReprPrimitive, D>; + + fn float_scatter( + dim: usize, + tensor: ReprPrimitive, D>, + indices: IntTensor, + values: ReprPrimitive, D>, + ) -> ReprPrimitive, D>; + + fn float_sum( + tensor: ReprPrimitive, D>, + ) -> SR::SparsePrimitive; + + fn float_sum_dim( + tensor: ReprPrimitive, D>, + dim: usize, + ) -> ReprPrimitive, D>; + + fn float_prod_dim( + tensor: ReprPrimitive, D>, + dim: usize, + ) -> ReprPrimitive, D>; + + fn float_mean( + tensor: ReprPrimitive, D>, + ) -> SR::SparsePrimitive; + + fn float_mean_dim( + tensor: ReprPrimitive, D>, + dim: usize, + ) -> ReprPrimitive, D>; + + fn float_remainder_scalar( + lhs: ReprPrimitive, D>, + rhs: FloatElem, + ) -> ReprPrimitive, D>; + + fn float_neg( + tensor: ReprPrimitive, D>, + ) -> ReprPrimitive, D>; +} + +pub trait SparseBoolOps, B: Backend> { + fn bool_values( + tensor: ReprPrimitive, D>, + ) -> Option>; + + fn bool_coordinates( + sparse: ReprPrimitive, D>, + ) -> Option>; + + fn bool_to_dense( + sparse: ReprPrimitive, D>, + ) -> B::BoolTensorPrimitive; + + fn bool_to_sparse( + dense: B::BoolTensorPrimitive, + ) -> ReprPrimitive, D>; + + fn bool_empty( + shape: Shape, + device: &Device, + ) -> SR::SparsePrimitive; + + fn bool_shape(tensor: &SR::SparsePrimitive) -> Shape; + + fn bool_reshape( + tensor: SR::SparsePrimitive, + shape: Shape, + ) -> SR::SparsePrimitive; + + fn bool_transpose( + tensor: SR::SparsePrimitive, + ) -> SR::SparsePrimitive; + + fn bool_swap_dims( + tensor: SR::SparsePrimitive, + dim1: usize, + dim2: usize, + ) -> SR::SparsePrimitive; + + fn bool_permute( + tensor: SR::SparsePrimitive, + axes: &[usize], + ) -> SR::SparsePrimitive; + + fn bool_flip( + tensor: SR::SparsePrimitive, + axes: &[usize], + ) -> SR::SparsePrimitive; + + fn bool_slice( + tensor: SR::SparsePrimitive, + indices: [Range; D2], + ) -> SR::SparsePrimitive; + + fn bool_slice_assign( + tensor: SR::SparsePrimitive, + ranges: [Range; D2], + value: SR::SparsePrimitive, + ) -> SR::SparsePrimitive; + + fn bool_device(tensor: &SR::SparsePrimitive) -> Device; + + fn bool_to_device( + tensor: SR::SparsePrimitive, + device: &Device, + ) -> SR::SparsePrimitive; + + fn bool_repeat_dim( + tensor: SR::SparsePrimitive, + dim: usize, + times: usize, + ) -> SR::SparsePrimitive; + + fn bool_cat( + tensors: Vec>, + dim: usize, + ) -> SR::SparsePrimitive; + + fn bool_any( + tensor: SR::SparsePrimitive, + ) -> SR::SparsePrimitive; + + fn bool_any_dim( + tensor: SR::SparsePrimitive, + dim: usize, + ) -> SR::SparsePrimitive; + + fn bool_all( + tensor: SR::SparsePrimitive, + ) -> SR::SparsePrimitive; + + fn bool_all_dim( + tensor: SR::SparsePrimitive, + dim: usize, + ) -> SR::SparsePrimitive; + + fn bool_expand( + tensor: SR::SparsePrimitive, + shape: Shape, + ) -> SR::SparsePrimitive; +} + +pub trait SparseIntOps, B: Backend> { + fn int_values( + tensor: ReprPrimitive, D>, + ) -> Option>; + + fn int_coordinates( + sparse: ReprPrimitive, D>, + ) -> Option>; + + fn int_to_dense( + sparse: ReprPrimitive, D>, + ) -> B::IntTensorPrimitive; + + fn int_to_sparse( + dense: B::IntTensorPrimitive, + ) -> ReprPrimitive, D>; + + fn int_empty( + shape: Shape, + device: &Device, + ) -> SR::SparsePrimitive; + + fn int_shape(tensor: &SR::SparsePrimitive) -> Shape; + + fn int_reshape( + tensor: SR::SparsePrimitive, + shape: Shape, + ) -> SR::SparsePrimitive; + + fn int_transpose( + tensor: SR::SparsePrimitive, + ) -> SR::SparsePrimitive; + + fn int_swap_dims( + tensor: SR::SparsePrimitive, + dim1: usize, + dim2: usize, + ) -> SR::SparsePrimitive; + + fn int_permute( + tensor: SR::SparsePrimitive, + axes: &[usize], + ) -> SR::SparsePrimitive; + + fn int_flip( + tensor: SR::SparsePrimitive, + axes: &[usize], + ) -> SR::SparsePrimitive; + + fn int_slice( + tensor: SR::SparsePrimitive, + indices: [Range; D2], + ) -> SR::SparsePrimitive; + + fn int_slice_assign( + tensor: SR::SparsePrimitive, + ranges: [Range; D2], + value: SR::SparsePrimitive, + ) -> SR::SparsePrimitive; + + fn int_device(tensor: &SR::SparsePrimitive) -> Device; + + fn int_to_device( + tensor: SR::SparsePrimitive, + device: &Device, + ) -> SR::SparsePrimitive; + + fn int_repeat_dim( + tensor: SR::SparsePrimitive, + dim: usize, + times: usize, + ) -> SR::SparsePrimitive; + + fn int_cat( + tensors: Vec>, + dim: usize, + ) -> SR::SparsePrimitive; + + fn int_any(tensor: SR::SparsePrimitive) + -> SR::SparsePrimitive; + + fn int_any_dim( + tensor: SR::SparsePrimitive, + dim: usize, + ) -> SR::SparsePrimitive; + + fn int_all(tensor: SR::SparsePrimitive) + -> SR::SparsePrimitive; + + fn int_all_dim( + tensor: SR::SparsePrimitive, + dim: usize, + ) -> SR::SparsePrimitive; + + fn int_expand( + tensor: SR::SparsePrimitive, + shape: Shape, + ) -> SR::SparsePrimitive; +} diff --git a/crates/burn-tensor/src/tensor/ops/tensor.rs b/crates/burn-tensor/src/tensor/ops/tensor.rs index 0edd5c8ee4..e906760743 100644 --- a/crates/burn-tensor/src/tensor/ops/tensor.rs +++ b/crates/burn-tensor/src/tensor/ops/tensor.rs @@ -5,7 +5,7 @@ use crate::backend::BackendBridge; use crate::tensor::cast::ToElement; use crate::{backend::Backend, tensor::Shape, Distribution, ElementConversion, Float, TensorData}; use crate::{tensor::api::chunk, tensor::api::narrow}; -use crate::{Tensor, TensorPrimitive}; +use crate::{Dense, Tensor, TensorPrimitive}; use alloc::vec::Vec; use core::future::Future; use core::ops::Range; @@ -1251,7 +1251,7 @@ pub trait FloatTensorOps { start: usize, length: usize, ) -> FloatTensor { - narrow::(TensorPrimitive::Float(tensor), dim, start, length).tensor() + narrow::(TensorPrimitive::Float(tensor), dim, start, length).tensor() } /// Split the tensor along the given dimension into chunks. @@ -1270,7 +1270,7 @@ pub trait FloatTensorOps { chunks: usize, dim: usize, ) -> Vec> { - chunk::(TensorPrimitive::Float(tensor), chunks, dim) + chunk::(TensorPrimitive::Float(tensor), chunks, dim) .into_iter() .map(|t| t.tensor()) .collect() diff --git a/crates/burn/Cargo.toml b/crates/burn/Cargo.toml index f08b6bbebb..3e07f764e4 100644 --- a/crates/burn/Cargo.toml +++ b/crates/burn/Cargo.toml @@ -36,6 +36,7 @@ vision = ["burn-core/vision"] # Backends autodiff = ["burn-core/autodiff"] fusion = ["burn-core/fusion"] +sparse = ["burn-core/sparse"] ## Backend features candle-cuda = ["burn-core/candle-cuda"]