From 04262381af79e65832eaa5dc9c3c4d458c75e952 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonathan=20Brachtha=CC=88user?= Date: Tue, 14 Jan 2025 21:05:34 +0100 Subject: [PATCH] Make local continuations direct style if possible --- .../effekt/core/optimizer/Normalizer.scala | 30 ++++- .../effekt/generator/js/TransformerCps.scala | 127 +++++++++++++++--- 2 files changed, 139 insertions(+), 18 deletions(-) diff --git a/effekt/shared/src/main/scala/effekt/core/optimizer/Normalizer.scala b/effekt/shared/src/main/scala/effekt/core/optimizer/Normalizer.scala index 9855b9446..c984aec6a 100644 --- a/effekt/shared/src/main/scala/effekt/core/optimizer/Normalizer.scala +++ b/effekt/shared/src/main/scala/effekt/core/optimizer/Normalizer.scala @@ -209,13 +209,39 @@ object Normalizer { normal => case Stmt.Val(id, tpe, binding, body) => - def normalizeVal(id: Id, tpe: ValueType, binding: Stmt, body: Stmt): Stmt = binding match { + def barendregdt(stmt: Stmt): Stmt = new Renamer().apply(stmt) + + def normalizeVal(id: Id, tpe: ValueType, binding: Stmt, body: Stmt): Stmt = normalize(binding) match { // [[ val x = sc match { case id(ps) => body2 }; body ]] = sc match { case id(ps) => val x = body2; body } case Stmt.Match(sc, List((id2, BlockLit(tparams2, cparams2, vparams2, bparams2, body2))), None) => Stmt.Match(sc, List((id2, BlockLit(tparams2, cparams2, vparams2, bparams2, normalizeVal(id, tpe, body2, body)))), None) + // These rewrites do not seem to contribute a lot given their complexity... + // vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv + + // [[ val x = if (cond) { thn } else { els }; body ]] = if (cond) { [[ val x = thn; body ]] } else { [[ val x = els; body ]] } +// case normalized @ Stmt.If(cond, thn, els) if body.size <= 2 => +// // since we duplicate the body, we need to freshen the names +// val normalizedThn = barendregdt(normalizeVal(id, tpe, thn, body)) +// val normalizedEls = barendregdt(normalizeVal(id, tpe, els, body)) +// +// Stmt.If(cond, normalizedThn, normalizedEls) +// +// case Stmt.Match(sc, clauses, default) +// // necessary since otherwise we loose Nothing-boxing +// // vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv +// if body.size <= 2 && (clauses.size + default.size) >= 1 => +// val normalizedClauses = clauses map { +// case (id2, BlockLit(tparams2, cparams2, vparams2, bparams2, body2)) => +// (id2, BlockLit(tparams2, cparams2, vparams2, bparams2, barendregdt(normalizeVal(id, tpe, body2, body))): BlockLit) +// } +// val normalizedDefault = default map { stmt => barendregdt(normalizeVal(id, tpe, stmt, body)) } +// Stmt.Match(sc, normalizedClauses, normalizedDefault) + + // ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + // [[ val x = return e; s ]] = let x = [[ e ]]; [[ s ]] case Stmt.Return(expr2) => Stmt.Let(id, tpe, expr2, normalize(body)(using C.bind(id, expr2))) @@ -245,7 +271,7 @@ object Normalizer { normal => case normalizedBody => Stmt.Val(id, tpe, other, normalizedBody) } } - normalizeVal(id, tpe, normalize(binding), body) + normalizeVal(id, tpe, binding, body) // "Congruences" diff --git a/effekt/shared/src/main/scala/effekt/generator/js/TransformerCps.scala b/effekt/shared/src/main/scala/effekt/generator/js/TransformerCps.scala index 827e1019f..4713c2c0b 100644 --- a/effekt/shared/src/main/scala/effekt/generator/js/TransformerCps.scala +++ b/effekt/shared/src/main/scala/effekt/generator/js/TransformerCps.scala @@ -6,6 +6,7 @@ import effekt.context.Context import effekt.context.assertions.* import effekt.cps.* import effekt.core.{ DeclarationContext, Id } +import effekt.cps.Variables.{ all, free } import scala.collection.mutable @@ -34,6 +35,8 @@ object TransformerCps extends Transformer { } } + case class ContinuationInfo(k: Id, param: Id, ks: Id) + case class TransformerContext( requiresThunk: Boolean, bindings: Map[Id, js.Expr], @@ -41,6 +44,12 @@ object TransformerCps extends Transformer { externs: Map[Id, cps.Extern.Def], // currently, lexically enclosing functions and their parameters (used to determine whether a call is recursive, to rewrite into a loop) definitions: Map[Id, DefInfo], + // the direct-style continuation, if available + directStyle: Option[ContinuationInfo], + // the current (direct-style) metacontinuation + metacont: Option[Id], + // substitutions for renaming of metaconts + metaconts: Map[Id, Id], // the original declaration context (used to compile pattern matching) declarations: DeclarationContext, // the usual compiler context @@ -50,13 +59,23 @@ object TransformerCps extends Transformer { def lookup(id: Id)(using C: TransformerContext): js.Expr = C.bindings.getOrElse(id, nameRef(id)) - def enterDefinition(id: Id, used: Used, block: cps.Block)(using C: TransformerContext): TransformerContext = block match { + def recursive(id: Id, used: Used, block: cps.Block)(using C: TransformerContext): TransformerContext = block match { case cps.BlockLit(vparams, bparams, ks, k, body) => - C.copy(definitions = Map(id -> DefInfo(id, vparams, bparams, ks, k, used))) + C.copy(definitions = Map(id -> DefInfo(id, vparams, bparams, ks, k, used)), directStyle = None, metacont = Some(ks)) case _ => C } - def clearDefinitions(using C: TransformerContext): TransformerContext = C.copy(definitions = Map.empty) + def nonrecursive(ks: Id)(using C: TransformerContext): TransformerContext = + C.copy(definitions = Map.empty, directStyle = None, metacont = Some(ks)) + + def nonrecursive(block: cps.BlockLit)(using C: TransformerContext): TransformerContext = nonrecursive(block.ks) + + // ks | let k1 x1 ks1 = { let k2 x2 ks2 = jump k v ks2 }; ... = jump k v ks + def directstyle(ks: Id)(using C: TransformerContext): TransformerContext = + val outer = C.metacont.getOrElse { sys error "Metacontinuation missing..." } + val outerSubstituted = C.metaconts.getOrElse(outer, outer) + val subst = C.metaconts.updated(ks, outerSubstituted) + C.copy(metacont = Some(ks), metaconts = subst) def bindingAll[R](bs: List[(Id, js.Expr)])(body: TransformerContext ?=> R)(using C: TransformerContext): R = body(using C.copy(bindings = C.bindings ++ bs)) @@ -79,6 +98,9 @@ object TransformerCps extends Transformer { Map.empty, externs.collect { case d: Extern.Def => (d.id, d) }.toMap, Map.empty, + None, + None, + Map.empty, D, C) val name = JSName(jsModuleName(module.path)) @@ -101,6 +123,9 @@ object TransformerCps extends Transformer { Map.empty, input.externs.collect { case d: Extern.Def => (d.id, d) }.toMap, Map.empty, + None, + None, + Map.empty, D, C) input.definitions.map(toJS) @@ -160,7 +185,7 @@ object TransformerCps extends Transformer { case cps.Block.BlockLit(vparams, bparams, ks, k, body) => val used = new Used(false) - val translatedBody = toJS(body)(using enterDefinition(id, used, b)).stmts + val translatedBody = toJS(body)(using recursive(id, used, b)).stmts if used.used then js.Lambda(vparams.map(nameDef) ++ bparams.map(nameDef) ++ List(nameDef(ks), nameDef(k)), @@ -169,7 +194,7 @@ object TransformerCps extends Transformer { js.Lambda(vparams.map(nameDef) ++ bparams.map(nameDef) ++ List(nameDef(ks), nameDef(k)), translatedBody) - case other => toJS(other)(using clearDefinitions) + case other => toJS(other) } def toJS(b: cps.Block)(using TransformerContext): js.Expr = b match { @@ -185,17 +210,18 @@ object TransformerCps extends Transformer { case cps.Implementation(interface, operations) => js.Object(operations.map { case cps.Operation(id, vps, bps, ks, k, body) => - nameDef(id) -> js.Lambda(vps.map(nameDef) ++ bps.map(nameDef) ++ List(nameDef(ks), nameDef(k)), toJS(body)(using clearDefinitions).stmts) + nameDef(id) -> js.Lambda(vps.map(nameDef) ++ bps.map(nameDef) ++ List(nameDef(ks), nameDef(k)), toJS(body)(using nonrecursive(ks)).stmts) }) } - def toJS(ks: cps.MetaCont): js.Expr = nameRef(ks.id) + def toJS(ks: cps.MetaCont)(using T: TransformerContext): js.Expr = + nameRef(T.metaconts.getOrElse(ks.id, ks.id)) def toJS(k: cps.Cont)(using T: TransformerContext): js.Expr = k match { case Cont.ContVar(id) => nameRef(id) case Cont.ContLam(result, ks, body) => - js.Lambda(List(nameDef(result), nameDef(ks)), toJS(body)(using clearDefinitions).stmts) + js.Lambda(List(nameDef(result), nameDef(ks)), toJS(body)(using nonrecursive(ks)).stmts) } def toJS(e: cps.Expr)(using D: TransformerContext): js.Expr = e match { @@ -225,9 +251,22 @@ object TransformerCps extends Transformer { js.Const(nameDef(id), toJS(binding)) :: toJS(body).run(k) } - case cps.Stmt.LetCont(id, binding, body) => + // [[ let k(x, ks) = ...; if (...) jump k(42, ks2) else jump k(10, ks3) ]] = + // let x; if (...) { x = 42; ks = ks2 } else { x = 10; ks = ks3 } ... + case cps.Stmt.LetCont(id, Cont.ContLam(param, ks, body), body2) if canBeDirect(id, body2) => Binding { k => - js.Const(nameDef(id), toJS(binding)) :: requiringThunk { toJS(body)(using clearDefinitions) }.run(k) + val withDirectStyle = D.copy(directStyle = Some(ContinuationInfo(id, param, ks))) + val renamingMetacont = directstyle(ks) + + val translatedIf = toJS(body2)(using withDirectStyle) + val translatedBody = toJS(body)(using renamingMetacont) + + js.Let(nameDef(param), js.Undefined) :: translatedIf.stmts ++ translatedBody.run(k) + } + + case cps.Stmt.LetCont(id, binding @ Cont.ContLam(result2, ks2, body2), body) => + Binding { k => + js.Const(nameDef(id), toJS(binding)(using nonrecursive(ks2))) :: requiringThunk { toJS(body) }.run(k) } case cps.Stmt.Match(sc, Nil, None) => @@ -245,14 +284,27 @@ object TransformerCps extends Transformer { pure(js.Switch(js.Member(scrutinee, `tag`), clauses.map { case (tag, clause) => val (e, binding) = toJS(scrutinee, tag, clause); - (e, binding.stmts) + + val stmts = binding.stmts + + stmts.last match { + case terminator : (js.Stmt.Return | js.Stmt.Break | js.Stmt.Continue) => (e, stmts) + case other => (e, stmts :+ js.Break()) + } }, default.map { s => toJS(s).stmts })) + case cps.Stmt.Jump(k, arg, ks) if D.directStyle.exists(c => c.k == k) => D.directStyle match { + case Some(ContinuationInfo(k2, param2, ks2)) => pure(js.Assign(nameRef(param2), toJS(arg))) + case None => sys error "Should not happen" + } + case cps.Stmt.Jump(k, arg, ks) => pure(js.Return(maybeThunking(js.Call(nameRef(k), toJS(arg), toJS(ks))))) - case cps.Stmt.App(callee @ DefInfo(id, vparams, bparams, ks1, k1, used), vargs, bargs, MetaCont(ks), Cont.ContVar(k)) if ks1 == ks && k1 == k => + case cps.Stmt.App(DefInfo(id, vparams, bparams, ks1, k1, used), vargs, bargs, MetaCont(ks), Cont.ContVar(k)) + // this call is a tailcall if the metacontinuation (after substitution) is the same and the continuation is the same. + if ks1 == D.metaconts.getOrElse(ks, ks) && k1 == k => Binding { k2 => val stmts = mutable.ListBuffer.empty[js.Stmt] stmts.append(js.RawStmt("/* prepare tail call */")) @@ -336,13 +388,13 @@ object TransformerCps extends Transformer { } case cps.Stmt.Reset(prog, ks, k) => - pure(js.Return(Call(RESET, toJS(prog)(using clearDefinitions), toJS(ks), toJS(k)))) + pure(js.Return(Call(RESET, toJS(prog)(using nonrecursive(prog)), toJS(ks), toJS(k)))) case cps.Stmt.Shift(prompt, body, ks, k) => - pure(js.Return(Call(SHIFT, nameRef(prompt), noThunking { toJS(body)(using clearDefinitions) }, toJS(ks), toJS(k)))) + pure(js.Return(Call(SHIFT, nameRef(prompt), noThunking { toJS(body)(using nonrecursive(body)) }, toJS(ks), toJS(k)))) case cps.Stmt.Resume(r, b, ks2, k2) => - pure(js.Return(js.Call(RESUME, nameRef(r), toJS(b)(using clearDefinitions), toJS(ks2), toJS(k2)))) + pure(js.Return(js.Call(RESUME, nameRef(r), toJS(b)(using nonrecursive(b)), toJS(ks2), toJS(k2)))) case cps.Stmt.Hole() => pure(js.Return($effekt.call("hole"))) @@ -383,12 +435,55 @@ object TransformerCps extends Transformer { case _ => js.Call(nameRef(id), args.map(toJS)) } - def canInline(extern: cps.Extern): Boolean = extern match { + private def canInline(extern: cps.Extern): Boolean = extern match { case cps.Extern.Def(_, _, Nil, async, ExternBody.StringExternBody(_, Template(_, _))) => !async case _ => false } + // Predicates for Direct-Style Transformation + // ------------------------------------------ + + private def canBeDirect(k: Id, stmt: Stmt)(using T: TransformerContext): Boolean = + def notIn(term: Stmt | Block | Expr | (Id, Clause) | Cont) = + val freeVars = term match { + case s: Stmt => free(s) + case b: Block => free(b) + case p: Expr => free(p) + case (id, Clause(_, body)) => free(body) + case c: Cont => free(c) + } + !freeVars.contains(k) + stmt match { + case Stmt.Jump(k2, arg, ks2) if k2 == k => notIn(arg) && T.metacont.contains(ks2.id) + case Stmt.Jump(k2, arg, ks2) => notIn(arg) + // TODO this could be a tailcall! + case Stmt.App(callee, vargs, bargs, ks, k) => notIn(stmt) + case Stmt.Invoke(callee, method, vargs, bargs, ks, k2) => notIn(stmt) + case Stmt.If(cond, thn, els) => canBeDirect(k, thn) && canBeDirect(k, els) + case Stmt.Match(scrutinee, clauses, default) => clauses.forall { + case (id, Clause(vparams, body)) => canBeDirect(k, body) + } && default.forall(body => canBeDirect(k, body)) + case Stmt.LetDef(id, binding, body) => notIn(binding) && canBeDirect(k, body) + case Stmt.LetExpr(id, binding, body) => notIn(binding) && canBeDirect(k, body) + case Stmt.LetCont(id, Cont.ContLam(result, ks2, body), body2) => + def willBeDirectItself = canBeDirect(id, body2) && canBeDirect(k, body)(using directstyle(ks2)) + def notFreeinContinuation = notIn(body) && canBeDirect(k, body2) + willBeDirectItself || notFreeinContinuation + case Stmt.Region(id, ks, body) => notIn(body) + case Stmt.Alloc(id, init, region, body) => notIn(init) && canBeDirect(k, body) + case Stmt.Var(id, init, ks2, body) => notIn(init) && canBeDirect(k, body) + case Stmt.Dealloc(ref, body) => canBeDirect(k, body) + case Stmt.Get(ref, id, body) => canBeDirect(k, body) + case Stmt.Put(ref, value, body) => notIn(value) && canBeDirect(k, body) + case Stmt.Reset(prog, ks, k) => notIn(stmt) + case Stmt.Shift(prompt, body, ks, k) => notIn(stmt) + case Stmt.Resume(resumption, body, ks, k) => notIn(stmt) + case Stmt.Hole() => true + } + + + // Thunking // --------