Skip to content

Commit

Permalink
Generate loops in JS for obviously tail-calling functions
Browse files Browse the repository at this point in the history
  • Loading branch information
b-studios committed Jan 13, 2025
1 parent 3cbf9cf commit 2aeed28
Showing 1 changed file with 95 additions and 12 deletions.
107 changes: 95 additions & 12 deletions effekt/shared/src/main/scala/effekt/generator/js/TransformerCps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@ package js

import effekt.context.Context
import effekt.context.assertions.*
import effekt.cps.*
import effekt.core.{ DeclarationContext, Id }
import effekt.cps.{ Block, * }
import effekt.core.{ Block, DeclarationContext, Id }

import scala.collection.mutable

object TransformerCps extends Transformer {

Expand All @@ -19,17 +21,43 @@ object TransformerCps extends Transformer {
val DEALLOC = Variable(JSName("DEALLOC"))
val TRAMPOLINE = Variable(JSName("TRAMPOLINE"))

class Used(var used: Boolean)
case class DefInfo(id: Id, vparams: List[Id], bparams: List[Id], ks: Id, k: Id, used: Used)

object DefInfo {
def unapply(b: cps.Block)(using C: TransformerContext): Option[(Id, List[Id], List[Id], Id, Id, Used)] = b match {
case cps.Block.BlockVar(id) => C.definitions.get(id) match {
case Some(DefInfo(id, vparams, bparams, ks, k, used)) => Some((id, vparams, bparams, ks, k, used))
case None => None
}
case _ => None
}
}

case class TransformerContext(
requiresThunk: Boolean,
bindings: Map[Id, js.Expr],
// definitions of externs (used to inline them)
externs: Map[Id, cps.Extern.Def],
declarations: DeclarationContext, // to be refactored
// currently, lexically enclosing functions and their parameters (used to determine whether a call is recursive, to rewrite into a loop)
definitions: Map[Id, DefInfo],
// the original declaration context (used to compile pattern matching)
declarations: DeclarationContext,
// the usual compiler context
errors: Context
)
implicit def autoContext(using C: TransformerContext): Context = C.errors

def lookup(id: Id)(using C: TransformerContext): js.Expr = C.bindings.getOrElse(id, nameRef(id))

def enterDefinition(id: Id, used: Used, block: cps.Block)(using C: TransformerContext): TransformerContext = block match {
case cps.BlockLit(vparams, bparams, ks, k, body) =>
C.copy(definitions = Map(id -> DefInfo(id, vparams, bparams, ks, k, used)))
case _ => C
}

def clearDefinitions(using C: TransformerContext): TransformerContext = C.copy(definitions = Map.empty)

def bindingAll[R](bs: List[(Id, js.Expr)])(body: TransformerContext ?=> R)(using C: TransformerContext): R =
body(using C.copy(bindings = C.bindings ++ bs))

Expand All @@ -50,6 +78,7 @@ object TransformerCps extends Transformer {
false,
Map.empty,
externs.collect { case d: Extern.Def => (d.id, d) }.toMap,
Map.empty,
D, C)

val name = JSName(jsModuleName(module.path))
Expand All @@ -71,14 +100,15 @@ object TransformerCps extends Transformer {
false,
Map.empty,
input.externs.collect { case d: Extern.Def => (d.id, d) }.toMap,
Map.empty,
D, C)

input.definitions.map(toJS)


def toJS(d: cps.ToplevelDefinition)(using TransformerContext): js.Stmt = d match {
case cps.ToplevelDefinition.Def(id, block) =>
js.Const(nameDef(id), requiringThunk { toJS(block) })
js.Const(nameDef(id), requiringThunk { toJS(id, block) })
case cps.ToplevelDefinition.Val(id, ks, k, binding) =>
js.Const(nameDef(id), Call(RUN_TOPLEVEL, js.Lambda(List(nameDef(ks), nameDef(k)), toJS(binding).stmts)))
case cps.ToplevelDefinition.Let(id, binding) =>
Expand Down Expand Up @@ -126,6 +156,23 @@ object TransformerCps extends Transformer {
Nil
}

def toJS(id: Id, b: cps.Block)(using TransformerContext): js.Expr = b match {
case cps.Block.BlockLit(vparams, bparams, ks, k, body) =>
val used = new Used(false)

// for now we add while everywhere, this should be done selectively...
val translatedBody = toJS(body)(using enterDefinition(id, used, b)).stmts

if used.used then
js.Lambda(vparams.map(nameDef) ++ bparams.map(nameDef) ++ List(nameDef(ks), nameDef(k)),
List(js.While(RawExpr("true"), translatedBody, Some(uniqueName(id)))))
else
js.Lambda(vparams.map(nameDef) ++ bparams.map(nameDef) ++ List(nameDef(ks), nameDef(k)),
translatedBody)

case other => toJS(other)(using clearDefinitions)
}

def toJS(b: cps.Block)(using TransformerContext): js.Expr = b match {
case cps.BlockVar(v) => nameRef(v)
case cps.Unbox(e) => toJS(e)
Expand All @@ -139,7 +186,7 @@ object TransformerCps extends Transformer {
case cps.Implementation(interface, operations) =>
js.Object(operations.map {
case cps.Operation(id, vps, bps, ks, k, body) =>
nameDef(id) -> js.Lambda(vps.map(nameDef) ++ bps.map(nameDef) ++ List(nameDef(ks), nameDef(k)), toJS(body).stmts)
nameDef(id) -> js.Lambda(vps.map(nameDef) ++ bps.map(nameDef) ++ List(nameDef(ks), nameDef(k)), toJS(body)(using clearDefinitions).stmts)
})
}

Expand All @@ -149,7 +196,7 @@ object TransformerCps extends Transformer {
case Cont.ContVar(id) =>
nameRef(id)
case Cont.ContLam(result, ks, body) =>
js.Lambda(List(nameDef(result), nameDef(ks)), toJS(body).stmts)
js.Lambda(List(nameDef(result), nameDef(ks)), toJS(body)(using clearDefinitions).stmts)
}

def toJS(e: cps.Expr)(using D: TransformerContext): js.Expr = e match {
Expand All @@ -165,9 +212,10 @@ object TransformerCps extends Transformer {
}

def toJS(s: cps.Stmt)(using D: TransformerContext): Binding[js.Stmt] = s match {

case cps.Stmt.LetDef(id, block, body) =>
Binding { k =>
js.Const(nameDef(id), requiringThunk { toJS(block) }) :: toJS(body).run(k)
js.Const(nameDef(id), requiringThunk { toJS(id, block) }) :: toJS(body).run(k)
}

case cps.Stmt.If(cond, thn, els) =>
Expand All @@ -180,7 +228,7 @@ object TransformerCps extends Transformer {

case cps.Stmt.LetCont(id, binding, body) =>
Binding { k =>
js.Const(nameDef(id), toJS(binding)) :: requiringThunk { toJS(body) }.run(k)
js.Const(nameDef(id), toJS(binding)) :: requiringThunk { toJS(body)(using clearDefinitions) }.run(k)
}

case cps.Stmt.Match(sc, Nil, None) =>
Expand All @@ -205,6 +253,40 @@ object TransformerCps extends Transformer {
case cps.Stmt.Jump(k, arg, ks) =>
pure(js.Return(maybeThunking(js.Call(nameRef(k), toJS(arg), toJS(ks)))))

case cps.Stmt.App(callee @ DefInfo(id, vparams, bparams, ks1, k1, used), vargs, bargs, MetaCont(ks), Cont.ContVar(k)) if ks1 == ks && k1 == k =>
Binding { k2 =>
val stmts = mutable.ListBuffer.empty[js.Stmt]
stmts.append(js.RawStmt("/* prepare tail call */"))

used.used = true

// const x3 = [[ arg ]]; ...
val vtmps = (vparams zip vargs).map { (id, arg) =>
val tmp = Id(id)
stmts.append(js.Const(nameDef(tmp), toJS(arg)))
tmp
}
val btmps = (bparams zip bargs).map { (id, arg) =>
val tmp = Id(id)
stmts.append(js.Const(nameDef(tmp), toJS(arg)))
tmp
}

// x = x3;
(vparams zip vtmps).foreach {
(param, tmp) => stmts.append(js.Assign(nameRef(param), nameRef(tmp)))
}
(bparams zip btmps).foreach {
(param, tmp) => stmts.append(js.Assign(nameRef(param), nameRef(tmp)))
}

// continue f;
val jump = js.Continue(Some(uniqueName(id)));

stmts.appendAll(k2(jump))
stmts.toList
}

case cps.Stmt.App(callee, vargs, bargs, ks, k) =>
pure(js.Return(maybeThunking(js.Call(toJS(callee), vargs.map(toJS) ++ bargs.map(toJS) ++ List(toJS(ks),
requiringThunk { toJS(k) })))))
Expand Down Expand Up @@ -255,13 +337,13 @@ object TransformerCps extends Transformer {
}

case cps.Stmt.Reset(prog, ks, k) =>
pure(js.Return(Call(RESET, toJS(prog), toJS(ks), toJS(k))))
pure(js.Return(Call(RESET, toJS(prog)(using clearDefinitions), toJS(ks), toJS(k))))

case cps.Stmt.Shift(prompt, body, ks, k) =>
pure(js.Return(Call(SHIFT, nameRef(prompt), noThunking { toJS(body) }, toJS(ks), toJS(k))))
pure(js.Return(Call(SHIFT, nameRef(prompt), noThunking { toJS(body)(using clearDefinitions) }, toJS(ks), toJS(k))))

case cps.Stmt.Resume(r, b, ks2, k2) =>
pure(js.Return(js.Call(RESUME, nameRef(r), toJS(b), toJS(ks2), toJS(k2))))
pure(js.Return(js.Call(RESUME, nameRef(r), toJS(b)(using clearDefinitions), toJS(ks2), toJS(k2))))

case cps.Stmt.Hole() =>
pure(js.Return($effekt.call("hole")))
Expand All @@ -284,7 +366,8 @@ object TransformerCps extends Transformer {
}

def toJS(d: cps.Def)(using T: TransformerContext): js.Stmt = d match {
case cps.Def(id, block) => js.Const(nameDef(id), requiringThunk { toJS(block) })
case cps.Def(id, block) =>
js.Const(nameDef(id), requiringThunk { toJS(id, block) })
}


Expand Down

0 comments on commit 2aeed28

Please sign in to comment.