Skip to content

Commit

Permalink
feat: Add utilities for extracting authentication credentials from gR…
Browse files Browse the repository at this point in the history
…PC calls (#288)
  • Loading branch information
SanjayVas authored Dec 16, 2024
1 parent bfc08ff commit 966a317
Show file tree
Hide file tree
Showing 8 changed files with 469 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ kt_jvm_library(
name = "grpc",
srcs = glob(["*.kt"]),
deps = [
"//imports/java/com/google/crypto/tink",
"//imports/java/com/google/gson",
"//imports/java/com/google/protobuf",
"//imports/java/io/grpc:api",
"//imports/java/io/grpc:context",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/*
* Copyright 2024 The Cross-Media Measurement Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.wfanet.measurement.common.grpc

import io.grpc.CallCredentials
import io.grpc.Metadata
import io.grpc.SecurityLevel
import io.grpc.Status
import java.util.concurrent.Executor

/**
* [CallCredentials] for a bearer auth token.
*
* @param token the bearer token
* @param requirePrivacy whether to require that the transport's security level is
* [SecurityLevel.PRIVACY_AND_INTEGRITY], e.g. that the transport is encrypted via TLS
*/
class BearerTokenCallCredentials(val token: String, private val requirePrivacy: Boolean = true) :
CallCredentials() {
override fun applyRequestMetadata(
requestInfo: RequestInfo,
appExecutor: Executor,
applier: MetadataApplier,
) {
if (requirePrivacy && requestInfo.securityLevel != SecurityLevel.PRIVACY_AND_INTEGRITY) {
applier.fail(Status.UNAUTHENTICATED.withDescription("Credentials require private transport"))
}

val headers = Metadata()
headers.put(AUTHORIZATION_METADATA_KEY, "$AUTH_TYPE $token")

applier.apply(headers)
}

companion object {
private val AUTHORIZATION_METADATA_KEY: Metadata.Key<String> =
Metadata.Key.of("Authorization", Metadata.ASCII_STRING_MARSHALLER)
private const val AUTH_TYPE = "Bearer"

fun fromHeaders(
headers: Metadata,
requirePrivacy: Boolean = true,
): BearerTokenCallCredentials? {
val authHeader = headers[AUTHORIZATION_METADATA_KEY] ?: return null
if (!authHeader.startsWith(AUTH_TYPE)) {
return null
}

val token = authHeader.substring(AUTH_TYPE.length).trim()
return BearerTokenCallCredentials(token, requirePrivacy)
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Copyright 2024 The Cross-Media Measurement Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.wfanet.measurement.common.grpc

import io.grpc.Grpc
import io.grpc.ServerCall
import io.grpc.Status
import java.security.cert.X509Certificate

object ClientCertificateAuthentication {
/**
* Extracts the TLS client certificate from [call].
*
* @throws io.grpc.StatusException on failure
*/
fun <ReqT, RespT> extractClientCertificate(call: ServerCall<ReqT, RespT>): X509Certificate {
val sslSession =
call.attributes[Grpc.TRANSPORT_ATTR_SSL_SESSION]
?: throw Status.UNAUTHENTICATED.withDescription("No SSL session").asException()
val clientCert =
sslSession.peerCertificates.firstOrNull()
?: throw Status.UNAUTHENTICATED.withDescription("No client certificate").asException()
return clientCert as X509Certificate
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
/*
* Copyright 2024 The Cross-Media Measurement Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.wfanet.measurement.common.grpc

import com.google.crypto.tink.KeysetHandle
import com.google.crypto.tink.jwt.JwkSetConverter
import com.google.crypto.tink.jwt.JwtPublicKeyVerify
import com.google.crypto.tink.jwt.JwtSignatureConfig
import com.google.crypto.tink.jwt.JwtValidator
import com.google.gson.JsonObject
import com.google.gson.JsonParser
import io.grpc.Metadata
import io.grpc.Status
import java.io.IOException
import java.security.GeneralSecurityException
import java.time.Clock
import org.wfanet.measurement.common.base64UrlDecode

/** Utility for extracting OpenID Connect (OIDC) token information from gRPC request headers. */
class OpenIdConnectAuthentication(
audience: String,
openIdProviderConfigs: Iterable<OpenIdProviderConfig>,
clock: Clock = Clock.systemUTC(),
) {
private val jwtValidator =
JwtValidator.newBuilder().setClock(clock).expectAudience(audience).ignoreIssuer().build()

private val jwksHandleByIssuer: Map<String, KeysetHandle> =
openIdProviderConfigs.associateBy({ it.issuer }) {
JwkSetConverter.toPublicKeysetHandle(it.jwks)
}

/**
* Verifies and decodes an OIDC bearer token from [headers].
*
* The token must be a signed JWT.
*
* @throws io.grpc.StatusException on failure
*/
fun verifyAndDecodeBearerToken(headers: Metadata): VerifiedToken {
val credentials =
BearerTokenCallCredentials.fromHeaders(headers)
?: throw Status.UNAUTHENTICATED.withDescription("Bearer token not found in headers")
.asException()

val token: String = credentials.token
val tokenParts = token.split(".")
if (tokenParts.size != 3) {
throw Status.UNAUTHENTICATED.withDescription("Token is not a valid signed JWT").asException()
}
val payload: JsonObject =
try {
JsonParser.parseString(tokenParts[1].base64UrlDecode().toStringUtf8()).asJsonObject
} catch (e: IOException) {
throw Status.UNAUTHENTICATED.withCause(e)
.withDescription("Token is not a valid signed JWT")
.asException()
}

val issuer =
payload.get(ISSUER_CLAIM)?.asString
?: throw Status.UNAUTHENTICATED.withDescription("Issuer not found").asException()
val jwksHandle =
jwksHandleByIssuer[issuer]
?: throw Status.UNAUTHENTICATED.withDescription("Unknown issuer").asException()

val verifiedJwt =
try {
jwksHandle.getPrimitive(JwtPublicKeyVerify::class.java).verifyAndDecode(token, jwtValidator)
} catch (e: GeneralSecurityException) {
throw Status.UNAUTHENTICATED.withCause(e).withDescription(e.message).asException()
}

if (!verifiedJwt.hasSubject()) {
throw Status.UNAUTHENTICATED.withDescription("Subject not found").asException()
}
val scopes: Set<String> =
if (verifiedJwt.hasStringClaim(SCOPES_CLAIM)) {
verifiedJwt.getStringClaim(SCOPES_CLAIM).split(" ").toSet()
} else {
emptySet()
}

return VerifiedToken(issuer, verifiedJwt.subject, scopes)
}

companion object {
init {
JwtSignatureConfig.register()
}

private const val ISSUER_CLAIM = "iss"
private const val SCOPES_CLAIM = "scope"
}

data class VerifiedToken(val issuer: String, val subject: String, val scopes: Set<String>)

data class OpenIdProviderConfig(
val issuer: String,
/** JSON Web Key Set (JWKS) for the provider. */
val jwks: String,
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ kt_jvm_library(
name = "testing",
srcs = glob(["*.kt"]),
deps = [
"//imports/java/com/google/crypto/tink",
"//imports/java/io/grpc:api",
"//imports/java/io/grpc:core",
"//imports/java/io/grpc/inprocess",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
* Copyright 2024 The Cross-Media Measurement Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.wfanet.measurement.common.grpc.testing

import com.google.crypto.tink.KeyTemplate
import com.google.crypto.tink.KeyTemplates
import com.google.crypto.tink.KeysetHandle
import com.google.crypto.tink.jwt.JwkSetConverter
import com.google.crypto.tink.jwt.JwtPublicKeySign
import com.google.crypto.tink.jwt.JwtSignatureConfig
import com.google.crypto.tink.jwt.RawJwt
import java.time.Duration
import java.time.Instant
import org.wfanet.measurement.common.grpc.BearerTokenCallCredentials
import org.wfanet.measurement.common.grpc.OpenIdConnectAuthentication

/** An ephemeral OpenID provider for testing. */
class OpenIdProvider(private val issuer: String) {
private val jwkSetHandle = KeysetHandle.generateNew(KEY_TEMPLATE)

val providerConfig: OpenIdConnectAuthentication.OpenIdProviderConfig by lazy {
val jwks = JwkSetConverter.fromPublicKeysetHandle(jwkSetHandle.publicKeysetHandle)
OpenIdConnectAuthentication.OpenIdProviderConfig(issuer, jwks)
}

fun generateCredentials(
audience: String,
subject: String,
scopes: Set<String>,
expiration: Instant = Instant.now().plus(Duration.ofMinutes(5)),
): BearerTokenCallCredentials {
val token = generateSignedToken(audience, subject, scopes, expiration)
return BearerTokenCallCredentials(token, false)
}

/** Generates a signed and encoded JWT using the specified parameters. */
private fun generateSignedToken(
audience: String,
subject: String,
scopes: Set<String>,
expiration: Instant,
): String {
val rawJwt =
RawJwt.newBuilder()
.setAudience(audience)
.setIssuer(issuer)
.setSubject(subject)
.addStringClaim("scope", scopes.joinToString(" "))
.setExpiration(expiration)
.build()
val signer = jwkSetHandle.getPrimitive(JwtPublicKeySign::class.java)
return signer.signAndEncode(rawJwt)
}

companion object {
init {
JwtSignatureConfig.register()
}

private val KEY_TEMPLATE: KeyTemplate = KeyTemplates.get("JWT_ES256")
}
}
15 changes: 15 additions & 0 deletions src/test/kotlin/org/wfanet/measurement/common/grpc/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,18 @@ kt_jvm_test(
"//src/main/kotlin/org/wfanet/measurement/common/grpc",
],
)

kt_jvm_test(
name = "OpenIdConnectAuthenticationTest",
srcs = ["OpenIdConnectAuthenticationTest.kt"],
test_class = "org.wfanet.measurement.common.grpc.OpenIdConnectAuthenticationTest",
deps = [
"//imports/java/com/google/common/truth",
"//imports/java/org/junit",
"//imports/kotlin/kotlin/test",
"//imports/kotlin/org/mockito/kotlin",
"//src/main/kotlin/org/wfanet/measurement/common/grpc",
"//src/main/kotlin/org/wfanet/measurement/common/grpc/testing",
"//src/main/kotlin/org/wfanet/measurement/common/testing",
],
)
Loading

0 comments on commit 966a317

Please sign in to comment.