Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weโ€™ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

INTERNAL: Add ScramSaslClient #869

Draft
wants to merge 1 commit into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,11 @@
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>com.bolyartech.scram_sasl</groupId>
<artifactId>scram_sasl</artifactId>
<version>2.0.2</version>
</dependency>

<!-- TEST -->
<dependency>
Expand Down
49 changes: 49 additions & 0 deletions src/main/java/net/spy/memcached/auth/ScramMechanism.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package net.spy.memcached.auth;

import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;

public enum ScramMechanism {
SCRAM_SHA_256("SHA-256", "HmacSHA256");

private static final Map<String, ScramMechanism> MECHANISMS_MAP;

private final String mechanismName;
private final String hashAlgorithm;
private final String macAlgorithm;

static {
Map<String, ScramMechanism> map = new HashMap<>();
for (ScramMechanism mech : values()) {
map.put(mech.mechanismName, mech);
}
MECHANISMS_MAP = Collections.unmodifiableMap(map);
}

private ScramMechanism(String hashAlgorithm, String macAlgorithm) {
this.mechanismName = "SCRAM-" + hashAlgorithm;
this.hashAlgorithm = hashAlgorithm;
this.macAlgorithm = macAlgorithm;
}

public final String mechanismName() {
return this.mechanismName;
}
public String hashAlgorithm() {
return hashAlgorithm;
}

public String macAlgorithm() {
return macAlgorithm;
}

public static ScramMechanism forMechanismName(String mechanismName) {
return MECHANISMS_MAP.get(mechanismName);
}

public static Collection<String> mechanismNames() {
return MECHANISMS_MAP.keySet();
}
}
179 changes: 179 additions & 0 deletions src/main/java/net/spy/memcached/auth/ScramSaslClient.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
package net.spy.memcached.auth;

import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.Collection;
import java.util.Map;

import javax.security.auth.callback.Callback;
import javax.security.auth.callback.CallbackHandler;
import javax.security.auth.callback.NameCallback;
import javax.security.auth.callback.PasswordCallback;
import javax.security.sasl.SaslClient;
import javax.security.sasl.SaslClientFactory;
import javax.security.sasl.SaslException;

import com.bolyartech.scram_sasl.client.ScramClientFunctionality;
import com.bolyartech.scram_sasl.client.ScramClientFunctionalityImpl;
import com.bolyartech.scram_sasl.common.ScramException;

public class ScramSaslClient implements SaslClient {

enum State {
SEND_CLIENT_FIRST_MESSAGE,
RECEIVE_SERVER_FIRST_MESSAGE,
RECEIVE_SERVER_FINAL_MESSAGE,
COMPLETE,
FAILED
}

private final ScramMechanism mechanism;
private final CallbackHandler callbackHandler;
private final ScramClientFunctionality scf;
private State state;

public ScramSaslClient(ScramMechanism mechanism, CallbackHandler cbh) {
this.callbackHandler = cbh;
this.mechanism = mechanism;
this.scf = new ScramClientFunctionalityImpl(
mechanism.hashAlgorithm(), mechanism.macAlgorithm());
this.state = State.SEND_CLIENT_FIRST_MESSAGE;
}

@Override
public String getMechanismName() {
return this.mechanism.mechanismName();
}

@Override
public boolean hasInitialResponse() {
return true;
}

@Override
public byte[] evaluateChallenge(byte[] challenge) throws SaslException {
try {
switch (this.state) {
case SEND_CLIENT_FIRST_MESSAGE:
if (challenge != null && challenge.length != 0) {
throw new SaslException("Expected empty challenge");
}

NameCallback nameCallback = new NameCallback("Name: ");

try {
callbackHandler.handle(new Callback[]{nameCallback});
uhm0311 marked this conversation as resolved.
Show resolved Hide resolved
} catch (Throwable e) {
throw new SaslException("User name could not be obtained", e);
}

String username = nameCallback.getName();
byte[] clientFirstMessage = this.scf.prepareFirstMessage(username).getBytes();
this.state = State.RECEIVE_SERVER_FIRST_MESSAGE;
return clientFirstMessage;

case RECEIVE_SERVER_FIRST_MESSAGE:
String serverFirstMessage = new String(challenge, StandardCharsets.UTF_8);

PasswordCallback passwordCallback = new PasswordCallback("Password: ", false);
try {
callbackHandler.handle(new Callback[]{passwordCallback});
} catch (Throwable e) {
throw new SaslException("Password could not be obtained", e);
}

String password = String.valueOf(passwordCallback.getPassword());
byte[] clientFinalMessage = this.scf.prepareFinalMessage(
password, serverFirstMessage).getBytes();
if (clientFinalMessage == null) {
throw new SaslException("clientFinalMessage should not be null");
}
this.state = State.RECEIVE_SERVER_FINAL_MESSAGE;
return clientFinalMessage;

case RECEIVE_SERVER_FINAL_MESSAGE:
String serverFinalMessage = new String(challenge, StandardCharsets.UTF_8);
if (!this.scf.checkServerFinalMessage(serverFinalMessage)) {
throw new SaslException("Sasl authentication using " + this.mechanism +
" failed with error: invalid server final message");
}
this.state = State.COMPLETE;
return new byte[]{};

default:
throw new SaslException("Unexpected challenge in Sasl client state " + this.state);
}
} catch (ScramException e) {
this.state = State.FAILED;
throw new SaslException("ScramException", e);
} catch (SaslException e) {
this.state = State.FAILED;
throw e;
}
}

@Override
public boolean isComplete() {
return this.state == State.COMPLETE;
}

@Override
public byte[] unwrap(byte[] incoming, int offset, int len) throws SaslException {
if (!isComplete()) {
throw new IllegalStateException("Authentication exchange has not completed");
}
return Arrays.copyOfRange(incoming, offset, offset + len);
}

@Override
public byte[] wrap(byte[] outgoing, int offset, int len) throws SaslException {
if (!isComplete()) {
throw new IllegalStateException("Authentication exchange has not completed");
}
return Arrays.copyOfRange(outgoing, offset, offset + len);
}

@Override
public Object getNegotiatedProperty(String propName) {
if (!isComplete()) {
throw new IllegalStateException("Authentication exchange has not completed");
}
return null;
}

@Override
public void dispose() throws SaslException {
}

public static class ScramSaslClientFactory implements SaslClientFactory {
@Override
public SaslClient createSaslClient(String[] mechanisms,
String authorizationId,
String protocol,
String serverName,
Map<String, ?> props,
CallbackHandler cbh) throws SaslException {

ScramMechanism mechanism = null;
for (String mech : mechanisms) {
mechanism = ScramMechanism.forMechanismName(mech);
if (mechanism != null) {
break;
}
}
if (mechanism == null) {
throw new SaslException(String.format("Requested mechanisms '%s' not supported."
+ " Supported mechanisms are '%s'.",
Arrays.asList(mechanisms), ScramMechanism.mechanismNames()));
}

return new ScramSaslClient(mechanism, cbh);
}

@Override
public String[] getMechanismNames(Map<String, ?> props) {
Collection<String> mechanisms = ScramMechanism.mechanismNames();
return mechanisms.toArray(new String[0]);
}
}
}
21 changes: 21 additions & 0 deletions src/main/java/net/spy/memcached/auth/ScramSaslClientProvider.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package net.spy.memcached.auth;

import java.security.Provider;
import java.security.Security;

import net.spy.memcached.auth.ScramSaslClient.ScramSaslClientFactory;

public final class ScramSaslClientProvider extends Provider {

private static final long serialVersionUID = 1L;

@SuppressWarnings("deprecation")
private ScramSaslClientProvider() {
super("SASL/SCRAM Client Provider", 1.0, "SASL/SCRAM Client Provider for Arcus");
put("SaslClientFactory.SCRAM-SHA-256", ScramSaslClientFactory.class.getName());
}

public static void initialize() {
Security.addProvider(new ScramSaslClientProvider());
}
}