Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor/replace environments by substitution #787

Merged
merged 16 commits into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions effekt/js/src/main/scala/effekt/Backend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@ package effekt
class Backend {
val compiler = generator.js.JavaScriptWeb()
val runner = ()
val name = "js-web"
}
2 changes: 1 addition & 1 deletion effekt/jvm/src/test/scala/effekt/core/OptimizerTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class OptimizerTests extends CoreTests {
def normalize(input: String, expected: String)(using munit.Location) =
assertTransformsTo(input, expected) { tree =>
val anfed = BindSubexpressions.transform(tree)
val normalized = Normalizer.normalize(Set(mainSymbol), anfed, 50)
val normalized = Normalizer.normalize(Set(mainSymbol), anfed, 50, false)
Deadcode.remove(mainSymbol, normalized)
}

Expand Down
54 changes: 27 additions & 27 deletions effekt/jvm/src/test/scala/effekt/core/VMTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,7 @@ class VMTests extends munit.FunSuite {
dynamicDispatches = 0,
patternMatches = 400,
branches = 1487,
pushedFrames = 1352,
pushedFrames = 1185,
poppedFrames = 1409,
allocations = 54,
closures = 0,
Expand All @@ -549,7 +549,7 @@ class VMTests extends munit.FunSuite {
dynamicDispatches = 0,
patternMatches = 0,
branches = 210,
pushedFrames = 379,
pushedFrames = 378,
poppedFrames = 377,
allocations = 0,
closures = 0,
Expand Down Expand Up @@ -613,7 +613,7 @@ class VMTests extends munit.FunSuite {
dynamicDispatches = 0,
patternMatches = 4,
branches = 701,
pushedFrames = 874,
pushedFrames = 702,
poppedFrames = 880,
allocations = 4,
closures = 0,
Expand Down Expand Up @@ -660,13 +660,13 @@ class VMTests extends munit.FunSuite {

examplesDir / "casestudies" / "scheduler.effekt.md" -> Some(Summary(
staticDispatches = 60,
dynamicDispatches = 8,
dynamicDispatches = 7,
patternMatches = 95,
branches = 41,
pushedFrames = 106,
pushedFrames = 105,
poppedFrames = 106,
allocations = 73,
closures = 8,
closures = 7,
variableReads = 29,
variableWrites = 18,
resets = 1,
Expand Down Expand Up @@ -695,40 +695,40 @@ class VMTests extends munit.FunSuite {
dynamicDispatches = 783,
patternMatches = 13502,
branches = 14892,
pushedFrames = 28523,
poppedFrames = 28499,
pushedFrames = 28210,
poppedFrames = 28186,
allocations = 7923,
closures = 521,
variableReads = 6742,
variableWrites = 1901,
resets = 806,
shifts = 855,
resumes = 839
resets = 778,
shifts = 229,
resumes = 213
)),

examplesDir / "casestudies" / "anf.effekt.md" -> Some(Summary(
staticDispatches = 4775,
dynamicDispatches = 443,
patternMatches = 7272,
branches = 8110,
pushedFrames = 16275,
poppedFrames = 16260,
pushedFrames = 16101,
poppedFrames = 16088,
allocations = 4317,
closures = 358,
variableReads = 4080,
variableWrites = 1343,
resets = 481,
shifts = 660,
resumes = 644
resets = 458,
shifts = 322,
resumes = 306
)),

examplesDir / "casestudies" / "inference.effekt.md" -> Some(Summary(
staticDispatches = 1457444,
dynamicDispatches = 3201452,
patternMatches = 1474290,
branches = 303298,
pushedFrames = 7574480,
poppedFrames = 6709185,
pushedFrames = 7574476,
poppedFrames = 6709181,
allocations = 4626007,
closures = 865541,
variableReads = 2908620,
Expand All @@ -741,11 +741,11 @@ class VMTests extends munit.FunSuite {
examplesDir / "pos" / "raytracer.effekt" -> Some(Summary(
staticDispatches = 79696,
dynamicDispatches = 0,
patternMatches = 1014772,
patternMatches = 795964,
branches = 71995,
pushedFrames = 223269,
poppedFrames = 223269,
allocations = 127533,
allocations = 103221,
closures = 0,
variableReads = 77886,
variableWrites = 26904,
Expand All @@ -761,26 +761,26 @@ class VMTests extends munit.FunSuite {
dynamicDispatches = 0,
patternMatches = 0,
branches = 11,
pushedFrames = 102,
poppedFrames = 102,
pushedFrames = 92,
poppedFrames = 92,
allocations = 0,
closures = 0,
variableReads = 61,
variableWrites = 30,
resets = 1,
shifts = 10,
resumes = 10
resets = 0,
shifts = 0,
resumes = 0
)),

examplesDir / "benchmarks" / "other" / "church_exponentiation.effekt" -> Some(Summary(
staticDispatches = 7,
dynamicDispatches = 1062912,
dynamicDispatches = 797188,
patternMatches = 0,
branches = 5,
pushedFrames = 531467,
poppedFrames = 531467,
allocations = 0,
closures = 265750,
closures = 26,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is particularly significant...

variableReads = 0,
variableWrites = 0,
resets = 0,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -298,8 +298,9 @@ object PolymorphismBoxing extends Phase[CoreTransformed, CoreTransformed] {
Stmt.Alloc(id, transform(init), region, transform(body))
case Stmt.Var(id, init, cap, body) =>
Stmt.Var(id, transform(init), cap, transform(body))
case Stmt.Reset(body) =>
Stmt.Reset(transform(body))
case Stmt.Reset(BlockLit(tps, cps, vps, prompt :: Nil, body)) =>
Stmt.Reset(BlockLit(tps, cps, vps, prompt :: Nil, coerce(transform(body), stmt.tpe)))
case Stmt.Reset(body) => ???
case Stmt.Shift(prompt, body) =>
Stmt.Shift(prompt, transform(body))
case Stmt.Resume(k, body) =>
Expand Down
11 changes: 6 additions & 5 deletions effekt/shared/src/main/scala/effekt/core/Tree.scala
Original file line number Diff line number Diff line change
Expand Up @@ -272,15 +272,15 @@ enum Stmt extends Tree {
case Alloc(id: Id, init: Pure, region: Id, body: Stmt)

// creates a fresh state handler to model local (backtrackable) state.
// [[capture]] is a binding occurence.
// [[capture]] is a binding occurrence.
// e.g. state(init) { [x]{x: Ref} => ... }
case Var(id: Id, init: Pure, capture: Id, body: Stmt)
case Get(id: Id, annotatedCapt: Captures, annotatedTpe: ValueType)
case Put(id: Id, annotatedCapt: Captures, value: Pure)

// binds a fresh prompt as [[id]] in [[body]] and delimits the scope of captured continuations
// Reset({ [cap]{p: Prompt[answer] at cap} => stmt: answer}): answer
case Reset(body: BlockLit)
case Reset(body: Block.BlockLit)

// captures the continuation up to the given prompt
// Invariant, it always has the shape:
Expand Down Expand Up @@ -703,7 +703,7 @@ object substitutions {

case Match(scrutinee, clauses, default) =>
Match(substitute(scrutinee), clauses.map {
case (id, b) => (id, substitute(b).asInstanceOf[BlockLit])
case (id, b) => (id, substitute(b))
}, default.map(substitute))

case Alloc(id, init, region, body) =>
Expand All @@ -719,11 +719,12 @@ object substitutions {
case Put(id, capt, value) =>
Put(substituteAsVar(id), substitute(capt), substitute(value))

// We annotate the answer type here since it needs to be the union of body.tpe and all shifts
case Reset(body) =>
Reset(substitute(body).asInstanceOf[BlockLit])
Reset(substitute(body))

case Shift(prompt, body) =>
val after = substitute(body).asInstanceOf[BlockLit]
val after = substitute(body)
Shift(substitute(prompt).asInstanceOf[BlockVar], after)

case Resume(k, body) =>
Expand Down
38 changes: 22 additions & 16 deletions effekt/shared/src/main/scala/effekt/core/Type.scala
Original file line number Diff line number Diff line change
Expand Up @@ -97,19 +97,21 @@ object Type {
/**
* Function types are the only type constructor that we have subtyping on.
*/
def merge(tpe1: ValueType, tpe2: ValueType, covariant: Boolean): ValueType = (tpe1, tpe2) match {
case (ValueType.Boxed(btpe1, capt1), ValueType.Boxed(btpe2, capt2)) =>
def merge(tpe1: ValueType, tpe2: ValueType, covariant: Boolean): ValueType = (tpe1, tpe2, covariant) match {
case (tpe1, tpe2, covariant) if tpe1 == tpe2 => tpe1
case (ValueType.Boxed(btpe1, capt1), ValueType.Boxed(btpe2, capt2), covariant) =>
ValueType.Boxed(merge(btpe1, btpe2, covariant), merge(capt1, capt2, covariant))
case (tpe1, tpe2) if covariant =>
if (isSubtype(tpe1, tpe2)) tpe2 else tpe1
case (tpe1, tpe2) if !covariant =>
if (isSubtype(tpe1, tpe2)) tpe1 else tpe2
case (TBottom, tpe2, true) => tpe2
case (tpe1, TBottom, true) => tpe1
case (TTop, tpe2, true) => TTop
case (tpe1, TTop, true) => TTop
case (TBottom, tpe2, false) => TBottom
case (tpe1, TBottom, false) => TBottom
case (TTop, tpe2, false) => tpe2
case (tpe1, TTop, false) => tpe1
// TODO this swallows a lot of bugs that we NEED to fix
case _ => tpe1
}
private def isSubtype(tpe1: ValueType, tpe2: ValueType): Boolean = (tpe1, tpe2) match {
case (tpe1, TTop) => true
case (TBottom, tpe1) => true
case _ => false // conservative :)
// sys error s"Cannot compare ${tpe1} ${tpe2} in ${covariant}" // conservative :)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you ever want to fix more bugs, comment this in again.

}

def merge(tpe1: BlockType, tpe2: BlockType, covariant: Boolean): BlockType = (tpe1, tpe2) match {
Expand All @@ -129,12 +131,12 @@ object Type {
assert(targs.size == tparams.size, "Wrong number of type arguments")
assert(cargs.size == cparams.size, "Wrong number of capture arguments")

val vsubst = (tparams zip targs).toMap
val tsubst = (tparams zip targs).toMap
val csubst = (cparams zip cargs).toMap
BlockType.Function(Nil, Nil,
vparams.map { tpe => substitute(tpe, vsubst, Map.empty) },
bparams.map { tpe => substitute(tpe, vsubst, Map.empty) },
substitute(result, vsubst, csubst))
vparams.map { tpe => substitute(tpe, tsubst, Map.empty) },
bparams.map { tpe => substitute(tpe, tsubst, Map.empty) },
substitute(result, tsubst, csubst))
}

def substitute(capt: Captures, csubst: Map[Id, Captures]): Captures = capt.flatMap {
Expand Down Expand Up @@ -206,7 +208,11 @@ object Type {
case Stmt.Var(id, init, cap, body) => body.tpe
case Stmt.Get(id, capt, tpe) => tpe
case Stmt.Put(id, capt, value) => TUnit
case Stmt.Reset(body) => body.returnType
case Stmt.Reset(BlockLit(_, _, _, prompt :: Nil, body)) => prompt.tpe match {
case TPrompt(tpe) => tpe
case _ => ???
}
case Stmt.Reset(body) => ???
case Stmt.Shift(prompt, body) => body.bparams match {
case core.BlockParam(id, BlockType.Interface(ResumeSymbol, List(result, answer)), captures) :: Nil => result
case _ => ???
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ object Normalizer { normal =>
exprs: Map[Id, Expr],
decls: DeclarationContext, // for field selection
usage: mutable.Map[Id, Usage], // mutable in order to add new information after renaming
maxInlineSize: Int // to control inlining and avoid code bloat
maxInlineSize: Int, // to control inlining and avoid code bloat
preserveBoxing: Boolean // for LLVM, prevents some optimizations
) {
def bind(id: Id, expr: Expr): Context = copy(exprs = exprs + (id -> expr))
def bind(id: Id, block: Block): Context = copy(blocks = blocks + (id -> block))
Expand Down Expand Up @@ -65,14 +66,14 @@ object Normalizer { normal =>
case None => false
}

def normalize(entrypoints: Set[Id], m: ModuleDecl, maxInlineSize: Int): ModuleDecl = {
def normalize(entrypoints: Set[Id], m: ModuleDecl, maxInlineSize: Int, preserveBoxing: Boolean): ModuleDecl = {
// usage information is used to detect recursive functions (and not inline them)
val usage = Reachable(entrypoints, m)

val defs = m.definitions.collect {
case Toplevel.Def(id, block) => id -> block
}.toMap
val context = Context(defs, Map.empty, DeclarationContext(m.declarations, m.externs), mutable.Map.from(usage), maxInlineSize)
val context = Context(defs, Map.empty, DeclarationContext(m.declarations, m.externs), mutable.Map.from(usage), maxInlineSize, preserveBoxing)

val (normalizedDefs, _) = normalizeToplevel(m.definitions)(using context)
m.copy(definitions = normalizedDefs)
Expand Down Expand Up @@ -138,10 +139,10 @@ object Normalizer { normal =>

// TODO for `New` we should track how often each operation is used, not the object itself
// to decide inlining.
private def shouldInline(b: BlockLit, boundBy: Option[BlockVar])(using C: Context): Boolean = boundBy match {
private def shouldInline(b: BlockLit, boundBy: Option[BlockVar], blockArgs: List[Block])(using C: Context): Boolean = boundBy match {
case Some(id) if isRecursive(id.id) => false
case Some(id) => isOnce(id.id) || b.body.size <= C.maxInlineSize
case None => true
case _ => blockArgs.exists { b => b.isInstanceOf[BlockLit] } // higher-order function with known arg
}

private def active(e: Expr)(using Context): Expr =
Expand Down Expand Up @@ -171,7 +172,7 @@ object Normalizer { normal =>
// -------
case Stmt.App(b, targs, vargs, bargs) =>
active(b) match {
case NormalizedBlock.Known(b: BlockLit, boundBy) if shouldInline(b, boundBy) =>
case NormalizedBlock.Known(b: BlockLit, boundBy) if shouldInline(b, boundBy, bargs) =>
reduce(b, targs, vargs.map(normalize), bargs.map(normalize))
case normalized =>
Stmt.App(normalized.shared, targs, vargs.map(normalize), bargs.map(normalize))
Expand All @@ -181,7 +182,7 @@ object Normalizer { normal =>
active(b) match {
case n @ NormalizedBlock.Known(Block.New(impl), boundBy) =>
selectOperation(impl, method) match {
case b: BlockLit if shouldInline(b, boundBy) => reduce(b, targs, vargs.map(normalize), bargs.map(normalize))
case b: BlockLit if shouldInline(b, boundBy, bargs) => reduce(b, targs, vargs.map(normalize), bargs.map(normalize))
case _ => Stmt.Invoke(n.shared, method, methodTpe, targs, vargs.map(normalize), bargs.map(normalize))
}

Expand Down Expand Up @@ -213,6 +214,14 @@ object Normalizer { normal =>

def normalizeVal(id: Id, tpe: ValueType, binding: Stmt, body: Stmt): Stmt = normalize(binding) match {

// [[ val x = ABORT; body ]] = ABORT
case abort if !C.preserveBoxing && abort.tpe == Type.TBottom =>
abort

case abort @ Stmt.Shift(p, BlockLit(tparams, cparams, vparams, List(k), body))
if !C.preserveBoxing && !Variables.free(body).containsBlock(k.id) =>
abort

// [[ val x = sc match { case id(ps) => body2 }; body ]] = sc match { case id(ps) => val x = body2; body }
case Stmt.Match(sc, List((id2, BlockLit(tparams2, cparams2, vparams2, bparams2, body2))), None) =>
Stmt.Match(sc, List((id2, BlockLit(tparams2, cparams2, vparams2, bparams2,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ object Optimizer extends Phase[CoreTransformed, CoreTransformed] {

def optimize(source: Source, mainSymbol: symbols.Symbol, core: ModuleDecl)(using Context): ModuleDecl =

val isLLVM = Context.config.backend().name == "llvm"
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is REALLY hacky. Another solution would be to first optimize and then do polymorphism boxing. However, this alternative currently (I didn't try to find out why) fails a lot of test.


var tree = core

// (1) first thing we do is simply remove unused definitions (this speeds up all following analysis and rewrites)
Expand All @@ -37,18 +39,16 @@ object Optimizer extends Phase[CoreTransformed, CoreTransformed] {

def normalize(m: ModuleDecl) = {
val anfed = BindSubexpressions.transform(m)
val normalized = Normalizer.normalize(Set(mainSymbol), anfed, Context.config.maxInlineSize().toInt)
Deadcode.remove(mainSymbol, normalized)
val normalized = Normalizer.normalize(Set(mainSymbol), anfed, Context.config.maxInlineSize().toInt, isLLVM)
val live = Deadcode.remove(mainSymbol, normalized)
val tailRemoved = RemoveTailResumptions(live)
tailRemoved
}

// (3) normalize once and remove beta redexes
// (3) normalize a few times (since tail resumptions might only surface after normalization and leave dead Resets)
tree = Context.timed("normalize-1", source.name) { normalize(tree) }

// (4) optimize continuation capture in the tail-resumptive case
tree = Context.timed("tail-resumptions", source.name) { RemoveTailResumptions(tree) }

// (5) normalize again to clean up and remove new redexes
tree = Context.timed("normalize-2", source.name) { normalize(tree) }
tree = Context.timed("normalize-3", source.name) { normalize(tree) }

tree
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ object RemoveTailResumptions {
case Stmt.Get(id, annotatedCapt, annotatedTpe) => false
case Stmt.Put(id, annotatedCapt, value) => false
case Stmt.Reset(BlockLit(tparams, cparams, vparams, bparams, body)) => tailResumptive(k, body) // is this correct?
case Stmt.Shift(prompt, body) => false
case Stmt.Shift(prompt, body) => stmt.tpe == Type.TBottom
case Stmt.Resume(k2, body) => k2.id == k // what if k is free in body?
case Stmt.Hole() => true
}
Expand Down
Loading
Loading