Skip to content

Commit

Permalink
Minor updates in Graph Separation Boosting.
Browse files Browse the repository at this point in the history
  • Loading branch information
rmitsuboshi committed Aug 31, 2024
1 parent ede4758 commit 9ee4dbb
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@ pub struct GraphSepBoost<'a, F> {

// Hypohteses obtained by the weak-learner.
hypotheses: Vec<F>,


// The number of edges at the end of the previous round.
n_edges: usize,
}


Expand All @@ -104,6 +108,7 @@ impl<'a, F> GraphSepBoost<'a, F> {
sample,
hypotheses: Vec::new(),
edges: Vec::new(),
n_edges: usize::MAX,
}
}
}
Expand Down Expand Up @@ -180,6 +185,11 @@ impl<F> Booster<F> for GraphSepBoost<'_, F>
}
}

self.n_edges = self.edges
.iter()
.map(|edges| edges.len())
.sum();

self.hypotheses = Vec::new();
}

Expand All @@ -191,23 +201,30 @@ impl<F> Booster<F> for GraphSepBoost<'_, F>
) -> ControlFlow<usize>
where W: WeakLearner<Hypothesis = F>,
{
let n_edges_2 = self.edges.iter()
.map(|edge| edge.len())
.sum::<usize>();
if n_edges_2 == 0 {
if self.n_edges == 0 {
return ControlFlow::Break(iteration);
}
let denom = n_edges_2 as f64;

let dist = self.edges.iter()
.map(|edge| edge.len() as f64 / denom)
.map(|edge| edge.len() as f64 / self.n_edges as f64)
.collect::<Vec<_>>();

// Get a new hypothesis
let h = weak_learner.produce(self.sample, &dist);
self.update_params(&h);
self.hypotheses.push(h);


let n_edges = self.edges
.iter()
.map(|edges| edges.len())
.sum::<usize>();
if self.n_edges == n_edges {
eprintln!("[WARN] number of edges does not decrease.");
return ControlFlow::Break(iteration+1);
}
self.n_edges = n_edges;

ControlFlow::Continue(())
}

Expand Down
2 changes: 1 addition & 1 deletion src/sample/feature_struct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ impl DenseFeature {


impl SparseFeature {
/// Construct an empty dense feature with `name`.
/// Construct an empty sparse feature with `name`.
pub fn new<T: ToString>(name: T) -> Self {
Self {
name: name.to_string(),
Expand Down
6 changes: 3 additions & 3 deletions tests/graphsepboost.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ pub mod graphsepboost_tests {
#[test]
fn iris() {
let mut path = env::current_dir().unwrap();
path.push("tests/dataset/iris_binary.csv");
path.push("tests/dataset/german.csv");
// path.push("tests/dataset/iris_binary.csv");

let sample = SampleReader::new()
.file(path)
Expand All @@ -20,8 +21,7 @@ pub mod graphsepboost_tests {
.unwrap();


let mut booster = GraphSepBoost::init(&sample)
.tolerance(0.01);
let mut booster = GraphSepBoost::init(&sample);

let wl = DecisionTreeBuilder::new(&sample)
.max_depth(1)
Expand Down

0 comments on commit 9ee4dbb

Please sign in to comment.