Skip to content

Commit

Permalink
Add Math tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ninehusky committed Jan 9, 2025
1 parent 8ea08a7 commit 2b35859
Show file tree
Hide file tree
Showing 4 changed files with 150 additions and 63 deletions.
103 changes: 43 additions & 60 deletions src/chomper.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
use crate::language::MathLang;
use crate::PredicateInterpreter;
use std::{fmt::Display, str::FromStr, sync::Arc};

use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};

use crate::{
ite::DummySort,
language::{CVec, ChompyLanguage, ValidationResult},
PredicateInterpreter,
};
use egglog::{sort::EqSort, EGraph};
use log::info;
Expand Down Expand Up @@ -197,7 +201,7 @@ pub trait Chomper {

// terms is a vector of (lhs, rhs) pairs with NO variables--not even 1...
let terms: Vec<(Sexp, Sexp)> = self.get_language().concretize_rule(rule);
const MAX_DERIVABILITY_ITERATIONS: usize = 10;
const MAX_DERIVABILITY_ITERATIONS: usize = 7;
let mut egraph = initial_egraph.clone();
for rule in ruleset {
self.add_rewrite(&mut egraph, rule);
Expand Down Expand Up @@ -425,68 +429,47 @@ fn all_variables_bound(rule: &Rule) -> bool {
.all(|var| lhs_vars.contains(var))
}

#[allow(unused_imports)]
pub mod tests {
use crate::language::MathLang;
use crate::PredicateInterpreter;

use crate::chomper::Chomper;
use crate::language::*;

use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};

#[test]
fn test_chomper() {
struct MathChomper;

impl Chomper for MathChomper {
type Constant = i64;

fn make_pred_interpreter() -> impl crate::PredicateInterpreter {
#[derive(Debug)]
struct DummyPredicateInterpreter;
impl PredicateInterpreter for DummyPredicateInterpreter {
fn interp_cond(&self, sexp: &ruler::enumo::Sexp) -> bool {
let dummy_term = MathLang::Var("dummy".to_string());
match dummy_term.eval(sexp, &Default::default()).get(0).unwrap() {
Some(val) => *val > 0,
None => false,
}
}
/// A sample implementation of the Chomper trait for the MathLang language.
pub struct MathChomper;

impl Chomper for MathChomper {
type Constant = i64;

fn make_pred_interpreter() -> impl crate::PredicateInterpreter {
#[derive(Debug)]
struct DummyPredicateInterpreter;
impl PredicateInterpreter for DummyPredicateInterpreter {
fn interp_cond(&self, sexp: &ruler::enumo::Sexp) -> bool {
let dummy_term = MathLang::Var("dummy".to_string());
match dummy_term.eval(sexp, &Default::default()).get(0).unwrap() {
Some(val) => *val > 0,
None => false,
}
DummyPredicateInterpreter
}

fn initialize_env(
&self,
) -> ruler::HashMap<String, CVec<dyn ChompyLanguage<Constant = Self::Constant>>>
{
let mut env = ruler::HashMap::default();
// make seedable rng
let seed = 0b1001;
// TODO: this should be part of the interface for eval?
let cvec_len = 10;
let mut rng = StdRng::seed_from_u64(seed);

for var in self.get_language().get_vars() {
let cvec = (0..cvec_len)
.map(|_| Some(rng.gen_range(-10..10)))
.collect::<CVec<MathLang>>();
env.insert(var.clone(), cvec);
}
env
}

fn get_language(&self) -> Box<impl ChompyLanguage<Constant = Self::Constant>> {
Box::new(MathLang::Var("dummy".to_string()))
}
}
DummyPredicateInterpreter
}

let chomper = MathChomper;
let rules = chomper.run_chompy(10);
for rule in rules {
println!("{}", rule);
fn initialize_env(
&self,
) -> ruler::HashMap<String, CVec<dyn ChompyLanguage<Constant = Self::Constant>>> {
let mut env = ruler::HashMap::default();
// make seedable rng
let seed = 0b1001;
// TODO: this should be part of the interface for eval?
let cvec_len = 10;
let mut rng = StdRng::seed_from_u64(seed);

for var in self.get_language().get_vars() {
let cvec = (0..cvec_len)
.map(|_| Some(rng.gen_range(-10..10)))
.collect::<CVec<MathLang>>();
env.insert(var.clone(), cvec);
}
env
}

fn get_language(&self) -> Box<impl ChompyLanguage<Constant = Self::Constant>> {
Box::new(MathLang::Var("dummy".to_string()))
}
}
9 changes: 7 additions & 2 deletions src/language.rs
Original file line number Diff line number Diff line change
Expand Up @@ -336,11 +336,16 @@ impl ChompyLanguage for MathLang {
}

fn get_vals(&self) -> Vec<Self::Constant> {
vec![]
vec![-1, 0, 1]
}

fn get_vars(&self) -> Vec<String> {
vec!["x".to_string(), "y".to_string(), "z".to_string()]
vec![
"a".to_string(),
"b".to_string(),
"c".to_string(),
"d".to_string(),
]
}

fn const_type_as_str(&self) -> String {
Expand Down
43 changes: 43 additions & 0 deletions tests/math/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
pub mod tests {
use chompy::{
chomper::{Chomper, MathChomper, Rule},
language::{ChompyLanguage, MathLang},
};
use ruler::enumo::Sexp;
use std::str::FromStr;

use super::super::evaluate_ruleset;

#[test]
fn check_math_rules() {
let chomper = MathChomper;
let predicates = chomper.get_language().get_predicates();
let values = chomper.get_language().get_vals();
for v in values {
println!("Value: {}", chomper.get_language().make_val(v));
}
for p in predicates.force() {
println!("Predicate: {}", p);
}
let rules = chomper.run_chompy(5);
let hand_picked_rules = vec![
Rule {
condition: Sexp::from_str("(Neq ?x (Const 0))").ok(),
lhs: Sexp::from_str("(Div ?x ?x)").unwrap(),
rhs: Sexp::from_str("(Const 1)").unwrap(),
},
Rule {
condition: Sexp::from_str("(Gt ?x (Const -1))").ok(),
lhs: Sexp::from_str("(Abs ?x)").unwrap(),
rhs: Sexp::from_str("?x").unwrap(),
},
];
evaluate_ruleset::<MathChomper, MathLang>(
&rules,
&hand_picked_rules,
chomper,
MathLang::Var("dummy".into()),
);
}
}

58 changes: 57 additions & 1 deletion tests/mod.rs
Original file line number Diff line number Diff line change
@@ -1 +1,57 @@
mod halide;
use chompy::chomper::{Chomper, Rule};
use chompy::language::ChompyLanguage;
use ruler::enumo::Sexp;

// mod halide;
mod math;

// This is a hack. `rule_is_derivable` calls
// `concretize_rule`, which expects the rules which are
// discovered via observational equivalence, i.e., rules
// before they have been generalized. Going down a level, this
// means `concretize_rule` expects the rules to have variables
// as `(Var blah)` vs. what gets output by Chompy, which is
// `?blah`. This function transforms the rules from the latter
// format to the former.
pub fn transform_rule(rule: &Rule, language: &Box<impl ChompyLanguage>) -> Rule {
fn transform_sexp(sexp: &Sexp, language: &Box<impl ChompyLanguage>) -> Sexp {
match sexp {
Sexp::Atom(a) => {
if a.starts_with("?") {
language.make_var(&a[1..])
} else {
sexp.clone()
}
}
Sexp::List(l) => Sexp::List(l.iter().map(|s| transform_sexp(s, language)).collect()),
}
}
Rule {
condition: if let Some(cond) = &rule.condition {
Some(transform_sexp(cond, language))
} else {
None
},
lhs: transform_sexp(&rule.lhs, language),
rhs: transform_sexp(&rule.rhs, language),
}
}

pub fn evaluate_ruleset<C: Chomper + Sized, L: ChompyLanguage + Sized>(
chompy_rules: &Vec<Rule>,
other_rules: &Vec<Rule>,
chomper: C,
language: L,
) {
let egraph = chomper.get_initial_egraph();
let b = Box::new(language);
for rule in other_rules {
let rule = transform_rule(rule, &b);
let result = chomper.rule_is_derivable(&egraph, &chompy_rules, &rule);
if result {
println!("Rule is derivable: {:?}", rule);
} else {
println!("Rule is not derivable: {:?}", rule);
}
}
}

0 comments on commit 2b35859

Please sign in to comment.