Skip to content

Commit

Permalink
INTERNAL: Add ScramSaslClient
Browse files Browse the repository at this point in the history
  • Loading branch information
namsic committed Jan 15, 2025
1 parent 5c13575 commit 94a7e87
Show file tree
Hide file tree
Showing 4 changed files with 251 additions and 0 deletions.
5 changes: 5 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,11 @@
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>com.bolyartech.scram_sasl</groupId>
<artifactId>scram_sasl</artifactId>
<version>2.0.2</version>
</dependency>

<!-- TEST -->
<dependency>
Expand Down
49 changes: 49 additions & 0 deletions src/main/java/net/spy/memcached/auth/ScramMechanism.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package net.spy.memcached.auth;

import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;

public enum ScramMechanism {
SCRAM_SHA_256("SHA-256", "HmacSHA256");

private static final Map<String, ScramMechanism> MECHANISMS_MAP;

private final String mechanismName;
private final String hashAlgorithm;
private final String macAlgorithm;

static {
Map<String, ScramMechanism> map = new HashMap<>();
for (ScramMechanism mech : values()) {
map.put(mech.mechanismName, mech);
}
MECHANISMS_MAP = Collections.unmodifiableMap(map);
}

private ScramMechanism(String hashAlgorithm, String macAlgorithm) {
this.mechanismName = "SCRAM-" + hashAlgorithm;
this.hashAlgorithm = hashAlgorithm;
this.macAlgorithm = macAlgorithm;
}

public final String mechanismName() {
return this.mechanismName;
}
public String hashAlgorithm() {
return hashAlgorithm;
}

public String macAlgorithm() {
return macAlgorithm;
}

public static ScramMechanism forMechanismName(String mechanismName) {
return MECHANISMS_MAP.get(mechanismName);
}

public static Collection<String> mechanismNames() {
return MECHANISMS_MAP.keySet();
}
}
176 changes: 176 additions & 0 deletions src/main/java/net/spy/memcached/auth/ScramSaslClient.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
package net.spy.memcached.auth;

import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.Collection;
import java.util.Map;

import javax.security.auth.callback.Callback;
import javax.security.auth.callback.CallbackHandler;
import javax.security.auth.callback.NameCallback;
import javax.security.auth.callback.PasswordCallback;
import javax.security.sasl.SaslClient;
import javax.security.sasl.SaslClientFactory;
import javax.security.sasl.SaslException;

import com.bolyartech.scram_sasl.client.ScramClientFunctionality;
import com.bolyartech.scram_sasl.client.ScramClientFunctionalityImpl;
import com.bolyartech.scram_sasl.common.ScramException;

public class ScramSaslClient implements SaslClient {

enum State {
SEND_CLIENT_FIRST_MESSAGE,
RECEIVE_SERVER_FIRST_MESSAGE,
RECEIVE_SERVER_FINAL_MESSAGE,
COMPLETE,
FAILED
}

private final ScramMechanism mechanism;
private final CallbackHandler callbackHandler;
private final ScramClientFunctionality scf;
private State state;

public ScramSaslClient(ScramMechanism mechanism, CallbackHandler cbh) {
this.callbackHandler = cbh;
this.mechanism = mechanism;
this.scf = new ScramClientFunctionalityImpl(
mechanism.hashAlgorithm(), mechanism.macAlgorithm());
this.state = State.SEND_CLIENT_FIRST_MESSAGE;
}

@Override
public String getMechanismName() {
return this.mechanism.mechanismName();
}

@Override
public boolean hasInitialResponse() {
return true;
}

@Override
public byte[] evaluateChallenge(byte[] challenge) throws SaslException {
try {
switch (this.state) {
case SEND_CLIENT_FIRST_MESSAGE:
if (challenge != null && challenge.length != 0) {
throw new SaslException("Expected empty challenge");
}

NameCallback nameCallback = new NameCallback("Name: ");

try {
callbackHandler.handle(new Callback[]{nameCallback});
} catch (Throwable e) {
throw new SaslException("User name could not be obtained", e);
}

String username = nameCallback.getName();
byte[] clientFirstMessage = this.scf.prepareFirstMessage(username).getBytes();
this.state = State.RECEIVE_SERVER_FIRST_MESSAGE;
return clientFirstMessage;

case RECEIVE_SERVER_FIRST_MESSAGE:
String serverFirstMessage = new String(challenge, StandardCharsets.UTF_8);

PasswordCallback passwordCallback = new PasswordCallback("Password: ", false);
try {
callbackHandler.handle(new Callback[]{passwordCallback});
} catch (Throwable e) {
throw new SaslException("Password could not be obtained", e);
}

String password = String.valueOf(passwordCallback.getPassword());
byte[] clientFinalMessage = this.scf.prepareFinalMessage(
password, serverFirstMessage).getBytes();
this.state = State.RECEIVE_SERVER_FINAL_MESSAGE;
return clientFinalMessage;

case RECEIVE_SERVER_FINAL_MESSAGE:
String serverFinalMessage = new String(challenge, StandardCharsets.UTF_8);
if (!this.scf.checkServerFinalMessage(serverFinalMessage)) {
throw new SaslException("Sasl authentication using " + this.mechanism +
" failed with error: invalid server final message");
}
this.state = State.COMPLETE;
return new byte[]{};

default:
throw new SaslException("Unexpected challenge in Sasl client state " + this.state);
}
} catch (ScramException e) {
this.state = State.FAILED;
throw new SaslException("ScramException", e);
} catch (SaslException e) {
this.state = State.FAILED;
throw e;
}
}

@Override
public boolean isComplete() {
return this.state == State.COMPLETE;
}

@Override
public byte[] unwrap(byte[] incoming, int offset, int len) throws SaslException {
if (!isComplete()) {
throw new IllegalStateException("Authentication exchange has not completed");
}
return Arrays.copyOfRange(incoming, offset, offset + len);
}

@Override
public byte[] wrap(byte[] outgoing, int offset, int len) throws SaslException {
if (!isComplete()) {
throw new IllegalStateException("Authentication exchange has not completed");
}
return Arrays.copyOfRange(outgoing, offset, offset + len);
}

@Override
public Object getNegotiatedProperty(String propName) {
if (!isComplete()) {
throw new IllegalStateException("Authentication exchange has not completed");
}
return null;
}

@Override
public void dispose() throws SaslException {
}

public static class ScramSaslClientFactory implements SaslClientFactory {
@Override
public SaslClient createSaslClient(String[] mechanisms,
String authorizationId,
String protocol,
String serverName,
Map<String, ?> props,
CallbackHandler cbh) throws SaslException {

ScramMechanism mechanism = null;
for (String mech : mechanisms) {
mechanism = ScramMechanism.forMechanismName(mech);
if (mechanism != null) {
break;
}
}
if (mechanism == null) {
throw new SaslException(String.format("Requested mechanisms '%s' not supported."
+ " Supported mechanisms are '%s'.",
Arrays.asList(mechanisms), ScramMechanism.mechanismNames()));
}

return new ScramSaslClient(mechanism, cbh);
}

@Override
public String[] getMechanismNames(Map<String, ?> props) {
Collection<String> mechanisms = ScramMechanism.mechanismNames();
return mechanisms.toArray(new String[0]);
}
}
}
21 changes: 21 additions & 0 deletions src/main/java/net/spy/memcached/auth/ScramSaslClientProvider.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package net.spy.memcached.auth;

import java.security.Provider;
import java.security.Security;

import net.spy.memcached.auth.ScramSaslClient.ScramSaslClientFactory;

public final class ScramSaslClientProvider extends Provider {

private static final long serialVersionUID = 1L;

@SuppressWarnings("deprecation")
private ScramSaslClientProvider() {
super("SASL/SCRAM Client Provider", 1.0, "SASL/SCRAM Client Provider for Arcus");
put("SaslClientFactory.SCRAM-SHA-256", ScramSaslClientFactory.class.getName());
}

public static void initialize() {
Security.addProvider(new ScramSaslClientProvider());
}
}

0 comments on commit 94a7e87

Please sign in to comment.