Skip to content

Commit

Permalink
Make local continuations direct style if possible
Browse files Browse the repository at this point in the history
  • Loading branch information
b-studios committed Jan 14, 2025
1 parent 112bc87 commit 0426238
Show file tree
Hide file tree
Showing 2 changed files with 139 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -209,13 +209,39 @@ object Normalizer { normal =>

case Stmt.Val(id, tpe, binding, body) =>

def normalizeVal(id: Id, tpe: ValueType, binding: Stmt, body: Stmt): Stmt = binding match {
def barendregdt(stmt: Stmt): Stmt = new Renamer().apply(stmt)

def normalizeVal(id: Id, tpe: ValueType, binding: Stmt, body: Stmt): Stmt = normalize(binding) match {

// [[ val x = sc match { case id(ps) => body2 }; body ]] = sc match { case id(ps) => val x = body2; body }
case Stmt.Match(sc, List((id2, BlockLit(tparams2, cparams2, vparams2, bparams2, body2))), None) =>
Stmt.Match(sc, List((id2, BlockLit(tparams2, cparams2, vparams2, bparams2,
normalizeVal(id, tpe, body2, body)))), None)

// These rewrites do not seem to contribute a lot given their complexity...
// vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv

// [[ val x = if (cond) { thn } else { els }; body ]] = if (cond) { [[ val x = thn; body ]] } else { [[ val x = els; body ]] }
// case normalized @ Stmt.If(cond, thn, els) if body.size <= 2 =>
// // since we duplicate the body, we need to freshen the names
// val normalizedThn = barendregdt(normalizeVal(id, tpe, thn, body))
// val normalizedEls = barendregdt(normalizeVal(id, tpe, els, body))
//
// Stmt.If(cond, normalizedThn, normalizedEls)
//
// case Stmt.Match(sc, clauses, default)
// // necessary since otherwise we loose Nothing-boxing
// // vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv
// if body.size <= 2 && (clauses.size + default.size) >= 1 =>
// val normalizedClauses = clauses map {
// case (id2, BlockLit(tparams2, cparams2, vparams2, bparams2, body2)) =>
// (id2, BlockLit(tparams2, cparams2, vparams2, bparams2, barendregdt(normalizeVal(id, tpe, body2, body))): BlockLit)
// }
// val normalizedDefault = default map { stmt => barendregdt(normalizeVal(id, tpe, stmt, body)) }
// Stmt.Match(sc, normalizedClauses, normalizedDefault)

// ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

// [[ val x = return e; s ]] = let x = [[ e ]]; [[ s ]]
case Stmt.Return(expr2) =>
Stmt.Let(id, tpe, expr2, normalize(body)(using C.bind(id, expr2)))
Expand Down Expand Up @@ -245,7 +271,7 @@ object Normalizer { normal =>
case normalizedBody => Stmt.Val(id, tpe, other, normalizedBody)
}
}
normalizeVal(id, tpe, normalize(binding), body)
normalizeVal(id, tpe, binding, body)


// "Congruences"
Expand Down
127 changes: 111 additions & 16 deletions effekt/shared/src/main/scala/effekt/generator/js/TransformerCps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import effekt.context.Context
import effekt.context.assertions.*
import effekt.cps.*
import effekt.core.{ DeclarationContext, Id }
import effekt.cps.Variables.{ all, free }

import scala.collection.mutable

Expand Down Expand Up @@ -34,13 +35,21 @@ object TransformerCps extends Transformer {
}
}

case class ContinuationInfo(k: Id, param: Id, ks: Id)

case class TransformerContext(
requiresThunk: Boolean,
bindings: Map[Id, js.Expr],
// definitions of externs (used to inline them)
externs: Map[Id, cps.Extern.Def],
// currently, lexically enclosing functions and their parameters (used to determine whether a call is recursive, to rewrite into a loop)
definitions: Map[Id, DefInfo],
// the direct-style continuation, if available
directStyle: Option[ContinuationInfo],
// the current (direct-style) metacontinuation
metacont: Option[Id],
// substitutions for renaming of metaconts
metaconts: Map[Id, Id],
// the original declaration context (used to compile pattern matching)
declarations: DeclarationContext,
// the usual compiler context
Expand All @@ -50,13 +59,23 @@ object TransformerCps extends Transformer {

def lookup(id: Id)(using C: TransformerContext): js.Expr = C.bindings.getOrElse(id, nameRef(id))

def enterDefinition(id: Id, used: Used, block: cps.Block)(using C: TransformerContext): TransformerContext = block match {
def recursive(id: Id, used: Used, block: cps.Block)(using C: TransformerContext): TransformerContext = block match {
case cps.BlockLit(vparams, bparams, ks, k, body) =>
C.copy(definitions = Map(id -> DefInfo(id, vparams, bparams, ks, k, used)))
C.copy(definitions = Map(id -> DefInfo(id, vparams, bparams, ks, k, used)), directStyle = None, metacont = Some(ks))
case _ => C
}

def clearDefinitions(using C: TransformerContext): TransformerContext = C.copy(definitions = Map.empty)
def nonrecursive(ks: Id)(using C: TransformerContext): TransformerContext =
C.copy(definitions = Map.empty, directStyle = None, metacont = Some(ks))

def nonrecursive(block: cps.BlockLit)(using C: TransformerContext): TransformerContext = nonrecursive(block.ks)

// ks | let k1 x1 ks1 = { let k2 x2 ks2 = jump k v ks2 }; ... = jump k v ks
def directstyle(ks: Id)(using C: TransformerContext): TransformerContext =
val outer = C.metacont.getOrElse { sys error "Metacontinuation missing..." }
val outerSubstituted = C.metaconts.getOrElse(outer, outer)
val subst = C.metaconts.updated(ks, outerSubstituted)
C.copy(metacont = Some(ks), metaconts = subst)

def bindingAll[R](bs: List[(Id, js.Expr)])(body: TransformerContext ?=> R)(using C: TransformerContext): R =
body(using C.copy(bindings = C.bindings ++ bs))
Expand All @@ -79,6 +98,9 @@ object TransformerCps extends Transformer {
Map.empty,
externs.collect { case d: Extern.Def => (d.id, d) }.toMap,
Map.empty,
None,
None,
Map.empty,
D, C)

val name = JSName(jsModuleName(module.path))
Expand All @@ -101,6 +123,9 @@ object TransformerCps extends Transformer {
Map.empty,
input.externs.collect { case d: Extern.Def => (d.id, d) }.toMap,
Map.empty,
None,
None,
Map.empty,
D, C)

input.definitions.map(toJS)
Expand Down Expand Up @@ -160,7 +185,7 @@ object TransformerCps extends Transformer {
case cps.Block.BlockLit(vparams, bparams, ks, k, body) =>
val used = new Used(false)

val translatedBody = toJS(body)(using enterDefinition(id, used, b)).stmts
val translatedBody = toJS(body)(using recursive(id, used, b)).stmts

if used.used then
js.Lambda(vparams.map(nameDef) ++ bparams.map(nameDef) ++ List(nameDef(ks), nameDef(k)),
Expand All @@ -169,7 +194,7 @@ object TransformerCps extends Transformer {
js.Lambda(vparams.map(nameDef) ++ bparams.map(nameDef) ++ List(nameDef(ks), nameDef(k)),
translatedBody)

case other => toJS(other)(using clearDefinitions)
case other => toJS(other)
}

def toJS(b: cps.Block)(using TransformerContext): js.Expr = b match {
Expand All @@ -185,17 +210,18 @@ object TransformerCps extends Transformer {
case cps.Implementation(interface, operations) =>
js.Object(operations.map {
case cps.Operation(id, vps, bps, ks, k, body) =>
nameDef(id) -> js.Lambda(vps.map(nameDef) ++ bps.map(nameDef) ++ List(nameDef(ks), nameDef(k)), toJS(body)(using clearDefinitions).stmts)
nameDef(id) -> js.Lambda(vps.map(nameDef) ++ bps.map(nameDef) ++ List(nameDef(ks), nameDef(k)), toJS(body)(using nonrecursive(ks)).stmts)
})
}

def toJS(ks: cps.MetaCont): js.Expr = nameRef(ks.id)
def toJS(ks: cps.MetaCont)(using T: TransformerContext): js.Expr =
nameRef(T.metaconts.getOrElse(ks.id, ks.id))

def toJS(k: cps.Cont)(using T: TransformerContext): js.Expr = k match {
case Cont.ContVar(id) =>
nameRef(id)
case Cont.ContLam(result, ks, body) =>
js.Lambda(List(nameDef(result), nameDef(ks)), toJS(body)(using clearDefinitions).stmts)
js.Lambda(List(nameDef(result), nameDef(ks)), toJS(body)(using nonrecursive(ks)).stmts)
}

def toJS(e: cps.Expr)(using D: TransformerContext): js.Expr = e match {
Expand Down Expand Up @@ -225,9 +251,22 @@ object TransformerCps extends Transformer {
js.Const(nameDef(id), toJS(binding)) :: toJS(body).run(k)
}

case cps.Stmt.LetCont(id, binding, body) =>
// [[ let k(x, ks) = ...; if (...) jump k(42, ks2) else jump k(10, ks3) ]] =
// let x; if (...) { x = 42; ks = ks2 } else { x = 10; ks = ks3 } ...
case cps.Stmt.LetCont(id, Cont.ContLam(param, ks, body), body2) if canBeDirect(id, body2) =>
Binding { k =>
js.Const(nameDef(id), toJS(binding)) :: requiringThunk { toJS(body)(using clearDefinitions) }.run(k)
val withDirectStyle = D.copy(directStyle = Some(ContinuationInfo(id, param, ks)))
val renamingMetacont = directstyle(ks)

val translatedIf = toJS(body2)(using withDirectStyle)
val translatedBody = toJS(body)(using renamingMetacont)

js.Let(nameDef(param), js.Undefined) :: translatedIf.stmts ++ translatedBody.run(k)
}

case cps.Stmt.LetCont(id, binding @ Cont.ContLam(result2, ks2, body2), body) =>
Binding { k =>
js.Const(nameDef(id), toJS(binding)(using nonrecursive(ks2))) :: requiringThunk { toJS(body) }.run(k)
}

case cps.Stmt.Match(sc, Nil, None) =>
Expand All @@ -245,14 +284,27 @@ object TransformerCps extends Transformer {
pure(js.Switch(js.Member(scrutinee, `tag`),
clauses.map { case (tag, clause) =>
val (e, binding) = toJS(scrutinee, tag, clause);
(e, binding.stmts)

val stmts = binding.stmts

stmts.last match {
case terminator : (js.Stmt.Return | js.Stmt.Break | js.Stmt.Continue) => (e, stmts)
case other => (e, stmts :+ js.Break())
}
},
default.map { s => toJS(s).stmts }))

case cps.Stmt.Jump(k, arg, ks) if D.directStyle.exists(c => c.k == k) => D.directStyle match {
case Some(ContinuationInfo(k2, param2, ks2)) => pure(js.Assign(nameRef(param2), toJS(arg)))
case None => sys error "Should not happen"
}

case cps.Stmt.Jump(k, arg, ks) =>
pure(js.Return(maybeThunking(js.Call(nameRef(k), toJS(arg), toJS(ks)))))

case cps.Stmt.App(callee @ DefInfo(id, vparams, bparams, ks1, k1, used), vargs, bargs, MetaCont(ks), Cont.ContVar(k)) if ks1 == ks && k1 == k =>
case cps.Stmt.App(DefInfo(id, vparams, bparams, ks1, k1, used), vargs, bargs, MetaCont(ks), Cont.ContVar(k))
// this call is a tailcall if the metacontinuation (after substitution) is the same and the continuation is the same.
if ks1 == D.metaconts.getOrElse(ks, ks) && k1 == k =>
Binding { k2 =>
val stmts = mutable.ListBuffer.empty[js.Stmt]
stmts.append(js.RawStmt("/* prepare tail call */"))
Expand Down Expand Up @@ -336,13 +388,13 @@ object TransformerCps extends Transformer {
}

case cps.Stmt.Reset(prog, ks, k) =>
pure(js.Return(Call(RESET, toJS(prog)(using clearDefinitions), toJS(ks), toJS(k))))
pure(js.Return(Call(RESET, toJS(prog)(using nonrecursive(prog)), toJS(ks), toJS(k))))

case cps.Stmt.Shift(prompt, body, ks, k) =>
pure(js.Return(Call(SHIFT, nameRef(prompt), noThunking { toJS(body)(using clearDefinitions) }, toJS(ks), toJS(k))))
pure(js.Return(Call(SHIFT, nameRef(prompt), noThunking { toJS(body)(using nonrecursive(body)) }, toJS(ks), toJS(k))))

case cps.Stmt.Resume(r, b, ks2, k2) =>
pure(js.Return(js.Call(RESUME, nameRef(r), toJS(b)(using clearDefinitions), toJS(ks2), toJS(k2))))
pure(js.Return(js.Call(RESUME, nameRef(r), toJS(b)(using nonrecursive(b)), toJS(ks2), toJS(k2))))

case cps.Stmt.Hole() =>
pure(js.Return($effekt.call("hole")))
Expand Down Expand Up @@ -383,12 +435,55 @@ object TransformerCps extends Transformer {
case _ => js.Call(nameRef(id), args.map(toJS))
}

def canInline(extern: cps.Extern): Boolean = extern match {
private def canInline(extern: cps.Extern): Boolean = extern match {
case cps.Extern.Def(_, _, Nil, async, ExternBody.StringExternBody(_, Template(_, _))) => !async
case _ => false
}


// Predicates for Direct-Style Transformation
// ------------------------------------------

private def canBeDirect(k: Id, stmt: Stmt)(using T: TransformerContext): Boolean =
def notIn(term: Stmt | Block | Expr | (Id, Clause) | Cont) =
val freeVars = term match {
case s: Stmt => free(s)
case b: Block => free(b)
case p: Expr => free(p)
case (id, Clause(_, body)) => free(body)
case c: Cont => free(c)
}
!freeVars.contains(k)
stmt match {
case Stmt.Jump(k2, arg, ks2) if k2 == k => notIn(arg) && T.metacont.contains(ks2.id)
case Stmt.Jump(k2, arg, ks2) => notIn(arg)
// TODO this could be a tailcall!
case Stmt.App(callee, vargs, bargs, ks, k) => notIn(stmt)
case Stmt.Invoke(callee, method, vargs, bargs, ks, k2) => notIn(stmt)
case Stmt.If(cond, thn, els) => canBeDirect(k, thn) && canBeDirect(k, els)
case Stmt.Match(scrutinee, clauses, default) => clauses.forall {
case (id, Clause(vparams, body)) => canBeDirect(k, body)
} && default.forall(body => canBeDirect(k, body))
case Stmt.LetDef(id, binding, body) => notIn(binding) && canBeDirect(k, body)
case Stmt.LetExpr(id, binding, body) => notIn(binding) && canBeDirect(k, body)
case Stmt.LetCont(id, Cont.ContLam(result, ks2, body), body2) =>
def willBeDirectItself = canBeDirect(id, body2) && canBeDirect(k, body)(using directstyle(ks2))
def notFreeinContinuation = notIn(body) && canBeDirect(k, body2)
willBeDirectItself || notFreeinContinuation
case Stmt.Region(id, ks, body) => notIn(body)
case Stmt.Alloc(id, init, region, body) => notIn(init) && canBeDirect(k, body)
case Stmt.Var(id, init, ks2, body) => notIn(init) && canBeDirect(k, body)
case Stmt.Dealloc(ref, body) => canBeDirect(k, body)
case Stmt.Get(ref, id, body) => canBeDirect(k, body)
case Stmt.Put(ref, value, body) => notIn(value) && canBeDirect(k, body)
case Stmt.Reset(prog, ks, k) => notIn(stmt)
case Stmt.Shift(prompt, body, ks, k) => notIn(stmt)
case Stmt.Resume(resumption, body, ks, k) => notIn(stmt)
case Stmt.Hole() => true
}



// Thunking
// --------

Expand Down

0 comments on commit 0426238

Please sign in to comment.