summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorOmer Katz <omer.drow@gmail.com>2017-07-14 08:09:06 +0300
committerOmer Katz <omer.drow@gmail.com>2017-07-14 08:09:06 +0300
commite1ee66b24c10615f25dce1e2293fbec9cc00a1f4 (patch)
tree92a4c23a5dd9c755702409f1670e4284c9e12a41
parent942b3aa9cf76d08951fe7baae163b595300cdfe5 (diff)
downloadpy-amqp-fix-string-mechanisms.tar.gz
Convert mechanisms to bytes if it is a stringfix-string-mechanisms
-rw-r--r--amqp/connection.py4
-rw-r--r--t/unit/test_connection.py17
2 files changed, 20 insertions, 1 deletions
diff --git a/amqp/connection.py b/amqp/connection.py
index f0c4540..2ed89fb 100644
--- a/amqp/connection.py
+++ b/amqp/connection.py
@@ -33,7 +33,7 @@ from .exceptions import (
ConnectionForced, ConnectionError, error_for_code,
RecoverableConnectionError, RecoverableChannelError,
)
-from .five import array, items, monotonic, range, values
+from .five import array, items, monotonic, range, values, string
from .method_framing import frame_handler, frame_writer
from .transport import Transport
@@ -344,6 +344,8 @@ class Connection(AbstractChannel):
self.version_major = version_major
self.version_minor = version_minor
self.server_properties = server_properties
+ if isinstance(mechanisms, string):
+ mechanisms = mechanisms.encode('utf-8')
self.mechanisms = mechanisms.split(b' ')
self.locales = locales.split(' ')
AMQP_LOGGER.debug(
diff --git a/t/unit/test_connection.py b/t/unit/test_connection.py
index 3a94475..621bb72 100644
--- a/t/unit/test_connection.py
+++ b/t/unit/test_connection.py
@@ -104,6 +104,23 @@ class test_Connection:
),
)
+
+ def test_on_start_string_mechanisms(self):
+ self.conn._on_start(3, 4, {'foo': 'bar'}, 'x y z AMQPLAIN PLAIN',
+ 'en_US en_GB')
+ assert self.conn.version_major == 3
+ assert self.conn.version_minor == 4
+ assert self.conn.server_properties == {'foo': 'bar'}
+ assert self.conn.mechanisms == [b'x', b'y', b'z',
+ b'AMQPLAIN', b'PLAIN']
+ assert self.conn.locales == ['en_US', 'en_GB']
+ self.conn.send_method.assert_called_with(
+ spec.Connection.StartOk, 'FsSs', (
+ self.conn.client_properties, b'AMQPLAIN',
+ self.conn.authentication[0].start(self.conn), self.conn.locale,
+ ),
+ )
+
def test_missing_credentials(self):
with pytest.raises(ValueError):
self.conn = Connection(userid=None, password=None)