Skip to content

Commit

Permalink
Added GrpcServerSpy, FakeIdp, fake cert and key to enable TLS and OID…
Browse files Browse the repository at this point in the history
…C tests.
  • Loading branch information
merlante committed Jul 30, 2024
1 parent bdbaad8 commit 11ad945
Show file tree
Hide file tree
Showing 6 changed files with 494 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -1,27 +1,36 @@
package org.project_kessel.relations.client;

import io.grpc.Metadata;
import org.junit.jupiter.api.AfterAll;
import org.project_kessel.api.relations.v1beta1.CheckRequest;
import org.project_kessel.api.relations.v1beta1.KesselCheckServiceGrpc;
import org.project_kessel.api.relations.v1beta1.KesselLookupServiceGrpc;
import org.project_kessel.api.relations.v1beta1.KesselTupleServiceGrpc;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.project_kessel.relations.client.fake.GrpcServerSpy;

import java.util.HashMap;
import java.util.Hashtable;
import java.util.Optional;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

import static io.smallrye.common.constraint.Assert.assertNotNull;
import static org.junit.jupiter.api.Assertions.*;
import static org.project_kessel.relations.client.util.CertUtil.*;

class RelationsGrpcClientsManagerTest {
private static final Metadata.Key<String> authorizationKey = Metadata.Key.of("Authorization", Metadata.ASCII_STRING_MARSHALLER);

@BeforeAll
static void testSetup() {
/* Make sure all client managers shutdown/removed before tests */
RelationsGrpcClientsManager.shutdownAll();
/* Add self-signed cert to keystore, trust manager and SSL context for TLS testing. */
addTestCACertToTrustStore();
}

@AfterEach
Expand All @@ -30,6 +39,12 @@ void testTeardown() {
RelationsGrpcClientsManager.shutdownAll();
}

@AfterAll
static void removeTestSetup() {
/* Remove self-signed cert */
removeTestCACertFromKeystore();
}

@Test
void testManagerReusePatterns() {
var one = RelationsGrpcClientsManager.forInsecureClients("localhost:8080");
Expand Down Expand Up @@ -115,6 +130,37 @@ void testThreadingChaos() {
}
}

@Test
void testManagersHoldIntendedCredentialsInChannel() throws Exception {
Config.AuthenticationConfig authnConfig = dummyNonDisabledAuthenticationConfig();
var manager = RelationsGrpcClientsManager.forInsecureClients("localhost:7000");
var manager2 = RelationsGrpcClientsManager.forInsecureClients("localhost:7001", authnConfig);
var manager3 = RelationsGrpcClientsManager.forSecureClients("localhost:7002");
var manager4 = RelationsGrpcClientsManager.forSecureClients("localhost:7003", authnConfig);

var checkClient = manager.getCheckClient();
var checkClient2 = manager2.getCheckClient();
var checkClient3 = manager3.getCheckClient();
var checkClient4 = manager4.getCheckClient();

var cd1 = GrpcServerSpy.runAgainstTemporaryServerWithDummyServices(7000, () -> checkClient.check(CheckRequest.getDefaultInstance()));
var cd2 = GrpcServerSpy.runAgainstTemporaryServerWithDummyServices(7001, () -> checkClient2.check(CheckRequest.getDefaultInstance()));
var cd3 = GrpcServerSpy.runAgainstTemporaryTlsServerWithDummyServices(7002, () -> checkClient3.check(CheckRequest.getDefaultInstance()));
var cd4 = GrpcServerSpy.runAgainstTemporaryTlsServerWithDummyServices(7003, () -> checkClient4.check(CheckRequest.getDefaultInstance()));

assertNull(cd1.getMetadata().get(authorizationKey));
assertEquals("NONE", cd1.getCall().getSecurityLevel().toString());

assertNotNull(cd2.getMetadata().get(authorizationKey));
assertEquals("NONE", cd2.getCall().getSecurityLevel().toString());

assertNull(cd3.getMetadata().get(authorizationKey));
assertEquals("PRIVACY_AND_INTEGRITY", cd3.getCall().getSecurityLevel().toString());

assertNotNull(cd4.getMetadata().get(authorizationKey));
assertEquals("PRIVACY_AND_INTEGRITY", cd4.getCall().getSecurityLevel().toString());
}

@Test
void testManagerReuseInternal() throws Exception {
RelationsGrpcClientsManager.forInsecureClients("localhost:8080");
Expand Down Expand Up @@ -190,4 +236,43 @@ void testCreateAndShutdownPatternsInternal() throws Exception {
insecureManagersSize = ((HashMap<?,?>)insecureField.get(null)).size();
assertEquals(0, insecureManagersSize);
}

Config.AuthenticationConfig dummyNonDisabledAuthenticationConfig() {
return new Config.AuthenticationConfig() {
@Override
public Config.AuthMode mode() {
return Config.AuthMode.OIDC_CLIENT_CREDENTIALS; // any non-disabled value
}

@Override
public Optional<Config.OIDCClientCredentialsConfig> clientCredentialsConfig() {
return Optional.of(new Config.OIDCClientCredentialsConfig() {
@Override
public String issuer() {
return "http://localhost:8090";
}

@Override
public String clientId() {
return "test";
}

@Override
public String clientSecret() {
return "test";
}

@Override
public Optional<String[]> scope() {
return Optional.empty();
}

@Override
public Optional<String> oidcClientCredentialsMinterImplementation() {
return Optional.empty();
}
});
}
};
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package org.project_kessel.relations.client.fake;

import com.sun.net.httpserver.HttpExchange;
import com.sun.net.httpserver.HttpHandler;
import com.sun.net.httpserver.HttpServer;

import java.io.IOException;
import java.io.OutputStream;
import java.net.InetSocketAddress;

/**
* Super-fake Idp that supports a hard-coded well-known discovery endpoint and a corresponding fake token endpoint.
* Does not use TLS.
*/
public class FakeIdp {
private final int port;
HttpServer server = null;

public FakeIdp(int port) {
this.port = port;
}

public void start() {
try {
server = HttpServer.create(new InetSocketAddress(port), 0);
} catch (IOException e) {
throw new RuntimeException(e);
}
server.createContext("/.well-known/openid-configuration", new WellKnownHandler());
server.createContext("/token", new TokenHandler());
server.setExecutor(null); // creates a default executor
server.start();
}

public void stop() {
server.stop(0);
}

static class TokenHandler implements HttpHandler {
@Override
public void handle(HttpExchange t) throws IOException {
String response = "{\n" +
" \"iss\": \"http://localhost:8090/\",\n" +
" \"aud\": \"us\",\n" +
" \"sub\": \"usr_123\",\n" +
" \"scope\": \"read write\",\n" +
" \"iat\": 1458785796,\n" +
" \"exp\": 1458872196,\n" +
" \"token_type\": \"Bearer\",\n" +
" \"access_token\": \"blah\"\n" +
"}";
t.getResponseHeaders().set("Content-Type", "application/json; charset=UTF-8");
t.sendResponseHeaders(200, response.length());
OutputStream os = t.getResponseBody();
os.write(response.getBytes());
os.close();
}
}

static class WellKnownHandler implements HttpHandler {
@Override
public void handle(HttpExchange t) throws IOException {
String response = "{\n" +
"\t\"issuer\":\"http://localhost:8090\",\n" +
"\t\"authorization_endpoint\":\"http://localhost:8090/protocol/openid-connect/auth\",\n" +
"\t\"token_endpoint\":\"http://localhost:8090/token\",\n" +
"\t\"introspection_endpoint\":\"http://localhost:8090/token/introspect\",\n" +
"\t\"jwks_uri\":\"http://localhost:8090/certs\",\n" +
"\t\"response_types_supported\":[\"code\",\"none\",\"id_token\",\"token\",\"id_token token\",\"code id_token\",\"code token\",\"code id_token token\"],\n" +
"\t\"token_endpoint_auth_methods_supported\":[\"private_key_jwt\",\"client_secret_basic\",\"client_secret_post\",\"tls_client_auth\",\"client_secret_jwt\"],\n" +
"\t\"subject_types_supported\":[\"public\",\"pairwise\"]\n" +
"}";
t.getResponseHeaders().set("Content-Type", "application/json; charset=UTF-8");
t.sendResponseHeaders(200, response.length());
OutputStream os = t.getResponseBody();
os.write(response.getBytes());
os.close();
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
package org.project_kessel.relations.client.fake;

import io.grpc.*;
import io.grpc.stub.StreamObserver;
import org.project_kessel.api.relations.v1beta1.*;

import java.io.File;
import java.io.IOException;
import java.net.URL;
import java.util.Objects;
import java.util.concurrent.TimeUnit;

public class GrpcServerSpy extends Server {
private final Server server;

public GrpcServerSpy(int port, boolean tlsEnabled, ServerInterceptor interceptor, BindableService... services) {
ServerBuilder<?> serverBuilder = ServerBuilder.forPort(port);
if (tlsEnabled) {
URL certsUrl = Thread.currentThread().getContextClassLoader().getResource("certs/test.crt");
URL keyUrl = Thread.currentThread().getContextClassLoader().getResource("certs/test.key");
File certFile = new File(Objects.requireNonNull(certsUrl).getPath());
File keyFile = new File(Objects.requireNonNull(keyUrl).getPath());
serverBuilder.useTransportSecurity(certFile, keyFile);
}
if (interceptor != null) {
serverBuilder.intercept(interceptor);
}
for (BindableService service : services) {
serverBuilder.addService(service);
}
server = serverBuilder.build();
}

public static ServerCallDetails runAgainstTemporaryServerWithDummyServices(int port, Call grpcCallFunction) {
return runAgainstTemporaryServerWithDummyServicesTlsSelect(port, false, grpcCallFunction);
}

public static ServerCallDetails runAgainstTemporaryTlsServerWithDummyServices(int port, Call grpcCallFunction) {
return runAgainstTemporaryServerWithDummyServicesTlsSelect(port, true, grpcCallFunction);
}

private static ServerCallDetails runAgainstTemporaryServerWithDummyServicesTlsSelect(int port, boolean tlsEnabled, Call grpcCallFunction) {
var dummyCheckService = new KesselCheckServiceGrpc.KesselCheckServiceImplBase() {
@Override
public void check(CheckRequest request, StreamObserver<CheckResponse> responseObserver) {
responseObserver.onNext(CheckResponse.getDefaultInstance());
responseObserver.onCompleted();
}
};
var dummyTupleService = new KesselTupleServiceGrpc.KesselTupleServiceImplBase() {
@Override
public void readTuples(ReadTuplesRequest request, StreamObserver<ReadTuplesResponse> responseObserver) {
responseObserver.onNext(ReadTuplesResponse.getDefaultInstance());
responseObserver.onCompleted();
}
};
var dummyLookupService = new KesselLookupServiceGrpc.KesselLookupServiceImplBase() {
@Override
public void lookupSubjects(LookupSubjectsRequest request, StreamObserver<LookupSubjectsResponse> responseObserver) {
responseObserver.onNext(LookupSubjectsResponse.getDefaultInstance());
responseObserver.onCompleted();
}
};

return runAgainstTemporaryServerTlsSelect(port, tlsEnabled, grpcCallFunction, dummyCheckService, dummyTupleService, dummyLookupService);
}

public static ServerCallDetails runAgainstTemporaryServer(int port, Call grpcCallFunction, BindableService... services) {
return runAgainstTemporaryServerTlsSelect(port, false, grpcCallFunction, services);
}

public static ServerCallDetails runAgainstTemporaryTlsServer(int port, Call grpcCallFunction, BindableService... services) {
return runAgainstTemporaryServerTlsSelect(port, true, grpcCallFunction, services);
}

private static ServerCallDetails runAgainstTemporaryServerTlsSelect(int port, boolean tlsEnabled, Call grpcCallFunction, BindableService... services) {
final ServerCallDetails serverCallDetails = new ServerCallDetails();

var spyInterceptor = new ServerInterceptor() {
@Override
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(ServerCall<ReqT, RespT> call, Metadata headers, ServerCallHandler<ReqT, RespT> next) {
serverCallDetails.setCall(call);
serverCallDetails.setMetadata(headers);
return next.startCall(call, headers);
}
};

FakeIdp fakeIdp = new FakeIdp(8090);
var serverSpy = new GrpcServerSpy(port, tlsEnabled, spyInterceptor, services);

try {
fakeIdp.start();
serverSpy.start();
grpcCallFunction.call();
serverSpy.shutdown();
fakeIdp.stop();

return serverCallDetails;
} catch (Exception e) {
throw new RuntimeException(e);
} finally {
serverSpy.shutdown();
fakeIdp.stop();
}
}

@Override
public Server start() throws IOException {
server.start();
return this;
}

@Override
public Server shutdown() {
server.shutdown();
return this;
}

@Override
public Server shutdownNow() {
server.shutdownNow();
return this;
}

@Override
public boolean isShutdown() {
return server.isShutdown();
}

@Override
public boolean isTerminated() {
return server.isTerminated();
}

@Override
public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException {
return server.awaitTermination(timeout, unit);
}

@Override
public void awaitTermination() throws InterruptedException {
server.awaitTermination();
}

public interface Call {
void call();
}

public static class ServerCallDetails {
private ServerCall<?,?> call;
private Metadata metadata;

public ServerCallDetails() {
}

public ServerCall<?,?> getCall() {
return call;
}

public Metadata getMetadata() {
return metadata;
}

public void setCall(ServerCall<?,?> call) {
this.call = call;
}

public void setMetadata(Metadata metadata) {
this.metadata = metadata;
}
}
}
Loading

0 comments on commit 11ad945

Please sign in to comment.