summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPaul Brown <paul90brown@gmail.com>2021-12-21 03:38:31 -0600
committerAsif Saif Uddin <auvipy@gmail.com>2021-12-21 17:46:28 +0600
commit7cd90455a2006353c8b9361fbc6fc86739ebb4a0 (patch)
treee326ed52ef4b334b72a6658c970bc84aaf0aed24
parent0f5e23c861b7fbe01fb138835e640aff2d7d5cce (diff)
downloadlibrabbitmq-7cd90455a2006353c8b9361fbc6fc86739ebb4a0.tar.gz
improve performance of _get_free_channel_id, fix channel max bug
-rw-r--r--librabbitmq/__init__.py7
-rw-r--r--tests/test_functional.py31
2 files changed, 28 insertions, 10 deletions
diff --git a/librabbitmq/__init__.py b/librabbitmq/__init__.py
index ee7a06c..d2210e7 100644
--- a/librabbitmq/__init__.py
+++ b/librabbitmq/__init__.py
@@ -255,8 +255,11 @@ class Connection(_librabbitmq.Connection):
pass
def _get_free_channel_id(self):
- for channel_id in range(1, self.channel_max):
- if channel_id not in self._used_channel_ids:
+ # Cast to a set for fast lookups, and keep stored as an array for lower memory usage.
+ used_channel_ids = set(self._used_channel_ids)
+
+ for channel_id in range(1, self.channel_max + 1):
+ if channel_id not in used_channel_ids:
self._used_channel_ids.append(channel_id)
return channel_id
diff --git a/tests/test_functional.py b/tests/test_functional.py
index 7352d0a..63b6bbc 100644
--- a/tests/test_functional.py
+++ b/tests/test_functional.py
@@ -44,6 +44,24 @@ class test_Connection(unittest.TestCase):
self.assertGreaterEqual(took_time, 2)
self.assertLessEqual(took_time, 4)
+ def test_get_free_channel_id(self):
+ with Connection() as connection:
+ assert connection._get_free_channel_id() == 1
+ assert connection._get_free_channel_id() == 2
+
+ def test_get_free_channel_id__channels_full(self):
+ with Connection() as connection:
+ for _ in range(connection.channel_max):
+ connection._get_free_channel_id()
+ with self.assertRaises(ConnectionError):
+ connection._get_free_channel_id()
+
+ def test_channel(self):
+ with Connection() as connection:
+ self.assertEqual(connection._used_channel_ids, array('H'))
+ connection.channel()
+ self.assertEqual(connection._used_channel_ids, array('H', (1,)))
+
class test_Channel(unittest.TestCase):
@@ -157,14 +175,11 @@ class test_Channel(unittest.TestCase):
self.connection.drain_events(timeout=0.1)
self.assertEqual(len(messages), 1)
- def test_get_free_channel_id(self):
- self.connection._used_channel_ids = array('H')
- assert self.connection._get_free_channel_id() == 1
-
- def test_get_free_channel_id__channels_full(self):
- self.connection._used_channel_ids = array('H', range(1, self.connection.channel_max))
- with self.assertRaises(ConnectionError):
- self.connection._get_free_channel_id()
+ def test_close(self):
+ self.assertEqual(self.connection._used_channel_ids, array('H', (1,)))
+ self.channel.close()
+ self.channel = None
+ self.assertEqual(self.connection._used_channel_ids, array('H'))
def tearDown(self):
if self.channel and self.connection.connected: