diff --git a/effekt/jvm/src/test/scala/effekt/StdlibTests.scala b/effekt/jvm/src/test/scala/effekt/StdlibTests.scala index e3066302d..d362b2384 100644 --- a/effekt/jvm/src/test/scala/effekt/StdlibTests.scala +++ b/effekt/jvm/src/test/scala/effekt/StdlibTests.scala @@ -41,11 +41,17 @@ class StdlibLLVMTests extends StdlibTests { override def debug = sys.env.get("EFFEKT_DEBUG").nonEmpty override def ignored: List[File] = List( - // Toplevel let-bindings (for ANSI-color-codes in output) not supported - examplesDir / "stdlib" / "test" / "test.effekt", + // segfaults + examplesDir / "stdlib" / "stream" / "fuse_newlines.effekt", + + // valgrind + examplesDir / "stdlib" / "list" / "modifyat.effekt", + examplesDir / "stdlib" / "list" / "updateat.effekt", + // Syscall param write(buf) points to uninitialised byte(s) examplesDir / "stdlib" / "io" / "filesystem" / "files.effekt", examplesDir / "stdlib" / "io" / "filesystem" / "async_file_io.effekt", + // Conditional jump or move depends on uninitialised value(s) examplesDir / "stdlib" / "io" / "filesystem" / "wordcount.effekt", ) diff --git a/effekt/jvm/src/test/scala/effekt/core/OptimizerTests.scala b/effekt/jvm/src/test/scala/effekt/core/OptimizerTests.scala index c6d933961..dc91897d9 100644 --- a/effekt/jvm/src/test/scala/effekt/core/OptimizerTests.scala +++ b/effekt/jvm/src/test/scala/effekt/core/OptimizerTests.scala @@ -1,5 +1,8 @@ package effekt package core + +import effekt.core.optimizer.* + import effekt.symbols class OptimizerTests extends CoreTests { @@ -30,15 +33,11 @@ class OptimizerTests extends CoreTests { Deadcode.remove(Set(mainSymbol), tree) } - def inlineOnce(input: String, expected: String)(using munit.Location) = - assertTransformsTo(input, expected) { tree => - val (result, count) = Inline.once(Set(mainSymbol), tree, 50) - result - } - - def inlineFull(input: String, expected: String)(using munit.Location) = + def normalize(input: String, expected: String)(using munit.Location) = assertTransformsTo(input, expected) { tree => - Inline.full(Set(mainSymbol), tree, 50) + val anfed = BindSubexpressions.transform(tree) + val normalized = Normalizer.normalize(Set(mainSymbol), anfed, 50) + Deadcode.remove(mainSymbol, normalized) } test("toplevel"){ @@ -155,11 +154,10 @@ class OptimizerTests extends CoreTests { |""".stripMargin val expected = - """ def foo = { () => return 42 } - | def main = { () => return 42 } + """ def main = { () => return 42 } |""".stripMargin - inlineOnce(input, expected) + normalize(input, expected) } test("inline with argument"){ @@ -169,11 +167,10 @@ class OptimizerTests extends CoreTests { |""".stripMargin val expected = - """ def foo = { (n: Int) => return n:Int } - | def main = { () => return 42 } + """ def main = { () => return 42 } |""".stripMargin - inlineOnce(input, expected) + normalize(input, expected) } test("inline higher order function"){ @@ -188,17 +185,10 @@ class OptimizerTests extends CoreTests { |""".stripMargin val expected = - """ def foo = { (n: Int) => return n:Int } - | def hof = { (){f : (Int) => Int} => - | (f : (Int) => Int @ {f})(1) - | } - | def main = { () => - | def local(n: Int) = return n:Int - | (local : (Int) => Int @ {})(1) - | } + """ def main = { () => return 1 } |""".stripMargin - inlineOnce(input, expected) + normalize(input, expected) } test("fully inline higher order function"){ @@ -216,7 +206,7 @@ class OptimizerTests extends CoreTests { """ def main = { () => return 1 } |""".stripMargin - inlineFull(input, expected) + normalize(input, expected) } } diff --git a/effekt/shared/src/main/scala/effekt/core/Inline.scala b/effekt/shared/src/main/scala/effekt/core/Inline.scala deleted file mode 100644 index 8fc01284e..000000000 --- a/effekt/shared/src/main/scala/effekt/core/Inline.scala +++ /dev/null @@ -1,284 +0,0 @@ -package effekt -package core - -import effekt.core.Block.BlockLit -import effekt.core.Pure.ValueVar -import effekt.core.normal.* -import effekt.util.messages.INTERNAL_ERROR - -import scala.collection.mutable -import kiama.util.Counter - -/** - * Inlines block definitions. - * - * 1. First computes usage (using [[Reachable.apply]]) - * 2. Top down traversal where we inline definitions - * - * Invariants: - * - the context `defs` always contains the _original_ definitions, not rewritten ones. - * Rewriting them has to be performed at the inline-site. - */ -object Inline { - - case class InlineContext( - // is mutable to update when introducing temporaries; - // they should also be visible after leaving a scope (so mutable.Map and not `var usage`). - usage: mutable.Map[Id, Usage], - defs: Map[Id, Definition], - maxInlineSize: Int, - inlineCount: Counter = Counter(0) - ) { - def ++(other: Map[Id, Definition]): InlineContext = InlineContext(usage, defs ++ other, maxInlineSize, inlineCount) - - def ++(other: List[Definition]): InlineContext = ++(other.map(d => d.id -> d).toMap) - - def ++=(fresh: Map[Id, Usage]): Unit = { usage ++= fresh } - } - - def once(entrypoints: Set[Id], m: ModuleDecl, maxInlineSize: Int): (ModuleDecl, Int) = { - val usage = Reachable(m) ++ entrypoints.map(id => id -> Usage.Many).toMap - val defs = m.definitions.map(d => d.id -> d).toMap - val context = InlineContext(mutable.Map.from(usage), defs, maxInlineSize) - - val (updatedDefs, _) = rewrite(m.definitions)(using context) - (m.copy(definitions = updatedDefs), context.inlineCount.value) - } - - def full(entrypoints: Set[Id], m: ModuleDecl, maxInlineSize: Int): ModuleDecl = - var lastCount = 1 - var tree = m - while (lastCount > 0) { - val (inlined, count) = Inline.once(entrypoints, tree, maxInlineSize) - // (3) drop unused definitions after inlining - tree = Deadcode.remove(entrypoints, inlined) - lastCount = count - } - tree - - def shouldInline(id: Id)(using ctx: InlineContext): Boolean = - ctx.usage.get(id) match { - case None => false - case Some(Usage.Once) => true - case Some(Usage.Recursive) => false // we don't inline recursive functions for the moment - case Some(Usage.Many) => - ctx.defs.get(id).exists { d => - def isSmall = d.size <= ctx.maxInlineSize - def isHigherOrder = d match { - case Definition.Def(id, BlockLit(_, _, _, bparams, _)) => - bparams.exists(p => p.tpe match { - case t: BlockType.Function => true - case t: BlockType.Interface => false - }) - case _ => false - } - isSmall || isHigherOrder - } - } - - def shouldKeep(id: Id)(using ctx: InlineContext): Boolean = - ctx.usage.get(id) match { - case None => false - case Some(Usage.Once) => false - case Some(Usage.Recursive) => true // we don't inline recursive functions for the moment - case Some(Usage.Many) => true - } - - def used(id: Id)(using ctx: InlineContext): Boolean = - ctx.usage.isDefinedAt(id) - - /** - * Rewrites the list of definition and returns: - * 1. the updated list - * 2. definitions - * a. original defnitions: in case we need to dealias elsewhere - * b. the updated definitions, where the rhs might have been dealiased already (see #733) - */ - def rewrite(definitions: List[Definition])(using ctx: InlineContext): (List[Definition], InlineContext) = - given allDefs: InlineContext = ctx ++ definitions - - val filtered = definitions.collect { - case Definition.Def(id, block) => Definition.Def(id, rewrite(block)) - // we drop aliases - case Definition.Let(id, tpe, binding) if !binding.isInstanceOf[ValueVar] => - Definition.Let(id, tpe, rewrite(binding)) - } - (filtered, allDefs ++ filtered) - - def blockDefFor(id: Id)(using ctx: InlineContext): Option[Block] = - ctx.defs.get(id) map { - case Definition.Def(id, block) => block - case Definition.Let(id, _, binding) => INTERNAL_ERROR("Should not happen") - } - - def dealias(b: Block.BlockVar)(using ctx: InlineContext): BlockVar = - ctx.defs.get(b.id) match { - case Some(Definition.Def(id, aliased : Block.BlockVar)) => dealias(aliased) - case _ => b - } - - def dealias(b: Pure.ValueVar)(using ctx: InlineContext): ValueVar = - ctx.defs.get(b.id) match { - case Some(Definition.Let(id, _, aliased : Pure.ValueVar)) => dealias(aliased) - case _ => b - } - - def rewrite(d: Definition)(using InlineContext): Definition = d match { - case Definition.Def(id, block) => Definition.Def(id, rewrite(block)) - case Definition.Let(id, tpe, binding) => Definition.Let(id, tpe, rewrite(binding)) - } - - def rewrite(s: Stmt)(using C: InlineContext): Stmt = s match { - case Stmt.Scope(definitions, body) => - val (defs, ctx) = rewrite(definitions) - scope(defs, rewrite(body)(using ctx)) - - case Stmt.App(b, targs, vargs, bargs) => - app(rewrite(b), targs, vargs.map(rewrite), bargs.map(rewrite)) - - case Stmt.Invoke(b, method, methodTpe, targs, vargs, bargs) => - invoke(rewrite(b), method, methodTpe, targs, vargs.map(rewrite), bargs.map(rewrite)) - - case Stmt.Reset(body) => - rewrite(body) match { - case BlockLit(tparams, cparams, vparams, List(prompt), body) if !used(prompt.id) => body - case b => Stmt.Reset(b) - } - - // congruences - case Stmt.Return(expr) => Return(rewrite(expr)) - case Stmt.Val(id, tpe, binding, body) => valDef(id, tpe, rewrite(binding), rewrite(body)) - case Stmt.If(cond, thn, els) => If(rewrite(cond), rewrite(thn), rewrite(els)) - case Stmt.Match(scrutinee, clauses, default) => - patternMatch(rewrite(scrutinee), clauses.map { case (id, value) => id -> rewrite(value) }, default.map(rewrite)) - case Stmt.Alloc(id, init, region, body) => Alloc(id, rewrite(init), region, rewrite(body)) - case Stmt.Shift(prompt, b @ BlockLit(tparams, cparams, vparams, List(k), body)) if tailResumptive(k.id, body) => - C.inlineCount.next() - rewrite(removeTailResumption(k.id, body)) - - case Stmt.Shift(prompt, body) => Shift(prompt, rewrite(body)) - - - case Stmt.Resume(k, body) => Resume(k, rewrite(body)) - case Stmt.Region(body) => Region(rewrite(body)) - case Stmt.Var(id, init, capture, body) => Stmt.Var(id, rewrite(init), capture, rewrite(body)) - case Stmt.Get(id, capt, tpe) => Stmt.Get(id, capt, tpe) - case Stmt.Put(id, capt, value) => Stmt.Put(id, capt, rewrite(value)) - case Stmt.Hole() => s - } - def rewrite(b: BlockLit)(using InlineContext): BlockLit = - b match { - case BlockLit(tparams, cparams, vparams, bparams, body) => - BlockLit(tparams, cparams, vparams, bparams, rewrite(body)) - } - - def rewrite(b: Block)(using C: InlineContext): Block = b match { - case Block.BlockVar(id, _, _) if shouldInline(id) => - blockDefFor(id) match { - case Some(value) => - C.inlineCount.next() - Renamer.rename(value) - case None => b - } - case b @ Block.BlockVar(id, _, _) => dealias(b) - - // congruences - case b @ Block.BlockLit(tparams, cparams, vparams, bparams, body) => rewrite(b) - case Block.Unbox(pure) => unbox(rewrite(pure)) - case Block.New(impl) => New(rewrite(impl)) - } - - def rewrite(s: Implementation)(using InlineContext): Implementation = - s match { - case Implementation(interface, operations) => Implementation(interface, operations.map { op => - op.copy(body = rewrite(op.body)) - }) - } - - def rewrite(p: Pure)(using InlineContext): Pure = p match { - case Pure.PureApp(b, targs, vargs) => pureApp(rewrite(b), targs, vargs.map(rewrite)) - case Pure.Make(data, tag, vargs) => make(data, tag, vargs.map(rewrite)) - // currently, we don't inline values, but we can dealias them - case x @ Pure.ValueVar(id, annotatedType) => dealias(x) - - // congruences - case Pure.Literal(value, annotatedType) => p - case Pure.Select(target, field, annotatedType) => select(rewrite(target), field, annotatedType) - case Pure.Box(b, annotatedCapture) => box(rewrite(b), annotatedCapture) - } - - def rewrite(e: Expr)(using InlineContext): Expr = e match { - case DirectApp(b, targs, vargs, bargs) => directApp(rewrite(b), targs, vargs.map(rewrite), bargs.map(rewrite)) - - // congruences - case Run(s) => run(rewrite(s)) - case pure: Pure => rewrite(pure) - } - - case class Binding[A](run: (A => Stmt) => Stmt) { - def flatMap[B](rest: A => Binding[B]): Binding[B] = { - Binding(k => run(a => rest(a).run(k))) - } - } - - def pure[A](a: A): Binding[A] = Binding(k => k(a)) - - // A simple syntactic check whether this stmt is tailresumptive in k - def tailResumptive(k: Id, stmt: Stmt): Boolean = - def freeInStmt(stmt: Stmt): Boolean = Variables.free(stmt).containsBlock(k) - def freeInExpr(expr: Expr): Boolean = Variables.free(expr).containsBlock(k) - def freeInDef(definition: Definition): Boolean = Variables.free(definition).containsBlock(k) - - stmt match { - case Stmt.Scope(definitions, body) => definitions.forall { d => !freeInDef(d) } && tailResumptive(k, body) - case Stmt.Return(expr) => false - case Stmt.Val(id, annotatedTpe, binding, body) => tailResumptive(k, body) && !freeInStmt(binding) - case Stmt.App(callee, targs, vargs, bargs) => false - case Stmt.Invoke(callee, method, methodTpe, targs, vargs, bargs) => false - case Stmt.If(cond, thn, els) => !freeInExpr(cond) && tailResumptive(k, thn) && tailResumptive(k, els) - // Interestingly, we introduce a join point making this more difficult to implement properly - case Stmt.Match(scrutinee, clauses, default) => !freeInExpr(scrutinee) && clauses.forall { - case (_, BlockLit(tparams, cparams, vparams, bparams, body)) => tailResumptive(k, body) - } && default.forall { stmt => tailResumptive(k, stmt) } - case Stmt.Region(BlockLit(tparams, cparams, vparams, bparams, body)) => tailResumptive(k, body) - case Stmt.Region(_) => ??? - case Stmt.Alloc(id, init, region, body) => tailResumptive(k, body) && !freeInExpr(init) - case Stmt.Var(id, init, capture, body) => tailResumptive(k, body) && !freeInExpr(init) - case Stmt.Get(id, annotatedCapt, annotatedTpe) => false - case Stmt.Put(id, annotatedCapt, value) => false - case Stmt.Reset(BlockLit(tparams, cparams, vparams, bparams, body)) => tailResumptive(k, body) // is this correct? - case Stmt.Shift(prompt, body) => false - case Stmt.Resume(k2, body) => k2.id == k // what if k is free in body? - case Stmt.Hole() => true - } - - def removeTailResumption(k: Id, stmt: Stmt): Stmt = stmt match { - case Stmt.Scope(definitions, body) => Stmt.Scope(definitions, removeTailResumption(k, body)) - case Stmt.Val(id, annotatedTpe, binding, body) => Stmt.Val(id, annotatedTpe, binding, removeTailResumption(k, body)) - case Stmt.If(cond, thn, els) => Stmt.If(cond, removeTailResumption(k, thn), removeTailResumption(k, els)) - case Stmt.Match(scrutinee, clauses, default) => Stmt.Match(scrutinee, clauses.map { - case (tag, block) => tag -> removeTailResumption(k, block) - }, default.map(removeTailResumption(k, _))) - case Stmt.Region(body : BlockLit) => - Stmt.Region(removeTailResumption(k, body)) - case Stmt.Region(_) => ??? - case Stmt.Alloc(id, init, region, body) => Stmt.Alloc(id, init, region, removeTailResumption(k, body)) - case Stmt.Var(id, init, capture, body) => Stmt.Var(id, init, capture, removeTailResumption(k, body)) - case Stmt.Reset(body) => Stmt.Reset(removeTailResumption(k, body)) - case Stmt.Resume(k2, body) if k2.id == k => body - - case Stmt.Resume(k, body) => stmt - case Stmt.Shift(prompt, body) => stmt - case Stmt.Hole() => stmt - case Stmt.Return(expr) => stmt - case Stmt.App(callee, targs, vargs, bargs) => stmt - case Stmt.Invoke(callee, method, methodTpe, targs, vargs, bargs) => stmt - case Stmt.Get(id, annotatedCapt, annotatedTpe) => stmt - case Stmt.Put(id, annotatedCapt, value) => stmt - } - - def removeTailResumption(k: Id, block: BlockLit): BlockLit = block match { - case BlockLit(tparams, cparams, vparams, bparams, body) => - BlockLit(tparams, cparams, vparams, bparams, removeTailResumption(k, body)) - } -} diff --git a/effekt/shared/src/main/scala/effekt/core/LambdaLifting.scala b/effekt/shared/src/main/scala/effekt/core/LambdaLifting.scala index 28fcee820..27afab523 100644 --- a/effekt/shared/src/main/scala/effekt/core/LambdaLifting.scala +++ b/effekt/shared/src/main/scala/effekt/core/LambdaLifting.scala @@ -7,8 +7,6 @@ import scala.collection.mutable import effekt.core.Variables import effekt.core.Variables.{ all, bound, free } -import effekt.core.normal.scope - class LambdaLifting(m: core.ModuleDecl)(using Context) extends core.Tree.Rewrite { val locals = Locals(m) @@ -61,7 +59,7 @@ class LambdaLifting(m: core.ModuleDecl)(using Context) extends core.Tree.Rewrite override def stmt = { case core.Scope(defs, body) => - scope(defs.flatMap { + MaybeScope(defs.flatMap { // we lift named local definitions to the toplevel case Definition.Def(id, BlockLit(tparams, cparams, vparams, bparams, body)) => lifted.append(Definition.Def(id, diff --git a/effekt/shared/src/main/scala/effekt/core/Optimizer.scala b/effekt/shared/src/main/scala/effekt/core/Optimizer.scala deleted file mode 100644 index a1dd8f80b..000000000 --- a/effekt/shared/src/main/scala/effekt/core/Optimizer.scala +++ /dev/null @@ -1,36 +0,0 @@ -package effekt -package core - -import effekt.PhaseResult.CoreTransformed -import effekt.context.Context - -import kiama.util.Source - -object Optimizer extends Phase[CoreTransformed, CoreTransformed] { - - val phaseName: String = "core-optimizer" - - def run(input: CoreTransformed)(using Context): Option[CoreTransformed] = - input match { - case CoreTransformed(source, tree, mod, core) => - val term = Context.checkMain(mod) - val optimized = optimize(source, term, core) - Some(CoreTransformed(source, tree, mod, optimized)) - } - - def optimize(source: Source, mainSymbol: symbols.Symbol, core: ModuleDecl)(using Context): ModuleDecl = - // (1) first thing we do is simply remove unused definitions (this speeds up all following analysis and rewrites) - val tree = Context.timed("deadcode-elimination", source.name) { Deadcode.remove(mainSymbol, core) } - - if !Context.config.optimize() then return tree; - - // (2) lift static arguments (worker/wrapper) - val lifted = Context.timed("static-argument-transformation", source.name) { - StaticArguments.transform(mainSymbol, tree) - } - - // (3) inline unique block definitions - Context.timed("inliner", source.name) { - Inline.full(Set(mainSymbol), lifted, Context.config.maxInlineSize().toInt) - } -} diff --git a/effekt/shared/src/main/scala/effekt/core/PolymorphismBoxing.scala b/effekt/shared/src/main/scala/effekt/core/PolymorphismBoxing.scala index c2057425c..8030e7dff 100644 --- a/effekt/shared/src/main/scala/effekt/core/PolymorphismBoxing.scala +++ b/effekt/shared/src/main/scala/effekt/core/PolymorphismBoxing.scala @@ -369,7 +369,7 @@ object PolymorphismBoxing extends Phase[CoreTransformed, CoreTransformed] { (BlockType.Function(tparams, cparams, vparams, bparams, boxedResult), boxedResult, result) case _ => Context.abort("Body of a region cannot have interface type") } - val doBoxResult = coercer[Block](tBody.tpe, expectedBodyTpe) + val doBoxResult = coercer[BlockLit](tBody.tpe, expectedBodyTpe) // Create coercer for eagerly unboxing the result again val doUnboxResult = coercer(actualReturnType, expectedReturnType) val resName = TmpValue("boxedResult") diff --git a/effekt/shared/src/main/scala/effekt/core/PrettyPrinter.scala b/effekt/shared/src/main/scala/effekt/core/PrettyPrinter.scala index 96d11ed33..b989fb753 100644 --- a/effekt/shared/src/main/scala/effekt/core/PrettyPrinter.scala +++ b/effekt/shared/src/main/scala/effekt/core/PrettyPrinter.scala @@ -105,7 +105,7 @@ object PrettyPrinter extends ParenPrettyPrinter { case Select(b, field, tpe) => toDoc(b) <> "." <> toDoc(field) case Box(b, capt) => parens("box" <+> toDoc(b)) - case Run(s) => "run" <+> braces(toDoc(s)) + case Run(s) => "run" <+> block(toDoc(s)) } def argsToDoc(targs: List[core.ValueType], vargs: List[core.Pure], bargs: List[core.Block]): Doc = @@ -154,7 +154,7 @@ object PrettyPrinter extends ParenPrettyPrinter { def toDoc(d: Definition): Doc = d match { case Definition.Def(id, BlockLit(tps, cps, vps, bps, body)) => - "def" <+> toDoc(id) <> paramsToDoc(tps, vps, bps) <+> "=" <> nested(toDoc(body)) + "def" <+> toDoc(id) <> paramsToDoc(tps, vps, bps) <+> "=" <+> block(toDoc(body)) case Definition.Def(id, block) => "def" <+> toDoc(id) <+> "=" <+> toDoc(block) case Definition.Let(id, _, binding) => @@ -166,14 +166,14 @@ object PrettyPrinter extends ParenPrettyPrinter { toDoc(definitions) <> emptyline <> toDoc(rest) case Return(e) => - toDoc(e) + "return" <+> toDoc(e) case Val(Wildcard(), _, binding, body) => toDoc(binding) <> ";" <> line <> toDoc(body) case Val(id, tpe, binding, body) => - "val" <+> toDoc(id) <> ":" <+> toDoc(tpe) <+> "=" <+> toDoc(binding) <> ";" <> line <> + "val" <+> toDoc(id) <> ":" <+> toDoc(tpe) <+> "=" <+> block(toDoc(binding)) <> ";" <> line <> toDoc(body) case App(b, targs, vargs, bargs) => diff --git a/effekt/shared/src/main/scala/effekt/core/Renamer.scala b/effekt/shared/src/main/scala/effekt/core/Renamer.scala index f43b9c07b..6d90f3548 100644 --- a/effekt/shared/src/main/scala/effekt/core/Renamer.scala +++ b/effekt/shared/src/main/scala/effekt/core/Renamer.scala @@ -18,6 +18,9 @@ class Renamer(names: Names = Names(Map.empty), prefix: String = "") extends core // list of scopes that map bound symbols to their renamed variants. private var scopes: List[Map[Id, Id]] = List.empty + // Here we track ALL renamings + var renamed: Map[Id, Id] = Map.empty + private var suffix: Int = 0 def freshIdFor(id: Id): Id = @@ -28,7 +31,9 @@ class Renamer(names: Names = Names(Map.empty), prefix: String = "") extends core def withBindings[R](ids: List[Id])(f: => R): R = val before = scopes try { - scopes = ids.map { x => x -> freshIdFor(x) }.toMap :: scopes + val newScope = ids.map { x => x -> freshIdFor(x) }.toMap + scopes = newScope :: scopes + renamed = renamed ++ newScope f } finally { scopes = before } @@ -108,4 +113,8 @@ class Renamer(names: Names = Names(Map.empty), prefix: String = "") extends core object Renamer { def rename(b: Block): Block = Renamer().rewrite(b) + def rename(b: BlockLit): (BlockLit, Map[Id, Id]) = + val renamer = Renamer() + val res = renamer.rewrite(b) + (res, renamer.renamed) } diff --git a/effekt/shared/src/main/scala/effekt/core/Tree.scala b/effekt/shared/src/main/scala/effekt/core/Tree.scala index 24d4cdfe9..fef09cfee 100644 --- a/effekt/shared/src/main/scala/effekt/core/Tree.scala +++ b/effekt/shared/src/main/scala/effekt/core/Tree.scala @@ -6,6 +6,8 @@ import effekt.util.Structural import effekt.util.messages.INTERNAL_ERROR import effekt.util.messages.ErrorReporter +import scala.annotation.tailrec + /** * Tree structure of programs in our internal core representation. * @@ -304,7 +306,7 @@ enum Stmt extends Tree { case Match(scrutinee: Pure, clauses: List[(Id, BlockLit)], default: Option[Stmt]) // (Type-monomorphic?) Regions - case Region(body: Block) + case Region(body: BlockLit) case Alloc(id: Id, init: Pure, region: Id, body: Stmt) // creates a fresh state handler to model local (backtrackable) state. @@ -336,147 +338,17 @@ enum Stmt extends Tree { export Stmt.* /** - * Smart constructors to establish some normal form + * A smart constructor for `stmt.Scope` that only introduces a scope if there are bindings */ -object normal { - - def valDef(id: Id, tpe: ValueType, binding: Stmt, body: Stmt): Stmt = - (binding, body) match { - - // [[ val x = STMT; return x ]] == STMT - case (_, Stmt.Return(Pure.ValueVar(other, _))) if other == id => - binding - - // [[ val x = return EXPR; STMT ]] = [[ let x = EXPR; STMT ]] - // - // This opt is too good for JS: it blows the stack on - // recursive functions that are used to encode while... - // - // The solution to this problem is implemented in core.MakeStackSafe: - // all recursive functions that could blow the stack are trivially wrapped - // again, after optimizing. - case (Stmt.Return(expr), body) => - scope(List(Definition.Let(id, tpe, expr)), body) - - // here we are flattening scopes; be aware that this extends - // life-times of bindings! - // - // { val x = { def...; BODY }; REST } = { def ...; val x = BODY } - case (Stmt.Scope(definitions, binding), body) => - scope(definitions, valDef(id, tpe, binding, body)) - - case _ => Stmt.Val(id, tpe, binding, body) - } - - // { def f=...; { def g=...; BODY } } = { def f=...; def g; BODY } - def scope(definitions: List[Definition], body: Stmt): Stmt = body match { - case Stmt.Scope(others, body) => scope(definitions ++ others, body) - case _ => if (definitions.isEmpty) body else Stmt.Scope(definitions, body) - } - - // TODO perform record selection here, if known - def select(target: Pure, field: Id, annotatedType: ValueType): Pure = - Select(target, field, annotatedType) - - def app(callee: Block, targs: List[ValueType], vargs: List[Pure], bargs: List[Block]): Stmt = - callee match { - case b : Block.BlockLit => reduce(b, targs, vargs, bargs) - case other => Stmt.App(callee, targs, vargs, bargs) - } - - def invoke(callee: Block, method: Id, methodTpe: BlockType, targs: List[ValueType], vargs: List[Pure], bargs: List[Block]): Stmt = - callee match { - case Block.New(impl) => - val Operation(name, tps, cps, vps, bps, body) = - impl.operations.find(op => op.name == method).getOrElse { - INTERNAL_ERROR("Should not happen") - } - reduce(BlockLit(tps, cps, vps, bps, body), targs, vargs, bargs) - case other => Invoke(callee, method, methodTpe, targs, vargs, bargs) - } - - def reset(body: BlockLit): Stmt = body match { - // case BlockLit(tparams, cparams, vparams, List(prompt), - // Stmt.Shift(prompt2, body) if prompt.id == prompt2.id => ??? - case other => Stmt.Reset(body) - } - - def make(tpe: ValueType.Data, tag: Id, vargs: List[Pure]): Pure = - Pure.Make(tpe, tag, vargs) - - def pureApp(callee: Block, targs: List[ValueType], vargs: List[Pure]): Pure = - callee match { - case b : Block.BlockLit => - INTERNAL_ERROR( - """|This should not happen! - |User defined functions always have to be called with App, not PureApp. - |If this error does occur, this means this changed. - |Check `core.Transformer.makeFunctionCall` for details. - |""".stripMargin) - case other => - Pure.PureApp(callee, targs, vargs) - } - - // "match" is a keyword in Scala - def patternMatch(scrutinee: Pure, clauses: List[(Id, BlockLit)], default: Option[Stmt]): Stmt = - scrutinee match { - case Pure.Make(dataType, ctorTag, vargs) => - clauses.collectFirst { case (tag, lit) if tag == ctorTag => lit } - .map(body => app(body, Nil, vargs, Nil)) - .orElse { default }.getOrElse { sys error "Pattern not exhaustive. This should not happen" } - case other => (clauses, default) match { - // Unit-like types: there is only one case and it is just a tag. - // sc match { case Unit() => body } ==> body - case ((id, lit) :: Nil, None) if lit.vparams.isEmpty => lit.body - case _ => Match(scrutinee, clauses, default) - } - } - - - def directApp(callee: Block, targs: List[ValueType], vargs: List[Pure], bargs: List[Block]): Expr = - callee match { - case b : Block.BlockLit => run(reduce(b, targs, vargs, Nil)) - case other => DirectApp(callee, targs, vargs, bargs) - } - - def reduce(b: BlockLit, targs: List[core.ValueType], vargs: List[Pure], bargs: List[Block]): Stmt = { - - // Only bind if not already a variable!!! - var ids: Set[Id] = Set.empty - var bindings: List[Definition.Def] = Nil - var bvars: List[Block.BlockVar] = Nil - - // (1) first bind - bargs foreach { - case x: Block.BlockVar => bvars = bvars :+ x - // introduce a binding - case block => - val id = symbols.TmpBlock("blockBinding") - bindings = bindings :+ Definition.Def(id, block) - bvars = bvars :+ Block.BlockVar(id, block.tpe, block.capt) - ids += id - } - - // (2) substitute - val body = substitutions.substitute(b, targs, vargs, bvars) - - scope(bindings, body) - } - - def run(s: Stmt): Expr = s match { - case Stmt.Return(expr) => expr - case _ => Run(s) - } - - def box(b: Block, capt: Captures): Pure = b match { - case Block.Unbox(pure) => pure - case b => Box(b, capt) - } - - def unbox(p: Pure): Block = p match { - case Pure.Box(b, _) => b - case p => Unbox(p) - } +def MaybeScope(definitions: List[Definition], body: Stmt): Stmt = body match { + // flatten scopes + // { def f = ...; { def g = ...; BODY } } = { def f = ...; def g; BODY } + case Stmt.Scope(others, body) => MaybeScope(definitions ++ others, body) + + // Drop scope if empty + // { ; BODY } = BODY + case _ if definitions.isEmpty => body + case _ => Stmt.Scope(definitions, body) } /** @@ -603,6 +475,55 @@ object Tree { case (p, b) => (p, rewrite(b)) } } + + class RewriteWithContext[Ctx] extends Structural { + def id(using Ctx): PartialFunction[Id, Id] = PartialFunction.empty + def pure(using Ctx): PartialFunction[Pure, Pure] = PartialFunction.empty + def expr(using Ctx): PartialFunction[Expr, Expr] = PartialFunction.empty + def stmt(using Ctx): PartialFunction[Stmt, Stmt] = PartialFunction.empty + def defn(using Ctx): PartialFunction[Definition, Definition] = PartialFunction.empty + def block(using Ctx): PartialFunction[Block, Block] = PartialFunction.empty + def handler(using Ctx): PartialFunction[Implementation, Implementation] = PartialFunction.empty + def param(using Ctx): PartialFunction[Param, Param] = PartialFunction.empty + + def rewrite(x: Id)(using Ctx): Id = if id.isDefinedAt(x) then id(x) else x + def rewrite(p: Pure)(using Ctx): Pure = rewriteStructurally(p, pure) + def rewrite(e: Expr)(using Ctx): Expr = rewriteStructurally(e, expr) + def rewrite(s: Stmt)(using Ctx): Stmt = rewriteStructurally(s, stmt) + def rewrite(b: Block)(using Ctx): Block = rewriteStructurally(b, block) + def rewrite(d: Definition)(using Ctx): Definition = rewriteStructurally(d, defn) + def rewrite(e: Implementation)(using Ctx): Implementation = rewriteStructurally(e, handler) + def rewrite(o: Operation)(using Ctx): Operation = rewriteStructurally(o) + def rewrite(p: Param)(using Ctx): Param = rewriteStructurally(p, param) + def rewrite(p: Param.ValueParam)(using Ctx): Param.ValueParam = rewrite(p: Param).asInstanceOf[Param.ValueParam] + def rewrite(p: Param.BlockParam)(using Ctx): Param.BlockParam = rewrite(p: Param).asInstanceOf[Param.BlockParam] + def rewrite(b: ExternBody)(using Ctx): ExternBody= rewrite(b) + + def rewrite(b: BlockLit)(using Ctx): BlockLit = if block.isDefinedAt(b) then block(b).asInstanceOf else b match { + case BlockLit(tparams, cparams, vparams, bparams, body) => + BlockLit(tparams map rewrite, cparams map rewrite, vparams map rewrite, bparams map rewrite, rewrite(body)) + } + def rewrite(b: BlockVar)(using Ctx): BlockVar = if block.isDefinedAt(b) then block(b).asInstanceOf else b match { + case BlockVar(id, annotatedTpe, annotatedCapt) => BlockVar(rewrite(id), rewrite(annotatedTpe), rewrite(annotatedCapt)) + } + + def rewrite(t: ValueType)(using Ctx): ValueType = rewriteStructurally(t) + def rewrite(t: ValueType.Data)(using Ctx): ValueType.Data = rewriteStructurally(t) + + def rewrite(t: BlockType)(using Ctx): BlockType = rewriteStructurally(t) + def rewrite(t: BlockType.Interface)(using Ctx): BlockType.Interface = rewriteStructurally(t) + def rewrite(capt: Captures)(using Ctx): Captures = capt.map(rewrite) + + def rewrite(m: ModuleDecl)(using Ctx): ModuleDecl = + m match { + case ModuleDecl(path, includes, declarations, externs, definitions, exports) => + ModuleDecl(path, includes, declarations, externs, definitions.map(rewrite), exports) + } + + def rewrite(matchClause: (Id, BlockLit))(using Ctx): (Id, BlockLit) = matchClause match { + case (p, b) => (p, rewrite(b)) + } + } } enum Variable { @@ -838,6 +759,15 @@ object substitutions { case h : Hole => h } + def substitute(b: BlockLit)(using subst: Substitution): BlockLit = b match { + case BlockLit(tparams, cparams, vparams, bparams, body) => + val shadowedTypelevel = subst shadowTypes tparams shadowCaptures cparams + BlockLit(tparams, cparams, + vparams.map(p => substitute(p)(using shadowedTypelevel)), + bparams.map(p => substitute(p)(using shadowedTypelevel)), + substitute(body)(using shadowedTypelevel shadowParams (vparams ++ bparams))) + } + def substituteAsVar(id: Id)(using subst: Substitution): Id = subst.blocks.get(id) map { case BlockVar(x, _, _) => x @@ -848,19 +778,9 @@ object substitutions { block match { case BlockVar(id, tpe, capt) if subst.blocks.isDefinedAt(id) => subst.blocks(id) case BlockVar(id, tpe, capt) => BlockVar(id, substitute(tpe), substitute(capt)) - - case BlockLit(tparams, cparams, vparams, bparams, body) => - val shadowedTypelevel = subst shadowTypes tparams shadowCaptures cparams - BlockLit(tparams, cparams, - vparams.map(p => substitute(p)(using shadowedTypelevel)), - bparams.map(p => substitute(p)(using shadowedTypelevel)), - substitute(body)(using shadowedTypelevel shadowParams (vparams ++ bparams))) - - case Unbox(pure) => - Unbox(substitute(pure)) - - case New(impl) => - New(substitute(impl)) + case b: BlockLit => substitute(b) + case Unbox(pure) => Unbox(substitute(pure)) + case New(impl) => New(substitute(impl)) } def substitute(pure: Pure)(using subst: Substitution): Pure = diff --git a/effekt/shared/src/main/scala/effekt/core/optimizer/BindSubexpressions.scala b/effekt/shared/src/main/scala/effekt/core/optimizer/BindSubexpressions.scala new file mode 100644 index 000000000..52ebfd34e --- /dev/null +++ b/effekt/shared/src/main/scala/effekt/core/optimizer/BindSubexpressions.scala @@ -0,0 +1,199 @@ +package effekt +package core +package optimizer + +// Establishes a normal form in which every subexpression +// is explicitly named and aliasing (val x = y) is removed. +// +// let x = Cons(1, Cons(2, Cons(3, Nil()))) +// +// -> +// let x1 = Nil() +// let x2 = Cons(3, x1) +// let x3 = Cons(2, x2) +// let x = Cons(1, x3) +object BindSubexpressions { + + type Env = Map[Id, Id] + def alias(from: Id, to: Id, env: Env): Env = + env + (from -> env.getOrElse(to, to)) + + def transform(m: ModuleDecl): ModuleDecl = m match { + case ModuleDecl(path, includes, declarations, externs, definitions, exports) => + val (newDefs, env) = transformDefs(definitions)(using Map.empty) + ModuleDecl(path, includes, declarations, externs, newDefs, exports) + } + + def transformDefs(definitions: List[Definition])(using env: Env): (List[Definition], Env) = + var definitionsSoFar = List.empty[Definition] + var envSoFar = env + + definitions.foreach { + case Definition.Def(id, block) => + transform(block)(using envSoFar) match { + case Bind(Block.BlockVar(x, _, _), defs) => + definitionsSoFar ++= defs + envSoFar = alias(id, x, envSoFar) + + case Bind(other, defs) => + definitionsSoFar = definitionsSoFar ++ (defs :+ Definition.Def(id, other)) + } + case Definition.Let(id, tpe, expr) => + transform(expr)(using envSoFar) match { + case Bind(Pure.ValueVar(x, _), defs) => + definitionsSoFar ++= defs + envSoFar = alias(id, x, envSoFar) + case Bind(other, defs) => + definitionsSoFar = definitionsSoFar ++ (defs :+ Definition.Let(id, transform(tpe)(using envSoFar), other)) + } + } + (definitionsSoFar, envSoFar) + + def transform(s: Stmt)(using env: Env): Stmt = s match { + case Stmt.Scope(definitions, body) => + val (newDefs, newEnv) = transformDefs(definitions) + MaybeScope(newDefs, transform(body)(using newEnv)) + + case Stmt.App(callee, targs, vargs, bargs) => delimit { + for { + c <- transform(callee) + vs <- transformExprs(vargs) + bs <- transformBlocks(bargs) + } yield Stmt.App(c, targs.map(transform), vs, bs) + } + + case Stmt.Invoke(callee, method, methodTpe, targs, vargs, bargs) => delimit { + for { + c <- transform(callee) + vs <- transformExprs(vargs) + bs <- transformBlocks(bargs) + } yield Stmt.Invoke(c, method, transform(methodTpe), targs.map(transform), vs, bs) + } + + case Stmt.Return(expr) => transform(expr).run { res => Stmt.Return(res) } + case Stmt.Alloc(id, init, region, body) => transform(init).run { v => Stmt.Alloc(id, v, transform(region), transform(body)) } + case Stmt.Var(id, init, capture, body) => transform(init).run { v => Stmt.Var(id, v, transform(capture), transform(body)) } + case Stmt.Get(id, capt, tpe) => Stmt.Get(id, transform(capt), transform(tpe)) + case Stmt.Put(id, capt, value) => transform(value).run { v => Stmt.Put(id, transform(capt), v) } + + case Stmt.If(cond, thn, els) => transform(cond).run { c => + Stmt.If(c, transform(thn), transform(els)) + } + case Stmt.Match(scrutinee, clauses, default) => transform(scrutinee).run { sc => + Stmt.Match(sc, clauses.map { case (tag, rhs) => (tag, transform(rhs)) }, default.map(transform)) + } + + // Congruences + case Stmt.Region(body) => Stmt.Region(transform(body)) + case Stmt.Val(id, tpe, binding, body) => Stmt.Val(id, transform(tpe), transform(binding), transform(body)) + case Stmt.Reset(body) => Stmt.Reset(transform(body)) + case Stmt.Shift(prompt, body) => Stmt.Shift(transform(prompt), transform(body)) + case Stmt.Resume(k, body) => Stmt.Resume(transform(k), transform(body)) + case Stmt.Hole() => Stmt.Hole() + } + + def transform(b: Block)(using Env): Bind[Block] = b match { + case b: Block.BlockVar => pure(transform(b)) + case b: Block.BlockLit => pure(transform(b)) + case Block.New(impl) => pure(Block.New(transform(impl))) + case Block.Unbox(pure) => transform(pure) { v => bind(Block.Unbox(v)) } + } + + def transform(b: BlockLit)(using Env): BlockLit = b match { + case BlockLit(tparams, cparams, vparams, bparams, body) => + BlockLit(tparams, cparams, vparams.map(transform), bparams.map(transform), transform(body)) + } + + def transform(b: BlockVar)(using Env): BlockVar = b match { + case BlockVar(id, annotatedTpe, annotatedCapt) => + BlockVar(transform(id), transform(annotatedTpe), transform(annotatedCapt)) + } + + def transform(impl: Implementation)(using Env): Implementation = impl match { + case Implementation(interface, operations) => Implementation(transform(interface).asInstanceOf, operations.map { + case Operation(name, tparams, cparams, vparams, bparams, body) => + Operation(name, tparams, cparams, vparams.map(transform), bparams.map(transform), transform(body)) + }) + } + + def transform(p: ValueParam)(using Env): ValueParam = p match { + case ValueParam(id, tpe) => ValueParam(id, transform(tpe)) + } + def transform(p: BlockParam)(using Env): BlockParam = p match { + case BlockParam(id, tpe, capt) => BlockParam(id, transform(tpe), transform(capt)) + } + + def transform(id: Id)(using env: Env): Id = env.getOrElse(id, id) + + def transform(e: Expr)(using Env): Bind[ValueVar | Literal] = e match { + case Pure.ValueVar(id, tpe) => pure(ValueVar(transform(id), transform(tpe))) + case Pure.Literal(value, tpe) => pure(Pure.Literal(value, transform(tpe))) + + case Pure.Make(data, tag, vargs) => transformExprs(vargs) { vs => + bind(Pure.Make(data, tag, vs)) + } + case DirectApp(block, targs, vargs, bargs) => for { + b <- transform(block); + vs <- transformExprs(vargs); + bs <- transformBlocks(bargs); + res <- bind(DirectApp(b, targs.map(transform), vs, bs)) + } yield res + case Pure.PureApp(block, targs, vargs) => for { + b <- transform(block); + vs <- transformExprs(vargs); + res <- bind(Pure.PureApp(b, targs.map(transform), vs)) + } yield res + case Pure.Select(target, field, tpe) => transform(target) { v => bind(Pure.Select(v, field, transform(tpe))) } + case Pure.Box(block, capt) => transform(block) { b => bind(Pure.Box(b, transform(capt))) } + + case Run(s) => bind(Run(transform(s))) + } + + def transformExprs(es: List[Expr])(using Env): Bind[List[ValueVar | Literal]] = traverse(es)(transform) + def transformBlocks(es: List[Block])(using Env): Bind[List[Block]] = traverse(es)(transform) + + // Types + // ----- + // Types mention captures and captures might require renaming after dealiasing + def transform(tpe: ValueType)(using Env): ValueType = tpe match { + case ValueType.Var(name) => ValueType.Var(transform(name)) + case ValueType.Data(name, targs) => ValueType.Data(name, targs.map(transform)) + case ValueType.Boxed(tpe, capt) => ValueType.Boxed(transform(tpe), transform(capt)) + } + def transform(tpe: BlockType)(using Env): BlockType = tpe match { + case BlockType.Function(tparams, cparams, vparams, bparams, result) => + BlockType.Function(tparams, cparams, vparams.map(transform), bparams.map(transform), transform(result)) + case BlockType.Interface(name, targs) => + BlockType.Interface(name, targs.map(transform)) + } + def transform(captures: Captures)(using Env): Captures = captures.map(transform) + + + // Binding Monad + // ------------- + case class Bind[+A](value: A, definitions: List[Definition]) { + def run(f: A => Stmt): Stmt = MaybeScope(definitions, f(value)) + def map[B](f: A => B): Bind[B] = Bind(f(value), definitions) + def flatMap[B](f: A => Bind[B]): Bind[B] = + val Bind(result, other) = f(value) + Bind(result, definitions ++ other) + def apply[B](f: A => Bind[B]): Bind[B] = flatMap(f) + } + def pure[A](value: A): Bind[A] = Bind(value, Nil) + def bind[A](expr: Expr): Bind[ValueVar] = + val id = Id("tmp") + Bind(ValueVar(id, expr.tpe), List(Definition.Let(id, expr.tpe, expr))) + + def bind[A](block: Block): Bind[BlockVar] = + val id = Id("tmp") + Bind(BlockVar(id, block.tpe, block.capt), List(Definition.Def(id, block))) + + def delimit(b: Bind[Stmt]): Stmt = b.run(a => a) + + def traverse[S, T](l: List[S])(f: S => Bind[T]): Bind[List[T]] = + l match { + case Nil => pure(Nil) + case head :: tail => for { x <- f(head); xs <- traverse(tail)(f) } yield x :: xs + } + +} diff --git a/effekt/shared/src/main/scala/effekt/core/Deadcode.scala b/effekt/shared/src/main/scala/effekt/core/optimizer/Deadcode.scala similarity index 68% rename from effekt/shared/src/main/scala/effekt/core/Deadcode.scala rename to effekt/shared/src/main/scala/effekt/core/optimizer/Deadcode.scala index 8765d570a..649e5c28b 100644 --- a/effekt/shared/src/main/scala/effekt/core/Deadcode.scala +++ b/effekt/shared/src/main/scala/effekt/core/optimizer/Deadcode.scala @@ -1,22 +1,23 @@ package effekt package core +package optimizer -import effekt.core.Block.BlockLit -import effekt.core.Pure.ValueVar -import effekt.core.normal.* - -class Deadcode(entrypoints: Set[Id], definitions: Map[Id, Definition]) extends core.Tree.Rewrite { - - val reachable = Reachable(entrypoints, definitions) +class Deadcode(reachable: Map[Id, Usage]) extends core.Tree.Rewrite { override def stmt = { // Remove local unused definitions case Scope(defs, stmt) => - scope(defs.collect { + MaybeScope(defs.collect { case d: Definition.Def if reachable.isDefinedAt(d.id) => rewrite(d) // we only keep non-pure OR reachable let bindings case d: Definition.Let if d.capt.nonEmpty || reachable.isDefinedAt(d.id) => rewrite(d) }, rewrite(stmt)) + + case Reset(body) => + rewrite(body) match { + case BlockLit(tparams, cparams, vparams, List(prompt), body) if !reachable.isDefinedAt(prompt.id) => body + case b => Stmt.Reset(b) + } } override def rewrite(m: ModuleDecl): ModuleDecl = m.copy( @@ -31,7 +32,9 @@ class Deadcode(entrypoints: Set[Id], definitions: Map[Id, Definition]) extends c object Deadcode { def remove(entrypoints: Set[Id], m: ModuleDecl): ModuleDecl = - Deadcode(entrypoints, m.definitions.map(d => d.id -> d).toMap).rewrite(m) + val reachable = Reachable(entrypoints, m) + Deadcode(reachable).rewrite(m) + def remove(entrypoint: Id, m: ModuleDecl): ModuleDecl = remove(Set(entrypoint), m) } diff --git a/effekt/shared/src/main/scala/effekt/core/optimizer/DropBindings.scala b/effekt/shared/src/main/scala/effekt/core/optimizer/DropBindings.scala new file mode 100644 index 000000000..740586c47 --- /dev/null +++ b/effekt/shared/src/main/scala/effekt/core/optimizer/DropBindings.scala @@ -0,0 +1,66 @@ +package effekt +package core +package optimizer + +import context.Context + +/** + * This phase drops (inlines) unique value bindings. + * + * let x = 42 + * let y = x + 1 + * let z = y * 2 + * z + * + * --> + * + * (42 + 1) * 2 + * + * This improves the performance for JS on some benchmarks (mostly match_options, sum_range, and parsing_dollars). + */ +object DropBindings extends Phase[CoreTransformed, CoreTransformed] { + + val phaseName: String = "drop-bindings" + + def run(input: CoreTransformed)(using C: Context): Option[CoreTransformed] = + input match { + case CoreTransformed(source, tree, mod, core) => + val main = C.checkMain(mod) + Some(CoreTransformed(source, tree, mod, apply(Set(main), core))) + } + + def apply(entrypoints: Set[Id], m: ModuleDecl): ModuleDecl = + dropping.rewrite(m)(using DropContext(Reachable(entrypoints, m), Map.empty)) + + private case class DropContext( + usage: Map[Id, Usage], + definitions: Map[Id, Pure] + ) { + def updated(id: Id, p: Pure): DropContext = this.copy(definitions = definitions.updated(id, p)) + } + + private def hasDefinition(id: Id)(using C: DropContext) = C.definitions.isDefinedAt(id) + private def definitionOf(id: Id)(using C: DropContext): Pure = C.definitions(id) + private def usedOnce(id: Id)(using C: DropContext) = C.usage.get(id).contains(Usage.Once) + private def currentContext(using C: DropContext): C.type = C + + private object dropping extends Tree.RewriteWithContext[DropContext] { + + override def pure(using DropContext) = { + case Pure.ValueVar(id, tpe) if usedOnce(id) && hasDefinition(id) => definitionOf(id) + } + + override def stmt(using DropContext) = { + case Stmt.Scope(defs, body) => + var contextSoFar = currentContext + val ds = defs.flatMap { + case Definition.Let(id, tpe, p: Pure) if usedOnce(id) => + val transformed = rewrite(p)(using contextSoFar) + contextSoFar = contextSoFar.updated(id, transformed) + None + case d => Some(rewrite(d)(using contextSoFar)) + } + MaybeScope(ds, rewrite(body)(using contextSoFar)) + } + } +} diff --git a/effekt/shared/src/main/scala/effekt/core/optimizer/Normalizer.scala b/effekt/shared/src/main/scala/effekt/core/optimizer/Normalizer.scala new file mode 100644 index 000000000..2744f255b --- /dev/null +++ b/effekt/shared/src/main/scala/effekt/core/optimizer/Normalizer.scala @@ -0,0 +1,402 @@ +package effekt +package core +package optimizer + +import effekt.util.messages.INTERNAL_ERROR + +import scala.annotation.tailrec +import scala.collection.mutable + +/** + * Removes "cuts", that is it performs a step of computation if enough information + * is available. + * + * def foo(n: Int) = return n + 1 + * + * foo(42) + * + * becomes + * + * def foo(n: Int) = return n + 1 + * return 42 + 1 + * + * removing the overhead of the function call. Under the following conditions, + * cuts are _not_ removed: + * + * - the definition is recursive + * - inlining would exceed the maxInlineSize + * + * If the function is called _exactly once_, it is inlined regardless of the maxInlineSize. + */ +object Normalizer { normal => + + case class Context( + blocks: Map[Id, Block], + exprs: Map[Id, Expr], + decls: DeclarationContext, // for field selection + usage: mutable.Map[Id, Usage], // mutable in order to add new information after renaming + maxInlineSize: Int // to control inlining and avoid code bloat + ) { + def bind(id: Id, expr: Expr): Context = copy(exprs = exprs + (id -> expr)) + def bind(id: Id, block: Block): Context = copy(blocks = blocks + (id -> block)) + } + + private def blockFor(id: Id)(using ctx: Context): Option[Block] = + ctx.blocks.get(id) + + private def exprFor(id: Id)(using ctx: Context): Option[Expr] = + ctx.exprs.get(id) + + private def isRecursive(id: Id)(using ctx: Context): Boolean = + ctx.usage.get(id) match { + case Some(value) => value == Usage.Recursive + // We assume it is recursive, if (for some reason) we do not have information; + // since reducing might diverge, otherwise. + // + // This is, however, a strange case since this means we call a function we deemed unreachable. + // It _can_ happen, for instance, by updating the usage (subtracting) and not deadcode eliminating. + // This is the case for examples/pos/bidirectional/scheduler.effekt + case None => true // sys error s"No info for ${id}" + } + + private def isOnce(id: Id)(using ctx: Context): Boolean = + ctx.usage.get(id) match { + case Some(value) => value == Usage.Once + case None => false + } + + def normalize(entrypoints: Set[Id], m: ModuleDecl, maxInlineSize: Int): ModuleDecl = { + // usage information is used to detect recursive functions (and not inline them) + val usage = Reachable(entrypoints, m) + + val defs = m.definitions.collect { + case Definition.Def(id, block) => id -> block + }.toMap + val context = Context(defs, Map.empty, DeclarationContext(m.declarations, m.externs), mutable.Map.from(usage), maxInlineSize) + + val (normalizedDefs, _) = normalize(m.definitions)(using context) + m.copy(definitions = normalizedDefs) + } + + def normalize(definitions: List[Definition])(using ctx: Context): (List[Definition], Context) = + var contextSoFar = ctx + val defs = definitions.map { + case Definition.Def(id, block) => + val normalized = active(block)(using contextSoFar).dealiased + contextSoFar = contextSoFar.bind(id, normalized) + Definition.Def(id, normalized) + case Definition.Let(id, tpe, expr) => + val normalized = active(expr)(using contextSoFar) + contextSoFar = contextSoFar.bind(id, normalized) + Definition.Let(id, tpe, normalized) + } + (defs, contextSoFar) + + private enum NormalizedBlock { + case Known(b: BlockLit | New | Unbox, boundBy: Option[BlockVar]) + case Unknown(b: BlockVar) + + def dealiased: Block = this match { + case NormalizedBlock.Known(b, boundBy) => b + case NormalizedBlock.Unknown(b) => b + } + def shared: Block = this match { + case NormalizedBlock.Known(b, boundBy) => boundBy.getOrElse(b) + case NormalizedBlock.Unknown(b) => b + } + } + + /** + * This is a bit tricky: depending on the call-site of `active` + * we either want to find a redex (BlockLit | New), maximally dealias (in def bindings), + * discover the outmost Unbox (when boxing again), or preserve some sharing otherwise. + * + * A good testcase to look at for this is: + * examples/pos/capture/regions.effekt + */ + private def active[R](b: Block)(using C: Context): NormalizedBlock = + normalize(b) match { + case b: Block.BlockLit => NormalizedBlock.Known(b, None) + case b @ Block.New(impl) => NormalizedBlock.Known(b, None) + + case x @ Block.BlockVar(id, annotatedTpe, annotatedCapt) => blockFor(id) match { + case Some(b: (BlockLit | New | Unbox)) => NormalizedBlock.Known(b, Some(x)) + case _ => NormalizedBlock.Unknown(x) + } + case Block.Unbox(pure) => active(pure) match { + case Pure.Box(b, annotatedCapture) => active(b) + case other => NormalizedBlock.Known(Block.Unbox(pure), None) + } + } + + // TODO for `New` we should track how often each operation is used, not the object itself + // to decide inlining. + private def shouldInline(b: BlockLit, boundBy: Option[BlockVar])(using C: Context): Boolean = boundBy match { + case Some(id) if isRecursive(id.id) => false + case Some(id) => isOnce(id.id) || b.body.size <= C.maxInlineSize + case None => true + } + + private def active(e: Expr)(using Context): Expr = + normalize(e) match { + case x @ Pure.ValueVar(id, annotatedType) => exprFor(id) match { + case Some(p: Pure.Make) => p + case Some(p: Pure.Literal) => p + case Some(p: Pure.Box) => p + // We only inline non side-effecting expressions + case Some(other) if other.capt.isEmpty => other + case _ => x // stuck + } + case other => other // stuck + } + + def normalize(d: Definition)(using C: Context): Definition = d match { + case Definition.Def(id, block) => Definition.Def(id, normalize(block)) + case Definition.Let(id, tpe, binding) => Definition.Let(id, tpe, normalize(binding)) + } + + def normalize(s: Stmt)(using C: Context): Stmt = s match { + + case Stmt.Scope(definitions, body) => + val (defs, ctx) = normalize(definitions) + normal.Scope(defs, normalize(body)(using ctx)) + + // Redexes + // ------- + case Stmt.App(b, targs, vargs, bargs) => + active(b) match { + case NormalizedBlock.Known(b: BlockLit, boundBy) if shouldInline(b, boundBy) => + reduce(b, targs, vargs.map(normalize), bargs.map(normalize)) + case normalized => + Stmt.App(normalized.shared, targs, vargs.map(normalize), bargs.map(normalize)) + } + + case Stmt.Invoke(b, method, methodTpe, targs, vargs, bargs) => + active(b) match { + case n @ NormalizedBlock.Known(Block.New(impl), boundBy) => + selectOperation(impl, method) match { + case b: BlockLit if shouldInline(b, boundBy) => reduce(b, targs, vargs.map(normalize), bargs.map(normalize)) + case _ => Stmt.Invoke(n.shared, method, methodTpe, targs, vargs.map(normalize), bargs.map(normalize)) + } + + case normalized => + Stmt.Invoke(normalized.shared, method, methodTpe, targs, vargs.map(normalize), bargs.map(normalize)) + } + + case Stmt.Match(scrutinee, clauses, default) => active(scrutinee) match { + case Pure.Make(data, tag, vargs) if clauses.exists { case (id, _) => id == tag } => + val clause: BlockLit = clauses.collectFirst { case (id, cl) if id == tag => cl }.get + normalize(reduce(clause, Nil, vargs.map(normalize), Nil)) + case Pure.Make(data, tag, vargs) if default.isDefined => + normalize(default.get) + case _ => + val normalized = normalize(scrutinee) + Stmt.Match(normalized, clauses.map { case (id, value) => id -> normalize(value) }, default.map(normalize)) + } + + // [[ if (true) stmt1 else stmt2 ]] = [[ stmt1 ]] + case Stmt.If(cond, thn, els) => active(cond) match { + case Pure.Literal(true, annotatedType) => normalize(thn) + case Pure.Literal(false, annotatedType) => normalize(els) + case _ => If(normalize(cond), normalize(thn), normalize(els)) + } + + case Stmt.Val(id, tpe, binding, body) => + + def normalizeVal(id: Id, tpe: ValueType, binding: Stmt, body: Stmt): Stmt = binding match { + + // [[ val x = return e; s ]] = let x = [[ e ]]; [[ s ]] + case Stmt.Return(expr) => + normal.Scope(List(Definition.Let(id, tpe, expr)), + normalize(body)(using C.bind(id, expr))) + + // Commute val and bindings + // [[ val x = { def f = ...; let y = ...; STMT }; STMT ]] = def f = ...; let y = ...; val x = STMT; STMT + case Stmt.Scope(ds, bodyBinding) => + normal.Scope(ds, normalizeVal(id, tpe, bodyBinding, body)) + + // Flatten vals. This should be non-leaking since we use garbage free refcounting. + // [[ val x = { val y = stmt1; stmt2 }; stmt3 ]] = [[ val y = stmt1; val x = stmt2; stmt3 ]] + case Stmt.Val(id2, tpe2, binding2, body2) => + normalizeVal(id2, tpe2, binding2, Stmt.Val(id, tpe, body2, body)) + + + // [[ val x = { var y in r = e; stmt1 }; stmt2 ]] = var y in r = e; [[ val x = stmt1; stmt2 ]] + case Stmt.Alloc(id2, init2, region2, body2) => + Stmt.Alloc(id2, init2, region2, normalizeVal(id, tpe, body2, body)) + + // [[ val x = stmt; return x ]] = [[ stmt ]] + case other => normalize(body) match { + case Stmt.Return(x: ValueVar) if x.id == id => other + case normalizedBody => Stmt.Val(id, tpe, other, normalizedBody) + } + } + normalizeVal(id, tpe, normalize(binding), body) + + + // "Congruences" + // ------------- + + case Stmt.Reset(body) => Stmt.Reset(normalize(body)) + case Stmt.Shift(prompt, body) => Shift(prompt, normalize(body)) + case Stmt.Return(expr) => Return(normalize(expr)) + case Stmt.Alloc(id, init, region, body) => Alloc(id, normalize(init), region, normalize(body)) + case Stmt.Resume(k, body) => Resume(k, normalize(body)) + case Stmt.Region(body) => Region(normalize(body)) + case Stmt.Var(id, init, capture, body) => Stmt.Var(id, normalize(init), capture, normalize(body)) + case Stmt.Get(id, capt, tpe) => Stmt.Get(id, capt, tpe) + case Stmt.Put(id, capt, value) => Stmt.Put(id, capt, normalize(value)) + case Stmt.Hole() => s + } + def normalize(b: BlockLit)(using Context): BlockLit = + b match { + case BlockLit(tparams, cparams, vparams, bparams, body) => + BlockLit(tparams, cparams, vparams, bparams, normalize(body)) + } + + def normalize(b: Block)(using Context): Block = b match { + case b @ Block.BlockVar(id, _, _) => b + case b @ Block.BlockLit(tparams, cparams, vparams, bparams, body) => normalize(b) + + // [[ unbox (box b) ]] = [[ b ]] + case Block.Unbox(pure) => normal.Unbox(normalize(pure)) + case Block.New(impl) => New(normalize(impl)) + } + + def normalize(s: Implementation)(using Context): Implementation = + s match { + case Implementation(interface, operations) => Implementation(interface, operations.map { op => + op.copy(body = normalize(op.body)) + }) + } + + def normalize(p: Pure)(using ctx: Context): Pure = p match { + // [[ Constructor(f = v).f ]] = [[ v ]] + case Pure.Select(target, field, annotatedType) => active(target) match { + case Pure.Make(datatype, tag, fields) => + val constructor = ctx.decls.findConstructor(tag).get + val expr = (constructor.fields zip fields).collectFirst { case (f, expr) if f.id == field => expr }.get + normalize(expr) + case _ => Pure.Select(normalize(target), field, annotatedType) + } + + // [[ box (unbox e) ]] = [[ e ]] + case Pure.Box(b, annotatedCapture) => active(b) match { + case NormalizedBlock.Known(Unbox(p), boundBy) => p + case _ => normal.Box(normalize(b), annotatedCapture) + } + + // congruences + case Pure.PureApp(b, targs, vargs) => Pure.PureApp(normalize(b), targs, vargs.map(normalize)) + case Pure.Make(data, tag, vargs) => Pure.Make(data, tag, vargs.map(normalize)) + case Pure.ValueVar(id, annotatedType) => p + case Pure.Literal(value, annotatedType) => p + } + + def normalize(e: Expr)(using Context): Expr = e match { + case DirectApp(b, targs, vargs, bargs) => DirectApp(normalize(b), targs, vargs.map(normalize), bargs.map(normalize)) + + // [[ run (return e) ]] = [[ e ]] + case Run(s) => normal.Run(normalize(s)) + + case pure: Pure => normalize(pure) + } + + + // Smart Constructors + // ------------------ + @tailrec + private def Scope(definitions: List[Definition], body: Stmt): Stmt = body match { + + // flatten scopes + // { def f = ...; { def g = ...; BODY } } = { def f = ...; def g; BODY } + case Stmt.Scope(others, body) => normal.Scope(definitions ++ others, body) + + // commute bindings + // let x = run { let y = e; s } = let y = e; let x = run { s } + case _ => if (definitions.isEmpty) body else { + var defsSoFar: List[Definition] = Nil + + definitions.foreach { + case Definition.Let(id, tpe, Run(Stmt.Scope(ds, body))) => + defsSoFar = defsSoFar ++ (ds :+ Definition.Let(id, tpe, normal.Run(body))) + case d => defsSoFar = defsSoFar :+ d + } + Stmt.Scope(defsSoFar, body) + } + } + + private def Run(s: Stmt): Expr = s match { + + // run { let x = e; return x } = e + case Stmt.Scope(Definition.Let(id1, _, binding) :: Nil, Stmt.Return(Pure.ValueVar(id2, _))) if id1 == id2 => + binding + + // run { return e } = e + case Stmt.Return(expr) => expr + + case _ => core.Run(s) + } + + // box (unbox p) = p + private def Box(b: Block, capt: Captures): Pure = b match { + case Block.Unbox(pure) => pure + case b => Pure.Box(b, capt) + } + + // unbox (box b) = b + private def Unbox(p: Pure): Block = p match { + case Pure.Box(b, _) => b + case p => Block.Unbox(p) + } + + + // Helpers for beta-reduction + // -------------------------- + + private def reduce(b: BlockLit, targs: List[core.ValueType], vargs: List[Pure], bargs: List[Block])(using C: Context): Stmt = { + // To update usage information + val usage = C.usage + def copyUsage(from: Id, to: Id) = usage.get(from) match { + case Some(info) => usage.update(to, info) + case None => () + } + + // Only bind if not already a variable!!! + var ids: Set[Id] = Set.empty + var bindings: List[Definition.Def] = Nil + var bvars: List[Block.BlockVar] = Nil + + // (1) first bind + (b.bparams zip bargs) foreach { + case (bparam, x: Block.BlockVar) => + + // Update usage: u1 + (u2 - 1) + usage.update(x.id, usage.getOrElse(bparam.id, Usage.Never) + usage.getOrElse(x.id, Usage.Never).decrement) + bvars = bvars :+ x + // introduce a binding + case (bparam, block) => + val id = symbols.TmpBlock("blockBinding") + bindings = bindings :+ Definition.Def(id, block) + bvars = bvars :+ Block.BlockVar(id, block.tpe, block.capt) + copyUsage(bparam.id, id) + ids += id + } + + val (renamedLit: BlockLit, renamedIds) = Renamer.rename(b) + + renamedIds.foreach(copyUsage) + + val newUsage = usage.collect { case (id, usage) if util.show(id) contains "foreach" => (id, usage) } + + // (2) substitute + val body = substitutions.substitute(renamedLit, targs, vargs, bvars) + + normalize(normal.Scope(bindings, body)) + } + + private def selectOperation(impl: Implementation, method: Id): Block.BlockLit = + impl.operations.collectFirst { + case Operation(name, tps, cps, vps, bps, body) if name == method => BlockLit(tps, cps, vps, bps, body): Block.BlockLit + }.getOrElse { INTERNAL_ERROR("Should not happen") } +} diff --git a/effekt/shared/src/main/scala/effekt/core/optimizer/Optimizer.scala b/effekt/shared/src/main/scala/effekt/core/optimizer/Optimizer.scala new file mode 100644 index 000000000..8d0d6d029 --- /dev/null +++ b/effekt/shared/src/main/scala/effekt/core/optimizer/Optimizer.scala @@ -0,0 +1,54 @@ +package effekt +package core +package optimizer + +import effekt.PhaseResult.CoreTransformed +import effekt.context.Context + +import kiama.util.Source + +object Optimizer extends Phase[CoreTransformed, CoreTransformed] { + + val phaseName: String = "core-optimizer" + + def run(input: CoreTransformed)(using Context): Option[CoreTransformed] = + input match { + case CoreTransformed(source, tree, mod, core) => + val term = Context.checkMain(mod) + val optimized = Context.timed("optimize", source.name) { optimize(source, term, core) } + Some(CoreTransformed(source, tree, mod, optimized)) + } + + def optimize(source: Source, mainSymbol: symbols.Symbol, core: ModuleDecl)(using Context): ModuleDecl = + + var tree = core + + // (1) first thing we do is simply remove unused definitions (this speeds up all following analysis and rewrites) + tree = Context.timed("deadcode-elimination", source.name) { + Deadcode.remove(mainSymbol, tree) + } + + if !Context.config.optimize() then return tree; + + // (2) lift static arguments + tree = Context.timed("static-argument-transformation", source.name) { + StaticArguments.transform(mainSymbol, tree) + } + + def normalize(m: ModuleDecl) = { + val anfed = BindSubexpressions.transform(m) + val normalized = Normalizer.normalize(Set(mainSymbol), anfed, Context.config.maxInlineSize().toInt) + Deadcode.remove(mainSymbol, normalized) + } + + // (3) normalize once and remove beta redexes + tree = Context.timed("normalize-1", source.name) { normalize(tree) } + + // (4) optimize continuation capture in the tail-resumptive case + tree = Context.timed("tail-resumptions", source.name) { RemoveTailResumptions(tree) } + + // (5) normalize again to clean up and remove new redexes + tree = Context.timed("normalize-2", source.name) { normalize(tree) } + + tree +} diff --git a/effekt/shared/src/main/scala/effekt/core/Reachable.scala b/effekt/shared/src/main/scala/effekt/core/optimizer/Reachable.scala similarity index 70% rename from effekt/shared/src/main/scala/effekt/core/Reachable.scala rename to effekt/shared/src/main/scala/effekt/core/optimizer/Reachable.scala index 3d51c94f0..f0eb439ed 100644 --- a/effekt/shared/src/main/scala/effekt/core/Reachable.scala +++ b/effekt/shared/src/main/scala/effekt/core/optimizer/Reachable.scala @@ -1,5 +1,6 @@ package effekt package core +package optimizer /** * A simple reachability analysis. @@ -10,35 +11,35 @@ class Reachable( var seen: Set[Id] ) { + private def update(id: Id, u: Usage): Unit = reachable = reachable.updated(id, u) + private def usage(id: Id): Usage = reachable.getOrElse(id, Usage.Never) + def process(d: Definition)(using defs: Map[Id, Definition]): Unit = - if stack.contains(d.id) then - reachable = reachable.updated(d.id, Usage.Recursive) + if stack.contains(d.id) then update(d.id, Usage.Recursive) else d match { case Definition.Def(id, block) => seen = seen + id + val before = stack stack = id :: stack + process(block) stack = before case Definition.Let(id, _, binding) => seen = seen + id + process(binding) } def process(id: Id)(using defs: Map[Id, Definition]): Unit = if (stack.contains(id)) { - reachable = reachable.updated(id, Usage.Recursive) + update(id, Usage.Recursive) return; } - val count = reachable.get(id) match { - case Some(Usage.Once) => Usage.Many - case Some(Usage.Many) => Usage.Many - case Some(Usage.Recursive) => Usage.Recursive - case None => Usage.Once - } - reachable = reachable.updated(id, count) + update(id, usage(id) + Usage.Once) + if (!seen.contains(id)) { defs.get(id).foreach(process) } @@ -57,9 +58,13 @@ class Reachable( definitions.foreach { case d: Definition.Def => currentDefs += d.id -> d // recursive - process(d)(using currentDefs) + // Do NOT process them here, since this would mean the definition is used + // process(d)(using currentDefs) case d: Definition.Let => - process(d)(using currentDefs) + // DO only process if NOT pure + if (d.binding.capt.nonEmpty) { + process(d)(using currentDefs) + } currentDefs += d.id -> d // non-recursive } process(body)(using currentDefs) @@ -111,32 +116,42 @@ class Reachable( def process(i: Implementation)(using defs: Map[Id, Definition]): Unit = i.operations.foreach { op => process(op.body) } - } object Reachable { - def apply(entrypoints: Set[Id], definitions: Map[Id, Definition]): Map[Id, Usage] = { - val analysis = new Reachable(Map.empty, Nil, Set.empty) - entrypoints.foreach(d => analysis.process(d)(using definitions)) - analysis.reachable - } + def apply(entrypoints: Set[Id], m: ModuleDecl): Map[Id, Usage] = { + val definitions = m.definitions.map(d => d.id -> d).toMap + val initialUsage = entrypoints.map { id => id -> Usage.Recursive }.toMap + val analysis = new Reachable(initialUsage, Nil, Set.empty) - def apply(m: ModuleDecl): Map[Id, Usage] = { - val analysis = new Reachable(Map.empty, Nil, Set.empty) - val defs = m.definitions.map(d => d.id -> d).toMap - m.definitions.foreach(d => analysis.process(d)(using defs)) - analysis.reachable - } + entrypoints.foreach(d => analysis.process(d)(using definitions)) - def apply(s: Stmt.Scope): Map[Id, Usage] = { - val analysis = new Reachable(Map.empty, Nil, Set.empty) - analysis.process(s)(using Map.empty) analysis.reachable } } enum Usage { + case Never case Once case Many case Recursive + + def +(other: Usage): Usage = (this, other) match { + case (Usage.Never, other) => other + case (other, Usage.Never) => other + case (other, Usage.Recursive) => Usage.Recursive + case (Usage.Recursive, other) => Usage.Recursive + case (Usage.Once, Usage.Once) => Usage.Many + case (Usage.Many, Usage.Many) => Usage.Many + case (Usage.Many, Usage.Once) => Usage.Many + case (Usage.Once, Usage.Many) => Usage.Many + } + + // -1 + def decrement: Usage = this match { + case Usage.Never => Usage.Never + case Usage.Once => Usage.Never + case Usage.Many => Usage.Many + case Usage.Recursive => Usage.Recursive + } } diff --git a/effekt/shared/src/main/scala/effekt/core/optimizer/RemoveTailResumptions.scala b/effekt/shared/src/main/scala/effekt/core/optimizer/RemoveTailResumptions.scala new file mode 100644 index 000000000..bf6544a70 --- /dev/null +++ b/effekt/shared/src/main/scala/effekt/core/optimizer/RemoveTailResumptions.scala @@ -0,0 +1,73 @@ +package effekt +package core +package optimizer + +object RemoveTailResumptions { + + def apply(m: ModuleDecl): ModuleDecl = removal.rewrite(m) + + object removal extends Tree.Rewrite { + override def stmt: PartialFunction[Stmt, Stmt] = { + case Stmt.Shift(prompt, BlockLit(tparams, cparams, vparams, List(k), body)) if tailResumptive(k.id, body) => + removeTailResumption(k.id, body) + case Stmt.Shift(prompt, body) => Shift(prompt, rewrite(body)) + } + } + + // A simple syntactic check whether this stmt is tailresumptive in k + def tailResumptive(k: Id, stmt: Stmt): Boolean = + def freeInStmt(stmt: Stmt): Boolean = Variables.free(stmt).containsBlock(k) + def freeInExpr(expr: Expr): Boolean = Variables.free(expr).containsBlock(k) + def freeInDef(definition: Definition): Boolean = Variables.free(definition).containsBlock(k) + + stmt match { + case Stmt.Scope(definitions, body) => definitions.forall { d => !freeInDef(d) } && tailResumptive(k, body) + case Stmt.Return(expr) => false + case Stmt.Val(id, annotatedTpe, binding, body) => tailResumptive(k, body) && !freeInStmt(binding) + case Stmt.App(callee, targs, vargs, bargs) => false + case Stmt.Invoke(callee, method, methodTpe, targs, vargs, bargs) => false + case Stmt.If(cond, thn, els) => !freeInExpr(cond) && tailResumptive(k, thn) && tailResumptive(k, els) + // Interestingly, we introduce a join point making this more difficult to implement properly + case Stmt.Match(scrutinee, clauses, default) => !freeInExpr(scrutinee) && clauses.forall { + case (_, BlockLit(tparams, cparams, vparams, bparams, body)) => tailResumptive(k, body) + } && default.forall { stmt => tailResumptive(k, stmt) } + case Stmt.Region(BlockLit(tparams, cparams, vparams, bparams, body)) => tailResumptive(k, body) + case Stmt.Alloc(id, init, region, body) => tailResumptive(k, body) && !freeInExpr(init) + case Stmt.Var(id, init, capture, body) => tailResumptive(k, body) && !freeInExpr(init) + case Stmt.Get(id, annotatedCapt, annotatedTpe) => false + case Stmt.Put(id, annotatedCapt, value) => false + case Stmt.Reset(BlockLit(tparams, cparams, vparams, bparams, body)) => tailResumptive(k, body) // is this correct? + case Stmt.Shift(prompt, body) => false + case Stmt.Resume(k2, body) => k2.id == k // what if k is free in body? + case Stmt.Hole() => true + } + + def removeTailResumption(k: Id, stmt: Stmt): Stmt = stmt match { + case Stmt.Scope(definitions, body) => Stmt.Scope(definitions, removeTailResumption(k, body)) + case Stmt.Val(id, annotatedTpe, binding, body) => Stmt.Val(id, annotatedTpe, binding, removeTailResumption(k, body)) + case Stmt.If(cond, thn, els) => Stmt.If(cond, removeTailResumption(k, thn), removeTailResumption(k, els)) + case Stmt.Match(scrutinee, clauses, default) => Stmt.Match(scrutinee, clauses.map { + case (tag, block) => tag -> removeTailResumption(k, block) + }, default.map(removeTailResumption(k, _))) + case Stmt.Region(body : BlockLit) => + Stmt.Region(removeTailResumption(k, body)) + case Stmt.Alloc(id, init, region, body) => Stmt.Alloc(id, init, region, removeTailResumption(k, body)) + case Stmt.Var(id, init, capture, body) => Stmt.Var(id, init, capture, removeTailResumption(k, body)) + case Stmt.Reset(body) => Stmt.Reset(removeTailResumption(k, body)) + case Stmt.Resume(k2, body) if k2.id == k => body + + case Stmt.Resume(k, body) => stmt + case Stmt.Shift(prompt, body) => stmt + case Stmt.Hole() => stmt + case Stmt.Return(expr) => stmt + case Stmt.App(callee, targs, vargs, bargs) => stmt + case Stmt.Invoke(callee, method, methodTpe, targs, vargs, bargs) => stmt + case Stmt.Get(id, annotatedCapt, annotatedTpe) => stmt + case Stmt.Put(id, annotatedCapt, value) => stmt + } + + def removeTailResumption(k: Id, block: BlockLit): BlockLit = block match { + case BlockLit(tparams, cparams, vparams, bparams, body) => + BlockLit(tparams, cparams, vparams, bparams, removeTailResumption(k, body)) + } +} diff --git a/effekt/shared/src/main/scala/effekt/core/StaticArguments.scala b/effekt/shared/src/main/scala/effekt/core/optimizer/StaticArguments.scala similarity index 80% rename from effekt/shared/src/main/scala/effekt/core/StaticArguments.scala rename to effekt/shared/src/main/scala/effekt/core/optimizer/StaticArguments.scala index 4aabbf42d..3a8e1214b 100644 --- a/effekt/shared/src/main/scala/effekt/core/StaticArguments.scala +++ b/effekt/shared/src/main/scala/effekt/core/optimizer/StaticArguments.scala @@ -1,8 +1,9 @@ package effekt package core +package optimizer -import effekt.core.normal.* import scala.collection.mutable + import effekt.core.Type.returnType /** @@ -34,9 +35,12 @@ object StaticArguments { types.forall(x => x) && (values.exists(x => x) || blocks.exists(x => x)) } - def dropStatic[A](isStatic: List[Boolean], arguments: List[A]): List[A] = + private def dropStatic[A](isStatic: List[Boolean], arguments: List[A]): List[A] = (isStatic zip arguments).collect { case (false, arg) => arg } + private def selectStatic[A](isStatic: List[Boolean], arguments: List[A]): List[A] = + (isStatic zip arguments).collect { case (true, arg) => arg } + /** * Wraps the definition in another function, abstracting arguments along the way. * For example: @@ -58,7 +62,7 @@ object StaticArguments { def wrapDefinition(id: Id, blockLit: BlockLit)(using ctx: StaticArgumentsContext): Definition.Def = val IsStatic(staticT, staticV, staticB) = ctx.statics(id) - assert(staticT.forall(x => x), "Can only apply the worker-wrapper translation, if all type arguments are static.") + assert(staticT.forall(x => x), "Can only apply the static arguments translation, if all type arguments are static.") val workerType = BlockType.Function( dropStatic(staticT, blockLit.tparams), // should always be empty! @@ -69,11 +73,8 @@ object StaticArguments { blockLit.returnType ) - val workerVar: Block.BlockVar = BlockVar(Id(id.name.name + "_worker"), workerType, blockLit.capt) - ctx.workers(id) = workerVar - // fresh params for the wrapper function and its invocation - // note: only freshen params if not static to prevent duplicates + // note: only freshen non-static params to prevent duplicates val freshCparams: List[Id] = (staticB zip blockLit.cparams).map { case (true, param) => param case (false, param) => Id(param) @@ -82,11 +83,17 @@ object StaticArguments { case (true, param) => param case (false, ValueParam(id, tpe)) => ValueParam(Id(id), tpe) } - val freshBparams: List[BlockParam] = (staticB zip blockLit.bparams).map { - case (true, param) => param - case (false, BlockParam(id, tpe, capt)) => BlockParam(Id(id), tpe, capt) + val freshBparams: List[BlockParam] = (staticB zip blockLit.bparams zip freshCparams).map { + case ((true, param), capt) => param + case ((false, BlockParam(id, tpe, _)), capt) => BlockParam(Id(id), tpe, Set(capt)) } + // the worker now closes over the static block arguments (`c` in the example above): + val newCapture = blockLit.capt ++ selectStatic(staticB, freshCparams).toSet + + val workerVar: Block.BlockVar = BlockVar(Id(id.name.name + "_worker"), workerType, newCapture) + ctx.workers(id) = workerVar + Definition.Def(id, BlockLit( blockLit.tparams, freshCparams, @@ -120,22 +127,22 @@ object StaticArguments { def rewrite(s: Stmt)(using C: StaticArgumentsContext): Stmt = s match { case Stmt.Scope(definitions, body) => - scope(rewrite(definitions), rewrite(body)) + MaybeScope(rewrite(definitions), rewrite(body)) case Stmt.App(b, targs, vargs, bargs) => b match { // if arguments are static && recursive call: call worker with reduced arguments case BlockVar(id, annotatedTpe, annotatedCapt) if hasStatics(id) && within(id) => val IsStatic(staticT, staticV, staticB) = C.statics(id) - app(C.workers(id), + Stmt.App(C.workers(id), dropStatic(staticT, targs), dropStatic(staticV, vargs).map(rewrite), dropStatic(staticB, bargs).map(rewrite)) - case _ => app(rewrite(b), targs, vargs.map(rewrite), bargs.map(rewrite)) + case _ => Stmt.App(rewrite(b), targs, vargs.map(rewrite), bargs.map(rewrite)) } case Stmt.Invoke(b, method, methodTpe, targs, vargs, bargs) => - invoke(rewrite(b), method, methodTpe, targs, vargs.map(rewrite), bargs.map(rewrite)) + Stmt.Invoke(rewrite(b), method, methodTpe, targs, vargs.map(rewrite), bargs.map(rewrite)) case Stmt.Reset(body) => rewrite(body) match { @@ -144,10 +151,9 @@ object StaticArguments { // congruences case Stmt.Return(expr) => Return(rewrite(expr)) - case Stmt.Val(id, tpe, binding, body) => valDef(id, tpe, rewrite(binding), rewrite(body)) + case Stmt.Val(id, tpe, binding, body) => Stmt.Val(id, tpe, rewrite(binding), rewrite(body)) case Stmt.If(cond, thn, els) => If(rewrite(cond), rewrite(thn), rewrite(els)) - case Stmt.Match(scrutinee, clauses, default) => - patternMatch(rewrite(scrutinee), clauses.map { case (id, value) => id -> rewrite(value) }, default.map(rewrite)) + case Stmt.Match(scrutinee, clauses, default) => Stmt.Match(rewrite(scrutinee), clauses.map { case (id, value) => id -> rewrite(value) }, default.map(rewrite)) case Stmt.Alloc(id, init, region, body) => Alloc(id, rewrite(init), region, rewrite(body)) case Stmt.Shift(prompt, body) => Shift(prompt, rewrite(body)) case Stmt.Resume(k, body) => Resume(k, rewrite(body)) @@ -155,7 +161,7 @@ object StaticArguments { case Stmt.Var(id, init, capture, body) => Stmt.Var(id, rewrite(init), capture, rewrite(body)) case Stmt.Get(id, capt, tpe) => Stmt.Get(id, capt, tpe) case Stmt.Put(id, capt, value) => Stmt.Put(id, capt, rewrite(value)) - case Stmt.Hole() => s + case Stmt.Hole() => Stmt.Hole() } def rewrite(b: BlockLit)(using StaticArgumentsContext): BlockLit = b match { @@ -168,8 +174,8 @@ object StaticArguments { // congruences case b @ Block.BlockLit(tparams, cparams, vparams, bparams, body) => rewrite(b) - case Block.Unbox(pure) => unbox(rewrite(pure)) - case Block.New(impl) => New(rewrite(impl)) + case Block.Unbox(pure) => Block.Unbox(rewrite(pure)) + case Block.New(impl) => Block.New(rewrite(impl)) } def rewrite(s: Implementation)(using StaticArgumentsContext): Implementation = @@ -180,21 +186,21 @@ object StaticArguments { } def rewrite(p: Pure)(using StaticArgumentsContext): Pure = p match { - case Pure.PureApp(b, targs, vargs) => pureApp(rewrite(b), targs, vargs.map(rewrite)) - case Pure.Make(data, tag, vargs) => make(data, tag, vargs.map(rewrite)) + case Pure.PureApp(b, targs, vargs) => Pure.PureApp(rewrite(b), targs, vargs.map(rewrite)) + case Pure.Make(data, tag, vargs) => Pure.Make(data, tag, vargs.map(rewrite)) case x @ Pure.ValueVar(id, annotatedType) => x // congruences case Pure.Literal(value, annotatedType) => p - case Pure.Select(target, field, annotatedType) => select(rewrite(target), field, annotatedType) - case Pure.Box(b, annotatedCapture) => box(rewrite(b), annotatedCapture) + case Pure.Select(target, field, annotatedType) => Pure.Select(rewrite(target), field, annotatedType) + case Pure.Box(b, annotatedCapture) => Pure.Box(rewrite(b), annotatedCapture) } def rewrite(e: Expr)(using StaticArgumentsContext): Expr = e match { - case DirectApp(b, targs, vargs, bargs) => directApp(rewrite(b), targs, vargs.map(rewrite), bargs.map(rewrite)) + case DirectApp(b, targs, vargs, bargs) => DirectApp(rewrite(b), targs, vargs.map(rewrite), bargs.map(rewrite)) // congruences - case Run(s) => run(rewrite(s)) + case Run(s) => Run(rewrite(s)) case pure: Pure => rewrite(pure) } diff --git a/effekt/shared/src/main/scala/effekt/generator/chez/ChezScheme.scala b/effekt/shared/src/main/scala/effekt/generator/chez/ChezScheme.scala index f83e0cec9..e91f2b568 100644 --- a/effekt/shared/src/main/scala/effekt/generator/chez/ChezScheme.scala +++ b/effekt/shared/src/main/scala/effekt/generator/chez/ChezScheme.scala @@ -3,7 +3,8 @@ package generator package chez import effekt.context.Context -import effekt.symbols.{Module, Symbol} +import effekt.core.optimizer.Optimizer +import effekt.symbols.{ Module, Symbol } import effekt.util.messages.ErrorReporter import kiama.output.PrettyPrinterTypes.Document import kiama.util.Source @@ -52,7 +53,7 @@ trait ChezScheme extends Compiler[String] { // ------------------------ // Source => Core => Chez lazy val Compile = - allToCore(Core) andThen Aggregate andThen core.Optimizer andThen Chez map { case (main, expr) => + allToCore(Core) andThen Aggregate andThen Optimizer andThen Chez map { case (main, expr) => (Map(main -> pretty(expr).layout), main) } diff --git a/effekt/shared/src/main/scala/effekt/generator/chez/Transformer.scala b/effekt/shared/src/main/scala/effekt/generator/chez/Transformer.scala index ab983ef7b..968b9d707 100644 --- a/effekt/shared/src/main/scala/effekt/generator/chez/Transformer.scala +++ b/effekt/shared/src/main/scala/effekt/generator/chez/Transformer.scala @@ -117,9 +117,11 @@ trait Transformer { // currently bidirectional handlers are not supported case Resume(k, Return(expr)) => chez.Call(toChez(k), List(toChez(expr))) + case Resume(k, other) => sys error s"Not supported yet: ${util.show(stmt)}" + case Region(body) => chez.Builtin("with-region", toChez(body)) - case other => chez.Let(Nil, toChez(other)) + case s: Scope => chez.Let(Nil, toChez(s)) } def toChez(decl: core.Declaration): List[chez.Def] = decl match { diff --git a/effekt/shared/src/main/scala/effekt/generator/js/JavaScript.scala b/effekt/shared/src/main/scala/effekt/generator/js/JavaScript.scala index ded31be25..cc97dcbb1 100644 --- a/effekt/shared/src/main/scala/effekt/generator/js/JavaScript.scala +++ b/effekt/shared/src/main/scala/effekt/generator/js/JavaScript.scala @@ -4,6 +4,7 @@ package js import effekt.PhaseResult.CoreTransformed import effekt.context.Context +import effekt.core.optimizer.{ DropBindings, Optimizer } import kiama.output.PrettyPrinterTypes.Document import kiama.util.Source @@ -41,7 +42,7 @@ class JavaScript(additionalFeatureFlags: List[String] = Nil) extends Compiler[St Frontend andThen Middleend } - lazy val Optimized = allToCore(Core) andThen Aggregate andThen core.Optimizer map { + lazy val Optimized = allToCore(Core) andThen Aggregate andThen Optimizer andThen DropBindings map { case input @ CoreTransformed(source, tree, mod, core) => val mainSymbol = Context.checkMain(mod) val mainFile = path(mod) diff --git a/effekt/shared/src/main/scala/effekt/generator/llvm/LLVM.scala b/effekt/shared/src/main/scala/effekt/generator/llvm/LLVM.scala index 774f0d73b..219be5c07 100644 --- a/effekt/shared/src/main/scala/effekt/generator/llvm/LLVM.scala +++ b/effekt/shared/src/main/scala/effekt/generator/llvm/LLVM.scala @@ -3,8 +3,9 @@ package generator package llvm import effekt.context.Context +import effekt.core.optimizer import effekt.machine -import kiama.output.PrettyPrinterTypes.{Document, emptyLinks} +import kiama.output.PrettyPrinterTypes.{ Document, emptyLinks } import kiama.util.Source @@ -38,7 +39,7 @@ class LLVM extends Compiler[String] { // The Compilation Pipeline // ------------------------ // Source => Core => Machine => LLVM - lazy val Compile = allToCore(Core) andThen Aggregate andThen core.PolymorphismBoxing andThen core.Optimizer andThen Machine map { + lazy val Compile = allToCore(Core) andThen Aggregate andThen core.PolymorphismBoxing andThen optimizer.Optimizer andThen Machine map { case (mod, main, prog) => (mod, llvm.Transformer.transform(prog)) } @@ -51,7 +52,7 @@ class LLVM extends Compiler[String] { // ----------------------------------- object steps { // intermediate steps for VSCode - val afterCore = allToCore(Core) andThen Aggregate andThen core.PolymorphismBoxing andThen core.Optimizer + val afterCore = allToCore(Core) andThen Aggregate andThen core.PolymorphismBoxing andThen optimizer.Optimizer val afterMachine = afterCore andThen Machine map { case (mod, main, prog) => prog } val afterLLVM = afterMachine map { case machine.Program(decls, prog) => diff --git a/effekt/shared/src/main/scala/effekt/machine/Transformer.scala b/effekt/shared/src/main/scala/effekt/machine/Transformer.scala index 34cd60ae5..0e010bd25 100644 --- a/effekt/shared/src/main/scala/effekt/machine/Transformer.scala +++ b/effekt/shared/src/main/scala/effekt/machine/Transformer.scala @@ -6,6 +6,8 @@ import effekt.core.{ Block, DeclarationContext, Definition, Id, given } import effekt.symbols.{ Symbol, TermSymbol } import effekt.symbols.builtins.TState import effekt.util.messages.ErrorReporter +import effekt.symbols.ErrorMessageInterpolator + object Transformer { @@ -121,16 +123,13 @@ object Transformer { noteDefinition(id, free, params) } - case Definition.Def(id, core.Unbox(_)) => - // TODO deal with this case - () + case Definition.Def(id, b @ core.Unbox(_)) => + noteParameter(id, b.tpe) case Definition.Let(_, _, _) => () } - - // (2) Actually translate the definitions definitions.foldRight(transform(rest)) { case (core.Definition.Let(id, tpe, binding), rest) => @@ -149,9 +148,10 @@ object Transformer { case (core.Definition.Def(id, core.BlockVar(alias, tpe, _)), rest) => Def(transformLabel(id), Jump(transformLabel(alias)), rest) - case (d @ core.Definition.Def(_, _: core.Unbox), rest) => - // TODO deal with this case by substitution - ErrorReporter.abort(s"block definition: $d") + case (core.Definition.Def(id, core.Unbox(pure)), rest) => + transform(pure).run { boxed => + ForeignCall(Variable(transform(id), Type.Negative()), "unbox", List(boxed), rest) + } } case core.Return(expr) => @@ -168,7 +168,7 @@ object Transformer { transform(vargs, bargs).run { (values, blocks) => callee match { case Block.BlockVar(id, annotatedTpe, annotatedCapt) => - BPC.info.getOrElse(id, sys.error(s"Cannot find block info for ${id}.\n${BPC.info}")) match { + BPC.info.getOrElse(id, sys.error(pp"In ${stmt}. Cannot find block info for ${id}: ${annotatedTpe}.\n${BPC.info}")) match { // Unknown Jump to function case BlockInfo.Parameter(tpe: core.BlockType.Function) => Invoke(Variable(transform(id), transform(tpe)), builtins.Apply, values ++ blocks) @@ -182,13 +182,18 @@ object Transformer { } case Block.Unbox(pure) => - transform(pure).run { callee => Invoke(callee, builtins.Apply, values ++ blocks) } + transform(pure).run { boxedCallee => + val callee = Variable(freshName(boxedCallee.name), Type.Negative()) + + ForeignCall(callee, "unbox", List(boxedCallee), + Invoke(callee, builtins.Apply, values ++ blocks)) + } case Block.New(impl) => ErrorReporter.panic("Applying an object") case Block.BlockLit(tparams, cparams, vparams, bparams, body) => - ErrorReporter.panic("Call to block literal should have been reduced") + ErrorReporter.panic(pp"Call to block literal should have been reduced: ${stmt}") } } @@ -202,7 +207,12 @@ object Transformer { Invoke(Variable(transform(id), transform(tpe)), opTag, values ++ blocks) case Block.Unbox(pure) => - transform(pure).run { callee => Invoke(callee, opTag, values ++ blocks) } + transform(pure).run { boxedCallee => + val callee = Variable(freshName(boxedCallee.name), Type.Negative()) + + ForeignCall(callee, "unbox", List(boxedCallee), + Invoke(callee, opTag, values ++ blocks)) + } case Block.New(impl) => ErrorReporter.panic("Method call to known object should have been reduced") @@ -451,7 +461,12 @@ object Transformer { } case core.Box(block, annot) => - transformBlockArg(block) + transformBlockArg(block).flatMap { unboxed => + Binding { k => + val boxed = Variable(freshName(unboxed.name), Type.Positive()) + ForeignCall(boxed, "box", List(unboxed), k(boxed)) + } + } case _ => ErrorReporter.abort(s"Unsupported expression: $expr") @@ -485,7 +500,7 @@ object Transformer { def transform(tpe: core.ValueType)(using ErrorReporter): Type = tpe match { case core.ValueType.Var(name) => Positive() // assume all value parameters are data - case core.ValueType.Boxed(tpe, capt) => Negative() + case core.ValueType.Boxed(tpe, capt) => Positive() case core.Type.TUnit => builtins.UnitType case core.Type.TInt => Type.Int() case core.Type.TChar => Type.Int() @@ -566,7 +581,7 @@ object Transformer { BPC.globals += (id -> Label(transform(id), Nil)) def getBlockInfo(id: Id)(using BPC: BlocksParamsContext): BlockInfo = - BPC.info.getOrElse(id, sys error s"No block info for ${id}") + BPC.info.getOrElse(id, sys error s"No block info for ${util.show(id)}") def getDefinition(id: Id)(using BPC: BlocksParamsContext): BlockInfo.Definition = getBlockInfo(id) match { case d : BlockInfo.Definition => d diff --git a/effekt/shared/src/main/scala/effekt/util/Debug.scala b/effekt/shared/src/main/scala/effekt/util/Debug.scala index ca9189ef3..f58d5d6ca 100644 --- a/effekt/shared/src/main/scala/effekt/util/Debug.scala +++ b/effekt/shared/src/main/scala/effekt/util/Debug.scala @@ -4,7 +4,7 @@ package util import effekt.symbols.TypePrinter -lazy val showGeneric: PartialFunction[Any, String] = { +val showGeneric: PartialFunction[Any, String] = { case l: List[_] => l.map(show).mkString("List(", ", ", ")") case o: Option[_] => @@ -12,7 +12,7 @@ lazy val showGeneric: PartialFunction[Any, String] = { case other => other.toString } -lazy val show: PartialFunction[Any, String] = +val show: PartialFunction[Any, String] = TypePrinter.show orElse core.PrettyPrinter.show orElse generator.js.PrettyPrinter.show orElse diff --git a/examples/benchmarks/are_we_fast_yet/queens.effekt b/examples/benchmarks/are_we_fast_yet/queens.effekt index 497ef8496..7ab35a427 100644 --- a/examples/benchmarks/are_we_fast_yet/queens.effekt +++ b/examples/benchmarks/are_we_fast_yet/queens.effekt @@ -49,4 +49,3 @@ def run(n: Int) = { } def main() = benchmark(8){run} - diff --git a/examples/benchmarks/effect_handlers_bench/parsing_dollars.effekt b/examples/benchmarks/effect_handlers_bench/parsing_dollars.effekt index 77ce1aaf1..11a662a89 100644 --- a/examples/benchmarks/effect_handlers_bench/parsing_dollars.effekt +++ b/examples/benchmarks/effect_handlers_bench/parsing_dollars.effekt @@ -28,7 +28,7 @@ def parse(a: Int): Unit / {Read, Emit, Stop} = { do Stop() } } - + def sum { action: () => Unit / Emit } = { var s = 0; try { diff --git a/examples/benchmarks/effect_handlers_bench/tree_explore.effekt b/examples/benchmarks/effect_handlers_bench/tree_explore.effekt index 6611050e8..fba8f712e 100644 --- a/examples/benchmarks/effect_handlers_bench/tree_explore.effekt +++ b/examples/benchmarks/effect_handlers_bench/tree_explore.effekt @@ -65,4 +65,3 @@ def run(n: Int) = { } def main() = benchmark(5){run} - diff --git a/libraries/common/io.effekt b/libraries/common/io.effekt index af47633bd..c3ab21adf 100644 --- a/libraries/common/io.effekt +++ b/libraries/common/io.effekt @@ -16,7 +16,8 @@ extern async def spawn(task: Task[Unit]): Unit = js "$effekt.capture(k => { setTimeout(() => k($effekt.unit), 0); return $effekt.run(${task}) })" llvm """ call void @c_yield(%Stack %stack) - call void @run(%Neg ${task}) + %unboxed = call ccc %Neg @unbox(%Pos ${task}) + call void @run(%Neg %unboxed) ret void """ diff --git a/libraries/llvm/rts.ll b/libraries/llvm/rts.ll index f69fd3dac..cd3fb0d8a 100644 --- a/libraries/llvm/rts.ll +++ b/libraries/llvm/rts.ll @@ -114,6 +114,26 @@ declare void @exit(i64) declare void @llvm.assume(i1) +; Boxing (externs functions, hence ccc) +define ccc %Pos @box(%Neg %input) { + %vtable = extractvalue %Neg %input, 0 + %heap_obj = extractvalue %Neg %input, 1 + %vtable_as_int = ptrtoint ptr %vtable to i64 + %pos_result = insertvalue %Pos undef, i64 %vtable_as_int, 0 + %pos_result_with_heap = insertvalue %Pos %pos_result, ptr %heap_obj, 1 + ret %Pos %pos_result_with_heap +} + +define ccc %Neg @unbox(%Pos %input) { + %tag = extractvalue %Pos %input, 0 + %heap_obj = extractvalue %Pos %input, 1 + %vtable = inttoptr i64 %tag to ptr + %neg_result = insertvalue %Neg undef, ptr %vtable, 0 + %neg_result_with_heap = insertvalue %Neg %neg_result, ptr %heap_obj, 1 + ret %Neg %neg_result_with_heap +} + + ; Prompts define private %Prompt @currentPrompt(%Stack %stack) {