From 3d8675b2e6bea6cf41e4075233a3fd4bf75840b3 Mon Sep 17 00:00:00 2001 From: John Elizarraras Date: Mon, 18 Mar 2024 15:01:03 -0500 Subject: [PATCH] no selection of top nodes, just the string and score --- webgestalt_lib/src/methods/nta.rs | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/webgestalt_lib/src/methods/nta.rs b/webgestalt_lib/src/methods/nta.rs index 15376b1..52fa072 100644 --- a/webgestalt_lib/src/methods/nta.rs +++ b/webgestalt_lib/src/methods/nta.rs @@ -7,8 +7,6 @@ pub struct NTAOptions { pub edge_list: Vec>, /// A vector of strings representing the seeds pub seeds: Vec, - /// An integer representing the neighborhood size - pub neighborhood_size: usize, /// A float representing the reset probability during random walk (default: 0.5) pub reset_probability: f64, /// A float representing the tolerance for probability calculation @@ -20,19 +18,24 @@ impl Default for NTAOptions { NTAOptions { edge_list: vec![], seeds: vec![], - neighborhood_size: 50, reset_probability: 0.5, tolerance: 0.000001, } } } +pub struct NTAResult { + pub neighborhood: Vec, + pub scores: Vec, + pub candidates: Vec, +} + /// Uses random walk to calculate the neighborhood of a set of nodes /// Returns [`Vec`]representing the nodes in the neighborhood /// /// # Parameters /// - `config` - A [`NTAOptions`] struct containing the edge list, seeds, neighborhood size, reset probability, and tolerance -pub fn nta(config: NTAOptions) -> Vec { +pub fn nta(config: NTAOptions) -> Vec<(String, f64)> { println!("Building Graph"); let unique_nodes = ahash::AHashSet::from_iter(config.edge_list.iter().flatten().cloned()); let mut node_map: ahash::AHashMap = ahash::AHashMap::default(); @@ -60,13 +63,10 @@ pub fn nta(config: NTAOptions) -> Vec { config.reset_probability, config.reset_probability, ); - let walk = walk_res.to_vec(); - let mut top_n = walk.iter().enumerate().collect::>(); - top_n.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap()); - top_n.truncate(config.neighborhood_size); - top_n - .iter() - .map(|(i, _p)| reverse_map.get(i).unwrap().clone()) + let mut walk = walk_res.iter().enumerate().collect::>(); + walk.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap()); + walk.iter() + .map(|(i, p)| (reverse_map.get(&i).unwrap().clone(), **p)) .collect() }