summaryrefslogtreecommitdiff
path: root/qpid/java
diff options
context:
space:
mode:
authorRobert Godfrey <rgodfrey@apache.org>2014-07-21 18:20:30 +0000
committerRobert Godfrey <rgodfrey@apache.org>2014-07-21 18:20:30 +0000
commit6312bc11c12f4afaa76c14c8fd986e86f9385625 (patch)
treee85d7fd84fcbea69a3d396e8054aae9a869ea8b2 /qpid/java
parente4e12129abf3e332c59eedc1348a402941c3f33e (diff)
downloadqpid-python-6312bc11c12f4afaa76c14c8fd986e86f9385625.tar.gz
QPID-5884 : Add additional SASL mechanisms to the JMS AMQP 1.0 client
git-svn-id: https://svn.apache.org/repos/asf/qpid/trunk@1612369 13f79535-47bb-0310-9956-ffa450edef68
Diffstat (limited to 'qpid/java')
-rw-r--r--qpid/java/amqp-1-0-client/src/main/java/org/apache/qpid/amqp_1_0/client/TCPTransportProvier.java20
-rw-r--r--qpid/java/amqp-1-0-common/src/main/java/org/apache/qpid/amqp_1_0/transport/ConnectionEndpoint.java656
2 files changed, 608 insertions, 68 deletions
diff --git a/qpid/java/amqp-1-0-client/src/main/java/org/apache/qpid/amqp_1_0/client/TCPTransportProvier.java b/qpid/java/amqp-1-0-client/src/main/java/org/apache/qpid/amqp_1_0/client/TCPTransportProvier.java
index f4a21ea359..ee515c33ef 100644
--- a/qpid/java/amqp-1-0-client/src/main/java/org/apache/qpid/amqp_1_0/client/TCPTransportProvier.java
+++ b/qpid/java/amqp-1-0-client/src/main/java/org/apache/qpid/amqp_1_0/client/TCPTransportProvier.java
@@ -20,15 +20,6 @@
*/
package org.apache.qpid.amqp_1_0.client;
-import org.apache.qpid.amqp_1_0.framing.ConnectionHandler;
-import org.apache.qpid.amqp_1_0.framing.ExceptionHandler;
-import org.apache.qpid.amqp_1_0.transport.ConnectionEndpoint;
-import org.apache.qpid.amqp_1_0.type.FrameBody;
-import org.apache.qpid.amqp_1_0.type.SaslFrameBody;
-
-import javax.net.ssl.SSLContext;
-import javax.net.ssl.SSLSocket;
-import javax.net.ssl.SSLSocketFactory;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
@@ -36,6 +27,16 @@ import java.net.Socket;
import java.net.SocketTimeoutException;
import java.nio.ByteBuffer;
+import javax.net.ssl.SSLContext;
+import javax.net.ssl.SSLSocket;
+import javax.net.ssl.SSLSocketFactory;
+
+import org.apache.qpid.amqp_1_0.framing.ConnectionHandler;
+import org.apache.qpid.amqp_1_0.framing.ExceptionHandler;
+import org.apache.qpid.amqp_1_0.transport.ConnectionEndpoint;
+import org.apache.qpid.amqp_1_0.type.FrameBody;
+import org.apache.qpid.amqp_1_0.type.SaslFrameBody;
+
class TCPTransportProvier implements TransportProvider
{
private final String _transport;
@@ -70,6 +71,7 @@ class TCPTransportProvier implements TransportProvider
SSLSocket sslSocket = (SSLSocket) socketFactory.createSocket(address, port);
+ conn.setExternalPrincipal(sslSocket.getSession().getLocalPrincipal());
s=sslSocket;
}
else
diff --git a/qpid/java/amqp-1-0-common/src/main/java/org/apache/qpid/amqp_1_0/transport/ConnectionEndpoint.java b/qpid/java/amqp-1-0-common/src/main/java/org/apache/qpid/amqp_1_0/transport/ConnectionEndpoint.java
index 5d6bc67373..873f4e8f53 100644
--- a/qpid/java/amqp-1-0-common/src/main/java/org/apache/qpid/amqp_1_0/transport/ConnectionEndpoint.java
+++ b/qpid/java/amqp-1-0-common/src/main/java/org/apache/qpid/amqp_1_0/transport/ConnectionEndpoint.java
@@ -21,21 +21,30 @@
package org.apache.qpid.amqp_1_0.transport;
+import java.io.IOException;
+import java.io.UnsupportedEncodingException;
import java.net.SocketAddress;
import java.nio.ByteBuffer;
import java.nio.charset.Charset;
+import java.security.InvalidKeyException;
+import java.security.MessageDigest;
+import java.security.NoSuchAlgorithmException;
import java.security.Principal;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
+import java.util.UUID;
import java.util.concurrent.TimeoutException;
import java.util.logging.Level;
import java.util.logging.Logger;
+import javax.crypto.Mac;
+import javax.crypto.spec.SecretKeySpec;
import javax.security.sasl.SaslException;
import javax.security.sasl.SaslServer;
+import javax.xml.bind.DatatypeConverter;
import org.apache.qpid.amqp_1_0.codec.DescribedTypeConstructorRegistry;
import org.apache.qpid.amqp_1_0.codec.ValueWriter;
@@ -82,7 +91,7 @@ public class ConnectionEndpoint implements DescribedTypeConstructorRegistry.Sour
private static final short DEFAULT_CHANNEL_MAX = Integer.getInteger("amqp.channel_max", 255).shortValue();
private static final int DEFAULT_MAX_FRAME = Integer.getInteger("amqp.max_frame_size", 1 << 15);
- private static final long DEFAULT_SYNC_TIMEOUT = Long.getLong("amqp.connection_sync_timeout",5000l);
+ private static final long DEFAULT_SYNC_TIMEOUT = Long.getLong("amqp.connection_sync_timeout", 5000l);
private ConnectionState _state = ConnectionState.UNOPENED;
@@ -137,6 +146,8 @@ public class ConnectionEndpoint implements DescribedTypeConstructorRegistry.Sour
private long _syncTimeout = DEFAULT_SYNC_TIMEOUT;
private String _localHostname;
+ private boolean _secure;
+ private Principal _externalPrincipal;
public ConnectionEndpoint(Container container, SaslServerProvider cbs)
{
@@ -157,7 +168,7 @@ public class ConnectionEndpoint implements DescribedTypeConstructorRegistry.Sour
public void setPrincipal(Principal user)
{
- if(_user == null)
+ if (_user == null)
{
_user = user;
_requiresSASLClient = user != null;
@@ -199,7 +210,7 @@ public class ConnectionEndpoint implements DescribedTypeConstructorRegistry.Sour
_frameOutputHandler = frameOutputHandler;
}
- public void setProperties(Map<Symbol,Object> properties)
+ public void setProperties(Map<Symbol, Object> properties)
{
_properties = properties;
}
@@ -260,12 +271,12 @@ public class ConnectionEndpoint implements DescribedTypeConstructorRegistry.Sour
{
Open open = new Open();
- if(_receivingSessions == null)
+ if (_receivingSessions == null)
{
- _receivingSessions = new SessionEndpoint[channelMax+1];
- _sendingSessions = new SessionEndpoint[channelMax+1];
+ _receivingSessions = new SessionEndpoint[channelMax + 1];
+ _sendingSessions = new SessionEndpoint[channelMax + 1];
}
- if(channelMax < _channelMax)
+ if (channelMax < _channelMax)
{
_channelMax = channelMax;
}
@@ -273,7 +284,7 @@ public class ConnectionEndpoint implements DescribedTypeConstructorRegistry.Sour
open.setContainerId(_container.getId());
open.setMaxFrameSize(getDesiredMaxFrameSize());
open.setHostname(getRemoteHostname());
- if(_properties != null)
+ if (_properties != null)
{
open.setProperties(_properties);
}
@@ -322,7 +333,7 @@ public class ConnectionEndpoint implements DescribedTypeConstructorRegistry.Sour
error.setDescription("Frame received on channel " + channel + " which is not known as a begun session.");
this.handleError(error);
}
-
+
return session;
}
@@ -335,18 +346,18 @@ public class ConnectionEndpoint implements DescribedTypeConstructorRegistry.Sour
? open.getChannelMax().shortValue()
: _channelMax;
- if(_receivingSessions == null)
+ if (_receivingSessions == null)
{
- _receivingSessions = new SessionEndpoint[_channelMax+1];
- _sendingSessions = new SessionEndpoint[_channelMax+1];
+ _receivingSessions = new SessionEndpoint[_channelMax + 1];
+ _sendingSessions = new SessionEndpoint[_channelMax + 1];
}
UnsignedInteger remoteDesiredMaxFrameSize =
open.getMaxFrameSize() == null ? UnsignedInteger.valueOf(DEFAULT_MAX_FRAME) : open.getMaxFrameSize();
_maxFrameSize = (remoteDesiredMaxFrameSize.compareTo(_desiredMaxFrameSize) < 0
- ? remoteDesiredMaxFrameSize
- : _desiredMaxFrameSize).intValue();
+ ? remoteDesiredMaxFrameSize
+ : _desiredMaxFrameSize).intValue();
_remoteContainerId = open.getContainerId();
_localHostname = open.getHostname();
@@ -411,7 +422,7 @@ public class ConnectionEndpoint implements DescribedTypeConstructorRegistry.Sour
switch (_state)
{
case UNOPENED:
- sendOpen((short)0,0);
+ sendOpen((short) 0, 0);
sendClose(close);
_state = ConnectionState.CLOSED;
break;
@@ -433,7 +444,7 @@ public class ConnectionEndpoint implements DescribedTypeConstructorRegistry.Sour
if (!_closedForInput)
{
_closedForInput = true;
- switch(_state)
+ switch (_state)
{
case UNOPENED:
case AWAITING_OPEN:
@@ -448,7 +459,7 @@ public class ConnectionEndpoint implements DescribedTypeConstructorRegistry.Sour
default:
}
- if(_receivingSessions != null)
+ if (_receivingSessions != null)
{
for (int i = 0; i < _receivingSessions.length; i++)
{
@@ -508,7 +519,7 @@ public class ConnectionEndpoint implements DescribedTypeConstructorRegistry.Sour
endpoint.setReceivingChannel(channel);
endpoint.setNextIncomingId(begin.getNextOutgoingId());
endpoint.setOutgoingSessionCredit(begin.getIncomingWindow());
-
+
if (endpoint.getState() == SessionState.END_SENT)
{
_sendingSessions[myChannelId] = null;
@@ -728,19 +739,24 @@ public class ConnectionEndpoint implements DescribedTypeConstructorRegistry.Sour
Close close = new Close();
close.setError(error);
send((short) 0, close);
-
+
this.setClosedForOutput(true);
}
}
+ public void setExternalPrincipal(final Principal externalPrincipal)
+ {
+ _externalPrincipal = externalPrincipal;
+ }
+
public static interface FrameReceiptLogger
{
boolean isEnabled();
+
void received(final SocketAddress remoteAddress, short channel, Object frame);
}
-
private FrameReceiptLogger _logger =
new FrameReceiptLogger()
{
@@ -911,52 +927,82 @@ public class ConnectionEndpoint implements DescribedTypeConstructorRegistry.Sour
}
}
+ private final AmqpSaslClient[] _supportedSaslClientMechanisms =
+ new AmqpSaslClient[]{new ScramSHA256SaslClient(), new ScramSHA1SaslClient(), new ExternalSaslClient(),
+ new CramMD5SaslClient(), new CramMD5HashedSaslClient(), new PlainSaslClient(), new AnonymousSaslClient()};
+
+ private AmqpSaslClient _saslClient;
+
public void receiveSaslMechanisms(final SaslMechanisms saslMechanisms)
{
SaslInit init = new SaslInit();
init.setHostname(_remoteHostname);
Set<Symbol> mechanisms = new HashSet<Symbol>(Arrays.asList(saslMechanisms.getSaslServerMechanisms()));
- if (mechanisms.contains(SASL_PLAIN) && _password != null)
- {
-
- init.setMechanism(SASL_PLAIN);
-
- byte[] usernameBytes = _user.getName().getBytes(Charset.forName("UTF-8"));
- byte[] passwordBytes = _password.getBytes(Charset.forName("UTF-8"));
- byte[] initResponse = new byte[usernameBytes.length + passwordBytes.length + 2];
- System.arraycopy(usernameBytes, 0, initResponse, 1, usernameBytes.length);
- System.arraycopy(passwordBytes, 0, initResponse, usernameBytes.length + 2, passwordBytes.length);
- init.setInitialResponse(new Binary(initResponse));
- }
- else if (mechanisms.contains(SASL_ANONYMOUS))
+ for (AmqpSaslClient saslClient : _supportedSaslClientMechanisms)
{
- init.setMechanism(SASL_ANONYMOUS);
+ if (mechanisms.contains(saslClient.getMechanismName()) && saslClient.canSupportMechanism())
+ {
+ _saslClient = saslClient;
+ break;
+ }
}
- else if (mechanisms.contains(SASL_EXTERNAL))
+ if (_saslClient != null)
{
- init.setMechanism(SASL_EXTERNAL);
+ try
+ {
+
+ init.setMechanism(_saslClient.getMechanismName());
+ if (_saslClient.hasInitialResponse())
+ {
+ init.setInitialResponse(new Binary(_saslClient.getResponse(new byte[0])));
+ }
+ _saslFrameOutput.send(new SASLFrame(init), null);
+
+ }
+ catch (SaslException e)
+ {
+ closeSaslWithFailure();
+
+ }
}
else
{
- synchronized (getLock())
- {
- _saslComplete = true;
- _authenticated = false;
- getLock().notifyAll();
- }
- setClosedForInput(true);
- _saslFrameOutput.close();
+ closeSaslWithFailure();
+
+ }
+ }
- return;
+ public void closeSaslWithFailure()
+ {
+ synchronized (getLock())
+ {
+ _saslComplete = true;
+ _authenticated = false;
+ getLock().notifyAll();
}
- _saslFrameOutput.send(new SASLFrame(init), null);
+ setClosedForInput(true);
+ _saslFrameOutput.close();
}
public void receiveSaslChallenge(final SaslChallenge saslChallenge)
{
- //To change body of implemented methods use File | Settings | File Templates.
+ try
+ {
+ ByteBuffer challenge = saslChallenge.getChallenge().asByteBuffer();
+ final byte[] challengeBytes = new byte[challenge.remaining()];
+ challenge.get(challengeBytes);
+ byte[] responseBytes = _saslClient.getResponse(challengeBytes);
+ SaslResponse response = new SaslResponse();
+ response.setResponse(new Binary(responseBytes));
+ _saslFrameOutput.send(new SASLFrame(response), null);
+ }
+ catch (SaslException e)
+ {
+ closeSaslWithFailure();
+ }
+
}
public void receiveSaslResponse(final SaslResponse saslResponse)
@@ -1036,14 +1082,7 @@ public class ConnectionEndpoint implements DescribedTypeConstructorRegistry.Sour
}
else
{
- synchronized (getLock())
- {
- _saslComplete = true;
- _authenticated = false;
- getLock().notifyAll();
- }
- setClosedForInput(true);
- _saslFrameOutput.close();
+ closeSaslWithFailure();
}
}
@@ -1151,16 +1190,16 @@ public class ConnectionEndpoint implements DescribedTypeConstructorRegistry.Sour
{
long endTime = System.currentTimeMillis() + timeout;
- synchronized(getLock())
+ synchronized (getLock())
{
- while(!predicate.isSatisfied())
+ while (!predicate.isSatisfied())
{
getLock().wait(timeout);
- if(!predicate.isSatisfied())
+ if (!predicate.isSatisfied())
{
timeout = endTime - System.currentTimeMillis();
- if(timeout <= 0l)
+ if (timeout <= 0l)
{
throw new TimeoutException();
}
@@ -1169,4 +1208,503 @@ public class ConnectionEndpoint implements DescribedTypeConstructorRegistry.Sour
}
}
+
+ private interface AmqpSaslClient
+ {
+ boolean canSupportMechanism();
+
+ Symbol getMechanismName();
+
+ boolean hasInitialResponse();
+
+ byte[] getResponse(byte[] challenge) throws SaslException;
+ }
+
+ private class AnonymousSaslClient implements AmqpSaslClient
+ {
+
+ @Override
+ public boolean canSupportMechanism()
+ {
+ return true;
+ }
+
+ @Override
+ public Symbol getMechanismName()
+ {
+ return Symbol.valueOf("ANONYMOUS");
+ }
+
+ @Override
+ public boolean hasInitialResponse()
+ {
+ return false;
+ }
+
+ @Override
+ public byte[] getResponse(final byte[] challenge)
+ {
+ return new byte[0];
+ }
+ }
+
+ private class ExternalSaslClient implements AmqpSaslClient
+ {
+
+ @Override
+ public boolean canSupportMechanism()
+ {
+ return ConnectionEndpoint.this._externalPrincipal != null;
+ }
+
+ @Override
+ public Symbol getMechanismName()
+ {
+ return Symbol.valueOf("EXTERNAL");
+ }
+
+ @Override
+ public boolean hasInitialResponse()
+ {
+ return false;
+ }
+
+ @Override
+ public byte[] getResponse(final byte[] challenge)
+ {
+ return new byte[0];
+ }
+ }
+
+ private class PlainSaslClient implements AmqpSaslClient
+ {
+
+ private boolean _initResponseSent;
+
+ @Override
+ public boolean canSupportMechanism()
+ {
+ return ConnectionEndpoint.this._user != null
+ && ConnectionEndpoint.this._password != null;
+ }
+
+ @Override
+ public Symbol getMechanismName()
+ {
+ return Symbol.valueOf("PLAIN");
+ }
+
+ @Override
+ public boolean hasInitialResponse()
+ {
+ return true;
+ }
+
+ @Override
+ public byte[] getResponse(final byte[] challenge)
+ {
+ if (_initResponseSent)
+ {
+ return new byte[0];
+ }
+ else
+ {
+ _initResponseSent = true;
+ byte[] usernameBytes = _user.getName().getBytes(Charset.forName("UTF-8"));
+ byte[] passwordBytes = _password.getBytes(Charset.forName("UTF-8"));
+ byte[] initResponse = new byte[usernameBytes.length + passwordBytes.length + 2];
+ System.arraycopy(usernameBytes, 0, initResponse, 1, usernameBytes.length);
+ System.arraycopy(passwordBytes, 0, initResponse, usernameBytes.length + 2, passwordBytes.length);
+ return initResponse;
+ }
+ }
+ }
+
+
+ private static final Charset ASCII = Charset.forName("ASCII");
+
+ abstract static private class AbstractScramSaslClient implements AmqpSaslClient
+ {
+
+
+ private static final byte[] INT_1 = new byte[]{0, 0, 0, 1};
+ private static final String GS2_HEADER = "n,,";
+
+ private final String _digestName;
+ private final String _hmacName;
+ private final ConnectionEndpoint _endpoint;
+
+ private String _username;
+ private final String _clientNonce = UUID.randomUUID().toString();
+ private String _serverNonce;
+ private byte[] _salt;
+ private int _iterationCount;
+ private String _clientFirstMessageBare;
+ private byte[] _serverSignature;
+
+ enum State
+ {
+ INITIAL,
+ CLIENT_FIRST_SENT,
+ CLIENT_PROOF_SENT,
+ COMPLETE
+ }
+
+ public final Symbol _mechanism;
+
+
+ private State _state = State.INITIAL;
+
+ public AbstractScramSaslClient(ConnectionEndpoint endpoint,
+ final Symbol mechanism,
+ final String digestName,
+ final String hmacName)
+ {
+ _endpoint = endpoint;
+ _mechanism = mechanism;
+ _digestName = digestName;
+ _hmacName = hmacName;
+
+ }
+
+
+ @Override
+ public boolean canSupportMechanism()
+ {
+ return _endpoint._user != null
+ && _endpoint._password != null;
+ }
+
+
+ @Override
+ public Symbol getMechanismName()
+ {
+ return _mechanism;
+ }
+
+ @Override
+ public boolean hasInitialResponse()
+ {
+ return true;
+ }
+
+ @Override
+ public byte[] getResponse(final byte[] challenge) throws SaslException
+ {
+ byte[] response;
+ switch (_state)
+ {
+ case INITIAL:
+ response = initialResponse();
+ _state = State.CLIENT_FIRST_SENT;
+ break;
+ case CLIENT_FIRST_SENT:
+ response = calculateClientProof(challenge);
+ _state = State.CLIENT_PROOF_SENT;
+ break;
+ case CLIENT_PROOF_SENT:
+ evaluateOutcome(challenge);
+ response = null;
+ _state = State.COMPLETE;
+ break;
+ default:
+ throw new SaslException("No challenge expected in state " + _state);
+ }
+ return response;
+ }
+
+ private void evaluateOutcome(final byte[] challenge) throws SaslException
+ {
+ String serverFinalMessage = new String(challenge, ASCII);
+ String[] parts = serverFinalMessage.split(",");
+ if (!parts[0].startsWith("v="))
+ {
+ throw new SaslException("Server final message did not contain verifier");
+ }
+ byte[] serverSignature = DatatypeConverter.parseBase64Binary(parts[0].substring(2));
+ if (!Arrays.equals(_serverSignature, serverSignature))
+ {
+ throw new SaslException("Server signature did not match");
+ }
+ }
+
+ private byte[] calculateClientProof(final byte[] challenge) throws SaslException
+ {
+ try
+ {
+ String serverFirstMessage = new String(challenge, ASCII);
+ String[] parts = serverFirstMessage.split(",");
+ if (parts.length < 3)
+ {
+ throw new SaslException("Server challenge '" + serverFirstMessage + "' cannot be parsed");
+ }
+ else if (parts[0].startsWith("m="))
+ {
+ throw new SaslException("Server requires mandatory extension which is not supported: " + parts[0]);
+ }
+ else if (!parts[0].startsWith("r="))
+ {
+ throw new SaslException("Server challenge '"
+ + serverFirstMessage
+ + "' cannot be parsed, cannot find nonce");
+ }
+ String nonce = parts[0].substring(2);
+ if (!nonce.startsWith(_clientNonce))
+ {
+ throw new SaslException("Server challenge did not use correct client nonce");
+ }
+ _serverNonce = nonce;
+ if (!parts[1].startsWith("s="))
+ {
+ throw new SaslException("Server challenge '"
+ + serverFirstMessage
+ + "' cannot be parsed, cannot find salt");
+ }
+ String base64Salt = parts[1].substring(2);
+ _salt = DatatypeConverter.parseBase64Binary(base64Salt);
+ if (!parts[2].startsWith("i="))
+ {
+ throw new SaslException("Server challenge '"
+ + serverFirstMessage
+ + "' cannot be parsed, cannot find iteration count");
+ }
+ String iterCountString = parts[2].substring(2);
+ _iterationCount = Integer.parseInt(iterCountString);
+ if (_iterationCount <= 0)
+ {
+ throw new SaslException("Iteration count " + _iterationCount + " is not a positive integer");
+ }
+
+ byte[] passwordBytes = saslPrep(_endpoint._password).getBytes("UTF-8");
+
+ byte[] saltedPassword = generateSaltedPassword(passwordBytes);
+
+
+ String clientFinalMessageWithoutProof =
+ "c=" + DatatypeConverter.printBase64Binary(GS2_HEADER.getBytes(ASCII))
+ + ",r=" + _serverNonce;
+
+ String authMessage =
+ _clientFirstMessageBare + "," + serverFirstMessage + "," + clientFinalMessageWithoutProof;
+
+ byte[] clientKey = computeHmac(saltedPassword, "Client Key");
+ byte[] storedKey = MessageDigest.getInstance(_digestName).digest(clientKey);
+
+ byte[] clientSignature = computeHmac(storedKey, authMessage);
+
+ byte[] clientProof = clientKey.clone();
+ for (int i = 0; i < clientProof.length; i++)
+ {
+ clientProof[i] ^= clientSignature[i];
+ }
+ byte[] serverKey = computeHmac(saltedPassword, "Server Key");
+ _serverSignature = computeHmac(serverKey, authMessage);
+
+ String finalMessageWithProof = clientFinalMessageWithoutProof
+ + ",p=" + DatatypeConverter.printBase64Binary(clientProof);
+ return finalMessageWithProof.getBytes();
+ }
+ catch (IllegalArgumentException | IOException | NoSuchAlgorithmException e)
+ {
+ throw new SaslException(e.getMessage(), e);
+ }
+ }
+
+ private byte[] computeHmac(final byte[] key, final String string)
+ throws SaslException, UnsupportedEncodingException
+ {
+ Mac mac = createHmac(key);
+ mac.update(string.getBytes(ASCII));
+ return mac.doFinal();
+ }
+
+ private byte[] generateSaltedPassword(final byte[] passwordBytes) throws SaslException
+ {
+ Mac mac = createHmac(passwordBytes);
+
+ mac.update(_salt);
+ mac.update(INT_1);
+ byte[] result = mac.doFinal();
+
+ byte[] previous = null;
+ for (int i = 1; i < _iterationCount; i++)
+ {
+ mac.update(previous != null ? previous : result);
+ previous = mac.doFinal();
+ for (int x = 0; x < result.length; x++)
+ {
+ result[x] ^= previous[x];
+ }
+ }
+
+ return result;
+ }
+
+ private Mac createHmac(final byte[] keyBytes)
+ throws SaslException
+ {
+ try
+ {
+ SecretKeySpec key = new SecretKeySpec(keyBytes, _hmacName);
+ Mac mac = Mac.getInstance(_hmacName);
+ mac.init(key);
+ return mac;
+ }
+ catch (NoSuchAlgorithmException | InvalidKeyException e)
+ {
+ throw new SaslException(e.getMessage(), e);
+ }
+ }
+
+
+ private byte[] initialResponse() throws SaslException
+ {
+ StringBuffer buf = new StringBuffer("n=");
+ _username = _endpoint.getUser().getName();
+ buf.append(saslPrep(_username));
+ buf.append(",r=");
+ buf.append(_clientNonce);
+ _clientFirstMessageBare = buf.toString();
+ return (GS2_HEADER + _clientFirstMessageBare).getBytes(ASCII);
+ }
+
+ private String saslPrep(String name) throws SaslException
+ {
+ // TODO - a real implementation of SaslPrep
+
+ if (!ASCII.newEncoder().canEncode(name))
+ {
+ throw new SaslException("Can only encode names and passwords which are restricted to ASCII characters");
+ }
+
+ name = name.replace("=", "=3D");
+ name = name.replace(",", "=2C");
+ return name;
+ }
+
+ public boolean isComplete()
+ {
+ return _state == State.COMPLETE;
+ }
+
+ }
+
+ private final class ScramSHA1SaslClient extends AbstractScramSaslClient
+ {
+
+ public ScramSHA1SaslClient()
+ {
+ super(ConnectionEndpoint.this, Symbol.valueOf("SCRAM-SHA-1"), "SHA-1", "HmacSHA1");
+ }
+ }
+
+
+ private final class ScramSHA256SaslClient extends AbstractScramSaslClient
+ {
+
+ public ScramSHA256SaslClient()
+ {
+ super(ConnectionEndpoint.this, Symbol.valueOf("SCRAM-SHA-256"), "SHA-256", "HmacSHA256");
+ }
+ }
+
+ private class CramMD5SaslClient implements AmqpSaslClient
+ {
+
+ @Override
+ public boolean canSupportMechanism()
+ {
+ return ConnectionEndpoint.this._user != null
+ && ConnectionEndpoint.this._password != null;
+ }
+
+ @Override
+ public Symbol getMechanismName()
+ {
+ return Symbol.valueOf("CRAM-MD5");
+ }
+
+ @Override
+ public boolean hasInitialResponse()
+ {
+ return false;
+ }
+
+ @Override
+ public byte[] getResponse(final byte[] challenge) throws SaslException
+ {
+
+ try
+ {
+ SecretKeySpec key = new SecretKeySpec(getSharedSecretBytes(), "HmacMD5");
+
+ Mac mac = Mac.getInstance("HmacMD5");
+ mac.init(key);
+
+ mac.update(challenge);
+ byte[] result = mac.doFinal();
+
+ StringBuilder responseBeforeBase64 = new StringBuilder(ConnectionEndpoint.this.getUser().getName());
+ responseBeforeBase64.append(" ");
+ for (byte b : result)
+ {
+ responseBeforeBase64.append(String.format("%02x", b));
+ }
+
+ return responseBeforeBase64.toString().getBytes(ASCII);
+ }
+ catch (NoSuchAlgorithmException | InvalidKeyException e)
+ {
+ throw new SaslException(e.getMessage(), e);
+ }
+ }
+
+ public byte[] getSharedSecretBytes() throws SaslException
+ {
+ return ConnectionEndpoint.this._password.getBytes(ASCII);
+ }
+ }
+
+ private final class CramMD5HashedSaslClient extends CramMD5SaslClient
+ {
+
+ @Override
+ public Symbol getMechanismName()
+ {
+ return Symbol.valueOf("CRAM-MD5-HASHED");
+ }
+
+ public byte[] getSharedSecretBytes() throws SaslException
+ {
+
+ try
+ {
+
+ byte[] data = ConnectionEndpoint.this._password.getBytes("utf-8");
+ MessageDigest md = MessageDigest.getInstance("MD5");
+ for (byte b : data)
+ {
+ md.update(b);
+ }
+
+ byte[] digest = md.digest();
+
+ char[] hash = new char[digest.length];
+
+ int index = 0;
+ for (byte b : digest)
+ {
+ hash[index++] = (char) b;
+ }
+
+ return new String(hash).getBytes("utf-8");
+ }
+ catch (NoSuchAlgorithmException | UnsupportedEncodingException e)
+ {
+ throw new SaslException(e.getMessage(), e);
+ }
+
+ }
+
+ }
}