Skip to content

Commit

Permalink
Remove Lens to allow for easy transition to Scala3 (wavesplatform#3970)
Browse files Browse the repository at this point in the history
  • Loading branch information
vladimirlogachev authored Sep 27, 2024
1 parent 1fa1513 commit d51a17f
Show file tree
Hide file tree
Showing 9 changed files with 61 additions and 94 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import com.wavesplatform.lang.v1.compiler.Types.*
import com.wavesplatform.lang.v1.evaluator.ctx.FunctionTypeSignature
import com.wavesplatform.lang.v1.parser.Expressions.Pos
import com.wavesplatform.lang.v1.parser.Expressions.Pos.AnyPos
import shapeless.*

case class CompilerContext(
predefTypes: Map[String, FINAL],
Expand Down Expand Up @@ -60,9 +59,5 @@ object CompilerContext {
y.provideRuntimeTypeOnCastError
)

val types: Lens[CompilerContext, Map[String, FINAL]] = lens[CompilerContext] >> Symbol("predefTypes")
val vars: Lens[CompilerContext, VariableTypes] = lens[CompilerContext] >> Symbol("varDefs")
val functions: Lens[CompilerContext, FunctionTypes] = lens[CompilerContext] >> Symbol("functionDefs")

val empty = CompilerContext(Map(), Map(), Map(), true)
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import com.wavesplatform.lang.contract.DApp.*
import com.wavesplatform.lang.contract.meta.{MetaMapper, V1 as MetaV1, V2 as MetaV2}
import com.wavesplatform.lang.directives.values.{StdLibVersion, V3, V6}
import com.wavesplatform.lang.v1.compiler.CompilationError.{AlreadyDefined, Generic, UnionNotAllowedForCallableArgs, WrongArgumentType}
import com.wavesplatform.lang.v1.compiler.CompilerContext.{VariableInfo, vars}
import com.wavesplatform.lang.v1.compiler.CompilerContext.VariableInfo
import com.wavesplatform.lang.v1.compiler.ContractCompiler.*
import com.wavesplatform.lang.v1.compiler.ScriptResultSource.FreeCall
import com.wavesplatform.lang.v1.compiler.Terms.EXPR
Expand Down Expand Up @@ -87,7 +87,7 @@ class ContractCompiler(version: StdLibVersion) extends ExpressionCompiler(versio
.getOrElse(List.empty)
unionInCallableErrs <- checkCallableUnions(af, annotationsWithErr._1.toList.flatten)
compiledBody <- local {
modify[Id, CompilerContext, CompilationError](vars.modify(_)(_ ++ annotationBindings)).flatMap(_ =>
modify[Id, CompilerContext, CompilationError](ctx => ctx.copy(varDefs = ctx.varDefs ++ annotationBindings)).flatMap(_ =>
compileFunc(af.f.position, af.f, saveExprContext, annotationBindings.map(_._1), allowIllFormedStrings)
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ class ExpressionCompiler(val version: StdLibVersion) {
.handleError()
compiledFuncBody <- local {
val newArgs: VariableTypes = argTypesWithErr._1.getOrElse(List.empty).toMap
modify[Id, CompilerContext, CompilationError](vars.modify(_)(_ ++ newArgs))
modify[Id, CompilerContext, CompilationError](ctx1 => ctx1.copy(varDefs = ctx1.varDefs ++ newArgs))
.flatMap(_ => compileExprWithCtx(func.expr, saveExprContext, allowIllFormedStrings))
}

Expand All @@ -368,10 +368,10 @@ class ExpressionCompiler(val version: StdLibVersion) {
}

protected def updateCtx(letName: String, letType: Types.FINAL, p: Pos): CompileM[Unit] =
modify[Id, CompilerContext, CompilationError](vars.modify(_)(_ + (letName -> VariableInfo(p, letType))))
modify[Id, CompilerContext, CompilationError](ctx => ctx.copy(varDefs = ctx.varDefs + (letName -> VariableInfo(p, letType))))

protected def updateCtx(funcName: String, typeSig: FunctionTypeSignature, p: Pos): CompileM[Unit] =
modify[Id, CompilerContext, CompilationError](functions.modify(_)(_ + (funcName -> FunctionInfo(p, List(typeSig)))))
modify[Id, CompilerContext, CompilationError](ctx => ctx.copy(functionDefs = ctx.functionDefs + (funcName -> FunctionInfo(p, List(typeSig)))))

private def compileLetBlock(
p: Pos,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,27 +1,18 @@
package com.wavesplatform.lang.v1.estimator.v2


import com.wavesplatform.lang.v1.FunctionHeader
import com.wavesplatform.lang.v1.compiler.Terms.FUNC
import com.wavesplatform.lang.v1.estimator.EstimationError
import com.wavesplatform.lang.v1.estimator.v2.EstimatorContext.EvalM
import com.wavesplatform.lang.v1.task.TaskM
import shapeless.{Lens, lens}

private[v2] case class EstimatorContext(
letDefs: Map[String, (Boolean, EvalM[Long])],
predefFuncs: Map[FunctionHeader, Long],
userFuncs: Map[FunctionHeader, FUNC] = Map.empty,
overlappedRefs: Map[String, (Boolean, EvalM[Long])] = Map.empty
letDefs: Map[String, (Boolean, EvalM[Long])],
predefFuncs: Map[FunctionHeader, Long],
userFuncs: Map[FunctionHeader, FUNC] = Map.empty,
overlappedRefs: Map[String, (Boolean, EvalM[Long])] = Map.empty
)

private[v2] object EstimatorContext {
type EvalM[A] = TaskM[EstimatorContext, EstimationError, A]

object Lenses {
val lets: Lens[EstimatorContext, Map[String, (Boolean, EvalM[Long])]] = lens[EstimatorContext] >> Symbol("letDefs")
val userFuncs: Lens[EstimatorContext, Map[FunctionHeader, FUNC]] = lens[EstimatorContext] >> Symbol("userFuncs")
val predefFuncs: Lens[EstimatorContext, Map[FunctionHeader, Long]] = lens[EstimatorContext] >> Symbol("predefFuncs")
val overlappedRefs: Lens[EstimatorContext, Map[String, (Boolean, EvalM[Long])]] = lens[EstimatorContext] >> Symbol("overlappedRefs")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import com.wavesplatform.lang.v1.FunctionHeader
import com.wavesplatform.lang.v1.compiler.Terms._
import com.wavesplatform.lang.v1.estimator.{EstimationError, ScriptEstimator}
import com.wavesplatform.lang.v1.estimator.v2.EstimatorContext.EvalM
import com.wavesplatform.lang.v1.estimator.v2.EstimatorContext.Lenses._
import com.wavesplatform.lang.v1.task.imports._
import monix.eval.Coeval

Expand Down Expand Up @@ -45,7 +44,7 @@ object ScriptEstimatorV2 extends ScriptEstimator {
local {
val letResult = (false, evalExpr(let.value))
for {
_ <- update(lets.modify(_)(_.updated(let.name, letResult)))
_ <- update(ctx => ctx.copy(letDefs = ctx.letDefs.updated(let.name, letResult)))
r <- evalExpr(inner)
} yield r + 5
}
Expand All @@ -61,31 +60,31 @@ object ScriptEstimatorV2 extends ScriptEstimator {
local {
for {
_ <- checkFuncCtx(func)
_ <- update(userFuncs.modify(_)(_ + (FunctionHeader.User(func.name) -> func)))
_ <- update(ctx => ctx.copy(userFuncs = ctx.userFuncs + (FunctionHeader.User(func.name) -> func)))
r <- evalExpr(inner)
} yield r + 5
}

private def checkFuncCtx(func: FUNC): EvalM[Unit] =
local {
for {
_ <- update(lets.modify(_)(_ ++ func.args.map((_, (true, const(0)))).toMap))
_ <- update(ctx => ctx.copy(letDefs = ctx.letDefs ++ func.args.map((_, (true, const(0)))).toMap))
_ <- evalExpr(func.body)
} yield ()
}

private def evalRef(key: String): EvalM[Long] =
for {
ctx <- get[Id, EstimatorContext, EstimationError]
r <- lets.get(ctx).get(key) match {
r <- ctx.letDefs.get(key) match {
case Some((false, lzy)) => setRefEvaluated(key, lzy)
case Some((true, _)) => const(0)
case None => raiseError[Id, EstimatorContext, EstimationError, Long](s"A definition of '$key' not found")
}
} yield r + 2

private def setRefEvaluated(key: String, lzy: EvalM[Long]): EvalM[Long] =
update(lets.modify(_)(_.updated(key, (true, lzy))))
update(ctx => ctx.copy(letDefs = ctx.letDefs.updated(key, (true, lzy))))
.flatMap(_ => lzy)

private def evalGetter(expr: EXPR): EvalM[Long] =
Expand All @@ -94,26 +93,36 @@ object ScriptEstimatorV2 extends ScriptEstimator {
private def evalFuncCall(header: FunctionHeader, args: List[EXPR]): EvalM[Long] =
for {
ctx <- get[Id, EstimatorContext, EstimationError]
bodyComplexity <- predefFuncs
.get(ctx)
bodyComplexity <- ctx.predefFuncs
.get(header)
.map(bodyComplexity => evalFuncArgs(args).map(_ + bodyComplexity))
.orElse(userFuncs.get(ctx).get(header).map(evalUserFuncCall(_, args)))
.orElse(ctx.userFuncs.get(header).map(evalUserFuncCall(_, args)))
.getOrElse(raiseError[Id, EstimatorContext, EstimationError, Long](s"function '$header' not found"))
} yield bodyComplexity

private def evalUserFuncCall(func: FUNC, args: List[EXPR]): EvalM[Long] =
for {
argsComplexity <- evalFuncArgs(args)
ctx <- get[Id, EstimatorContext, EstimationError]
_ <- update(lets.modify(_)(_ ++ ctx.overlappedRefs))
_ <- update(ctx1 => ctx1.copy(letDefs = ctx1.letDefs ++ ctx.overlappedRefs))
overlapped = func.args.flatMap(arg => ctx.letDefs.get(arg).map((arg, _))).toMap
ctxArgs = func.args.map((_, (false, const(1)))).toMap
_ <- update((lets ~ overlappedRefs).modify(_) { case (l, or) => (l ++ ctxArgs, or ++ overlapped) })
_ <- update(ctx1 =>
ctx1.copy(
letDefs = ctx1.letDefs ++ ctxArgs,
overlappedRefs = ctx1.overlappedRefs ++ overlapped
)
)

bodyComplexity <- evalExpr(func.body).map(_ + func.args.size * 5)
evaluatedCtx <- get[Id, EstimatorContext, EstimationError]
overlappedChanges = overlapped.map { case ref @ (name, _) => evaluatedCtx.letDefs.get(name).map((name, _)).getOrElse(ref) }
_ <- update((lets ~ overlappedRefs).modify(_) { case (l, or) => (l -- ctxArgs.keys ++ overlapped, or ++ overlappedChanges) })
_ <- update(ctx1 =>
ctx1.copy(
letDefs = ctx1.letDefs -- ctxArgs.keys ++ overlapped,
overlappedRefs = ctx1.overlappedRefs ++ overlappedChanges
)
)
} yield bodyComplexity + argsComplexity

private def evalFuncArgs(args: List[EXPR]): EvalM[Long] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import com.wavesplatform.lang.v1.estimator.EstimationError
import com.wavesplatform.lang.v1.estimator.v3.EstimatorContext.EvalM
import com.wavesplatform.lang.v1.task.TaskM
import monix.eval.Coeval
import shapeless.{Lens, lens}

private[v3] case class EstimatorContext(
funcs: Map[FunctionHeader, (Coeval[Long], Set[String])],
Expand All @@ -18,9 +17,4 @@ private[v3] case class EstimatorContext(

private[v3] object EstimatorContext {
type EvalM[A] = TaskM[EstimatorContext, EstimationError, A]

object Lenses {
val funcs: Lens[EstimatorContext, Map[FunctionHeader, (Coeval[Long], Set[String])]] = lens[EstimatorContext] >> Symbol("funcs")
val usedRefs: Lens[EstimatorContext, Set[String]] = lens[EstimatorContext] >> Symbol("usedRefs")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import com.wavesplatform.lang.v1.FunctionHeader
import com.wavesplatform.lang.v1.FunctionHeader.User
import com.wavesplatform.lang.v1.compiler.Terms.*
import com.wavesplatform.lang.v1.estimator.v3.EstimatorContext.EvalM
import com.wavesplatform.lang.v1.estimator.v3.EstimatorContext.Lenses.*
import com.wavesplatform.lang.v1.estimator.{EstimationError, ScriptEstimator}
import com.wavesplatform.lang.v1.task.imports.*
import monix.eval.Coeval
Expand Down Expand Up @@ -82,7 +81,7 @@ case class ScriptEstimatorV3(fixOverflow: Boolean, overhead: Boolean, letFixes:
letCosts <- usedRefs.toSeq.traverse { ref =>
local {
for {
_ <- update(funcs.set(_)(startCtx.funcs))
_ <- update(ctx1 => ctx1.copy(funcs = startCtx.funcs))
cost <- ctx.globalLetEvals.getOrElse(ref, zero)
} yield cost
}
Expand All @@ -100,22 +99,18 @@ case class ScriptEstimatorV3(fixOverflow: Boolean, overhead: Boolean, letFixes:
}

private def beforeNextExprEval(let: LET, eval: EvalM[Long]): EvalM[Unit] =
update(ctx =>
usedRefs
.modify(ctx)(_ - let.name)
.copy(refsCosts = ctx.refsCosts + (let.name -> local(eval)))
)
update(ctx => ctx.copy(usedRefs = ctx.usedRefs - let.name, refsCosts = ctx.refsCosts + (let.name -> local(eval))))

private def afterNextExprEval(let: LET, startCtx: EstimatorContext): EvalM[Unit] =
update(ctx =>
usedRefs
.modify(ctx)(r => if (startCtx.usedRefs.contains(let.name)) r + let.name else r - let.name)
.copy(refsCosts =
ctx.copy(
usedRefs = if (startCtx.usedRefs.contains(let.name)) ctx.usedRefs + let.name else ctx.usedRefs - let.name,
refsCosts =
if (startCtx.refsCosts.contains(let.name))
ctx.refsCosts + (let.name -> startCtx.refsCosts(let.name))
else
ctx.refsCosts - let.name
)
)
)

private def evalFuncBlock(func: FUNC, nextExpr: EXPR, activeFuncArgs: Set[String], globalDeclarationsMode: Boolean): EvalM[Long] =
Expand All @@ -142,14 +137,12 @@ case class ScriptEstimatorV3(fixOverflow: Boolean, overhead: Boolean, letFixes:
_ <- set[Id, EstimatorContext, EstimationError](ctx.copy(globalFunctionsCosts = ctx.globalFunctionsCosts + (name -> totalCost)))
} yield ()

private def handleUsedRefs(name: String, cost: Long, ctx: EstimatorContext, refsUsedInBody: Set[String]): EvalM[Unit] =
update(
(funcs ~ usedRefs).modify(_) { case (funcs, _) =>
(
funcs + (User(name) -> (Coeval.now(cost), refsUsedInBody)),
ctx.usedRefs
)
}
private def handleUsedRefs(name: String, cost: Long, startCtx: EstimatorContext, refsUsedInBody: Set[String]): EvalM[Unit] =
update(ctx =>
ctx.copy(
funcs = ctx.funcs + (User(name) -> (Coeval.now(cost), refsUsedInBody)),
usedRefs = startCtx.usedRefs
)
)

private def evalIF(cond: EXPR, ifTrue: EXPR, ifFalse: EXPR, activeFuncArgs: Set[String]): EvalM[Long] =
Expand All @@ -165,7 +158,7 @@ case class ScriptEstimatorV3(fixOverflow: Boolean, overhead: Boolean, letFixes:
if (activeFuncArgs.contains(key) && letFixes)
const(overheadCost)
else
update(usedRefs.modify(_)(_ + key)).map(_ => overheadCost)
update(ctx => ctx.copy(usedRefs = ctx.usedRefs + key)).map(_ => overheadCost)

private def evalGetter(expr: EXPR, activeFuncArgs: Set[String]): EvalM[Long] =
evalExpr(expr, activeFuncArgs).flatMap(sum(_, overheadCost))
Expand All @@ -187,18 +180,15 @@ case class ScriptEstimatorV3(fixOverflow: Boolean, overhead: Boolean, letFixes:
} yield result

private def setFuncToCtx(header: FunctionHeader, bodyCost: Coeval[Long], bodyUsedRefs: Set[EstimationError]): EvalM[Unit] =
update(
(funcs ~ usedRefs).modify(_) { case (funcs, usedRefs) =>
(
funcs + (header -> (bodyCost, Set())),
usedRefs ++ bodyUsedRefs
)
}
update(ctx =>
ctx.copy(
funcs = ctx.funcs + (header -> (bodyCost, Set())),
usedRefs = ctx.usedRefs ++ bodyUsedRefs
)
)

private def getFuncCost(header: FunctionHeader, ctx: EstimatorContext): EvalM[(Coeval[Long], Set[EstimationError])] =
funcs
.get(ctx)
ctx.funcs
.get(header)
.map(const)
.getOrElse(
Expand All @@ -217,9 +207,9 @@ case class ScriptEstimatorV3(fixOverflow: Boolean, overhead: Boolean, letFixes:
): EvalM[Long] =
for {
startCtx <- get[Id, EstimatorContext, EstimationError]
_ <- ctxFuncsOpt.fold(doNothing.void)(ctxFuncs => update(funcs.set(_)(ctxFuncs)))
_ <- ctxFuncsOpt.fold(doNothing.void)(ctxFuncs => update(ctx => ctx.copy(funcs = ctxFuncs)))
cost <- evalExpr(expr, activeFuncArgs)
_ <- update(funcs.set(_)(startCtx.funcs))
_ <- update(ctx => ctx.copy(funcs = startCtx.funcs))
} yield cost

private def withUsedRefs[A](eval: EvalM[A]): EvalM[(A, Set[String])] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import com.wavesplatform.lang.v1.compiler.Terms.*
import com.wavesplatform.lang.v1.compiler.Types.{CASETYPEREF, NOTHING}
import com.wavesplatform.lang.v1.evaluator.ContextfulNativeFunction.{Extended, Simple}
import com.wavesplatform.lang.v1.evaluator.ctx.*
import com.wavesplatform.lang.v1.evaluator.ctx.EnabledLogEvaluationContext.Lenses
import com.wavesplatform.lang.v1.task.imports.*
import com.wavesplatform.lang.v1.traits.Environment
import com.wavesplatform.lang.*
Expand All @@ -30,16 +29,16 @@ object EvaluatorV1 {
}

class EvaluatorV1[F[_]: Monad, C[_[_]]](implicit ev: Monad[EvalF[F, *]], ev2: Monad[CoevalF[F, *]]) {
private val lenses = new Lenses[F, C]
import lenses.*

private def evalLetBlock(let: LET, inner: EXPR): EvalM[F, C, (EvaluationContext[C, F], EVALUATED)] =
for {
ctx <- get[F, EnabledLogEvaluationContext[C, F], ExecutionError]
blockEvaluation = evalExpr(let.value)
lazyBlock = LazyVal(blockEvaluation.ter(ctx), ctx.l(let.name))
result <- local {
modify[F, EnabledLogEvaluationContext[C, F], ExecutionError](lets.modify(_)(_.updated(let.name, lazyBlock)))
modify[F, EnabledLogEvaluationContext[C, F], ExecutionError](ctx1 =>
ctx1.copy(ec = ctx1.ec.copy(letDefs = ctx1.ec.letDefs.updated(let.name, lazyBlock)))
)
.flatMap(_ => evalExprWithCtx(inner))
}
} yield result
Expand All @@ -49,15 +48,17 @@ class EvaluatorV1[F[_]: Monad, C[_[_]]](implicit ev: Monad[EvalF[F, *]], ev2: Mo
val function = UserFunction(func.name, 0, NOTHING, func.args.map(n => (n, NOTHING))*)(func.body)
.asInstanceOf[UserFunction[C]]
local {
modify[F, EnabledLogEvaluationContext[C, F], ExecutionError](funcs.modify(_)(_.updated(funcHeader, function)))
modify[F, EnabledLogEvaluationContext[C, F], ExecutionError](ctx =>
ctx.copy(ec = ctx.ec.copy(functions = ctx.ec.functions.updated(funcHeader, function)))
)
.flatMap(_ => evalExprWithCtx(inner))
}
}

private def evalRef(key: String): EvalM[F, C, (EvaluationContext[C, F], EVALUATED)] =
for {
ctx <- get[F, EnabledLogEvaluationContext[C, F], ExecutionError]
r <- lets.get(ctx).get(key) match {
r <- ctx.ec.letDefs.get(key) match {
case Some(lzy) => liftTER[F, C, EVALUATED](lzy.value)
case None => raiseError[F, EnabledLogEvaluationContext[C, F], ExecutionError, EVALUATED](s"A definition of '$key' not found")
}
Expand All @@ -83,8 +84,7 @@ class EvaluatorV1[F[_]: Monad, C[_[_]]](implicit ev: Monad[EvalF[F, *]], ev2: Mo
private def evalFunctionCall(header: FunctionHeader, args: List[EXPR]): EvalM[F, C, (EvaluationContext[C, F], EVALUATED)] =
for {
ctx <- get[F, EnabledLogEvaluationContext[C, F], ExecutionError]
result <- funcs
.get(ctx)
result <- ctx.ec.functions
.get(header)
.map {
case func: UserFunction[C] =>
Expand All @@ -94,7 +94,7 @@ class EvaluatorV1[F[_]: Monad, C[_[_]]](implicit ev: Monad[EvalF[F, *]], ev2: Mo
}
local {
val newState: EvalM[F, C, Unit] =
set[F, EnabledLogEvaluationContext[C, F], ExecutionError](lets.set(ctx)(letDefsWithArgs)).map(_.pure[F])
set[F, EnabledLogEvaluationContext[C, F], ExecutionError](ctx.copy(ec = ctx.ec.copy(letDefs = letDefsWithArgs))).map(_.pure[F])
Monad[EvalM[F, C, *]].flatMap(newState)(_ => evalExpr(func.ev(ctx.ec.environment, args)))
}
}: EvalM[F, C, EVALUATED]
Expand All @@ -118,7 +118,7 @@ class EvaluatorV1[F[_]: Monad, C[_[_]]](implicit ev: Monad[EvalF[F, *]], ev2: Mo
// no such function, try data constructor
header match {
case FunctionHeader.User(typeName, _) =>
types.get(ctx).get(typeName).collect { case t @ CASETYPEREF(_, fields, _) =>
ctx.ec.typeDefs.get(typeName).collect { case t @ CASETYPEREF(_, fields, _) =>
args
.traverse[EvalM[F, C, *], EVALUATED](evalExpr)
.map(values => CaseObj(t, fields.map(_._1).zip(values).toMap): EVALUATED)
Expand Down
Loading

0 comments on commit d51a17f

Please sign in to comment.