diff --git a/src/clustering.rs b/src/clustering.rs index 1fe3083..3edd929 100644 --- a/src/clustering.rs +++ b/src/clustering.rs @@ -126,6 +126,17 @@ impl Clustering { *x = self.positives.clusters.find(*x); }); } + + /// Map a vector of node ids to their representative node ids, replacing + /// nodes that have been filtered out with 0. + pub fn filter_map(&self, seeds: &mut Array, filtered_nodes: HashSet) { + seeds + .iter_mut() + .for_each(|x| match filtered_nodes.contains(x) { + false => *x = self.positives.clusters.find(*x), + true => *x = self.positives.clusters.len(), + }); + } } #[cfg(test)] diff --git a/src/lib.rs b/src/lib.rs index e24ba35..2e8ba56 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -26,7 +26,7 @@ pub fn get_edges( affinities: &Array, offsets: Vec>, seeds: &Array, -) -> Vec { +) -> (Vec, HashSet) { // let (_, array_shape) = get_dims::(seeds.dim(), 0); let offsets: Vec<[isize; D]> = offsets .into_iter() @@ -91,10 +91,13 @@ pub fn get_edges( to_filter.remove(v); } }); - agglom_edges - .into_iter() - .filter(|edge| !(to_filter.contains(&edge.1) || to_filter.contains(&edge.2))) - .collect() + ( + agglom_edges + .into_iter() + .filter(|edge| !(to_filter.contains(&edge.1) || to_filter.contains(&edge.2))) + .collect(), + to_filter, + ) } pub fn agglomerate( @@ -122,18 +125,21 @@ pub fn agglomerate( }); // main algorithm - let sorted_edges = get_edges::(affinities, offsets, &seeds); + let (sorted_edges, mut filtered_background) = get_edges::(affinities, offsets, &seeds); edges.extend(sorted_edges); + lookup.values().for_each(|node_id| { + filtered_background.remove(node_id); + }); let mut clustering = Clustering::new(seeds.len()); clustering.process_edges(edges); - clustering.map(&mut seeds); + clustering.filter_map(&mut seeds, filtered_background); // TODO: Fix seed handling // now we have to remap seeded entries back onto the original ids let mut rev_lookup = HashMap::with_capacity(lookup.len()); - //rev_lookup.insert(0, seeds.len()); + rev_lookup.insert(seeds.len(), 0); lookup.iter().for_each(|(seed, id)| { let rep_id = clustering.positives.clusters.find(*id); if *seed != rep_id { @@ -337,6 +343,56 @@ mod tests { assert!(!ids.contains(&0), "{:?}", components); assert!(ids.len() == 4, "{:?}", components); } + + /// Seeds + /// 1 2 0 + /// 4 0 0 + /// 0 0 0 + /// + /// Affs + /// offset [0, -1] + /// 0 0 0 + /// 0 0 0 + /// 0 0 0 + /// + /// offset [-1, 0] + /// 0 0 0 + /// 0 0 0 + /// 0 0 0 + /// + /// Expected Components + /// 1 2 0 + /// 4 0 0 + /// 0 0 0 + /// + #[test] + fn test_filtered_background() { + let affinities = array![ + [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], + [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]] + ] + .into_dyn() + - 0.5; + let seeds = array![[1, 2, 0], [4, 0, 0], [0, 0, 0]].into_dyn(); + let offsets = vec![vec![0, -1], vec![-1, 0]]; + let components = agglomerate::<2>(&affinities, offsets, vec![], seeds); + let ids = components + .clone() + .into_iter() + .unique() + .collect::>(); + for id in [1, 2, 4].iter() { + assert!(ids.contains(id), "{:?}", components); + } + assert!(ids.contains(&0), "{:?}", components); + assert!(ids.len() == 4, "{:?}", components); + assert!( + components.iter().counts().get(&0).unwrap() == &6, + "{:?}", + components + ); + } + #[test] fn test_cluster() { let edges = vec![