Skip to content

Commit

Permalink
Merge pull request #46 from ninehusky/ninehusky-add-halide-tests
Browse files Browse the repository at this point in the history
Add initial experiments
  • Loading branch information
ninehusky authored Dec 17, 2024
2 parents 7a332f6 + c96d9b5 commit bbac1ec
Show file tree
Hide file tree
Showing 27 changed files with 7,462 additions and 518 deletions.
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,6 @@
path = extraction-gym
url = [email protected]:ninehusky/extraction-gym.git
branch = ninehusky-use-extraction-gym-as-lib
[submodule "Halide"]
path = Halide
url = [email protected]:halide/Halide.git
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ edition = "2021"
[dependencies]
egglog = "0.3.0"
egraph-serialize = "0.2.0"
env_logger = "0.11.5"
indexmap = "2.6.0"
lazy_static = "1.5.0"
ruler = { path = "./ruler" }
Expand Down
1 change: 1 addition & 0 deletions Halide
Submodule Halide added at 166cd9
97 changes: 87 additions & 10 deletions tests/halide.rs → src/halide/halide.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@ use chompy::{CVec, Chomper};
use rand::{rngs::StdRng, Rng, SeedableRng};
use ruler::{
enumo::{Sexp, Workload},
HashMap, ValidationResult,
HashMap, HashSet, ValidationResult,
};

use std::str::FromStr;

use z3::ast::Ast;

use chompy::utils::TERM_PLACEHOLDER;
Expand Down Expand Up @@ -55,10 +57,10 @@ impl Chomper for HalideChomper {

fn productions(&self) -> ruler::enumo::Workload {
Workload::new(&[
// format!(
// "(ternary {} {} {})",
// TERM_PLACEHOLDER, TERM_PLACEHOLDER, TERM_PLACEHOLDER
// ),
format!(
"(ternary {} {} {})",
TERM_PLACEHOLDER, TERM_PLACEHOLDER, TERM_PLACEHOLDER
),
format!("(binary {} {})", TERM_PLACEHOLDER, TERM_PLACEHOLDER),
format!("(unary {})", TERM_PLACEHOLDER),
])
Expand All @@ -74,8 +76,8 @@ impl Chomper for HalideChomper {
}

fn atoms(&self) -> Workload {
// Workload::new(&["(Var a)", "(Var b)", "(Lit 1)", "(Lit 0)"])
Workload::new(&["(Var a)", "(Var b)"])
// Workload::new(&["(Var a)", "(Var b)", "(Lit 1)", "(Lit 0)"])
}

fn matches_var_pattern(&self, term: &ruler::enumo::Sexp) -> bool {
Expand Down Expand Up @@ -308,7 +310,7 @@ impl Chomper for HalideChomper {
}

impl HalideChomper {
fn make_env(rng: &mut StdRng) -> HashMap<String, Vec<Option<i64>>> {
pub fn make_env(rng: &mut StdRng) -> HashMap<String, Vec<Option<i64>>> {
let mut env = HashMap::default();
let dummy = HalideChomper { env: env.clone() };
for atom in &dummy.atoms().force() {
Expand All @@ -330,6 +332,76 @@ impl HalideChomper {
}
}

pub fn get_vars(sexp: &Sexp) -> HashSet<String> {
let mut vars: HashSet<String> = HashSet::default();
get_vars_internal(sexp, &mut vars);
vars
}

fn get_vars_internal(sexp: &Sexp, vars: &mut HashSet<String>) {
match sexp {
Sexp::Atom(a) => {
if a.starts_with("?") {
// remove the "?"
vars.insert(a[1..].to_string());
}
}
Sexp::List(l) => {
for term in l {
get_vars_internal(term, vars);
}
}
}
}

// fn get_vars(ast: z3::ast::Int) -> Vec<String> {
// let mut vars: HashSet<String> = HashSet::default();
// match ast {
// z3::ast::Int::
// }
// }

// generates a binding from variables -> values that satisfy
// the given condition.
pub fn generate_environment(cond: &Sexp) -> HashMap<String, Sexp> {
let mut cfg = z3::Config::new();
cfg.set_timeout_msec(1000);
let ctx = z3::Context::new(&cfg);
let solver = z3::Solver::new(&ctx);
let constraint = sexp_to_z3(&ctx, cond);
solver.assert(&constraint._eq(&z3::ast::Int::from_i64(&ctx, 1)));
let result = solver.check();
let mut env = HashMap::default();
if result == z3::SatResult::Sat {
let model = solver.get_model();
// TODO: clean this up. this kind of string manipulation is not ideal.
// we should be able to parse the z3 expression for
// constants and extract the values through interpretation of the model.
let model_str = model.unwrap().to_string();
let lines: Vec<&str> = model_str.lines().collect();
for line in lines {
println!("line: {}", line);
// split on "->"
let parts = line.split("->").collect::<Vec<&str>>();
if parts.len() != 2 {
panic!("Unexpected number of parts: {}", parts.len());
}
// remove parens
let var = parts[0].trim().trim_matches(|c| c == '(' || c == ')');
let val = parts[1]
.trim()
.trim_matches(|c| c == '(' || c == ')')
.replace(' ', "");
println!("val: {}", val);
let val_num: i64 = val.parse().unwrap();
let sexp = Sexp::from_str(&format!("(Lit {})", val_num).to_string()).unwrap();
// do we need CVEC_LEN here?
env.insert(var.to_string(), sexp);
}
}
env
}

fn sexp_to_z3<'a>(ctx: &'a z3::Context, sexp: &Sexp) -> z3::ast::Int<'a> {
match sexp {
Sexp::Atom(a) => {
Expand Down Expand Up @@ -448,14 +520,19 @@ pub mod tests {

use super::*;

#[test]
// not running for now because is expensive.
// #[test]
fn run_halide_chomper() {
let env = HalideChomper::make_env(&mut StdRng::seed_from_u64(0));
let mut chomper = HalideChomper { env };
let mut chomper = HalideChomper {
env,
memo: Default::default(),
};
let mut egraph = EGraph::default();

#[derive(Debug)]
struct HalidePredicateInterpreter {
memo: Default::default(),
chomper: HalideChomper,
}

Expand All @@ -480,7 +557,7 @@ pub mod tests {
});
egraph.add_arcsort(halide_sort.clone()).unwrap();
egraph.add_arcsort(dummy_sort).unwrap();
init_egraph!(egraph, "./egglog/halide.egg");
init_egraph!(egraph, "../../tests/egglog/halide.egg");

chomper.run_chompy(&mut egraph);
}
Expand Down
98 changes: 98 additions & 0 deletions src/halide/halide_to_sexpr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
#!/usr/bin/python3

import ast
import re

REWRITE_FILES = ["out-c.txt", "out-nc.txt"]

OUTPUT_FILE = "rules.txt"

BAD_STR = "BAD"

def ast_to_sexpr(node):
if type(node) == ast.Expr:
return ast_to_sexpr(node.value)
match type(node):
case ast.Constant:
assert type(node.value) == int
return f"(Lit {node.value})"
case ast.Name:
if node.id == "false":
return "(Lit 0)"
elif node.id == "true":
return "(Lit 1)"
return f"(Var {node.id})"
case ast.Call:
call_to_op_str = {
"max": "Max",
"min": "Min",
"select": "Select",
}

func = call_to_op_str.get(node.func.id, BAD_STR)

return f"({func} {' '.join(map(ast_to_sexpr, node.args))})"
case ast.UnaryOp:
ast_to_op_str = {
ast.USub: "Neg",
ast.Not: "Not",
}
op = ast_to_op_str[type(node.op)]
return f"({op} {ast_to_sexpr(node.operand)})"
case ast.BoolOp:
boolop_to_op_str = {
ast.And: "And",
ast.Or: "Or",
}
op = boolop_to_op_str[type(node.op)]
return f"({op} {" ".join(map(ast_to_sexpr, node.values))})"
case ast.BinOp:
ast_to_op_str = {
ast.Add: "Add",
ast.Sub: "Sub",
ast.Mult: "Mul",
ast.Div: "Div",
}
op = ast_to_op_str[type(node.op)]
return f"({op} {ast_to_sexpr(node.left)} {ast_to_sexpr(node.right)})"
case ast.Compare:
ast_to_cmp_str = {
ast.Eq: "Eq",
ast.NotEq: "Neq",
ast.Lt: "Lt",
ast.LtE: "Leq",
# these are not in the Enumo subset of Halide eval
# ast.Gt: "Gt",
# ast.GtE: "Ge",
}
cmp = ast_to_cmp_str.get(type(node.ops[0]), BAD_STR)
return f"({cmp} {ast_to_sexpr(node.left)} {ast_to_sexpr(node.comparators[0])})"
# default case
case _:
raise Exception(f"unknown node: {type(node)}")

if __name__ == "__main__":
with open(OUTPUT_FILE, "w+") as out:
for file in REWRITE_FILES:
with open(file, "r") as f:
lines = f.readlines()
total_rules = len(lines)
added_rules_count = 0
for line in lines:
line = line.replace("&&", "and").replace("||", "or").replace("!", "not ").replace("not =", "!=")
# split on "==>" or "if"
parts = list(map(lambda expr: ast_to_sexpr(ast.parse(re.sub(r"^\s+", "", expr)).body[0]),
re.split(r"==>|if ", line)))

rule = ";".join(parts)

if len(parts) != 2 and len(parts) != 3:
raise Exception(f"bad length: {len(parts)}")
elif BAD_STR not in rule:
added_rules_count += 1
out.write(f"{rule}\n")

print(f"Added {added_rules_count} / {total_rules} rules from {file}")



Loading

0 comments on commit bbac1ec

Please sign in to comment.