Skip to content

Commit

Permalink
Make local continuations direct style if possible (#777)
Browse files Browse the repository at this point in the history
This PR tries to remove the overhead of CPS (and trampolining) for
direct style code.

Essentially, for the following code in CPS

```
let cont k(x, ks) = ...
if (cond) {
  k(42, ks')
} else {
  k(43, ks')
}
```

we are now generating the following JS:

```javascript
let x;
if (cond) {
  x = 42
} else {
  x = 43
};
...
```

This only works under the assumption that:

- `k` is not used somewhere in a more first-class manner (under a
different context, as part of a continuation passed to a function, etc.)
- `k` is always called with the currently in scope meta continuation
`ks'` (so that we can avoid assigning it -- this turns out to be
essential with the feature interaction of loops)

The interaction with loops was difficult to get right:

- turning recursive functions into direct style loops makes additional
continuations more local
- making a continuation more local turns more functions into loops


As a result, the function `parse_worker` (from the `parsing_dollars`
benchmark) is now a tight loop.

**Before the PR**

```javascript
function parse_worker_0(a_1, ks_7, k_4) {
      const x_0 = i_0.value;
      function k_3(c_0, ks_4) {
        if (c_0 === (36)) {
          return () => parse_worker_0((a_1 + (1)), ks_4, k_4);
        } else if (c_0 === (10)) {
          const x_1 = s_0.value;
          s_0.value = (x_1 + a_1);
          return () => parse_worker_0(0, ks_4, k_4);
        } else {
          return SHIFT(p_0, (k_5, ks_5, k_6) => k_6($effekt.unit, ks_5), ks_4, k_4);
        }
      }
      if ((x_0 > n_0)) {
        return SHIFT(p_0, (k_7, ks_6, k_8) => k_8($effekt.unit, ks_6), ks_7, (v_r_1, ks_8) =>
          $effekt.emptyMatch());
      } else {
        const x_2 = j_0.value;
        if (x_2 === (0)) {
          const x_3 = i_0.value;
          i_0.value = (x_3 + (1));
          const x_4 = i_0.value;
          j_0.value = x_4;
          return () => k_3(10, ks_7);
        } else {
          const x_5 = j_0.value;
          j_0.value = (x_5 - (1));
          return () => k_3(36, ks_7);
        }
      }
    }
```

**After the PR**

```javascript
function parse_worker_0(a_9, ks_293, k_236) {
      parse_worker_0: while (true) {
        const x_11 = i_13.value;
        let c_0 = undefined;
        if ((x_11 > n_17)) {
          return SHIFT(p_32, (k_232, ks_292, k_233) =>
            k_233($effekt.unit, ks_292), ks_293, (v_r_95, ks_294) =>
            $effekt.emptyMatch());
        } else {
          const x_12 = j_0.value;
          if (x_12 === (0)) {
            const x_13 = i_13.value;
            i_13.value = (x_13 + (1));
            const x_14 = i_13.value;
            j_0.value = x_14;
            c_0 = 10;
          } else {
            const x_15 = j_0.value;
            j_0.value = (x_15 - (1));
            c_0 = 36;
          }
        }
        if (c_0 === (36)) {
          /* prepare tail call */
          const a_8 = (a_9 + (1));
          a_9 = a_8;
          continue parse_worker_0;
        } else if (c_0 === (10)) {
          const x_16 = s_4.value;
          s_4.value = (x_16 + a_9);
          /* prepare tail call */
          const a_10 = 0;
          a_9 = a_10;
          continue parse_worker_0;
        } else {
          return SHIFT(p_32, (k_234, ks_295, k_235) =>
            k_235($effekt.unit, ks_295), ks_293, k_236);
        }
      }
    }
```

Additionally, here are a few preliminary benchmark results:


![image](https://github.com/user-attachments/assets/6d3bd21a-fdcf-4149-8e0a-78ba48c4c6d2)
  • Loading branch information
b-studios authored Jan 14, 2025
1 parent 957517a commit df1adc5
Show file tree
Hide file tree
Showing 2 changed files with 170 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -209,13 +209,39 @@ object Normalizer { normal =>

case Stmt.Val(id, tpe, binding, body) =>

def normalizeVal(id: Id, tpe: ValueType, binding: Stmt, body: Stmt): Stmt = binding match {
// def barendregt(stmt: Stmt): Stmt = new Renamer().apply(stmt)

def normalizeVal(id: Id, tpe: ValueType, binding: Stmt, body: Stmt): Stmt = normalize(binding) match {

// [[ val x = sc match { case id(ps) => body2 }; body ]] = sc match { case id(ps) => val x = body2; body }
case Stmt.Match(sc, List((id2, BlockLit(tparams2, cparams2, vparams2, bparams2, body2))), None) =>
Stmt.Match(sc, List((id2, BlockLit(tparams2, cparams2, vparams2, bparams2,
normalizeVal(id, tpe, body2, body)))), None)

// These rewrites do not seem to contribute a lot given their complexity...
// vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv

// [[ val x = if (cond) { thn } else { els }; body ]] = if (cond) { [[ val x = thn; body ]] } else { [[ val x = els; body ]] }
// case normalized @ Stmt.If(cond, thn, els) if body.size <= 2 =>
// // since we duplicate the body, we need to freshen the names
// val normalizedThn = barendregt(normalizeVal(id, tpe, thn, body))
// val normalizedEls = barendregt(normalizeVal(id, tpe, els, body))
//
// Stmt.If(cond, normalizedThn, normalizedEls)
//
// case Stmt.Match(sc, clauses, default)
// // necessary since otherwise we loose Nothing-boxing
// // vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv
// if body.size <= 2 && (clauses.size + default.size) >= 1 =>
// val normalizedClauses = clauses map {
// case (id2, BlockLit(tparams2, cparams2, vparams2, bparams2, body2)) =>
// (id2, BlockLit(tparams2, cparams2, vparams2, bparams2, barendregt(normalizeVal(id, tpe, body2, body))): BlockLit)
// }
// val normalizedDefault = default map { stmt => barendregt(normalizeVal(id, tpe, stmt, body)) }
// Stmt.Match(sc, normalizedClauses, normalizedDefault)

// ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

// [[ val x = return e; s ]] = let x = [[ e ]]; [[ s ]]
case Stmt.Return(expr2) =>
Stmt.Let(id, tpe, expr2, normalize(body)(using C.bind(id, expr2)))
Expand Down Expand Up @@ -245,7 +271,7 @@ object Normalizer { normal =>
case normalizedBody => Stmt.Val(id, tpe, other, normalizedBody)
}
}
normalizeVal(id, tpe, normalize(binding), body)
normalizeVal(id, tpe, binding, body)


// "Congruences"
Expand Down
185 changes: 142 additions & 43 deletions effekt/shared/src/main/scala/effekt/generator/js/TransformerCps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import effekt.context.Context
import effekt.context.assertions.*
import effekt.cps.*
import effekt.core.{ DeclarationContext, Id }
import effekt.cps.Variables.{ all, free }

import scala.collection.mutable

Expand All @@ -21,45 +22,31 @@ object TransformerCps extends Transformer {
val DEALLOC = Variable(JSName("DEALLOC"))
val TRAMPOLINE = Variable(JSName("TRAMPOLINE"))

class Used(var used: Boolean)
case class DefInfo(id: Id, vparams: List[Id], bparams: List[Id], ks: Id, k: Id, used: Used)

object DefInfo {
def unapply(b: cps.Block)(using C: TransformerContext): Option[(Id, List[Id], List[Id], Id, Id, Used)] = b match {
case cps.Block.BlockVar(id) => C.definitions.get(id) match {
case Some(DefInfo(id, vparams, bparams, ks, k, used)) => Some((id, vparams, bparams, ks, k, used))
case None => None
}
case _ => None
}
}
class RecursiveUsage(var jumped: Boolean)
case class RecursiveDefInfo(id: Id, vparams: List[Id], bparams: List[Id], ks: Id, k: Id, used: RecursiveUsage)
case class ContinuationInfo(k: Id, param: Id, ks: Id)

case class TransformerContext(
requiresThunk: Boolean,
// known definitions of expressions (used to inline into externs)
bindings: Map[Id, js.Expr],
// definitions of externs (used to inline them)
externs: Map[Id, cps.Extern.Def],
// currently, lexically enclosing functions and their parameters (used to determine whether a call is recursive, to rewrite into a loop)
definitions: Map[Id, DefInfo],
// the innermost (in direct style) enclosing functions (used to rewrite a definition to a loop)
recursive: Option[RecursiveDefInfo],
// the direct-style continuation, if available (used in case cps.Stmt.LetCont)
directStyle: Option[ContinuationInfo],
// the current direct-style metacontinuation
metacont: Option[Id],
// substitutions for renaming of metaconts (to avoid rebinding them)
metaconts: Map[Id, Id],
// the original declaration context (used to compile pattern matching)
declarations: DeclarationContext,
// the usual compiler context
errors: Context
)
implicit def autoContext(using C: TransformerContext): Context = C.errors

def lookup(id: Id)(using C: TransformerContext): js.Expr = C.bindings.getOrElse(id, nameRef(id))

def enterDefinition(id: Id, used: Used, block: cps.Block)(using C: TransformerContext): TransformerContext = block match {
case cps.BlockLit(vparams, bparams, ks, k, body) =>
C.copy(definitions = Map(id -> DefInfo(id, vparams, bparams, ks, k, used)))
case _ => C
}

def clearDefinitions(using C: TransformerContext): TransformerContext = C.copy(definitions = Map.empty)

def bindingAll[R](bs: List[(Id, js.Expr)])(body: TransformerContext ?=> R)(using C: TransformerContext): R =
body(using C.copy(bindings = C.bindings ++ bs))

/**
* Entrypoint used by the compiler to compile whole programs
Expand All @@ -78,6 +65,9 @@ object TransformerCps extends Transformer {
false,
Map.empty,
externs.collect { case d: Extern.Def => (d.id, d) }.toMap,
None,
None,
None,
Map.empty,
D, C)

Expand All @@ -100,6 +90,9 @@ object TransformerCps extends Transformer {
false,
Map.empty,
input.externs.collect { case d: Extern.Def => (d.id, d) }.toMap,
None,
None,
None,
Map.empty,
D, C)

Expand Down Expand Up @@ -158,18 +151,18 @@ object TransformerCps extends Transformer {

def toJS(id: Id, b: cps.Block)(using TransformerContext): js.Expr = b match {
case cps.Block.BlockLit(vparams, bparams, ks, k, body) =>
val used = new Used(false)
val used = new RecursiveUsage(false)

val translatedBody = toJS(body)(using enterDefinition(id, used, b)).stmts
val translatedBody = toJS(body)(using recursive(id, used, b)).stmts

if used.used then
if used.jumped then
js.Lambda(vparams.map(nameDef) ++ bparams.map(nameDef) ++ List(nameDef(ks), nameDef(k)),
List(js.While(RawExpr("true"), translatedBody, Some(uniqueName(id)))))
else
js.Lambda(vparams.map(nameDef) ++ bparams.map(nameDef) ++ List(nameDef(ks), nameDef(k)),
translatedBody)

case other => toJS(other)(using clearDefinitions)
case other => toJS(other)
}

def toJS(b: cps.Block)(using TransformerContext): js.Expr = b match {
Expand All @@ -185,17 +178,18 @@ object TransformerCps extends Transformer {
case cps.Implementation(interface, operations) =>
js.Object(operations.map {
case cps.Operation(id, vps, bps, ks, k, body) =>
nameDef(id) -> js.Lambda(vps.map(nameDef) ++ bps.map(nameDef) ++ List(nameDef(ks), nameDef(k)), toJS(body)(using clearDefinitions).stmts)
nameDef(id) -> js.Lambda(vps.map(nameDef) ++ bps.map(nameDef) ++ List(nameDef(ks), nameDef(k)), toJS(body)(using nonrecursive(ks)).stmts)
})
}

def toJS(ks: cps.MetaCont): js.Expr = nameRef(ks.id)
def toJS(ks: cps.MetaCont)(using T: TransformerContext): js.Expr =
nameRef(T.metaconts.getOrElse(ks.id, ks.id))

def toJS(k: cps.Cont)(using T: TransformerContext): js.Expr = k match {
case Cont.ContVar(id) =>
nameRef(id)
case Cont.ContLam(result, ks, body) =>
js.Lambda(List(nameDef(result), nameDef(ks)), toJS(body)(using clearDefinitions).stmts)
js.Lambda(List(nameDef(result), nameDef(ks)), toJS(body)(using nonrecursive(ks)).stmts)
}

def toJS(e: cps.Expr)(using D: TransformerContext): js.Expr = e match {
Expand Down Expand Up @@ -225,9 +219,18 @@ object TransformerCps extends Transformer {
js.Const(nameDef(id), toJS(binding)) :: toJS(body).run(k)
}

case cps.Stmt.LetCont(id, binding, body) =>
// [[ let k(x, ks) = ...; if (...) jump k(42, ks2) else jump k(10, ks3) ]] =
// let x; if (...) { x = 42; ks = ks2 } else { x = 10; ks = ks3 } ...
case cps.Stmt.LetCont(id, Cont.ContLam(param, ks, body), body2) if canBeDirect(id, body2) =>
Binding { k =>
js.Const(nameDef(id), toJS(binding)) :: requiringThunk { toJS(body)(using clearDefinitions) }.run(k)
js.Let(nameDef(param), js.Undefined) ::
toJS(body2)(using withDirectStyle(id, param, ks)).stmts ++
toJS(body)(using directstyle(ks)).run(k)
}

case cps.Stmt.LetCont(id, binding @ Cont.ContLam(result2, ks2, body2), body) =>
Binding { k =>
js.Const(nameDef(id), toJS(binding)(using nonrecursive(ks2))) :: requiringThunk { toJS(body) }.run(k)
}

case cps.Stmt.Match(sc, Nil, None) =>
Expand All @@ -245,19 +248,30 @@ object TransformerCps extends Transformer {
pure(js.Switch(js.Member(scrutinee, `tag`),
clauses.map { case (tag, clause) =>
val (e, binding) = toJS(scrutinee, tag, clause);
(e, binding.stmts)

val stmts = binding.stmts

stmts.last match {
case terminator : (js.Stmt.Return | js.Stmt.Break | js.Stmt.Continue) => (e, stmts)
case other => (e, stmts :+ js.Break())
}
},
default.map { s => toJS(s).stmts }))

case cps.Stmt.Jump(k, arg, ks) if D.directStyle.exists(c => c.k == k) => D.directStyle match {
case Some(ContinuationInfo(k2, param2, ks2)) => pure(js.Assign(nameRef(param2), toJS(arg)))
case None => sys error "Should not happen"
}

case cps.Stmt.Jump(k, arg, ks) =>
pure(js.Return(maybeThunking(js.Call(nameRef(k), toJS(arg), toJS(ks)))))

case cps.Stmt.App(callee @ DefInfo(id, vparams, bparams, ks1, k1, used), vargs, bargs, MetaCont(ks), Cont.ContVar(k)) if ks1 == ks && k1 == k =>
case cps.Stmt.App(Recursive(id, vparams, bparams, ks1, k1, used), vargs, bargs, MetaCont(ks), Cont.ContVar(k)) if sameScope(ks, k, ks1, k1) =>
Binding { k2 =>
val stmts = mutable.ListBuffer.empty[js.Stmt]
stmts.append(js.RawStmt("/* prepare tail call */"))

used.used = true
used.jumped = true

// const x3 = [[ arg ]]; ...
val vtmps = (vparams zip vargs).map { (id, arg) =>
Expand Down Expand Up @@ -336,13 +350,13 @@ object TransformerCps extends Transformer {
}

case cps.Stmt.Reset(prog, ks, k) =>
pure(js.Return(Call(RESET, toJS(prog)(using clearDefinitions), toJS(ks), toJS(k))))
pure(js.Return(Call(RESET, toJS(prog)(using nonrecursive(prog)), toJS(ks), toJS(k))))

case cps.Stmt.Shift(prompt, body, ks, k) =>
pure(js.Return(Call(SHIFT, nameRef(prompt), noThunking { toJS(body)(using clearDefinitions) }, toJS(ks), toJS(k))))
pure(js.Return(Call(SHIFT, nameRef(prompt), noThunking { toJS(body)(using nonrecursive(body)) }, toJS(ks), toJS(k))))

case cps.Stmt.Resume(r, b, ks2, k2) =>
pure(js.Return(js.Call(RESUME, nameRef(r), toJS(b)(using clearDefinitions), toJS(ks2), toJS(k2))))
pure(js.Return(js.Call(RESUME, nameRef(r), toJS(b)(using nonrecursive(b)), toJS(ks2), toJS(k2))))

case cps.Stmt.Hole() =>
pure(js.Return($effekt.call("hole")))
Expand Down Expand Up @@ -373,7 +387,7 @@ object TransformerCps extends Transformer {
// Inlining Externs
// ----------------

def inlineExtern(id: Id, args: List[cps.Pure])(using T: TransformerContext): js.Expr =
private def inlineExtern(id: Id, args: List[cps.Pure])(using T: TransformerContext): js.Expr =
T.externs.get(id) match {
case Some(cps.Extern.Def(id, params, Nil, async,
ExternBody.StringExternBody(featureFlag, Template(strings, templateArgs)))) if !async =>
Expand All @@ -383,11 +397,96 @@ object TransformerCps extends Transformer {
case _ => js.Call(nameRef(id), args.map(toJS))
}

def canInline(extern: cps.Extern): Boolean = extern match {
private def canInline(extern: cps.Extern): Boolean = extern match {
case cps.Extern.Def(_, _, Nil, async, ExternBody.StringExternBody(_, Template(_, _))) => !async
case _ => false
}

private def bindingAll[R](bs: List[(Id, js.Expr)])(body: TransformerContext ?=> R)(using C: TransformerContext): R =
body(using C.copy(bindings = C.bindings ++ bs))

private def lookup(id: Id)(using C: TransformerContext): js.Expr = C.bindings.getOrElse(id, nameRef(id))


// Helpers for Direct-Style Transformation
// ---------------------------------------

/**
* Used to determine whether a call with continuations [[ ks ]] (after substitution) and [[ k ]]
* is the same as the original function definition (that is [[ ks1 ]] and [[ k1 ]].
*/
private def sameScope(ks: Id, k: Id, ks1: Id, k1: Id)(using C: TransformerContext): Boolean =
ks1 == C.metaconts.getOrElse(ks, ks) && k1 == k

private def withDirectStyle(id: Id, param: Id, ks: Id)(using C: TransformerContext): TransformerContext =
C.copy(directStyle = Some(ContinuationInfo(id, param, ks)))

private def recursive(id: Id, used: RecursiveUsage, block: cps.Block)(using C: TransformerContext): TransformerContext = block match {
case cps.BlockLit(vparams, bparams, ks, k, body) =>
C.copy(recursive = Some(RecursiveDefInfo(id, vparams, bparams, ks, k, used)), directStyle = None, metacont = Some(ks))
case _ => C
}

private def nonrecursive(ks: Id)(using C: TransformerContext): TransformerContext =
C.copy(recursive = None, directStyle = None, metacont = Some(ks))

private def nonrecursive(block: cps.BlockLit)(using C: TransformerContext): TransformerContext = nonrecursive(block.ks)

// ks | let k1 x1 ks1 = { let k2 x2 ks2 = jump k v ks2 }; ... = jump k v ks
private def directstyle(ks: Id)(using C: TransformerContext): TransformerContext =
val outer = C.metacont.getOrElse { sys error "Metacontinuation missing..." }
val outerSubstituted = C.metaconts.getOrElse(outer, outer)
val subst = C.metaconts.updated(ks, outerSubstituted)
C.copy(metacont = Some(ks), metaconts = subst)

private object Recursive {
def unapply(b: cps.Block)(using C: TransformerContext): Option[(Id, List[Id], List[Id], Id, Id, RecursiveUsage)] = b match {
case cps.Block.BlockVar(id) => C.recursive.collect {
case RecursiveDefInfo(id2, vparams, bparams, ks, k, used) if id == id2 => (id, vparams, bparams, ks, k, used)
}
case _ => None
}
}

private def canBeDirect(k: Id, stmt: Stmt)(using T: TransformerContext): Boolean =
def notIn(term: Stmt | Block | Expr | (Id, Clause) | Cont) =
val freeVars = term match {
case s: Stmt => free(s)
case b: Block => free(b)
case p: Expr => free(p)
case (id, Clause(_, body)) => free(body)
case c: Cont => free(c)
}
!freeVars.contains(k)
stmt match {
case Stmt.Jump(k2, arg, ks2) if k2 == k => notIn(arg) && T.metacont.contains(ks2.id)
case Stmt.Jump(k2, arg, ks2) => notIn(arg)
// TODO this could be a tailcall!
case Stmt.App(callee, vargs, bargs, ks, k) => notIn(stmt)
case Stmt.Invoke(callee, method, vargs, bargs, ks, k2) => notIn(stmt)
case Stmt.If(cond, thn, els) => canBeDirect(k, thn) && canBeDirect(k, els)
case Stmt.Match(scrutinee, clauses, default) => clauses.forall {
case (id, Clause(vparams, body)) => canBeDirect(k, body)
} && default.forall(body => canBeDirect(k, body))
case Stmt.LetDef(id, binding, body) => notIn(binding) && canBeDirect(k, body)
case Stmt.LetExpr(id, binding, body) => notIn(binding) && canBeDirect(k, body)
case Stmt.LetCont(id, Cont.ContLam(result, ks2, body), body2) =>
def willBeDirectItself = canBeDirect(id, body2) && canBeDirect(k, body)(using directstyle(ks2))
def notFreeinContinuation = notIn(body) && canBeDirect(k, body2)
willBeDirectItself || notFreeinContinuation
case Stmt.Region(id, ks, body) => notIn(body)
case Stmt.Alloc(id, init, region, body) => notIn(init) && canBeDirect(k, body)
case Stmt.Var(id, init, ks2, body) => notIn(init) && canBeDirect(k, body)
case Stmt.Dealloc(ref, body) => canBeDirect(k, body)
case Stmt.Get(ref, id, body) => canBeDirect(k, body)
case Stmt.Put(ref, value, body) => notIn(value) && canBeDirect(k, body)
case Stmt.Reset(prog, ks, k) => notIn(stmt)
case Stmt.Shift(prompt, body, ks, k) => notIn(stmt)
case Stmt.Resume(resumption, body, ks, k) => notIn(stmt)
case Stmt.Hole() => true
}



// Thunking
// --------
Expand Down

0 comments on commit df1adc5

Please sign in to comment.