Skip to content

Commit

Permalink
Add Halide derivability experiments (messy)
Browse files Browse the repository at this point in the history
  • Loading branch information
ninehusky committed Dec 9, 2024
1 parent 33c1871 commit 8da09fd
Show file tree
Hide file tree
Showing 12 changed files with 2,153 additions and 8 deletions.
16 changes: 12 additions & 4 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ pub struct Rules {
pub conditional: Vec<Rule>,
}

pub const MAX_SIZE: usize = 10;
pub const MAX_SIZE: usize = 6;
pub const EXAMPLE_COUNT: usize = 1;

#[macro_export]
Expand Down Expand Up @@ -81,7 +81,7 @@ pub trait Chomper {
result
}

fn run_chompy(&mut self, egraph: &mut EGraph) {
fn run_chompy(&mut self, egraph: &mut EGraph) -> Vec<Rule> {
// TODO: i want to use a set here.
let mut found_rules: Vec<Rule> = vec![];
let mut max_eclass_id = 0;
Expand Down Expand Up @@ -271,7 +271,8 @@ pub trait Chomper {
}
}

panic!("not all rules were found");
// TODO: check
found_rules
}

fn make_var(&self, var: &str) -> Sexp;
Expand All @@ -286,7 +287,10 @@ pub trait Chomper {
assert!(var_name.starts_with("?"));
let var_name = var_name.trim_start_matches("?").to_string();
if !env.contains_key(&var_name) {
panic!("variable not found in env: {}", var_name);
// TODO: check this
// this might be a terrible idea.
//
return Sexp::from_str(&format!("(Lit 5)")).unwrap();
}
return env[&var_name].clone();
}
Expand Down Expand Up @@ -347,6 +351,9 @@ pub trait Chomper {
match &rule.condition {
Some((_, envs)) => {
for env in envs {
println!("here is the env: {:?}", env);
println!("lhs: {}", lhs);
println!("rhs: {}", rhs);
let concretized_lhs = self.concretize_term_conditional(&lhs, env.clone());
let concretized_rhs = self.concretize_term_conditional(&rhs, env.clone());

Expand All @@ -363,6 +370,7 @@ pub trait Chomper {
"#,
concretized_lhs, concretized_rhs
)
.replace("quote", "\"")
.as_str(),
)
.unwrap();
Expand Down
81 changes: 77 additions & 4 deletions tests/halide.rs → tests/halide/halide.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@ use chompy::{CVec, Chomper};
use rand::{rngs::StdRng, Rng, SeedableRng};
use ruler::{
enumo::{Sexp, Workload},
HashMap, ValidationResult,
HashMap, HashSet, ValidationResult,
};

use std::str::FromStr;

use z3::ast::Ast;

use chompy::utils::TERM_PLACEHOLDER;
Expand Down Expand Up @@ -308,7 +310,7 @@ impl Chomper for HalideChomper {
}

impl HalideChomper {
fn make_env(rng: &mut StdRng) -> HashMap<String, Vec<Option<i64>>> {
pub fn make_env(rng: &mut StdRng) -> HashMap<String, Vec<Option<i64>>> {
let mut env = HashMap::default();
let dummy = HalideChomper { env: env.clone() };
for atom in &dummy.atoms().force() {
Expand All @@ -330,6 +332,76 @@ impl HalideChomper {
}
}

pub fn get_vars(sexp: &Sexp) -> HashSet<String> {
let mut vars: HashSet<String> = HashSet::default();
get_vars_internal(sexp, &mut vars);
vars
}

fn get_vars_internal(sexp: &Sexp, vars: &mut HashSet<String>) {
match sexp {
Sexp::Atom(a) => {
if a.starts_with("?") {
// remove the "?"
vars.insert(a[1..].to_string());
}
}
Sexp::List(l) => {
for term in l {
get_vars_internal(term, vars);
}
}
}
}

// fn get_vars(ast: z3::ast::Int) -> Vec<String> {
// let mut vars: HashSet<String> = HashSet::default();
// match ast {
// z3::ast::Int::
// }
// }

// generates a binding from variables -> values that satisfy
// the given condition.
pub fn generate_environment(cond: &Sexp) -> HashMap<String, Sexp> {
let mut cfg = z3::Config::new();
cfg.set_timeout_msec(1000);
let ctx = z3::Context::new(&cfg);
let solver = z3::Solver::new(&ctx);
let constraint = sexp_to_z3(&ctx, cond);
solver.assert(&constraint._eq(&z3::ast::Int::from_i64(&ctx, 1)));
let result = solver.check();
let mut env = HashMap::default();
if result == z3::SatResult::Sat {
let model = solver.get_model();
// TODO: clean this up. this kind of string manipulation is not ideal.
// we should be able to parse the z3 expression for
// constants and extract the values through interpretation of the model.
let model_str = model.unwrap().to_string();
let lines: Vec<&str> = model_str.lines().collect();
for line in lines {
println!("line: {}", line);
// split on "->"
let parts = line.split("->").collect::<Vec<&str>>();
if parts.len() != 2 {
panic!("Unexpected number of parts: {}", parts.len());
}
// remove parens
let var = parts[0].trim().trim_matches(|c| c == '(' || c == ')');
let val = parts[1]
.trim()
.trim_matches(|c| c == '(' || c == ')')
.replace(' ', "");
println!("val: {}", val);
let val_num: i64 = val.parse().unwrap();
let sexp = Sexp::from_str(&format!("(Lit {})", val_num).to_string()).unwrap();
// do we need CVEC_LEN here?
env.insert(var.to_string(), sexp);
}
}
env
}

fn sexp_to_z3<'a>(ctx: &'a z3::Context, sexp: &Sexp) -> z3::ast::Int<'a> {
match sexp {
Sexp::Atom(a) => {
Expand Down Expand Up @@ -448,7 +520,8 @@ pub mod tests {

use super::*;

#[test]
// not running for now because is expensive.
// #[test]
fn run_halide_chomper() {
let env = HalideChomper::make_env(&mut StdRng::seed_from_u64(0));
let mut chomper = HalideChomper { env };
Expand Down Expand Up @@ -480,7 +553,7 @@ pub mod tests {
});
egraph.add_arcsort(halide_sort.clone()).unwrap();
egraph.add_arcsort(dummy_sort).unwrap();
init_egraph!(egraph, "./egglog/halide.egg");
init_egraph!(egraph, "../egglog/halide.egg");

chomper.run_chompy(&mut egraph);
}
Expand Down
94 changes: 94 additions & 0 deletions tests/halide/halide_to_sexpr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
#!/usr/bin/python3

import ast
import re

REWRITE_FILES = ["out-c.txt"]

OUTPUT_FILE = "rules.txt"

BAD_STR = "BAD"

def ast_to_sexpr(node):
if type(node) == ast.Expr:
return ast_to_sexpr(node.value)
match type(node):
case ast.Constant:
assert type(node.value) == int
return f"(Lit {node.value})"
case ast.Name:
return f"(Var {node.id})"
case ast.Call:
call_to_op_str = {
"max": "Max",
"min": "Min",
"select": "Select",
}

func = call_to_op_str.get(node.func.id, BAD_STR)

return f"({func} {' '.join(map(ast_to_sexpr, node.args))})"
case ast.UnaryOp:
ast_to_op_str = {
ast.USub: "Neg",
ast.Not: "Not",
}
op = ast_to_op_str[type(node.op)]
return f"({op} {ast_to_sexpr(node.operand)})"
case ast.BoolOp:
boolop_to_op_str = {
ast.And: "And",
ast.Or: "Or",
}
op = boolop_to_op_str[type(node.op)]
return f"({op} {" ".join(map(ast_to_sexpr, node.values))})"
case ast.BinOp:
ast_to_op_str = {
ast.Add: "Add",
ast.Sub: "Sub",
ast.Mult: "Mul",
ast.Div: "Div",
}
op = ast_to_op_str[type(node.op)]
return f"({op} {ast_to_sexpr(node.left)} {ast_to_sexpr(node.right)})"
case ast.Compare:
ast_to_cmp_str = {
ast.Eq: "Eq",
ast.NotEq: "Neq",
ast.Lt: "Lt",
ast.LtE: "Leq",
# these are not in the Enumo subset of Halide eval
# ast.Gt: "Gt",
# ast.GtE: "Ge",
}
cmp = ast_to_cmp_str.get(type(node.ops[0]), BAD_STR)
return f"({cmp} {ast_to_sexpr(node.left)} {ast_to_sexpr(node.comparators[0])})"
# default case
case _:
raise Exception(f"unknown node: {type(node)}")

if __name__ == "__main__":
with open(OUTPUT_FILE, "w+") as out:
for file in REWRITE_FILES:
with open(file, "r") as f:
lines = f.readlines()
total_rules = len(lines)
added_rules_count = 0
for line in lines:
line = line.replace("&&", "and").replace("||", "or").replace("!", "not ").replace("not =", "!=")
# split on "==>" or "if"
parts = list(map(lambda expr: ast_to_sexpr(ast.parse(re.sub(r"^\s+", "", expr)).body[0]),
re.split(r"==>|if ", line)))

rule = ";".join(parts)

if len(parts) != 2 and len(parts) != 3:
raise Exception(f"bad length: {len(parts)}")
elif BAD_STR not in rule:
added_rules_count += 1
out.write(f"{rule}\n")

print(f"Added {added_rules_count} / {total_rules} rules from {file}")



Loading

0 comments on commit 8da09fd

Please sign in to comment.