From 78684e1bf243a023709355bcec3d7ad756a92185 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C3=A9o=20Coletta?= Date: Mon, 21 Oct 2024 07:38:03 +0200 Subject: [PATCH] perf: improve get_subset_sum performance --- crates/subset-sum/src/get_subset_sum.rs | 27 +++++++++++++++---------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/crates/subset-sum/src/get_subset_sum.rs b/crates/subset-sum/src/get_subset_sum.rs index 76eb1aa..e46214c 100644 --- a/crates/subset-sum/src/get_subset_sum.rs +++ b/crates/subset-sum/src/get_subset_sum.rs @@ -10,8 +10,8 @@ use std::iter::Sum; use std::time::Instant; #[derive(Hash, Eq, PartialEq, Clone)] -pub struct SubsetSumArg { - integer_list: Vec, +pub struct SubsetSumArg<'a, N: Num + Copy> { + integer_list: &'a [N], sum: N, } @@ -22,12 +22,15 @@ struct SubsetSum { pub type SubsetSumResult = Result>, SubsetSumError>; -impl RecurFn, SubsetSumResult> for SubsetSum { +impl<'a, N> RecurFn, SubsetSumResult> for SubsetSum +where + N: Num + Copy + Hash + Sum + Eq + Ord, +{ #[inline] fn body( &self, - subset_sum: impl Fn(SubsetSumArg) -> SubsetSumResult, - arg: SubsetSumArg, + subset_sum: impl Fn(SubsetSumArg<'a, N>) -> SubsetSumResult, + arg: SubsetSumArg<'a, N>, ) -> SubsetSumResult { if let Some(timeout) = self.timeout_in_ms { if self.now.elapsed().as_millis() >= timeout { @@ -44,7 +47,7 @@ impl RecurFn, SubsetSumResult> fo } if arg.integer_list.iter().copied().sum::() == arg.sum { - return Ok(Some(arg.integer_list)); + return Ok(Some(arg.integer_list.to_vec())); } if arg.integer_list.contains(&arg.sum) { @@ -52,8 +55,7 @@ impl RecurFn, SubsetSumResult> fo } for (index, ¤t) in arg.integer_list.iter().enumerate() { - let mut subset = arg.integer_list.clone(); - subset.remove(index); + let subset = &arg.integer_list[index + 1..]; if let Some(mut result) = subset_sum(SubsetSumArg { integer_list: subset, @@ -68,11 +70,14 @@ impl RecurFn, SubsetSumResult> fo } } -pub fn get_subset_sum( +pub fn get_subset_sum( mut list: Vec, sum: N, timeout_in_ms: Option, -) -> SubsetSumResult { +) -> SubsetSumResult +where + N: Num + Copy + Hash + Eq + Ord + Sum, +{ let subset_sum = unsync::memoize(SubsetSum { now: Instant::now(), timeout_in_ms, @@ -81,7 +86,7 @@ pub fn get_subset_sum( list.sort_unstable(); subset_sum.call(SubsetSumArg { - integer_list: list, + integer_list: &list, sum, }) }