Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(ATL-6924): add support for other keys #834

Merged
merged 5 commits into from
Apr 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,6 @@ lazy val Dependencies = new {

// We have to exclude bouncycastle since for some reason bitcoinj depends on bouncycastle jdk15to18
// (i.e. JDK 1.5 to 1.8), but we are using JDK 11
val prismCrypto =
"io.iohk.atala" % "prism-crypto-jvm" % versions.prismSdk
val prismIdentity =
"io.iohk.atala" % "prism-identity-jvm" % versions.prismSdk

Expand Down Expand Up @@ -154,7 +152,7 @@ lazy val Dependencies = new {
val sttpDependencies = Seq(sttpCore, sttpCE2)
val tofuDependencies = Seq(tofu, tofuLogging, tofuDerevoTagless)
val prismDependencies =
Seq(prismCrypto, prismIdentity)
Seq(prismIdentity)
val scalapbDependencies = Seq(
"com.thesamet.scalapb" %% "scalapb-runtime" % scalapb.compiler.Version.scalapbVersion % "protobuf",
"com.thesamet.scalapb" %% "scalapb-runtime-grpc" % scalapb.compiler.Version.scalapbVersion
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ class NodeGrpcServiceImpl(
node_api
.GetDidDocumentResponse(document = didData.maybeData)
.withLastUpdateOperation(
didData.maybeOperation.map(a => ByteString.copyFrom(a.getValue)).getOrElse(ByteString.EMPTY)
didData.maybeOperation.map(a => ByteString.copyFrom(a.bytes.toArray)).getOrElse(ByteString.EMPTY)
)
.withLastSyncedBlockTimestamp(didData.lastSyncedTimeStamp.toProtoTimestamp)
)
Expand Down
133 changes: 133 additions & 0 deletions node/src/main/scala/io/iohk/atala/prism/node/crypto/CryptoUtils.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
package io.iohk.atala.prism.node.crypto

import io.iohk.atala.prism.node.models.ProtocolConstants
import org.bouncycastle.jcajce.provider.asymmetric.util.EC5Util
import org.bouncycastle.jce.interfaces.ECPublicKey

import java.security.{KeyFactory, MessageDigest, PublicKey, Security, Signature}
import org.bouncycastle.jce.{ECNamedCurveTable, ECPointUtil}
import org.bouncycastle.jce.provider.BouncyCastleProvider

import java.security.spec.{ECPoint, ECPublicKeySpec}

object CryptoUtils {
trait SecpPublicKey {
private[crypto] def publicKey: PublicKey
def curveName: String = ProtocolConstants.secpCurveName
def compressed: Array[Byte] = publicKey
.asInstanceOf[ECPublicKey]
.getQ
.getEncoded(true)
def x: Array[Byte] = publicKey.asInstanceOf[ECPublicKey].getQ.getAffineXCoord.getEncoded
def y: Array[Byte] = publicKey.asInstanceOf[ECPublicKey].getQ.getAffineYCoord.getEncoded
}

private[crypto] class SecpPublicKeyImpl(pubKey: PublicKey) extends SecpPublicKey {
override private[crypto] def publicKey: PublicKey = pubKey
}

// We define the constructor to SecpKeys private so that the only way to generate
// these keys is by using the methods unsafeToPublicKeyFromByteCoordinates and
// unsafeToPublicKeyFromCompressed.
object SecpPublicKey {

private[crypto] def fromPublicKey(key: PublicKey): SecpPublicKey = new SecpPublicKeyImpl(key)

def checkECDSASignature(msg: Array[Byte], sig: Array[Byte], pubKey: SecpPublicKey): Boolean = {
val ecdsaVerify = Signature.getInstance("SHA256withECDSA", provider)
ecdsaVerify.initVerify(pubKey.publicKey)
ecdsaVerify.update(msg)
ecdsaVerify.verify(sig)
}

def unsafeToSecpPublicKeyFromCompressed(com: Vector[Byte]): SecpPublicKey = {
val params = ECNamedCurveTable.getParameterSpec("secp256k1")
val fact = KeyFactory.getInstance("ECDSA", provider)
val curve = params.getCurve
val ellipticCurve = EC5Util.convertCurve(curve, params.getSeed)
val point = ECPointUtil.decodePoint(ellipticCurve, com.toArray)
val params2 = EC5Util.convertSpec(ellipticCurve, params)
val keySpec = new ECPublicKeySpec(point, params2)
SecpPublicKey.fromPublicKey(fact.generatePublic(keySpec))
}

def unsafeToSecpPublicKeyFromByteCoordinates(x: Array[Byte], y: Array[Byte]): SecpPublicKey = {
def trimLeadingZeroes(arr: Array[Byte], c: String): Array[Byte] = {
val trimmed = arr.dropWhile(_ == 0.toByte)
require(
trimmed.length <= PUBLIC_KEY_COORDINATE_BYTE_SIZE,
s"Expected $c coordinate byte length to be less than or equal ${PUBLIC_KEY_COORDINATE_BYTE_SIZE}, but got ${trimmed.length} bytes"
)
trimmed
}

val xTrimmed = trimLeadingZeroes(x, "x")
val yTrimmed = trimLeadingZeroes(y, "y")
val xInteger = BigInt(1, xTrimmed)
val yInteger = BigInt(1, yTrimmed)
SecpPublicKey.unsafeToSecpPublicKeyFromBigIntegerCoordinates(xInteger, yInteger)
}

def unsafeToSecpPublicKeyFromBigIntegerCoordinates(x: BigInt, y: BigInt): SecpPublicKey = {
val params = ECNamedCurveTable.getParameterSpec("secp256k1")
val fact = KeyFactory.getInstance("ECDSA", provider)
val curve = params.getCurve
val ellipticCurve = EC5Util.convertCurve(curve, params.getSeed)
val point = new ECPoint(x.bigInteger, y.bigInteger)
val params2 = EC5Util.convertSpec(ellipticCurve, params)
val keySpec = new ECPublicKeySpec(point, params2)
SecpPublicKey.fromPublicKey(fact.generatePublic(keySpec))
}
}

private val provider = new BouncyCastleProvider()
private val PUBLIC_KEY_COORDINATE_BYTE_SIZE: Int = 32

Security.addProvider(provider)

trait Sha256Hash {
def bytes: Vector[Byte]
def hexEncoded: String = bytesToHex(bytes)
}

private[crypto] case class Sha256HashImpl(bytes: Vector[Byte]) extends Sha256Hash {
require(bytes.size == 32)
}

object Sha256Hash {

def fromBytes(arr: Array[Byte]): Sha256Hash = Sha256HashImpl(arr.toVector)

def compute(bArray: Array[Byte]): Sha256Hash = {
Sha256HashImpl(
MessageDigest
.getInstance("SHA-256")
.digest(bArray)
.toVector
)
}

def fromHex(hexedBytes: String): Sha256Hash = {
val HEX_STRING_RE = "^[0-9a-fA-F]{64}$".r
if (HEX_STRING_RE.matches(hexedBytes)) Sha256HashImpl(hexToBytes(hexedBytes))
else
throw new IllegalArgumentException(
"The given hex string doesn't correspond to a valid SHA-256 hash encoded as string"
)
}
}

def bytesToHex(bytes: Vector[Byte]): String = {
bytes.map(byte => f"${byte & 0xff}%02x").mkString
}

def hexToBytes(hex: String): Vector[Byte] = {
val HEX_ARRAY = "0123456789abcdef".toCharArray
for {
pair <- hex.grouped(2).toVector
firstIndex = HEX_ARRAY.indexOf(pair(0))
secondIndex = HEX_ARRAY.indexOf(pair(1))
octet = firstIndex << 4 | secondIndex
} yield octet.toByte
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@ import com.google.protobuf.ByteString
import io.iohk.atala.prism.protos.models.TimestampInfo
import io.iohk.atala.prism.crypto.EC.{INSTANCE => EC}
import io.iohk.atala.prism.crypto.keys.ECPublicKey
import io.iohk.atala.prism.crypto.ECConfig.{INSTANCE => ECConfig}
import io.iohk.atala.prism.node.models.{DidSuffix, Ledger}
import io.iohk.atala.prism.node.models.{DidSuffix, Ledger, PublicKeyData}
import io.iohk.atala.prism.protos.common_models
import io.iohk.atala.prism.node.models
import io.iohk.atala.prism.node.models.KeyUsage._
Expand Down Expand Up @@ -56,7 +55,7 @@ object ProtoCodecs {
didDataState.keys.map(key =>
toProtoPublicKey(
key.keyId,
toECKeyData(key.key),
toCompressedECKeyData(key.key),
toProtoKeyUsage(key.keyUsage),
toLedgerData(key.addedOn),
key.revokedOn map toLedgerData
Expand All @@ -83,28 +82,26 @@ object ProtoCodecs {

def toProtoPublicKey(
id: String,
ecKeyData: node_models.ECKeyData,
compressedEcKeyData: node_models.CompressedECKeyData,
Comment on lines -86 to +85
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as we now assume we always retrieve a compressed key, we request that the model to convert is a compressed one

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does it apply to new key types and old key type as well? would that require a change in spec then?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO I wouldn't require change on the specs. Would just put in the spec that the compressed version is recommended.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this method is applied on data retrieved from the node DB, we store keys in compressed format, meaning that all keys we retrieve will be compressed indistinctly of how the user submitted it. So, no need for spec changes for this part

We do need to tell in the spec that Ed25519 and X25519 keys can only be sent in compressed format (which I understand is fine, please correct if I am wrong @FabioPinheiro ). Secp keys can be sent in any of the formats and the node will work fine

keyUsage: node_models.KeyUsage,
addedOn: node_models.LedgerData,
revokedOn: Option[node_models.LedgerData]
): node_models.PublicKey = {
val withoutRevKey = node_models
.PublicKey()
.withId(id)
.withEcKeyData(ecKeyData)
.withCompressedEcKeyData(compressedEcKeyData)
.withUsage(keyUsage)
.withAddedOn(addedOn)

revokedOn.fold(withoutRevKey)(revTime => withoutRevKey.withRevokedOn(revTime))
}

def toECKeyData(key: ECPublicKey): node_models.ECKeyData = {
val point = key.getCurvePoint
Comment on lines -101 to -102
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

before, the code assumed we had an ECPublicKey which could be decomposed into x and y coordinates. Now we assume it is always a compressed key

def toCompressedECKeyData(key: PublicKeyData): node_models.CompressedECKeyData = {
node_models
.ECKeyData()
.withCurve(ECConfig.getCURVE_NAME)
.withX(ByteString.copyFrom(point.getX.bytes()))
.withY(ByteString.copyFrom(point.getY.bytes()))
.CompressedECKeyData()
.withCurve(key.curveName)
.withData(ByteString.copyFrom(key.compressedKey.toArray))
}

def toProtoKeyUsage(keyUsage: models.KeyUsage): node_models.KeyUsage = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ package io.iohk.atala.prism.node.interop

import cats.data.NonEmptyList
import doobie.{Get, Meta, Read, Write}
import io.iohk.atala.prism.crypto.{MerkleRoot, Sha256Digest}
import doobie.implicits.legacy.instant._
import io.iohk.atala.prism.node.crypto.CryptoUtils.Sha256Hash
import io.iohk.atala.prism.protos.models.TimestampInfo
import io.iohk.atala.prism.node.models.{DidSuffix, Ledger, TransactionId}
import io.iohk.atala.prism.node.utils.DoobieImplicits.byteArraySeqMeta
Expand All @@ -12,12 +12,6 @@ import java.time.Instant
import scala.collection.compat.immutable.ArraySeq

object implicits {
implicit val merkleRootMeta: Meta[MerkleRoot] =
Meta[Array[Byte]].timap(arr => new MerkleRoot(Sha256Digest.fromBytes(arr)))(
_.getHash.getValue
)
implicit val merkleRootRead: Read[MerkleRoot] =
Read[Array[Byte]].map(arr => new MerkleRoot(Sha256Digest.fromBytes(arr)))

implicit val didSuffixMeta: Meta[DidSuffix] =
Meta[String].timap { DidSuffix.apply }(_.value)
Expand All @@ -36,12 +30,12 @@ object implicits {
implicit val ledgerRead: Read[Ledger] =
Read[String].map { Ledger.withNameInsensitive }

implicit val Sha256DigestWrite: Write[Sha256Digest] =
Write[Array[Byte]].contramap(_.getValue)
implicit val Sha256DigestRead: Read[Sha256Digest] =
Read[Array[Byte]].map(Sha256Digest.fromBytes)
implicit val Sha256DigestGet: Get[Sha256Digest] =
Get[Array[Byte]].map(Sha256Digest.fromBytes)
implicit val Sha256DigestWrite: Write[Sha256Hash] =
Write[Array[Byte]].contramap(_.bytes.toArray)
implicit val Sha256HashRead: Read[Sha256Hash] =
Read[Array[Byte]].map(Sha256Hash.fromBytes)
implicit val Sha256HashGet: Get[Sha256Hash] =
Get[Array[Byte]].map(Sha256Hash.fromBytes)

implicit val timestampInfoRead: Read[TimestampInfo] =
Read[(Instant, Int, Int)].map { case (abt, absn, osn) =>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package io.iohk.atala.prism.node.models

import io.iohk.atala.prism.crypto.{Sha256, Sha256Digest}
import io.iohk.atala.prism.node.crypto.CryptoUtils.Sha256Hash
import io.iohk.atala.prism.protos.node_models
import io.iohk.atala.prism.node.utils.BytesOps
import tofu.logging.{DictLoggable, LogRenderer}
Expand All @@ -25,19 +25,17 @@ object AtalaObjectId {
}

def apply(value: Vector[Byte]): AtalaObjectId = {
// temporary replace for require(value.length == SHA256Digest.getBYTE_LENGTH)
// rewrite to safe version pls
// will throw an error if something is wrong with the value
val digestUnsafe = Sha256Digest.fromBytes(value.toArray).getValue
new AtalaObjectId(digestUnsafe.toVector)
// This will throw an error if something is wrong with the value
val digestUnsafe = Sha256Hash.fromBytes(value.toArray).bytes
new AtalaObjectId(digestUnsafe)
}

def of(atalaObject: node_models.AtalaObject): AtalaObjectId = {
of(atalaObject.toByteArray)
}

def of(bytes: Array[Byte]): AtalaObjectId = {
val hash = Sha256.compute(bytes)
AtalaObjectId(hash.getValue.toVector)
val hash = Sha256Hash.compute(bytes)
AtalaObjectId(hash.bytes)
}
}
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
package io.iohk.atala.prism.node.models

import com.google.protobuf.ByteString
import io.iohk.atala.prism.crypto.{Sha256, Sha256Digest}
import io.iohk.atala.prism.node.crypto.CryptoUtils.Sha256Hash
import io.iohk.atala.prism.protos.node_models
import tofu.logging.{DictLoggable, LogRenderer}
import io.iohk.atala.prism.node.utils.BytesOps

import java.util.UUID

class AtalaOperationId private (val digest: Sha256Digest) {
def value: Vector[Byte] = digest.getValue.toVector
class AtalaOperationId private (val digest: Sha256Hash) {
def value: Vector[Byte] = digest.bytes

def hexValue: String = BytesOps.bytesToHex(value)

Expand Down Expand Up @@ -40,22 +40,22 @@ object AtalaOperationId {
}

def of(atalaOperation: node_models.SignedAtalaOperation): AtalaOperationId = {
val hash = Sha256.compute(atalaOperation.toByteArray)
val hash = Sha256Hash.compute(atalaOperation.toByteArray)
new AtalaOperationId(hash)
}

def random(): AtalaOperationId = {
val hash = Sha256.compute(UUID.randomUUID().toString.getBytes())
val hash = Sha256Hash.compute(UUID.randomUUID().toString.getBytes())
new AtalaOperationId(hash)
}

def fromVectorUnsafe(bytes: Vector[Byte]): AtalaOperationId = {
val hash = Sha256Digest.fromBytes(bytes.toArray)
val hash = Sha256Hash.fromBytes(bytes.toArray)
new AtalaOperationId(hash)
}

def fromHexUnsafe(hex: String): AtalaOperationId = {
val hash = Sha256Digest.fromHex(hex)
val hash = Sha256Hash.fromHex(hex)
new AtalaOperationId(hash)
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package io.iohk.atala.prism.node.models

import io.iohk.atala.prism.crypto.Sha256Digest
import io.iohk.atala.prism.node.crypto.CryptoUtils.Sha256Hash

import scala.util.matching.Regex
import scala.util.{Failure, Success, Try}
Expand All @@ -15,7 +15,7 @@ object DidSuffix {

def didFromStringSuffix(in: String): String = "did:prism:" + in

def fromDigest(in: Sha256Digest): DidSuffix = DidSuffix(in.getHexValue)
def fromDigest(in: Sha256Hash): DidSuffix = DidSuffix(in.hexEncoded)

def fromString(string: String): Try[DidSuffix] = {
if (string.nonEmpty && suffixRegex.pattern.matcher(string).matches())
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package io.iohk.atala.prism.node.models

import io.iohk.atala.prism.crypto.Sha256
import io.iohk.atala.prism.node.crypto.CryptoUtils.Sha256Hash

import scala.util.matching.Regex
import scala.util.{Failure, Success, Try}
Expand All @@ -26,5 +26,5 @@ object IdType {
)
}

def random: IdType = IdType(value = Sha256.compute(java.util.UUID.randomUUID.toString.getBytes).getHexValue)
def random: IdType = IdType(value = Sha256Hash.compute(java.util.UUID.randomUUID.toString.getBytes).hexEncoded)
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ object ProtocolConstants {
val contextStringCharLimit: Int =
Try(globalConfig.getInt("contextStringCharLimit")).toOption.getOrElse(defaultContextStringCharLength)

val supportedEllipticCurves: Seq[String] = List("secp256k1", "Ed25519", "X25519")
val secpCurveName = "secp256k1"
val ed25519CurveName = "Ed25519"
val x25519CurveName = "X25519"

val supportedEllipticCurves: Seq[String] = List(secpCurveName, ed25519CurveName, x25519CurveName)

}
Loading
Loading