From 966a317afcfd2fad5084d5b69121ead5f5782f1b Mon Sep 17 00:00:00 2001 From: Sanjay Vasandani Date: Mon, 16 Dec 2024 12:48:10 -0800 Subject: [PATCH] feat: Add utilities for extracting authentication credentials from gRPC calls (#288) --- .../measurement/common/grpc/BUILD.bazel | 2 + .../common/grpc/BearerTokenCallCredentials.kt | 67 ++++++++ .../grpc/ClientCertificateAuthentication.kt | 39 +++++ .../grpc/OpenIdConnectAuthentication.kt | 117 ++++++++++++++ .../common/grpc/testing/BUILD.bazel | 1 + .../common/grpc/testing/OpenIdProvider.kt | 76 +++++++++ .../measurement/common/grpc/BUILD.bazel | 15 ++ .../grpc/OpenIdConnectAuthenticationTest.kt | 152 ++++++++++++++++++ 8 files changed, 469 insertions(+) create mode 100644 src/main/kotlin/org/wfanet/measurement/common/grpc/BearerTokenCallCredentials.kt create mode 100644 src/main/kotlin/org/wfanet/measurement/common/grpc/ClientCertificateAuthentication.kt create mode 100644 src/main/kotlin/org/wfanet/measurement/common/grpc/OpenIdConnectAuthentication.kt create mode 100644 src/main/kotlin/org/wfanet/measurement/common/grpc/testing/OpenIdProvider.kt create mode 100644 src/test/kotlin/org/wfanet/measurement/common/grpc/OpenIdConnectAuthenticationTest.kt diff --git a/src/main/kotlin/org/wfanet/measurement/common/grpc/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/common/grpc/BUILD.bazel index 1909e010b..48e752098 100644 --- a/src/main/kotlin/org/wfanet/measurement/common/grpc/BUILD.bazel +++ b/src/main/kotlin/org/wfanet/measurement/common/grpc/BUILD.bazel @@ -6,6 +6,8 @@ kt_jvm_library( name = "grpc", srcs = glob(["*.kt"]), deps = [ + "//imports/java/com/google/crypto/tink", + "//imports/java/com/google/gson", "//imports/java/com/google/protobuf", "//imports/java/io/grpc:api", "//imports/java/io/grpc:context", diff --git a/src/main/kotlin/org/wfanet/measurement/common/grpc/BearerTokenCallCredentials.kt b/src/main/kotlin/org/wfanet/measurement/common/grpc/BearerTokenCallCredentials.kt new file mode 100644 index 000000000..cc8731337 --- /dev/null +++ b/src/main/kotlin/org/wfanet/measurement/common/grpc/BearerTokenCallCredentials.kt @@ -0,0 +1,67 @@ +/* + * Copyright 2024 The Cross-Media Measurement Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.wfanet.measurement.common.grpc + +import io.grpc.CallCredentials +import io.grpc.Metadata +import io.grpc.SecurityLevel +import io.grpc.Status +import java.util.concurrent.Executor + +/** + * [CallCredentials] for a bearer auth token. + * + * @param token the bearer token + * @param requirePrivacy whether to require that the transport's security level is + * [SecurityLevel.PRIVACY_AND_INTEGRITY], e.g. that the transport is encrypted via TLS + */ +class BearerTokenCallCredentials(val token: String, private val requirePrivacy: Boolean = true) : + CallCredentials() { + override fun applyRequestMetadata( + requestInfo: RequestInfo, + appExecutor: Executor, + applier: MetadataApplier, + ) { + if (requirePrivacy && requestInfo.securityLevel != SecurityLevel.PRIVACY_AND_INTEGRITY) { + applier.fail(Status.UNAUTHENTICATED.withDescription("Credentials require private transport")) + } + + val headers = Metadata() + headers.put(AUTHORIZATION_METADATA_KEY, "$AUTH_TYPE $token") + + applier.apply(headers) + } + + companion object { + private val AUTHORIZATION_METADATA_KEY: Metadata.Key = + Metadata.Key.of("Authorization", Metadata.ASCII_STRING_MARSHALLER) + private const val AUTH_TYPE = "Bearer" + + fun fromHeaders( + headers: Metadata, + requirePrivacy: Boolean = true, + ): BearerTokenCallCredentials? { + val authHeader = headers[AUTHORIZATION_METADATA_KEY] ?: return null + if (!authHeader.startsWith(AUTH_TYPE)) { + return null + } + + val token = authHeader.substring(AUTH_TYPE.length).trim() + return BearerTokenCallCredentials(token, requirePrivacy) + } + } +} diff --git a/src/main/kotlin/org/wfanet/measurement/common/grpc/ClientCertificateAuthentication.kt b/src/main/kotlin/org/wfanet/measurement/common/grpc/ClientCertificateAuthentication.kt new file mode 100644 index 000000000..7cf0923f0 --- /dev/null +++ b/src/main/kotlin/org/wfanet/measurement/common/grpc/ClientCertificateAuthentication.kt @@ -0,0 +1,39 @@ +/* + * Copyright 2024 The Cross-Media Measurement Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.wfanet.measurement.common.grpc + +import io.grpc.Grpc +import io.grpc.ServerCall +import io.grpc.Status +import java.security.cert.X509Certificate + +object ClientCertificateAuthentication { + /** + * Extracts the TLS client certificate from [call]. + * + * @throws io.grpc.StatusException on failure + */ + fun extractClientCertificate(call: ServerCall): X509Certificate { + val sslSession = + call.attributes[Grpc.TRANSPORT_ATTR_SSL_SESSION] + ?: throw Status.UNAUTHENTICATED.withDescription("No SSL session").asException() + val clientCert = + sslSession.peerCertificates.firstOrNull() + ?: throw Status.UNAUTHENTICATED.withDescription("No client certificate").asException() + return clientCert as X509Certificate + } +} diff --git a/src/main/kotlin/org/wfanet/measurement/common/grpc/OpenIdConnectAuthentication.kt b/src/main/kotlin/org/wfanet/measurement/common/grpc/OpenIdConnectAuthentication.kt new file mode 100644 index 000000000..c92803527 --- /dev/null +++ b/src/main/kotlin/org/wfanet/measurement/common/grpc/OpenIdConnectAuthentication.kt @@ -0,0 +1,117 @@ +/* + * Copyright 2024 The Cross-Media Measurement Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.wfanet.measurement.common.grpc + +import com.google.crypto.tink.KeysetHandle +import com.google.crypto.tink.jwt.JwkSetConverter +import com.google.crypto.tink.jwt.JwtPublicKeyVerify +import com.google.crypto.tink.jwt.JwtSignatureConfig +import com.google.crypto.tink.jwt.JwtValidator +import com.google.gson.JsonObject +import com.google.gson.JsonParser +import io.grpc.Metadata +import io.grpc.Status +import java.io.IOException +import java.security.GeneralSecurityException +import java.time.Clock +import org.wfanet.measurement.common.base64UrlDecode + +/** Utility for extracting OpenID Connect (OIDC) token information from gRPC request headers. */ +class OpenIdConnectAuthentication( + audience: String, + openIdProviderConfigs: Iterable, + clock: Clock = Clock.systemUTC(), +) { + private val jwtValidator = + JwtValidator.newBuilder().setClock(clock).expectAudience(audience).ignoreIssuer().build() + + private val jwksHandleByIssuer: Map = + openIdProviderConfigs.associateBy({ it.issuer }) { + JwkSetConverter.toPublicKeysetHandle(it.jwks) + } + + /** + * Verifies and decodes an OIDC bearer token from [headers]. + * + * The token must be a signed JWT. + * + * @throws io.grpc.StatusException on failure + */ + fun verifyAndDecodeBearerToken(headers: Metadata): VerifiedToken { + val credentials = + BearerTokenCallCredentials.fromHeaders(headers) + ?: throw Status.UNAUTHENTICATED.withDescription("Bearer token not found in headers") + .asException() + + val token: String = credentials.token + val tokenParts = token.split(".") + if (tokenParts.size != 3) { + throw Status.UNAUTHENTICATED.withDescription("Token is not a valid signed JWT").asException() + } + val payload: JsonObject = + try { + JsonParser.parseString(tokenParts[1].base64UrlDecode().toStringUtf8()).asJsonObject + } catch (e: IOException) { + throw Status.UNAUTHENTICATED.withCause(e) + .withDescription("Token is not a valid signed JWT") + .asException() + } + + val issuer = + payload.get(ISSUER_CLAIM)?.asString + ?: throw Status.UNAUTHENTICATED.withDescription("Issuer not found").asException() + val jwksHandle = + jwksHandleByIssuer[issuer] + ?: throw Status.UNAUTHENTICATED.withDescription("Unknown issuer").asException() + + val verifiedJwt = + try { + jwksHandle.getPrimitive(JwtPublicKeyVerify::class.java).verifyAndDecode(token, jwtValidator) + } catch (e: GeneralSecurityException) { + throw Status.UNAUTHENTICATED.withCause(e).withDescription(e.message).asException() + } + + if (!verifiedJwt.hasSubject()) { + throw Status.UNAUTHENTICATED.withDescription("Subject not found").asException() + } + val scopes: Set = + if (verifiedJwt.hasStringClaim(SCOPES_CLAIM)) { + verifiedJwt.getStringClaim(SCOPES_CLAIM).split(" ").toSet() + } else { + emptySet() + } + + return VerifiedToken(issuer, verifiedJwt.subject, scopes) + } + + companion object { + init { + JwtSignatureConfig.register() + } + + private const val ISSUER_CLAIM = "iss" + private const val SCOPES_CLAIM = "scope" + } + + data class VerifiedToken(val issuer: String, val subject: String, val scopes: Set) + + data class OpenIdProviderConfig( + val issuer: String, + /** JSON Web Key Set (JWKS) for the provider. */ + val jwks: String, + ) +} diff --git a/src/main/kotlin/org/wfanet/measurement/common/grpc/testing/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/common/grpc/testing/BUILD.bazel index 4795a2e1b..bf848e98f 100644 --- a/src/main/kotlin/org/wfanet/measurement/common/grpc/testing/BUILD.bazel +++ b/src/main/kotlin/org/wfanet/measurement/common/grpc/testing/BUILD.bazel @@ -9,6 +9,7 @@ kt_jvm_library( name = "testing", srcs = glob(["*.kt"]), deps = [ + "//imports/java/com/google/crypto/tink", "//imports/java/io/grpc:api", "//imports/java/io/grpc:core", "//imports/java/io/grpc/inprocess", diff --git a/src/main/kotlin/org/wfanet/measurement/common/grpc/testing/OpenIdProvider.kt b/src/main/kotlin/org/wfanet/measurement/common/grpc/testing/OpenIdProvider.kt new file mode 100644 index 000000000..7ffdc1563 --- /dev/null +++ b/src/main/kotlin/org/wfanet/measurement/common/grpc/testing/OpenIdProvider.kt @@ -0,0 +1,76 @@ +/* + * Copyright 2024 The Cross-Media Measurement Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.wfanet.measurement.common.grpc.testing + +import com.google.crypto.tink.KeyTemplate +import com.google.crypto.tink.KeyTemplates +import com.google.crypto.tink.KeysetHandle +import com.google.crypto.tink.jwt.JwkSetConverter +import com.google.crypto.tink.jwt.JwtPublicKeySign +import com.google.crypto.tink.jwt.JwtSignatureConfig +import com.google.crypto.tink.jwt.RawJwt +import java.time.Duration +import java.time.Instant +import org.wfanet.measurement.common.grpc.BearerTokenCallCredentials +import org.wfanet.measurement.common.grpc.OpenIdConnectAuthentication + +/** An ephemeral OpenID provider for testing. */ +class OpenIdProvider(private val issuer: String) { + private val jwkSetHandle = KeysetHandle.generateNew(KEY_TEMPLATE) + + val providerConfig: OpenIdConnectAuthentication.OpenIdProviderConfig by lazy { + val jwks = JwkSetConverter.fromPublicKeysetHandle(jwkSetHandle.publicKeysetHandle) + OpenIdConnectAuthentication.OpenIdProviderConfig(issuer, jwks) + } + + fun generateCredentials( + audience: String, + subject: String, + scopes: Set, + expiration: Instant = Instant.now().plus(Duration.ofMinutes(5)), + ): BearerTokenCallCredentials { + val token = generateSignedToken(audience, subject, scopes, expiration) + return BearerTokenCallCredentials(token, false) + } + + /** Generates a signed and encoded JWT using the specified parameters. */ + private fun generateSignedToken( + audience: String, + subject: String, + scopes: Set, + expiration: Instant, + ): String { + val rawJwt = + RawJwt.newBuilder() + .setAudience(audience) + .setIssuer(issuer) + .setSubject(subject) + .addStringClaim("scope", scopes.joinToString(" ")) + .setExpiration(expiration) + .build() + val signer = jwkSetHandle.getPrimitive(JwtPublicKeySign::class.java) + return signer.signAndEncode(rawJwt) + } + + companion object { + init { + JwtSignatureConfig.register() + } + + private val KEY_TEMPLATE: KeyTemplate = KeyTemplates.get("JWT_ES256") + } +} diff --git a/src/test/kotlin/org/wfanet/measurement/common/grpc/BUILD.bazel b/src/test/kotlin/org/wfanet/measurement/common/grpc/BUILD.bazel index 7ec81d3ad..4b59c0358 100644 --- a/src/test/kotlin/org/wfanet/measurement/common/grpc/BUILD.bazel +++ b/src/test/kotlin/org/wfanet/measurement/common/grpc/BUILD.bazel @@ -52,3 +52,18 @@ kt_jvm_test( "//src/main/kotlin/org/wfanet/measurement/common/grpc", ], ) + +kt_jvm_test( + name = "OpenIdConnectAuthenticationTest", + srcs = ["OpenIdConnectAuthenticationTest.kt"], + test_class = "org.wfanet.measurement.common.grpc.OpenIdConnectAuthenticationTest", + deps = [ + "//imports/java/com/google/common/truth", + "//imports/java/org/junit", + "//imports/kotlin/kotlin/test", + "//imports/kotlin/org/mockito/kotlin", + "//src/main/kotlin/org/wfanet/measurement/common/grpc", + "//src/main/kotlin/org/wfanet/measurement/common/grpc/testing", + "//src/main/kotlin/org/wfanet/measurement/common/testing", + ], +) diff --git a/src/test/kotlin/org/wfanet/measurement/common/grpc/OpenIdConnectAuthenticationTest.kt b/src/test/kotlin/org/wfanet/measurement/common/grpc/OpenIdConnectAuthenticationTest.kt new file mode 100644 index 000000000..168290886 --- /dev/null +++ b/src/test/kotlin/org/wfanet/measurement/common/grpc/OpenIdConnectAuthenticationTest.kt @@ -0,0 +1,152 @@ +/* + * Copyright 2024 The Cross-Media Measurement Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.wfanet.measurement.common.grpc + +import com.google.common.truth.Truth.assertThat +import io.grpc.CallCredentials +import io.grpc.Metadata +import io.grpc.Status +import io.grpc.StatusException +import java.time.Duration +import java.time.Instant +import java.util.concurrent.Executor +import kotlin.test.assertFailsWith +import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.JUnit4 +import org.mockito.Mockito.mock +import org.wfanet.measurement.common.grpc.testing.OpenIdProvider +import org.wfanet.measurement.common.testing.verifyAndCapture + +@RunWith(JUnit4::class) +class OpenIdConnectAuthenticationTest { + @Test + fun `verifyAndDecodeBearerToken returns VerifiedToken`() { + val issuer = "example.com" + val subject = "user1@example.com" + val audience = "foobar" + val scopes = setOf("foo.bar", "foo.baz") + val openIdProvider = OpenIdProvider(issuer) + val credentials = openIdProvider.generateCredentials(audience, subject, scopes) + val auth = OpenIdConnectAuthentication(audience, listOf(openIdProvider.providerConfig)) + + val token = auth.verifyAndDecodeBearerToken(extractHeaders(credentials)) + + assertThat(token).isEqualTo(OpenIdConnectAuthentication.VerifiedToken(issuer, subject, scopes)) + } + + @Test + fun `verifyAndDecodeBearerToken throws UNAUTHENTICATED when token is expired`() { + val issuer = "example.com" + val subject = "user1@example.com" + val audience = "foobar" + val scopes = setOf("foo.bar", "foo.baz") + val openIdProvider = OpenIdProvider(issuer) + val credentials = + openIdProvider.generateCredentials( + audience, + subject, + scopes, + Instant.now().minus(Duration.ofMinutes(5)), + ) + val auth = OpenIdConnectAuthentication(audience, listOf(openIdProvider.providerConfig)) + + val exception = + assertFailsWith { + auth.verifyAndDecodeBearerToken(extractHeaders(credentials)) + } + + assertThat(exception.status.code).isEqualTo(Status.Code.UNAUTHENTICATED) + assertThat(exception).hasMessageThat().ignoringCase().contains("expired") + } + + @Test + fun `verifyAndDecodeBearerToken throws UNAUTHENTICATED when audience does not match`() { + val issuer = "example.com" + val subject = "user1@example.com" + val audience = "foobar" + val scopes = setOf("foo.bar", "foo.baz") + val openIdProvider = OpenIdProvider(issuer) + val credentials = openIdProvider.generateCredentials("bad-audience", subject, scopes) + val auth = OpenIdConnectAuthentication(audience, listOf(openIdProvider.providerConfig)) + + val exception = + assertFailsWith { + auth.verifyAndDecodeBearerToken(extractHeaders(credentials)) + } + + assertThat(exception.status.code).isEqualTo(Status.Code.UNAUTHENTICATED) + assertThat(exception).hasMessageThat().ignoringCase().contains("audience") + } + + @Test + fun `verifyAndDecodeBearerToken throws UNAUTHENTICATED when provider not found for issuer`() { + val issuer = "example.com" + val subject = "user1@example.com" + val audience = "foobar" + val scopes = setOf("foo.bar", "foo.baz") + val openIdProvider = OpenIdProvider(issuer) + val credentials = openIdProvider.generateCredentials(audience, subject, scopes) + val auth = OpenIdConnectAuthentication(audience, emptyList()) + + val exception = + assertFailsWith { + auth.verifyAndDecodeBearerToken(extractHeaders(credentials)) + } + + assertThat(exception.status.code).isEqualTo(Status.Code.UNAUTHENTICATED) + assertThat(exception).hasMessageThat().ignoringCase().contains("issuer") + } + + @Test + fun `verifyAndDecodeBearerToken throws UNAUTHENTICATED when token is not a valid JWT`() { + val audience = "foobar" + val credentials = BearerTokenCallCredentials("foo", false) + val auth = OpenIdConnectAuthentication(audience, emptyList()) + + val exception = + assertFailsWith { + auth.verifyAndDecodeBearerToken(extractHeaders(credentials)) + } + + assertThat(exception.status.code).isEqualTo(Status.Code.UNAUTHENTICATED) + assertThat(exception).hasMessageThat().contains("JWT") + } + + @Test + fun `verifyAndDecodeBearerToken throws UNAUTHENTICATED when header not found`() { + val audience = "foobar" + val auth = OpenIdConnectAuthentication(audience, emptyList()) + + val exception = assertFailsWith { auth.verifyAndDecodeBearerToken(Metadata()) } + + assertThat(exception.status.code).isEqualTo(Status.Code.UNAUTHENTICATED) + assertThat(exception).hasMessageThat().contains("header") + } + + private fun extractHeaders(credentials: BearerTokenCallCredentials): Metadata { + val applierMock = mock() + credentials.applyRequestMetadata(mock(), DirectExecutor, applierMock) + return verifyAndCapture(applierMock, CallCredentials.MetadataApplier::apply) + } + + private object DirectExecutor : Executor { + override fun execute(command: Runnable) { + command.run() + } + } +}