From 7cd90455a2006353c8b9361fbc6fc86739ebb4a0 Mon Sep 17 00:00:00 2001 From: Paul Brown Date: Tue, 21 Dec 2021 03:38:31 -0600 Subject: improve performance of _get_free_channel_id, fix channel max bug --- librabbitmq/__init__.py | 7 +++++-- tests/test_functional.py | 31 +++++++++++++++++++++++-------- 2 files changed, 28 insertions(+), 10 deletions(-) diff --git a/librabbitmq/__init__.py b/librabbitmq/__init__.py index ee7a06c..d2210e7 100644 --- a/librabbitmq/__init__.py +++ b/librabbitmq/__init__.py @@ -255,8 +255,11 @@ class Connection(_librabbitmq.Connection): pass def _get_free_channel_id(self): - for channel_id in range(1, self.channel_max): - if channel_id not in self._used_channel_ids: + # Cast to a set for fast lookups, and keep stored as an array for lower memory usage. + used_channel_ids = set(self._used_channel_ids) + + for channel_id in range(1, self.channel_max + 1): + if channel_id not in used_channel_ids: self._used_channel_ids.append(channel_id) return channel_id diff --git a/tests/test_functional.py b/tests/test_functional.py index 7352d0a..63b6bbc 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -44,6 +44,24 @@ class test_Connection(unittest.TestCase): self.assertGreaterEqual(took_time, 2) self.assertLessEqual(took_time, 4) + def test_get_free_channel_id(self): + with Connection() as connection: + assert connection._get_free_channel_id() == 1 + assert connection._get_free_channel_id() == 2 + + def test_get_free_channel_id__channels_full(self): + with Connection() as connection: + for _ in range(connection.channel_max): + connection._get_free_channel_id() + with self.assertRaises(ConnectionError): + connection._get_free_channel_id() + + def test_channel(self): + with Connection() as connection: + self.assertEqual(connection._used_channel_ids, array('H')) + connection.channel() + self.assertEqual(connection._used_channel_ids, array('H', (1,))) + class test_Channel(unittest.TestCase): @@ -157,14 +175,11 @@ class test_Channel(unittest.TestCase): self.connection.drain_events(timeout=0.1) self.assertEqual(len(messages), 1) - def test_get_free_channel_id(self): - self.connection._used_channel_ids = array('H') - assert self.connection._get_free_channel_id() == 1 - - def test_get_free_channel_id__channels_full(self): - self.connection._used_channel_ids = array('H', range(1, self.connection.channel_max)) - with self.assertRaises(ConnectionError): - self.connection._get_free_channel_id() + def test_close(self): + self.assertEqual(self.connection._used_channel_ids, array('H', (1,))) + self.channel.close() + self.channel = None + self.assertEqual(self.connection._used_channel_ids, array('H')) def tearDown(self): if self.channel and self.connection.connected: -- cgit v1.2.1