From 5dd1f23b77a4d1937fc309efa73d208278ab8de4 Mon Sep 17 00:00:00 2001 From: Richard Ipsum Date: Mon, 11 May 2015 16:31:47 +0100 Subject: Use protocol to validate incoming requests Change-Id: I16680439b131e63d30eeff91814a1af643af6246 --- distbuild/initiator_connection.py | 46 +++++++++++++++++++++++---------------- distbuild/protocol.py | 26 ++++++++++++++++++---- 2 files changed, 49 insertions(+), 23 deletions(-) diff --git a/distbuild/initiator_connection.py b/distbuild/initiator_connection.py index b3e17e98..d48ad214 100644 --- a/distbuild/initiator_connection.py +++ b/distbuild/initiator_connection.py @@ -19,6 +19,11 @@ import logging import distbuild +PROTOCOL_VERSION_MISMATCH_RESPONSE = ( + 'Protocol version mismatch between server and initiator: ' + 'distbuild network uses distbuild protocol version %s, ' + 'but client uses version %s.' +) class InitiatorDisconnect(object): @@ -50,7 +55,7 @@ class InitiatorConnection(distbuild.StateMachine): state machines, and vice versa. ''' - + _idgen = distbuild.IdentifierGenerator('InitiatorConnection') _route_map = distbuild.RouteMap() @@ -122,25 +127,28 @@ class InitiatorConnection(distbuild.StateMachine): 'build-cancel': self._handle_build_cancel, 'build-status': self._handle_build_status, } - try: - if event.msg.get('protocol_version') == distbuild.protocol.VERSION: - msg_handler[event.msg['type']](event) - else: - response = ( - 'Protocol version mismatch between server & initiator: ' - 'distbuild network uses distbuild protocol version %s, ' - 'but client uses version %s.' % - (distbuild.protocol.VERSION, - event.msg.get('protocol_version'))) - self._refuse_build_request(event.msg, response) - except (KeyError, ValueError) as ex: - response = ( - 'Invalid build-request message. Check you are using a ' - 'supported version of Morph. This distbuild network uses ' - 'protocol version %i.' % distbuild.protocol.VERSION) + + protocol_version = event.msg.get('protocol_version') + msg_type = event.msg.get('type') + + if (protocol_version == distbuild.protocol.VERSION + and msg_type in msg_handler + and distbuild.protocol.is_valid_message(event.msg)): + try: + msg_handler[msg_type](event) + except Exception: + logging.exception('Error handling msg: %s', event.msg) + else: + response = 'Bad request' + + if (protocol_version is not None + and protocol_version != distbuild.protocol.VERSION): + # Provide hint to possible cause of bad request + response += ('\n' + PROTOCOL_VERSION_MISMATCH_RESPONSE % + (distbuild.protocol.VERSION, protocol_version)) + + logging.info('Invalid message from initiator: %s', event.msg) self._refuse_build_request(event.msg, response) - logging.info('Invalid message from initiator: %s: exception %r', - event.msg, ex) def _refuse_build_request(self, build_request_message, reason): '''Send an error message back to the initiator. diff --git a/distbuild/protocol.py b/distbuild/protocol.py index 9aab6a6d..44552ae1 100644 --- a/distbuild/protocol.py +++ b/distbuild/protocol.py @@ -129,13 +129,13 @@ _optional_fields = { } -def message(message_type, **kwargs): - known_types = _required_fields.keys() - assert message_type in known_types - +def _validate(message_type, **kwargs): required_fields = _required_fields[message_type] optional_fields = _optional_fields.get(message_type, []) + known_types = _required_fields.keys() + assert message_type in known_types + for name in required_fields: assert name in kwargs, 'field %s is required' % name @@ -143,7 +143,25 @@ def message(message_type, **kwargs): assert (name in required_fields or name in optional_fields), \ 'field %s is not allowed' % name +def message(message_type, **kwargs): + _validate(message_type, **kwargs) + msg = dict(kwargs) msg['type'] = message_type return msg +def is_valid_message(msg): + + if 'type' not in msg: + return False + + msg_type = msg['type'] + del msg['type'] + + try: + _validate(msg_type, **msg) + return True + except AssertionError: + return False + finally: + msg['type'] = msg_type -- cgit v1.2.1