diff --git a/effekt/js/src/main/scala/effekt/Backend.scala b/effekt/js/src/main/scala/effekt/Backend.scala index aee886ce5..ad34093be 100644 --- a/effekt/js/src/main/scala/effekt/Backend.scala +++ b/effekt/js/src/main/scala/effekt/Backend.scala @@ -3,4 +3,5 @@ package effekt class Backend { val compiler = generator.js.JavaScriptWeb() val runner = () + val name = "js-web" } diff --git a/effekt/jvm/src/test/scala/effekt/core/OptimizerTests.scala b/effekt/jvm/src/test/scala/effekt/core/OptimizerTests.scala index dc91897d9..fe5f2381a 100644 --- a/effekt/jvm/src/test/scala/effekt/core/OptimizerTests.scala +++ b/effekt/jvm/src/test/scala/effekt/core/OptimizerTests.scala @@ -36,7 +36,7 @@ class OptimizerTests extends CoreTests { def normalize(input: String, expected: String)(using munit.Location) = assertTransformsTo(input, expected) { tree => val anfed = BindSubexpressions.transform(tree) - val normalized = Normalizer.normalize(Set(mainSymbol), anfed, 50) + val normalized = Normalizer.normalize(Set(mainSymbol), anfed, 50, false) Deadcode.remove(mainSymbol, normalized) } diff --git a/effekt/jvm/src/test/scala/effekt/core/VMTests.scala b/effekt/jvm/src/test/scala/effekt/core/VMTests.scala index 8d9a1f1b2..5bd042abe 100644 --- a/effekt/jvm/src/test/scala/effekt/core/VMTests.scala +++ b/effekt/jvm/src/test/scala/effekt/core/VMTests.scala @@ -533,7 +533,7 @@ class VMTests extends munit.FunSuite { dynamicDispatches = 0, patternMatches = 400, branches = 1487, - pushedFrames = 1352, + pushedFrames = 1185, poppedFrames = 1409, allocations = 54, closures = 0, @@ -549,7 +549,7 @@ class VMTests extends munit.FunSuite { dynamicDispatches = 0, patternMatches = 0, branches = 210, - pushedFrames = 379, + pushedFrames = 378, poppedFrames = 377, allocations = 0, closures = 0, @@ -613,7 +613,7 @@ class VMTests extends munit.FunSuite { dynamicDispatches = 0, patternMatches = 4, branches = 701, - pushedFrames = 874, + pushedFrames = 702, poppedFrames = 880, allocations = 4, closures = 0, @@ -660,13 +660,13 @@ class VMTests extends munit.FunSuite { examplesDir / "casestudies" / "scheduler.effekt.md" -> Some(Summary( staticDispatches = 60, - dynamicDispatches = 8, + dynamicDispatches = 7, patternMatches = 95, branches = 41, - pushedFrames = 106, + pushedFrames = 105, poppedFrames = 106, allocations = 73, - closures = 8, + closures = 7, variableReads = 29, variableWrites = 18, resets = 1, @@ -695,15 +695,15 @@ class VMTests extends munit.FunSuite { dynamicDispatches = 783, patternMatches = 13502, branches = 14892, - pushedFrames = 28523, - poppedFrames = 28499, + pushedFrames = 28210, + poppedFrames = 28186, allocations = 7923, closures = 521, variableReads = 6742, variableWrites = 1901, - resets = 806, - shifts = 855, - resumes = 839 + resets = 778, + shifts = 229, + resumes = 213 )), examplesDir / "casestudies" / "anf.effekt.md" -> Some(Summary( @@ -711,15 +711,15 @@ class VMTests extends munit.FunSuite { dynamicDispatches = 443, patternMatches = 7272, branches = 8110, - pushedFrames = 16275, - poppedFrames = 16260, + pushedFrames = 16101, + poppedFrames = 16088, allocations = 4317, closures = 358, variableReads = 4080, variableWrites = 1343, - resets = 481, - shifts = 660, - resumes = 644 + resets = 458, + shifts = 322, + resumes = 306 )), examplesDir / "casestudies" / "inference.effekt.md" -> Some(Summary( @@ -727,8 +727,8 @@ class VMTests extends munit.FunSuite { dynamicDispatches = 3201452, patternMatches = 1474290, branches = 303298, - pushedFrames = 7574480, - poppedFrames = 6709185, + pushedFrames = 7574476, + poppedFrames = 6709181, allocations = 4626007, closures = 865541, variableReads = 2908620, @@ -741,11 +741,11 @@ class VMTests extends munit.FunSuite { examplesDir / "pos" / "raytracer.effekt" -> Some(Summary( staticDispatches = 79696, dynamicDispatches = 0, - patternMatches = 1014772, + patternMatches = 795964, branches = 71995, pushedFrames = 223269, poppedFrames = 223269, - allocations = 127533, + allocations = 103221, closures = 0, variableReads = 77886, variableWrites = 26904, @@ -761,26 +761,26 @@ class VMTests extends munit.FunSuite { dynamicDispatches = 0, patternMatches = 0, branches = 11, - pushedFrames = 102, - poppedFrames = 102, + pushedFrames = 92, + poppedFrames = 92, allocations = 0, closures = 0, variableReads = 61, variableWrites = 30, - resets = 1, - shifts = 10, - resumes = 10 + resets = 0, + shifts = 0, + resumes = 0 )), examplesDir / "benchmarks" / "other" / "church_exponentiation.effekt" -> Some(Summary( staticDispatches = 7, - dynamicDispatches = 1062912, + dynamicDispatches = 797188, patternMatches = 0, branches = 5, pushedFrames = 531467, poppedFrames = 531467, allocations = 0, - closures = 265750, + closures = 26, variableReads = 0, variableWrites = 0, resets = 0, diff --git a/effekt/shared/src/main/scala/effekt/core/PolymorphismBoxing.scala b/effekt/shared/src/main/scala/effekt/core/PolymorphismBoxing.scala index 47daa1f58..1a4804924 100644 --- a/effekt/shared/src/main/scala/effekt/core/PolymorphismBoxing.scala +++ b/effekt/shared/src/main/scala/effekt/core/PolymorphismBoxing.scala @@ -298,8 +298,9 @@ object PolymorphismBoxing extends Phase[CoreTransformed, CoreTransformed] { Stmt.Alloc(id, transform(init), region, transform(body)) case Stmt.Var(id, init, cap, body) => Stmt.Var(id, transform(init), cap, transform(body)) - case Stmt.Reset(body) => - Stmt.Reset(transform(body)) + case Stmt.Reset(BlockLit(tps, cps, vps, prompt :: Nil, body)) => + Stmt.Reset(BlockLit(tps, cps, vps, prompt :: Nil, coerce(transform(body), stmt.tpe))) + case Stmt.Reset(body) => ??? case Stmt.Shift(prompt, body) => Stmt.Shift(prompt, transform(body)) case Stmt.Resume(k, body) => diff --git a/effekt/shared/src/main/scala/effekt/core/Tree.scala b/effekt/shared/src/main/scala/effekt/core/Tree.scala index bde32e984..491afb6cc 100644 --- a/effekt/shared/src/main/scala/effekt/core/Tree.scala +++ b/effekt/shared/src/main/scala/effekt/core/Tree.scala @@ -272,7 +272,7 @@ enum Stmt extends Tree { case Alloc(id: Id, init: Pure, region: Id, body: Stmt) // creates a fresh state handler to model local (backtrackable) state. - // [[capture]] is a binding occurence. + // [[capture]] is a binding occurrence. // e.g. state(init) { [x]{x: Ref} => ... } case Var(id: Id, init: Pure, capture: Id, body: Stmt) case Get(id: Id, annotatedCapt: Captures, annotatedTpe: ValueType) @@ -280,7 +280,7 @@ enum Stmt extends Tree { // binds a fresh prompt as [[id]] in [[body]] and delimits the scope of captured continuations // Reset({ [cap]{p: Prompt[answer] at cap} => stmt: answer}): answer - case Reset(body: BlockLit) + case Reset(body: Block.BlockLit) // captures the continuation up to the given prompt // Invariant, it always has the shape: @@ -703,7 +703,7 @@ object substitutions { case Match(scrutinee, clauses, default) => Match(substitute(scrutinee), clauses.map { - case (id, b) => (id, substitute(b).asInstanceOf[BlockLit]) + case (id, b) => (id, substitute(b)) }, default.map(substitute)) case Alloc(id, init, region, body) => @@ -719,11 +719,12 @@ object substitutions { case Put(id, capt, value) => Put(substituteAsVar(id), substitute(capt), substitute(value)) + // We annotate the answer type here since it needs to be the union of body.tpe and all shifts case Reset(body) => - Reset(substitute(body).asInstanceOf[BlockLit]) + Reset(substitute(body)) case Shift(prompt, body) => - val after = substitute(body).asInstanceOf[BlockLit] + val after = substitute(body) Shift(substitute(prompt).asInstanceOf[BlockVar], after) case Resume(k, body) => diff --git a/effekt/shared/src/main/scala/effekt/core/Type.scala b/effekt/shared/src/main/scala/effekt/core/Type.scala index 7f942e0b1..7a99d8539 100644 --- a/effekt/shared/src/main/scala/effekt/core/Type.scala +++ b/effekt/shared/src/main/scala/effekt/core/Type.scala @@ -97,19 +97,21 @@ object Type { /** * Function types are the only type constructor that we have subtyping on. */ - def merge(tpe1: ValueType, tpe2: ValueType, covariant: Boolean): ValueType = (tpe1, tpe2) match { - case (ValueType.Boxed(btpe1, capt1), ValueType.Boxed(btpe2, capt2)) => + def merge(tpe1: ValueType, tpe2: ValueType, covariant: Boolean): ValueType = (tpe1, tpe2, covariant) match { + case (tpe1, tpe2, covariant) if tpe1 == tpe2 => tpe1 + case (ValueType.Boxed(btpe1, capt1), ValueType.Boxed(btpe2, capt2), covariant) => ValueType.Boxed(merge(btpe1, btpe2, covariant), merge(capt1, capt2, covariant)) - case (tpe1, tpe2) if covariant => - if (isSubtype(tpe1, tpe2)) tpe2 else tpe1 - case (tpe1, tpe2) if !covariant => - if (isSubtype(tpe1, tpe2)) tpe1 else tpe2 + case (TBottom, tpe2, true) => tpe2 + case (tpe1, TBottom, true) => tpe1 + case (TTop, tpe2, true) => TTop + case (tpe1, TTop, true) => TTop + case (TBottom, tpe2, false) => TBottom + case (tpe1, TBottom, false) => TBottom + case (TTop, tpe2, false) => tpe2 + case (tpe1, TTop, false) => tpe1 + // TODO this swallows a lot of bugs that we NEED to fix case _ => tpe1 - } - private def isSubtype(tpe1: ValueType, tpe2: ValueType): Boolean = (tpe1, tpe2) match { - case (tpe1, TTop) => true - case (TBottom, tpe1) => true - case _ => false // conservative :) + // sys error s"Cannot compare ${tpe1} ${tpe2} in ${covariant}" // conservative :) } def merge(tpe1: BlockType, tpe2: BlockType, covariant: Boolean): BlockType = (tpe1, tpe2) match { @@ -129,12 +131,12 @@ object Type { assert(targs.size == tparams.size, "Wrong number of type arguments") assert(cargs.size == cparams.size, "Wrong number of capture arguments") - val vsubst = (tparams zip targs).toMap + val tsubst = (tparams zip targs).toMap val csubst = (cparams zip cargs).toMap BlockType.Function(Nil, Nil, - vparams.map { tpe => substitute(tpe, vsubst, Map.empty) }, - bparams.map { tpe => substitute(tpe, vsubst, Map.empty) }, - substitute(result, vsubst, csubst)) + vparams.map { tpe => substitute(tpe, tsubst, Map.empty) }, + bparams.map { tpe => substitute(tpe, tsubst, Map.empty) }, + substitute(result, tsubst, csubst)) } def substitute(capt: Captures, csubst: Map[Id, Captures]): Captures = capt.flatMap { @@ -206,7 +208,11 @@ object Type { case Stmt.Var(id, init, cap, body) => body.tpe case Stmt.Get(id, capt, tpe) => tpe case Stmt.Put(id, capt, value) => TUnit - case Stmt.Reset(body) => body.returnType + case Stmt.Reset(BlockLit(_, _, _, prompt :: Nil, body)) => prompt.tpe match { + case TPrompt(tpe) => tpe + case _ => ??? + } + case Stmt.Reset(body) => ??? case Stmt.Shift(prompt, body) => body.bparams match { case core.BlockParam(id, BlockType.Interface(ResumeSymbol, List(result, answer)), captures) :: Nil => result case _ => ??? diff --git a/effekt/shared/src/main/scala/effekt/core/optimizer/Normalizer.scala b/effekt/shared/src/main/scala/effekt/core/optimizer/Normalizer.scala index ebaa75899..bcbe2ca66 100644 --- a/effekt/shared/src/main/scala/effekt/core/optimizer/Normalizer.scala +++ b/effekt/shared/src/main/scala/effekt/core/optimizer/Normalizer.scala @@ -35,7 +35,8 @@ object Normalizer { normal => exprs: Map[Id, Expr], decls: DeclarationContext, // for field selection usage: mutable.Map[Id, Usage], // mutable in order to add new information after renaming - maxInlineSize: Int // to control inlining and avoid code bloat + maxInlineSize: Int, // to control inlining and avoid code bloat + preserveBoxing: Boolean // for LLVM, prevents some optimizations ) { def bind(id: Id, expr: Expr): Context = copy(exprs = exprs + (id -> expr)) def bind(id: Id, block: Block): Context = copy(blocks = blocks + (id -> block)) @@ -65,14 +66,14 @@ object Normalizer { normal => case None => false } - def normalize(entrypoints: Set[Id], m: ModuleDecl, maxInlineSize: Int): ModuleDecl = { + def normalize(entrypoints: Set[Id], m: ModuleDecl, maxInlineSize: Int, preserveBoxing: Boolean): ModuleDecl = { // usage information is used to detect recursive functions (and not inline them) val usage = Reachable(entrypoints, m) val defs = m.definitions.collect { case Toplevel.Def(id, block) => id -> block }.toMap - val context = Context(defs, Map.empty, DeclarationContext(m.declarations, m.externs), mutable.Map.from(usage), maxInlineSize) + val context = Context(defs, Map.empty, DeclarationContext(m.declarations, m.externs), mutable.Map.from(usage), maxInlineSize, preserveBoxing) val (normalizedDefs, _) = normalizeToplevel(m.definitions)(using context) m.copy(definitions = normalizedDefs) @@ -138,10 +139,10 @@ object Normalizer { normal => // TODO for `New` we should track how often each operation is used, not the object itself // to decide inlining. - private def shouldInline(b: BlockLit, boundBy: Option[BlockVar])(using C: Context): Boolean = boundBy match { + private def shouldInline(b: BlockLit, boundBy: Option[BlockVar], blockArgs: List[Block])(using C: Context): Boolean = boundBy match { case Some(id) if isRecursive(id.id) => false case Some(id) => isOnce(id.id) || b.body.size <= C.maxInlineSize - case None => true + case _ => blockArgs.exists { b => b.isInstanceOf[BlockLit] } // higher-order function with known arg } private def active(e: Expr)(using Context): Expr = @@ -171,7 +172,7 @@ object Normalizer { normal => // ------- case Stmt.App(b, targs, vargs, bargs) => active(b) match { - case NormalizedBlock.Known(b: BlockLit, boundBy) if shouldInline(b, boundBy) => + case NormalizedBlock.Known(b: BlockLit, boundBy) if shouldInline(b, boundBy, bargs) => reduce(b, targs, vargs.map(normalize), bargs.map(normalize)) case normalized => Stmt.App(normalized.shared, targs, vargs.map(normalize), bargs.map(normalize)) @@ -181,7 +182,7 @@ object Normalizer { normal => active(b) match { case n @ NormalizedBlock.Known(Block.New(impl), boundBy) => selectOperation(impl, method) match { - case b: BlockLit if shouldInline(b, boundBy) => reduce(b, targs, vargs.map(normalize), bargs.map(normalize)) + case b: BlockLit if shouldInline(b, boundBy, bargs) => reduce(b, targs, vargs.map(normalize), bargs.map(normalize)) case _ => Stmt.Invoke(n.shared, method, methodTpe, targs, vargs.map(normalize), bargs.map(normalize)) } @@ -213,6 +214,14 @@ object Normalizer { normal => def normalizeVal(id: Id, tpe: ValueType, binding: Stmt, body: Stmt): Stmt = normalize(binding) match { + // [[ val x = ABORT; body ]] = ABORT + case abort if !C.preserveBoxing && abort.tpe == Type.TBottom => + abort + + case abort @ Stmt.Shift(p, BlockLit(tparams, cparams, vparams, List(k), body)) + if !C.preserveBoxing && !Variables.free(body).containsBlock(k.id) => + abort + // [[ val x = sc match { case id(ps) => body2 }; body ]] = sc match { case id(ps) => val x = body2; body } case Stmt.Match(sc, List((id2, BlockLit(tparams2, cparams2, vparams2, bparams2, body2))), None) => Stmt.Match(sc, List((id2, BlockLit(tparams2, cparams2, vparams2, bparams2, diff --git a/effekt/shared/src/main/scala/effekt/core/optimizer/Optimizer.scala b/effekt/shared/src/main/scala/effekt/core/optimizer/Optimizer.scala index 8d0d6d029..3dbc7238a 100644 --- a/effekt/shared/src/main/scala/effekt/core/optimizer/Optimizer.scala +++ b/effekt/shared/src/main/scala/effekt/core/optimizer/Optimizer.scala @@ -21,6 +21,8 @@ object Optimizer extends Phase[CoreTransformed, CoreTransformed] { def optimize(source: Source, mainSymbol: symbols.Symbol, core: ModuleDecl)(using Context): ModuleDecl = + val isLLVM = Context.config.backend().name == "llvm" + var tree = core // (1) first thing we do is simply remove unused definitions (this speeds up all following analysis and rewrites) @@ -37,18 +39,16 @@ object Optimizer extends Phase[CoreTransformed, CoreTransformed] { def normalize(m: ModuleDecl) = { val anfed = BindSubexpressions.transform(m) - val normalized = Normalizer.normalize(Set(mainSymbol), anfed, Context.config.maxInlineSize().toInt) - Deadcode.remove(mainSymbol, normalized) + val normalized = Normalizer.normalize(Set(mainSymbol), anfed, Context.config.maxInlineSize().toInt, isLLVM) + val live = Deadcode.remove(mainSymbol, normalized) + val tailRemoved = RemoveTailResumptions(live) + tailRemoved } - // (3) normalize once and remove beta redexes + // (3) normalize a few times (since tail resumptions might only surface after normalization and leave dead Resets) tree = Context.timed("normalize-1", source.name) { normalize(tree) } - - // (4) optimize continuation capture in the tail-resumptive case - tree = Context.timed("tail-resumptions", source.name) { RemoveTailResumptions(tree) } - - // (5) normalize again to clean up and remove new redexes tree = Context.timed("normalize-2", source.name) { normalize(tree) } + tree = Context.timed("normalize-3", source.name) { normalize(tree) } tree } diff --git a/effekt/shared/src/main/scala/effekt/core/optimizer/RemoveTailResumptions.scala b/effekt/shared/src/main/scala/effekt/core/optimizer/RemoveTailResumptions.scala index 2174f7acd..c290ff9bd 100644 --- a/effekt/shared/src/main/scala/effekt/core/optimizer/RemoveTailResumptions.scala +++ b/effekt/shared/src/main/scala/effekt/core/optimizer/RemoveTailResumptions.scala @@ -38,7 +38,7 @@ object RemoveTailResumptions { case Stmt.Get(id, annotatedCapt, annotatedTpe) => false case Stmt.Put(id, annotatedCapt, value) => false case Stmt.Reset(BlockLit(tparams, cparams, vparams, bparams, body)) => tailResumptive(k, body) // is this correct? - case Stmt.Shift(prompt, body) => false + case Stmt.Shift(prompt, body) => stmt.tpe == Type.TBottom case Stmt.Resume(k2, body) => k2.id == k // what if k is free in body? case Stmt.Hole() => true } diff --git a/effekt/shared/src/main/scala/effekt/cps/Transformer.scala b/effekt/shared/src/main/scala/effekt/cps/Transformer.scala index 95ac72ae9..cdf407ad9 100644 --- a/effekt/shared/src/main/scala/effekt/cps/Transformer.scala +++ b/effekt/shared/src/main/scala/effekt/cps/Transformer.scala @@ -79,10 +79,10 @@ object Transformer { }) case core.Stmt.App(callee, targs, vargs, bargs) => - App(transform(callee), vargs.map(transform), bargs.map(transform), MetaCont(ks), k.reify) + App(transform(callee), vargs.map(transform), bargs.map(transform), MetaCont(ks), k.reifyAt(stmt.tpe)) case core.Stmt.Invoke(callee, method, tpe, targs, vargs, bargs) => - Invoke(transform(callee), method, vargs.map(transform), bargs.map(transform), MetaCont(ks), k.reify) + Invoke(transform(callee), method, vargs.map(transform), bargs.map(transform), MetaCont(ks), k.reifyAt(stmt.tpe)) case core.Stmt.If(cond, thn, els) => withJoinpoint(k) { k2 => @@ -119,7 +119,7 @@ object Transformer { val translatedBody: BlockLit = BlockLit(vparams.map { p => p.id }, List(resume.id), ks2, k2, transform(body, ks2, Continuation.Dynamic(k2))) - Shift(prompt.id, translatedBody, MetaCont(ks), k.reify) + Shift(prompt.id, translatedBody, MetaCont(ks), k.reifyAt(stmt.tpe)) case core.Stmt.Shift(prompt, body) => sys error "Shouldn't happen" @@ -127,7 +127,7 @@ object Transformer { val ks2 = Id("ks") val k2 = Id("k") Resume(cont.id, Block.BlockLit(Nil, Nil, ks2, k2, transform(body, ks2, Continuation.Dynamic(k2))), - MetaCont(ks), k.reify) + MetaCont(ks), k.reifyAt(stmt.tpe)) case core.Stmt.Hole() => Hole() @@ -230,8 +230,10 @@ enum Continuation { val ks = Id("ks") cps.Cont.ContLam(hint, ks, k(Pure.ValueVar(hint), ks)) } + + def reifyAt(tpe: core.ValueType): Cont = + if (tpe == core.Type.TBottom) Cont.Abort else reify } object Continuation { def Static(hint: Id)(k: (Pure, Id) => Stmt): Continuation.Static = Continuation.Static(hint, k) - } diff --git a/effekt/shared/src/main/scala/effekt/cps/Tree.scala b/effekt/shared/src/main/scala/effekt/cps/Tree.scala index 8fa44c7dd..517d5db13 100644 --- a/effekt/shared/src/main/scala/effekt/cps/Tree.scala +++ b/effekt/shared/src/main/scala/effekt/cps/Tree.scala @@ -4,6 +4,7 @@ package cps import core.{ Id, ValueType, BlockType, Captures } import effekt.source.FeatureFlag import effekt.util.messages.ErrorReporter +import effekt.util.messages.INTERNAL_ERROR sealed trait Tree extends Product { @@ -145,6 +146,7 @@ case class Clause(vparams: List[Id], body: Stmt) extends Tree enum Cont extends Tree { case ContVar(id: Id) case ContLam(result: Id, ks: Id, body: Stmt) + case Abort } case class MetaCont(id: Id) extends Tree @@ -228,6 +230,186 @@ object Variables { def free(ks: MetaCont): Variables = meta(ks.id) def free(k: Cont): Variables = k match { case Cont.ContVar(id) => cont(id) + case Cont.Abort => Set() case Cont.ContLam(result, ks, body) => free(body) -- value(result) -- meta(ks) } } + + +object substitutions { + + case class Substitution( + values: Map[Id, Pure] = Map.empty, + blocks: Map[Id, Block] = Map.empty, + conts: Map[Id, Cont] = Map.empty, + metaconts: Map[Id, MetaCont] = Map.empty + ) { + def shadowValues(shadowed: IterableOnce[Id]): Substitution = copy(values = values -- shadowed) + def shadowBlocks(shadowed: IterableOnce[Id]): Substitution = copy(blocks = blocks -- shadowed) + def shadowConts(shadowed: IterableOnce[Id]): Substitution = copy(conts = conts -- shadowed) + def shadowMetaconts(shadowed: IterableOnce[Id]): Substitution = copy(metaconts = metaconts -- shadowed) + + def shadowParams(vparams: Seq[Id], bparams: Seq[Id]): Substitution = + copy(values = values -- vparams, blocks = blocks -- bparams) + } + + def substitute(expression: Expr)(using Substitution): Expr = expression match { + case DirectApp(id, vargs, bargs) => + DirectApp(id, vargs.map(substitute), bargs.map(substitute)) + case p: Pure => substitute(p) + } + + def substitute(pure: Pure)(using subst: Substitution): Pure = pure match { + case ValueVar(id) if subst.values.isDefinedAt(id) => subst.values(id) + case ValueVar(id) => ValueVar(id) + case Literal(value) => Literal(value) + case Make(tpe, tag, vargs) => Make(tpe, tag, vargs.map(substitute)) + case PureApp(id, vargs) => PureApp(id, vargs.map(substitute)) + case Box(b) => Box(substitute(b)) + } + + def substitute(block: Block)(using subst: Substitution): Block = block match { + case BlockVar(id) if subst.blocks.isDefinedAt(id) => subst.blocks(id) + case BlockVar(id) => BlockVar(id) + case b: BlockLit => substitute(b) + case Unbox(pure) => Unbox(substitute(pure)) + case New(impl) => New(substitute(impl)) + } + + def substitute(b: BlockLit)(using subst: Substitution): BlockLit = b match { + case BlockLit(vparams, bparams, ks, k, body) => + BlockLit(vparams, bparams, ks, k, + substitute(body)(using subst + .shadowParams(vparams, bparams) + .shadowMetaconts(List(ks)) + .shadowConts(List(k)))) + } + + def substitute(stmt: Stmt)(using subst: Substitution): Stmt = stmt match { + case Jump(k, arg, ks) => + Jump( + substituteAsContVar(k), + substitute(arg), + substitute(ks)) + + case App(callee, vargs, bargs, ks, k) => + App( + substitute(callee), + vargs.map(substitute), + bargs.map(substitute), + substitute(ks), + substitute(k)) + + case Invoke(callee, method, vargs, bargs, ks, k) => + Invoke( + substitute(callee), + method, + vargs.map(substitute), + bargs.map(substitute), + substitute(ks), + substitute(k)) + + case If(cond, thn, els) => + If(substitute(cond), substitute(thn), substitute(els)) + + case Match(scrutinee, clauses, default) => + Match( + substitute(scrutinee), + clauses.map { case (id, cl) => (id, substitute(cl)) }, + default.map(substitute)) + + case LetDef(id, binding, body) => + LetDef(id, substitute(binding), + substitute(body)(using subst.shadowBlocks(List(id)))) + + case LetExpr(id, binding, body) => + LetExpr(id, substitute(binding), + substitute(body)(using subst.shadowValues(List(id)))) + + case LetCont(id, binding, body) => + LetCont(id, substitute(binding), + substitute(body)(using subst.shadowConts(List(id)))) + + case Region(id, ks, body) => + Region(id, substitute(ks), + substitute(body)(using subst.shadowBlocks(List(id)))) + + case Alloc(id, init, region, body) => + Alloc(id, substitute(init), substituteAsBlockVar(region), + substitute(body)(using subst.shadowBlocks(List(id)))) + + case Var(id, init, ks, body) => + Var(id, substitute(init), substitute(ks), + substitute(body)(using subst.shadowBlocks(List(id)))) + + case Dealloc(ref, body) => + Dealloc(substituteAsBlockVar(ref), substitute(body)) + + case Get(ref, id, body) => + Get(substituteAsBlockVar(ref), id, + substitute(body)(using subst.shadowValues(List(id)))) + + case Put(ref, value, body) => + Put(substituteAsBlockVar(ref), substitute(value), substitute(body)) + + case Reset(prog, ks, k) => + Reset(substitute(prog), substitute(ks), substitute(k)) + + case Shift(prompt, body, ks, k) => + Shift(substituteAsBlockVar(prompt), substitute(body), substitute(ks), substitute(k)) + + case Resume(r, body, ks, k) => + Resume(substituteAsBlockVar(r), substitute(body), substitute(ks), substitute(k)) + + case h: Hole => h + } + + def substitute(impl: Implementation)(using Substitution): Implementation = impl match { + case Implementation(interface, operations) => + Implementation(interface, operations.map(substitute)) + } + + def substitute(op: Operation)(using subst: Substitution): Operation = op match { + case Operation(name, vparams, bparams, ks, k, body) => + Operation(name, vparams, bparams, ks, k, + substitute(body)(using subst + .shadowParams(vparams, bparams) + .shadowMetaconts(List(ks)) + .shadowConts(List(k)))) + } + + def substitute(clause: Clause)(using subst: Substitution): Clause = clause match { + case Clause(vparams, body) => + Clause(vparams, substitute(body)(using subst.shadowValues(vparams))) + } + + def substitute(k: Cont)(using subst: Substitution): Cont = k match { + case Cont.ContVar(id) if subst.conts.isDefinedAt(id) => subst.conts(id) + case Cont.ContVar(id) => Cont.ContVar(id) + case lam @ Cont.ContLam(result, ks, body) => substitute(lam) + case Cont.Abort => Cont.Abort + } + + def substitute(k: Cont.ContLam)(using subst: Substitution): Cont.ContLam = k match { + case Cont.ContLam(result, ks, body) => + Cont.ContLam(result, ks, + substitute(body)(using subst + .shadowValues(List(result)) + .shadowMetaconts(List(ks)))) + } + + def substitute(ks: MetaCont)(using subst: Substitution): MetaCont = + subst.metaconts.getOrElse(ks.id, ks) + + def substituteAsBlockVar(id: Id)(using subst: Substitution): Id = + subst.blocks.get(id) map { + case BlockVar(x) => x + case _ => INTERNAL_ERROR("References should always be variables") + } getOrElse id + + def substituteAsContVar(id: Id)(using subst: Substitution): Id = + subst.conts.get(id) map { + case Cont.ContVar(x) => x + case _ => INTERNAL_ERROR("Continuation references should always be variables") + } getOrElse id +} diff --git a/effekt/shared/src/main/scala/effekt/generator/js/TransformerCps.scala b/effekt/shared/src/main/scala/effekt/generator/js/TransformerCps.scala index 438f97668..d47ca5757 100644 --- a/effekt/shared/src/main/scala/effekt/generator/js/TransformerCps.scala +++ b/effekt/shared/src/main/scala/effekt/generator/js/TransformerCps.scala @@ -7,7 +7,7 @@ import effekt.context.assertions.* import effekt.cps.* import effekt.core.{ DeclarationContext, Id } import effekt.cps.Variables.{ all, free } - +import effekt.cps.substitutions.Substitution import scala.collection.mutable object TransformerCps extends Transformer { @@ -23,13 +23,11 @@ object TransformerCps extends Transformer { val TRAMPOLINE = Variable(JSName("TRAMPOLINE")) class RecursiveUsage(var jumped: Boolean) - case class RecursiveDefInfo(id: Id, vparams: List[Id], bparams: List[Id], ks: Id, k: Id, used: RecursiveUsage) + case class RecursiveDefInfo(id: Id, label: Id, vparams: List[Id], bparams: List[Id], ks: Id, k: Id, used: RecursiveUsage) case class ContinuationInfo(k: Id, param: Id, ks: Id) case class TransformerContext( requiresThunk: Boolean, - // known definitions of expressions (used to inline into externs) - bindings: Map[Id, js.Expr], // definitions of externs (used to inline them) externs: Map[Id, cps.Extern.Def], // the innermost (in direct style) enclosing functions (used to rewrite a definition to a loop) @@ -38,8 +36,6 @@ object TransformerCps extends Transformer { directStyle: Option[ContinuationInfo], // the current direct-style metacontinuation metacont: Option[Id], - // substitutions for renaming of metaconts (to avoid rebinding them) - metaconts: Map[Id, Id], // the original declaration context (used to compile pattern matching) declarations: DeclarationContext, // the usual compiler context @@ -63,12 +59,10 @@ object TransformerCps extends Transformer { case cps.ModuleDecl(path, includes, declarations, externs, definitions, _) => given TransformerContext( false, - Map.empty, externs.collect { case d: Extern.Def => (d.id, d) }.toMap, None, None, None, - Map.empty, D, C) val name = JSName(jsModuleName(module.path)) @@ -88,12 +82,10 @@ object TransformerCps extends Transformer { val D = new DeclarationContext(coreModule.declarations, coreModule.externs) given TransformerContext( false, - Map.empty, input.externs.collect { case d: Extern.Def => (d.id, d) }.toMap, None, None, None, - Map.empty, D, C) input.definitions.map(toJS) @@ -152,12 +144,13 @@ object TransformerCps extends Transformer { def toJS(id: Id, b: cps.Block)(using TransformerContext): js.Expr = b match { case cps.Block.BlockLit(vparams, bparams, ks, k, body) => val used = new RecursiveUsage(false) + val label = Id(id) - val translatedBody = toJS(body)(using recursive(id, used, b)).stmts + val translatedBody = toJS(body)(using recursive(id, label, used, b)).stmts if used.jumped then js.Lambda(vparams.map(nameDef) ++ bparams.map(nameDef) ++ List(nameDef(ks), nameDef(k)), - List(js.While(RawExpr("true"), translatedBody, Some(uniqueName(id))))) + List(js.While(RawExpr("true"), translatedBody, Some(uniqueName(label))))) else js.Lambda(vparams.map(nameDef) ++ bparams.map(nameDef) ++ List(nameDef(ks), nameDef(k)), translatedBody) @@ -182,18 +175,18 @@ object TransformerCps extends Transformer { }) } - def toJS(ks: cps.MetaCont)(using T: TransformerContext): js.Expr = - nameRef(T.metaconts.getOrElse(ks.id, ks.id)) + def toJS(ks: cps.MetaCont)(using T: TransformerContext): js.Expr = nameRef(ks.id) def toJS(k: cps.Cont)(using T: TransformerContext): js.Expr = k match { case Cont.ContVar(id) => nameRef(id) case Cont.ContLam(result, ks, body) => js.Lambda(List(nameDef(result), nameDef(ks)), toJS(body)(using nonrecursive(ks)).stmts) + case Cont.Abort => js.Undefined } def toJS(e: cps.Expr)(using D: TransformerContext): js.Expr = e match { - case Pure.ValueVar(id) => lookup(id) + case Pure.ValueVar(id) => nameRef(id) case Pure.Literal(()) => $effekt.field("unit") case Pure.Literal(s: String) => JsString(escape(s)) case literal: Pure.Literal => js.RawExpr(literal.value.toString) @@ -224,8 +217,8 @@ object TransformerCps extends Transformer { case cps.Stmt.LetCont(id, Cont.ContLam(param, ks, body), body2) if canBeDirect(id, body2) => Binding { k => js.Let(nameDef(param), js.Undefined) :: - toJS(body2)(using withDirectStyle(id, param, ks)).stmts ++ - toJS(body)(using directstyle(ks)).run(k) + toJS(body2)(using markDirectStyle(id, param, ks)).stmts ++ + toJS(maintainDirectStyle(ks, body)).run(k) } case cps.Stmt.LetCont(id, binding @ Cont.ContLam(result2, ks2, body2), body) => @@ -266,43 +259,92 @@ object TransformerCps extends Transformer { case cps.Stmt.Jump(k, arg, ks) => pure(js.Return(maybeThunking(js.Call(nameRef(k), toJS(arg), toJS(ks))))) - case cps.Stmt.App(Recursive(id, vparams, bparams, ks1, k1, used), vargs, bargs, MetaCont(ks), Cont.ContVar(k)) if sameScope(ks, k, ks1, k1) => + + case cps.Stmt.App(Recursive(id, label, vparams, bparams, ks1, k1, used), vargs, bargs, MetaCont(ks), k) => Binding { k2 => val stmts = mutable.ListBuffer.empty[js.Stmt] - stmts.append(js.RawStmt("/* prepare tail call */")) + stmts.append(js.RawStmt("/* prepare call */")) used.jumped = true - // const x3 = [[ arg ]]; ... - val vtmps = (vparams zip vargs).map { (id, arg) => - val tmp = Id(id) - stmts.append(js.Const(nameDef(tmp), toJS(arg))) - tmp + // We need to create temporaries for all free variables that appear in arguments + val freeInArgs = (vargs.flatMap(Variables.free) ++ bargs.flatMap(Variables.free)).toSet + // Only compute free vars of continuation if it's not a ContVar + val (isTailCall, freeInK) = k match { + case Cont.ContVar(kid) => (kid == k1 && ks == ks1, Set.empty[Id]) + case _ => (false, Variables.free(k)) + } + val allFreeVars = freeInArgs ++ freeInK + val needsKsTmp = allFreeVars.contains(ks1) + val overlapping = allFreeVars.intersect((vparams ++ bparams).toSet) + + // Create temporaries for parameters that are used in the arguments or continuation + val paramTmps = overlapping.map { param => + val tmp = Id(s"tmp_${param}") + stmts.append(js.Const(nameDef(tmp), nameRef(param))) + param -> tmp + }.toMap + + val tmp_ks = if (needsKsTmp) { + val tmp = Id("tmp_ks") + stmts.append(js.Const(nameDef(tmp), nameRef(ks1))) + Some(tmp) + } else None + + // For non-tail calls, we need a continuation temporary if it's not a simple variable rename + val tmp_k = if (!isTailCall) k match { + case Cont.ContVar(kid) if kid != k1 => + // simple continuation variable, no need for temp binding + None + case _ => + val tmp = Id("tmp_k") + stmts.append(js.Const(nameDef(tmp), nameRef(k1))) + Some(tmp) + } else { + // For tail calls, only create temporary if k1 appears in arguments + if (freeInArgs.contains(k1)) { + val tmp = Id("tmp_k") + stmts.append(js.Const(nameDef(tmp), nameRef(k1))) + Some(tmp) + } else None } - val btmps = (bparams zip bargs).map { (id, arg) => - val tmp = Id(id) - stmts.append(js.Const(nameDef(tmp), toJS(arg))) - tmp + + // Prepare the substitution + val subst = Substitution( + values = paramTmps.map { case (p, t) => p -> Pure.ValueVar(t) }, + blocks = Map.empty, + conts = tmp_k.map(t => k1 -> Cont.ContVar(t)).toMap, + metaconts = tmp_ks.map(t => ks1 -> MetaCont(t)).toMap + ) + + // Update the continuation if this is not a tail call + if (!isTailCall) k match { + case Cont.ContVar(kid) if kid != k1 => + // simple variable rename + stmts.append(js.Assign(nameRef(k1), nameRef(kid))) + case _ => + stmts.append(js.Assign(nameRef(k1), toJS(substitutions.substitute(k)(using subst)))) } - // x = x3; - (vparams zip vtmps).foreach { - (param, tmp) => stmts.append(js.Assign(nameRef(param), nameRef(tmp))) + // Assign the substituted arguments + (vparams zip vargs).foreach { (param, arg) => + stmts.append(js.Assign(nameRef(param), toJS(substitutions.substitute(arg)(using subst)))) } - (bparams zip btmps).foreach { - (param, tmp) => stmts.append(js.Assign(nameRef(param), nameRef(tmp))) + (bparams zip bargs).foreach { (param, arg) => + stmts.append(js.Assign(nameRef(param), toJS(substitutions.substitute(arg)(using subst)))) } - // continue f; - val jump = js.Continue(Some(uniqueName(id))); + // Restore metacont if needed + if (needsKsTmp) stmts.append(js.Assign(nameRef(ks1), nameRef(tmp_ks.get))) + val jump = js.Continue(Some(uniqueName(label))) stmts.appendAll(k2(jump)) stmts.toList } case cps.Stmt.App(callee, vargs, bargs, ks, k) => - pure(js.Return(maybeThunking(js.Call(toJS(callee), vargs.map(toJS) ++ bargs.map(toJS) ++ List(toJS(ks), - requiringThunk { toJS(k) }))))) + pure(js.Return(js.Call(toJS(callee), vargs.map(toJS) ++ bargs.map(toJS) ++ List(toJS(ks), + requiringThunk { toJS(k) })))) case cps.Stmt.Invoke(callee, method, vargs, bargs, ks, k) => val args = vargs.map(toJS) ++ bargs.map(toJS) ++ List(toJS(ks), toJS(k)) @@ -350,13 +392,13 @@ object TransformerCps extends Transformer { } case cps.Stmt.Reset(prog, ks, k) => - pure(js.Return(Call(RESET, toJS(prog)(using nonrecursive(prog)), toJS(ks), toJS(k)))) + pure(js.Return(Call(RESET, requiringThunk { toJS(prog)(using nonrecursive(prog)) }, toJS(ks), toJS(k)))) case cps.Stmt.Shift(prompt, body, ks, k) => - pure(js.Return(Call(SHIFT, nameRef(prompt), noThunking { toJS(body)(using nonrecursive(body)) }, toJS(ks), toJS(k)))) + pure(js.Return(Call(SHIFT, nameRef(prompt), requiringThunk { toJS(body)(using nonrecursive(body)) }, toJS(ks), toJS(k)))) case cps.Stmt.Resume(r, b, ks2, k2) => - pure(js.Return(js.Call(RESUME, nameRef(r), toJS(b)(using nonrecursive(b)), toJS(ks2), toJS(k2)))) + pure(js.Return(js.Call(RESUME, nameRef(r), toJS(b)(using nonrecursive(b)), toJS(ks2), requiringThunk { toJS(k2) }))) case cps.Stmt.Hole() => pure(js.Return($effekt.call("hole"))) @@ -391,9 +433,19 @@ object TransformerCps extends Transformer { T.externs.get(id) match { case Some(cps.Extern.Def(id, params, Nil, async, ExternBody.StringExternBody(featureFlag, Template(strings, templateArgs)))) if !async => - bindingAll(params.zip(args.map(toJS))) { - js.RawExpr(strings, templateArgs.map(toJS)) + val subst = substitutions.Substitution( + values = (params zip args).toMap, + blocks = Map.empty, + conts = Map.empty, + metaconts = Map.empty + ) + + // Apply substitution to template arguments + val substitutedArgs = templateArgs.map { arg => + toJS(substitutions.substitute(arg)(using subst)) } + + js.RawExpr(strings, substitutedArgs) case _ => js.Call(nameRef(id), args.map(toJS)) } @@ -402,28 +454,19 @@ object TransformerCps extends Transformer { case _ => false } - private def bindingAll[R](bs: List[(Id, js.Expr)])(body: TransformerContext ?=> R)(using C: TransformerContext): R = - body(using C.copy(bindings = C.bindings ++ bs)) - - private def lookup(id: Id)(using C: TransformerContext): js.Expr = C.bindings.getOrElse(id, nameRef(id)) - - // Helpers for Direct-Style Transformation // --------------------------------------- /** - * Used to determine whether a call with continuations [[ ks ]] (after substitution) and [[ k ]] - * is the same as the original function definition (that is [[ ks1 ]] and [[ k1 ]]. + * Marks continuation `id` to be optimized to direct assignments to `param` instead of return statements. + * This is only valid in the same metacontinuation scope `ks`. */ - private def sameScope(ks: Id, k: Id, ks1: Id, k1: Id)(using C: TransformerContext): Boolean = - ks1 == C.metaconts.getOrElse(ks, ks) && k1 == k - - private def withDirectStyle(id: Id, param: Id, ks: Id)(using C: TransformerContext): TransformerContext = + private def markDirectStyle(id: Id, param: Id, ks: Id)(using C: TransformerContext): TransformerContext = C.copy(directStyle = Some(ContinuationInfo(id, param, ks))) - private def recursive(id: Id, used: RecursiveUsage, block: cps.Block)(using C: TransformerContext): TransformerContext = block match { + private def recursive(id: Id, label: Id, used: RecursiveUsage, block: cps.Block)(using C: TransformerContext): TransformerContext = block match { case cps.BlockLit(vparams, bparams, ks, k, body) => - C.copy(recursive = Some(RecursiveDefInfo(id, vparams, bparams, ks, k, used)), directStyle = None, metacont = Some(ks)) + C.copy(recursive = Some(RecursiveDefInfo(id, label, vparams, bparams, ks, k, used)), directStyle = None, metacont = Some(ks)) case _ => C } @@ -432,17 +475,20 @@ object TransformerCps extends Transformer { private def nonrecursive(block: cps.BlockLit)(using C: TransformerContext): TransformerContext = nonrecursive(block.ks) - // ks | let k1 x1 ks1 = { let k2 x2 ks2 = jump k v ks2 }; ... = jump k v ks - private def directstyle(ks: Id)(using C: TransformerContext): TransformerContext = + /** + * Ensures let-bound continuations can stay in direct style by aligning metacont scopes. + * This is used when the let-bound body jumps to an outer continuation. + * ks | let k1 x1 ks1 = { let k2 x2 ks2 = jump k v ks2 }; ... = jump k v ks + */ + private def maintainDirectStyle(ks: Id, body: Stmt)(using C: TransformerContext): Stmt = { val outer = C.metacont.getOrElse { sys error "Metacontinuation missing..." } - val outerSubstituted = C.metaconts.getOrElse(outer, outer) - val subst = C.metaconts.updated(ks, outerSubstituted) - C.copy(metacont = Some(ks), metaconts = subst) + substitutions.substitute(body)(using Substitution(metaconts = Map(ks -> MetaCont(outer)))) + } private object Recursive { - def unapply(b: cps.Block)(using C: TransformerContext): Option[(Id, List[Id], List[Id], Id, Id, RecursiveUsage)] = b match { + def unapply(b: cps.Block)(using C: TransformerContext): Option[(Id, Id, List[Id], List[Id], Id, Id, RecursiveUsage)] = b match { case cps.Block.BlockVar(id) => C.recursive.collect { - case RecursiveDefInfo(id2, vparams, bparams, ks, k, used) if id == id2 => (id, vparams, bparams, ks, k, used) + case RecursiveDefInfo(id2, label, vparams, bparams, ks, k, used) if id == id2 => (id, label, vparams, bparams, ks, k, used) } case _ => None } @@ -471,7 +517,7 @@ object TransformerCps extends Transformer { case Stmt.LetDef(id, binding, body) => notIn(binding) && canBeDirect(k, body) case Stmt.LetExpr(id, binding, body) => notIn(binding) && canBeDirect(k, body) case Stmt.LetCont(id, Cont.ContLam(result, ks2, body), body2) => - def willBeDirectItself = canBeDirect(id, body2) && canBeDirect(k, body)(using directstyle(ks2)) + def willBeDirectItself = canBeDirect(id, body2) && canBeDirect(k, maintainDirectStyle(ks2, body)) def notFreeinContinuation = notIn(body) && canBeDirect(k, body2) willBeDirectItself || notFreeinContinuation case Stmt.Region(id, ks, body) => notIn(body) diff --git a/effekt/shared/src/main/scala/effekt/machine/Transformer.scala b/effekt/shared/src/main/scala/effekt/machine/Transformer.scala index 354d13e53..092f927df 100644 --- a/effekt/shared/src/main/scala/effekt/machine/Transformer.scala +++ b/effekt/shared/src/main/scala/effekt/machine/Transformer.scala @@ -225,7 +225,8 @@ object Transformer { case core.Reset(core.BlockLit(Nil, cparams, Nil, List(prompt), body)) => noteParameters(List(prompt)) - val variable = Variable(freshName("returned"), transform(body.tpe)) + val answerType = stmt.tpe + val variable = Variable(freshName("returned"), transform(answerType)) val returnClause = Clause(List(variable), Return(List(variable))) Reset(Variable(transform(prompt.id), Type.Prompt()), returnClause, transform(body)) diff --git a/examples/benchmarks/effect_handlers_bench/parsing_dollars.effekt b/examples/benchmarks/effect_handlers_bench/parsing_dollars.effekt index 11a662a89..f41263be9 100644 --- a/examples/benchmarks/effect_handlers_bench/parsing_dollars.effekt +++ b/examples/benchmarks/effect_handlers_bench/parsing_dollars.effekt @@ -70,4 +70,3 @@ def run(n: Int) = sum { catch { feed(n) { parse(0) } } } def main() = benchmark(10){run} -