diff --git a/crates/burn-core/src/backend.rs b/crates/burn-core/src/backend.rs index a1a5813491..1c05f79617 100644 --- a/crates/burn-core/src/backend.rs +++ b/crates/burn-core/src/backend.rs @@ -29,4 +29,4 @@ pub use burn_tch as libtorch; pub use burn_tch::LibTorch; #[cfg(feature = "sparse")] -pub use burn_sparse as sparse; +pub use burn_sparse::decorator as sparse; diff --git a/crates/burn-core/src/tensor.rs b/crates/burn-core/src/tensor.rs index 074606bb14..ecc858ebbe 100644 --- a/crates/burn-core/src/tensor.rs +++ b/crates/burn-core/src/tensor.rs @@ -1 +1,6 @@ pub use burn_tensor::*; + +#[cfg(feature = "sparse")] +pub mod sparse { + pub use burn_sparse::backend::*; +} diff --git a/crates/burn-sparse/src/backend/api.rs b/crates/burn-sparse/src/backend/api.rs index 60e2023af3..146f4bbce4 100644 --- a/crates/burn-sparse/src/backend/api.rs +++ b/crates/burn-sparse/src/backend/api.rs @@ -1,7 +1,14 @@ use crate::backend::{Sparse, SparseBackend}; use burn_tensor::{Int, Tensor, TensorPrimitive}; -pub trait SparseTensor +pub trait ToSparse +where + B: SparseBackend, +{ + fn into_sparse(self) -> Tensor; +} + +pub trait SparseTensorApi where B: SparseBackend, { @@ -10,7 +17,16 @@ where fn dense(self) -> Tensor; } -impl SparseTensor for Tensor +impl ToSparse for Tensor +where + B: SparseBackend, +{ + fn into_sparse(self) -> Tensor { + Tensor::new(B::sparse_to_sparse(self.into_primitive().tensor())) + } +} + +impl SparseTensorApi for Tensor where B: SparseBackend, { diff --git a/crates/burn-sparse/src/backend/mod.rs b/crates/burn-sparse/src/backend/mod.rs index 20c24353ed..741143a850 100644 --- a/crates/burn-sparse/src/backend/mod.rs +++ b/crates/burn-sparse/src/backend/mod.rs @@ -4,5 +4,6 @@ mod kind; mod sparse_backend; pub use alias::*; +pub use api::*; pub use kind::*; pub use sparse_backend::*;