From 0f5e23c861b7fbe01fb138835e640aff2d7d5cce Mon Sep 17 00:00:00 2001 From: Silas Parker Date: Sun, 19 Dec 2021 05:44:28 +0000 Subject: Add connect_timeout parameter to Connection (#161) * Add connect_timeout parameter to Connection * Add connect_timeout parameter to Connection - Add unit tests and update Changelog --- Changelog | 6 ++++++ Modules/_librabbitmq/connection.c | 17 ++++++++++++++--- Modules/_librabbitmq/connection.h | 3 +++ librabbitmq/__init__.py | 3 ++- tests/test_functional.py | 34 ++++++++++++++++++++++++++++++++++ 5 files changed, 59 insertions(+), 4 deletions(-) diff --git a/Changelog b/Changelog index 5464523..a33344d 100644 --- a/Changelog +++ b/Changelog @@ -7,6 +7,12 @@ .. contents:: :local: +Next Release +============ + +- Add support for ``Connection.connect_timeout`` parameter + + .. _version-2.0.0: 2.0.0 diff --git a/Modules/_librabbitmq/connection.c b/Modules/_librabbitmq/connection.c index c9e7b35..daac3d2 100644 --- a/Modules/_librabbitmq/connection.c +++ b/Modules/_librabbitmq/connection.c @@ -1052,6 +1052,7 @@ PyRabbitMQ_ConnectionType_init(PyRabbitMQ_Connection *self, "frame_max", "heartbeat", "client_properties", + "connect_timeout", NULL }; char *hostname; @@ -1063,11 +1064,13 @@ PyRabbitMQ_ConnectionType_init(PyRabbitMQ_Connection *self, int frame_max = 131072; int heartbeat = 0; int port = 5672; + int connect_timeout = 0; PyObject *client_properties = NULL; - if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|ssssiiiiO", kwlist, + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|ssssiiiiOi", kwlist, &hostname, &userid, &password, &virtual_host, &port, - &channel_max, &frame_max, &heartbeat, &client_properties)) { + &channel_max, &frame_max, &heartbeat, &client_properties, + &connect_timeout)) { return -1; } @@ -1089,6 +1092,7 @@ PyRabbitMQ_ConnectionType_init(PyRabbitMQ_Connection *self, self->channel_max = channel_max; self->frame_max = frame_max; self->heartbeat = heartbeat; + self->connect_timeout = connect_timeout; self->weakreflist = NULL; self->callbacks = PyDict_New(); if (self->callbacks == NULL) return -1; @@ -1127,6 +1131,7 @@ PyRabbitMQ_Connection_connect(PyRabbitMQ_Connection *self) amqp_rpc_reply_t reply; amqp_pool_t pool; amqp_table_t properties; + struct timeval timeout = {0, 0}; pyobject_array_t pyobj_array = {0}; @@ -1144,7 +1149,13 @@ PyRabbitMQ_Connection_connect(PyRabbitMQ_Connection *self) goto error; } Py_BEGIN_ALLOW_THREADS; - status = amqp_socket_open(socket, self->hostname, self->port); + if (self->connect_timeout <= 0) { + status = amqp_socket_open(socket, self->hostname, self->port); + } else { + timeout.tv_sec = self->connect_timeout; + status = amqp_socket_open_noblock(socket, self->hostname, self->port, &timeout); + } + Py_END_ALLOW_THREADS; if (PyRabbitMQ_HandleAMQStatus(status, "Error opening socket")) { goto error; diff --git a/Modules/_librabbitmq/connection.h b/Modules/_librabbitmq/connection.h index 9595085..2d40d7d 100644 --- a/Modules/_librabbitmq/connection.h +++ b/Modules/_librabbitmq/connection.h @@ -151,6 +151,7 @@ typedef struct { int frame_max; int channel_max; int heartbeat; + int connect_timeout; int sockfd; int connected; @@ -267,6 +268,8 @@ static PyMemberDef PyRabbitMQ_ConnectionType_members[] = { offsetof(PyRabbitMQ_Connection, frame_max), READONLY, NULL}, {"callbacks", T_OBJECT_EX, offsetof(PyRabbitMQ_Connection, callbacks), READONLY, NULL}, + {"connect_timeout", T_INT, + offsetof(PyRabbitMQ_Connection, connect_timeout), READONLY, NULL}, {NULL, 0, 0, 0, NULL} /* Sentinel */ }; diff --git a/librabbitmq/__init__.py b/librabbitmq/__init__.py index a1d7e3b..ee7a06c 100644 --- a/librabbitmq/__init__.py +++ b/librabbitmq/__init__.py @@ -191,7 +191,7 @@ class Connection(_librabbitmq.Connection): def __init__(self, host='localhost', userid='guest', password='guest', virtual_host='/', port=5672, channel_max=0xffff, frame_max=131072, heartbeat=0, lazy=False, - client_properties=None, **kwargs): + client_properties=None, connect_timeout=None, **kwargs): if ':' in host: host, port = host.split(':') super(Connection, self).__init__( @@ -199,6 +199,7 @@ class Connection(_librabbitmq.Connection): virtual_host=virtual_host, channel_max=channel_max, frame_max=frame_max, heartbeat=heartbeat, client_properties=client_properties, + connect_timeout=0 if connect_timeout is None else int(connect_timeout), ) self.channels = {} self._used_channel_ids = array('H') diff --git a/tests/test_functional.py b/tests/test_functional.py index 7b8b17f..7352d0a 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -6,11 +6,45 @@ from six.moves import xrange import socket import unittest from array import array +import time from librabbitmq import Message, Connection, ConnectionError, ChannelError TEST_QUEUE = 'pyrabbit.testq' +class test_Connection(unittest.TestCase): + def test_connection_defaults(self): + """Test making a connection with the default settings.""" + with Connection() as connection: + self.assertGreaterEqual(connection.fileno(), 0) + + def test_connection_invalid_host(self): + """Test connection to an invalid host fails.""" + # Will fail quickly as OS will reject it. + with self.assertRaises(ConnectionError): + Connection(host="255.255.255.255") + + def test_connection_invalid_port(self): + """Test connection to an invalid port fails.""" + # Will fail quickly as OS will reject it. + with self.assertRaises(ConnectionError): + Connection(port=0) + + def test_connection_timeout(self): + """Test connection timeout.""" + # Can't rely on local OS being configured to ignore SYN packets + # (OS would normally reply with RST to closed port). To test the + # timeout, need to connect to something that is either slow, or + # never responds. + start_time = time.time() + with self.assertRaises(ConnectionError): + Connection(host="google.com", port=81, connect_timeout=3) + took_time = time.time() - start_time + # Allow some leaway to avoid spurious test failures. + self.assertGreaterEqual(took_time, 2) + self.assertLessEqual(took_time, 4) + + class test_Channel(unittest.TestCase): def setUp(self): -- cgit v1.2.1