Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor/replace environments by substitution #787

Merged
merged 16 commits into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion effekt/shared/src/main/scala/effekt/core/Tree.scala
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ enum Stmt extends Tree {
case Alloc(id: Id, init: Pure, region: Id, body: Stmt)

// creates a fresh state handler to model local (backtrackable) state.
// [[capture]] is a binding occurence.
// [[capture]] is a binding occurrence.
// e.g. state(init) { [x]{x: Ref} => ... }
case Var(id: Id, init: Pure, capture: Id, body: Stmt)
case Get(id: Id, annotatedCapt: Captures, annotatedTpe: ValueType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,10 @@ object Normalizer { normal =>

// TODO for `New` we should track how often each operation is used, not the object itself
// to decide inlining.
private def shouldInline(b: BlockLit, boundBy: Option[BlockVar])(using C: Context): Boolean = boundBy match {
private def shouldInline(b: BlockLit, boundBy: Option[BlockVar], blockArgs: List[Block])(using C: Context): Boolean = boundBy match {
case Some(id) if isRecursive(id.id) => false
case Some(id) => isOnce(id.id) || b.body.size <= C.maxInlineSize
case None => true
case _ => blockArgs.exists { b => b.isInstanceOf[BlockLit] } // higher-order function with known arg
}

private def active(e: Expr)(using Context): Expr =
Expand Down Expand Up @@ -171,7 +171,7 @@ object Normalizer { normal =>
// -------
case Stmt.App(b, targs, vargs, bargs) =>
active(b) match {
case NormalizedBlock.Known(b: BlockLit, boundBy) if shouldInline(b, boundBy) =>
case NormalizedBlock.Known(b: BlockLit, boundBy) if shouldInline(b, boundBy, bargs) =>
reduce(b, targs, vargs.map(normalize), bargs.map(normalize))
case normalized =>
Stmt.App(normalized.shared, targs, vargs.map(normalize), bargs.map(normalize))
Expand All @@ -181,7 +181,7 @@ object Normalizer { normal =>
active(b) match {
case n @ NormalizedBlock.Known(Block.New(impl), boundBy) =>
selectOperation(impl, method) match {
case b: BlockLit if shouldInline(b, boundBy) => reduce(b, targs, vargs.map(normalize), bargs.map(normalize))
case b: BlockLit if shouldInline(b, boundBy, bargs) => reduce(b, targs, vargs.map(normalize), bargs.map(normalize))
case _ => Stmt.Invoke(n.shared, method, methodTpe, targs, vargs.map(normalize), bargs.map(normalize))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,14 @@ object Optimizer extends Phase[CoreTransformed, CoreTransformed] {
def normalize(m: ModuleDecl) = {
val anfed = BindSubexpressions.transform(m)
val normalized = Normalizer.normalize(Set(mainSymbol), anfed, Context.config.maxInlineSize().toInt)
Deadcode.remove(mainSymbol, normalized)
val live = Deadcode.remove(mainSymbol, normalized)
RemoveTailResumptions(live)
}

// (3) normalize once and remove beta redexes
// (3) normalize a few times (since tail resumptions might only surface after normalization and leave dead Resets)
tree = Context.timed("normalize-1", source.name) { normalize(tree) }

// (4) optimize continuation capture in the tail-resumptive case
tree = Context.timed("tail-resumptions", source.name) { RemoveTailResumptions(tree) }

// (5) normalize again to clean up and remove new redexes
tree = Context.timed("normalize-2", source.name) { normalize(tree) }
tree = Context.timed("normalize-3", source.name) { normalize(tree) }

tree
}
179 changes: 179 additions & 0 deletions effekt/shared/src/main/scala/effekt/cps/Tree.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ package cps
import core.{ Id, ValueType, BlockType, Captures }
import effekt.source.FeatureFlag
import effekt.util.messages.ErrorReporter
import effekt.util.messages.INTERNAL_ERROR


sealed trait Tree extends Product {
Expand Down Expand Up @@ -231,3 +232,181 @@ object Variables {
case Cont.ContLam(result, ks, body) => free(body) -- value(result) -- meta(ks)
}
}


object substitutions {

case class Substitution(
values: Map[Id, Pure] = Map.empty,
blocks: Map[Id, Block] = Map.empty,
conts: Map[Id, Cont] = Map.empty,
metaconts: Map[Id, MetaCont] = Map.empty
) {
def shadowValues(shadowed: IterableOnce[Id]): Substitution = copy(values = values -- shadowed)
def shadowBlocks(shadowed: IterableOnce[Id]): Substitution = copy(blocks = blocks -- shadowed)
def shadowConts(shadowed: IterableOnce[Id]): Substitution = copy(conts = conts -- shadowed)
def shadowMetaconts(shadowed: IterableOnce[Id]): Substitution = copy(metaconts = metaconts -- shadowed)

def shadowParams(vparams: Seq[Id], bparams: Seq[Id]): Substitution =
copy(values = values -- vparams, blocks = blocks -- bparams)
}

def substitute(expression: Expr)(using Substitution): Expr = expression match {
case DirectApp(id, vargs, bargs) =>
DirectApp(id, vargs.map(substitute), bargs.map(substitute))
case p: Pure => substitute(p)
}

def substitute(pure: Pure)(using subst: Substitution): Pure = pure match {
case ValueVar(id) if subst.values.isDefinedAt(id) => subst.values(id)
case ValueVar(id) => ValueVar(id)
case Literal(value) => Literal(value)
case Make(tpe, tag, vargs) => Make(tpe, tag, vargs.map(substitute))
case PureApp(id, vargs) => PureApp(id, vargs.map(substitute))
case Box(b) => Box(substitute(b))
}

def substitute(block: Block)(using subst: Substitution): Block = block match {
case BlockVar(id) if subst.blocks.isDefinedAt(id) => subst.blocks(id)
case BlockVar(id) => BlockVar(id)
case b: BlockLit => substitute(b)
case Unbox(pure) => Unbox(substitute(pure))
case New(impl) => New(substitute(impl))
}

def substitute(b: BlockLit)(using subst: Substitution): BlockLit = b match {
case BlockLit(vparams, bparams, ks, k, body) =>
BlockLit(vparams, bparams, ks, k,
substitute(body)(using subst
.shadowParams(vparams, bparams)
.shadowMetaconts(List(ks))
.shadowConts(List(k))))
}

def substitute(stmt: Stmt)(using subst: Substitution): Stmt = stmt match {
case Jump(k, arg, ks) =>
Jump(
substituteAsContVar(k),
substitute(arg),
substitute(ks))

case App(callee, vargs, bargs, ks, k) =>
App(
substitute(callee),
vargs.map(substitute),
bargs.map(substitute),
substitute(ks),
substitute(k))

case Invoke(callee, method, vargs, bargs, ks, k) =>
Invoke(
substitute(callee),
method,
vargs.map(substitute),
bargs.map(substitute),
substitute(ks),
substitute(k))

case If(cond, thn, els) =>
If(substitute(cond), substitute(thn), substitute(els))

case Match(scrutinee, clauses, default) =>
Match(
substitute(scrutinee),
clauses.map { case (id, cl) => (id, substitute(cl)) },
default.map(substitute))

case LetDef(id, binding, body) =>
LetDef(id, substitute(binding),
substitute(body)(using subst.shadowBlocks(List(id))))

case LetExpr(id, binding, body) =>
LetExpr(id, substitute(binding),
substitute(body)(using subst.shadowValues(List(id))))

case LetCont(id, binding, body) =>
LetCont(id, substitute(binding),
substitute(body)(using subst.shadowConts(List(id))))

case Region(id, ks, body) =>
Region(id, substitute(ks),
substitute(body)(using subst.shadowBlocks(List(id))))

case Alloc(id, init, region, body) =>
Alloc(id, substitute(init), substituteAsBlockVar(region),
substitute(body)(using subst.shadowBlocks(List(id))))

case Var(id, init, ks, body) =>
Var(id, substitute(init), substitute(ks),
substitute(body)(using subst.shadowBlocks(List(id))))

case Dealloc(ref, body) =>
Dealloc(substituteAsBlockVar(ref), substitute(body))

case Get(ref, id, body) =>
Get(substituteAsBlockVar(ref), id,
substitute(body)(using subst.shadowValues(List(id))))

case Put(ref, value, body) =>
Put(substituteAsBlockVar(ref), substitute(value), substitute(body))

case Reset(prog, ks, k) =>
Reset(substitute(prog), substitute(ks), substitute(k))

case Shift(prompt, body, ks, k) =>
Shift(substituteAsBlockVar(prompt), substitute(body), substitute(ks), substitute(k))

case Resume(r, body, ks, k) =>
Resume(substituteAsBlockVar(r), substitute(body), substitute(ks), substitute(k))

case h: Hole => h
}

def substitute(impl: Implementation)(using Substitution): Implementation = impl match {
case Implementation(interface, operations) =>
Implementation(interface, operations.map(substitute))
}

def substitute(op: Operation)(using subst: Substitution): Operation = op match {
case Operation(name, vparams, bparams, ks, k, body) =>
Operation(name, vparams, bparams, ks, k,
substitute(body)(using subst
.shadowParams(vparams, bparams)
.shadowMetaconts(List(ks))
.shadowConts(List(k))))
}

def substitute(clause: Clause)(using subst: Substitution): Clause = clause match {
case Clause(vparams, body) =>
Clause(vparams, substitute(body)(using subst.shadowValues(vparams)))
}

def substitute(k: Cont)(using subst: Substitution): Cont = k match {
case Cont.ContVar(id) if subst.conts.isDefinedAt(id) => subst.conts(id)
case Cont.ContVar(id) => Cont.ContVar(id)
case lam @ Cont.ContLam(result, ks, body) => substitute(lam)
}

def substitute(k: Cont.ContLam)(using subst: Substitution): Cont.ContLam = k match {
case Cont.ContLam(result, ks, body) =>
Cont.ContLam(result, ks,
substitute(body)(using subst
.shadowValues(List(result))
.shadowMetaconts(List(ks))))
}

def substitute(ks: MetaCont)(using subst: Substitution): MetaCont =
subst.metaconts.getOrElse(ks.id, ks)

def substituteAsBlockVar(id: Id)(using subst: Substitution): Id =
subst.blocks.get(id) map {
case BlockVar(x) => x
case _ => INTERNAL_ERROR("References should always be variables")
} getOrElse id

def substituteAsContVar(id: Id)(using subst: Substitution): Id =
subst.conts.get(id) map {
case Cont.ContVar(x) => x
case _ => INTERNAL_ERROR("Continuation references should always be variables")
} getOrElse id
}
Loading
Loading