diff options
Diffstat (limited to 'cpp/src/qpid/sys/ssl')
-rw-r--r-- | cpp/src/qpid/sys/ssl/SslHandler.h | 2 | ||||
-rw-r--r-- | cpp/src/qpid/sys/ssl/SslIo.cpp | 22 | ||||
-rw-r--r-- | cpp/src/qpid/sys/ssl/SslIo.h | 18 | ||||
-rw-r--r-- | cpp/src/qpid/sys/ssl/SslSocket.cpp | 163 | ||||
-rw-r--r-- | cpp/src/qpid/sys/ssl/SslSocket.h | 46 |
5 files changed, 137 insertions, 114 deletions
diff --git a/cpp/src/qpid/sys/ssl/SslHandler.h b/cpp/src/qpid/sys/ssl/SslHandler.h index a340109966..400fa317fd 100644 --- a/cpp/src/qpid/sys/ssl/SslHandler.h +++ b/cpp/src/qpid/sys/ssl/SslHandler.h @@ -35,7 +35,7 @@ namespace sys { namespace ssl { class SslIO; -class SslIOBufferBase; +struct SslIOBufferBase; class SslSocket; class SslHandler : public OutputControl { diff --git a/cpp/src/qpid/sys/ssl/SslIo.cpp b/cpp/src/qpid/sys/ssl/SslIo.cpp index a58a137473..4a59819183 100644 --- a/cpp/src/qpid/sys/ssl/SslIo.cpp +++ b/cpp/src/qpid/sys/ssl/SslIo.cpp @@ -68,29 +68,33 @@ __thread int64_t threadMaxReadTimeNs = 2 * 1000000; // start at 2ms * Asynch Acceptor */ -SslAcceptor::SslAcceptor(const SslSocket& s, Callback callback) : +template <class T> +SslAcceptorTmpl<T>::SslAcceptorTmpl(const T& s, Callback callback) : acceptedCallback(callback), - handle(s, boost::bind(&SslAcceptor::readable, this, _1), 0, 0), + handle(s, boost::bind(&SslAcceptorTmpl<T>::readable, this, _1), 0, 0), socket(s) { s.setNonblocking(); ignoreSigpipe(); } -SslAcceptor::~SslAcceptor() +template <class T> +SslAcceptorTmpl<T>::~SslAcceptorTmpl() { handle.stopWatch(); } -void SslAcceptor::start(Poller::shared_ptr poller) { +template <class T> +void SslAcceptorTmpl<T>::start(Poller::shared_ptr poller) { handle.startWatch(poller); } /* * We keep on accepting as long as there is something to accept */ -void SslAcceptor::readable(DispatchHandle& h) { - SslSocket* s; +template <class T> +void SslAcceptorTmpl<T>::readable(DispatchHandle& h) { + Socket* s; do { errno = 0; // TODO: Currently we ignore the peers address, perhaps we should @@ -110,6 +114,10 @@ void SslAcceptor::readable(DispatchHandle& h) { h.rewatch(); } +// Explicitly instantiate the templates we need +template class SslAcceptorTmpl<SslSocket>; +template class SslAcceptorTmpl<SslMuxSocket>; + /* * Asynch Connector */ @@ -117,7 +125,7 @@ void SslAcceptor::readable(DispatchHandle& h) { SslConnector::SslConnector(const SslSocket& s, Poller::shared_ptr poller, std::string hostname, - uint16_t port, + std::string port, ConnectedCallback connCb, FailedCallback failCb) : DispatchHandle(s, diff --git a/cpp/src/qpid/sys/ssl/SslIo.h b/cpp/src/qpid/sys/ssl/SslIo.h index 53ac69d8d6..c980d73831 100644 --- a/cpp/src/qpid/sys/ssl/SslIo.h +++ b/cpp/src/qpid/sys/ssl/SslIo.h @@ -29,26 +29,30 @@ namespace qpid { namespace sys { + +class Socket; + namespace ssl { - + class SslSocket; /* * Asynchronous ssl acceptor: accepts connections then does a callback * with the accepted fd */ -class SslAcceptor { +template <class T> +class SslAcceptorTmpl { public: - typedef boost::function1<void, const SslSocket&> Callback; + typedef boost::function1<void, const Socket&> Callback; private: Callback acceptedCallback; qpid::sys::DispatchHandle handle; - const SslSocket& socket; + const T& socket; public: - SslAcceptor(const SslSocket& s, Callback callback); - ~SslAcceptor(); + SslAcceptorTmpl(const T& s, Callback callback); + ~SslAcceptorTmpl(); void start(qpid::sys::Poller::shared_ptr poller); private: @@ -73,7 +77,7 @@ public: SslConnector(const SslSocket& socket, Poller::shared_ptr poller, std::string hostname, - uint16_t port, + std::string port, ConnectedCallback connCb, FailedCallback failCb = 0); diff --git a/cpp/src/qpid/sys/ssl/SslSocket.cpp b/cpp/src/qpid/sys/ssl/SslSocket.cpp index 01e2658877..30234bb686 100644 --- a/cpp/src/qpid/sys/ssl/SslSocket.cpp +++ b/cpp/src/qpid/sys/ssl/SslSocket.cpp @@ -25,11 +25,13 @@ #include "qpid/Exception.h" #include "qpid/sys/posix/check.h" #include "qpid/sys/posix/PrivatePosix.h" +#include "qpid/log/Statement.h" #include <fcntl.h> #include <sys/types.h> #include <sys/socket.h> #include <sys/errno.h> +#include <poll.h> #include <netinet/in.h> #include <netinet/tcp.h> #include <netdb.h> @@ -50,36 +52,6 @@ namespace sys { namespace ssl { namespace { -std::string getName(int fd, bool local, bool includeService = false) -{ - ::sockaddr_storage name; // big enough for any socket address - ::socklen_t namelen = sizeof(name); - - int result = -1; - if (local) { - result = ::getsockname(fd, (::sockaddr*)&name, &namelen); - } else { - result = ::getpeername(fd, (::sockaddr*)&name, &namelen); - } - - QPID_POSIX_CHECK(result); - - char servName[NI_MAXSERV]; - char dispName[NI_MAXHOST]; - if (includeService) { - if (int rc=::getnameinfo((::sockaddr*)&name, namelen, dispName, sizeof(dispName), - servName, sizeof(servName), - NI_NUMERICHOST | NI_NUMERICSERV) != 0) - throw QPID_POSIX_ERROR(rc); - return std::string(dispName) + ":" + std::string(servName); - - } else { - if (int rc=::getnameinfo((::sockaddr*)&name, namelen, dispName, sizeof(dispName), 0, 0, NI_NUMERICHOST) != 0) - throw QPID_POSIX_ERROR(rc); - return dispName; - } -} - std::string getService(int fd, bool local) { ::sockaddr_storage name; // big enough for any socket address @@ -132,7 +104,7 @@ std::string getDomainFromSubject(std::string subject) } -SslSocket::SslSocket() : IOHandle(new IOHandlePrivate()), socket(0), prototype(0) +SslSocket::SslSocket() : socket(0), prototype(0) { impl->fd = ::socket (PF_INET, SOCK_STREAM, 0); if (impl->fd < 0) throw QPID_POSIX_ERROR(errno); @@ -144,7 +116,7 @@ SslSocket::SslSocket() : IOHandle(new IOHandlePrivate()), socket(0), prototype(0 * returned from accept. Because we use posix accept rather than * PR_Accept, we have to reset the handshake. */ -SslSocket::SslSocket(IOHandlePrivate* ioph, PRFileDesc* model) : IOHandle(ioph), socket(0), prototype(0) +SslSocket::SslSocket(IOHandlePrivate* ioph, PRFileDesc* model) : Socket(ioph), socket(0), prototype(0) { socket = SSL_ImportFD(model, PR_ImportTCPSocket(impl->fd)); NSS_CHECK(SSL_ResetHandshake(socket, true)); @@ -158,7 +130,7 @@ void SslSocket::setNonblocking() const PR_SetSocketOption(socket, &option); } -void SslSocket::connect(const std::string& host, uint16_t port) const +void SslSocket::connect(const std::string& host, const std::string& port) const { std::stringstream namestream; namestream << host << ":" << port; @@ -180,7 +152,7 @@ void SslSocket::connect(const std::string& host, uint16_t port) const PRHostEnt hostEntry; PR_CHECK(PR_GetHostByName(host.data(), hostBuffer, PR_NETDB_BUF_SIZE, &hostEntry)); PRNetAddr address; - int value = PR_EnumerateHostEnt(0, &hostEntry, port, &address); + int value = PR_EnumerateHostEnt(0, &hostEntry, boost::lexical_cast<PRUint16>(port), &address); if (value < 0) { throw Exception(QPID_MSG("Error getting address for host: " << ErrorString())); } else if (value == 0) { @@ -238,6 +210,7 @@ int SslSocket::listen(uint16_t port, int backlog, const std::string& certName, b SslSocket* SslSocket::accept() const { + QPID_LOG(trace, "Accepting SSL connection."); int afd = ::accept(impl->fd, 0, 0); if ( afd >= 0) { return new SslSocket(new IOHandlePrivate(afd), prototype); @@ -248,36 +221,109 @@ SslSocket* SslSocket::accept() const } } -int SslSocket::read(void *buf, size_t count) const -{ - return PR_Read(socket, buf, count); -} +#define SSL_STREAM_MAX_WAIT_ms 20 +#define SSL_STREAM_MAX_RETRIES 2 -int SslSocket::write(const void *buf, size_t count) const -{ - return PR_Write(socket, buf, count); -} +static bool isSslStream(int afd) { + int retries = SSL_STREAM_MAX_RETRIES; + unsigned char buf[5] = {}; -std::string SslSocket::getSockname() const -{ - return getName(impl->fd, true); + do { + struct pollfd fd = {afd, POLLIN, 0}; + + /* + * Note that this is blocking the accept thread, so connections that + * send no data can limit the rate at which we can accept new + * connections. + */ + if (::poll(&fd, 1, SSL_STREAM_MAX_WAIT_ms) > 0) { + errno = 0; + int result = recv(afd, buf, sizeof(buf), MSG_PEEK | MSG_DONTWAIT); + if (result == sizeof(buf)) { + break; + } + if (errno && errno != EAGAIN) { + int err = errno; + ::close(afd); + throw QPID_POSIX_ERROR(err); + } + } + } while (retries-- > 0); + + if (retries < 0) { + return false; + } + + /* + * SSLv2 Client Hello format + * http://www.mozilla.org/projects/security/pki/nss/ssl/draft02.html + * + * Bytes 0-1: RECORD-LENGTH + * Byte 2: MSG-CLIENT-HELLO (1) + * Byte 3: CLIENT-VERSION-MSB + * Byte 4: CLIENT-VERSION-LSB + * + * Allowed versions: + * 2.0 - SSLv2 + * 3.0 - SSLv3 + * 3.1 - TLS 1.0 + * 3.2 - TLS 1.1 + * 3.3 - TLS 1.2 + * + * The version sent in the Client-Hello is the latest version supported by + * the client. NSS may send version 3.x in an SSLv2 header for + * maximum compatibility. + */ + bool isSSL2Handshake = buf[2] == 1 && // MSG-CLIENT-HELLO + ((buf[3] == 3 && buf[4] <= 3) || // SSL 3.0 & TLS 1.0-1.2 (v3.1-3.3) + (buf[3] == 2 && buf[4] == 0)); // SSL 2 + + /* + * SSLv3/TLS Client Hello format + * RFC 2246 + * + * Byte 0: ContentType (handshake - 22) + * Bytes 1-2: ProtocolVersion {major, minor} + * + * Allowed versions: + * 3.0 - SSLv3 + * 3.1 - TLS 1.0 + * 3.2 - TLS 1.1 + * 3.3 - TLS 1.2 + */ + bool isSSL3Handshake = buf[0] == 22 && // handshake + (buf[1] == 3 && buf[2] <= 3); // SSL 3.0 & TLS 1.0-1.2 (v3.1-3.3) + + return isSSL2Handshake || isSSL3Handshake; } -std::string SslSocket::getPeername() const +Socket* SslMuxSocket::accept() const { - return getName(impl->fd, false); + int afd = ::accept(impl->fd, 0, 0); + if (afd >= 0) { + QPID_LOG(trace, "Accepting connection with optional SSL wrapper."); + if (isSslStream(afd)) { + QPID_LOG(trace, "Accepted SSL connection."); + return new SslSocket(new IOHandlePrivate(afd), prototype); + } else { + QPID_LOG(trace, "Accepted Plaintext connection."); + return new Socket(new IOHandlePrivate(afd)); + } + } else if (errno == EAGAIN) { + return 0; + } else { + throw QPID_POSIX_ERROR(errno); + } } -std::string SslSocket::getPeerAddress() const +int SslSocket::read(void *buf, size_t count) const { - if (!connectname.empty()) - return connectname; - return getName(impl->fd, false, true); + return PR_Read(socket, buf, count); } -std::string SslSocket::getLocalAddress() const +int SslSocket::write(const void *buf, size_t count) const { - return getName(impl->fd, true, true); + return PR_Write(socket, buf, count); } uint16_t SslSocket::getLocalPort() const @@ -290,17 +336,6 @@ uint16_t SslSocket::getRemotePort() const return atoi(getService(impl->fd, true).c_str()); } -int SslSocket::getError() const -{ - int result; - socklen_t rSize = sizeof (result); - - if (::getsockopt(impl->fd, SOL_SOCKET, SO_ERROR, &result, &rSize) < 0) - throw QPID_POSIX_ERROR(errno); - - return result; -} - void SslSocket::setTcpNoDelay(bool nodelay) const { if (nodelay) { diff --git a/cpp/src/qpid/sys/ssl/SslSocket.h b/cpp/src/qpid/sys/ssl/SslSocket.h index 25712c98d5..eabadcbe23 100644 --- a/cpp/src/qpid/sys/ssl/SslSocket.h +++ b/cpp/src/qpid/sys/ssl/SslSocket.h @@ -23,6 +23,7 @@ */ #include "qpid/sys/IOHandle.h" +#include "qpid/sys/Socket.h" #include <nspr.h> #include <string> @@ -36,7 +37,7 @@ class Duration; namespace ssl { -class SslSocket : public qpid::sys::IOHandle +class SslSocket : public qpid::sys::Socket { public: /** Create a socket wrapper for descriptor. */ @@ -53,7 +54,7 @@ public: * NSSInit().*/ void setCertName(const std::string& certName); - void connect(const std::string& host, uint16_t port) const; + void connect(const std::string& host, const std::string& port) const; void close() const; @@ -75,45 +76,13 @@ public: int read(void *buf, size_t count) const; int write(const void *buf, size_t count) const; - /** Returns the "socket name" ie the address bound to - * the near end of the socket - */ - std::string getSockname() const; - - /** Returns the "peer name" ie the address bound to - * the remote end of the socket - */ - std::string getPeername() const; - - /** - * Returns an address (host and port) for the remote end of the - * socket - */ - std::string getPeerAddress() const; - /** - * Returns an address (host and port) for the local end of the - * socket - */ - std::string getLocalAddress() const; - - /** - * Returns the full address of the connection: local and remote host and port. - */ - std::string getFullAddress() const { return getLocalAddress()+"-"+getPeerAddress(); } - uint16_t getLocalPort() const; uint16_t getRemotePort() const; - /** - * Returns the error code stored in the socket. This may be used - * to determine the result of a non-blocking connect. - */ - int getError() const; - int getKeyLen() const; std::string getClientAuthId() const; -private: +protected: mutable std::string connectname; mutable PRFileDesc* socket; std::string certname; @@ -126,6 +95,13 @@ private: mutable PRFileDesc* prototype; SslSocket(IOHandlePrivate* ioph, PRFileDesc* model); + friend class SslMuxSocket; +}; + +class SslMuxSocket : public SslSocket +{ +public: + Socket* accept() const; }; }}} |