Skip to content

Commit

Permalink
perf: improve get_subset_sum performance
Browse files Browse the repository at this point in the history
  • Loading branch information
leo91000 committed Oct 21, 2024
1 parent d117859 commit 22e4b6d
Showing 1 changed file with 16 additions and 11 deletions.
27 changes: 16 additions & 11 deletions crates/subset-sum/src/get_subset_sum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ use std::iter::Sum;
use std::time::Instant;

#[derive(Hash, Eq, PartialEq, Clone)]
pub struct SubsetSumArg<N: Num + Copy> {
integer_list: Vec<N>,
pub struct SubsetSumArg<'a, N: Num + Copy> {
integer_list: &'a [N],
sum: N,
}

Expand All @@ -22,12 +22,15 @@ struct SubsetSum {

pub type SubsetSumResult<N> = Result<Option<Vec<N>>, SubsetSumError>;

impl<N: Num + Copy + Hash + Sum> RecurFn<SubsetSumArg<N>, SubsetSumResult<N>> for SubsetSum {
impl<'a, N> RecurFn<SubsetSumArg<'a, N>, SubsetSumResult<N>> for SubsetSum
where
N: Num + Copy + Hash + Sum + Eq + Ord,
{
#[inline]
fn body(
&self,
subset_sum: impl Fn(SubsetSumArg<N>) -> SubsetSumResult<N>,
arg: SubsetSumArg<N>,
subset_sum: impl Fn(SubsetSumArg<'a, N>) -> SubsetSumResult<N>,
arg: SubsetSumArg<'a, N>,
) -> SubsetSumResult<N> {
if let Some(timeout) = self.timeout_in_ms {
if self.now.elapsed().as_millis() >= timeout {
Expand All @@ -44,16 +47,15 @@ impl<N: Num + Copy + Hash + Sum> RecurFn<SubsetSumArg<N>, SubsetSumResult<N>> fo
}

if arg.integer_list.iter().copied().sum::<N>() == arg.sum {
return Ok(Some(arg.integer_list));
return Ok(Some(arg.integer_list.to_vec()));
}

if arg.integer_list.contains(&arg.sum) {
return Ok(Some(vec![arg.sum]));
}

for (index, &current) in arg.integer_list.iter().enumerate() {
let mut subset = arg.integer_list.clone();
subset.remove(index);
let subset = &arg.integer_list[index + 1..];

if let Some(mut result) = subset_sum(SubsetSumArg {
integer_list: subset,
Expand All @@ -68,11 +70,14 @@ impl<N: Num + Copy + Hash + Sum> RecurFn<SubsetSumArg<N>, SubsetSumResult<N>> fo
}
}

pub fn get_subset_sum<N: Num + Copy + Hash + Eq + Ord + Sum>(
pub fn get_subset_sum<N>(
mut list: Vec<N>,
sum: N,
timeout_in_ms: Option<u128>,
) -> SubsetSumResult<N> {
) -> SubsetSumResult<N>
where
N: Num + Copy + Hash + Eq + Ord + Sum,
{
let subset_sum = unsync::memoize(SubsetSum {
now: Instant::now(),
timeout_in_ms,
Expand All @@ -81,7 +86,7 @@ pub fn get_subset_sum<N: Num + Copy + Hash + Eq + Ord + Sum>(
list.sort_unstable();

subset_sum.call(SubsetSumArg {
integer_list: list,
integer_list: &list,
sum,
})
}
Expand Down

0 comments on commit 22e4b6d

Please sign in to comment.