diff --git a/effekt/shared/src/main/scala/effekt/generator/js/TransformerCps.scala b/effekt/shared/src/main/scala/effekt/generator/js/TransformerCps.scala index 7f1d7e09b..202b55b5b 100644 --- a/effekt/shared/src/main/scala/effekt/generator/js/TransformerCps.scala +++ b/effekt/shared/src/main/scala/effekt/generator/js/TransformerCps.scala @@ -4,8 +4,10 @@ package js import effekt.context.Context import effekt.context.assertions.* -import effekt.cps.* -import effekt.core.{ DeclarationContext, Id } +import effekt.cps.{ Block, * } +import effekt.core.{ Block, DeclarationContext, Id } + +import scala.collection.mutable object TransformerCps extends Transformer { @@ -19,17 +21,43 @@ object TransformerCps extends Transformer { val DEALLOC = Variable(JSName("DEALLOC")) val TRAMPOLINE = Variable(JSName("TRAMPOLINE")) + class Used(var used: Boolean) + case class DefInfo(id: Id, vparams: List[Id], bparams: List[Id], ks: Id, k: Id, used: Used) + + object DefInfo { + def unapply(b: cps.Block)(using C: TransformerContext): Option[(Id, List[Id], List[Id], Id, Id, Used)] = b match { + case cps.Block.BlockVar(id) => C.definitions.get(id) match { + case Some(DefInfo(id, vparams, bparams, ks, k, used)) => Some((id, vparams, bparams, ks, k, used)) + case None => None + } + case _ => None + } + } + case class TransformerContext( requiresThunk: Boolean, bindings: Map[Id, js.Expr], + // definitions of externs (used to inline them) externs: Map[Id, cps.Extern.Def], - declarations: DeclarationContext, // to be refactored + // currently, lexically enclosing functions and their parameters (used to determine whether a call is recursive, to rewrite into a loop) + definitions: Map[Id, DefInfo], + // the original declaration context (used to compile pattern matching) + declarations: DeclarationContext, + // the usual compiler context errors: Context ) implicit def autoContext(using C: TransformerContext): Context = C.errors def lookup(id: Id)(using C: TransformerContext): js.Expr = C.bindings.getOrElse(id, nameRef(id)) + def enterDefinition(id: Id, used: Used, block: cps.Block)(using C: TransformerContext): TransformerContext = block match { + case cps.BlockLit(vparams, bparams, ks, k, body) => + C.copy(definitions = Map(id -> DefInfo(id, vparams, bparams, ks, k, used))) + case _ => C + } + + def clearDefinitions(using C: TransformerContext): TransformerContext = C.copy(definitions = Map.empty) + def bindingAll[R](bs: List[(Id, js.Expr)])(body: TransformerContext ?=> R)(using C: TransformerContext): R = body(using C.copy(bindings = C.bindings ++ bs)) @@ -50,6 +78,7 @@ object TransformerCps extends Transformer { false, Map.empty, externs.collect { case d: Extern.Def => (d.id, d) }.toMap, + Map.empty, D, C) val name = JSName(jsModuleName(module.path)) @@ -71,6 +100,7 @@ object TransformerCps extends Transformer { false, Map.empty, input.externs.collect { case d: Extern.Def => (d.id, d) }.toMap, + Map.empty, D, C) input.definitions.map(toJS) @@ -78,7 +108,7 @@ object TransformerCps extends Transformer { def toJS(d: cps.ToplevelDefinition)(using TransformerContext): js.Stmt = d match { case cps.ToplevelDefinition.Def(id, block) => - js.Const(nameDef(id), requiringThunk { toJS(block) }) + js.Const(nameDef(id), requiringThunk { toJS(id, block) }) case cps.ToplevelDefinition.Val(id, ks, k, binding) => js.Const(nameDef(id), Call(RUN_TOPLEVEL, js.Lambda(List(nameDef(ks), nameDef(k)), toJS(binding).stmts))) case cps.ToplevelDefinition.Let(id, binding) => @@ -126,6 +156,23 @@ object TransformerCps extends Transformer { Nil } + def toJS(id: Id, b: cps.Block)(using TransformerContext): js.Expr = b match { + case cps.Block.BlockLit(vparams, bparams, ks, k, body) => + val used = new Used(false) + + // for now we add while everywhere, this should be done selectively... + val translatedBody = toJS(body)(using enterDefinition(id, used, b)).stmts + + if used.used then + js.Lambda(vparams.map(nameDef) ++ bparams.map(nameDef) ++ List(nameDef(ks), nameDef(k)), + List(js.While(RawExpr("true"), translatedBody, Some(uniqueName(id))))) + else + js.Lambda(vparams.map(nameDef) ++ bparams.map(nameDef) ++ List(nameDef(ks), nameDef(k)), + translatedBody) + + case other => toJS(other)(using clearDefinitions) + } + def toJS(b: cps.Block)(using TransformerContext): js.Expr = b match { case cps.BlockVar(v) => nameRef(v) case cps.Unbox(e) => toJS(e) @@ -139,7 +186,7 @@ object TransformerCps extends Transformer { case cps.Implementation(interface, operations) => js.Object(operations.map { case cps.Operation(id, vps, bps, ks, k, body) => - nameDef(id) -> js.Lambda(vps.map(nameDef) ++ bps.map(nameDef) ++ List(nameDef(ks), nameDef(k)), toJS(body).stmts) + nameDef(id) -> js.Lambda(vps.map(nameDef) ++ bps.map(nameDef) ++ List(nameDef(ks), nameDef(k)), toJS(body)(using clearDefinitions).stmts) }) } @@ -149,7 +196,7 @@ object TransformerCps extends Transformer { case Cont.ContVar(id) => nameRef(id) case Cont.ContLam(result, ks, body) => - js.Lambda(List(nameDef(result), nameDef(ks)), toJS(body).stmts) + js.Lambda(List(nameDef(result), nameDef(ks)), toJS(body)(using clearDefinitions).stmts) } def toJS(e: cps.Expr)(using D: TransformerContext): js.Expr = e match { @@ -165,9 +212,10 @@ object TransformerCps extends Transformer { } def toJS(s: cps.Stmt)(using D: TransformerContext): Binding[js.Stmt] = s match { + case cps.Stmt.LetDef(id, block, body) => Binding { k => - js.Const(nameDef(id), requiringThunk { toJS(block) }) :: toJS(body).run(k) + js.Const(nameDef(id), requiringThunk { toJS(id, block) }) :: toJS(body).run(k) } case cps.Stmt.If(cond, thn, els) => @@ -180,7 +228,7 @@ object TransformerCps extends Transformer { case cps.Stmt.LetCont(id, binding, body) => Binding { k => - js.Const(nameDef(id), toJS(binding)) :: requiringThunk { toJS(body) }.run(k) + js.Const(nameDef(id), toJS(binding)) :: requiringThunk { toJS(body)(using clearDefinitions) }.run(k) } case cps.Stmt.Match(sc, Nil, None) => @@ -205,6 +253,40 @@ object TransformerCps extends Transformer { case cps.Stmt.Jump(k, arg, ks) => pure(js.Return(maybeThunking(js.Call(nameRef(k), toJS(arg), toJS(ks))))) + case cps.Stmt.App(callee @ DefInfo(id, vparams, bparams, ks1, k1, used), vargs, bargs, MetaCont(ks), Cont.ContVar(k)) if ks1 == ks && k1 == k => + Binding { k2 => + val stmts = mutable.ListBuffer.empty[js.Stmt] + stmts.append(js.RawStmt("/* prepare tail call */")) + + used.used = true + + // const x3 = [[ arg ]]; ... + val vtmps = (vparams zip vargs).map { (id, arg) => + val tmp = Id(id) + stmts.append(js.Const(nameDef(tmp), toJS(arg))) + tmp + } + val btmps = (bparams zip bargs).map { (id, arg) => + val tmp = Id(id) + stmts.append(js.Const(nameDef(tmp), toJS(arg))) + tmp + } + + // x = x3; + (vparams zip vtmps).foreach { + (param, tmp) => stmts.append(js.Assign(nameRef(param), nameRef(tmp))) + } + (bparams zip btmps).foreach { + (param, tmp) => stmts.append(js.Assign(nameRef(param), nameRef(tmp))) + } + + // continue f; + val jump = js.Continue(Some(uniqueName(id))); + + stmts.appendAll(k2(jump)) + stmts.toList + } + case cps.Stmt.App(callee, vargs, bargs, ks, k) => pure(js.Return(maybeThunking(js.Call(toJS(callee), vargs.map(toJS) ++ bargs.map(toJS) ++ List(toJS(ks), requiringThunk { toJS(k) }))))) @@ -255,13 +337,13 @@ object TransformerCps extends Transformer { } case cps.Stmt.Reset(prog, ks, k) => - pure(js.Return(Call(RESET, toJS(prog), toJS(ks), toJS(k)))) + pure(js.Return(Call(RESET, toJS(prog)(using clearDefinitions), toJS(ks), toJS(k)))) case cps.Stmt.Shift(prompt, body, ks, k) => - pure(js.Return(Call(SHIFT, nameRef(prompt), noThunking { toJS(body) }, toJS(ks), toJS(k)))) + pure(js.Return(Call(SHIFT, nameRef(prompt), noThunking { toJS(body)(using clearDefinitions) }, toJS(ks), toJS(k)))) case cps.Stmt.Resume(r, b, ks2, k2) => - pure(js.Return(js.Call(RESUME, nameRef(r), toJS(b), toJS(ks2), toJS(k2)))) + pure(js.Return(js.Call(RESUME, nameRef(r), toJS(b)(using clearDefinitions), toJS(ks2), toJS(k2)))) case cps.Stmt.Hole() => pure(js.Return($effekt.call("hole"))) @@ -284,7 +366,8 @@ object TransformerCps extends Transformer { } def toJS(d: cps.Def)(using T: TransformerContext): js.Stmt = d match { - case cps.Def(id, block) => js.Const(nameDef(id), requiringThunk { toJS(block) }) + case cps.Def(id, block) => + js.Const(nameDef(id), requiringThunk { toJS(id, block) }) }