From 39c4be1c0b80442cf3611de78c4ffbc3f7589adc Mon Sep 17 00:00:00 2001 From: Anvar Kiekbaev Date: Wed, 16 Dec 2020 11:40:03 +0300 Subject: [PATCH] NODE-2150: Native curve25519 implementation (#3156) --- build.sbt | 2 + curve25519-test/build.sbt | 8 + .../wavesplatform/curve25519/test/App.scala | 334 ++++++++++++++++++ .../scala/com/wavesplatform/lang/v1/CTX.scala | 10 +- .../v1/evaluator/ctx/impl/CryptoContext.scala | 25 +- .../ctx/impl/waves/WavesContext.scala | 10 +- .../scala/com/wavesplatform/Exporter.scala | 2 +- .../com/wavesplatform/crypto/package.scala | 6 +- .../com/wavesplatform/database/package.scala | 30 +- .../scala/com/wavesplatform/state/Diff.scala | 9 +- .../state/diffs/BlockDiffer.scala | 6 +- .../transaction/smart/BlockchainContext.scala | 43 ++- .../diffs/BlockDifferDetailedDiffTest.scala | 3 +- project/Dependencies.scala | 29 +- 14 files changed, 443 insertions(+), 74 deletions(-) create mode 100644 curve25519-test/build.sbt create mode 100644 curve25519-test/src/main/scala/com/wavesplatform/curve25519/test/App.scala diff --git a/build.sbt b/build.sbt index 4cb715af15d..7f97efee815 100644 --- a/build.sbt +++ b/build.sbt @@ -81,6 +81,8 @@ lazy val `node-it` = project.dependsOn(node, `grpc-server`) lazy val `node-generator` = project.dependsOn(node, `node` % "compile") lazy val benchmark = project.dependsOn(node % "compile;test->test") +lazy val `curve25519-test` = project.dependsOn(node) + lazy val root = (project in file(".")) .aggregate( `lang-js`, diff --git a/curve25519-test/build.sbt b/curve25519-test/build.sbt new file mode 100644 index 00000000000..04f342b8983 --- /dev/null +++ b/curve25519-test/build.sbt @@ -0,0 +1,8 @@ +libraryDependencies ++= Seq( + "com.typesafe.scala-logging" %% "scala-logging" % "3.9.2", + Dependencies.googleGuava, + Dependencies.monixModule("reactive").value, + Dependencies.curve25519 +) ++ Dependencies.logDeps + +enablePlugins(JavaAppPackaging) diff --git a/curve25519-test/src/main/scala/com/wavesplatform/curve25519/test/App.scala b/curve25519-test/src/main/scala/com/wavesplatform/curve25519/test/App.scala new file mode 100644 index 00000000000..849e84bdd57 --- /dev/null +++ b/curve25519-test/src/main/scala/com/wavesplatform/curve25519/test/App.scala @@ -0,0 +1,334 @@ +package com.wavesplatform.curve25519.test + +import java.io._ +import java.util +import java.util.concurrent._ +import java.util.concurrent.atomic.AtomicLong + +import com.google.common.io.{BaseEncoding, CountingOutputStream} +import com.google.common.primitives.{Bytes, Ints, Longs} +import com.typesafe.scalalogging.StrictLogging +import monix.execution.Scheduler +import org.whispersystems.curve25519.{Curve25519Provider, JavaCurve25519Provider, NativeCurve25519Provider} + +import scala.annotation.tailrec +import scala.concurrent.duration._ +import scala.reflect.ClassTag + +object App extends StrictLogging { + + class Writer(queue: BlockingQueue[Seq[CheckResult]], latch: CountDownLatch, out: DataOutputStream) extends Runnable { + override def run(): Unit = while (latch.getCount > 0) { + for (r <- queue.take()) { + out.writeInt(r.seedNr) + out.writeInt(r.messageNr) + out.write(r.signature) + } + out.flush() + } + } + + class Dispatcher(startWith: Int, modulus: Int, itemCount: Option[Int] = None) { + @volatile var maxSeedNr = startWith + private[this] val iterator = + Iterator + .from(startWith) + .flatMap { maxSeedNr => + (0 until modulus * (maxSeedNr + 1)).iterator.map { msgLength => + Input(maxSeedNr, maxSeedNr - msgLength / modulus, msgLength) + } + } + .zipWithIndex + .takeWhile { case (_, seqNr) => itemCount.forall(_ > seqNr) } + .map(_._1) + + def nextBatch(batchSize: Int): Seq[Input] = iterator.synchronized { + val batch = iterator.take(batchSize).toSeq + batch.lastOption.foreach(i => maxSeedNr = i.maxSeedNr) + batch + } + } + + class SignatureDataReader(in: DataInputStream) { + def nextBatch(batchSize: Int): Seq[(Int, Int, Array[Byte])] = in.synchronized { + val buffer = Seq.newBuilder[(Int, Int, Array[Byte])] + var canContinue = true + var counter = 0 + while (canContinue && counter < batchSize) try { + val seedNr = in.readInt() + val msgLength = in.readInt() + val signature = new Array[Byte](64) + in.readFully(signature) + buffer += ((seedNr, msgLength, signature)) + counter += 1 + } catch { + case _: EOFException => + canContinue = false + } + buffer.result() + } + } + + def mkMessageTemplate(randomSeed: Long): Array[Byte] = { + val seedSeq = Longs.toByteArray(randomSeed) + val remainder = MaxMessageLength % Longs.BYTES + Array.fill(MaxMessageLength / Longs.BYTES + (if (remainder > 0) 1 else 0))(seedSeq).flatten + } + + def mkMsg(seqNr: Int, messageTemplate: Array[Byte]): Array[Byte] = { + val length = seqNr % MaxMessageLength + 1 + val prefix = Ints.toByteArray(seqNr / MaxMessageLength).reverse + val result = new Array[Byte](length) + + System.arraycopy(prefix, 0, result, 0, math.min(Ints.BYTES, length)) + if (length > Ints.BYTES) { + System.arraycopy(messageTemplate, 0, result, Ints.BYTES, length - Ints.BYTES) + } + + result + } + + abstract class Worker( + latch: CountDownLatch, + queue: Option[BlockingQueue[Seq[CheckResult]]], + counter: AtomicLong, + randomSeed: Long, + nextBatch: () => Seq[Input] + ) extends Runnable { + private val messageTemplate = mkMessageTemplate(randomSeed) + + def sign(seed: Array[Byte], message: Array[Byte]): Array[Byte] + + @tailrec + private def checkTask(batch: Seq[Input]): Unit = + if (batch.isEmpty) { + logger.info("No more values to check") + } else { + queue.foreach( + _.put(batch.map { input => + val signarure = sign(mkAccountSeed(randomSeed)(input.seedNr), mkMsg(input.msgLength, messageTemplate)) + CheckResult(input.seedNr, input.msgLength, signarure) + }) + ) + + counter.updateAndGet(_ + batch.size) + checkTask(nextBatch()) + } + + override def run(): Unit = { + checkTask(nextBatch()) + latch.countDown() + } + } + + def provider[A <: Curve25519Provider: ClassTag]: A = { + val ctor = implicitly[ClassTag[A]].runtimeClass.getDeclaredConstructor() + ctor.setAccessible(true) + ctor.newInstance().asInstanceOf[A] + } + + private val multiplier = 0x5DEECE66DL + private val addend = 0xBL + private val mask = (1L << 48) - 1 + private val MaxMessageLength = 150 * 1024 + + private val NativeSignerJavaVerifier = 1 << 0 + private val JavaSignerNativeVerifier = 1 << 1 + private val NativeSignerJavaVerifierAlteredMessage = 1 << 2 + private val JavaSignerNativeVerifierAlteredMessage = 1 << 3 + + def mkAccountSeed(randomSeed: Long)(seqNr: Int): Array[Byte] = { + val nv = (seqNr * multiplier + addend) & mask + val value = nv << 32 | nv + Bytes.concat( + Longs.toByteArray(value ^ (randomSeed & 0xF000F000F000F000L)), + Longs.toByteArray(value ^ (randomSeed & 0x0F000F000F000F00L)), + Longs.toByteArray(value ^ (randomSeed & 0x00F000F000F000F0L)), + Longs.toByteArray(value ^ (randomSeed & 0x000F000F000F000FL)) + ) + } + + private def signAndCheck( + seed: Array[Byte], + message: Array[Byte], + nativeProvider: Curve25519Provider, + javaProvider: Curve25519Provider + ): Array[Byte] = { + val privateKey = nativeProvider.generatePrivateKey(seed) + val publicKey = javaProvider.generatePublicKey(privateKey) + val random = new Array[Byte](64) + ThreadLocalRandom.current().nextBytes(random) + + val nativeSignature = nativeProvider.calculateSignature(random, privateKey, message) + val javaSignature = javaProvider.calculateSignature(random, privateKey, message) + var result = 0 + + if (!util.Arrays.equals(nativeSignature, javaSignature)) {} + + if (!javaProvider.verifySignature(publicKey, message, nativeSignature)) { + logger.error(s"NSJV: pk=${toHex(publicKey)},msg=${toHex(message)},ns=${toHex(nativeSignature)}") + result |= NativeSignerJavaVerifier + } + + if (!nativeProvider.verifySignature(publicKey, message, javaSignature)) { + logger.error(s"JSNV: pk=${toHex(publicKey)},msg=${toHex(message)},js=${toHex(javaSignature)}") + result |= JavaSignerNativeVerifier + } + + val alteredMessage = util.Arrays.copyOf(message, message.length) + + if (alteredMessage(0) == 0) alteredMessage(0) = 1.toByte else alteredMessage(0) = 0 + + if (nativeProvider.verifySignature(publicKey, alteredMessage, javaSignature)) { + logger.error(s"MMNJ: pk=${toHex(publicKey)},msg=${toHex(alteredMessage)},js=${toHex(javaSignature)}") + result |= JavaSignerNativeVerifierAlteredMessage + } + + if (javaProvider.verifySignature(publicKey, alteredMessage, nativeSignature)) { + logger.error(s"MMJN: pk=${toHex(publicKey)},msg=${toHex(alteredMessage)},ns=${toHex(nativeSignature)}") + result |= NativeSignerJavaVerifierAlteredMessage + } + + if (result != 0) { + logger.error(s"""MISMATCH ($result): + |seed=${toHex(seed)}, message=${toHex(message)} + | sk=${toHex(privateKey)} + | pk=${toHex(publicKey)} + |native_sig=${toHex(nativeSignature)} + | java_sig=${toHex(javaSignature)}""".stripMargin) + } + + nativeSignature + } + + private def signMessage(seed: Array[Byte], message: Array[Byte], provider: Curve25519Provider): Array[Byte] = { + val privateKey = provider.generatePrivateKey(seed) + val random = new Array[Byte](64) + ThreadLocalRandom.current().nextBytes(random) + + provider.calculateSignature(random, privateKey, message) + } + + case class CheckResult(seedNr: Int, messageNr: Int, signature: Array[Byte]) + + case class Input(maxSeedNr: Int, seedNr: Int, msgLength: Int) + + def iter(modulus: Int, startWith: Int = 0): Iterator[Input] = + Iterator.from(startWith).flatMap { maxSeedNr => + (0 until modulus * (maxSeedNr + 1)).iterator.map { msgLength => + Input(maxSeedNr, maxSeedNr - msgLength / modulus, msgLength) + } + } + + private val codec = BaseEncoding.base16().lowerCase() + + private def toHex(bytes: Array[Byte]): String = codec.encode(bytes) + + private val HexPattern = "0x([0-9A-Fa-f]+)[Ll]*".r + + private val workerCount: Int = Runtime.getRuntime.availableProcessors() + + def main(args: Array[String]): Unit = { + val nativeProvider = provider[NativeCurve25519Provider] + val javaProvider = provider[JavaCurve25519Provider] + val latch = new CountDownLatch(workerCount) + + args(0).toLowerCase match { + case "verify" => + val in = new DataInputStream(new BufferedInputStream(new FileInputStream(args(1)), 1024 * 1024)) + val randomSeed = in.readLong() + val msgTemplate = mkMessageTemplate(randomSeed) + val reader = new SignatureDataReader(in) + + @tailrec def verifySignatures(): Unit = { + val batch = reader.nextBatch(1000) + if (batch.isEmpty) { + logger.info(s"No more signatures to verify") + } else { + for ((seedNr, msgLength, signature) <- batch) { + val seed = mkAccountSeed(randomSeed)(seedNr) + val message = mkMsg(msgLength, msgTemplate) + if (!nativeProvider.verifySignature(nativeProvider.generatePublicKey(nativeProvider.generatePrivateKey(seed)), message, signature)) { + logger.error(s"Mismatch: $seedNr, $msgLength") + } + } + verifySignatures() + } + } + + (1 to workerCount) foreach { i => + logger.info(s"Starting worker $i") + val t = new Thread({ () => + verifySignatures() + latch.countDown() + }) + t.setDaemon(true) + t.setName(s"worker-$i") + t.start() + } + + latch.await() + logger.info("All signatures are valid") + + case action @ ("check" | "generate") => + val randomSeed = args(1) match { + case HexPattern(n) => java.lang.Long.parseLong(n, 16) + case other => other.toLong + } + + val startWith = args(2).toInt + val modulus = args(3).toInt + + val count = args.lift(5).map(_.toInt) + logger.info(s"Item count: $count") + val dispatcher = new Dispatcher(startWith, modulus, count) + val counter = new AtomicLong() + + val maybeOut = args.lift(4).map { outputFileName => + val queue = new LinkedBlockingQueue[Seq[CheckResult]]() + val out = new CountingOutputStream(new BufferedOutputStream(new FileOutputStream(new File(outputFileName)), 100000)) + val dataStream = new DataOutputStream(out) + dataStream.writeLong(randomSeed) + dataStream.flush() + val writer = new Thread(new Writer(queue, latch, dataStream)) + writer.setName("writer") + writer.start() + + (out, writer, queue) + } + + def newWorker() = action match { + case "check" => + new Worker(latch, maybeOut.map(_._3), counter, randomSeed, () => dispatcher.nextBatch(1000)) { + override def sign(seed: Array[Byte], message: Array[Byte]): Array[Byte] = signAndCheck(seed, message, nativeProvider, javaProvider) + } + case "generate" => + new Worker(latch, maybeOut.map(_._3), counter, randomSeed, () => dispatcher.nextBatch(1000)) { + override def sign(seed: Array[Byte], message: Array[Byte]): Array[Byte] = signMessage(seed, message, nativeProvider) + } + } + + (1 to workerCount) foreach { i => + logger.info(s"Starting worker $i") + val t = new Thread(newWorker()) + t.setDaemon(true) + t.setName(s"worker-$i") + t.start() + } + + Scheduler.global.scheduleWithFixedDelay(1.minute, 1.minute)( + logger.info( + s"Max seed nr: ${dispatcher.maxSeedNr}. Checked ${counter.get()} values${maybeOut.fold("")(c => s", written ${c._1.getCount} bytes")}" + ) + ) + latch.await() + + maybeOut.foreach { + case (out, writer, _) => + writer.join() + out.flush() + out.close() + } + } + } +} diff --git a/lang/shared/src/main/scala/com/wavesplatform/lang/v1/CTX.scala b/lang/shared/src/main/scala/com/wavesplatform/lang/v1/CTX.scala index 32c4f7cc98e..24b2c2159c3 100644 --- a/lang/shared/src/main/scala/com/wavesplatform/lang/v1/CTX.scala +++ b/lang/shared/src/main/scala/com/wavesplatform/lang/v1/CTX.scala @@ -20,18 +20,20 @@ case class CTX[C[_[_]]]( @(JSExport @field) vars: Map[String, (FINAL, ContextfulVal[C])], @(JSExport @field) functions: Array[BaseFunction[C]] ) { - lazy val typeDefs = types.map(t => t.name -> t).toMap + lazy val typeDefs = types.view.map(t => t.name -> t).toMap + lazy val functionMap = functions.view.map(f => f.header -> f).toMap def evaluationContext[F[_]: Monad](env: C[F]): EvaluationContext[C, F] = { - if (functions.map(_.header).distinct.length != functions.length) { + + if (functionMap.size != functions.length) { val dups = functions.groupBy(_.header).filter(_._2.length != 1) throw new Exception(s"Duplicate runtime functions names: $dups") } EvaluationContext( env, typeDefs, - vars.view.mapValues(v => LazyVal.fromEval(v._2(env))).toMap, - functions.map(f => f.header -> f).toMap + vars.map { case (k, v) => k -> LazyVal.fromEval(v._2(env)) }, + functionMap ) } diff --git a/lang/shared/src/main/scala/com/wavesplatform/lang/v1/evaluator/ctx/impl/CryptoContext.scala b/lang/shared/src/main/scala/com/wavesplatform/lang/v1/evaluator/ctx/impl/CryptoContext.scala index 55477ad9163..8511df0bfb6 100644 --- a/lang/shared/src/main/scala/com/wavesplatform/lang/v1/evaluator/ctx/impl/CryptoContext.scala +++ b/lang/shared/src/main/scala/com/wavesplatform/lang/v1/evaluator/ctx/impl/CryptoContext.scala @@ -16,6 +16,8 @@ import com.wavesplatform.lang.v1.evaluator.ctx.impl.crypto.RSA.DigestAlgorithm import com.wavesplatform.lang.v1.evaluator.ctx.{BaseFunction, EvaluationContext, NativeFunction} import com.wavesplatform.lang.v1.{BaseGlobal, CTX} +import scala.collection.mutable + object CryptoContext { private val rsaTypeNames = List("NoAlg", "Md5", "Sha1", "Sha224", "Sha256", "Sha384", "Sha512", "Sha3224", "Sha3256", "Sha3384", "Sha3512") @@ -25,7 +27,7 @@ object CryptoContext { } private def digestAlgorithmType(v: StdLibVersion) = - UNION.create(rsaHashAlgs(v), (if(v > V3) { Some("RsaDigestAlgs") } else { None })) + UNION.create(rsaHashAlgs(v), if (v > V3) Some("RsaDigestAlgs") else None) private val rsaHashLib = { import com.wavesplatform.lang.v1.evaluator.ctx.impl.crypto.RSA._ @@ -39,7 +41,14 @@ object CryptoContext { private def digestAlgValue(tpe: CASETYPEREF): ContextfulVal[NoContext] = ContextfulVal.pure(CaseObj(tpe, Map.empty)) - def build(global: BaseGlobal, version: StdLibVersion): CTX[NoContext] = { + def build(global: BaseGlobal, version: StdLibVersion): CTX[NoContext] = + ctxCache.getOrElse((global, version), ctxCache.synchronized { + ctxCache.getOrElseUpdate((global, version), buildNew(global, version)) + }) + + private val ctxCache = mutable.AnyRefMap.empty[(BaseGlobal, StdLibVersion), CTX[NoContext]] + + private def buildNew(global: BaseGlobal, version: StdLibVersion): CTX[NoContext] = { def lgen( lim: Array[Int], name: ((Int, Int)) => (String, Short), @@ -86,8 +95,9 @@ object CryptoContext { (n => (s"${name}_${n._1}Kb", (internalName + n._2).toShort)), costs, (n => { - case CONST_BYTESTR(msg: ByteStr) :: _ => Either.cond(msg.size <= n * 1024, (), s"Invalid message size = ${msg.size} bytes, must be not greater than $n KB") - case xs => notImplemented[Id, Unit](s"${name}_${n}Kb(bytes: ByteVector)", xs) + case CONST_BYTESTR(msg: ByteStr) :: _ => + Either.cond(msg.size <= n * 1024, (), s"Invalid message size = ${msg.size} bytes, must be not greater than $n KB") + case xs => notImplemented[Id, Unit](s"${name}_${n}Kb(bytes: ByteVector)", xs) }), BYTESTR, ("bytes", BYTESTR) @@ -459,13 +469,13 @@ object CryptoContext { val rsaVarNames = List("NOALG", "MD5", "SHA1", "SHA224", "SHA256", "SHA384", "SHA512", "SHA3224", "SHA3256", "SHA3384", "SHA3512") val v4RsaDig = rsaHashAlgs(V4) - val v4Types = v4RsaDig :+ digestAlgorithmType(V4) + val v4Types = v4RsaDig :+ digestAlgorithmType(V4) val v4Vars: Map[String, (FINAL, ContextfulVal[NoContext])] = rsaVarNames.zip(v4RsaDig.map(t => (t, digestAlgValue(t)))).toMap val v3RsaDig = rsaHashAlgs(V3) - val v3Types = v3RsaDig :+ digestAlgorithmType(V3) + val v3Types = v3RsaDig :+ digestAlgorithmType(V3) val v3Vars: Map[String, (FINAL, ContextfulVal[NoContext])] = rsaVarNames.zip(v3RsaDig.map(t => (t, digestAlgValue(t)))).toMap @@ -482,7 +492,8 @@ object CryptoContext { Array( bls12Groth16VerifyF, bn256Groth16VerifyF, - createMerkleRootF, ecrecover,// new in V4 + createMerkleRootF, + ecrecover, // new in V4 rsaVerifyF, toBase16StringF(checkLength = true), fromBase16StringF(checkLength = true) // from V3 diff --git a/lang/shared/src/main/scala/com/wavesplatform/lang/v1/evaluator/ctx/impl/waves/WavesContext.scala b/lang/shared/src/main/scala/com/wavesplatform/lang/v1/evaluator/ctx/impl/waves/WavesContext.scala index a589d31dd74..39d8be3baad 100644 --- a/lang/shared/src/main/scala/com/wavesplatform/lang/v1/evaluator/ctx/impl/waves/WavesContext.scala +++ b/lang/shared/src/main/scala/com/wavesplatform/lang/v1/evaluator/ctx/impl/waves/WavesContext.scala @@ -1,7 +1,6 @@ package com.wavesplatform.lang.v1.evaluator.ctx.impl.waves import cats.implicits._ -import com.wavesplatform.common.utils.EitherExt2 import com.wavesplatform.lang.directives.values._ import com.wavesplatform.lang.directives.{DirectiveDictionary, DirectiveSet} import com.wavesplatform.lang.v1.CTX @@ -48,16 +47,11 @@ object WavesContext { private val variableCtxCache: Map[DirectiveSet, CTX[Environment]] = allDirectives - .filter(_.isRight) - .map(_.explicitGet()) - .map(ds => (ds, variableCtx(ds))) + .collect { case Right(ds) => (ds, variableCtx(ds)) } .toMap private def variableCtx(ds: DirectiveSet): CTX[Environment] = { - val isTokenContext = ds.scriptType match { - case Account => false - case Asset => true - } + val isTokenContext = ds.scriptType == Asset val proofsEnabled = !isTokenContext val version = ds.stdLibVersion CTX( diff --git a/node/src/main/scala/com/wavesplatform/Exporter.scala b/node/src/main/scala/com/wavesplatform/Exporter.scala index 85a736ee33b..6fb2e5a98c6 100644 --- a/node/src/main/scala/com/wavesplatform/Exporter.scala +++ b/node/src/main/scala/com/wavesplatform/Exporter.scala @@ -55,7 +55,7 @@ object Exporter extends ScorexLogging { IO.createOutputStream(outputFilename) match { case Success(output) => var exportedBytes = 0L - val bos = new BufferedOutputStream(output) + val bos = new BufferedOutputStream(output, 10 * 1024 * 1024) val start = System.currentTimeMillis() exportedBytes += IO.writeHeader(bos, format) (2 to height).foreach { h => diff --git a/node/src/main/scala/com/wavesplatform/crypto/package.scala b/node/src/main/scala/com/wavesplatform/crypto/package.scala index 49b8b5aba4f..d3b72db7699 100644 --- a/node/src/main/scala/com/wavesplatform/crypto/package.scala +++ b/node/src/main/scala/com/wavesplatform/crypto/package.scala @@ -13,7 +13,7 @@ import scorex.crypto.signatures.{Curve25519, Signature, PrivateKey => SPrivateKe import scala.util.Try -package object crypto { +package object crypto extends ScorexLogging { // Constants val SignatureLength: Int = Curve25519.SignatureLength val KeyLength: Int = Curve25519.KeyLength @@ -24,7 +24,9 @@ package object crypto { val constructor = classOf[OpportunisticCurve25519Provider].getDeclaredConstructors.head .asInstanceOf[Constructor[OpportunisticCurve25519Provider]] constructor.setAccessible(true) - constructor.newInstance() + val p = constructor.newInstance() + log.info(s"Native provider used: ${p.isNative}") + p } // Digests diff --git a/node/src/main/scala/com/wavesplatform/database/package.scala b/node/src/main/scala/com/wavesplatform/database/package.scala index 732ad4f5085..b3c9c448aa7 100644 --- a/node/src/main/scala/com/wavesplatform/database/package.scala +++ b/node/src/main/scala/com/wavesplatform/database/package.scala @@ -361,11 +361,11 @@ package object database extends ScorexLogging { } def readStateHash(bs: Array[Byte]): StateHash = { - val ndi = newDataInput(bs) + val ndi = newDataInput(bs) val sectionsCount = ndi.readByte() val sections = (0 until sectionsCount).map { _ => val sectionId = ndi.readByte() - val value = ndi.readByteStr(DigestLength) + val value = ndi.readByteStr(DigestLength) SectionId(sectionId) -> value } val totalHash = ndi.readByteStr(DigestLength) @@ -374,11 +374,12 @@ package object database extends ScorexLogging { def writeStateHash(sh: StateHash): Array[Byte] = { val sorted = sh.sectionHashes.toSeq.sortBy(_._1) - val ndo = newDataOutput(crypto.DigestLength + 1 + sorted.length * (1 + crypto.DigestLength)) + val ndo = newDataOutput(crypto.DigestLength + 1 + sorted.length * (1 + crypto.DigestLength)) ndo.writeByte(sorted.length) - sorted.foreach { case (sectionId, value) => - ndo.writeByte(sectionId.id.toByte) - ndo.writeByteStr(value.ensuring(_.arr.length == DigestLength)) + sorted.foreach { + case (sectionId, value) => + ndo.writeByte(sectionId.id.toByte) + ndo.writeByteStr(value.ensuring(_.arr.length == DigestLength)) } ndo.writeByteStr(sh.totalHash.ensuring(_.arr.length == DigestLength)) ndo.toByteArray @@ -555,18 +556,19 @@ package object database extends ScorexLogging { } def loadTransactions(height: Height, db: ReadOnlyDB): Option[Seq[(Transaction, Boolean)]] = - for { - meta <- db.get(Keys.blockMetaAt(height)) - } yield (0 until meta.transactionCount).toList.flatMap { n => - db.get(Keys.transactionAt(height, TxNum(n.toShort))) + if (height < 1 || db.get(Keys.height) < height) None + else { + val transactions = Seq.newBuilder[(Transaction, Boolean)] + db.iterateOver(KeyTags.NthTransactionInfoAtHeight.prefixBytes ++ Ints.toByteArray(height)) { e => + transactions += readTransaction(e.getValue) + } + Some(transactions.result()) } def loadBlock(height: Height, db: ReadOnlyDB): Option[Block] = for { - meta <- db.get(Keys.blockMetaAt(height)) - txs = (0 until meta.transactionCount).toList.flatMap { n => - db.get(Keys.transactionAt(height, TxNum(n.toShort))) - } + meta <- db.get(Keys.blockMetaAt(height)) + txs <- loadTransactions(height, db) block <- createBlock(meta.header, meta.signature, txs.map(_._1)).toOption } yield block diff --git a/node/src/main/scala/com/wavesplatform/state/Diff.scala b/node/src/main/scala/com/wavesplatform/state/Diff.scala index e67d1661ff2..d057249defe 100755 --- a/node/src/main/scala/com/wavesplatform/state/Diff.scala +++ b/node/src/main/scala/com/wavesplatform/state/Diff.scala @@ -13,8 +13,7 @@ import com.wavesplatform.transaction.Asset.IssuedAsset import com.wavesplatform.transaction.{Asset, Transaction} import play.api.libs.json._ -import scala.collection.mutable -import scala.collection.mutable.LinkedHashMap +import scala.collection.immutable.VectorMap case class LeaseBalance(in: Long, out: Long) @@ -179,7 +178,7 @@ object Diff { scriptsRun: Int = 0 ): Diff = Diff( - transactions = mutable.LinkedHashMap(), + transactions = VectorMap.empty, portfolios = portfolios, issuedAssets = issuedAssets, updatedAssets = updatedAssets, @@ -213,7 +212,7 @@ object Diff { ): Diff = Diff( // should be changed to VectorMap after 2.13 https://github.com/scala/scala/pull/6854 - transactions = LinkedHashMap(toDiffTxData(tx, portfolios, accountData)), + transactions = VectorMap(toDiffTxData(tx, portfolios, accountData)), portfolios = portfolios, issuedAssets = issuedAssets, updatedAssets = updatedAssets, @@ -238,7 +237,7 @@ object Diff { val empty = new Diff( - LinkedHashMap(), + VectorMap.empty, Map.empty, Map.empty, Map.empty, diff --git a/node/src/main/scala/com/wavesplatform/state/diffs/BlockDiffer.scala b/node/src/main/scala/com/wavesplatform/state/diffs/BlockDiffer.scala index ba64016a383..28cf0088d1f 100644 --- a/node/src/main/scala/com/wavesplatform/state/diffs/BlockDiffer.scala +++ b/node/src/main/scala/com/wavesplatform/state/diffs/BlockDiffer.scala @@ -17,7 +17,7 @@ import com.wavesplatform.transaction.{Asset, Transaction} import com.wavesplatform.utils.ScorexLogging object BlockDiffer extends ScorexLogging { - final case class DetailedDiff(parentDiff: Diff, transactionDiffs: Seq[Diff]) + final case class DetailedDiff(parentDiff: Diff, transactionDiffs: List[Diff]) final case class Result(diff: Diff, carry: Long, totalFee: Long, constraint: MiningConstraint, detailedDiff: DetailedDiff) case class Fraction(dividend: Int, divider: Int) { @@ -151,7 +151,7 @@ object BlockDiffer extends ScorexLogging { val hasSponsorship = currentBlockHeight >= Sponsorship.sponsoredFeesSwitchHeight(blockchain) txs - .foldLeft(TracedResult(Result(initDiff, 0L, 0L, initConstraint, DetailedDiff(initDiff, Seq.empty)).asRight[ValidationError])) { + .foldLeft(TracedResult(Result(initDiff, 0L, 0L, initConstraint, DetailedDiff(initDiff, Nil)).asRight[ValidationError])) { case (acc @ TracedResult(Left(_), _), _) => acc case (TracedResult(Right(Result(currDiff, carryFee, currTotalFee, currConstraint, DetailedDiff(parentDiff, txDiffs))), _), tx) => val currBlockchain = CompositeBlockchain(blockchain, Some(currDiff)) @@ -181,7 +181,7 @@ object BlockDiffer extends ScorexLogging { carryFee + carry, totalWavesFee, updatedConstraint, - DetailedDiff(parentDiff.combine(minerDiff), txDiffs :+ thisTxDiff) + DetailedDiff(parentDiff.combine(minerDiff), thisTxDiff :: txDiffs) ) ) } diff --git a/node/src/main/scala/com/wavesplatform/transaction/smart/BlockchainContext.scala b/node/src/main/scala/com/wavesplatform/transaction/smart/BlockchainContext.scala index 32153ef3b7f..2ca24f8fab0 100644 --- a/node/src/main/scala/com/wavesplatform/transaction/smart/BlockchainContext.scala +++ b/node/src/main/scala/com/wavesplatform/transaction/smart/BlockchainContext.scala @@ -1,10 +1,13 @@ package com.wavesplatform.transaction.smart +import java.util + import cats.Id import cats.implicits._ import com.wavesplatform.common.state.ByteStr import com.wavesplatform.lang.directives.DirectiveSet import com.wavesplatform.lang.directives.values.{ContentType, ScriptType, StdLibVersion} +import com.wavesplatform.lang.v1.CTX import com.wavesplatform.lang.v1.evaluator.ctx.EvaluationContext import com.wavesplatform.lang.v1.evaluator.ctx.impl.waves.WavesContext import com.wavesplatform.lang.v1.evaluator.ctx.impl.{CryptoContext, PureContext} @@ -16,26 +19,36 @@ import monix.eval.Coeval object BlockchainContext { type In = WavesEnvironment.In - def build(version: StdLibVersion, - nByte: Byte, - in: Coeval[Environment.InputEntity], - h: Coeval[Int], - blockchain: Blockchain, - isTokenContext: Boolean, - isContract: Boolean, - address: Environment.Tthis, - txId: ByteStr): Either[ExecutionError, EvaluationContext[Environment, Id]] = { + + private[this] val cache = new util.HashMap[(StdLibVersion, DirectiveSet), CTX[Environment]]() + + def build( + version: StdLibVersion, + nByte: Byte, + in: Coeval[Environment.InputEntity], + h: Coeval[Int], + blockchain: Blockchain, + isTokenContext: Boolean, + isContract: Boolean, + address: Environment.Tthis, + txId: ByteStr + ): Either[ExecutionError, EvaluationContext[Environment, Id]] = { DirectiveSet( version, ScriptType.isAssetScript(isTokenContext), ContentType.isDApp(isContract) ).map { ds => - val ctx = - PureContext.build(version).withEnvironment[Environment] |+| - CryptoContext.build(Global, version).withEnvironment[Environment] |+| - WavesContext.build(ds) - - ctx.evaluationContext(new WavesEnvironment(nByte, in, h, blockchain, address, ds, txId)) + cache + .synchronized( + cache.computeIfAbsent( + (version, ds), { _ => + PureContext.build(version).withEnvironment[Environment] |+| + CryptoContext.build(Global, version).withEnvironment[Environment] |+| + WavesContext.build(ds) + } + ) + ) + .evaluationContext(new WavesEnvironment(nByte, in, h, blockchain, address, ds, txId)) } } } diff --git a/node/src/test/scala/com/wavesplatform/state/diffs/BlockDifferDetailedDiffTest.scala b/node/src/test/scala/com/wavesplatform/state/diffs/BlockDifferDetailedDiffTest.scala index b5d297bf5db..f888541bb25 100644 --- a/node/src/test/scala/com/wavesplatform/state/diffs/BlockDifferDetailedDiffTest.scala +++ b/node/src/test/scala/com/wavesplatform/state/diffs/BlockDifferDetailedDiffTest.scala @@ -85,7 +85,8 @@ class BlockDifferDetailedDiffTest extends FreeSpec with Matchers with PropertyCh forAll(genesisTransfersBlockGen) { case (addr1, addr2, amt1, amt2, b) => assertDetailedDiff(Seq.empty, b) { - case (_, DetailedDiff(_, transactionDiffs)) => + case (_, DetailedDiff(_, td)) => + val transactionDiffs = td.reverse transactionDiffs.head.portfolios(addr1).balance shouldBe ENOUGH_AMT transactionDiffs(1).portfolios(addr1).balance shouldBe -(amt1 + transactionFee) transactionDiffs(1).portfolios(addr2).balance shouldBe amt1 diff --git a/project/Dependencies.scala b/project/Dependencies.scala index 2275d07aa14..2c4c8b35f77 100644 --- a/project/Dependencies.scala +++ b/project/Dependencies.scala @@ -14,22 +14,23 @@ object Dependencies { private def bouncyCastle(module: String) = "org.bouncycastle" % s"$module-jdk15on" % "1.59" private def catsModule(module: String, version: String = "2.1.0") = Def.setting("org.typelevel" %%% s"cats-$module" % version) - private def monixModule(module: String) = Def.setting("io.monix" %%% s"monix-$module" % "3.2.1") + def monixModule(module: String) = Def.setting("io.monix" %%% s"monix-$module" % "3.3.0") private val kindProjector = compilerPlugin("org.typelevel" % "kind-projector" % "0.11.0" cross CrossVersion.full) - val akkaHttp = akkaHttpModule("akka-http") - private val jacksonModuleScala = jacksonModule("module", "module-scala").withCrossVersion(CrossVersion.Binary()) - private val googleGuava = "com.google.guava" % "guava" % "27.0.1-jre" - private val kamonCore = kamonModule("core") - private val machinist = "org.typelevel" %% "machinist" % "0.6.8" - val logback = "ch.qos.logback" % "logback-classic" % "1.2.3" - val janino = "org.codehaus.janino" % "janino" % "3.0.12" - val asyncHttpClient = "org.asynchttpclient" % "async-http-client" % "2.7.0" + val akkaHttp = akkaHttpModule("akka-http") + val jacksonModuleScala = jacksonModule("module", "module-scala").withCrossVersion(CrossVersion.Binary()) + val googleGuava = "com.google.guava" % "guava" % "27.0.1-jre" + val kamonCore = kamonModule("core") + val machinist = "org.typelevel" %% "machinist" % "0.6.8" + val logback = "ch.qos.logback" % "logback-classic" % "1.2.3" + val janino = "org.codehaus.janino" % "janino" % "3.0.12" + val asyncHttpClient = "org.asynchttpclient" % "async-http-client" % "2.7.0" + val curve25519 = "com.wavesplatform" % "curve25519-java" % "0.6.4" - private val catsEffect = catsModule("effect", "2.1.3") - private val catsCore = catsModule("core") - private val shapeless = Def.setting("com.chuusai" %%% "shapeless" % "2.3.3") + val catsEffect = catsModule("effect", "2.1.3") + val catsCore = catsModule("core") + val shapeless = Def.setting("com.chuusai" %%% "shapeless" % "2.3.3") val scalaTest = "org.scalatest" %% "scalatest" % "3.0.8" % Test @@ -96,7 +97,7 @@ object Dependencies { ("org.typelevel" %% "cats-mtl-core" % "0.7.1").exclude("org.scalacheck", "scalacheck_2.13"), "ch.obermuhlner" % "big-math" % "2.1.0", ("org.scorexfoundation" %% "scrypto" % "2.1.8").exclude("org.whispersystems", "curve25519-java"), - "com.wavesplatform" % "curve25519-java" % "0.6.3", + curve25519, bouncyCastle("bcpkix"), bouncyCastle("bcprov"), kindProjector, @@ -164,7 +165,7 @@ object Dependencies { ) private[this] val protoSchemasLib = - "com.wavesplatform" % "protobuf-schemas" % "1.2.9-N2265-SNAPSHOT" classifier "proto" intransitive() + "com.wavesplatform" % "protobuf-schemas" % "1.2.9-N2265-SNAPSHOT" classifier "proto" intransitive () lazy val scalapbRuntime = Def.setting { val version = scalapb.compiler.Version.scalapbVersion