From 94a7e87d9d5ca3e2fc73e38f1fe7fc8e9fdd4f99 Mon Sep 17 00:00:00 2001 From: Namjae Kim Date: Fri, 10 Jan 2025 11:49:00 +0900 Subject: [PATCH] INTERNAL: Add ScramSaslClient --- pom.xml | 5 + .../spy/memcached/auth/ScramMechanism.java | 49 +++++ .../spy/memcached/auth/ScramSaslClient.java | 176 ++++++++++++++++++ .../auth/ScramSaslClientProvider.java | 21 +++ 4 files changed, 251 insertions(+) create mode 100644 src/main/java/net/spy/memcached/auth/ScramMechanism.java create mode 100644 src/main/java/net/spy/memcached/auth/ScramSaslClient.java create mode 100644 src/main/java/net/spy/memcached/auth/ScramSaslClientProvider.java diff --git a/pom.xml b/pom.xml index 22122c096..8dc15b2a6 100644 --- a/pom.xml +++ b/pom.xml @@ -118,6 +118,11 @@ + + com.bolyartech.scram_sasl + scram_sasl + 2.0.2 + diff --git a/src/main/java/net/spy/memcached/auth/ScramMechanism.java b/src/main/java/net/spy/memcached/auth/ScramMechanism.java new file mode 100644 index 000000000..f26fdb95f --- /dev/null +++ b/src/main/java/net/spy/memcached/auth/ScramMechanism.java @@ -0,0 +1,49 @@ +package net.spy.memcached.auth; + +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +public enum ScramMechanism { + SCRAM_SHA_256("SHA-256", "HmacSHA256"); + + private static final Map MECHANISMS_MAP; + + private final String mechanismName; + private final String hashAlgorithm; + private final String macAlgorithm; + + static { + Map map = new HashMap<>(); + for (ScramMechanism mech : values()) { + map.put(mech.mechanismName, mech); + } + MECHANISMS_MAP = Collections.unmodifiableMap(map); + } + + private ScramMechanism(String hashAlgorithm, String macAlgorithm) { + this.mechanismName = "SCRAM-" + hashAlgorithm; + this.hashAlgorithm = hashAlgorithm; + this.macAlgorithm = macAlgorithm; + } + + public final String mechanismName() { + return this.mechanismName; + } + public String hashAlgorithm() { + return hashAlgorithm; + } + + public String macAlgorithm() { + return macAlgorithm; + } + + public static ScramMechanism forMechanismName(String mechanismName) { + return MECHANISMS_MAP.get(mechanismName); + } + + public static Collection mechanismNames() { + return MECHANISMS_MAP.keySet(); + } +} diff --git a/src/main/java/net/spy/memcached/auth/ScramSaslClient.java b/src/main/java/net/spy/memcached/auth/ScramSaslClient.java new file mode 100644 index 000000000..bee67f7f8 --- /dev/null +++ b/src/main/java/net/spy/memcached/auth/ScramSaslClient.java @@ -0,0 +1,176 @@ +package net.spy.memcached.auth; + +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.Collection; +import java.util.Map; + +import javax.security.auth.callback.Callback; +import javax.security.auth.callback.CallbackHandler; +import javax.security.auth.callback.NameCallback; +import javax.security.auth.callback.PasswordCallback; +import javax.security.sasl.SaslClient; +import javax.security.sasl.SaslClientFactory; +import javax.security.sasl.SaslException; + +import com.bolyartech.scram_sasl.client.ScramClientFunctionality; +import com.bolyartech.scram_sasl.client.ScramClientFunctionalityImpl; +import com.bolyartech.scram_sasl.common.ScramException; + +public class ScramSaslClient implements SaslClient { + + enum State { + SEND_CLIENT_FIRST_MESSAGE, + RECEIVE_SERVER_FIRST_MESSAGE, + RECEIVE_SERVER_FINAL_MESSAGE, + COMPLETE, + FAILED + } + + private final ScramMechanism mechanism; + private final CallbackHandler callbackHandler; + private final ScramClientFunctionality scf; + private State state; + + public ScramSaslClient(ScramMechanism mechanism, CallbackHandler cbh) { + this.callbackHandler = cbh; + this.mechanism = mechanism; + this.scf = new ScramClientFunctionalityImpl( + mechanism.hashAlgorithm(), mechanism.macAlgorithm()); + this.state = State.SEND_CLIENT_FIRST_MESSAGE; + } + + @Override + public String getMechanismName() { + return this.mechanism.mechanismName(); + } + + @Override + public boolean hasInitialResponse() { + return true; + } + + @Override + public byte[] evaluateChallenge(byte[] challenge) throws SaslException { + try { + switch (this.state) { + case SEND_CLIENT_FIRST_MESSAGE: + if (challenge != null && challenge.length != 0) { + throw new SaslException("Expected empty challenge"); + } + + NameCallback nameCallback = new NameCallback("Name: "); + + try { + callbackHandler.handle(new Callback[]{nameCallback}); + } catch (Throwable e) { + throw new SaslException("User name could not be obtained", e); + } + + String username = nameCallback.getName(); + byte[] clientFirstMessage = this.scf.prepareFirstMessage(username).getBytes(); + this.state = State.RECEIVE_SERVER_FIRST_MESSAGE; + return clientFirstMessage; + + case RECEIVE_SERVER_FIRST_MESSAGE: + String serverFirstMessage = new String(challenge, StandardCharsets.UTF_8); + + PasswordCallback passwordCallback = new PasswordCallback("Password: ", false); + try { + callbackHandler.handle(new Callback[]{passwordCallback}); + } catch (Throwable e) { + throw new SaslException("Password could not be obtained", e); + } + + String password = String.valueOf(passwordCallback.getPassword()); + byte[] clientFinalMessage = this.scf.prepareFinalMessage( + password, serverFirstMessage).getBytes(); + this.state = State.RECEIVE_SERVER_FINAL_MESSAGE; + return clientFinalMessage; + + case RECEIVE_SERVER_FINAL_MESSAGE: + String serverFinalMessage = new String(challenge, StandardCharsets.UTF_8); + if (!this.scf.checkServerFinalMessage(serverFinalMessage)) { + throw new SaslException("Sasl authentication using " + this.mechanism + + " failed with error: invalid server final message"); + } + this.state = State.COMPLETE; + return new byte[]{}; + + default: + throw new SaslException("Unexpected challenge in Sasl client state " + this.state); + } + } catch (ScramException e) { + this.state = State.FAILED; + throw new SaslException("ScramException", e); + } catch (SaslException e) { + this.state = State.FAILED; + throw e; + } + } + + @Override + public boolean isComplete() { + return this.state == State.COMPLETE; + } + + @Override + public byte[] unwrap(byte[] incoming, int offset, int len) throws SaslException { + if (!isComplete()) { + throw new IllegalStateException("Authentication exchange has not completed"); + } + return Arrays.copyOfRange(incoming, offset, offset + len); + } + + @Override + public byte[] wrap(byte[] outgoing, int offset, int len) throws SaslException { + if (!isComplete()) { + throw new IllegalStateException("Authentication exchange has not completed"); + } + return Arrays.copyOfRange(outgoing, offset, offset + len); + } + + @Override + public Object getNegotiatedProperty(String propName) { + if (!isComplete()) { + throw new IllegalStateException("Authentication exchange has not completed"); + } + return null; + } + + @Override + public void dispose() throws SaslException { + } + + public static class ScramSaslClientFactory implements SaslClientFactory { + @Override + public SaslClient createSaslClient(String[] mechanisms, + String authorizationId, + String protocol, + String serverName, + Map props, + CallbackHandler cbh) throws SaslException { + + ScramMechanism mechanism = null; + for (String mech : mechanisms) { + mechanism = ScramMechanism.forMechanismName(mech); + if (mechanism != null) { + break; + } + } + if (mechanism == null) { + throw new SaslException(String.format("Requested mechanisms '%s' not supported." + + " Supported mechanisms are '%s'.", + Arrays.asList(mechanisms), ScramMechanism.mechanismNames())); + } + + return new ScramSaslClient(mechanism, cbh); + } + + @Override + public String[] getMechanismNames(Map props) { + Collection mechanisms = ScramMechanism.mechanismNames(); + return mechanisms.toArray(new String[0]); + } + } +} diff --git a/src/main/java/net/spy/memcached/auth/ScramSaslClientProvider.java b/src/main/java/net/spy/memcached/auth/ScramSaslClientProvider.java new file mode 100644 index 000000000..6439d3549 --- /dev/null +++ b/src/main/java/net/spy/memcached/auth/ScramSaslClientProvider.java @@ -0,0 +1,21 @@ +package net.spy.memcached.auth; + +import java.security.Provider; +import java.security.Security; + +import net.spy.memcached.auth.ScramSaslClient.ScramSaslClientFactory; + +public final class ScramSaslClientProvider extends Provider { + + private static final long serialVersionUID = 1L; + + @SuppressWarnings("deprecation") + private ScramSaslClientProvider() { + super("SASL/SCRAM Client Provider", 1.0, "SASL/SCRAM Client Provider for Arcus"); + put("SaslClientFactory.SCRAM-SHA-256", ScramSaslClientFactory.class.getName()); + } + + public static void initialize() { + Security.addProvider(new ScramSaslClientProvider()); + } +}