Skip to content

Commit

Permalink
no selection of top nodes, just the string and score
Browse files Browse the repository at this point in the history
  • Loading branch information
iblacksand committed Mar 18, 2024
1 parent 9f08253 commit 3d8675b
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions webgestalt_lib/src/methods/nta.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@ pub struct NTAOptions {
pub edge_list: Vec<Vec<String>>,
/// A vector of strings representing the seeds
pub seeds: Vec<String>,
/// An integer representing the neighborhood size
pub neighborhood_size: usize,
/// A float representing the reset probability during random walk (default: 0.5)
pub reset_probability: f64,
/// A float representing the tolerance for probability calculation
Expand All @@ -20,19 +18,24 @@ impl Default for NTAOptions {
NTAOptions {
edge_list: vec![],
seeds: vec![],
neighborhood_size: 50,
reset_probability: 0.5,
tolerance: 0.000001,
}
}
}

pub struct NTAResult {
pub neighborhood: Vec<String>,
pub scores: Vec<f64>,
pub candidates: Vec<String>,
}

/// Uses random walk to calculate the neighborhood of a set of nodes
/// Returns [`Vec<String>`]representing the nodes in the neighborhood
///
/// # Parameters
/// - `config` - A [`NTAOptions`] struct containing the edge list, seeds, neighborhood size, reset probability, and tolerance
pub fn nta(config: NTAOptions) -> Vec<String> {
pub fn nta(config: NTAOptions) -> Vec<(String, f64)> {
println!("Building Graph");
let unique_nodes = ahash::AHashSet::from_iter(config.edge_list.iter().flatten().cloned());
let mut node_map: ahash::AHashMap<String, usize> = ahash::AHashMap::default();
Expand Down Expand Up @@ -60,13 +63,10 @@ pub fn nta(config: NTAOptions) -> Vec<String> {
config.reset_probability,
config.reset_probability,
);
let walk = walk_res.to_vec();
let mut top_n = walk.iter().enumerate().collect::<Vec<_>>();
top_n.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap());
top_n.truncate(config.neighborhood_size);
top_n
.iter()
.map(|(i, _p)| reverse_map.get(i).unwrap().clone())
let mut walk = walk_res.iter().enumerate().collect::<Vec<(usize, &f64)>>();
walk.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap());
walk.iter()
.map(|(i, p)| (reverse_map.get(&i).unwrap().clone(), **p))
.collect()
}

Expand Down

0 comments on commit 3d8675b

Please sign in to comment.