summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSilas Parker <silas@srp.me.uk>2021-12-19 05:44:28 +0000
committerGitHub <noreply@github.com>2021-12-19 11:44:28 +0600
commit0f5e23c861b7fbe01fb138835e640aff2d7d5cce (patch)
tree4bce4720f63bacc455289caed682efdf76f74d93
parent3969118adbb133f5e0cd71a4d430e62b9a87b884 (diff)
downloadlibrabbitmq-0f5e23c861b7fbe01fb138835e640aff2d7d5cce.tar.gz
Add connect_timeout parameter to Connection (#161)
* Add connect_timeout parameter to Connection * Add connect_timeout parameter to Connection - Add unit tests and update Changelog
-rw-r--r--Changelog6
-rw-r--r--Modules/_librabbitmq/connection.c17
-rw-r--r--Modules/_librabbitmq/connection.h3
-rw-r--r--librabbitmq/__init__.py3
-rw-r--r--tests/test_functional.py34
5 files changed, 59 insertions, 4 deletions
diff --git a/Changelog b/Changelog
index 5464523..a33344d 100644
--- a/Changelog
+++ b/Changelog
@@ -7,6 +7,12 @@
.. contents::
:local:
+Next Release
+============
+
+- Add support for ``Connection.connect_timeout`` parameter
+
+
.. _version-2.0.0:
2.0.0
diff --git a/Modules/_librabbitmq/connection.c b/Modules/_librabbitmq/connection.c
index c9e7b35..daac3d2 100644
--- a/Modules/_librabbitmq/connection.c
+++ b/Modules/_librabbitmq/connection.c
@@ -1052,6 +1052,7 @@ PyRabbitMQ_ConnectionType_init(PyRabbitMQ_Connection *self,
"frame_max",
"heartbeat",
"client_properties",
+ "connect_timeout",
NULL
};
char *hostname;
@@ -1063,11 +1064,13 @@ PyRabbitMQ_ConnectionType_init(PyRabbitMQ_Connection *self,
int frame_max = 131072;
int heartbeat = 0;
int port = 5672;
+ int connect_timeout = 0;
PyObject *client_properties = NULL;
- if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|ssssiiiiO", kwlist,
+ if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|ssssiiiiOi", kwlist,
&hostname, &userid, &password, &virtual_host, &port,
- &channel_max, &frame_max, &heartbeat, &client_properties)) {
+ &channel_max, &frame_max, &heartbeat, &client_properties,
+ &connect_timeout)) {
return -1;
}
@@ -1089,6 +1092,7 @@ PyRabbitMQ_ConnectionType_init(PyRabbitMQ_Connection *self,
self->channel_max = channel_max;
self->frame_max = frame_max;
self->heartbeat = heartbeat;
+ self->connect_timeout = connect_timeout;
self->weakreflist = NULL;
self->callbacks = PyDict_New();
if (self->callbacks == NULL) return -1;
@@ -1127,6 +1131,7 @@ PyRabbitMQ_Connection_connect(PyRabbitMQ_Connection *self)
amqp_rpc_reply_t reply;
amqp_pool_t pool;
amqp_table_t properties;
+ struct timeval timeout = {0, 0};
pyobject_array_t pyobj_array = {0};
@@ -1144,7 +1149,13 @@ PyRabbitMQ_Connection_connect(PyRabbitMQ_Connection *self)
goto error;
}
Py_BEGIN_ALLOW_THREADS;
- status = amqp_socket_open(socket, self->hostname, self->port);
+ if (self->connect_timeout <= 0) {
+ status = amqp_socket_open(socket, self->hostname, self->port);
+ } else {
+ timeout.tv_sec = self->connect_timeout;
+ status = amqp_socket_open_noblock(socket, self->hostname, self->port, &timeout);
+ }
+
Py_END_ALLOW_THREADS;
if (PyRabbitMQ_HandleAMQStatus(status, "Error opening socket")) {
goto error;
diff --git a/Modules/_librabbitmq/connection.h b/Modules/_librabbitmq/connection.h
index 9595085..2d40d7d 100644
--- a/Modules/_librabbitmq/connection.h
+++ b/Modules/_librabbitmq/connection.h
@@ -151,6 +151,7 @@ typedef struct {
int frame_max;
int channel_max;
int heartbeat;
+ int connect_timeout;
int sockfd;
int connected;
@@ -267,6 +268,8 @@ static PyMemberDef PyRabbitMQ_ConnectionType_members[] = {
offsetof(PyRabbitMQ_Connection, frame_max), READONLY, NULL},
{"callbacks", T_OBJECT_EX,
offsetof(PyRabbitMQ_Connection, callbacks), READONLY, NULL},
+ {"connect_timeout", T_INT,
+ offsetof(PyRabbitMQ_Connection, connect_timeout), READONLY, NULL},
{NULL, 0, 0, 0, NULL} /* Sentinel */
};
diff --git a/librabbitmq/__init__.py b/librabbitmq/__init__.py
index a1d7e3b..ee7a06c 100644
--- a/librabbitmq/__init__.py
+++ b/librabbitmq/__init__.py
@@ -191,7 +191,7 @@ class Connection(_librabbitmq.Connection):
def __init__(self, host='localhost', userid='guest', password='guest',
virtual_host='/', port=5672, channel_max=0xffff,
frame_max=131072, heartbeat=0, lazy=False,
- client_properties=None, **kwargs):
+ client_properties=None, connect_timeout=None, **kwargs):
if ':' in host:
host, port = host.split(':')
super(Connection, self).__init__(
@@ -199,6 +199,7 @@ class Connection(_librabbitmq.Connection):
virtual_host=virtual_host, channel_max=channel_max,
frame_max=frame_max, heartbeat=heartbeat,
client_properties=client_properties,
+ connect_timeout=0 if connect_timeout is None else int(connect_timeout),
)
self.channels = {}
self._used_channel_ids = array('H')
diff --git a/tests/test_functional.py b/tests/test_functional.py
index 7b8b17f..7352d0a 100644
--- a/tests/test_functional.py
+++ b/tests/test_functional.py
@@ -6,11 +6,45 @@ from six.moves import xrange
import socket
import unittest
from array import array
+import time
from librabbitmq import Message, Connection, ConnectionError, ChannelError
TEST_QUEUE = 'pyrabbit.testq'
+class test_Connection(unittest.TestCase):
+ def test_connection_defaults(self):
+ """Test making a connection with the default settings."""
+ with Connection() as connection:
+ self.assertGreaterEqual(connection.fileno(), 0)
+
+ def test_connection_invalid_host(self):
+ """Test connection to an invalid host fails."""
+ # Will fail quickly as OS will reject it.
+ with self.assertRaises(ConnectionError):
+ Connection(host="255.255.255.255")
+
+ def test_connection_invalid_port(self):
+ """Test connection to an invalid port fails."""
+ # Will fail quickly as OS will reject it.
+ with self.assertRaises(ConnectionError):
+ Connection(port=0)
+
+ def test_connection_timeout(self):
+ """Test connection timeout."""
+ # Can't rely on local OS being configured to ignore SYN packets
+ # (OS would normally reply with RST to closed port). To test the
+ # timeout, need to connect to something that is either slow, or
+ # never responds.
+ start_time = time.time()
+ with self.assertRaises(ConnectionError):
+ Connection(host="google.com", port=81, connect_timeout=3)
+ took_time = time.time() - start_time
+ # Allow some leaway to avoid spurious test failures.
+ self.assertGreaterEqual(took_time, 2)
+ self.assertLessEqual(took_time, 4)
+
+
class test_Channel(unittest.TestCase):
def setUp(self):