diff options
| author | Rafael H. Schloming <rhs@apache.org> | 2009-12-26 12:42:57 +0000 |
|---|---|---|
| committer | Rafael H. Schloming <rhs@apache.org> | 2009-12-26 12:42:57 +0000 |
| commit | 248f1fe188fe2307b9dcf2c87a83b653eaa1920c (patch) | |
| tree | d5d0959a70218946ff72e107a6c106e32479a398 /python/qpid | |
| parent | 3c83a0e3ec7cf4dc23e83a340b25f5fc1676f937 (diff) | |
| download | qpid-python-248f1fe188fe2307b9dcf2c87a83b653eaa1920c.tar.gz | |
synchronized with trunk except for ruby dir
git-svn-id: https://svn.apache.org/repos/asf/qpid/branches/qpid.rnr@893970 13f79535-47bb-0310-9956-ffa450edef68
Diffstat (limited to 'python/qpid')
42 files changed, 6401 insertions, 1639 deletions
diff --git a/python/qpid/address.py b/python/qpid/address.py new file mode 100644 index 0000000000..6228ac757b --- /dev/null +++ b/python/qpid/address.py @@ -0,0 +1,161 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +import re +from lexer import Lexicon, LexError +from parser import Parser, ParseError + +l = Lexicon() + +LBRACE = l.define("LBRACE", r"\{") +RBRACE = l.define("RBRACE", r"\}") +LBRACK = l.define("LBRACK", r"\[") +RBRACK = l.define("RBRACK", r"\]") +COLON = l.define("COLON", r":") +SEMI = l.define("SEMI", r";") +SLASH = l.define("SLASH", r"/") +COMMA = l.define("COMMA", r",") +NUMBER = l.define("NUMBER", r'[+-]?[0-9]*\.?[0-9]+') +ID = l.define("ID", r'[a-zA-Z_](?:[a-zA-Z0-9_-]*[a-zA-Z0-9_])?') +STRING = l.define("STRING", r""""(?:[^\\"]|\\.)*"|'(?:[^\\']|\\.)*'""") +ESC = l.define("ESC", r"\\[^ux]|\\x[0-9a-fA-F][0-9a-fA-F]|\\u[0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F]") +SYM = l.define("SYM", r"[.#*%@$^!+-]") +WSPACE = l.define("WSPACE", r"[ \n\r\t]+") +EOF = l.eof("EOF") + +LEXER = l.compile() + +def lex(st): + return LEXER.lex(st) + +def tok2str(tok): + if tok.type is STRING: + return eval(tok.value) + elif tok.type is ESC: + if tok.value[1] == "x": + return eval('"%s"' % tok.value) + elif tok.value[1] == "u": + return eval('u"%s"' % tok.value) + else: + return tok.value[1] + else: + return tok.value + +def tok2obj(tok): + if tok.type in (STRING, NUMBER): + return eval(tok.value) + else: + return tok.value + +def toks2str(toks): + if toks: + return "".join(map(tok2str, toks)) + else: + return None + +class AddressParser(Parser): + + def __init__(self, tokens): + Parser.__init__(self, [t for t in tokens if t.type is not WSPACE]) + + def parse(self): + result = self.address() + self.eat(EOF) + return result + + def address(self): + name = toks2str(self.eat_until(SLASH, SEMI, EOF)) + + if name is None: + raise ParseError(self.next()) + + if self.matches(SLASH): + self.eat(SLASH) + subject = toks2str(self.eat_until(SEMI, EOF)) + else: + subject = None + + if self.matches(SEMI): + self.eat(SEMI) + options = self.map() + else: + options = None + return name, subject, options + + def map(self): + self.eat(LBRACE) + + result = {} + while True: + if self.matches(ID): + n, v = self.nameval() + result[n] = v + if self.matches(COMMA): + self.eat(COMMA) + elif self.matches(RBRACE): + break + else: + raise ParseError(self.next(), COMMA, RBRACE) + elif self.matches(RBRACE): + break + else: + raise ParseError(self.next(), ID, RBRACE) + + self.eat(RBRACE) + return result + + def nameval(self): + name = self.eat(ID).value + self.eat(COLON) + val = self.value() + return (name, val) + + def value(self): + if self.matches(NUMBER, STRING, ID): + return tok2obj(self.eat()) + elif self.matches(LBRACE): + return self.map() + elif self.matches(LBRACK): + return self.list() + else: + raise ParseError(self.next(), NUMBER, STRING, ID, LBRACE, LBRACK) + + def list(self): + self.eat(LBRACK) + + result = [] + + while True: + if self.matches(RBRACK): + break + else: + result.append(self.value()) + if self.matches(COMMA): + self.eat(COMMA) + elif self.matches(RBRACK): + break + else: + raise ParseError(self.next(), COMMA, RBRACK) + + self.eat(RBRACK) + return result + +def parse(addr): + return AddressParser(lex(addr)).parse() + +__all__ = ["parse", "ParseError"] diff --git a/python/qpid/assembler.py b/python/qpid/assembler.py deleted file mode 100644 index 92bb0aa0f8..0000000000 --- a/python/qpid/assembler.py +++ /dev/null @@ -1,118 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# - -from codec010 import StringCodec -from framer import * -from logging import getLogger - -log = getLogger("qpid.io.seg") - -class Segment: - - def __init__(self, first, last, type, track, channel, payload): - self.id = None - self.offset = None - self.first = first - self.last = last - self.type = type - self.track = track - self.channel = channel - self.payload = payload - - def decode(self, spec): - segs = spec["segment_type"] - choice = segs.choices[self.type] - return getattr(self, "decode_%s" % choice.name)(spec) - - def decode_control(self, spec): - sc = StringCodec(spec, self.payload) - return sc.read_control() - - def decode_command(self, spec): - sc = StringCodec(spec, self.payload) - hdr, cmd = sc.read_command() - cmd.id = self.id - return hdr, cmd - - def decode_header(self, spec): - sc = StringCodec(spec, self.payload) - values = [] - while len(sc.encoded) > 0: - values.append(sc.read_struct32()) - return values - - def decode_body(self, spec): - return self.payload - - def __str__(self): - return "%s%s %s %s %s %r" % (int(self.first), int(self.last), self.type, - self.track, self.channel, self.payload) - - def __repr__(self): - return str(self) - -class Assembler(Framer): - - def __init__(self, sock, max_payload = Frame.MAX_PAYLOAD): - Framer.__init__(self, sock) - self.max_payload = max_payload - self.fragments = {} - - def read_segment(self): - while True: - frame = self.read_frame() - - key = (frame.channel, frame.track) - seg = self.fragments.get(key) - if seg == None: - seg = Segment(frame.isFirstSegment(), frame.isLastSegment(), - frame.type, frame.track, frame.channel, "") - self.fragments[key] = seg - - seg.payload += frame.payload - - if frame.isLastFrame(): - self.fragments.pop(key) - log.debug("RECV %s", seg) - return seg - - def write_segment(self, segment): - remaining = segment.payload - - first = True - while first or remaining: - payload = remaining[:self.max_payload] - remaining = remaining[self.max_payload:] - - flags = 0 - if first: - flags |= FIRST_FRM - first = False - if not remaining: - flags |= LAST_FRM - if segment.first: - flags |= FIRST_SEG - if segment.last: - flags |= LAST_SEG - - frame = Frame(flags, segment.type, segment.track, segment.channel, - payload) - self.write_frame(frame) - - log.debug("SENT %s", segment) diff --git a/python/qpid/brokertest.py b/python/qpid/brokertest.py new file mode 100644 index 0000000000..83d6c44d84 --- /dev/null +++ b/python/qpid/brokertest.py @@ -0,0 +1,480 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +# Support library for tests that start multiple brokers, e.g. cluster +# or federation + +import os, signal, string, tempfile, popen2, socket, threading, time, imp +import qpid, traceback +from qpid import connection, messaging, util +from qpid.compat import format_exc +from qpid.harness import Skipped +from unittest import TestCase +from copy import copy +from threading import Thread, Lock, Condition +from logging import getLogger + +log = getLogger("qpid.brokertest") + +# Values for expected outcome of process at end of test +EXPECT_EXIT_OK=1 # Expect to exit with 0 status before end of test. +EXPECT_EXIT_FAIL=2 # Expect to exit with non-0 status before end of test. +EXPECT_RUNNING=3 # Expect to still be running at end of test + +def is_running(pid): + try: + os.kill(pid, 0) + return True + except: + return False + +class BadProcessStatus(Exception): + pass + +class ExceptionWrapper: + """Proxy object that adds a message to exceptions raised""" + def __init__(self, obj, msg): + self.obj = obj + self.msg = msg + + def __getattr__(self, name): + func = getattr(self.obj, name) + return lambda *args, **kwargs: self._wrap(func, args, kwargs) + + def _wrap(self, func, args, kwargs): + try: + return func(*args, **kwargs) + except Exception, e: + raise Exception("%s: %s" %(self.msg, str(e))) + +def error_line(f): + try: + lines = file(f).readlines() + if len(lines) > 0: return ": %s" % (lines[-1]) + except: pass + return "" + + +class Popen(popen2.Popen3): + """ + Similar to subprocess.Popen but using popen2 classes for portability. + Can set and verify expectation of process status at end of test. + Dumps command line, stdout, stderr to data dir for debugging. + """ + + def __init__(self, cmd, expect=EXPECT_EXIT_OK): + if type(cmd) is type(""): cmd = [cmd] # Make it a list. + self.cmd = [ str(x) for x in cmd ] + popen2.Popen3.__init__(self, self.cmd, True) + self.expect = expect + self.pname = "%s-%d" % (os.path.split(self.cmd[0])[-1], self.pid) + msg = "Process %s" % self.pname + self.stdin = ExceptionWrapper(self.tochild, msg) + self.stdout = ExceptionWrapper(self.fromchild, msg) + self.stderr = ExceptionWrapper(self.childerr, msg) + self.dump(self.cmd_str(), "cmd") + log.debug("Started process %s" % self.pname) + + def dump(self, str, ext): + name = "%s.%s" % (self.pname, ext) + f = file(name, "w") + f.write(str) + f.close() + return name + + def unexpected(self,msg): + self.dump(self.stdout.read(), "out") + err = self.dump(self.stderr.read(), "err") + raise BadProcessStatus("%s %s%s" % (self.pname, msg, error_line(err))) + + def stop(self): # Clean up at end of test. + if self.expect == EXPECT_RUNNING: + try: + self.kill() + except: + self.unexpected("expected running, exit code %d" % self.wait()) + else: + # Give the process some time to exit. + delay = 0.1 + while (self.poll() is None and delay < 1): + time.sleep(delay) + delay *= 2 + if self.returncode is None: # Still haven't stopped + self.kill() + self.unexpected("still running") + elif self.expect == EXPECT_EXIT_OK and self.returncode != 0: + self.unexpected("exit code %d" % self.returncode) + elif self.expect == EXPECT_EXIT_FAIL and self.returncode == 0: + self.unexpected("expected error") + + def communicate(self, input=None): + if input: + self.stdin.write(input) + self.stdin.close() + outerr = (self.stdout.read(), self.stderr.read()) + self.wait() + return outerr + + def is_running(self): return self.poll() is None + + def assert_running(self): + if not self.is_running(): unexpected("Exit code %d" % self.returncode) + + def poll(self): + self.returncode = popen2.Popen3.poll(self) + if (self.returncode == -1): self.returncode = None + return self.returncode + + def wait(self): + self.returncode = popen2.Popen3.wait(self) + return self.returncode + + def send_signal(self, sig): + os.kill(self.pid,sig) + self.wait() + + def terminate(self): self.send_signal(signal.SIGTERM) + def kill(self): self.send_signal(signal.SIGKILL) + + def cmd_str(self): return " ".join([str(s) for s in self.cmd]) + +def checkenv(name): + value = os.getenv(name) + if not value: raise Exception("Environment variable %s is not set" % name) + return value + +class Broker(Popen): + "A broker process. Takes care of start, stop and logging." + _broker_count = 0 + + def __init__(self, test, args=[], name=None, expect=EXPECT_RUNNING): + """Start a broker daemon. name determines the data-dir and log + file names.""" + + self.test = test + self._port = None + cmd = [BrokerTest.qpidd_exec, "--port=0", "--no-module-dir", "--auth=no"] + args + if name: self.name = name + else: + self.name = "broker%d" % Broker._broker_count + Broker._broker_count += 1 + self.log = self.name+".log" + cmd += ["--log-to-file", self.log, "--log-prefix", self.name] + cmd += ["--log-to-stderr=no"] + self.datadir = self.name + cmd += ["--data-dir", self.datadir] + Popen.__init__(self, cmd, expect) + test.cleanup_stop(self) + self.host = "localhost" + log.debug("Started broker %s (%s)" % (self.name, self.pname)) + + def port(self): + # Read port from broker process stdout if not already read. + if (self._port is None): + try: self._port = int(self.stdout.readline()) + except ValueError, e: + raise Exception("Can't get port for broker %s (%s)%s" % + (self.name, self.pname, error_line(self.log))) + return self._port + + def unexpected(self,msg): + raise BadProcessStatus("%s: %s (%s)" % (msg, self.name, self.pname)) + + def connect(self): + """New API connection to the broker.""" + return messaging.Connection.open(self.host, self.port()) + + def connect_old(self): + """Old API connection to the broker.""" + socket = qpid.util.connect(self.host,self.port()) + connection = qpid.connection.Connection (sock=socket) + connection.start() + return connection; + + def declare_queue(self, queue): + c = self.connect_old() + s = c.session(str(qpid.datatypes.uuid4())) + s.queue_declare(queue=queue) + c.close() + + def send_message(self, queue, message): + s = self.connect().session() + s.sender(queue+"; {create:always}").send(message) + s.connection.close() + + def send_messages(self, queue, messages): + s = self.connect().session() + sender = s.sender(queue+"; {create:always}") + for m in messages: sender.send(m) + s.connection.close() + + def get_message(self, queue): + s = self.connect().session() + m = s.receiver(queue+"; {create:always}", capacity=1).fetch(timeout=1) + s.acknowledge() + s.connection.close() + return m + + def get_messages(self, queue, n): + s = self.connect().session() + receiver = s.receiver(queue+"; {create:always}", capacity=n) + m = [receiver.fetch(timeout=1) for i in range(n)] + s.acknowledge() + s.connection.close() + return m + + def host_port(self): return "%s:%s" % (self.host, self.port()) + + +class Cluster: + """A cluster of brokers in a test.""" + + _cluster_count = 0 + + def __init__(self, test, count=0, args=[], expect=EXPECT_RUNNING, wait=True): + self.test = test + self._brokers=[] + self.name = "cluster%d" % Cluster._cluster_count + Cluster._cluster_count += 1 + # Use unique cluster name + self.args = copy(args) + self.args += [ "--cluster-name", "%s-%s:%d" % (self.name, socket.gethostname(), os.getpid()) ] + assert BrokerTest.cluster_lib + self.args += [ "--load-module", BrokerTest.cluster_lib ] + self.start_n(count, expect=expect, wait=wait) + + def start(self, name=None, expect=EXPECT_RUNNING, wait=True, args=[]): + """Add a broker to the cluster. Returns the index of the new broker.""" + if not name: name="%s-%d" % (self.name, len(self._brokers)) + log.debug("Cluster %s starting member %s" % (self.name, name)) + self._brokers.append(self.test.broker(self.args+args, name, expect, wait)) + return self._brokers[-1] + + def start_n(self, count, expect=EXPECT_RUNNING, wait=True, args=[]): + for i in range(count): self.start(expect=expect, wait=wait, args=args) + + # Behave like a list of brokers. + def __len__(self): return len(self._brokers) + def __getitem__(self,index): return self._brokers[index] + def __iter__(self): return self._brokers.__iter__() + +class BrokerTest(TestCase): + """ + Tracks processes started by test and kills at end of test. + Provides a well-known working directory for each test. + """ + + # Environment settings. + qpidd_exec = checkenv("QPIDD_EXEC") + cluster_lib = os.getenv("CLUSTER_LIB") + xml_lib = os.getenv("XML_LIB") + qpidConfig_exec = os.getenv("QPID_CONFIG_EXEC") + qpidRoute_exec = os.getenv("QPID_ROUTE_EXEC") + receiver_exec = os.getenv("RECEIVER_EXEC") + sender_exec = os.getenv("SENDER_EXEC") + store_lib = os.getenv("STORE_LIB") + test_store_lib = os.getenv("TEST_STORE_LIB") + rootdir = os.getcwd() + + def configure(self, config): self.config=config + + def setUp(self): + outdir = self.config.defines.get("OUTDIR") or "brokertest.tmp" + self.dir = os.path.join(self.rootdir, outdir, self.id()) + os.makedirs(self.dir) + os.chdir(self.dir) + self.stopem = [] # things to stop at end of test + + def tearDown(self): + err = [] + for p in self.stopem: + try: p.stop() + except Exception, e: err.append(str(e)) + if err: raise Exception("Unexpected process status:\n "+"\n ".join(err)) + + def cleanup_stop(self, stopable): + """Call thing.stop at end of test""" + self.stopem.append(stopable) + + def popen(self, cmd, expect=EXPECT_EXIT_OK): + """Start a process that will be killed at end of test, in the test dir.""" + os.chdir(self.dir) + p = Popen(cmd, expect) + self.cleanup_stop(p) + return p + + def broker(self, args=[], name=None, expect=EXPECT_RUNNING,wait=True): + """Create and return a broker ready for use""" + b = Broker(self, args=args, name=name, expect=expect) + if (wait): b.connect().close() + return b + + def cluster(self, count=0, args=[], expect=EXPECT_RUNNING, wait=True): + """Create and return a cluster ready for use""" + cluster = Cluster(self, count, args, expect=expect, wait=wait) + return cluster + + def wait(): + """Wait for all brokers in the cluster to be ready""" + for b in _brokers: b.connect().close() + +class RethrownException(Exception): + """Captures the original stack trace to be thrown later""" + def __init__(self, e, msg=""): + Exception.__init__(self, msg+"\n"+format_exc()) + +class StoppableThread(Thread): + """ + Base class for threads that do something in a loop and periodically check + to see if they have been stopped. + """ + def __init__(self): + self.stopped = False + self.error = None + Thread.__init__(self) + + def stop(self): + self.stopped = True + self.join() + if self.error: raise self.error + +class NumberedSender(Thread): + """ + Thread to run a sender client and send numbered messages until stopped. + """ + + def __init__(self, broker, max_depth=None): + """ + max_depth: enable flow control, ensure sent - received <= max_depth. + Requires self.received(n) to be called each time messages are received. + """ + Thread.__init__(self) + self.sender = broker.test.popen( + [broker.test.sender_exec, "--port", broker.port()], expect=EXPECT_RUNNING) + self.condition = Condition() + self.max = max_depth + self.received = 0 + self.stopped = False + self.error = None + + def run(self): + try: + self.sent = 0 + while not self.stopped: + if self.max: + self.condition.acquire() + while not self.stopped and self.sent - self.received > self.max: + self.condition.wait() + self.condition.release() + self.sender.stdin.write(str(self.sent)+"\n") + self.sender.stdin.flush() + self.sent += 1 + except Exception, e: self.error = RethrownException(e, self.sender.pname) + + def notify_received(self, count): + """Called by receiver to enable flow control. count = messages received so far.""" + self.condition.acquire() + self.received = count + self.condition.notify() + self.condition.release() + + def stop(self): + self.condition.acquire() + self.stopped = True + self.condition.notify() + self.condition.release() + self.join() + if self.error: raise self.error + +class NumberedReceiver(Thread): + """ + Thread to run a receiver client and verify it receives + sequentially numbered messages. + """ + def __init__(self, broker, sender = None): + """ + sender: enable flow control. Call sender.received(n) for each message received. + """ + Thread.__init__(self) + self.test = broker.test + self.receiver = self.test.popen( + [self.test.receiver_exec, "--port", broker.port()], expect=EXPECT_RUNNING) + self.stopat = None + self.lock = Lock() + self.error = None + self.sender = sender + + def continue_test(self): + self.lock.acquire() + ret = self.stopat is None or self.received < self.stopat + self.lock.release() + return ret + + def run(self): + try: + self.received = 0 + while self.continue_test(): + m = int(self.receiver.stdout.readline()) + assert(m <= self.received) # Allow for duplicates + if (m == self.received): + self.received += 1 + if self.sender: + self.sender.notify_received(self.received) + except Exception, e: + self.error = RethrownException(e, self.receiver.pname) + + def stop(self, count): + """Returns when received >= count""" + self.lock.acquire() + self.stopat = count + self.lock.release() + self.join() + if self.error: raise self.error + +class ErrorGenerator(StoppableThread): + """ + Thread that continuously generates errors by trying to consume from + a non-existent queue. For cluster regression tests, error handling + caused issues in the past. + """ + + def __init__(self, broker): + StoppableThread.__init__(self) + self.broker=broker + broker.test.cleanup_stop(self) + self.start() + + def run(self): + c = self.broker.connect_old() + try: + while not self.stopped: + try: + c.session(str(qpid.datatypes.uuid4())).message_subscribe( + queue="non-existent-queue") + assert(False) + except qpid.session.SessionException: pass + except: pass # Normal if broker is killed. + +def import_script(path): + """ + Import executable script at path as a module. + Requires some trickery as scripts are not in standard module format + """ + name=os.path.split(path)[1].replace("-","_") + return imp.load_module(name, file(path), path, ("", "r", imp.PY_SOURCE)) diff --git a/python/qpid/client.py b/python/qpid/client.py index 4605710de8..6107a4bc35 100644 --- a/python/qpid/client.py +++ b/python/qpid/client.py @@ -39,11 +39,8 @@ class Client: if spec: self.spec = spec else: - try: - name = os.environ["AMQP_SPEC"] - except KeyError: - raise EnvironmentError("environment variable AMQP_SPEC must be set") - self.spec = load(name) + from qpid_config import amqp_spec_0_9 + self.spec = load(amqp_spec_0_9) self.structs = StructFactory(self.spec) self.sessions = {} diff --git a/python/qpid/codec010.py b/python/qpid/codec010.py index dac023e2bd..682743df19 100644 --- a/python/qpid/codec010.py +++ b/python/qpid/codec010.py @@ -17,25 +17,69 @@ # under the License. # +import datetime from packer import Packer -from datatypes import serial, RangedSet, Struct +from datatypes import serial, timestamp, RangedSet, Struct, UUID +from ops import Compound, PRIMITIVE, COMPOUND class CodecException(Exception): pass +def direct(t): + return lambda x: t + +def map_str(s): + for c in s: + if ord(c) >= 0x80: + return "vbin16" + return "str16" + class Codec(Packer): - def __init__(self, spec): - self.spec = spec + ENCODINGS = { + unicode: direct("str16"), + str: map_str, + buffer: direct("vbin32"), + int: direct("int64"), + long: direct("int64"), + float: direct("double"), + None.__class__: direct("void"), + list: direct("list"), + tuple: direct("list"), + dict: direct("map"), + timestamp: direct("datetime"), + datetime.datetime: direct("datetime"), + UUID: direct("uuid"), + Compound: direct("struct32") + } + + def encoding(self, obj): + enc = self._encoding(obj.__class__, obj) + if enc is None: + raise CodecException("no encoding for %r" % obj) + return PRIMITIVE[enc] + + def _encoding(self, klass, obj): + if self.ENCODINGS.has_key(klass): + return self.ENCODINGS[klass](obj) + for base in klass.__bases__: + result = self._encoding(base, obj) + if result != None: + return result + + def read_primitive(self, type): + return getattr(self, "read_%s" % type.NAME)() + def write_primitive(self, type, v): + getattr(self, "write_%s" % type.NAME)(v) - def write_void(self, v): - assert v == None def read_void(self): return None + def write_void(self, v): + assert v == None - def write_bit(self, b): - if not b: raise ValueError(b) def read_bit(self): return True + def write_bit(self, b): + if not b: raise ValueError(b) def read_uint8(self): return self.unpack("!B") @@ -68,7 +112,7 @@ class Codec(Packer): def read_int16(self): return self.unpack("!h") def write_int16(self, n): - return self.unpack("!h", n) + self.pack("!h", n) def read_uint32(self): @@ -103,9 +147,11 @@ class Codec(Packer): self.pack("!q", n) def read_datetime(self): - return self.read_uint64() - def write_datetime(self, n): - self.write_uint64(n) + return timestamp(self.read_uint64()) + def write_datetime(self, t): + if isinstance(t, datetime.datetime): + t = timestamp(t) + self.write_uint64(t) def read_double(self): return self.unpack("!d") @@ -115,6 +161,8 @@ class Codec(Packer): def read_vbin8(self): return self.read(self.read_uint8()) def write_vbin8(self, b): + if isinstance(b, buffer): + b = str(b) self.write_uint8(len(b)) self.write(b) @@ -128,10 +176,17 @@ class Codec(Packer): def write_str16(self, s): self.write_vbin16(s.encode("utf8")) + def read_str16_latin(self): + return self.read_vbin16().decode("iso-8859-15") + def write_str16_latin(self, s): + self.write_vbin16(s.encode("iso-8859-15")) + def read_vbin16(self): return self.read(self.read_uint16()) def write_vbin16(self, b): + if isinstance(b, buffer): + b = str(b) self.write_uint16(len(b)) self.write(b) @@ -155,23 +210,13 @@ class Codec(Packer): def read_vbin32(self): return self.read(self.read_uint32()) def write_vbin32(self, b): + if isinstance(b, buffer): + b = str(b) self.write_uint32(len(b)) self.write(b) - def write_map(self, m): - sc = StringCodec(self.spec) - if m is not None: - sc.write_uint32(len(m)) - for k, v in m.items(): - type = self.spec.encoding(v.__class__) - if type == None: - raise CodecException("no encoding for %s" % v.__class__) - sc.write_str8(k) - sc.write_uint8(type.code) - type.encode(sc, v) - self.write_vbin32(sc.encoded) def read_map(self): - sc = StringCodec(self.spec, self.read_vbin32()) + sc = StringCodec(self.read_vbin32()) if not sc.encoded: return None count = sc.read_uint32() @@ -179,105 +224,146 @@ class Codec(Packer): while sc.encoded: k = sc.read_str8() code = sc.read_uint8() - type = self.spec.types[code] - v = type.decode(sc) + type = PRIMITIVE[code] + v = sc.read_primitive(type) result[k] = v return result + def write_map(self, m): + sc = StringCodec() + if m is not None: + sc.write_uint32(len(m)) + for k, v in m.items(): + type = self.encoding(v) + sc.write_str8(k) + sc.write_uint8(type.CODE) + sc.write_primitive(type, v) + self.write_vbin32(sc.encoded) + def read_array(self): + sc = StringCodec(self.read_vbin32()) + if not sc.encoded: + return None + type = PRIMITIVE[sc.read_uint8()] + count = sc.read_uint32() + result = [] + while count > 0: + result.append(sc.read_primitive(type)) + count -= 1 + return result def write_array(self, a): - sc = StringCodec(self.spec) + sc = StringCodec() if a is not None: if len(a) > 0: - type = self.spec.encoding(a[0].__class__) + type = self.encoding(a[0]) else: - type = self.spec.encoding(None.__class__) - sc.write_uint8(type.code) + type = self.encoding(None) + sc.write_uint8(type.CODE) sc.write_uint32(len(a)) for o in a: - type.encode(sc, o) + sc.write_primitive(type, o) self.write_vbin32(sc.encoded) - def read_array(self): - sc = StringCodec(self.spec, self.read_vbin32()) + + def read_list(self): + sc = StringCodec(self.read_vbin32()) if not sc.encoded: return None - type = self.spec.types[sc.read_uint8()] count = sc.read_uint32() result = [] while count > 0: - result.append(type.decode(sc)) + type = PRIMITIVE[sc.read_uint8()] + result.append(sc.read_primitive(type)) count -= 1 return result - def write_list(self, l): - sc = StringCodec(self.spec) + sc = StringCodec() if l is not None: sc.write_uint32(len(l)) for o in l: - type = self.spec.encoding(o.__class__) - sc.write_uint8(type.code) - type.encode(sc, o) + type = self.encoding(o) + sc.write_uint8(type.CODE) + sc.write_primitive(type, o) self.write_vbin32(sc.encoded) - def read_list(self): - sc = StringCodec(self.spec, self.read_vbin32()) - if not sc.encoded: - return None - count = sc.read_uint32() - result = [] - while count > 0: - type = self.spec.types[sc.read_uint8()] - result.append(type.decode(sc)) - count -= 1 - return result def read_struct32(self): size = self.read_uint32() code = self.read_uint16() - type = self.spec.structs[code] - fields = type.decode_fields(self) - return Struct(type, **fields) + cls = COMPOUND[code] + op = cls() + self.read_fields(op) + return op def write_struct32(self, value): - sc = StringCodec(self.spec) - sc.write_uint16(value._type.code) - value._type.encode_fields(sc, value) - self.write_vbin32(sc.encoded) - - def read_control(self): - cntrl = self.spec.controls[self.read_uint16()] - return Struct(cntrl, **cntrl.decode_fields(self)) - def write_control(self, ctrl): - type = ctrl._type - self.write_uint16(type.code) - type.encode_fields(self, ctrl) - - def read_command(self): - type = self.spec.commands[self.read_uint16()] - hdr = self.spec["session.header"].decode(self) - cmd = Struct(type, **type.decode_fields(self)) - return hdr, cmd - def write_command(self, hdr, cmd): - self.write_uint16(cmd._type.code) - hdr._type.encode(self, hdr) - cmd._type.encode_fields(self, cmd) + self.write_compound(value) + + def read_compound(self, cls): + size = self.read_size(cls.SIZE) + if cls.CODE is not None: + code = self.read_uint16() + assert code == cls.CODE + op = cls() + self.read_fields(op) + return op + def write_compound(self, op): + sc = StringCodec() + if op.CODE is not None: + sc.write_uint16(op.CODE) + sc.write_fields(op) + self.write_size(op.SIZE, len(sc.encoded)) + self.write(sc.encoded) + + def read_fields(self, op): + flags = 0 + for i in range(op.PACK): + flags |= (self.read_uint8() << 8*i) + + for i in range(len(op.FIELDS)): + f = op.FIELDS[i] + if flags & (0x1 << i): + if COMPOUND.has_key(f.type): + value = self.read_compound(COMPOUND[f.type]) + else: + value = getattr(self, "read_%s" % f.type)() + setattr(op, f.name, value) + def write_fields(self, op): + flags = 0 + for i in range(len(op.FIELDS)): + f = op.FIELDS[i] + value = getattr(op, f.name) + if f.type == "bit": + present = value + else: + present = value != None + if present: + flags |= (0x1 << i) + for i in range(op.PACK): + self.write_uint8((flags >> 8*i) & 0xFF) + for i in range(len(op.FIELDS)): + f = op.FIELDS[i] + if flags & (0x1 << i): + if COMPOUND.has_key(f.type): + enc = self.write_compound + else: + enc = getattr(self, "write_%s" % f.type) + value = getattr(op, f.name) + enc(value) def read_size(self, width): if width > 0: attr = "read_uint%d" % (width*8) return getattr(self, attr)() - def write_size(self, width, n): if width > 0: attr = "write_uint%d" % (width*8) getattr(self, attr)(n) def read_uuid(self): - return self.unpack("16s") - + return UUID(self.unpack("16s")) def write_uuid(self, s): + if isinstance(s, UUID): + s = s.bytes self.pack("16s", s) def read_bin128(self): return self.unpack("16s") - def write_bin128(self, b): self.pack("16s", b) @@ -285,14 +371,13 @@ class Codec(Packer): class StringCodec(Codec): - def __init__(self, spec, encoded = ""): - Codec.__init__(self, spec) + def __init__(self, encoded = ""): self.encoded = encoded - def write(self, s): - self.encoded += s - def read(self, n): result = self.encoded[:n] self.encoded = self.encoded[n:] return result + + def write(self, s): + self.encoded += s diff --git a/python/qpid/compat.py b/python/qpid/compat.py index 26f60fb8aa..c2b668a5e9 100644 --- a/python/qpid/compat.py +++ b/python/qpid/compat.py @@ -17,6 +17,8 @@ # under the License. # +import sys + try: set = set except NameError: @@ -26,3 +28,95 @@ try: from socket import SHUT_RDWR except ImportError: SHUT_RDWR = 2 + +try: + from traceback import format_exc +except ImportError: + import traceback + def format_exc(): + return "".join(traceback.format_exception(*sys.exc_info())) + +if tuple(sys.version_info[0:2]) < (2, 4): + from select import select as old_select + def select(rlist, wlist, xlist, timeout=None): + return old_select(list(rlist), list(wlist), list(xlist), timeout) +else: + from select import select + +class BaseWaiter: + + def wakeup(self): + self._do_write() + + def wait(self, timeout=None): + if timeout is not None: + ready, _, _ = select([self], [], [], timeout) + else: + ready = True + + if ready: + self._do_read() + return True + else: + return False + + def reading(self): + return True + + def readable(self): + self._do_read() + +if sys.platform in ('win32', 'cygwin'): + import socket + + class SockWaiter(BaseWaiter): + + def __init__(self, read_sock, write_sock): + self.read_sock = read_sock + self.write_sock = write_sock + + def _do_write(self): + self.write_sock.send("\0") + + def _do_read(self): + self.read_sock.recv(65536) + + def fileno(self): + return self.read_sock.fileno() + + def __repr__(self): + return "SockWaiter(%r, %r)" % (self.read_sock, self.write_sock) + + def selectable_waiter(): + listener = socket.socket() + listener.bind(('', 0)) + listener.listen(1) + _, port = listener.getsockname() + write_sock = socket.socket() + write_sock.connect(("127.0.0.1", port)) + read_sock, _ = listener.accept() + listener.close() + return SockWaiter(read_sock, write_sock) +else: + import os + + class PipeWaiter(BaseWaiter): + + def __init__(self, read_fd, write_fd): + self.read_fd = read_fd + self.write_fd = write_fd + + def _do_write(self): + os.write(self.write_fd, "\0") + + def _do_read(self): + os.read(self.read_fd, 65536) + + def fileno(self): + return self.read_fd + + def __repr__(self): + return "PipeWaiter(%r, %r)" % (self.read_fd, self.write_fd) + + def selectable_waiter(): + return PipeWaiter(*os.pipe()) diff --git a/python/qpid/concurrency.py b/python/qpid/concurrency.py new file mode 100644 index 0000000000..9837a3f0df --- /dev/null +++ b/python/qpid/concurrency.py @@ -0,0 +1,100 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import compat, inspect, time + +def synchronized(meth): + args, vargs, kwargs, defs = inspect.getargspec(meth) + scope = {} + scope["meth"] = meth + exec """ +def %s%s: + %s + %s._lock.acquire() + try: + return meth%s + finally: + %s._lock.release() +""" % (meth.__name__, inspect.formatargspec(args, vargs, kwargs, defs), + repr(inspect.getdoc(meth)), args[0], + inspect.formatargspec(args, vargs, kwargs, defs, + formatvalue=lambda x: ""), + args[0]) in scope + return scope[meth.__name__] + +class Waiter(object): + + def __init__(self, condition): + self.condition = condition + + def wait(self, predicate, timeout=None): + passed = 0 + start = time.time() + while not predicate(): + if timeout is None: + # XXX: this timed wait thing is not necessary for the fast + # condition from this module, only for the condition impl from + # the threading module + + # using the timed wait prevents keyboard interrupts from being + # blocked while waiting + self.condition.wait(3) + elif passed < timeout: + self.condition.wait(timeout - passed) + else: + return bool(predicate()) + passed = time.time() - start + return True + + def notify(self): + self.condition.notify() + + def notifyAll(self): + self.condition.notifyAll() + +class Condition: + + def __init__(self, lock): + self.lock = lock + self.waiters = [] + self.waiting = [] + + def notify(self): + assert self.lock._is_owned() + if self.waiting: + self.waiting[0].wakeup() + + def notifyAll(self): + assert self.lock._is_owned() + for w in self.waiting: + w.wakeup() + + def wait(self, timeout=None): + assert self.lock._is_owned() + if not self.waiters: + self.waiters.append(compat.selectable_waiter()) + sw = self.waiters.pop(0) + self.waiting.append(sw) + try: + st = self.lock._release_save() + sw.wait(timeout) + finally: + self.lock._acquire_restore(st) + self.waiting.remove(sw) + self.waiters.append(sw) diff --git a/python/qpid/connection.py b/python/qpid/connection.py index ce27a74489..18eeb99de8 100644 --- a/python/qpid/connection.py +++ b/python/qpid/connection.py @@ -20,15 +20,14 @@ import datatypes, session from threading import Thread, Condition, RLock from util import wait, notify -from assembler import Assembler, Segment from codec010 import StringCodec +from framing import * from session import Session -from invoker import Invoker -from spec010 import Control, Command, load -from spec import default +from generator import control_invoker +from spec import SPEC from exceptions import * from logging import getLogger -import delegates +import delegates, socket class ChannelBusy(Exception): pass @@ -44,28 +43,33 @@ def client(*args, **kwargs): def server(*args, **kwargs): return delegates.Server(*args, **kwargs) -class Connection(Assembler): +from framer import Framer - def __init__(self, sock, spec=None, delegate=client, **args): - Assembler.__init__(self, sock) - if spec == None: - spec = load(default()) - self.spec = spec - self.track = self.spec["track"] +class Connection(Framer): + def __init__(self, sock, delegate=client, **args): + Framer.__init__(self, sock) self.lock = RLock() self.attached = {} self.sessions = {} self.condition = Condition() + # XXX: we should combine this into a single comprehensive state + # model (whatever that means) self.opened = False self.failed = False + self.closed = False self.close_code = (None, "connection aborted") self.thread = Thread(target=self.run) self.thread.setDaemon(True) self.channel_max = 65535 + self.user_id = None + + self.op_enc = OpEncoder() + self.seg_enc = SegmentEncoder() + self.frame_enc = FrameEncoder() self.delegate = delegate(self, **args) @@ -79,7 +83,7 @@ class Connection(Assembler): else: ssn = self.sessions.get(name) if ssn is None: - ssn = Session(name, self.spec, delegate=delegate) + ssn = Session(name, delegate=delegate) self.sessions[name] = ssn elif ssn.channel is not None: if force: @@ -107,8 +111,7 @@ class Connection(Assembler): self.lock.release() def __channel(self): - # XXX: ch 0? - for i in xrange(self.channel_max): + for i in xrange(1, self.channel_max): if not self.attached.has_key(i): return i else: @@ -147,15 +150,45 @@ class Connection(Assembler): raise ConnectionFailed(*self.close_code) def run(self): - # XXX: we don't really have a good way to exit this loop without - # getting the other end to kill the socket - while True: + frame_dec = FrameDecoder() + seg_dec = SegmentDecoder() + op_dec = OpDecoder() + + while not self.closed: try: - seg = self.read_segment() - except Closed: + data = self.sock.recv(64*1024) + if self.security_layer_rx and data: + status, data = self.security_layer_rx.decode(data) + if not data: + self.detach_all() + break + except socket.timeout: + if self.aborted(): + self.detach_all() + raise Closed("connection timed out") + else: + continue + except socket.error, e: self.detach_all() - break - self.delegate.received(seg) + raise Closed(e) + frame_dec.write(data) + seg_dec.write(*frame_dec.read()) + op_dec.write(*seg_dec.read()) + for op in op_dec.read(): + self.delegate.received(op) + self.sock.close() + + def write_op(self, op): + self.sock_lock.acquire() + try: + self.op_enc.write(op) + self.seg_enc.write(*self.op_enc.read()) + self.frame_enc.write(*self.seg_enc.read()) + bytes = self.frame_enc.read() + self.write(bytes) + self.flush() + finally: + self.sock_lock.release() def close(self, timeout=None): if not self.opened: return @@ -172,26 +205,17 @@ class Connection(Assembler): log = getLogger("qpid.io.ctl") -class Channel(Invoker): +class Channel(control_invoker()): def __init__(self, connection, id): self.connection = connection self.id = id self.session = None - def resolve_method(self, name): - inst = self.connection.spec.instructions.get(name) - if inst is not None and isinstance(inst, Control): - return self.METHOD, inst - else: - return self.ERROR, None - - def invoke(self, type, args, kwargs): - ctl = type.new(args, kwargs) - sc = StringCodec(self.connection.spec) - sc.write_control(ctl) - self.connection.write_segment(Segment(True, True, type.segment_type, - type.track, self.id, sc.encoded)) + def invoke(self, op, args, kwargs): + ctl = op(*args, **kwargs) + ctl.channel = self.id + self.connection.write_op(ctl) log.debug("SENT %s", ctl) def __str__(self): diff --git a/python/qpid/connection08.py b/python/qpid/connection08.py index 8f2eef4770..d34cfe2847 100644 --- a/python/qpid/connection08.py +++ b/python/qpid/connection08.py @@ -28,6 +28,7 @@ from cStringIO import StringIO from spec import load from codec import EOF from compat import SHUT_RDWR +from exceptions import VersionError class SockIO: @@ -73,6 +74,9 @@ def listen(host, port, predicate = lambda: True): s, a = sock.accept() yield SockIO(s) +class FramingError(Exception): + pass + class Connection: def __init__(self, io, spec): @@ -107,7 +111,16 @@ class Connection: def read_8_0(self): c = self.codec - type = self.spec.constants.byid[c.decode_octet()].name + tid = c.decode_octet() + try: + type = self.spec.constants.byid[tid].name + except KeyError: + if tid == ord('A') and c.unpack("!3s") == "MQP": + _, _, major, minor = c.unpack("4B") + raise VersionError("client: %s-%s, server: %s-%s" % + (self.spec.major, self.spec.minor, major, minor)) + else: + raise FramingError("unknown frame type: %s" % tid) channel = c.decode_short() body = c.decode_longstr() dec = codec.Codec(StringIO(body), self.spec) @@ -122,6 +135,12 @@ class Connection: raise "frame error: expected %r, got %r" % (self.FRAME_END, garbage) return frame + def write_0_9(self, frame): + self.write_8_0(frame) + + def read_0_9(self): + return self.read_8_0() + def write_0_10(self, frame): c = self.codec flags = 0 diff --git a/python/qpid/datatypes.py b/python/qpid/datatypes.py index 7150caded2..61643715e4 100644 --- a/python/qpid/datatypes.py +++ b/python/qpid/datatypes.py @@ -17,7 +17,8 @@ # under the License. # -import threading, struct +import threading, struct, datetime, time +from exceptions import Timeout class Struct: @@ -83,7 +84,7 @@ class Message: def get(self, name): if self.headers: for h in self.headers: - if h._type.name == name: + if h.NAME == name: return h return None @@ -92,7 +93,7 @@ class Message: self.headers = [] idx = 0 while idx < len(self.headers): - if self.headers[idx]._type == header._type: + if self.headers[idx].NAME == header.NAME: self.headers[idx] = header return idx += 1 @@ -101,7 +102,7 @@ class Message: def clear(self, name): idx = 0 while idx < len(self.headers): - if self.headers[idx]._type.name == name: + if self.headers[idx].NAME == name: del self.headers[idx] return idx += 1 @@ -125,19 +126,19 @@ def serial(o): class Serial: def __init__(self, value): - self.value = value & 0xFFFFFFFF + self.value = value & 0xFFFFFFFFL def __hash__(self): return hash(self.value) def __cmp__(self, other): - if other is None: + if other.__class__ not in (int, long, Serial): return 1 other = serial(other) - delta = (self.value - other.value) & 0xFFFFFFFF - neg = delta & 0x80000000 + delta = (self.value - other.value) & 0xFFFFFFFFL + neg = delta & 0x80000000L mag = delta & 0x7FFFFFFF if neg: @@ -149,7 +150,10 @@ class Serial: return Serial(self.value + other) def __sub__(self, other): - return Serial(self.value - other) + if isinstance(other, Serial): + return self.value - other.value + else: + return Serial(self.value - other) def __repr__(self): return "serial(%s)" % self.value @@ -168,7 +172,7 @@ class Range: def __contains__(self, n): return self.lower <= n and n <= self.upper - + def __iter__(self): i = self.lower while i <= self.upper: @@ -229,7 +233,25 @@ class RangedSet: def add(self, lower, upper = None): self.add_range(Range(lower, upper)) - + + def empty(self): + for r in self.ranges: + if r.lower <= r.upper: + return False + return True + + def max(self): + if self.ranges: + return self.ranges[-1].upper + else: + return None + + def min(self): + if self.ranges: + return self.ranges[0].lower + else: + return None + def __iter__(self): return iter(self.ranges) @@ -253,9 +275,12 @@ class Future: def get(self, timeout=None): self._set.wait(timeout) - if self._error != None: - raise self.exception(self._error) - return self.value + if self._set.isSet(): + if self._error != None: + raise self.exception(self._error) + return self.value + else: + raise Timeout() def is_set(self): return self._set.isSet() @@ -289,10 +314,62 @@ class UUID: def __cmp__(self, other): if isinstance(other, UUID): return cmp(self.bytes, other.bytes) - raise NotImplemented() + else: + return -1 def __str__(self): return "%08x-%04x-%04x-%04x-%04x%08x" % struct.unpack("!LHHHHL", self.bytes) def __repr__(self): return "UUID(%r)" % str(self) + + def __hash__(self): + return self.bytes.__hash__() + +class timestamp(float): + + def __new__(cls, obj=None): + if obj is None: + obj = time.time() + elif isinstance(obj, datetime.datetime): + obj = time.mktime(obj.timetuple()) + 1e-6 * obj.microsecond + return super(timestamp, cls).__new__(cls, obj) + + def datetime(self): + return datetime.datetime.fromtimestamp(self) + + def __add__(self, other): + if isinstance(other, datetime.timedelta): + return timestamp(self.datetime() + other) + else: + return timestamp(float(self) + other) + + def __sub__(self, other): + if isinstance(other, datetime.timedelta): + return timestamp(self.datetime() - other) + else: + return timestamp(float(self) - other) + + def __radd__(self, other): + if isinstance(other, datetime.timedelta): + return timestamp(self.datetime() + other) + else: + return timestamp(other + float(self)) + + def __rsub__(self, other): + if isinstance(other, datetime.timedelta): + return timestamp(self.datetime() - other) + else: + return timestamp(other - float(self)) + + def __neg__(self): + return timestamp(-float(self)) + + def __pos__(self): + return self + + def __abs__(self): + return timestamp(abs(float(self))) + + def __repr__(self): + return "timestamp(%r)" % float(self) diff --git a/python/qpid/debug.py b/python/qpid/debug.py new file mode 100644 index 0000000000..b5dbd4d9d9 --- /dev/null +++ b/python/qpid/debug.py @@ -0,0 +1,55 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import threading, traceback, signal, sys, time + +def stackdump(sig, frm): + code = [] + for threadId, stack in sys._current_frames().items(): + code.append("\n# ThreadID: %s" % threadId) + for filename, lineno, name, line in traceback.extract_stack(stack): + code.append('File: "%s", line %d, in %s' % (filename, lineno, name)) + if line: + code.append(" %s" % (line.strip())) + print "\n".join(code) + +signal.signal(signal.SIGQUIT, stackdump) + +class LoudLock: + + def __init__(self): + self.lock = threading.RLock() + + def acquire(self, blocking=1): + while not self.lock.acquire(blocking=0): + time.sleep(1) + print >> sys.out, "TRYING" + traceback.print_stack(None, None, out) + print >> sys.out, "TRYING" + print >> sys.out, "ACQUIRED" + traceback.print_stack(None, None, out) + print >> sys.out, "ACQUIRED" + return True + + def _is_owned(self): + return self.lock._is_owned() + + def release(self): + self.lock.release() + diff --git a/python/qpid/delegates.py b/python/qpid/delegates.py index bf26553dda..4c41a6241f 100644 --- a/python/qpid/delegates.py +++ b/python/qpid/delegates.py @@ -20,7 +20,17 @@ import os, connection, session from util import notify from datatypes import RangedSet +from exceptions import VersionError, Closed from logging import getLogger +from ops import Control +import sys + +_have_sasl = None +try: + import saslwrapper + _have_sasl = True +except: + pass log = getLogger("qpid.io.ctl") @@ -28,26 +38,22 @@ class Delegate: def __init__(self, connection, delegate=session.client): self.connection = connection - self.spec = connection.spec self.delegate = delegate - self.control = self.spec["track.control"].value - def received(self, seg): - ssn = self.connection.attached.get(seg.channel) + def received(self, op): + ssn = self.connection.attached.get(op.channel) if ssn is None: - ch = connection.Channel(self.connection, seg.channel) + ch = connection.Channel(self.connection, op.channel) else: ch = ssn.channel - if seg.track == self.control: - ctl = seg.decode(self.spec) - log.debug("RECV %s", ctl) - attr = ctl._type.qname.replace(".", "_") - getattr(self, attr)(ch, ctl) + if isinstance(op, Control): + log.debug("RECV %s", op) + getattr(self, op.NAME)(ch, op) elif ssn is None: ch.session_detached() else: - ssn.received(seg) + ssn.received(op) def connection_close(self, ch, close): self.connection.close_code = (close.reply_code, close.reply_text) @@ -59,8 +65,12 @@ class Delegate: def connection_close_ok(self, ch, close_ok): self.connection.opened = False + self.connection.closed = True notify(self.connection.condition) + def connection_heartbeat(self, ch, hrt): + pass + def session_attach(self, ch, a): try: self.connection.attach(a.name, ch, self.delegate, a.force) @@ -119,7 +129,8 @@ class Server(Delegate): def start(self): self.connection.read_header() - self.connection.write_header(self.spec.major, self.spec.minor) + # XXX + self.connection.write_header(0, 10) connection.Channel(self.connection, 0).connection_start(mechanisms=["ANONYMOUS"]) def connection_start_ok(self, ch, start_ok): @@ -135,28 +146,101 @@ class Server(Delegate): class Client(Delegate): + ppid = 0 + try: + ppid = os.getppid() + except: + pass + PROPERTIES = {"product": "qpid python client", "version": "development", - "platform": os.name} + "platform": os.name, + "qpid.client_process": os.path.basename(sys.argv[0]), + "qpid.client_pid": os.getpid(), + "qpid.client_ppid": ppid} - def __init__(self, connection, username="guest", password="guest", mechanism="PLAIN"): + def __init__(self, connection, username=None, password=None, + mechanism=None, heartbeat=None, **kwargs): Delegate.__init__(self, connection) - self.username = username - self.password = password - self.mechanism = mechanism + + ## + ## self.acceptableMechanisms is the list of SASL mechanisms that the client is willing to + ## use. If it's None, then any mechanism is acceptable. + ## + self.acceptableMechanisms = None + if mechanism: + self.acceptableMechanisms = mechanism.split(" ") + self.heartbeat = heartbeat + self.username = username + self.password = password + + if _have_sasl: + self.sasl = saslwrapper.Client() + if username and len(username) > 0: + self.sasl.setAttr("username", str(username)) + if password and len(password) > 0: + self.sasl.setAttr("password", str(password)) + if "service" in kwargs: + self.sasl.setAttr("service", str(kwargs["service"])) + if "host" in kwargs: + self.sasl.setAttr("host", str(kwargs["host"])) + if "min_ssf" in kwargs: + self.sasl.setAttr("minssf", kwargs["min_ssf"]) + if "max_ssf" in kwargs: + self.sasl.setAttr("maxssf", kwargs["max_ssf"]) + self.sasl.init() def start(self): - self.connection.write_header(self.spec.major, self.spec.minor) - self.connection.read_header() + # XXX + cli_major = 0 + cli_minor = 10 + self.connection.write_header(cli_major, cli_minor) + magic, _, _, major, minor = self.connection.read_header() + if not (magic == "AMQP" and major == cli_major and minor == cli_minor): + raise VersionError("client: %s-%s, server: %s-%s" % + (cli_major, cli_minor, major, minor)) def connection_start(self, ch, start): - r = "\0%s\0%s" % (self.username, self.password) - ch.connection_start_ok(client_properties=Client.PROPERTIES, mechanism=self.mechanism, response=r) + mech_list = "" + for mech in start.mechanisms: + if (not self.acceptableMechanisms) or mech in self.acceptableMechanisms: + mech_list += str(mech) + " " + mech = None + initial = None + if _have_sasl: + status, mech, initial = self.sasl.start(mech_list) + if status == False: + raise Closed("SASL error: %s" % self.sasl.getError()) + else: + if self.username and self.password and ("PLAIN" in mech_list): + mech = "PLAIN" + initial = "\0%s\0%s" % (self.username, self.password) + else: + mech = "ANONYMOUS" + if not mech in mech_list: + raise Closed("No acceptable SASL authentication mechanism available") + ch.connection_start_ok(client_properties=Client.PROPERTIES, mechanism=mech, response=initial) + + def connection_secure(self, ch, secure): + resp = None + if _have_sasl: + status, resp = self.sasl.step(secure.challenge) + if status == False: + raise Closed("SASL error: %s" % self.sasl.getError()) + ch.connection_secure_ok(response=resp) def connection_tune(self, ch, tune): - ch.connection_tune_ok() + ch.connection_tune_ok(heartbeat=self.heartbeat) ch.connection_open() + if _have_sasl: + self.connection.user_id = self.sasl.getUserId() + self.connection.security_layer_tx = self.sasl def connection_open_ok(self, ch, open_ok): + if _have_sasl: + self.connection.security_layer_rx = self.sasl self.connection.opened = True notify(self.connection.condition) + + def connection_heartbeat(self, ch, hrt): + ch.connection_heartbeat() diff --git a/python/qpid/disp.py b/python/qpid/disp.py index d697cd0136..1b315c9d98 100644 --- a/python/qpid/disp.py +++ b/python/qpid/disp.py @@ -21,16 +21,115 @@ from time import strftime, gmtime +class Header: + """ """ + NONE = 1 + KMG = 2 + YN = 3 + Y = 4 + TIME_LONG = 5 + TIME_SHORT = 6 + DURATION = 7 + + def __init__(self, text, format=NONE): + self.text = text + self.format = format + + def __repr__(self): + return self.text + + def __str__(self): + return self.text + + def formatted(self, value): + try: + if value == None: + return '' + if self.format == Header.NONE: + return value + if self.format == Header.KMG: + return self.num(value) + if self.format == Header.YN: + if value: + return 'Y' + return 'N' + if self.format == Header.Y: + if value: + return 'Y' + return '' + if self.format == Header.TIME_LONG: + return strftime("%c", gmtime(value / 1000000000)) + if self.format == Header.TIME_SHORT: + return strftime("%X", gmtime(value / 1000000000)) + if self.format == Header.DURATION: + if value < 0: value = 0 + sec = value / 1000000000 + min = sec / 60 + hour = min / 60 + day = hour / 24 + result = "" + if day > 0: + result = "%dd " % day + if hour > 0 or result != "": + result += "%dh " % (hour % 24) + if min > 0 or result != "": + result += "%dm " % (min % 60) + result += "%ds" % (sec % 60) + return result + except: + return "?" + + def numCell(self, value, tag): + fp = float(value) / 1000. + if fp < 10.0: + return "%1.2f%c" % (fp, tag) + if fp < 100.0: + return "%2.1f%c" % (fp, tag) + return "%4d%c" % (value / 1000, tag) + + def num(self, value): + if value < 1000: + return "%4d" % value + if value < 1000000: + return self.numCell(value, 'k') + value /= 1000 + if value < 1000000: + return self.numCell(value, 'm') + value /= 1000 + return self.numCell(value, 'g') + + class Display: """ Display formatting for QPID Management CLI """ - def __init__ (self): - self.tableSpacing = 2 - self.tablePrefix = " " + def __init__(self, spacing=2, prefix=" "): + self.tableSpacing = spacing + self.tablePrefix = prefix self.timestampFormat = "%X" - def table (self, title, heads, rows): - """ Print a formatted table with autosized columns """ + def formattedTable(self, title, heads, rows): + fRows = [] + for row in rows: + fRow = [] + col = 0 + for cell in row: + fRow.append(heads[col].formatted(cell)) + col += 1 + fRows.append(fRow) + headtext = [] + for head in heads: + headtext.append(head.text) + self.table(title, headtext, fRows) + + def table(self, title, heads, rows): + """ Print a table with autosized columns """ + + # Pad the rows to the number of heads + for row in rows: + diff = len(heads) - len(row) + for idx in range(diff): + row.append("") + print title if len (rows) == 0: return @@ -40,7 +139,7 @@ class Display: for head in heads: width = len (head) for row in rows: - cellWidth = len (str (row[col])) + cellWidth = len (unicode (row[col])) if cellWidth > width: width = cellWidth colWidth.append (width + self.tableSpacing) @@ -60,9 +159,9 @@ class Display: line = self.tablePrefix col = 0 for width in colWidth: - line = line + str (row[col]) + line = line + unicode (row[col]) if col < len (heads) - 1: - for i in range (width - len (str (row[col]))): + for i in range (width - len (unicode (row[col]))): line = line + " " col = col + 1 print line @@ -77,3 +176,59 @@ class Display: def timestamp (self, nsec): """ Format a nanosecond-since-the-epoch timestamp for printing """ return strftime (self.timestampFormat, gmtime (nsec / 1000000000)) + + def duration(self, nsec): + if nsec < 0: nsec = 0 + sec = nsec / 1000000000 + min = sec / 60 + hour = min / 60 + day = hour / 24 + result = "" + if day > 0: + result = "%dd " % day + if hour > 0 or result != "": + result += "%dh " % (hour % 24) + if min > 0 or result != "": + result += "%dm " % (min % 60) + result += "%ds" % (sec % 60) + return result + +class Sortable: + """ """ + def __init__(self, row, sortIndex): + self.row = row + self.sortIndex = sortIndex + if sortIndex >= len(row): + raise Exception("sort index exceeds row boundary") + + def __cmp__(self, other): + return cmp(self.row[self.sortIndex], other.row[self.sortIndex]) + + def getRow(self): + return self.row + +class Sorter: + """ """ + def __init__(self, heads, rows, sortCol, limit=0, inc=True): + col = 0 + for head in heads: + if head.text == sortCol: + break + col += 1 + if col == len(heads): + raise Exception("sortCol '%s', not found in headers" % sortCol) + + list = [] + for row in rows: + list.append(Sortable(row, col)) + list.sort(reverse=not inc) + count = 0 + self.sorted = [] + for row in list: + self.sorted.append(row.getRow()) + count += 1 + if count == limit: + break + + def getSorted(self): + return self.sorted diff --git a/python/qpid/driver.py b/python/qpid/driver.py new file mode 100644 index 0000000000..2851c3aad3 --- /dev/null +++ b/python/qpid/driver.py @@ -0,0 +1,859 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import address, compat, connection, socket, struct, sys, time +from concurrency import synchronized +from datatypes import RangedSet, Serial +from exceptions import Timeout, VersionError +from framing import OpEncoder, SegmentEncoder, FrameEncoder, FrameDecoder, SegmentDecoder, OpDecoder +from logging import getLogger +from messaging import get_codec, ConnectError, Message, Pattern, UNLIMITED +from ops import * +from selector import Selector +from threading import Condition, Thread +from util import connect + +log = getLogger("qpid.messaging") +rawlog = getLogger("qpid.messaging.io.raw") +opslog = getLogger("qpid.messaging.io.ops") + +def addr2reply_to(addr): + name, subject, options = address.parse(addr) + return ReplyTo(name, subject) + +def reply_to2addr(reply_to): + if reply_to.routing_key is None: + return reply_to.exchange + elif reply_to.exchange in (None, ""): + return reply_to.routing_key + else: + return "%s/%s" % (reply_to.exchange, reply_to.routing_key) + +class Attachment: + + def __init__(self, target): + self.target = target + +# XXX + +DURABLE_DEFAULT=True + +# XXX + +FILTER_DEFAULTS = { + "topic": Pattern("*") + } + +# XXX +ppid = 0 +try: + ppid = os.getppid() +except: + pass + +CLIENT_PROPERTIES = {"product": "qpid python client", + "version": "development", + "platform": os.name, + "qpid.client_process": os.path.basename(sys.argv[0]), + "qpid.client_pid": os.getpid(), + "qpid.client_ppid": ppid} + +def noop(): pass + +class SessionState: + + def __init__(self, driver, session, name, channel): + self.driver = driver + self.session = session + self.name = name + self.channel = channel + self.detached = False + self.committing = False + self.aborting = False + + # sender state + self.sent = Serial(0) + self.acknowledged = RangedSet() + self.actions = {} + self.min_completion = self.sent + self.max_completion = self.sent + self.results = {} + + # receiver state + self.received = None + self.executed = RangedSet() + + # XXX: need to periodically exchange completion/known_completion + + def write_query(self, query, handler): + id = self.sent + self.write_cmd(query, lambda: handler(self.results.pop(id))) + + def write_cmd(self, cmd, action=noop): + if action != noop: + cmd.sync = True + if self.detached: + raise Exception("detached") + cmd.id = self.sent + self.sent += 1 + self.actions[cmd.id] = action + self.max_completion = cmd.id + self.write_op(cmd) + + def write_op(self, op): + op.channel = self.channel + self.driver.write_op(op) + +# XXX +HEADER="!4s4B" + +EMPTY_DP = DeliveryProperties() +EMPTY_MP = MessageProperties() + +SUBJECT = "qpid.subject" +TO = "qpid.to" + +class Driver: + + def __init__(self, connection): + self.connection = connection + self._lock = self.connection._lock + + self._selector = Selector.default() + self.reset() + + def reset(self): + self._opening = False + self._closing = False + self._connected = False + self._attachments = {} + + self._channel_max = 65536 + self._channels = 0 + self._sessions = {} + + self._socket = None + self._buf = "" + self._hdr = "" + self._op_enc = OpEncoder() + self._seg_enc = SegmentEncoder() + self._frame_enc = FrameEncoder() + self._frame_dec = FrameDecoder() + self._seg_dec = SegmentDecoder() + self._op_dec = OpDecoder() + self._timeout = None + + for ssn in self.connection.sessions.values(): + for m in ssn.acked + ssn.unacked + ssn.incoming: + m._transfer_id = None + for snd in ssn.senders: + snd.linked = False + for rcv in ssn.receivers: + rcv.impending = rcv.received + rcv.linked = False + + @synchronized + def wakeup(self): + self.dispatch() + self._selector.wakeup() + + def start(self): + self._selector.register(self) + + def fileno(self): + return self._socket.fileno() + + @synchronized + def reading(self): + return self._socket is not None + + @synchronized + def writing(self): + return self._socket is not None and self._buf + + @synchronized + def timing(self): + return self._timeout + + @synchronized + def readable(self): + error = None + recoverable = False + try: + data = self._socket.recv(64*1024) + if data: + rawlog.debug("READ: %r", data) + else: + rawlog.debug("ABORTED: %s", self._socket.getpeername()) + error = "connection aborted" + recoverable = True + except socket.error, e: + error = e + recoverable = True + + if not error: + try: + if len(self._hdr) < 8: + r = 8 - len(self._hdr) + self._hdr += data[:r] + data = data[r:] + + if len(self._hdr) == 8: + self.do_header(self._hdr) + + self._frame_dec.write(data) + self._seg_dec.write(*self._frame_dec.read()) + self._op_dec.write(*self._seg_dec.read()) + for op in self._op_dec.read(): + self.assign_id(op) + opslog.debug("RCVD: %r", op) + op.dispatch(self) + except VersionError, e: + error = e + except: + msg = compat.format_exc() + error = msg + + if error: + self._error(error, recoverable) + else: + self.dispatch() + + self.connection._waiter.notifyAll() + + def assign_id(self, op): + if isinstance(op, Command): + sst = self.get_sst(op) + op.id = sst.received + sst.received += 1 + + @synchronized + def writeable(self): + try: + n = self._socket.send(self._buf) + rawlog.debug("SENT: %r", self._buf[:n]) + self._buf = self._buf[n:] + except socket.error, e: + self._error(e, True) + self.connection._waiter.notifyAll() + + @synchronized + def timeout(self): + log.warn("retrying ...") + self.dispatch() + self.connection._waiter.notifyAll() + + def _error(self, err, recoverable): + if self._socket is not None: + self._socket.close() + self.reset() + if recoverable and self.connection.reconnect: + self._timeout = time.time() + 3 + log.warn("recoverable error: %s" % err) + log.warn("sleeping 3 seconds") + else: + self.connection.error = (err,) + + def write_op(self, op): + opslog.debug("SENT: %r", op) + self._op_enc.write(op) + self._seg_enc.write(*self._op_enc.read()) + self._frame_enc.write(*self._seg_enc.read()) + self._buf += self._frame_enc.read() + + def do_header(self, hdr): + cli_major = 0; cli_minor = 10 + magic, _, _, major, minor = struct.unpack(HEADER, hdr) + if major != cli_major or minor != cli_minor: + raise VersionError("client: %s-%s, server: %s-%s" % + (cli_major, cli_minor, major, minor)) + + def do_connection_start(self, start): + # XXX: should we use some sort of callback for this? + r = "\0%s\0%s" % (self.connection.username, self.connection.password) + m = self.connection.mechanism + self.write_op(ConnectionStartOk(client_properties=CLIENT_PROPERTIES, + mechanism=m, response=r)) + + def do_connection_tune(self, tune): + # XXX: is heartbeat protocol specific? + if tune.channel_max is not None: + self.channel_max = tune.channel_max + self.write_op(ConnectionTuneOk(heartbeat=self.connection.heartbeat, + channel_max=self.channel_max)) + self.write_op(ConnectionOpen()) + + def do_connection_open_ok(self, open_ok): + self._connected = True + + def connection_heartbeat(self, hrt): + self.write_op(ConnectionHeartbeat()) + + def do_connection_close(self, close): + self.write_op(ConnectionCloseOk()) + if close.reply_code != close_code.normal: + self.connection.error = (close.reply_code, close.reply_text) + # XXX: should we do a half shutdown on the socket here? + # XXX: we really need to test this, we may end up reporting a + # connection abort after this, if we were to do a shutdown on read + # and stop reading, then we wouldn't report the abort, that's + # probably the right thing to do + + def do_connection_close_ok(self, close_ok): + self._socket.close() + self.reset() + + def do_session_attached(self, atc): + pass + + def do_session_command_point(self, cp): + sst = self.get_sst(cp) + sst.received = cp.command_id + + def do_session_completed(self, sc): + sst = self.get_sst(sc) + for r in sc.commands: + sst.acknowledged.add(r.lower, r.upper) + + if not sc.commands.empty(): + while sst.min_completion in sc.commands: + if sst.actions.has_key(sst.min_completion): + sst.actions.pop(sst.min_completion)() + sst.min_completion += 1 + + def session_known_completed(self, kcmp): + sst = self.get_sst(kcmp) + executed = RangedSet() + for e in sst.executed.ranges: + for ke in kcmp.ranges: + if e.lower in ke and e.upper in ke: + break + else: + executed.add_range(e) + sst.executed = completed + + def do_session_flush(self, sf): + sst = self.get_sst(sf) + if sf.expected: + if sst.received is None: + exp = None + else: + exp = RangedSet(sst.received) + sst.write_op(SessionExpected(exp)) + if sf.confirmed: + sst.write_op(SessionConfirmed(sst.executed)) + if sf.completed: + sst.write_op(SessionCompleted(sst.executed)) + + def do_execution_result(self, er): + sst = self.get_sst(er) + sst.results[er.command_id] = er.value + + def do_execution_exception(self, ex): + sst = self.get_sst(ex) + sst.session.error = (ex,) + + def dispatch(self): + try: + if self._socket is None and self.connection._connected and not self._opening: + self.connect() + elif self._socket is not None and not self.connection._connected and not self._closing: + self.disconnect() + + if self._connected and not self._closing: + for ssn in self.connection.sessions.values(): + self.attach(ssn) + self.process(ssn) + except: + msg = compat.format_exc() + self.connection.error = (msg,) + + def connect(self): + try: + # XXX: should make this non blocking + self._socket = connect(self.connection.host, self.connection.port) + self._timeout = None + except socket.error, e: + if self.connection.reconnect: + self._error(e, True) + return + else: + raise e + self._buf += struct.pack(HEADER, "AMQP", 1, 1, 0, 10) + self._opening = True + + def disconnect(self): + self.write_op(ConnectionClose(close_code.normal)) + self._closing = True + + def attach(self, ssn): + sst = self._attachments.get(ssn) + if sst is None and not ssn.closed: + for i in xrange(0, self.channel_max): + if not self._sessions.has_key(i): + ch = i + break + else: + raise RuntimeError("all channels used") + sst = SessionState(self, ssn, ssn.name, ch) + sst.write_op(SessionAttach(name=ssn.name)) + sst.write_op(SessionCommandPoint(sst.sent, 0)) + sst.outgoing_idx = 0 + sst.acked = [] + if ssn.transactional: + sst.write_cmd(TxSelect()) + self._attachments[ssn] = sst + self._sessions[sst.channel] = sst + + for snd in ssn.senders: + self.link_out(snd) + for rcv in ssn.receivers: + self.link_in(rcv) + + if sst is not None and ssn.closing and not sst.detached: + sst.detached = True + sst.write_op(SessionDetach(name=ssn.name)) + + def get_sst(self, op): + return self._sessions[op.channel] + + def do_session_detached(self, dtc): + sst = self._sessions.pop(dtc.channel) + ssn = sst.session + del self._attachments[ssn] + ssn.closed = True + + def do_session_detach(self, dtc): + sst = self.get_sst(dtc) + sst.write_op(SessionDetached(name=dtc.name)) + self.do_session_detached(dtc) + + def link_out(self, snd): + sst = self._attachments.get(snd.session) + _snd = self._attachments.get(snd) + if _snd is None and not snd.closing and not snd.closed: + _snd = Attachment(snd) + _snd.closing = False + + if snd.target is None: + snd.error = ("target is None",) + snd.closed = True + return + + try: + _snd.name, _snd.subject, _snd.options = address.parse(snd.target) + except address.LexError, e: + snd.error = (e,) + snd.closed = True + return + except address.ParseError, e: + snd.error = (e,) + snd.closed = True + return + + # XXX: subject + if _snd.options is None: + _snd.options = {} + + def do_link(type, subtype): + if type == "topic": + _snd._exchange = _snd.name + _snd._routing_key = _snd.subject + elif type == "queue": + _snd._exchange = "" + _snd._routing_key = _snd.name + + snd.linked = True + + self.resolve_declare(sst, _snd, "sender", do_link) + self._attachments[snd] = _snd + + if snd.linked and snd.closing and not (snd.closed or _snd.closing): + _snd.closing = True + def do_unlink(): + del self._attachments[snd] + snd.closed = True + if _snd.options.get("delete") in ("always", "sender"): + self.delete(sst, _snd.name, do_unlink) + else: + do_unlink() + + def link_in(self, rcv): + sst = self._attachments.get(rcv.session) + _rcv = self._attachments.get(rcv) + if _rcv is None and not rcv.closing and not rcv.closed: + _rcv = Attachment(rcv) + _rcv.canceled = False + _rcv.draining = False + + if rcv.source is None: + rcv.error = ("source is None",) + rcv.closed = True + return + + try: + _rcv.name, _rcv.subject, _rcv.options = address.parse(rcv.source) + except address.LexError, e: + rcv.error = (e,) + rcv.closed = True + return + except address.ParseError, e: + rcv.error = (e,) + rcv.closed = True + return + + # XXX: subject + if _rcv.options is None: + _rcv.options = {} + + def do_link(type, subtype): + if type == "topic": + _rcv._queue = "%s.%s" % (rcv.session.name, rcv.destination) + sst.write_cmd(QueueDeclare(queue=_rcv._queue, durable=DURABLE_DEFAULT, exclusive=True, auto_delete=True)) + filter = _rcv.options.get("filter") + if _rcv.subject is None and filter is None: + f = FILTER_DEFAULTS[subtype] + elif _rcv.subject and filter: + # XXX + raise Exception("can't supply both subject and filter") + elif _rcv.subject: + # XXX + from messaging import Pattern + f = Pattern(_rcv.subject) + else: + f = filter + f._bind(sst, _rcv.name, _rcv._queue) + elif type == "queue": + _rcv._queue = _rcv.name + + sst.write_cmd(MessageSubscribe(queue=_rcv._queue, destination=rcv.destination)) + sst.write_cmd(MessageSetFlowMode(rcv.destination, flow_mode.credit)) + rcv.linked = True + + self.resolve_declare(sst, _rcv, "receiver", do_link) + self._attachments[rcv] = _rcv + + if rcv.linked and rcv.closing and not rcv.closed: + if not _rcv.canceled: + def do_unlink(): + del self._attachments[rcv] + rcv.closed = True + if _rcv.options.get("delete") in ("always", "receiver"): + sst.write_cmd(MessageCancel(rcv.destination)) + self.delete(sst, _rcv.name, do_unlink) + else: + sst.write_cmd(MessageCancel(rcv.destination), do_unlink) + _rcv.canceled = True + + def resolve_declare(self, sst, lnk, dir, action): + def do_resolved(er, qr): + if er.not_found and not qr.queue: + if lnk.options.get("create") in ("always", dir): + err = self.declare(sst, lnk.name, lnk.options, action) + else: + err = ("no such queue: %s" % lnk.name,) + + if err: + tgt = lnk.target + tgt.error = err + del self._attachments[tgt] + tgt.closed = True + return + elif qr.queue: + action("queue", None) + else: + action("topic", er.type) + self.resolve(sst, lnk.name, do_resolved) + + def resolve(self, sst, name, action): + args = [] + def do_result(r): + args.append(r) + def do_action(r): + do_result(r) + action(*args) + sst.write_query(ExchangeQuery(name), do_result) + sst.write_query(QueueQuery(name), do_action) + + def declare(self, sst, name, options, action): + opts = dict(options) + props = dict(opts.pop("node-properties", {})) + durable = props.pop("durable", DURABLE_DEFAULT) + type = props.pop("type", "queue") + xprops = dict(props.pop("x-properties", {})) + + if props: + return ("unrecognized option(s): %s" % "".join(props.keys()),) + + if type == "topic": + cmd = ExchangeDeclare(exchange=name, durable=durable) + elif type == "queue": + cmd = QueueDeclare(queue=name, durable=durable) + bindings = xprops.pop("bindings", []) + else: + return ("unrecognized type, must be topic or queue: %s" % type,) + + for f in cmd.FIELDS: + if f.name != "arguments" and xprops.has_key(f.name): + cmd[f.name] = xprops.pop(f.name) + if xprops: + cmd.arguments = xprops + + if type == "topic": + if cmd.type is None: + cmd.type = "topic" + subtype = cmd.type + else: + subtype = None + + cmds = [cmd] + if type == "queue": + for b in bindings: + try: + n, s, o = address.parse(b) + except address.ParseError, e: + return (e,) + cmds.append(ExchangeBind(name, n, s, o)) + + for c in cmds[:-1]: + sst.write_cmd(c) + def do_action(): + action(type, subtype) + sst.write_cmd(cmds[-1], do_action) + + def delete(self, sst, name, action): + def do_delete(er, qr): + if not er.not_found: + sst.write_cmd(ExchangeDelete(name), action) + elif qr.queue: + sst.write_cmd(QueueDelete(name), action) + else: + action() + self.resolve(sst, name, do_delete) + + def process(self, ssn): + if ssn.closed or ssn.closing: return + + sst = self._attachments[ssn] + + while sst.outgoing_idx < len(ssn.outgoing): + msg = ssn.outgoing[sst.outgoing_idx] + snd = msg._sender + # XXX: should check for sender error here + _snd = self._attachments.get(snd) + if _snd and snd.linked: + self.send(snd, msg) + sst.outgoing_idx += 1 + else: + break + + for rcv in ssn.receivers: + self.process_receiver(rcv) + + if ssn.acked: + messages = [m for m in ssn.acked if m not in sst.acked] + if messages: + # XXX: we're ignoring acks that get lost when disconnected, + # could we deal this via some message-id based purge? + ids = RangedSet(*[m._transfer_id for m in messages if m._transfer_id is not None]) + for range in ids: + sst.executed.add_range(range) + sst.write_op(SessionCompleted(sst.executed)) + def ack_ack(): + for m in messages: + ssn.acked.remove(m) + if not ssn.transactional: + sst.acked.remove(m) + sst.write_cmd(MessageAccept(ids), ack_ack) + sst.acked.extend(messages) + + if ssn.committing and not sst.committing: + def commit_ok(): + del sst.acked[:] + ssn.committing = False + ssn.committed = True + ssn.aborting = False + ssn.aborted = False + sst.write_cmd(TxCommit(), commit_ok) + sst.committing = True + + if ssn.aborting and not sst.aborting: + sst.aborting = True + def do_rb(): + messages = sst.acked + ssn.unacked + ssn.incoming + ids = RangedSet(*[m._transfer_id for m in messages]) + for range in ids: + sst.executed.add_range(range) + sst.write_op(SessionCompleted(sst.executed)) + sst.write_cmd(MessageRelease(ids)) + sst.write_cmd(TxRollback(), do_rb_ok) + + def do_rb_ok(): + del ssn.incoming[:] + del ssn.unacked[:] + del sst.acked[:] + + for rcv in ssn.receivers: + rcv.impending = rcv.received + rcv.returned = rcv.received + # XXX: do we need to update granted here as well? + + for rcv in ssn.receivers: + self.process_receiver(rcv) + + ssn.aborting = False + ssn.aborted = True + ssn.committing = False + ssn.committed = False + sst.aborting = False + + for rcv in ssn.receivers: + sst.write_cmd(MessageStop(rcv.destination)) + sst.write_cmd(ExecutionSync(), do_rb) + + def grant(self, rcv): + sst = self._attachments[rcv.session] + _rcv = self._attachments.get(rcv) + if _rcv is None or not rcv.linked or _rcv.canceled or _rcv.draining: + return + + if rcv.granted is UNLIMITED: + if rcv.impending is UNLIMITED: + delta = 0 + else: + delta = UNLIMITED + elif rcv.impending is UNLIMITED: + delta = -1 + else: + delta = max(rcv.granted, rcv.received) - rcv.impending + + if delta is UNLIMITED: + sst.write_cmd(MessageFlow(rcv.destination, credit_unit.byte, UNLIMITED.value)) + sst.write_cmd(MessageFlow(rcv.destination, credit_unit.message, UNLIMITED.value)) + rcv.impending = UNLIMITED + elif delta > 0: + sst.write_cmd(MessageFlow(rcv.destination, credit_unit.byte, UNLIMITED.value)) + sst.write_cmd(MessageFlow(rcv.destination, credit_unit.message, delta)) + rcv.impending += delta + elif delta < 0 and not rcv.draining: + _rcv.draining = True + def do_stop(): + rcv.impending = rcv.received + _rcv.draining = False + self.grant(rcv) + sst.write_cmd(MessageStop(rcv.destination), do_stop) + + if rcv.draining: + _rcv.draining = True + def do_flush(): + rcv.impending = rcv.received + rcv.granted = rcv.impending + _rcv.draining = False + rcv.draining = False + sst.write_cmd(MessageFlush(rcv.destination), do_flush) + + + def process_receiver(self, rcv): + if rcv.closed: return + self.grant(rcv) + + def send(self, snd, msg): + sst = self._attachments[snd.session] + _snd = self._attachments[snd] + + # XXX: what if subject is specified for a normal queue? + if _snd._routing_key is None: + rk = msg.subject + else: + rk = _snd._routing_key + # XXX: do we need to query to figure out how to create the reply-to interoperably? + if msg.reply_to: + rt = addr2reply_to(msg.reply_to) + else: + rt = None + dp = DeliveryProperties(routing_key=rk) + mp = MessageProperties(message_id=msg.id, + user_id=msg.user_id, + reply_to=rt, + correlation_id=msg.correlation_id, + content_type=msg.content_type, + application_headers=msg.properties) + if msg.subject is not None: + if mp.application_headers is None: + mp.application_headers = {} + mp.application_headers[SUBJECT] = msg.subject + if msg.to is not None: + if mp.application_headers is None: + mp.application_headers = {} + mp.application_headers[TO] = msg.to + if msg.durable: + dp.delivery_mode = delivery_mode.persistent + enc, dec = get_codec(msg.content_type) + body = enc(msg.content) + def msg_acked(): + # XXX: should we log the ack somehow too? + snd.acked += 1 + m = snd.session.outgoing.pop(0) + sst.outgoing_idx -= 1 + assert msg == m + sst.write_cmd(MessageTransfer(destination=_snd._exchange, headers=(dp, mp), + payload=body), msg_acked) + + def do_message_transfer(self, xfr): + sst = self.get_sst(xfr) + ssn = sst.session + + msg = self._decode(xfr) + rcv = ssn.receivers[int(xfr.destination)] + msg._receiver = rcv + if rcv.impending is not UNLIMITED: + assert rcv.received < rcv.impending, "%s, %s" % (rcv.received, rcv.impending) + rcv.received += 1 + log.debug("RECV [%s] %s", ssn, msg) + ssn.incoming.append(msg) + self.connection._waiter.notifyAll() + + def _decode(self, xfr): + dp = EMPTY_DP + mp = EMPTY_MP + + for h in xfr.headers: + if isinstance(h, DeliveryProperties): + dp = h + elif isinstance(h, MessageProperties): + mp = h + + ap = mp.application_headers + enc, dec = get_codec(mp.content_type) + content = dec(xfr.payload) + msg = Message(content) + msg.id = mp.message_id + if ap is not None: + msg.to = ap.get(TO) + msg.subject = ap.get(SUBJECT) + msg.user_id = mp.user_id + if mp.reply_to is not None: + msg.reply_to = reply_to2addr(mp.reply_to) + msg.correlation_id = mp.correlation_id + msg.durable = dp.delivery_mode == delivery_mode.persistent + msg.redelivered = dp.redelivered + msg.properties = mp.application_headers + msg.content_type = mp.content_type + msg._transfer_id = xfr.id + return msg diff --git a/python/qpid/exceptions.py b/python/qpid/exceptions.py index 7eaaf81ed4..2bd80b7ffe 100644 --- a/python/qpid/exceptions.py +++ b/python/qpid/exceptions.py @@ -19,3 +19,4 @@ class Closed(Exception): pass class Timeout(Exception): pass +class VersionError(Exception): pass diff --git a/python/qpid/framer.py b/python/qpid/framer.py index f6363b2291..47f57cf649 100644 --- a/python/qpid/framer.py +++ b/python/qpid/framer.py @@ -26,47 +26,6 @@ from logging import getLogger raw = getLogger("qpid.io.raw") frm = getLogger("qpid.io.frm") -FIRST_SEG = 0x08 -LAST_SEG = 0x04 -FIRST_FRM = 0x02 -LAST_FRM = 0x01 - -class Frame: - - HEADER = "!2BHxBH4x" - MAX_PAYLOAD = 65535 - struct.calcsize(HEADER) - - def __init__(self, flags, type, track, channel, payload): - if len(payload) > Frame.MAX_PAYLOAD: - raise ValueError("max payload size exceeded: %s" % len(payload)) - self.flags = flags - self.type = type - self.track = track - self.channel = channel - self.payload = payload - - def isFirstSegment(self): - return bool(FIRST_SEG & self.flags) - - def isLastSegment(self): - return bool(LAST_SEG & self.flags) - - def isFirstFrame(self): - return bool(FIRST_FRM & self.flags) - - def isLastFrame(self): - return bool(LAST_FRM & self.flags) - - def __str__(self): - return "%s%s%s%s %s %s %s %r" % (int(self.isFirstSegment()), - int(self.isLastSegment()), - int(self.isFirstFrame()), - int(self.isLastFrame()), - self.type, - self.track, - self.channel, - self.payload) - class FramingError(Exception): pass class Framer(Packer): @@ -76,19 +35,29 @@ class Framer(Packer): def __init__(self, sock): self.sock = sock self.sock_lock = RLock() - self._buf = "" + self.tx_buf = "" + self.rx_buf = "" + self.security_layer_tx = None + self.security_layer_rx = None + self.maxbufsize = 65535 def aborted(self): return False def write(self, buf): - self._buf += buf + self.tx_buf += buf def flush(self): self.sock_lock.acquire() try: - self._write(self._buf) - self._buf = "" + if self.security_layer_tx: + status, cipher_buf = self.security_layer_tx.encode(self.tx_buf) + if status == False: + raise Closed(self.security_layer_tx.getError()) + self._write(cipher_buf) + else: + self._write(self.tx_buf) + self.tx_buf = "" frm.debug("FLUSHED") finally: self.sock_lock.release() @@ -105,25 +74,42 @@ class Framer(Packer): raw.debug("SENT %r", buf[:n]) buf = buf[n:] + ## + ## Implementation Note: + ## + ## This function was modified to use the SASL security layer for content + ## decryption. As such, the socket read should read in "self.maxbufsize" + ## instead of "n" (the requested number of octets). However, since this + ## is one of two places in the code where the socket is read, the read + ## size had to be left at "n". This is because this function is + ## apparently only used to read the first 8 octets from a TCP socket. If + ## we read beyond "n" octets, the remaing octets won't be processed and + ## the connection handshake will fail. + ## def read(self, n): - data = "" - while len(data) < n: + while len(self.rx_buf) < n: try: - s = self.sock.recv(n - len(data)) + s = self.sock.recv(n) # NOTE: instead of "n", arg should be "self.maxbufsize" + if self.security_layer_rx: + status, s = self.security_layer_rx.decode(s) + if status == False: + raise Closed(self.security_layer_tx.getError()) except socket.timeout: if self.aborted(): raise Closed() else: continue except socket.error, e: - if data != "": + if self.rx_buf != "": raise e else: raise Closed() if len(s) == 0: raise Closed() - data += s + self.rx_buf += s raw.debug("RECV %r", s) + data = self.rx_buf[0:n] + self.rx_buf = self.rx_buf[n:] return data def read_header(self): @@ -136,24 +122,3 @@ class Framer(Packer): self.flush() finally: self.sock_lock.release() - - def write_frame(self, frame): - self.sock_lock.acquire() - try: - size = len(frame.payload) + struct.calcsize(Frame.HEADER) - track = frame.track & 0x0F - self.pack(Frame.HEADER, frame.flags, frame.type, size, track, frame.channel) - self.write(frame.payload) - if frame.isLastSegment() and frame.isLastFrame(): - self.flush() - frm.debug("SENT %s", frame) - finally: - self.sock_lock.release() - - def read_frame(self): - flags, type, size, track, channel = self.unpack(Frame.HEADER) - if flags & 0xF0: raise FramingError() - payload = self.read(size - struct.calcsize(Frame.HEADER)) - frame = Frame(flags, type, track, channel, payload) - frm.debug("RECV %s", frame) - return frame diff --git a/python/qpid/framing.py b/python/qpid/framing.py new file mode 100644 index 0000000000..0a8f26272c --- /dev/null +++ b/python/qpid/framing.py @@ -0,0 +1,310 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import struct + +FIRST_SEG = 0x08 +LAST_SEG = 0x04 +FIRST_FRM = 0x02 +LAST_FRM = 0x01 + +class Frame: + + HEADER = "!2BHxBH4x" + HEADER_SIZE = struct.calcsize(HEADER) + MAX_PAYLOAD = 65535 - struct.calcsize(HEADER) + + def __init__(self, flags, type, track, channel, payload): + if len(payload) > Frame.MAX_PAYLOAD: + raise ValueError("max payload size exceeded: %s" % len(payload)) + self.flags = flags + self.type = type + self.track = track + self.channel = channel + self.payload = payload + + def isFirstSegment(self): + return bool(FIRST_SEG & self.flags) + + def isLastSegment(self): + return bool(LAST_SEG & self.flags) + + def isFirstFrame(self): + return bool(FIRST_FRM & self.flags) + + def isLastFrame(self): + return bool(LAST_FRM & self.flags) + + def __repr__(self): + return "%s%s%s%s %s %s %s %r" % (int(self.isFirstSegment()), + int(self.isLastSegment()), + int(self.isFirstFrame()), + int(self.isLastFrame()), + self.type, + self.track, + self.channel, + self.payload) + +class Segment: + + def __init__(self, first, last, type, track, channel, payload): + self.id = None + self.offset = None + self.first = first + self.last = last + self.type = type + self.track = track + self.channel = channel + self.payload = payload + + def __repr__(self): + return "%s%s %s %s %s %r" % (int(self.first), int(self.last), self.type, + self.track, self.channel, self.payload) + +class FrameDecoder: + + def __init__(self): + self.input = "" + self.output = [] + self.parse = self.__frame_header + + def write(self, bytes): + self.input += bytes + while True: + next = self.parse() + if next is None: + break + else: + self.parse = next + + def __consume(self, n): + result = self.input[:n] + self.input = self.input[n:] + return result + + def __frame_header(self): + if len(self.input) >= Frame.HEADER_SIZE: + st = self.__consume(Frame.HEADER_SIZE) + self.flags, self.type, self.size, self.track, self.channel = \ + struct.unpack(Frame.HEADER, st) + return self.__frame_body + + def __frame_body(self): + size = self.size - Frame.HEADER_SIZE + if len(self.input) >= size: + payload = self.__consume(size) + frame = Frame(self.flags, self.type, self.track, self.channel, payload) + self.output.append(frame) + return self.__frame_header + + def read(self): + result = self.output + self.output = [] + return result + +class FrameEncoder: + + def __init__(self): + self.output = "" + + def write(self, *frames): + for frame in frames: + size = len(frame.payload) + Frame.HEADER_SIZE + track = frame.track & 0x0F + self.output += struct.pack(Frame.HEADER, frame.flags, frame.type, size, + track, frame.channel) + self.output += frame.payload + + def read(self): + result = self.output + self.output = "" + return result + +class SegmentDecoder: + + def __init__(self): + self.fragments = {} + self.segments = [] + + def write(self, *frames): + for frm in frames: + key = (frm.channel, frm.track) + seg = self.fragments.get(key) + + if seg == None: + seg = Segment(frm.isFirstSegment(), frm.isLastSegment(), + frm.type, frm.track, frm.channel, "") + self.fragments[key] = seg + + seg.payload += frm.payload + + if frm.isLastFrame(): + self.fragments.pop(key) + self.segments.append(seg) + + def read(self): + result = self.segments + self.segments = [] + return result + +class SegmentEncoder: + + def __init__(self, max_payload=Frame.MAX_PAYLOAD): + self.max_payload = max_payload + self.frames = [] + + def write(self, *segments): + for seg in segments: + remaining = seg.payload + + first = True + while first or remaining: + payload = remaining[:self.max_payload] + remaining = remaining[self.max_payload:] + + flags = 0 + if first: + flags |= FIRST_FRM + first = False + if not remaining: + flags |= LAST_FRM + if seg.first: + flags |= FIRST_SEG + if seg.last: + flags |= LAST_SEG + + frm = Frame(flags, seg.type, seg.track, seg.channel, payload) + self.frames.append(frm) + + def read(self): + result = self.frames + self.frames = [] + return result + +from ops import COMMANDS, CONTROLS, COMPOUND, Header, segment_type, track +from spec import SPEC + +from codec010 import StringCodec + +class OpEncoder: + + def __init__(self): + self.segments = [] + + def write(self, *ops): + for op in ops: + if COMMANDS.has_key(op.NAME): + seg_type = segment_type.command + seg_track = track.command + enc = self.encode_command(op) + elif CONTROLS.has_key(op.NAME): + seg_type = segment_type.control + seg_track = track.control + enc = self.encode_compound(op) + else: + raise ValueError(op) + seg = Segment(True, False, seg_type, seg_track, op.channel, enc) + self.segments.append(seg) + if hasattr(op, "headers") and op.headers is not None: + hdrs = "" + for h in op.headers: + hdrs += self.encode_compound(h) + seg = Segment(False, False, segment_type.header, seg_track, op.channel, + hdrs) + self.segments.append(seg) + if hasattr(op, "payload") and op.payload is not None: + self.segments.append(Segment(False, False, segment_type.body, seg_track, + op.channel, op.payload)) + self.segments[-1].last = True + + def encode_command(self, cmd): + sc = StringCodec() + sc.write_uint16(cmd.CODE) + sc.write_compound(Header(sync=cmd.sync)) + sc.write_fields(cmd) + return sc.encoded + + def encode_compound(self, op): + sc = StringCodec() + sc.write_compound(op) + return sc.encoded + + def read(self): + result = self.segments + self.segments = [] + return result + +class OpDecoder: + + def __init__(self): + self.op = None + self.ops = [] + + def write(self, *segments): + for seg in segments: + if seg.first: + if seg.type == segment_type.command: + self.op = self.decode_command(seg.payload) + elif seg.type == segment_type.control: + self.op = self.decode_control(seg.payload) + else: + raise ValueError(seg) + self.op.channel = seg.channel + elif seg.type == segment_type.header: + if self.op.headers is None: + self.op.headers = [] + self.op.headers.extend(self.decode_headers(seg.payload)) + elif seg.type == segment_type.body: + if self.op.payload is None: + self.op.payload = seg.payload + else: + self.op.payload += seg.payload + if seg.last: + self.ops.append(self.op) + self.op = None + + def decode_command(self, encoded): + sc = StringCodec(encoded) + code = sc.read_uint16() + cls = COMMANDS[code] + hdr = sc.read_compound(Header) + cmd = cls() + sc.read_fields(cmd) + cmd.sync = hdr.sync + return cmd + + def decode_control(self, encoded): + sc = StringCodec(encoded) + code = sc.read_uint16() + cls = CONTROLS[code] + ctl = cls() + sc.read_fields(ctl) + return ctl + + def decode_headers(self, encoded): + sc = StringCodec(encoded) + result = [] + while sc.encoded: + result.append(sc.read_struct32()) + return result + + def read(self): + result = self.ops + self.ops = [] + return result diff --git a/python/qpid/generator.py b/python/qpid/generator.py new file mode 100644 index 0000000000..02d11e5005 --- /dev/null +++ b/python/qpid/generator.py @@ -0,0 +1,56 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import sys + +from ops import * + +def METHOD(module, op): + method = lambda self, *args, **kwargs: self.invoke(op, args, kwargs) + if sys.version_info[:2] > (2, 3): + method.__name__ = op.__name__ + method.__doc__ = op.__doc__ + method.__module__ = module + return method + +def generate(module, operations): + dict = {} + + for name, enum in ENUMS.items(): + if isinstance(name, basestring): + dict[name] = enum + + for name, op in COMPOUND.items(): + if isinstance(name, basestring): + dict[name] = METHOD(module, op) + + for name, op in operations.items(): + if isinstance(name, basestring): + dict[name] = METHOD(module, op) + + return dict + +def invoker(name, operations): + return type(name, (), generate(invoker.__module__, operations)) + +def command_invoker(): + return invoker("CommandInvoker", COMMANDS) + +def control_invoker(): + return invoker("ControlInvoker", CONTROLS) diff --git a/python/qpid/harness.py b/python/qpid/harness.py new file mode 100644 index 0000000000..ce48481612 --- /dev/null +++ b/python/qpid/harness.py @@ -0,0 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +class Skipped(Exception): pass diff --git a/python/qpid/invoker.py b/python/qpid/invoker.py deleted file mode 100644 index 635f3ee769..0000000000 --- a/python/qpid/invoker.py +++ /dev/null @@ -1,48 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# - -import sys - -# TODO: need a better naming for this class now that it does the value -# stuff -class Invoker: - - def METHOD(self, name, resolved): - method = lambda *args, **kwargs: self.invoke(resolved, args, kwargs) - if sys.version_info[:2] > (2, 3): - method.__name__ = resolved.pyname - method.__doc__ = resolved.pydoc - method.__module__ = self.__class__.__module__ - self.__dict__[name] = method - return method - - def VALUE(self, name, resolved): - self.__dict__[name] = resolved - return resolved - - def ERROR(self, name, resolved): - raise AttributeError("%s instance has no attribute '%s'" % - (self.__class__.__name__, name)) - - def resolve_method(self, name): - return ERROR, None - - def __getattr__(self, name): - disp, resolved = self.resolve_method(name) - return disp(name, resolved) diff --git a/python/qpid/lexer.py b/python/qpid/lexer.py new file mode 100644 index 0000000000..87845560eb --- /dev/null +++ b/python/qpid/lexer.py @@ -0,0 +1,112 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +import re + +class Type: + + def __init__(self, name, pattern=None): + self.name = name + self.pattern = pattern + + def __repr__(self): + return self.name + +class Lexicon: + + def __init__(self): + self.types = [] + self._eof = None + + def define(self, name, pattern): + t = Type(name, pattern) + self.types.append(t) + return t + + def eof(self, name): + t = Type(name) + self._eof = t + return t + + def compile(self): + types = self.types[:] + joined = "|".join(["(%s)" % t.pattern for t in types]) + rexp = re.compile(joined) + return Lexer(types, self._eof, rexp) + +class Token: + + def __init__(self, type, value, input, position): + self.type = type + self.value = value + self.input = input + self.position = position + + def line_info(self): + return line_info(self.input, self.position) + + def __repr__(self): + if self.value is None: + return repr(self.type) + else: + return "%s(%r)" % (self.type, self.value) + + +class LexError(Exception): + pass + +def line_info(st, pos): + idx = 0 + lineno = 1 + column = 0 + line_pos = 0 + while idx < pos: + if st[idx] == "\n": + lineno += 1 + column = 0 + line_pos = idx + column += 1 + idx += 1 + + end = st.find("\n", line_pos) + if end < 0: + end = len(st) + line = st[line_pos:end] + + return line, lineno, column + +class Lexer: + + def __init__(self, types, eof, rexp): + self.types = types + self.eof = eof + self.rexp = rexp + + def lex(self, st): + pos = 0 + while pos < len(st): + m = self.rexp.match(st, pos) + if m is None: + line, ln, col = line_info(st, pos) + raise LexError("unrecognized characters line:%s,%s: %s" % (ln, col, line)) + else: + idx = m.lastindex + t = Token(self.types[idx - 1], m.group(idx), st, pos) + yield t + pos = m.end() + yield Token(self.eof, None, st, pos) diff --git a/python/qpid/management.py b/python/qpid/management.py index 83c29a78a5..325ab4903d 100644 --- a/python/qpid/management.py +++ b/python/qpid/management.py @@ -17,6 +17,10 @@ # under the License. # +############################################################################### +## This file is being obsoleted by qmf/console.py +############################################################################### + """ Management API for Qpid """ @@ -69,6 +73,57 @@ class mgmtObject (object): for cell in row: setattr (self, cell[0], cell[1]) +class objectId(object): + """ Object that represents QMF object identifiers """ + + def __init__(self, codec, first=0, second=0): + if codec: + self.first = codec.read_uint64() + self.second = codec.read_uint64() + else: + self.first = first + self.second = second + + def __cmp__(self, other): + if other == None: + return 1 + if self.first < other.first: + return -1 + if self.first > other.first: + return 1 + if self.second < other.second: + return -1 + if self.second > other.second: + return 1 + return 0 + + + def index(self): + return (self.first, self.second) + + def getFlags(self): + return (self.first & 0xF000000000000000) >> 60 + + def getSequence(self): + return (self.first & 0x0FFF000000000000) >> 48 + + def getBroker(self): + return (self.first & 0x0000FFFFF0000000) >> 28 + + def getBank(self): + return self.first & 0x000000000FFFFFFF + + def getObject(self): + return self.second + + def isDurable(self): + return self.getSequence() == 0 + + def encode(self, codec): + codec.write_uint64(self.first) + codec.write_uint64(self.second) + + class methodResult: """ Object that contains the result of a method call """ @@ -111,19 +166,23 @@ class managementChannel: ssn.exchange_bind (exchange="amq.direct", queue=self.replyName, binding_key=self.replyName) - ssn.message_subscribe (queue=self.topicName, destination="tdest") - ssn.message_subscribe (queue=self.replyName, destination="rdest") + ssn.message_subscribe (queue=self.topicName, destination="tdest", + accept_mode=ssn.accept_mode.none, + acquire_mode=ssn.acquire_mode.pre_acquired) + ssn.message_subscribe (queue=self.replyName, destination="rdest", + accept_mode=ssn.accept_mode.none, + acquire_mode=ssn.acquire_mode.pre_acquired) ssn.incoming ("tdest").listen (self.topicCb, self.exceptionCb) ssn.incoming ("rdest").listen (self.replyCb) ssn.message_set_flow_mode (destination="tdest", flow_mode=1) - ssn.message_flow (destination="tdest", unit=0, value=0xFFFFFFFF) - ssn.message_flow (destination="tdest", unit=1, value=0xFFFFFFFF) + ssn.message_flow (destination="tdest", unit=0, value=0xFFFFFFFFL) + ssn.message_flow (destination="tdest", unit=1, value=0xFFFFFFFFL) ssn.message_set_flow_mode (destination="rdest", flow_mode=1) - ssn.message_flow (destination="rdest", unit=0, value=0xFFFFFFFF) - ssn.message_flow (destination="rdest", unit=1, value=0xFFFFFFFF) + ssn.message_flow (destination="rdest", unit=0, value=0xFFFFFFFFL) + ssn.message_flow (destination="rdest", unit=1, value=0xFFFFFFFFL) def setBrokerInfo (self, data): self.brokerInfo = data @@ -151,9 +210,6 @@ class managementChannel: if self.enabled: self.qpidChannel.message_transfer (destination=exchange, message=msg) - def accept (self, msg): - self.qpidChannel.message_accept(RangedSet(msg.id)) - def message (self, body, routing_key="broker"): dp = self.qpidChannel.delivery_properties() dp.routing_key = routing_key @@ -178,8 +234,7 @@ class managementClient: #======================================================== # User API - interacts with the class's user #======================================================== - def __init__ (self, amqpSpec, ctrlCb=None, configCb=None, instCb=None, methodCb=None, closeCb=None): - self.spec = amqpSpec + def __init__ (self, unused=None, ctrlCb=None, configCb=None, instCb=None, methodCb=None, closeCb=None): self.ctrlCb = ctrlCb self.configCb = configCb self.instCb = instCb @@ -212,7 +267,7 @@ class managementClient: self.channels.append (mch) self.incOutstanding (mch) - codec = Codec (self.spec) + codec = Codec () self.setHeader (codec, ord ('B')) msg = mch.message(codec.encoded) mch.send ("qpid.management", msg) @@ -229,12 +284,12 @@ class managementClient: def getObjects (self, channel, userSequence, className, bank=0): """ Request immediate content from broker """ - codec = Codec (self.spec) + codec = Codec () self.setHeader (codec, ord ('G'), userSequence) ft = {} ft["_class"] = className codec.write_map (ft) - msg = channel.message(codec.encoded, routing_key="agent.%d" % bank) + msg = channel.message(codec.encoded, routing_key="agent.1.%d" % bank) channel.send ("qpid.management", msg) def syncWaitForStable (self, channel): @@ -297,27 +352,28 @@ class managementClient: #======================================================== def topicCb (self, ch, msg): """ Receive messages via the topic queue of a particular channel. """ - codec = Codec (self.spec, msg.body) - hdr = self.checkHeader (codec) - if hdr == None: - raise ValueError ("outer header invalid"); + codec = Codec (msg.body) + while True: + hdr = self.checkHeader (codec) + if hdr == None: + return - if hdr[0] == 'p': - self.handlePackageInd (ch, codec) - elif hdr[0] == 'q': - self.handleClassInd (ch, codec) - elif hdr[0] == 'h': - self.handleHeartbeat (ch, codec) - else: - self.parse (ch, codec, hdr[0], hdr[1]) - ch.accept(msg) + if hdr[0] == 'p': + self.handlePackageInd (ch, codec) + elif hdr[0] == 'q': + self.handleClassInd (ch, codec) + elif hdr[0] == 'h': + self.handleHeartbeat (ch, codec) + elif hdr[0] == 'e': + self.handleEvent (ch, codec) + else: + self.parse (ch, codec, hdr[0], hdr[1]) def replyCb (self, ch, msg): """ Receive messages via the reply queue of a particular channel. """ - codec = Codec (self.spec, msg.body) + codec = Codec (msg.body) hdr = self.checkHeader (codec) if hdr == None: - ch.accept(msg) return if hdr[0] == 'm': @@ -332,7 +388,6 @@ class managementClient: self.handleClassInd (ch, codec) else: self.parse (ch, codec, hdr[0], hdr[1]) - ch.accept(msg) def exceptCb (self, ch, data): if self.closeCb != None: @@ -345,25 +400,27 @@ class managementClient: """ Compose the header of a management message. """ codec.write_uint8 (ord ('A')) codec.write_uint8 (ord ('M')) - codec.write_uint8 (ord ('1')) + codec.write_uint8 (ord ('2')) codec.write_uint8 (opcode) codec.write_uint32 (seq) def checkHeader (self, codec): - """ Check the header of a management message and extract the opcode and - class. """ - octet = chr (codec.read_uint8 ()) - if octet != 'A': - return None - octet = chr (codec.read_uint8 ()) - if octet != 'M': + """ Check the header of a management message and extract the opcode and class. """ + try: + octet = chr (codec.read_uint8 ()) + if octet != 'A': + return None + octet = chr (codec.read_uint8 ()) + if octet != 'M': + return None + octet = chr (codec.read_uint8 ()) + if octet != '2': + return None + opcode = chr (codec.read_uint8 ()) + seq = codec.read_uint32 () + return (opcode, seq) + except: return None - octet = chr (codec.read_uint8 ()) - if octet != '1': - return None - opcode = chr (codec.read_uint8 ()) - seq = codec.read_uint32 () - return (opcode, seq) def encodeValue (self, codec, value, typecode): """ Encode, into the codec, a value based on its typecode. """ @@ -380,19 +437,19 @@ class managementClient: elif typecode == 6: codec.write_str8 (value) elif typecode == 7: - codec.write_vbin32 (value) + codec.write_str16 (value) elif typecode == 8: # ABSTIME codec.write_uint64 (long (value)) elif typecode == 9: # DELTATIME codec.write_uint64 (long (value)) elif typecode == 10: # REF - codec.write_uint64 (long (value)) + value.encode(codec) elif typecode == 11: # BOOL codec.write_uint8 (int (value)) elif typecode == 12: # FLOAT codec.write_float (float (value)) elif typecode == 13: # DOUBLE - codec.write_double (double (value)) + codec.write_double (float (value)) elif typecode == 14: # UUID codec.write_uuid (value) elif typecode == 15: # FTABLE @@ -421,15 +478,15 @@ class managementClient: elif typecode == 5: data = codec.read_uint8 () elif typecode == 6: - data = str (codec.read_str8 ()) + data = codec.read_str8 () elif typecode == 7: - data = codec.read_vbin32 () + data = codec.read_str16 () elif typecode == 8: # ABSTIME data = codec.read_uint64 () elif typecode == 9: # DELTATIME data = codec.read_uint64 () elif typecode == 10: # REF - data = codec.read_uint64 () + data = objectId(codec) elif typecode == 11: # BOOL data = codec.read_uint8 () elif typecode == 12: # FLOAT @@ -469,12 +526,14 @@ class managementClient: if self.ctrlCb != None: self.ctrlCb (ch.context, self.CTRL_SCHEMA_LOADED, None) ch.ssn.exchange_bind (exchange="qpid.management", - queue=ch.topicName, binding_key="mgmt.#") + queue=ch.topicName, binding_key="console.#") + ch.ssn.exchange_bind (exchange="qpid.management", + queue=ch.topicName, binding_key="schema.#") def handleMethodReply (self, ch, codec, sequence): status = codec.read_uint32 () - sText = str (codec.read_str8 ()) + sText = codec.read_str16 () data = self.seqMgr.release (sequence) if data == None: @@ -510,7 +569,7 @@ class managementClient: def handleCommandComplete (self, ch, codec, seq): code = codec.read_uint32 () - text = str (codec.read_str8 ()) + text = codec.read_str8 () data = (seq, code, text) context = self.seqMgr.release (seq) if context == "outstanding": @@ -530,19 +589,19 @@ class managementClient: self.ctrlCb (ch.context, self.CTRL_BROKER_INFO, ch.brokerInfo) # Send a package request - sendCodec = Codec (self.spec) + sendCodec = Codec () seq = self.seqMgr.reserve ("outstanding") self.setHeader (sendCodec, ord ('P'), seq) smsg = ch.message(sendCodec.encoded) ch.send ("qpid.management", smsg) def handlePackageInd (self, ch, codec): - pname = str (codec.read_str8 ()) + pname = codec.read_str8 () if pname not in self.packages: self.packages[pname] = {} # Send a class request - sendCodec = Codec (self.spec) + sendCodec = Codec () seq = self.seqMgr.reserve ("outstanding") self.setHeader (sendCodec, ord ('Q'), seq) self.incOutstanding (ch) @@ -551,15 +610,18 @@ class managementClient: ch.send ("qpid.management", smsg) def handleClassInd (self, ch, codec): - pname = str (codec.read_str8 ()) - cname = str (codec.read_str8 ()) - hash = codec.read_bin128 () + kind = codec.read_uint8() + if kind != 1: # This API doesn't handle new-style events + return + pname = codec.read_str8() + cname = codec.read_str8() + hash = codec.read_bin128() if pname not in self.packages: return if (cname, hash) not in self.packages[pname]: # Send a schema request - sendCodec = Codec (self.spec) + sendCodec = Codec () seq = self.seqMgr.reserve ("outstanding") self.setHeader (sendCodec, ord ('S'), seq) self.incOutstanding (ch) @@ -574,16 +636,49 @@ class managementClient: if self.ctrlCb != None: self.ctrlCb (ch.context, self.CTRL_HEARTBEAT, timestamp) + def handleEvent (self, ch, codec): + if self.eventCb == None: + return + timestamp = codec.read_uint64() + objId = objectId(codec) + packageName = codec.read_str8() + className = codec.read_str8() + hash = codec.read_bin128() + name = codec.read_str8() + classKey = (packageName, className, hash) + if classKey not in self.schema: + return; + schemaClass = self.schema[classKey] + row = [] + es = schemaClass['E'] + arglist = None + for ename in es: + (edesc, eargs) = es[ename] + if ename == name: + arglist = eargs + if arglist == None: + return + for arg in arglist: + row.append((arg[0], self.decodeValue(codec, arg[1]))) + self.eventCb(ch.context, classKey, objId, name, row) + def parseSchema (self, ch, codec): """ Parse a received schema-description message. """ self.decOutstanding (ch) - packageName = str (codec.read_str8 ()) - className = str (codec.read_str8 ()) + kind = codec.read_uint8() + if kind != 1: # This API doesn't handle new-style events + return + packageName = codec.read_str8 () + className = codec.read_str8 () hash = codec.read_bin128 () + hasSupertype = 0 #codec.read_uint8() configCount = codec.read_uint16 () instCount = codec.read_uint16 () methodCount = codec.read_uint16 () - eventCount = codec.read_uint16 () + if hasSupertype != 0: + supertypePackage = codec.read_str8() + supertypeClass = codec.read_str8() + supertypeHash = codec.read_bin128() if packageName not in self.packages: return @@ -597,22 +692,22 @@ class managementClient: configs = [] insts = [] methods = {} - events = [] configs.append (("id", 4, "", "", 1, 1, None, None, None, None, None)) insts.append (("id", 4, None, None)) for idx in range (configCount): ft = codec.read_map () - name = str (ft["name"]) - type = ft["type"] - access = ft["access"] - index = ft["index"] - unit = None - min = None - max = None - maxlen = None - desc = None + name = str (ft["name"]) + type = ft["type"] + access = ft["access"] + index = ft["index"] + optional = ft["optional"] + unit = None + min = None + max = None + maxlen = None + desc = None for key, value in ft.items (): if key == "unit": @@ -626,7 +721,7 @@ class managementClient: elif key == "desc": desc = str (value) - config = (name, type, unit, desc, access, index, min, max, maxlen) + config = (name, type, unit, desc, access, index, min, max, maxlen, optional) configs.append (config) for idx in range (instCount): @@ -689,11 +784,26 @@ class managementClient: schemaClass['C'] = configs schemaClass['I'] = insts schemaClass['M'] = methods - schemaClass['E'] = events self.schema[classKey] = schemaClass if self.schemaCb != None: - self.schemaCb (ch.context, classKey, configs, insts, methods, events) + self.schemaCb (ch.context, classKey, configs, insts, methods, {}) + + def parsePresenceMasks(self, codec, schemaClass): + """ Generate a list of not-present properties """ + excludeList = [] + bit = 0 + for element in schemaClass['C'][1:]: + if element[9] == 1: + if bit == 0: + mask = codec.read_uint8() + bit = 1 + if (mask & bit) == 0: + excludeList.append(element[0]) + bit = bit * 2 + if bit == 256: + bit = 0 + return excludeList def parseContent (self, ch, cls, codec, seq=0): """ Parse a received content message. """ @@ -702,8 +812,8 @@ class managementClient: if cls == 'I' and self.instCb == None: return - packageName = str (codec.read_str8 ()) - className = str (codec.read_str8 ()) + packageName = codec.read_str8 () + className = codec.read_str8 () hash = codec.read_bin128 () classKey = (packageName, className, hash) @@ -716,21 +826,26 @@ class managementClient: timestamps.append (codec.read_uint64 ()) # Current Time timestamps.append (codec.read_uint64 ()) # Create Time timestamps.append (codec.read_uint64 ()) # Delete Time - + objId = objectId(codec) schemaClass = self.schema[classKey] if cls == 'C' or cls == 'B': - for element in schemaClass['C'][:]: + notPresent = self.parsePresenceMasks(codec, schemaClass) + + if cls == 'C' or cls == 'B': + row.append(("id", objId)) + for element in schemaClass['C'][1:]: tc = element[1] name = element[0] - data = self.decodeValue (codec, tc) - row.append ((name, data)) + if name in notPresent: + row.append((name, None)) + else: + data = self.decodeValue(codec, tc) + row.append((name, data)) if cls == 'I' or cls == 'B': - if cls == 'B': - start = 1 - else: - start = 0 - for element in schemaClass['I'][start:]: + if cls == 'I': + row.append(("id", objId)) + for element in schemaClass['I'][1:]: tc = element[1] name = element[0] data = self.decodeValue (codec, tc) @@ -760,12 +875,15 @@ class managementClient: def method (self, channel, userSequence, objId, classId, methodName, args): """ Invoke a method on an object """ - codec = Codec (self.spec) + codec = Codec () sequence = self.seqMgr.reserve ((userSequence, classId, methodName)) self.setHeader (codec, ord ('M'), sequence) - codec.write_uint64 (objId) # ID of object + objId.encode(codec) + codec.write_str8 (classId[0]) + codec.write_str8 (classId[1]) + codec.write_bin128 (classId[2]) codec.write_str8 (methodName) - bank = (objId & 0x0000FFFFFF000000) >> 24 + bank = "%d.%d" % (objId.getBroker(), objId.getBank()) # Encode args according to schema if classId not in self.schema: @@ -795,5 +913,5 @@ class managementClient: packageName = classId[0] className = classId[1] - msg = channel.message(codec.encoded, "agent." + str(bank)) + msg = channel.message(codec.encoded, "agent." + bank) channel.send ("qpid.management", msg) diff --git a/python/qpid/managementdata.py b/python/qpid/managementdata.py index fc9eb391b7..61cb10c134 100644 --- a/python/qpid/managementdata.py +++ b/python/qpid/managementdata.py @@ -19,11 +19,19 @@ # under the License. # + +############################################################################### +## This file is being obsoleted by qmf/console.py +############################################################################### + import qpid import re import socket import struct import os +import platform +import locale +from qpid.connection import Timeout from qpid.management import managementChannel, managementClient from threading import Lock from disp import Display @@ -40,9 +48,11 @@ class Broker: if not match: raise ValueError("'%s' is not a valid broker url" % (text)) user, password, host, port = match.groups() - self.host = socket.gethostbyname (host) if port: self.port = int(port) else: self.port = 5672 + for addr in socket.getaddrinfo(host, self.port): + if addr[1] == socket.AF_INET: + self.host = addr[4][0] self.username = user or "guest" self.password = password or "guest" @@ -71,14 +81,14 @@ class ManagementData: # def registerObjId (self, objId): - if not objId in self.idBackMap: - self.idBackMap[objId] = self.nextId + if not objId.index() in self.idBackMap: + self.idBackMap[objId.index()] = self.nextId self.idMap[self.nextId] = objId self.nextId += 1 - def displayObjId (self, objId): - if objId in self.idBackMap: - return self.idBackMap[objId] + def displayObjId (self, objIdIndex): + if objIdIndex in self.idBackMap: + return self.idBackMap[objIdIndex] else: return 0 @@ -86,11 +96,16 @@ class ManagementData: if displayId in self.idMap: return self.idMap[displayId] else: - return 0 + return None def displayClassName (self, cls): (packageName, className, hash) = cls - return packageName + "." + className + rev = self.schema[cls][4] + if rev == 0: + suffix = "" + else: + suffix = ".%d" % rev + return packageName + ":" + className + suffix def dataHandler (self, context, className, list, timestamps): """ Callback for configuration and instrumentation data updates """ @@ -102,19 +117,20 @@ class ManagementData: self.tables[className] = {} # Register the ID so a more friendly presentation can be displayed - id = long (list[0][1]) - self.registerObjId (id) + objId = list[0][1] + oidx = objId.index() + self.registerObjId (objId) # If this object hasn't been seen before, create a new object record with # the timestamps and empty lists for configuration and instrumentation data. - if id not in self.tables[className]: - self.tables[className][id] = (timestamps, [], []) + if oidx not in self.tables[className]: + self.tables[className][oidx] = (timestamps, [], []) - (unused, oldConf, oldInst) = self.tables[className][id] + (unused, oldConf, oldInst) = self.tables[className][oidx] # For config updates, simply replace old config list with the new one. if context == 0: #config - self.tables[className][id] = (timestamps, list, oldInst) + self.tables[className][oidx] = (timestamps, list, oldInst) # For instrumentation updates, carry the minimum and maximum values for # "hi-lo" stats forward. @@ -132,7 +148,7 @@ class ManagementData: if oldInst[idx][1] < value: value = oldInst[idx][1] newInst.append ((key, value)) - self.tables[className][id] = (timestamps, oldConf, newInst) + self.tables[className][oidx] = (timestamps, oldConf, newInst) finally: self.lock.release () @@ -190,15 +206,25 @@ class ManagementData: self.lastUnit = None self.methodSeq = 1 self.methodsPending = {} - self.sessionId = "%s.%d" % (os.uname()[1], os.getpid()) + self.sessionId = "%s.%d" % (platform.uname()[1], os.getpid()) self.broker = Broker (host) - self.conn = Connection (connect (self.broker.host, self.broker.port), + sock = connect (self.broker.host, self.broker.port) + oldTimeout = sock.gettimeout() + sock.settimeout(10) + self.conn = Connection (sock, username=self.broker.username, password=self.broker.password) - self.spec = self.conn.spec + def aborted(): + raise Timeout("Waiting for connection to be established with broker") + oldAborted = self.conn.aborted + self.conn.aborted = aborted + self.conn.start () - self.mclient = managementClient (self.spec, self.ctrlHandler, self.configHandler, + sock.settimeout(oldTimeout) + self.conn.aborted = oldAborted + + self.mclient = managementClient ("unused", self.ctrlHandler, self.configHandler, self.instHandler, self.methodReply, self.closeHandler) self.mclient.schemaListener (self.schemaHandler) self.mch = self.mclient.addChannel (self.conn.session(self.sessionId)) @@ -211,11 +237,13 @@ class ManagementData: pass def refName (self, oid): - if oid == 0: + if oid == None: return "NULL" - return str (self.displayObjId (oid)) + return str (self.displayObjId (oid.index())) def valueDisplay (self, classKey, key, value): + if value == None: + return "<NULL>" for kind in range (2): schema = self.schema[classKey][kind] for item in schema: @@ -248,7 +276,7 @@ class ManagementData: else: return "True" elif typecode == 14: - return "%08x-%04x-%04x-%04x-%04x%08x" % struct.unpack ("!LHHHHL", value) + return str (value) elif typecode == 15: return str (value) return "*type-error*" @@ -267,14 +295,21 @@ class ManagementData: return result def getClassKey (self, className): - dotPos = className.find(".") - if dotPos == -1: + delimPos = className.find(":") + if delimPos == -1: + schemaRev = 0 + delim = className.find(".") + if delim != -1: + schemaRev = int(className[delim + 1:]) + name = className[0:delim] + else: + name = className for key in self.schema: - if key[1] == className: + if key[1] == name and self.schema[key][4] == schemaRev: return key else: - package = className[0:dotPos] - name = className[dotPos + 1:] + package = className[0:delimPos] + name = className[delimPos + 1:] schemaRev = 0 delim = name.find(".") if delim != -1: @@ -338,6 +373,12 @@ class ManagementData: return "int32" elif typecode == 19: return "int64" + elif typecode == 20: + return "object" + elif typecode == 21: + return "list" + elif typecode == 22: + return "array" else: raise ValueError ("Invalid type code: %d" % typecode) @@ -437,7 +478,7 @@ class ManagementData: if classKey in self.tables: ids = self.listOfIds(classKey, tokens[1:]) for objId in ids: - (ts, config, inst) = self.tables[classKey][self.rawObjId(objId)] + (ts, config, inst) = self.tables[classKey][self.rawObjId(objId).index()] createTime = self.disp.timestamp (ts[1]) destroyTime = "-" if ts[2] > 0: @@ -445,7 +486,7 @@ class ManagementData: objIndex = self.getObjIndex (classKey, config) row = (objId, createTime, destroyTime, objIndex) rows.append (row) - self.disp.table ("Objects of type %s.%s" % (classKey[0], classKey[1]), + self.disp.table ("Objects of type %s" % self.displayClassName(classKey), ("ID", "Created", "Destroyed", "Index"), rows) finally: @@ -486,33 +527,33 @@ class ManagementData: rows = [] timestamp = None - config = self.tables[classKey][ids[0]][1] + config = self.tables[classKey][ids[0].index()][1] for eIdx in range (len (config)): key = config[eIdx][0] if key != "id": row = ("property", key) for id in ids: if timestamp == None or \ - timestamp < self.tables[classKey][id][0][0]: - timestamp = self.tables[classKey][id][0][0] - (key, value) = self.tables[classKey][id][1][eIdx] + timestamp < self.tables[classKey][id.index()][0][0]: + timestamp = self.tables[classKey][id.index()][0][0] + (key, value) = self.tables[classKey][id.index()][1][eIdx] row = row + (self.valueDisplay (classKey, key, value),) rows.append (row) - inst = self.tables[classKey][ids[0]][2] + inst = self.tables[classKey][ids[0].index()][2] for eIdx in range (len (inst)): key = inst[eIdx][0] if key != "id": row = ("statistic", key) for id in ids: - (key, value) = self.tables[classKey][id][2][eIdx] + (key, value) = self.tables[classKey][id.index()][2][eIdx] row = row + (self.valueDisplay (classKey, key, value),) rows.append (row) titleRow = ("Type", "Element") for id in ids: - titleRow = titleRow + (self.refName (id),) - caption = "Object of type %s.%s:" % (classKey[0], classKey[1]) + titleRow = titleRow + (self.refName(id),) + caption = "Object of type %s:" % self.displayClassName(classKey) if timestamp != None: caption = caption + " (last sample time: " + self.disp.timestamp (timestamp) + ")" self.disp.table (caption, titleRow, rows) @@ -530,15 +571,11 @@ class ManagementData: sorted.sort () for classKey in sorted: tuple = self.schema[classKey] - if tuple[4] == 0: - suffix = "" - else: - suffix = ".%d" % tuple[4] - className = classKey[0] + "." + classKey[1] + suffix - row = (className, len (tuple[0]), len (tuple[1]), len (tuple[2]), len (tuple[3])) + row = (self.displayClassName(classKey), len (tuple[0]), len (tuple[1]), + len (tuple[2])) rows.append (row) self.disp.table ("Classes in Schema:", - ("Class", "Properties", "Statistics", "Methods", "Events"), + ("Class", "Properties", "Statistics", "Methods"), rows) finally: self.lock.release () @@ -563,13 +600,15 @@ class ManagementData: access = self.accessName (config[4]) extra = "" if config[5] == 1: - extra = extra + "index " + extra += "index " if config[6] != None: - extra = extra + "Min: " + str (config[6]) + extra += "Min: " + str(config[6]) + " " if config[7] != None: - extra = extra + "Max: " + str (config[7]) + extra += "Max: " + str(config[7]) + " " if config[8] != None: - extra = extra + "MaxLen: " + str (config[8]) + extra += "MaxLen: " + str(config[8]) + " " + if config[9] == 1: + extra += "optional " rows.append ((name, typename, unit, access, extra, desc)) for config in self.schema[classKey][1]: @@ -581,7 +620,7 @@ class ManagementData: rows.append ((name, typename, unit, "", "", desc)) titles = ("Element", "Type", "Unit", "Access", "Notes", "Description") - self.disp.table ("Schema for class '%s.%s.%d':" % (classKey[0], classKey[1], schemaRev), titles, rows) + self.disp.table ("Schema for class '%s':" % self.displayClassName(classKey), titles, rows) for mname in self.schema[classKey][2]: (mdesc, args) = self.schema[classKey][2][mname] @@ -606,14 +645,14 @@ class ManagementData: titles = ("Argument", "Type", "Direction", "Unit", "Notes", "Description") self.disp.table (caption, titles, rows) - except: + except Exception,e: pass self.lock.release () def getClassForId (self, objId): """ Given an object ID, return the class key for the referenced object """ for classKey in self.tables: - if objId in self.tables[classKey]: + if objId.index() in self.tables[classKey]: return classKey return None @@ -626,7 +665,7 @@ class ManagementData: raise ValueError () if methodName not in self.schema[classKey][2]: - print "Method '%s' not valid for class '%s.%s'" % (methodName, classKey[0], classKey[1]) + print "Method '%s' not valid for class '%s'" % (methodName, self.displayClassName(classKey)) raise ValueError () schemaMethod = self.schema[classKey][2][methodName] @@ -647,7 +686,7 @@ class ManagementData: self.methodSeq = self.methodSeq + 1 self.methodsPending[self.methodSeq] = methodName - except: + except Exception, e: methodOk = False self.lock.release () if methodOk: @@ -659,14 +698,19 @@ class ManagementData: def makeIdRow (self, displayId): if displayId in self.idMap: - rawId = self.idMap[displayId] + objId = self.idMap[displayId] else: return None - return (displayId, - rawId, - (rawId & 0x7FFF000000000000) >> 48, - (rawId & 0x0000FFFFFF000000) >> 24, - (rawId & 0x0000000000FFFFFF)) + if objId.getFlags() == 0: + flags = "" + else: + flags = str(objId.getFlags()) + seq = objId.getSequence() + if seq == 0: + seqText = "<durable>" + else: + seqText = str(seq) + return (displayId, flags, seqText, objId.getBroker(), objId.getBank(), hex(objId.getObject())) def listIds (self, select): rows = [] @@ -683,7 +727,7 @@ class ManagementData: return rows.append(row) self.disp.table("Translation of Display IDs:", - ("DisplayID", "RawID", "BootSequence", "Bank", "Object"), + ("DisplayID", "Flags", "BootSequence", "Broker", "Bank", "Object"), rows) def do_list (self, data): @@ -704,7 +748,11 @@ class ManagementData: self.schemaTable (data) def do_call (self, data): - tokens = data.split () + encTokens = data.split () + try: + tokens = [a.decode(locale.getpreferredencoding()) for a in encArgs] + except: + tokens = encTokens if len (tokens) < 2: print "Not enough arguments supplied" return diff --git a/python/qpid/message.py b/python/qpid/message.py index eb3ef5c03c..4d31da2846 100644 --- a/python/qpid/message.py +++ b/python/qpid/message.py @@ -17,7 +17,6 @@ # under the License. # from connection08 import Method, Request -from sets import Set class Message: diff --git a/python/qpid/messaging.py b/python/qpid/messaging.py new file mode 100644 index 0000000000..4f2c190ce2 --- /dev/null +++ b/python/qpid/messaging.py @@ -0,0 +1,822 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +""" +A candidate high level messaging API for python. + +Areas that still need work: + + - asynchronous send + - asynchronous error notification + - definition of the arguments for L{Session.sender} and L{Session.receiver} + - standard L{Message} properties + - L{Message} content encoding + - protocol negotiation/multiprotocol impl +""" + +from codec010 import StringCodec +from concurrency import synchronized, Waiter, Condition +from datatypes import timestamp, uuid4, Serial +from logging import getLogger +from ops import PRIMITIVE +from threading import Thread, RLock +from util import default + +log = getLogger("qpid.messaging") + +static = staticmethod + +AMQP_PORT = 5672 +AMQPS_PORT = 5671 + +class Constant: + + def __init__(self, name, value=None): + self.name = name + self.value = value + + def __repr__(self): + return self.name + +UNLIMITED = Constant("UNLIMITED", 0xFFFFFFFFL) + +class ConnectionError(Exception): + """ + The base class for all connection related exceptions. + """ + pass + +class ConnectError(ConnectionError): + """ + Exception raised when there is an error connecting to the remote + peer. + """ + pass + +class Connection: + + """ + A Connection manages a group of L{Sessions<Session>} and connects + them with a remote endpoint. + """ + + @static + def open(host, port=None, username="guest", password="guest", + mechanism="PLAIN", heartbeat=None, **options): + """ + Creates an AMQP connection and connects it to the given host and port. + + @type host: str + @param host: the name or ip address of the remote host + @type port: int + @param port: the port number of the remote host + @rtype: Connection + @return: a connected Connection + """ + conn = Connection(host, port, username, password, mechanism, heartbeat, **options) + conn.connect() + return conn + + def __init__(self, host, port=None, username="guest", password="guest", + mechanism="PLAIN", heartbeat=None, **options): + """ + Creates a connection. A newly created connection must be connected + with the Connection.connect() method before it can be used. + + @type host: str + @param host: the name or ip address of the remote host + @type port: int + @param port: the port number of the remote host + @rtype: Connection + @return: a disconnected Connection + """ + self.host = host + self.port = default(port, AMQP_PORT) + self.username = username + self.password = password + self.mechanism = mechanism + self.heartbeat = heartbeat + + self.id = str(uuid4()) + self.session_counter = 0 + self.sessions = {} + self.reconnect = options.get("reconnect", False) + self._connected = False + self._lock = RLock() + self._condition = Condition(self._lock) + self._waiter = Waiter(self._condition) + self._modcount = Serial(0) + self.error = None + from driver import Driver + self._driver = Driver(self) + self._driver.start() + + def _wait(self, predicate, timeout=None): + return self._waiter.wait(predicate, timeout=timeout) + + def _wakeup(self): + self._modcount += 1 + self._driver.wakeup() + + def _check_error(self, exc=ConnectionError): + if self.error: + raise exc(*self.error) + + def _ewait(self, predicate, timeout=None, exc=ConnectionError): + result = self._wait(lambda: self.error or predicate(), timeout) + self._check_error(exc) + return result + + @synchronized + def session(self, name=None, transactional=False): + """ + Creates or retrieves the named session. If the name is omitted or + None, then a unique name is chosen based on a randomly generated + uuid. + + @type name: str + @param name: the session name + @rtype: Session + @return: the named Session + """ + + if name is None: + name = "%s:%s" % (self.id, self.session_counter) + self.session_counter += 1 + else: + name = "%s:%s" % (self.id, name) + + if self.sessions.has_key(name): + return self.sessions[name] + else: + ssn = Session(self, name, transactional) + self.sessions[name] = ssn + self._wakeup() + return ssn + + @synchronized + def _remove_session(self, ssn): + del self.sessions[ssn.name] + + @synchronized + def connect(self): + """ + Connect to the remote endpoint. + """ + self._connected = True + self._wakeup() + self._ewait(lambda: self._driver._connected, exc=ConnectError) + + @synchronized + def disconnect(self): + """ + Disconnect from the remote endpoint. + """ + self._connected = False + self._wakeup() + self._ewait(lambda: not self._driver._connected) + + @synchronized + def connected(self): + """ + Return true if the connection is connected, false otherwise. + """ + return self._connected + + @synchronized + def close(self): + """ + Close the connection and all sessions. + """ + for ssn in self.sessions.values(): + ssn.close() + self.disconnect() + +class Pattern: + """ + The pattern filter matches the supplied wildcard pattern against a + message subject. + """ + + def __init__(self, value): + self.value = value + + # XXX: this should become part of the driver + def _bind(self, sst, exchange, queue): + from qpid.ops import ExchangeBind + sst.write_cmd(ExchangeBind(exchange=exchange, queue=queue, + binding_key=self.value.replace("*", "#"))) + +class SessionError(Exception): + pass + +class Disconnected(SessionError): + """ + Exception raised when an operation is attempted that is illegal when + disconnected. + """ + pass + +class NontransactionalSession(SessionError): + """ + Exception raised when commit or rollback is attempted on a non + transactional session. + """ + pass + +class TransactionAborted(SessionError): + pass + +class Session: + + """ + Sessions provide a linear context for sending and receiving + messages, and manage various Senders and Receivers. + """ + + def __init__(self, connection, name, transactional): + self.connection = connection + self.name = name + + self.transactional = transactional + + self.committing = False + self.committed = True + self.aborting = False + self.aborted = False + + self.senders = [] + self.receivers = [] + self.outgoing = [] + self.incoming = [] + self.unacked = [] + self.acked = [] + # XXX: I hate this name. + self.ack_capacity = UNLIMITED + + self.error = None + self.closing = False + self.closed = False + + self._lock = connection._lock + + def __repr__(self): + return "<Session %s>" % self.name + + def _wait(self, predicate, timeout=None): + return self.connection._wait(predicate, timeout=timeout) + + def _wakeup(self): + self.connection._wakeup() + + def _check_error(self, exc=SessionError): + self.connection._check_error(exc) + if self.error: + raise exc(*self.error) + + def _ewait(self, predicate, timeout=None, exc=SessionError): + result = self.connection._ewait(lambda: self.error or predicate(), timeout, exc) + self._check_error(exc) + return result + + @synchronized + def sender(self, target, **options): + """ + Creates a L{Sender} that may be used to send L{Messages<Message>} + to the specified target. + + @type target: str + @param target: the target to which messages will be sent + @rtype: Sender + @return: a new Sender for the specified target + """ + sender = Sender(self, len(self.senders), target, options) + self.senders.append(sender) + self._wakeup() + # XXX: because of the lack of waiting here we can end up getting + # into the driver loop with messages sent for senders that haven't + # been linked yet, something similar can probably happen for + # receivers + return sender + + @synchronized + def receiver(self, source, **options): + """ + Creates a receiver that may be used to fetch L{Messages<Message>} + from the specified source. + + @type source: str + @param source: the source of L{Messages<Message>} + @rtype: Receiver + @return: a new Receiver for the specified source + """ + receiver = Receiver(self, len(self.receivers), source, options) + self.receivers.append(receiver) + self._wakeup() + return receiver + + @synchronized + def _count(self, predicate): + result = 0 + for msg in self.incoming: + if predicate(msg): + result += 1 + return result + + def _peek(self, predicate): + for msg in self.incoming: + if predicate(msg): + return msg + + def _pop(self, predicate): + i = 0 + while i < len(self.incoming): + msg = self.incoming[i] + if predicate(msg): + del self.incoming[i] + return msg + else: + i += 1 + + @synchronized + def _get(self, predicate, timeout=None): + if self._ewait(lambda: ((self._peek(predicate) is not None) or self.closing), + timeout): + msg = self._pop(predicate) + if msg is not None: + msg._receiver.returned += 1 + self.unacked.append(msg) + log.debug("RETR [%s] %s", self, msg) + return msg + return None + + @synchronized + def next_receiver(self, timeout=None): + if self._ewait(lambda: self.incoming, timeout): + return self.incoming[0]._receiver + else: + raise Empty + + @synchronized + def acknowledge(self, message=None, sync=True): + """ + Acknowledge the given L{Message}. If message is None, then all + unacknowledged messages on the session are acknowledged. + + @type message: Message + @param message: the message to acknowledge or None + @type sync: boolean + @param sync: if true then block until the message(s) are acknowledged + """ + if message is None: + messages = self.unacked[:] + else: + messages = [message] + + for m in messages: + if self.ack_capacity is not UNLIMITED: + if self.ack_capacity <= 0: + # XXX: this is currently a SendError, maybe it should be a SessionError? + raise InsufficientCapacity("ack_capacity = %s" % self.ack_capacity) + self._wakeup() + self._ewait(lambda: len(self.acked) < self.ack_capacity) + self.unacked.remove(m) + self.acked.append(m) + + self._wakeup() + if sync: + self._ewait(lambda: not [m for m in messages if m in self.acked]) + + @synchronized + def commit(self): + """ + Commit outstanding transactional work. This consists of all + message sends and receives since the prior commit or rollback. + """ + if not self.transactional: + raise NontransactionalSession() + self.committing = True + self._wakeup() + self._ewait(lambda: not self.committing) + if self.aborted: + raise TransactionAborted() + assert self.committed + + @synchronized + def rollback(self): + """ + Rollback outstanding transactional work. This consists of all + message sends and receives since the prior commit or rollback. + """ + if not self.transactional: + raise NontransactionalSession() + self.aborting = True + self._wakeup() + self._ewait(lambda: not self.aborting) + assert self.aborted + + @synchronized + def close(self): + """ + Close the session. + """ + # XXX: should be able to express this condition through API calls + self._ewait(lambda: not self.outgoing and not self.acked) + + for link in self.receivers + self.senders: + link.close() + + self.closing = True + self._wakeup() + self._ewait(lambda: self.closed) + self.connection._remove_session(self) + +class SendError(SessionError): + pass + +class InsufficientCapacity(SendError): + pass + +class Sender: + + """ + Sends outgoing messages. + """ + + def __init__(self, session, index, target, options): + self.session = session + self.index = index + self.target = target + self.options = options + self.capacity = options.get("capacity", UNLIMITED) + self.durable = options.get("durable") + self.queued = Serial(0) + self.acked = Serial(0) + self.error = None + self.linked = False + self.closing = False + self.closed = False + self._lock = self.session._lock + + def _wakeup(self): + self.session._wakeup() + + def _check_error(self, exc=SendError): + self.session._check_error(exc) + if self.error: + raise exc(*self.error) + + def _ewait(self, predicate, timeout=None, exc=SendError): + result = self.session._ewait(lambda: self.error or predicate(), timeout, exc) + self._check_error(exc) + return result + + @synchronized + def pending(self): + """ + Returns the number of messages awaiting acknowledgment. + @rtype: int + @return: the number of unacknowledged messages + """ + return self.queued - self.acked + + @synchronized + def send(self, object, sync=True, timeout=None): + """ + Send a message. If the object passed in is of type L{unicode}, + L{str}, L{list}, or L{dict}, it will automatically be wrapped in a + L{Message} and sent. If it is of type L{Message}, it will be sent + directly. If the sender capacity is not L{UNLIMITED} then send + will block until there is available capacity to send the message. + If the timeout parameter is specified, then send will throw an + L{InsufficientCapacity} exception if capacity does not become + available within the specified time. + + @type object: unicode, str, list, dict, Message + @param object: the message or content to send + + @type sync: boolean + @param sync: if true then block until the message is sent + + @type timeout: float + @param timeout: the time to wait for available capacity + """ + + if not self.session.connection._connected or self.session.closing: + raise Disconnected() + + self._ewait(lambda: self.linked) + + if isinstance(object, Message): + message = object + else: + message = Message(object) + + if message.durable is None: + message.durable = self.durable + + if self.capacity is not UNLIMITED: + if self.capacity <= 0: + raise InsufficientCapacity("capacity = %s" % self.capacity) + if not self._ewait(lambda: self.pending() < self.capacity, timeout=timeout): + raise InsufficientCapacity("capacity = %s" % self.capacity) + + # XXX: what if we send the same message to multiple senders? + message._sender = self + self.session.outgoing.append(message) + self.queued += 1 + + self._wakeup() + + if sync: + self.sync() + assert message not in self.session.outgoing + + @synchronized + def sync(self): + mno = self.queued + self._ewait(lambda: self.acked >= mno) + + @synchronized + def close(self): + """ + Close the Sender. + """ + self.closing = True + self._wakeup() + try: + self.session._ewait(lambda: self.closed) + finally: + self.session.senders.remove(self) + +class ReceiveError(SessionError): + pass + +class Empty(ReceiveError): + """ + Exception raised by L{Receiver.fetch} when there is no message + available within the alloted time. + """ + pass + +class Receiver(object): + + """ + Receives incoming messages from a remote source. Messages may be + fetched with L{fetch}. + """ + + def __init__(self, session, index, source, options): + self.session = session + self.index = index + self.destination = str(self.index) + self.source = source + self.options = options + + self.granted = Serial(0) + self.draining = False + self.impending = Serial(0) + self.received = Serial(0) + self.returned = Serial(0) + + self.error = None + self.linked = False + self.closing = False + self.closed = False + self._lock = self.session._lock + self._capacity = 0 + self._set_capacity(options.get("capacity", 0), False) + + @synchronized + def _set_capacity(self, c, wakeup=True): + if c is UNLIMITED: + self._capacity = c.value + else: + self._capacity = c + self._grant() + if wakeup: + self._wakeup() + + def _get_capacity(self): + if self._capacity == UNLIMITED.value: + return UNLIMITED + else: + return self._capacity + + capacity = property(_get_capacity, _set_capacity) + + def _wakeup(self): + self.session._wakeup() + + def _check_error(self, exc=ReceiveError): + self.session._check_error(exc) + if self.error: + raise exc(*self.error) + + def _ewait(self, predicate, timeout=None, exc=ReceiveError): + result = self.session._ewait(lambda: self.error or predicate(), timeout, exc) + self._check_error(exc) + return result + + @synchronized + def pending(self): + """ + Returns the number of messages available to be fetched by the + application. + + @rtype: int + @return: the number of available messages + """ + return self.received - self.returned + + def _pred(self, msg): + return msg._receiver == self + + @synchronized + def fetch(self, timeout=None): + """ + Fetch and return a single message. A timeout of None will block + forever waiting for a message to arrive, a timeout of zero will + return immediately if no messages are available. + + @type timeout: float + @param timeout: the time to wait for a message to be available + """ + + self._ewait(lambda: self.linked) + + if self._capacity == 0: + self.granted = self.returned + 1 + self._wakeup() + self._ewait(lambda: self.impending >= self.granted) + msg = self.session._get(self._pred, timeout=timeout) + if msg is None: + self.draining = True + self._wakeup() + self._ewait(lambda: not self.draining) + self._grant() + self._wakeup() + msg = self.session._get(self._pred, timeout=0) + if msg is None: + raise Empty() + elif self._capacity not in (0, UNLIMITED.value): + self.granted += 1 + self._wakeup() + return msg + + def _grant(self): + if self._capacity == UNLIMITED.value: + self.granted = UNLIMITED + else: + self.granted = self.received + self._capacity + + @synchronized + def close(self): + """ + Close the receiver. + """ + self.closing = True + self._wakeup() + try: + self.session._ewait(lambda: self.closed) + finally: + self.session.receivers.remove(self) + +def codec(name): + type = PRIMITIVE[name] + + def encode(x): + sc = StringCodec() + sc.write_primitive(type, x) + return sc.encoded + + def decode(x): + sc = StringCodec(x) + return sc.read_primitive(type) + + return encode, decode + +# XXX: need to correctly parse the mime type and deal with +# content-encoding header + +TYPE_MAPPINGS={ + dict: "amqp/map", + list: "amqp/list", + unicode: "text/plain; charset=utf8", + unicode: "text/plain", + buffer: None, + str: None, + None.__class__: None + } + +TYPE_CODEC={ + "amqp/map": codec("map"), + "amqp/list": codec("list"), + "text/plain; charset=utf8": (lambda x: x.encode("utf8"), lambda x: x.decode("utf8")), + "text/plain": (lambda x: x.encode("utf8"), lambda x: x.decode("utf8")), + "": (lambda x: x, lambda x: x), + None: (lambda x: x, lambda x: x) + } + +def get_type(content): + return TYPE_MAPPINGS[content.__class__] + +def get_codec(content_type): + return TYPE_CODEC[content_type] + +UNSPECIFIED = object() + +class Message: + + """ + A message consists of a standard set of fields, an application + defined set of properties, and some content. + + @type id: str + @ivar id: the message id + @type user_id: ??? + @ivar user_id: the user-id of the message producer + @type to: ??? + @ivar to: ??? + @type reply_to: ??? + @ivar reply_to: ??? + @type correlation_id: str + @ivar correlation_id: a correlation-id for the message + @type properties: dict + @ivar properties: application specific message properties + @type content_type: str + @ivar content_type: the content-type of the message + @type content: str, unicode, buffer, dict, list + @ivar content: the message content + """ + + def __init__(self, content=None, content_type=UNSPECIFIED, id=None, + subject=None, to=None, user_id=None, reply_to=None, + correlation_id=None, durable=None, properties=None): + """ + Construct a new message with the supplied content. The + content-type of the message will be automatically inferred from + type of the content parameter. + + @type content: str, unicode, buffer, dict, list + @param content: the message content + + @type content_type: str + @param content_type: the content-type of the message + """ + self.id = id + self.subject = subject + self.to = to + self.user_id = user_id + self.reply_to = reply_to + self.correlation_id = correlation_id + self.durable = durable + self.redelivered = False + if properties is None: + self.properties = {} + else: + self.properties = properties + if content_type is UNSPECIFIED: + self.content_type = get_type(content) + else: + self.content_type = content_type + self.content = content + + def __repr__(self): + args = [] + for name in ["id", "subject", "to", "user_id", "reply_to", + "correlation_id"]: + value = self.__dict__[name] + if value is not None: args.append("%s=%r" % (name, value)) + for name in ["durable", "properties"]: + value = self.__dict__[name] + if value: args.append("%s=%r" % (name, value)) + if self.content_type != get_type(self.content): + args.append("content_type=%r" % self.content_type) + if self.content is not None: + if args: + args.append("content=%r" % self.content) + else: + args.append(repr(self.content)) + return "Message(%s)" % ", ".join(args) + +__all__ = ["Connection", "Session", "Sender", "Receiver", "Pattern", "Message", + "ConnectionError", "ConnectError", "SessionError", "Disconnected", + "SendError", "InsufficientCapacity", "ReceiveError", "Empty", + "timestamp", "uuid4", "UNLIMITED", "AMQP_PORT", "AMQPS_PORT"] diff --git a/python/qpid/mimetype.py b/python/qpid/mimetype.py new file mode 100644 index 0000000000..f512996b9f --- /dev/null +++ b/python/qpid/mimetype.py @@ -0,0 +1,106 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +import re, rfc822 +from lexer import Lexicon, LexError +from parser import Parser, ParseError + +l = Lexicon() + +LPAREN = l.define("LPAREN", r"\(") +RPAREN = l.define("LPAREN", r"\)") +SLASH = l.define("SLASH", r"/") +SEMI = l.define("SEMI", r";") +EQUAL = l.define("EQUAL", r"=") +TOKEN = l.define("TOKEN", r'[^()<>@,;:\\"/\[\]?= ]+') +STRING = l.define("STRING", r'"(?:[^\\"]|\\.)*"') +WSPACE = l.define("WSPACE", r"[ \n\r\t]+") +EOF = l.eof("EOF") + +LEXER = l.compile() + +def lex(st): + return LEXER.lex(st) + +class MimeTypeParser(Parser): + + def __init__(self, tokens): + Parser.__init__(self, [t for t in tokens if t.type is not WSPACE]) + + def parse(self): + result = self.mimetype() + self.eat(EOF) + return result + + def mimetype(self): + self.remove_comments() + self.reset() + + type = self.eat(TOKEN).value.lower() + self.eat(SLASH) + subtype = self.eat(TOKEN).value.lower() + + params = [] + while True: + if self.matches(SEMI): + params.append(self.parameter()) + else: + break + + return type, subtype, params + + def remove_comments(self): + while True: + self.eat_until(LPAREN, EOF) + if self.matches(LPAREN): + self.remove(*self.comment()) + else: + break + + def comment(self): + start = self.eat(LPAREN) + + while True: + self.eat_until(LPAREN, RPAREN) + if self.matches(LPAREN): + self.comment() + else: + break + + end = self.eat(RPAREN) + return start, end + + def parameter(self): + self.eat(SEMI) + name = self.eat(TOKEN).value + self.eat(EQUAL) + value = self.value() + return name, value + + def value(self): + if self.matches(TOKEN): + return self.eat().value + elif self.matches(STRING): + return rfc822.unquote(self.eat().value) + else: + raise ParseError(self.next(), TOKEN, STRING) + +def parse(addr): + return MimeTypeParser(lex(addr)).parse() + +__all__ = ["parse", "ParseError"] diff --git a/python/qpid/ops.py b/python/qpid/ops.py new file mode 100644 index 0000000000..a8ba826857 --- /dev/null +++ b/python/qpid/ops.py @@ -0,0 +1,280 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +import os, mllib, cPickle as pickle +from util import fill + +class Primitive(object): + pass + +class Enum(object): + pass + +class Field: + + def __init__(self, name, type, default=None): + self.name = name + self.type = type + self.default = default + + def __repr__(self): + return "%s: %s" % (self.name, self.type) + +class Compound(object): + + UNENCODED=[] + + def __init__(self, *args, **kwargs): + args = list(args) + for f in self.ARGS: + if args: + a = args.pop(0) + else: + a = kwargs.pop(f.name, f.default) + setattr(self, f.name, a) + if args: + raise TypeError("%s takes at most %s arguments (%s given))" % + (self.__class__.__name__, len(self.ARGS), + len(self.ARGS) + len(args))) + if kwargs: + raise TypeError("got unexpected keyword argument '%s'" % kwargs.keys()[0]) + + def fields(self): + result = {} + for f in self.FIELDS: + result[f.name] = getattr(self, f.name) + return result + + def args(self): + result = {} + for f in self.ARGS: + result[f.name] = getattr(self, f.name) + return result + + def __getitem__(self, attr): + return getattr(self, attr) + + def __setitem__(self, attr, value): + setattr(self, attr, value) + + def dispatch(self, target, *args): + handler = "do_%s" % self.NAME + getattr(target, handler)(self, *args) + + def __repr__(self, extras=()): + return "%s(%s)" % (self.__class__.__name__, + ", ".join(["%s=%r" % (f.name, getattr(self, f.name)) + for f in self.ARGS + if getattr(self, f.name) != f.default])) + +class Command(Compound): + UNENCODED=[Field("channel", "uint16", 0), + Field("id", "sequence-no", None), + Field("sync", "bit", False), + Field("headers", None, None), + Field("payload", None, None)] + +class Control(Compound): + UNENCODED=[Field("channel", "uint16", 0)] + +def pythonize(st): + if st is None: + return None + else: + return str(st.replace("-", "_")) + +def pydoc(op, children=()): + doc = "\n\n".join([fill(p.text(), 0) for p in op.query["doc"]]) + for ch in children: + doc += "\n\n " + pythonize(ch["@name"]) + " -- " + str(ch["@label"]) + ch_descs ="\n\n".join([fill(p.text(), 4) for p in ch.query["doc"]]) + if ch_descs: + doc += "\n\n" + ch_descs + return doc + +def studly(st): + return "".join([p.capitalize() for p in st.split("-")]) + +def klass(nd): + while nd.parent is not None: + if hasattr(nd.parent, "name") and nd.parent.name == "class": + return nd.parent + else: + nd = nd.parent + +def included(nd): + cls = klass(nd) + if cls is None: + return True + else: + return cls["@name"] not in ("file", "stream") + +def num(s): + if s: return int(s, 0) + +def code(nd): + c = num(nd["@code"]) + if c is None: + return None + else: + cls = klass(nd) + if cls is None: + return c + else: + return c | (num(cls["@code"]) << 8) + +def default(f): + if f["@type"] == "bit": + return False + else: + return None + +def make_compound(decl, base): + dict = {} + fields = decl.query["field"] + dict["__doc__"] = pydoc(decl, fields) + dict["NAME"] = pythonize(decl["@name"]) + dict["SIZE"] = num(decl["@size"]) + dict["CODE"] = code(decl) + dict["PACK"] = num(decl["@pack"]) + dict["FIELDS"] = [Field(pythonize(f["@name"]), resolve(f), default(f)) for f in fields] + dict["ARGS"] = dict["FIELDS"] + base.UNENCODED + return str(studly(decl["@name"])), (base,), dict + +def make_restricted(decl): + name = pythonize(decl["@name"]) + dict = {} + choices = decl.query["choice"] + dict["__doc__"] = pydoc(decl, choices) + dict["NAME"] = name + dict["TYPE"] = str(decl.parent["@type"]) + values = [] + for ch in choices: + val = int(ch["@value"], 0) + dict[pythonize(ch["@name"])] = val + values.append(val) + dict["VALUES"] = values + return name, (Enum,), dict + +def make_type(decl): + name = pythonize(decl["@name"]) + dict = {} + dict["__doc__"] = pydoc(decl) + dict["NAME"] = name + dict["CODE"] = code(decl) + return str(studly(decl["@name"])), (Primitive,), dict + +def make_command(decl): + decl.set_attr("name", "%s-%s" % (decl.parent["@name"], decl["@name"])) + decl.set_attr("size", "0") + decl.set_attr("pack", "2") + name, bases, dict = make_compound(decl, Command) + dict["RESULT"] = pythonize(decl["result/@type"]) or pythonize(decl["result/struct/@name"]) + return name, bases, dict + +def make_control(decl): + decl.set_attr("name", "%s-%s" % (decl.parent["@name"], decl["@name"])) + decl.set_attr("size", "0") + decl.set_attr("pack", "2") + return make_compound(decl, Control) + +def make_struct(decl): + return make_compound(decl, Compound) + +def make_enum(decl): + decl.set_attr("name", decl.parent["@name"]) + return make_restricted(decl) + + +vars = globals() + +def make(nd): + return vars["make_%s" % nd.name](nd) + +from qpid_config import amqp_spec as file +pclfile = "%s.ops.pcl" % file + +if os.path.exists(pclfile) and \ + os.path.getmtime(pclfile) > os.path.getmtime(file): + f = open(pclfile, "r") + types = pickle.load(f) + f.close() +else: + spec = mllib.xml_parse(file) + + def qualify(nd, field="@name"): + cls = klass(nd) + if cls is None: + return pythonize(nd[field]) + else: + return pythonize("%s.%s" % (cls["@name"], nd[field])) + + domains = dict([(qualify(d), pythonize(d["@type"])) + for d in spec.query["amqp/domain", included] + \ + spec.query["amqp/class/domain", included]]) + + def resolve(nd): + candidates = qualify(nd, "@type"), pythonize(nd["@type"]) + for c in candidates: + if domains.has_key(c): + while domains.has_key(c): + c = domains[c] + return c + else: + return c + + type_decls = \ + spec.query["amqp/class/command", included] + \ + spec.query["amqp/class/control", included] + \ + spec.query["amqp/class/command/result/struct", included] + \ + spec.query["amqp/class/struct", included] + \ + spec.query["amqp/class/domain/enum", included] + \ + spec.query["amqp/domain/enum", included] + \ + spec.query["amqp/type"] + types = [make(nd) for nd in type_decls] + + if os.access(os.path.dirname(os.path.abspath(pclfile)), os.W_OK): + f = open(pclfile, "w") + pickle.dump(types, f) + f.close() + +ENUMS = {} +PRIMITIVE = {} +COMPOUND = {} +COMMANDS = {} +CONTROLS = {} + +for name, bases, _dict in types: + t = type(name, bases, _dict) + vars[name] = t + + if issubclass(t, Command): + COMMANDS[t.NAME] = t + COMMANDS[t.CODE] = t + elif issubclass(t, Control): + CONTROLS[t.NAME] = t + CONTROLS[t.CODE] = t + elif issubclass(t, Compound): + COMPOUND[t.NAME] = t + if t.CODE is not None: + COMPOUND[t.CODE] = t + elif issubclass(t, Primitive): + PRIMITIVE[t.NAME] = t + PRIMITIVE[t.CODE] = t + elif issubclass(t, Enum): + ENUMS[t.NAME] = t diff --git a/python/qpid/parser.py b/python/qpid/parser.py new file mode 100644 index 0000000000..233f0a8469 --- /dev/null +++ b/python/qpid/parser.py @@ -0,0 +1,68 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +class ParseError(Exception): + + def __init__(self, token, *expected): + line, ln, col = token.line_info() + exp = ", ".join(map(str, expected)) + if len(expected) > 1: + exp = "(%s)" % exp + if expected: + msg = "expecting %s, got %s line:%s,%s:%s" % (exp, token, ln, col, line) + else: + msg = "unexpected token %s line:%s,%s:%s" % (token, ln, col, line) + Exception.__init__(self, msg) + self.token = token + self.expected = expected + +class Parser: + + def __init__(self, tokens): + self.tokens = tokens + self.idx = 0 + + def next(self): + return self.tokens[self.idx] + + def matches(self, *types): + return self.next().type in types + + def eat(self, *types): + if types and not self.matches(*types): + raise ParseError(self.next(), *types) + else: + t = self.next() + self.idx += 1 + return t + + def eat_until(self, *types): + result = [] + while not self.matches(*types): + result.append(self.eat()) + return result + + def remove(self, start, end): + start_idx = self.tokens.index(start) + end_idx = self.tokens.index(end) + 1 + del self.tokens[start_idx:end_idx] + self.idx -= end_idx - start_idx + + def reset(self): + self.idx = 0 diff --git a/python/qpid/peer.py b/python/qpid/peer.py index 0932efeab3..2bc9844351 100644 --- a/python/qpid/peer.py +++ b/python/qpid/peer.py @@ -25,7 +25,7 @@ incoming method frames to a delegate. """ import thread, threading, traceback, socket, sys, logging -from connection08 import EOF, Method, Header, Body, Request, Response +from connection08 import EOF, Method, Header, Body, Request, Response, VersionError from message import Message from queue import Queue, Closed as QueueClosed from content import Content @@ -95,6 +95,8 @@ class Peer: break ch = self.channel(frame.channel) ch.receive(frame, self.work) + except VersionError, e: + self.closed(e) except: self.fatal() @@ -193,11 +195,7 @@ class Channel: self.futures = {} self.control_queue = Queue(0)#used for incoming methods that appas may want to handle themselves - # Use reliable framing if version == 0-9. - if spec.major == 0 and spec.minor == 9: - self.invoker = self.invoke_reliable - else: - self.invoker = self.invoke_method + self.invoker = self.invoke_method self.use_execution_layer = (spec.major == 0 and spec.minor == 10) or (spec.major == 99 and spec.minor == 0) self.synchronous = True @@ -464,6 +462,6 @@ class IncomingCompletion: #TODO: record and manage the ranges properly range = [mark, mark] if (self.mark == -1):#hack until wraparound is implemented - self.channel.execution_complete(cumulative_execution_mark=0xFFFFFFFF, ranged_execution_set=range) + self.channel.execution_complete(cumulative_execution_mark=0xFFFFFFFFL, ranged_execution_set=range) else: self.channel.execution_complete(cumulative_execution_mark=self.mark, ranged_execution_set=range) diff --git a/python/qpid/queue.py b/python/qpid/queue.py index c9f4d1d1d0..63a7684843 100644 --- a/python/qpid/queue.py +++ b/python/qpid/queue.py @@ -63,7 +63,9 @@ class Queue(BaseQueue): if listener is None: if self.thread is not None: self.put(Queue.STOP) - self.thread.join() + # loop and timed join permit keyboard interrupts to work + while self.thread.isAlive(): + self.thread.join(3) self.thread = None self.listener = listener diff --git a/python/qpid/selector.py b/python/qpid/selector.py new file mode 100644 index 0000000000..ca5946c3f9 --- /dev/null +++ b/python/qpid/selector.py @@ -0,0 +1,139 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +import atexit, time +from compat import select, set, selectable_waiter +from threading import Thread, Lock + +class Acceptor: + + def __init__(self, sock, handler): + self.sock = sock + self.handler = handler + + def fileno(self): + return self.sock.fileno() + + def reading(self): + return True + + def writing(self): + return False + + def readable(self): + sock, addr = self.sock.accept() + self.handler(sock) + +class Selector: + + lock = Lock() + DEFAULT = None + + @staticmethod + def default(): + Selector.lock.acquire() + try: + if Selector.DEFAULT is None: + sel = Selector() + atexit.register(sel.stop) + sel.start() + Selector.DEFAULT = sel + return Selector.DEFAULT + finally: + Selector.lock.release() + + def __init__(self): + self.selectables = set() + self.reading = set() + self.writing = set() + self.waiter = selectable_waiter() + self.reading.add(self.waiter) + self.stopped = False + self.thread = None + + def wakeup(self): + self.waiter.wakeup() + + def register(self, selectable): + self.selectables.add(selectable) + self.modify(selectable) + + def _update(self, selectable): + if selectable.reading(): + self.reading.add(selectable) + else: + self.reading.discard(selectable) + if selectable.writing(): + self.writing.add(selectable) + else: + self.writing.discard(selectable) + return selectable.timing() + + def modify(self, selectable): + self._update(selectable) + self.wakeup() + + def unregister(self, selectable): + self.reading.discard(selectable) + self.writing.discard(selectable) + self.selectables.discard(selectable) + self.wakeup() + + def start(self): + self.stopped = False + self.thread = Thread(target=self.run) + self.thread.setDaemon(True) + self.thread.start(); + + def run(self): + while not self.stopped: + wakeup = None + for sel in self.selectables.copy(): + t = self._update(sel) + if t is not None: + if wakeup is None: + wakeup = t + else: + wakeup = min(wakeup, t) + + if wakeup is None: + timeout = None + else: + timeout = max(0, wakeup - time.time()) + + rd, wr, ex = select(self.reading, self.writing, (), timeout) + + for sel in wr: + if sel.writing(): + sel.writeable() + + for sel in rd: + if sel.reading(): + sel.readable() + + now = time.time() + for sel in self.selectables.copy(): + w = sel.timing() + if w is not None and now > w: + sel.timeout() + + def stop(self, timeout=None): + self.stopped = True + self.wakeup() + self.thread.join(timeout) + self.thread = None diff --git a/python/qpid/session.py b/python/qpid/session.py index 2f70461ab6..2f1bd81bd4 100644 --- a/python/qpid/session.py +++ b/python/qpid/session.py @@ -18,12 +18,13 @@ # from threading import Condition, RLock, Lock, currentThread -from invoker import Invoker +from spec import SPEC +from generator import command_invoker from datatypes import RangedSet, Struct, Future from codec010 import StringCodec -from assembler import Segment from queue import Queue from datatypes import Message, serial +from ops import Command, MessageTransfer from util import wait, notify from exceptions import * from logging import getLogger @@ -43,12 +44,12 @@ def server(*args): INCOMPLETE = object() -class Session(Invoker): +class Session(command_invoker()): - def __init__(self, name, spec, auto_sync=True, timeout=10, delegate=client): + def __init__(self, name, auto_sync=True, timeout=10, delegate=client): self.name = name - self.spec = spec self.auto_sync = auto_sync + self.need_sync = True self.timeout = timeout self.channel = None self.invoke_lock = Lock() @@ -66,8 +67,6 @@ class Session(Invoker): self.results = {} self.exceptions = [] - self.assembly = None - self.delegate = delegate(self) def incoming(self, destination): @@ -94,7 +93,7 @@ class Session(Invoker): ch = self.channel if ch is not None and currentThread() == ch.connection.thread: raise SessionException("deadlock detected") - if not self.auto_sync: + if self.need_sync: self.execution_sync(sync=True) last = self.sender.next_id - 1 if not wait(self.condition, lambda: @@ -133,82 +132,50 @@ class Session(Invoker): finally: self.lock.release() - def resolve_method(self, name): - cmd = self.spec.instructions.get(name) - if cmd is not None and cmd.track == self.spec["track.command"].value: - return self.METHOD, cmd + def invoke(self, op, args, kwargs): + if issubclass(op, Command): + self.invoke_lock.acquire() + try: + return self.do_invoke(op, args, kwargs) + finally: + self.invoke_lock.release() else: - # XXX - for st in self.spec.structs.values(): - if st.name == name: - return self.METHOD, st - if self.spec.structs_by_name.has_key(name): - return self.METHOD, self.spec.structs_by_name[name] - if self.spec.enums.has_key(name): - return self.VALUE, self.spec.enums[name] - return self.ERROR, None - - def invoke(self, type, args, kwargs): - # XXX - if not hasattr(type, "track"): - return type.new(args, kwargs) - - self.invoke_lock.acquire() - try: - return self.do_invoke(type, args, kwargs) - finally: - self.invoke_lock.release() + return op(*args, **kwargs) - def do_invoke(self, type, args, kwargs): + def do_invoke(self, op, args, kwargs): if self._closing: raise SessionClosed() - if self.channel == None: + ch = self.channel + if ch == None: raise SessionDetached() - if type.segments: - if len(args) == len(type.fields) + 1: + if op == MessageTransfer: + if len(args) == len(op.FIELDS) + 1: message = args[-1] args = args[:-1] else: message = kwargs.pop("message", None) - else: - message = None - - hdr = Struct(self.spec["session.header"]) - hdr.sync = self.auto_sync or kwargs.pop("sync", False) + if message is not None: + kwargs["headers"] = message.headers + kwargs["payload"] = message.body - cmd = type.new(args, kwargs) - sc = StringCodec(self.spec) - sc.write_command(hdr, cmd) + cmd = op(*args, **kwargs) + cmd.sync = self.auto_sync or cmd.sync + self.need_sync = not cmd.sync + cmd.channel = ch.id - seg = Segment(True, (message == None or - (message.headers == None and message.body == None)), - type.segment_type, type.track, self.channel.id, sc.encoded) - - if type.result: + if op.RESULT: result = Future(exception=SessionException) self.results[self.sender.next_id] = result - self.send(seg) - - log.debug("SENT %s %s %s", seg.id, hdr, cmd) - - if message != None: - if message.headers != None: - sc = StringCodec(self.spec) - for st in message.headers: - sc.write_struct32(st) - seg = Segment(False, message.body == None, self.spec["segment_type.header"].value, - type.track, self.channel.id, sc.encoded) - self.send(seg) - if message.body != None: - seg = Segment(False, True, self.spec["segment_type.body"].value, - type.track, self.channel.id, message.body) - self.send(seg) - msg.debug("SENT %s", message) - - if type.result: + self.send(cmd) + + log.debug("SENT %s", cmd) + if op == MessageTransfer: + msg.debug("SENT %s", cmd) + + if op.RESULT: if self.auto_sync: return result.get(self.timeout) else: @@ -216,81 +183,47 @@ class Session(Invoker): elif self.auto_sync: self.sync(self.timeout) - def received(self, seg): - self.receiver.received(seg) - if seg.first: - assert self.assembly == None - self.assembly = [] - self.assembly.append(seg) - if seg.last: - self.dispatch(self.assembly) - self.assembly = None + def received(self, cmd): + self.receiver.received(cmd) + self.dispatch(cmd) - def dispatch(self, assembly): - segments = assembly[:] + def dispatch(self, cmd): + log.debug("RECV %s", cmd) - hdr, cmd = assembly.pop(0).decode(self.spec) - log.debug("RECV %s %s %s", cmd.id, hdr, cmd) - - args = [] - - for st in cmd._type.segments: - if assembly: - seg = assembly[0] - if seg.type == st.segment_type: - args.append(seg.decode(self.spec)) - assembly.pop(0) - continue - args.append(None) - - assert len(assembly) == 0 - - attr = cmd._type.qname.replace(".", "_") - result = getattr(self.delegate, attr)(cmd, *args) - - if cmd._type.result: + result = getattr(self.delegate, cmd.NAME)(cmd) + if result is INCOMPLETE: + return + elif result is not None: self.execution_result(cmd.id, result) - if result is not INCOMPLETE: - for seg in segments: - self.receiver.completed(seg) - # XXX: don't forget to obey sync for manual completion as well - if hdr.sync: - self.channel.session_completed(self.receiver._completed) + self.receiver.completed(cmd) + # XXX: don't forget to obey sync for manual completion as well + if cmd.sync: + self.channel.session_completed(self.receiver._completed) - def send(self, seg): - self.sender.send(seg) - - def __str__(self): - return '<Session: %s, %s>' % (self.name, self.channel) + def send(self, cmd): + self.sender.send(cmd) def __repr__(self): - return str(self) + return '<Session: %s, %s>' % (self.name, self.channel) class Receiver: def __init__(self, session): self.session = session self.next_id = None - self.next_offset = None self._completed = RangedSet() - def received(self, seg): - if self.next_id == None or self.next_offset == None: + def received(self, cmd): + if self.next_id == None: raise Exception("todo") - seg.id = self.next_id - seg.offset = self.next_offset - if seg.last: - self.next_id += 1 - self.next_offset = 0 - else: - self.next_offset += len(seg.payload) + cmd.id = self.next_id + self.next_id += 1 - def completed(self, seg): - if seg.id == None: - raise ValueError("cannot complete unidentified segment") - if seg.last: - self._completed.add(seg.id) + def completed(self, cmd): + if cmd.id == None: + raise ValueError("cannot complete unidentified command") + self._completed.add(cmd.id) def known_completed(self, commands): completed = RangedSet() @@ -307,30 +240,27 @@ class Sender: def __init__(self, session): self.session = session self.next_id = serial(0) - self.next_offset = 0 - self.segments = [] + self.commands = [] self._completed = RangedSet() - def send(self, seg): - seg.id = self.next_id - seg.offset = self.next_offset - if seg.last: - self.next_id += 1 - self.next_offset = 0 - else: - self.next_offset += len(seg.payload) - self.segments.append(seg) + def send(self, cmd): + ch = self.session.channel + if ch is None: + raise SessionDetached() + cmd.id = self.next_id + self.next_id += 1 if self.session.send_id: self.session.send_id = False - self.session.channel.session_command_point(seg.id, seg.offset) - self.session.channel.connection.write_segment(seg) + ch.session_command_point(cmd.id, 0) + self.commands.append(cmd) + ch.connection.write_op(cmd) def completed(self, commands): idx = 0 - while idx < len(self.segments): - seg = self.segments[idx] - if seg.id in commands: - del self.segments[idx] + while idx < len(self.commands): + cmd = self.commands[idx] + if cmd.id in commands: + del self.commands[idx] else: idx += 1 for range in commands.ranges: @@ -344,8 +274,9 @@ class Incoming(Queue): self.destination = destination def start(self): - for unit in self.session.credit_unit.values(): - self.session.message_flow(self.destination, unit, 0xFFFFFFFF) + self.session.message_set_flow_mode(self.destination, self.session.flow_mode.credit) + for unit in self.session.credit_unit.VALUES: + self.session.message_flow(self.destination, unit, 0xFFFFFFFFL) def stop(self): self.session.message_cancel(self.destination) @@ -368,9 +299,9 @@ class Delegate: class Client(Delegate): - def message_transfer(self, cmd, headers, body): - m = Message(body) - m.headers = headers + def message_transfer(self, cmd): + m = Message(cmd.payload) + m.headers = cmd.headers m.id = cmd.id messages = self.session.incoming(cmd.destination) messages.put(m) diff --git a/python/qpid/spec.py b/python/qpid/spec.py index e6d914044c..e9bfef1fa6 100644 --- a/python/qpid/spec.py +++ b/python/qpid/spec.py @@ -29,7 +29,7 @@ class so that the generated code can be reused in a variety of situations. """ -import os, mllib, spec08, spec010 +import os, mllib, spec08 def default(): try: @@ -54,6 +54,8 @@ def load(specfile, *errata): minor = doc["amqp/@minor"] if major == "0" and minor == "10": - return spec010.load(specfile, *errata) + return None else: return spec08.load(specfile, *errata) + +SPEC = load(default()) diff --git a/python/qpid/spec010.py b/python/qpid/spec010.py deleted file mode 100644 index 23966e6176..0000000000 --- a/python/qpid/spec010.py +++ /dev/null @@ -1,691 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# - -import os, cPickle, datatypes -from codec010 import StringCodec -from util import mtime, fill - -class Node: - - def __init__(self, children): - self.children = children - self.named = {} - self.docs = [] - self.rules = [] - - def register(self): - for ch in self.children: - ch.register(self) - - def resolve(self): - for ch in self.children: - ch.resolve() - - def __getitem__(self, name): - path = name.split(".", 1) - nd = self.named - for step in path: - nd = nd[step] - return nd - - def __iter__(self): - return iter(self.children) - -class Anonymous: - - def __init__(self, children): - self.children = children - - def register(self, node): - for ch in self.children: - ch.register(node) - - def resolve(self): - for ch in self.children: - ch.resolve() - -class Named: - - def __init__(self, name): - self.name = name - self.qname = None - - def register(self, node): - self.spec = node.spec - self.klass = node.klass - node.named[self.name] = self - if node.qname: - self.qname = "%s.%s" % (node.qname, self.name) - else: - self.qname = self.name - - def __str__(self): - return self.qname - - def __repr__(self): - return str(self) - -class Lookup: - - def lookup(self, name): - value = None - if self.klass: - try: - value = self.klass[name] - except KeyError: - pass - if not value: - value = self.spec[name] - return value - -class Coded: - - def __init__(self, code): - self.code = code - -class Constant(Named, Node): - - def __init__(self, name, value, children): - Named.__init__(self, name) - Node.__init__(self, children) - self.value = value - - def register(self, node): - Named.register(self, node) - node.constants.append(self) - Node.register(self) - -class Type(Named, Node): - - def __init__(self, name, children): - Named.__init__(self, name) - Node.__init__(self, children) - - def is_present(self, value): - return value != None - - def register(self, node): - Named.register(self, node) - Node.register(self) - -class Primitive(Coded, Type): - - def __init__(self, name, code, fixed, variable, children): - Coded.__init__(self, code) - Type.__init__(self, name, children) - self.fixed = fixed - self.variable = variable - - def register(self, node): - Type.register(self, node) - if self.code is not None: - self.spec.types[self.code] = self - - def is_present(self, value): - if self.fixed == 0: - return value - else: - return Type.is_present(self, value) - - def encode(self, codec, value): - getattr(codec, "write_%s" % self.name)(value) - - def decode(self, codec): - return getattr(codec, "read_%s" % self.name)() - -class Domain(Type, Lookup): - - def __init__(self, name, type, children): - Type.__init__(self, name, children) - self.type = type - self.choices = {} - - def resolve(self): - self.type = self.lookup(self.type) - Node.resolve(self) - - def encode(self, codec, value): - self.type.encode(codec, value) - - def decode(self, codec): - return self.type.decode(codec) - -class Enum: - - def __init__(self, name): - self.name = name - self._names = () - self._values = () - - def values(self): - return self._values - - def __repr__(self): - return "%s(%s)" % (self.name, ", ".join(self._names)) - -class Choice(Named, Node): - - def __init__(self, name, value, children): - Named.__init__(self, name) - Node.__init__(self, children) - self.value = value - - def register(self, node): - Named.register(self, node) - node.choices[self.value] = self - Node.register(self) - try: - enum = node.spec.enums[node.name] - except KeyError: - enum = Enum(node.name) - node.spec.enums[node.name] = enum - setattr(enum, self.name, self.value) - enum._names += (self.name,) - enum._values += (self.value,) - -class Composite(Type, Coded): - - def __init__(self, name, label, code, size, pack, children): - Coded.__init__(self, code) - Type.__init__(self, name, children) - self.label = label - self.fields = [] - self.size = size - self.pack = pack - - def new(self, args, kwargs): - return datatypes.Struct(self, *args, **kwargs) - - def decode(self, codec): - codec.read_size(self.size) - if self.code is not None: - code = codec.read_uint16() - assert self.code == code - return datatypes.Struct(self, **self.decode_fields(codec)) - - def decode_fields(self, codec): - flags = 0 - for i in range(self.pack): - flags |= (codec.read_uint8() << 8*i) - - result = {} - - for i in range(len(self.fields)): - f = self.fields[i] - if flags & (0x1 << i): - result[f.name] = f.type.decode(codec) - else: - result[f.name] = None - return result - - def encode(self, codec, value): - sc = StringCodec(self.spec) - if self.code is not None: - sc.write_uint16(self.code) - self.encode_fields(sc, value) - codec.write_size(self.size, len(sc.encoded)) - codec.write(sc.encoded) - - def encode_fields(self, codec, values): - flags = 0 - for i in range(len(self.fields)): - f = self.fields[i] - if f.type.is_present(values[f.name]): - flags |= (0x1 << i) - for i in range(self.pack): - codec.write_uint8((flags >> 8*i) & 0xFF) - for i in range(len(self.fields)): - f = self.fields[i] - if flags & (0x1 << i): - f.type.encode(codec, values[f.name]) - - def docstring(self): - docs = [] - if self.label: - docs.append(self.label) - docs += [d.text for d in self.docs] - s = "\n\n".join([fill(t, 2) for t in docs]) - for f in self.fields: - fdocs = [] - if f.label: - fdocs.append(f.label) - else: - fdocs.append("") - fdocs += [d.text for d in f.docs] - s += "\n\n" + "\n\n".join([fill(fdocs[0], 4, f.name)] + - [fill(t, 4) for t in fdocs[1:]]) - return s - - -class Field(Named, Node, Lookup): - - def __init__(self, name, label, type, children): - Named.__init__(self, name) - Node.__init__(self, children) - self.label = label - self.type = type - self.exceptions = [] - - def default(self): - return None - - def register(self, node): - Named.register(self, node) - node.fields.append(self) - Node.register(self) - - def resolve(self): - self.type = self.lookup(self.type) - Node.resolve(self) - - def __str__(self): - return "%s: %s" % (self.qname, self.type.qname) - -class Struct(Composite): - - def register(self, node): - Composite.register(self, node) - if self.code is not None: - self.spec.structs[self.code] = self - self.spec.structs_by_name[self.name] = self - self.pyname = self.name - self.pydoc = self.docstring() - - def __str__(self): - fields = ",\n ".join(["%s: %s" % (f.name, f.type.qname) - for f in self.fields]) - return "%s {\n %s\n}" % (self.qname, fields) - -class Segment: - - def __init__(self): - self.segment_type = None - - def register(self, node): - self.spec = node.spec - self.klass = node.klass - node.segments.append(self) - Node.register(self) - -class Instruction(Composite, Segment): - - def __init__(self, name, label, code, children): - Composite.__init__(self, name, label, code, 0, 2, children) - Segment.__init__(self) - self.track = None - self.handlers = [] - - def __str__(self): - return "%s(%s)" % (self.qname, ", ".join(["%s: %s" % (f.name, f.type.qname) - for f in self.fields])) - - def register(self, node): - Composite.register(self, node) - self.pyname = self.qname.replace(".", "_") - self.pydoc = self.docstring() - self.spec.instructions[self.pyname] = self - -class Control(Instruction): - - def __init__(self, name, code, label, children): - Instruction.__init__(self, name, code, label, children) - self.response = None - - def register(self, node): - Instruction.register(self, node) - node.controls.append(self) - self.spec.controls[self.code] = self - self.segment_type = self.spec["segment_type.control"].value - self.track = self.spec["track.control"].value - -class Command(Instruction): - - def __init__(self, name, label, code, children): - Instruction.__init__(self, name, label, code, children) - self.result = None - self.exceptions = [] - self.segments = [] - - def register(self, node): - Instruction.register(self, node) - node.commands.append(self) - self.spec.commands[self.code] = self - self.segment_type = self.spec["segment_type.command"].value - self.track = self.spec["track.command"].value - -class Header(Segment, Node): - - def __init__(self, children): - Segment.__init__(self) - Node.__init__(self, children) - self.entries = [] - - def register(self, node): - Segment.register(self, node) - self.segment_type = self.spec["segment_type.header"].value - Node.register(self) - -class Entry(Lookup): - - def __init__(self, type): - self.type = type - - def register(self, node): - self.spec = node.spec - self.klass = node.klass - node.entries.append(self) - - def resolve(self): - self.type = self.lookup(self.type) - -class Body(Segment, Node): - - def __init__(self, children): - Segment.__init__(self) - Node.__init__(self, children) - - def register(self, node): - Segment.register(self, node) - self.segment_type = self.spec["segment_type.body"].value - Node.register(self) - - def resolve(self): pass - -class Class(Named, Coded, Node): - - def __init__(self, name, code, children): - Named.__init__(self, name) - Coded.__init__(self, code) - Node.__init__(self, children) - self.controls = [] - self.commands = [] - - def register(self, node): - Named.register(self, node) - self.klass = self - node.classes.append(self) - Node.register(self) - -class Doc: - - def __init__(self, type, title, text): - self.type = type - self.title = title - self.text = text - - def register(self, node): - node.docs.append(self) - - def resolve(self): pass - -class Role(Named, Node): - - def __init__(self, name, children): - Named.__init__(self, name) - Node.__init__(self, children) - - def register(self, node): - Named.register(self, node) - Node.register(self) - -class Rule(Named, Node): - - def __init__(self, name, children): - Named.__init__(self, name) - Node.__init__(self, children) - - def register(self, node): - Named.register(self, node) - node.rules.append(self) - Node.register(self) - -class Exception(Named, Node): - - def __init__(self, name, error_code, children): - Named.__init__(self, name) - Node.__init__(self, children) - self.error_code = error_code - - def register(self, node): - Named.register(self, node) - node.exceptions.append(self) - Node.register(self) - -class Spec(Node): - - ENCODINGS = { - basestring: "vbin16", - int: "int64", - long: "int64", - float: "float", - None.__class__: "void", - list: "list", - tuple: "list", - dict: "map" - } - - def __init__(self, major, minor, port, children): - Node.__init__(self, children) - self.major = major - self.minor = minor - self.port = port - self.constants = [] - self.classes = [] - self.types = {} - self.qname = None - self.spec = self - self.klass = None - self.instructions = {} - self.controls = {} - self.commands = {} - self.structs = {} - self.structs_by_name = {} - self.enums = {} - - def encoding(self, klass): - if Spec.ENCODINGS.has_key(klass): - return self.named[Spec.ENCODINGS[klass]] - for base in klass.__bases__: - result = self.encoding(base) - if result != None: - return result - -class Implement: - - def __init__(self, handle): - self.handle = handle - - def register(self, node): - node.handlers.append(self.handle) - - def resolve(self): pass - -class Response(Node): - - def __init__(self, name, children): - Node.__init__(self, children) - self.name = name - - def register(self, node): - Node.register(self) - -class Result(Node, Lookup): - - def __init__(self, type, children): - self.type = type - Node.__init__(self, children) - - def register(self, node): - node.result = self - self.qname = node.qname - self.klass = node.klass - self.spec = node.spec - Node.register(self) - - def resolve(self): - self.type = self.lookup(self.type) - Node.resolve(self) - -import mllib - -def num(s): - if s: return int(s, 0) - -REPLACE = {" ": "_", "-": "_"} -KEYWORDS = {"global": "global_", - "return": "return_"} - -def id(name): - name = str(name) - for key, val in REPLACE.items(): - name = name.replace(key, val) - try: - name = KEYWORDS[name] - except KeyError: - pass - return name - -class Loader: - - def __init__(self): - self.class_code = 0 - - def code(self, nd): - c = num(nd["@code"]) - if c is None: - return None - else: - return c | (self.class_code << 8) - - def list(self, q): - result = [] - for nd in q: - result.append(nd.dispatch(self)) - return result - - def children(self, n): - return self.list(n.query["#tag"]) - - def data(self, d): - return d.data - - def do_amqp(self, a): - return Spec(num(a["@major"]), num(a["@minor"]), num(a["@port"]), - self.children(a)) - - def do_type(self, t): - return Primitive(id(t["@name"]), self.code(t), num(t["@fixed-width"]), - num(t["@variable-width"]), self.children(t)) - - def do_constant(self, c): - return Constant(id(c["@name"]), num(c["@value"]), self.children(c)) - - def do_domain(self, d): - return Domain(id(d["@name"]), id(d["@type"]), self.children(d)) - - def do_enum(self, e): - return Anonymous(self.children(e)) - - def do_choice(self, c): - return Choice(id(c["@name"]), num(c["@value"]), self.children(c)) - - def do_class(self, c): - code = num(c["@code"]) - self.class_code = code - children = self.children(c) - children += self.list(c.query["command/result/struct"]) - self.class_code = 0 - return Class(id(c["@name"]), code, children) - - def do_doc(self, doc): - text = reduce(lambda x, y: x + y, self.list(doc.children)) - return Doc(doc["@type"], doc["@title"], text) - - def do_xref(self, x): - return x["@ref"] - - def do_role(self, r): - return Role(id(r["@name"]), self.children(r)) - - def do_control(self, c): - return Control(id(c["@name"]), c["@label"], self.code(c), self.children(c)) - - def do_rule(self, r): - return Rule(id(r["@name"]), self.children(r)) - - def do_implement(self, i): - return Implement(id(i["@handle"])) - - def do_response(self, r): - return Response(id(r["@name"]), self.children(r)) - - def do_field(self, f): - return Field(id(f["@name"]), f["@label"], id(f["@type"]), self.children(f)) - - def do_struct(self, s): - return Struct(id(s["@name"]), s["@label"], self.code(s), num(s["@size"]), - num(s["@pack"]), self.children(s)) - - def do_command(self, c): - return Command(id(c["@name"]), c["@label"], self.code(c), self.children(c)) - - def do_segments(self, s): - return Anonymous(self.children(s)) - - def do_header(self, h): - return Header(self.children(h)) - - def do_entry(self, e): - return Entry(id(e["@type"])) - - def do_body(self, b): - return Body(self.children(b)) - - def do_result(self, r): - type = r["@type"] - if not type: - type = r["struct/@name"] - return Result(id(type), self.list(r.query["#tag", lambda x: x.name != "struct"])) - - def do_exception(self, e): - return Exception(id(e["@name"]), id(e["@error-code"]), self.children(e)) - -def load(xml): - fname = xml + ".pcl" - - if os.path.exists(fname) and mtime(fname) > mtime(__file__): - file = open(fname, "r") - s = cPickle.load(file) - file.close() - else: - doc = mllib.xml_parse(xml) - s = doc["amqp"].dispatch(Loader()) - s.register() - s.resolve() - - try: - file = open(fname, "w") - except IOError: - file = None - - if file: - cPickle.dump(s, file) - file.close() - - return s diff --git a/python/qpid/testlib.py b/python/qpid/testlib.py index b5aa59f586..1439b892ea 100644 --- a/python/qpid/testlib.py +++ b/python/qpid/testlib.py @@ -21,191 +21,13 @@ # Support library for qpid python tests. # -import sys, re, unittest, os, random, logging, traceback -import qpid.client, qpid.spec +import unittest, traceback, socket +import qpid.client, qmf.console import Queue -from fnmatch import fnmatch -from getopt import getopt, GetoptError from qpid.content import Content from qpid.message import Message - -#0-10 support -from qpid.connection import Connection -from qpid.spec010 import load -from qpid.util import connect - -def findmodules(root): - """Find potential python modules under directory root""" - found = [] - for dirpath, subdirs, files in os.walk(root): - modpath = dirpath.replace(os.sep, '.') - if not re.match(r'\.svn$', dirpath): # Avoid SVN directories - for f in files: - match = re.match(r'(.+)\.py$', f) - if match and f != '__init__.py': - found.append('.'.join([modpath, match.group(1)])) - return found - -def default(value, default): - if (value == None): return default - else: return value - -class TestRunner: - - SPEC_FOLDER = "../specs" - - """Runs unit tests. - - Parses command line arguments, provides utility functions for tests, - runs the selected test suite. - """ - - def _die(self, message = None): - if message: print message - print """ -run-tests [options] [test*] -The name of a test is package.module.ClassName.testMethod -Options: - -?/-h/--help : this message - -s/--spec <spec.xml> : URL of AMQP XML specification or one of these abbreviations: - 0-8 - use the default 0-8 specification. - 0-9 - use the default 0-9 specification. - -e/--errata <errata.xml> : file containing amqp XML errata - -b/--broker [<user>[/<password>]@]<host>[:<port>] : broker to connect to - -v/--verbose : verbose - lists tests as they are run. - -d/--debug : enable debug logging. - -i/--ignore <test> : ignore the named test. - -I/--ignore-file : file containing patterns to ignore. - -S/--skip-self-test : skips the client self tests in the 'tests folder' - -F/--spec-folder : folder that contains the specs to be loaded - """ - sys.exit(1) - - def setBroker(self, broker): - rex = re.compile(r""" - # [ <user> [ / <password> ] @] <host> [ :<port> ] - ^ (?: ([^/]*) (?: / ([^@]*) )? @)? ([^:]+) (?: :([0-9]+))?$""", re.X) - match = rex.match(broker) - if not match: self._die("'%s' is not a valid broker" % (broker)) - self.user, self.password, self.host, self.port = match.groups() - self.port = int(default(self.port, 5672)) - self.user = default(self.user, "guest") - self.password = default(self.password, "guest") - - def ignoreFile(self, filename): - f = file(filename) - for line in f.readlines(): self.ignore.append(line.strip()) - f.close() - - def use08spec(self): - "True if we are running with the old 0-8 spec." - # NB: AMQP 0-8 identifies itself as 8-0 for historical reasons. - return self.spec.major==8 and self.spec.minor==0 - - def _parseargs(self, args): - # Defaults - self.setBroker("localhost") - self.verbose = 1 - self.ignore = [] - self.specfile = "0-8" - self.errata = [] - self.skip_self_test = False - - try: - opts, self.tests = getopt(args, "s:e:b:h?dvSi:I:F:", - ["help", "spec", "errata=", "server", - "verbose", "skip-self-test", "ignore", - "ignore-file", "spec-folder"]) - except GetoptError, e: - self._die(str(e)) - for opt, value in opts: - if opt in ("-?", "-h", "--help"): self._die() - if opt in ("-s", "--spec"): self.specfile = value - if opt in ("-e", "--errata"): self.errata.append(value) - if opt in ("-b", "--broker"): self.setBroker(value) - if opt in ("-v", "--verbose"): self.verbose = 2 - if opt in ("-d", "--debug"): logging.basicConfig(level=logging.DEBUG) - if opt in ("-i", "--ignore"): self.ignore.append(value) - if opt in ("-I", "--ignore-file"): self.ignoreFile(value) - if opt in ("-S", "--skip-self-test"): self.skip_self_test = True - if opt in ("-F", "--spec-folder"): TestRunner.SPEC_FOLDER = value - # Abbreviations for default settings. - if (self.specfile == "0-10"): - self.spec = load(self.get_spec_file("amqp.0-10.xml")) - elif (self.specfile == "0-10-errata"): - self.spec = load(self.get_spec_file("amqp.0-10-qpid-errata.xml")) - else: - if (self.specfile == "0-8"): - self.specfile = self.get_spec_file("amqp.0-8.xml") - elif (self.specfile == "0-9"): - self.specfile = self.get_spec_file("amqp.0-9.xml") - self.errata.append(self.get_spec_file("amqp-errata.0-9.xml")) - - if (self.specfile == None): - self._die("No XML specification provided") - print "Using specification from:", self.specfile - - self.spec = qpid.spec.load(self.specfile, *self.errata) - - if len(self.tests) == 0: - if not self.skip_self_test: - self.tests=findmodules("tests") - if self.use08spec(): - self.tests+=findmodules("tests_0-8") - elif (self.spec.major == 99 and self.spec.minor == 0): - self.tests+=findmodules("tests_0-10_preview") - elif (self.spec.major == 0 and self.spec.minor == 10): - self.tests+=findmodules("tests_0-10") - else: - self.tests+=findmodules("tests_0-9") - - def testSuite(self): - class IgnoringTestSuite(unittest.TestSuite): - def addTest(self, test): - if isinstance(test, unittest.TestCase): - for pattern in testrunner.ignore: - if fnmatch(test.id(), pattern): - return - unittest.TestSuite.addTest(self, test) - - # Use our IgnoringTestSuite in the test loader. - unittest.TestLoader.suiteClass = IgnoringTestSuite - return unittest.defaultTestLoader.loadTestsFromNames(self.tests) - - def run(self, args=sys.argv[1:]): - self._parseargs(args) - runner = unittest.TextTestRunner(descriptions=False, - verbosity=self.verbose) - result = runner.run(self.testSuite()) - - if (self.ignore): - print "=======================================" - print "NOTE: the following tests were ignored:" - for t in self.ignore: print t - print "=======================================" - - return result.wasSuccessful() - - def connect(self, host=None, port=None, spec=None, user=None, password=None, tune_params=None): - """Connect to the broker, returns a qpid.client.Client""" - host = host or self.host - port = port or self.port - spec = spec or self.spec - user = user or self.user - password = password or self.password - client = qpid.client.Client(host, port, spec) - if self.use08spec(): - client.start({"LOGIN": user, "PASSWORD": password}, tune_params=tune_params) - else: - client.start("\x00" + user + "\x00" + password, mechanism="PLAIN", tune_params=tune_params) - return client - - def get_spec_file(self, fname): - return TestRunner.SPEC_FOLDER + os.sep + fname - -# Global instance for tests to call connect. -testrunner = TestRunner() - +from qpid.harness import Skipped +from qpid.exceptions import VersionError class TestBase(unittest.TestCase): """Base class for Qpid test cases. @@ -219,13 +41,16 @@ class TestBase(unittest.TestCase): resources to clean up later. """ + def configure(self, config): + self.config = config + def setUp(self): self.queues = [] self.exchanges = [] self.client = self.connect() self.channel = self.client.channel(1) self.version = (self.client.spec.major, self.client.spec.minor) - if self.version == (8, 0): + if self.version == (8, 0) or self.version == (0, 9): self.channel.channel_open() else: self.channel.session_open() @@ -245,9 +70,26 @@ class TestBase(unittest.TestCase): else: self.client.close() - def connect(self, *args, **keys): + def connect(self, host=None, port=None, user=None, password=None, tune_params=None): """Create a new connction, return the Client object""" - return testrunner.connect(*args, **keys) + host = host or self.config.broker.host + port = port or self.config.broker.port or 5672 + user = user or "guest" + password = password or "guest" + client = qpid.client.Client(host, port) + try: + if client.spec.major == 8 and client.spec.minor == 0: + client.start({"LOGIN": user, "PASSWORD": password}, tune_params=tune_params) + else: + client.start("\x00" + user + "\x00" + password, mechanism="PLAIN", tune_params=tune_params) + except qpid.client.Closed, e: + if isinstance(e.args[0], VersionError): + raise Skipped(e.args[0]) + else: + raise e + except socket.error, e: + raise Skipped(e) + return client def queue_declare(self, channel=None, *args, **keys): channel = channel or self.channel @@ -271,24 +113,15 @@ class TestBase(unittest.TestCase): def consume(self, queueName): """Consume from named queue returns the Queue object.""" - if testrunner.use08spec(): - reply = self.channel.basic_consume(queue=queueName, no_ack=True) - return self.client.queue(reply.consumer_tag) - else: - if not "uniqueTag" in dir(self): self.uniqueTag = 1 - else: self.uniqueTag += 1 - consumer_tag = "tag" + str(self.uniqueTag) - self.channel.message_subscribe(queue=queueName, destination=consumer_tag) - self.channel.message_flow(destination=consumer_tag, unit=0, value=0xFFFFFFFF) - self.channel.message_flow(destination=consumer_tag, unit=1, value=0xFFFFFFFF) - return self.client.queue(consumer_tag) + reply = self.channel.basic_consume(queue=queueName, no_ack=True) + return self.client.queue(reply.consumer_tag) def subscribe(self, channel=None, **keys): channel = channel or self.channel consumer_tag = keys["destination"] channel.message_subscribe(**keys) - channel.message_flow(destination=consumer_tag, unit=0, value=0xFFFFFFFF) - channel.message_flow(destination=consumer_tag, unit=1, value=0xFFFFFFFF) + channel.message_flow(destination=consumer_tag, unit=0, value=0xFFFFFFFFL) + channel.message_flow(destination=consumer_tag, unit=1, value=0xFFFFFFFFL) def assertEmpty(self, queue): """Assert that the queue is empty""" @@ -302,24 +135,14 @@ class TestBase(unittest.TestCase): Publish to exchange and assert queue.get() returns the same message. """ body = self.uniqueString() - if testrunner.use08spec(): - self.channel.basic_publish( - exchange=exchange, - content=Content(body, properties=properties), - routing_key=routing_key) - else: - self.channel.message_transfer( - destination=exchange, - content=Content(body, properties={'application_headers':properties,'routing_key':routing_key})) + self.channel.basic_publish( + exchange=exchange, + content=Content(body, properties=properties), + routing_key=routing_key) msg = queue.get(timeout=1) - if testrunner.use08spec(): - self.assertEqual(body, msg.content.body) - if (properties): - self.assertEqual(properties, msg.content.properties) - else: - self.assertEqual(body, msg.content.body) - if (properties): - self.assertEqual(properties, msg.content['application_headers']) + self.assertEqual(body, msg.content.body) + if (properties): + self.assertEqual(properties, msg.content.properties) def assertPublishConsume(self, queue="", exchange="", routing_key="", properties=None): """ @@ -329,7 +152,7 @@ class TestBase(unittest.TestCase): self.assertPublishGet(self.consume(queue), exchange, routing_key, properties) def assertChannelException(self, expectedCode, message): - if self.version == (8, 0): #or "transitional" in self.client.spec.file: + if self.version == (8, 0) or self.version == (0, 9): if not isinstance(message, Message): self.fail("expected channel_close method, got %s" % (message)) self.assertEqual("channel", message.method.klass.name) self.assertEqual("close", message.method.name) @@ -346,31 +169,58 @@ class TestBase(unittest.TestCase): self.assertEqual("close", message.method.name) self.assertEqual(expectedCode, message.reply_code) +#0-10 support +from qpid.connection import Connection +from qpid.util import connect, ssl, URL + class TestBase010(unittest.TestCase): """ Base class for Qpid test cases. using the final 0-10 spec """ + def configure(self, config): + self.config = config + self.broker = config.broker + self.defines = self.config.defines + def setUp(self): - spec = testrunner.spec - self.conn = Connection(connect(testrunner.host, testrunner.port), spec, - username=testrunner.user, password=testrunner.password) - self.conn.start(timeout=10) + self.conn = self.connect() self.session = self.conn.session("test-session", timeout=10) + self.qmf = None + + def startQmf(self, handler=None): + self.qmf = qmf.console.Session(handler) + self.qmf_broker = self.qmf.addBroker(str(self.broker)) def connect(self, host=None, port=None): - spec = testrunner.spec - conn = Connection(connect(host or testrunner.host, port or testrunner.port), spec) - conn.start(timeout=10) + url = self.broker + if url.scheme == URL.AMQPS: + default_port = 5671 + else: + default_port = 5672 + try: + sock = connect(host or url.host, port or url.port or default_port) + except socket.error, e: + raise Skipped(e) + if url.scheme == URL.AMQPS: + sock = ssl(sock) + conn = Connection(sock, username=url.user or "guest", + password=url.password or "guest") + try: + conn.start(timeout=10) + except VersionError, e: + raise Skipped(e) return conn def tearDown(self): if not self.session.error(): self.session.close(timeout=10) self.conn.close(timeout=10) + if self.qmf: + self.qmf.delBroker(self.qmf_broker) def subscribe(self, session=None, **keys): session = session or self.session consumer_tag = keys["destination"] session.message_subscribe(**keys) - session.message_flow(destination=consumer_tag, unit=0, value=0xFFFFFFFF) - session.message_flow(destination=consumer_tag, unit=1, value=0xFFFFFFFF) + session.message_flow(destination=consumer_tag, unit=0, value=0xFFFFFFFFL) + session.message_flow(destination=consumer_tag, unit=1, value=0xFFFFFFFFL) diff --git a/python/qpid/tests/__init__.py b/python/qpid/tests/__init__.py new file mode 100644 index 0000000000..2f0fcfdf67 --- /dev/null +++ b/python/qpid/tests/__init__.py @@ -0,0 +1,28 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +class Test: + + def __init__(self, name): + self.name = name + + def configure(self, config): + self.config = config + +import address, framing, mimetype, messaging diff --git a/python/qpid/tests/address.py b/python/qpid/tests/address.py new file mode 100644 index 0000000000..7c101eee5e --- /dev/null +++ b/python/qpid/tests/address.py @@ -0,0 +1,199 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from qpid.tests import Test +from qpid.address import lex, parse, ParseError, EOF, ID, NUMBER, SYM, WSPACE +from parser import ParserBase + +class AddressTests(ParserBase, Test): + + EXCLUDE = (WSPACE, EOF) + + def do_lex(self, st): + return lex(st) + + def do_parse(self, st): + return parse(st) + + def valid(self, addr, name=None, subject=None, options=None): + ParserBase.valid(self, addr, (name, subject, options)) + + def testDashInId1(self): + self.lex("foo-bar", ID) + + def testDashInId2(self): + self.lex("foo-3", ID) + + def testDashAlone1(self): + self.lex("foo - bar", ID, SYM, ID) + + def testDashAlone2(self): + self.lex("foo - 3", ID, SYM, NUMBER) + + def testLeadingDash(self): + self.lex("-foo", SYM, ID) + + def testTrailingDash(self): + self.lex("foo-", ID, SYM) + + def testNegativeNum(self): + self.lex("-3", NUMBER) + + def testHash(self): + self.valid("foo/bar.#", "foo", "bar.#") + + def testStar(self): + self.valid("foo/bar.*", "foo", "bar.*") + + def testColon(self): + self.valid("foo.bar/baz.qux:moo:arf", "foo.bar", "baz.qux:moo:arf") + + def testOptions(self): + self.valid("foo.bar/baz.qux:moo:arf; {key: value}", + "foo.bar", "baz.qux:moo:arf", {"key": "value"}) + + def testOptionsTrailingComma(self): + self.valid("name/subject; {key: value,}", "name", "subject", + {"key": "value"}) + + def testSemiSubject(self): + self.valid("foo.bar/'baz.qux;moo:arf'; {key: value}", + "foo.bar", "baz.qux;moo:arf", {"key": "value"}) + + def testCommaSubject(self): + self.valid("foo.bar/baz.qux.{moo,arf}", "foo.bar", "baz.qux.{moo,arf}") + + def testCommaSubjectOptions(self): + self.valid("foo.bar/baz.qux.{moo,arf}; {key: value}", "foo.bar", + "baz.qux.{moo,arf}", {"key": "value"}) + + def testUnbalanced(self): + self.valid("foo.bar/baz.qux.{moo,arf; {key: value}", "foo.bar", + "baz.qux.{moo,arf", {"key": "value"}) + + def testSlashQuote(self): + self.valid("foo.bar\\/baz.qux.{moo,arf; {key: value}", + "foo.bar/baz.qux.{moo,arf", + None, {"key": "value"}) + + def testSlashHexEsc1(self): + self.valid("foo.bar\\x00baz.qux.{moo,arf; {key: value}", + "foo.bar\x00baz.qux.{moo,arf", + None, {"key": "value"}) + + def testSlashHexEsc2(self): + self.valid("foo.bar\\xffbaz.qux.{moo,arf; {key: value}", + "foo.bar\xffbaz.qux.{moo,arf", + None, {"key": "value"}) + + def testSlashHexEsc3(self): + self.valid("foo.bar\\xFFbaz.qux.{moo,arf; {key: value}", + "foo.bar\xFFbaz.qux.{moo,arf", + None, {"key": "value"}) + + def testSlashUnicode1(self): + self.valid("foo.bar\\u1234baz.qux.{moo,arf; {key: value}", + u"foo.bar\u1234baz.qux.{moo,arf", None, {"key": "value"}) + + def testSlashUnicode2(self): + self.valid("foo.bar\\u0000baz.qux.{moo,arf; {key: value}", + u"foo.bar\u0000baz.qux.{moo,arf", None, {"key": "value"}) + + def testSlashUnicode3(self): + self.valid("foo.bar\\uffffbaz.qux.{moo,arf; {key: value}", + u"foo.bar\uffffbaz.qux.{moo,arf", None, {"key": "value"}) + + def testSlashUnicode4(self): + self.valid("foo.bar\\uFFFFbaz.qux.{moo,arf; {key: value}", + u"foo.bar\uFFFFbaz.qux.{moo,arf", None, {"key": "value"}) + + def testNoName(self): + self.invalid("; {key: value}", + "unexpected token SEMI(';') line:1,0:; {key: value}") + + def testEmpty(self): + self.invalid("", "unexpected token EOF line:1,0:") + + def testNoNameSlash(self): + self.invalid("/asdf; {key: value}", + "unexpected token SLASH('/') line:1,0:/asdf; {key: value}") + + def testBadOptions1(self): + self.invalid("name/subject; {", + "expecting (ID, RBRACE), got EOF line:1,15:name/subject; {") + + def testBadOptions2(self): + self.invalid("name/subject; { 3", + "expecting (ID, RBRACE), got NUMBER('3') " + "line:1,16:name/subject; { 3") + + def testBadOptions3(self): + self.invalid("name/subject; { key:", + "expecting (NUMBER, STRING, ID, LBRACE, LBRACK), got EOF " + "line:1,20:name/subject; { key:") + + def testBadOptions4(self): + self.invalid("name/subject; { key: value", + "expecting (COMMA, RBRACE), got EOF " + "line:1,26:name/subject; { key: value") + + def testBadOptions5(self): + self.invalid("name/subject; { key: value asdf", + "expecting (COMMA, RBRACE), got ID('asdf') " + "line:1,27:name/subject; { key: value asdf") + + def testBadOptions6(self): + self.invalid("name/subject; { key: value,", + "expecting (ID, RBRACE), got EOF " + "line:1,27:name/subject; { key: value,") + + def testBadOptions7(self): + self.invalid("name/subject; { key: value } asdf", + "expecting EOF, got ID('asdf') " + "line:1,29:name/subject; { key: value } asdf") + + def testList1(self): + self.valid("name/subject; { key: [] }", "name", "subject", {"key": []}) + + def testList2(self): + self.valid("name/subject; { key: ['one'] }", "name", "subject", {"key": ['one']}) + + def testList3(self): + self.valid("name/subject; { key: [1, 2, 3] }", "name", "subject", + {"key": [1, 2, 3]}) + + def testList4(self): + self.valid("name/subject; { key: [1, [2, 3], 4] }", "name", "subject", + {"key": [1, [2, 3], 4]}) + + def testBadList1(self): + self.invalid("name/subject; { key: [ }", "expecting (NUMBER, STRING, ID, LBRACE, LBRACK), " + "got RBRACE('}') line:1,23:name/subject; { key: [ }") + + def testBadList2(self): + self.invalid("name/subject; { key: [ 1 }", "expecting (COMMA, RBRACK), " + "got RBRACE('}') line:1,25:name/subject; { key: [ 1 }") + + def testBadList3(self): + self.invalid("name/subject; { key: [ 1 2 }", "expecting (COMMA, RBRACK), " + "got NUMBER('2') line:1,25:name/subject; { key: [ 1 2 }") + + def testBadList4(self): + self.invalid("name/subject; { key: [ 1 2 ] }", "expecting (COMMA, RBRACK), " + "got NUMBER('2') line:1,25:name/subject; { key: [ 1 2 ] }") diff --git a/python/qpid/tests/framing.py b/python/qpid/tests/framing.py new file mode 100644 index 0000000000..0b33df8b9a --- /dev/null +++ b/python/qpid/tests/framing.py @@ -0,0 +1,289 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +# setup, usage, teardown, errors(sync), errors(async), stress, soak, +# boundary-conditions, config + +from qpid.tests import Test +from qpid.framing import * + +class Base(Test): + + def cmp_frames(self, frm1, frm2): + assert frm1.flags == frm2.flags, "expected: %r, got %r" % (frm1, frm2) + assert frm1.type == frm2.type, "expected: %r, got %r" % (frm1, frm2) + assert frm1.track == frm2.track, "expected: %r, got %r" % (frm1, frm2) + assert frm1.channel == frm2.channel, "expected: %r, got %r" % (frm1, frm2) + assert frm1.payload == frm2.payload, "expected: %r, got %r" % (frm1, frm2) + + def cmp_segments(self, seg1, seg2): + assert seg1.first == seg2.first, "expected: %r, got %r" % (seg1, seg2) + assert seg1.last == seg2.last, "expected: %r, got %r" % (seg1, seg2) + assert seg1.type == seg2.type, "expected: %r, got %r" % (seg1, seg2) + assert seg1.track == seg2.track, "expected: %r, got %r" % (seg1, seg2) + assert seg1.channel == seg2.channel, "expected: %r, got %r" % (seg1, seg2) + assert seg1.payload == seg2.payload, "expected: %r, got %r" % (seg1, seg2) + + def cmp_list(self, l1, l2): + if l1 is None: + assert l2 is None + return + + assert len(l1) == len(l2) + for v1, v2 in zip(l1, l2): + if isinstance(v1, Compound): + self.cmp_ops(v1, v2) + else: + assert v1 == v2 + + def cmp_ops(self, op1, op2): + if op1 is None: + assert op2 is None + return + + assert op1.__class__ == op2.__class__ + cls = op1.__class__ + assert op1.NAME == op2.NAME + assert op1.CODE == op2.CODE + assert op1.FIELDS == op2.FIELDS + for f in cls.FIELDS: + v1 = getattr(op1, f.name) + v2 = getattr(op2, f.name) + if COMPOUND.has_key(f.type) or f.type == "struct32": + self.cmp_ops(v1, v2) + elif f.type in ("list", "array"): + self.cmp_list(v1, v2) + else: + assert v1 == v2, "expected: %r, got %r" % (v1, v2) + + if issubclass(cls, Command) or issubclass(cls, Control): + assert op1.channel == op2.channel + + if issubclass(cls, Command): + assert op1.sync == op2.sync, "expected: %r, got %r" % (op1.sync, op2.sync) + assert (op1.headers is None and op2.headers is None) or \ + (op1.headers is not None and op2.headers is not None) + if op1.headers is not None: + assert len(op1.headers) == len(op2.headers) + for h1, h2 in zip(op1.headers, op2.headers): + self.cmp_ops(h1, h2) + +class FrameTest(Base): + + def enc_dec(self, frames, encoded=None): + enc = FrameEncoder() + dec = FrameDecoder() + + enc.write(*frames) + bytes = enc.read() + if encoded is not None: + assert bytes == encoded, "expected %r, got %r" % (encoded, bytes) + dec.write(bytes) + dframes = dec.read() + + assert len(frames) == len(dframes) + for f, df, in zip(frames, dframes): + self.cmp_frames(f, df) + + def testEmpty(self): + self.enc_dec([Frame(0, 0, 0, 0, "")], + "\x00\x00\x00\x0c\x00\x00\x00\x00\x00\x00\x00\x00") + + def testSingle(self): + self.enc_dec([Frame(0, 0, 0, 1, "payload")], + "\x00\x00\x00\x13\x00\x00\x00\x01\x00\x00\x00\x00payload") + + def testMaxChannel(self): + self.enc_dec([Frame(0, 0, 0, 65535, "max-channel")], + "\x00\x00\x00\x17\x00\x00\xff\xff\x00\x00\x00\x00max-channel") + + def testMaxType(self): + self.enc_dec([Frame(0, 255, 0, 0, "max-type")], + "\x00\xff\x00\x14\x00\x00\x00\x00\x00\x00\x00\x00max-type") + + def testMaxTrack(self): + self.enc_dec([Frame(0, 0, 15, 0, "max-track")], + "\x00\x00\x00\x15\x00\x0f\x00\x00\x00\x00\x00\x00max-track") + + def testSequence(self): + self.enc_dec([Frame(0, 0, 0, 0, "zero"), + Frame(0, 0, 0, 1, "one"), + Frame(0, 0, 1, 0, "two"), + Frame(0, 0, 1, 1, "three"), + Frame(0, 1, 0, 0, "four"), + Frame(0, 1, 0, 1, "five"), + Frame(0, 1, 1, 0, "six"), + Frame(0, 1, 1, 1, "seven"), + Frame(1, 0, 0, 0, "eight"), + Frame(1, 0, 0, 1, "nine"), + Frame(1, 0, 1, 0, "ten"), + Frame(1, 0, 1, 1, "eleven"), + Frame(1, 1, 0, 0, "twelve"), + Frame(1, 1, 0, 1, "thirteen"), + Frame(1, 1, 1, 0, "fourteen"), + Frame(1, 1, 1, 1, "fifteen")]) + +class SegmentTest(Base): + + def enc_dec(self, segments, frames=None, interleave=None, max_payload=Frame.MAX_PAYLOAD): + enc = SegmentEncoder(max_payload) + dec = SegmentDecoder() + + enc.write(*segments) + frms = enc.read() + if frames is not None: + assert len(frames) == len(frms), "expected %s, got %s" % (frames, frms) + for f1, f2 in zip(frames, frms): + self.cmp_frames(f1, f2) + if interleave is not None: + ilvd = [] + for f in frms: + ilvd.append(f) + if interleave: + ilvd.append(interleave.pop(0)) + ilvd.extend(interleave) + dec.write(*ilvd) + else: + dec.write(*frms) + segs = dec.read() + assert len(segments) == len(segs) + for s1, s2 in zip(segments, segs): + self.cmp_segments(s1, s2) + + def testEmpty(self): + self.enc_dec([Segment(True, True, 0, 0, 0, "")], + [Frame(FIRST_FRM | LAST_FRM | FIRST_SEG | LAST_SEG, 0, 0, 0, + "")]) + + def testSingle(self): + self.enc_dec([Segment(True, True, 0, 0, 0, "payload")], + [Frame(FIRST_FRM | LAST_FRM | FIRST_SEG | LAST_SEG, 0, 0, 0, + "payload")]) + + def testMaxChannel(self): + self.enc_dec([Segment(False, False, 0, 0, 65535, "max-channel")], + [Frame(FIRST_FRM | LAST_FRM, 0, 0, 65535, "max-channel")]) + + def testMaxType(self): + self.enc_dec([Segment(False, False, 255, 0, 0, "max-type")], + [Frame(FIRST_FRM | LAST_FRM, 255, 0, 0, "max-type")]) + + def testMaxTrack(self): + self.enc_dec([Segment(False, False, 0, 15, 0, "max-track")], + [Frame(FIRST_FRM | LAST_FRM, 0, 15, 0, "max-track")]) + + def testSequence(self): + self.enc_dec([Segment(True, False, 0, 0, 0, "one"), + Segment(False, False, 0, 0, 0, "two"), + Segment(False, True, 0, 0, 0, "three")], + [Frame(FIRST_FRM | LAST_FRM | FIRST_SEG, 0, 0, 0, "one"), + Frame(FIRST_FRM | LAST_FRM, 0, 0, 0, "two"), + Frame(FIRST_FRM | LAST_FRM | LAST_SEG, 0, 0, 0, "three")]) + + def testInterleaveChannel(self): + frames = [Frame(0, 0, 0, 0, chr(ord("a") + i)) for i in range(7)] + frames[0].flags |= FIRST_FRM + frames[-1].flags |= LAST_FRM + + ilvd = [Frame(0, 0, 0, 1, chr(ord("a") + i)) for i in range(7)] + + self.enc_dec([Segment(False, False, 0, 0, 0, "abcdefg")], frames, ilvd, max_payload=1) + + def testInterleaveTrack(self): + frames = [Frame(0, 0, 0, 0, "%c%c" % (ord("a") + i, ord("a") + i + 1)) + for i in range(0, 8, 2)] + frames[0].flags |= FIRST_FRM + frames[-1].flags |= LAST_FRM + + ilvd = [Frame(0, 0, 1, 0, "%c%c" % (ord("a") + i, ord("a") + i + 1)) + for i in range(0, 8, 2)] + + self.enc_dec([Segment(False, False, 0, 0, 0, "abcdefgh")], frames, ilvd, max_payload=2) + +from qpid.ops import * + +class OpTest(Base): + + def enc_dec(self, ops): + enc = OpEncoder() + dec = OpDecoder() + enc.write(*ops) + segs = enc.read() + dec.write(*segs) + dops = dec.read() + assert len(ops) == len(dops) + for op1, op2 in zip(ops, dops): + self.cmp_ops(op1, op2) + + def testEmtpyMT(self): + self.enc_dec([MessageTransfer()]) + + def testEmptyMTSync(self): + self.enc_dec([MessageTransfer(sync=True)]) + + def testMT(self): + self.enc_dec([MessageTransfer(destination="asdf")]) + + def testSyncMT(self): + self.enc_dec([MessageTransfer(destination="asdf", sync=True)]) + + def testEmptyPayloadMT(self): + self.enc_dec([MessageTransfer(payload="")]) + + def testPayloadMT(self): + self.enc_dec([MessageTransfer(payload="test payload")]) + + def testHeadersEmptyPayloadMT(self): + self.enc_dec([MessageTransfer(headers=[DeliveryProperties()])]) + + def testHeadersPayloadMT(self): + self.enc_dec([MessageTransfer(headers=[DeliveryProperties()], payload="test payload")]) + + def testMultiHeadersEmptyPayloadMT(self): + self.enc_dec([MessageTransfer(headers=[DeliveryProperties(), MessageProperties()])]) + + def testMultiHeadersPayloadMT(self): + self.enc_dec([MessageTransfer(headers=[MessageProperties(), DeliveryProperties()], payload="test payload")]) + + def testContentTypeHeadersPayloadMT(self): + self.enc_dec([MessageTransfer(headers=[MessageProperties(content_type="text/plain")], payload="test payload")]) + + def testMulti(self): + self.enc_dec([MessageTransfer(), + MessageTransfer(sync=True), + MessageTransfer(destination="one"), + MessageTransfer(destination="two", sync=True), + MessageTransfer(destination="three", payload="test payload")]) + + def testControl(self): + self.enc_dec([SessionAttach(name="asdf")]) + + def testMixed(self): + self.enc_dec([SessionAttach(name="fdsa"), MessageTransfer(destination="test")]) + + def testChannel(self): + self.enc_dec([SessionAttach(name="asdf", channel=3), MessageTransfer(destination="test", channel=1)]) + + def testCompound(self): + self.enc_dec([MessageTransfer(headers=[MessageProperties(reply_to=ReplyTo(exchange="exch", routing_key="rk"))])]) + + def testListCompound(self): + self.enc_dec([ExecutionResult(value=RecoverResult(in_doubt=[Xid(global_id="one"), + Xid(global_id="two"), + Xid(global_id="three")]))]) diff --git a/python/qpid/tests/messaging.py b/python/qpid/tests/messaging.py new file mode 100644 index 0000000000..f2a270192e --- /dev/null +++ b/python/qpid/tests/messaging.py @@ -0,0 +1,929 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +# setup, usage, teardown, errors(sync), errors(async), stress, soak, +# boundary-conditions, config + +import time +from qpid import compat +from qpid.tests import Test +from qpid.harness import Skipped +from qpid.messaging import Connection, ConnectError, Disconnected, Empty, \ + InsufficientCapacity, Message, ReceiveError, SendError, SessionError, \ + UNLIMITED, uuid4 +from Queue import Queue, Empty as QueueEmpty + +class Base(Test): + + def setup_connection(self): + return None + + def setup_session(self): + return None + + def setup_sender(self): + return None + + def setup_receiver(self): + return None + + def setup(self): + self.test_id = uuid4() + self.broker = self.config.broker + try: + self.conn = self.setup_connection() + except ConnectError, e: + raise Skipped(e) + self.ssn = self.setup_session() + self.snd = self.setup_sender() + if self.snd is not None: + self.snd.durable = self.durable() + self.rcv = self.setup_receiver() + + def teardown(self): + if self.conn is not None and self.conn.connected(): + self.conn.close() + + def content(self, base, count = None): + if count is None: + return "%s[%s]" % (base, self.test_id) + else: + return "%s[%s, %s]" % (base, count, self.test_id) + + def ping(self, ssn): + PING_Q = 'ping-queue; {create: always, delete: always}' + # send a message + sender = ssn.sender(PING_Q, durable=self.durable()) + content = self.content("ping") + sender.send(content) + receiver = ssn.receiver(PING_Q) + msg = receiver.fetch(0) + ssn.acknowledge() + assert msg.content == content, "expected %r, got %r" % (content, msg.content) + + def drain(self, rcv, limit=None, timeout=0, expected=None): + contents = [] + try: + while limit is None or len(contents) < limit: + contents.append(rcv.fetch(timeout=timeout).content) + except Empty: + pass + if expected is not None: + assert expected == contents, "expected %s, got %s" % (expected, contents) + return contents + + def assertEmpty(self, rcv): + contents = self.drain(rcv) + assert len(contents) == 0, "%s is supposed to be empty: %s" % (rcv, contents) + + def assertPending(self, rcv, expected): + p = rcv.pending() + assert p == expected, "expected %s, got %s" % (expected, p) + + def sleep(self): + time.sleep(self.delay()) + + def delay(self): + return float(self.config.defines.get("delay", "2")) + + def get_bool(self, name): + return self.config.defines.get(name, "false").lower() in ("true", "yes", "1") + + def durable(self): + return self.get_bool("durable") + + def reconnect(self): + return self.get_bool("reconnect") + +class SetupTests(Base): + + def testOpen(self): + # XXX: need to flesh out URL support/syntax + self.conn = Connection.open(self.broker.host, self.broker.port, + reconnect=self.reconnect()) + self.ping(self.conn.session()) + + def testConnect(self): + # XXX: need to flesh out URL support/syntax + self.conn = Connection(self.broker.host, self.broker.port, + reconnect=self.reconnect()) + self.conn.connect() + self.ping(self.conn.session()) + + def testConnectError(self): + try: + self.conn = Connection.open("localhost", 0) + assert False, "connect succeeded" + except ConnectError, e: + # XXX: should verify that e includes appropriate diagnostic info + pass + +class ConnectionTests(Base): + + def setup_connection(self): + return Connection.open(self.broker.host, self.broker.port, + reconnect=self.reconnect()) + + def testSessionAnon(self): + ssn1 = self.conn.session() + ssn2 = self.conn.session() + self.ping(ssn1) + self.ping(ssn2) + assert ssn1 is not ssn2 + + def testSessionNamed(self): + ssn1 = self.conn.session("one") + ssn2 = self.conn.session("two") + self.ping(ssn1) + self.ping(ssn2) + assert ssn1 is not ssn2 + assert ssn1 is self.conn.session("one") + assert ssn2 is self.conn.session("two") + + def testDisconnect(self): + ssn = self.conn.session() + self.ping(ssn) + self.conn.disconnect() + try: + self.ping(ssn) + assert False, "ping succeeded" + except Disconnected: + # this is the expected failure when pinging on a disconnected + # connection + pass + self.conn.connect() + self.ping(ssn) + + def testClose(self): + self.conn.close() + assert not self.conn.connected() + +ACK_QC = 'test-ack-queue; {create: always}' +ACK_QD = 'test-ack-queue; {delete: always}' + +class SessionTests(Base): + + def setup_connection(self): + return Connection.open(self.broker.host, self.broker.port, + reconnect=self.reconnect()) + + def setup_session(self): + return self.conn.session() + + def testSender(self): + snd = self.ssn.sender('test-snd-queue; {create: sender, delete: receiver}', + durable=self.durable()) + snd2 = self.ssn.sender(snd.target, durable=self.durable()) + assert snd is not snd2 + snd2.close() + + content = self.content("testSender") + snd.send(content) + rcv = self.ssn.receiver(snd.target) + msg = rcv.fetch(0) + assert msg.content == content + self.ssn.acknowledge(msg) + + def testReceiver(self): + rcv = self.ssn.receiver('test-rcv-queue; {create: always}') + rcv2 = self.ssn.receiver(rcv.source) + assert rcv is not rcv2 + rcv2.close() + + content = self.content("testReceiver") + snd = self.ssn.sender(rcv.source, durable=self.durable()) + snd.send(content) + msg = rcv.fetch(0) + assert msg.content == content + self.ssn.acknowledge(msg) + snd2 = self.ssn.receiver('test-rcv-queue; {delete: always}') + + def testNextReceiver(self): + ADDR = 'test-next-rcv-queue; {create: always, delete: always}' + rcv1 = self.ssn.receiver(ADDR, capacity=UNLIMITED) + rcv2 = self.ssn.receiver(ADDR, capacity=UNLIMITED) + rcv3 = self.ssn.receiver(ADDR, capacity=UNLIMITED) + + snd = self.ssn.sender(ADDR) + + msgs = [] + for i in range(10): + content = self.content("testNextReceiver", i) + snd.send(content) + msgs.append(content) + + fetched = [] + try: + while True: + rcv = self.ssn.next_receiver(timeout=self.delay()) + assert rcv in (rcv1, rcv2, rcv3) + assert rcv.pending() > 0 + fetched.append(rcv.fetch().content) + except Empty: + pass + assert msgs == fetched, "expecting %s, got %s" % (msgs, fetched) + self.ssn.acknowledge() + + # XXX, we need a convenient way to assert that required queues are + # empty on setup, and possibly also to drain queues on teardown + def ackTest(self, acker, ack_capacity=None): + # send a bunch of messages + snd = self.ssn.sender(ACK_QC, durable=self.durable()) + contents = [self.content("ackTest", i) for i in range(15)] + for c in contents: + snd.send(c) + + # drain the queue, verify the messages are there and then close + # without acking + rcv = self.ssn.receiver(ACK_QC) + self.drain(rcv, expected=contents) + self.ssn.close() + + # drain the queue again, verify that they are all the messages + # were requeued, and ack this time before closing + self.ssn = self.conn.session() + if ack_capacity is not None: + self.ssn.ack_capacity = ack_capacity + rcv = self.ssn.receiver(ACK_QC) + self.drain(rcv, expected=contents) + acker(self.ssn) + self.ssn.close() + + # drain the queue a final time and verify that the messages were + # dequeued + self.ssn = self.conn.session() + rcv = self.ssn.receiver(ACK_QD) + self.assertEmpty(rcv) + + def testAcknowledge(self): + self.ackTest(lambda ssn: ssn.acknowledge()) + + def testAcknowledgeAsync(self): + self.ackTest(lambda ssn: ssn.acknowledge(sync=False)) + + def testAcknowledgeAsyncAckCap0(self): + try: + try: + self.ackTest(lambda ssn: ssn.acknowledge(sync=False), 0) + assert False, "acknowledge shouldn't succeed with ack_capacity of zero" + except InsufficientCapacity: + pass + finally: + self.ssn.ack_capacity = UNLIMITED + self.drain(self.ssn.receiver(ACK_QD)) + self.ssn.acknowledge() + + def testAcknowledgeAsyncAckCap1(self): + self.ackTest(lambda ssn: ssn.acknowledge(sync=False), 1) + + def testAcknowledgeAsyncAckCap5(self): + self.ackTest(lambda ssn: ssn.acknowledge(sync=False), 5) + + def testAcknowledgeAsyncAckCapUNLIMITED(self): + self.ackTest(lambda ssn: ssn.acknowledge(sync=False), UNLIMITED) + + def send(self, ssn, queue, base, count=1): + snd = ssn.sender(queue, durable=self.durable()) + contents = [] + for i in range(count): + c = self.content(base, i) + snd.send(c) + contents.append(c) + snd.close() + return contents + + def txTest(self, commit): + TX_Q = 'test-tx-queue; {create: sender, delete: receiver}' + TX_Q_COPY = 'test-tx-queue-copy; {create: always, delete: always}' + txssn = self.conn.session(transactional=True) + contents = self.send(self.ssn, TX_Q, "txTest", 3) + txrcv = txssn.receiver(TX_Q) + txsnd = txssn.sender(TX_Q_COPY, durable=self.durable()) + rcv = self.ssn.receiver(txrcv.source) + copy_rcv = self.ssn.receiver(txsnd.target) + self.assertEmpty(copy_rcv) + for i in range(3): + m = txrcv.fetch(0) + txsnd.send(m) + self.assertEmpty(copy_rcv) + txssn.acknowledge() + if commit: + txssn.commit() + self.assertEmpty(rcv) + assert contents == self.drain(copy_rcv) + else: + txssn.rollback() + assert contents == self.drain(rcv) + self.assertEmpty(copy_rcv) + self.ssn.acknowledge() + + def testCommit(self): + self.txTest(True) + + def testRollback(self): + self.txTest(False) + + def txTestSend(self, commit): + TX_SEND_Q = 'test-tx-send-queue; {create: sender, delete: receiver}' + txssn = self.conn.session(transactional=True) + contents = self.send(txssn, TX_SEND_Q, "txTestSend", 3) + rcv = self.ssn.receiver(TX_SEND_Q) + self.assertEmpty(rcv) + + if commit: + txssn.commit() + assert contents == self.drain(rcv) + self.ssn.acknowledge() + else: + txssn.rollback() + self.assertEmpty(rcv) + txssn.commit() + self.assertEmpty(rcv) + + def testCommitSend(self): + self.txTestSend(True) + + def testRollbackSend(self): + self.txTestSend(False) + + def txTestAck(self, commit): + TX_ACK_QC = 'test-tx-ack-queue; {create: always}' + TX_ACK_QD = 'test-tx-ack-queue; {delete: always}' + txssn = self.conn.session(transactional=True) + txrcv = txssn.receiver(TX_ACK_QC) + self.assertEmpty(txrcv) + contents = self.send(self.ssn, TX_ACK_QC, "txTestAck", 3) + assert contents == self.drain(txrcv) + + if commit: + txssn.acknowledge() + else: + txssn.rollback() + drained = self.drain(txrcv) + assert contents == drained, "expected %s, got %s" % (contents, drained) + txssn.acknowledge() + txssn.rollback() + assert contents == self.drain(txrcv) + txssn.commit() # commit without ack + self.assertEmpty(txrcv) + + txssn.close() + + txssn = self.conn.session(transactional=True) + txrcv = txssn.receiver(TX_ACK_QC) + assert contents == self.drain(txrcv) + txssn.acknowledge() + txssn.commit() + rcv = self.ssn.receiver(TX_ACK_QD) + self.assertEmpty(rcv) + txssn.close() + self.assertEmpty(rcv) + + def testCommitAck(self): + self.txTestAck(True) + + def testRollbackAck(self): + self.txTestAck(False) + + def testClose(self): + self.ssn.close() + try: + self.ping(self.ssn) + assert False, "ping succeeded" + except Disconnected: + pass + +RECEIVER_Q = 'test-receiver-queue; {create: always, delete: always}' + +class ReceiverTests(Base): + + def setup_connection(self): + return Connection.open(self.broker.host, self.broker.port, + reconnect=self.reconnect()) + + def setup_session(self): + return self.conn.session() + + def setup_sender(self): + return self.ssn.sender(RECEIVER_Q) + + def setup_receiver(self): + return self.ssn.receiver(RECEIVER_Q) + + def send(self, base, count = None): + content = self.content(base, count) + self.snd.send(content) + return content + + def testFetch(self): + try: + msg = self.rcv.fetch(0) + assert False, "unexpected message: %s" % msg + except Empty: + pass + try: + start = time.time() + msg = self.rcv.fetch(self.delay()) + assert False, "unexpected message: %s" % msg + except Empty: + elapsed = time.time() - start + assert elapsed >= self.delay() + + one = self.send("testFetch", 1) + two = self.send("testFetch", 2) + three = self.send("testFetch", 3) + msg = self.rcv.fetch(0) + assert msg.content == one + msg = self.rcv.fetch(self.delay()) + assert msg.content == two + msg = self.rcv.fetch() + assert msg.content == three + self.ssn.acknowledge() + + def testCapacityIncrease(self): + content = self.send("testCapacityIncrease") + self.sleep() + assert self.rcv.pending() == 0 + self.rcv.capacity = UNLIMITED + self.sleep() + assert self.rcv.pending() == 1 + msg = self.rcv.fetch(0) + assert msg.content == content + assert self.rcv.pending() == 0 + self.ssn.acknowledge() + + def testCapacityDecrease(self): + self.rcv.capacity = UNLIMITED + one = self.send("testCapacityDecrease", 1) + self.sleep() + assert self.rcv.pending() == 1 + msg = self.rcv.fetch(0) + assert msg.content == one + + self.rcv.capacity = 0 + + two = self.send("testCapacityDecrease", 2) + self.sleep() + assert self.rcv.pending() == 0 + msg = self.rcv.fetch(0) + assert msg.content == two + + self.ssn.acknowledge() + + def testCapacity(self): + self.rcv.capacity = 5 + self.assertPending(self.rcv, 0) + + for i in range(15): + self.send("testCapacity", i) + self.sleep() + self.assertPending(self.rcv, 5) + + self.drain(self.rcv, limit = 5) + self.sleep() + self.assertPending(self.rcv, 5) + + drained = self.drain(self.rcv) + assert len(drained) == 10, "%s, %s" % (len(drained), drained) + self.assertPending(self.rcv, 0) + + self.ssn.acknowledge() + + def testCapacityUNLIMITED(self): + self.rcv.capacity = UNLIMITED + self.assertPending(self.rcv, 0) + + for i in range(10): + self.send("testCapacityUNLIMITED", i) + self.sleep() + self.assertPending(self.rcv, 10) + + self.drain(self.rcv) + self.assertPending(self.rcv, 0) + + self.ssn.acknowledge() + + def testPending(self): + self.rcv.capacity = UNLIMITED + assert self.rcv.pending() == 0 + + for i in range(3): + self.send("testPending", i) + self.sleep() + assert self.rcv.pending() == 3 + + for i in range(3, 10): + self.send("testPending", i) + self.sleep() + assert self.rcv.pending() == 10 + + self.drain(self.rcv, limit=3) + assert self.rcv.pending() == 7 + + self.drain(self.rcv) + assert self.rcv.pending() == 0 + + self.ssn.acknowledge() + + # XXX: need testClose + +class AddressTests(Base): + + def setup_connection(self): + return Connection.open(self.broker.host, self.broker.port, + reconnect=self.reconnect()) + + def setup_session(self): + return self.conn.session() + + def testBadOption(self): + snd = self.ssn.sender("test-bad-option; {create: always, node-properties: {this-property-does-not-exist: 3}}") + try: + snd.send("ping") + except SendError, e: + assert "unrecognized option" in str(e) + + def testCreateQueue(self): + snd = self.ssn.sender("test-create-queue; {create: always, delete: always, " + "node-properties: {type: queue, durable: False, " + "x-properties: {auto_delete: true}}}") + content = self.content("testCreateQueue") + snd.send(content) + rcv = self.ssn.receiver("test-create-queue") + self.drain(rcv, expected=[content]) + + def createExchangeTest(self, props=""): + addr = """test-create-exchange; { + create: always, + delete: always, + node-properties: { + type: topic, + durable: False, + x-properties: {auto_delete: true, %s} + } + }""" % props + snd = self.ssn.sender(addr) + snd.send("ping") + rcv1 = self.ssn.receiver("test-create-exchange/first") + rcv2 = self.ssn.receiver("test-create-exchange/first") + rcv3 = self.ssn.receiver("test-create-exchange/second") + for r in (rcv1, rcv2, rcv3): + try: + r.fetch(0) + assert False + except Empty: + pass + msg1 = Message(self.content("testCreateExchange", 1), subject="first") + msg2 = Message(self.content("testCreateExchange", 2), subject="second") + snd.send(msg1) + snd.send(msg2) + self.drain(rcv1, expected=[msg1.content]) + self.drain(rcv2, expected=[msg1.content]) + self.drain(rcv3, expected=[msg2.content]) + + def testCreateExchange(self): + self.createExchangeTest() + + def testCreateExchangeDirect(self): + self.createExchangeTest("type: direct") + + def testCreateExchangeTopic(self): + self.createExchangeTest("type: topic") + + def testDeleteBySender(self): + snd = self.ssn.sender("test-delete; {create: always}") + snd.send("ping") + snd.close() + snd = self.ssn.sender("test-delete; {delete: always}") + snd.send("ping") + snd.close() + snd = self.ssn.sender("test-delete") + try: + snd.send("ping") + except SendError, e: + assert "no such queue" in str(e) + + def testDeleteByReceiver(self): + rcv = self.ssn.receiver("test-delete; {create: always, delete: always}") + try: + rcv.fetch(0) + except Empty: + pass + rcv.close() + + try: + self.ssn.receiver("test-delete") + except SendError, e: + assert "no such queue" in str(e) + + def testDeleteSpecial(self): + snd = self.ssn.sender("amq.topic; {delete: always}") + snd.send("asdf") + try: + snd.close() + except SessionError, e: + assert "Cannot delete default exchange" in str(e) + # XXX: need to figure out close after error + self.conn._remove_session(self.ssn) + + def testBindings(self): + snd = self.ssn.sender(""" +test-bindings-queue; { + create: always, + delete: always, + node-properties: { + x-properties: { + bindings: ["amq.topic/a.#", "amq.direct/b", "amq.topic/c.*"] + } + } +} +""") + snd.send("one") + snd_a = self.ssn.sender("amq.topic/a.foo") + snd_b = self.ssn.sender("amq.direct/b") + snd_c = self.ssn.sender("amq.topic/c.bar") + snd_a.send("two") + snd_b.send("three") + snd_c.send("four") + rcv = self.ssn.receiver("test-bindings-queue") + self.drain(rcv, expected=["one", "two", "three", "four"]) + +NOSUCH_Q = "this-queue-should-not-exist" +UNPARSEABLE_ADDR = "name/subject; {bad options" +UNLEXABLE_ADDR = "\0x0\0x1\0x2\0x3" + +class AddressErrorTests(Base): + + def setup_connection(self): + return Connection.open(self.broker.host, self.broker.port, + reconnect=self.reconnect()) + + def setup_session(self): + return self.conn.session() + + def sendErrorTest(self, addr, exc, check=lambda e: True): + snd = self.ssn.sender(addr, durable=self.durable()) + try: + snd.send("hello") + assert False, "send succeeded" + except exc, e: + assert check(e), "unexpected error: %s" % compat.format_exc(e) + snd.close() + + def fetchErrorTest(self, addr, exc, check=lambda e: True): + rcv = self.ssn.receiver(addr) + try: + rcv.fetch(timeout=0) + assert False, "fetch succeeded" + except exc, e: + assert check(e), "unexpected error: %s" % compat.format_exc(e) + rcv.close() + + def testNoneTarget(self): + # XXX: should have specific exception for this + self.sendErrorTest(None, SendError) + + def testNoneSource(self): + # XXX: should have specific exception for this + self.fetchErrorTest(None, ReceiveError) + + def testNoTarget(self): + # XXX: should have specific exception for this + self.sendErrorTest(NOSUCH_Q, SendError, lambda e: NOSUCH_Q in str(e)) + + def testNoSource(self): + # XXX: should have specific exception for this + self.fetchErrorTest(NOSUCH_Q, ReceiveError, lambda e: NOSUCH_Q in str(e)) + + def testUnparseableTarget(self): + # XXX: should have specific exception for this + self.sendErrorTest(UNPARSEABLE_ADDR, SendError, + lambda e: "expecting COLON" in str(e)) + + def testUnparseableSource(self): + # XXX: should have specific exception for this + self.fetchErrorTest(UNPARSEABLE_ADDR, ReceiveError, + lambda e: "expecting COLON" in str(e)) + + def testUnlexableTarget(self): + # XXX: should have specific exception for this + self.sendErrorTest(UNLEXABLE_ADDR, SendError, + lambda e: "unrecognized characters" in str(e)) + + def testUnlexableSource(self): + # XXX: should have specific exception for this + self.fetchErrorTest(UNLEXABLE_ADDR, ReceiveError, + lambda e: "unrecognized characters" in str(e)) + +SENDER_Q = 'test-sender-q; {create: always, delete: always}' + +class SenderTests(Base): + + def setup_connection(self): + return Connection.open(self.broker.host, self.broker.port, + reconnect=self.reconnect()) + + def setup_session(self): + return self.conn.session() + + def setup_sender(self): + return self.ssn.sender(SENDER_Q) + + def setup_receiver(self): + return self.ssn.receiver(SENDER_Q) + + def checkContent(self, content): + self.snd.send(content) + msg = self.rcv.fetch(0) + assert msg.content == content + + out = Message(content) + self.snd.send(out) + echo = self.rcv.fetch(0) + assert out.content == echo.content + assert echo.content == msg.content + self.ssn.acknowledge() + + def testSendString(self): + self.checkContent(self.content("testSendString")) + + def testSendList(self): + self.checkContent(["testSendList", 1, 3.14, self.test_id]) + + def testSendMap(self): + self.checkContent({"testSendMap": self.test_id, "pie": "blueberry", "pi": 3.14}) + + def asyncTest(self, capacity): + self.snd.capacity = capacity + msgs = [self.content("asyncTest", i) for i in range(15)] + for m in msgs: + self.snd.send(m, sync=False) + drained = self.drain(self.rcv, timeout=self.delay()) + assert msgs == drained, "expected %s, got %s" % (msgs, drained) + self.ssn.acknowledge() + + def testSendAsyncCapacity0(self): + try: + self.asyncTest(0) + assert False, "send shouldn't succeed with zero capacity" + except InsufficientCapacity: + # this is expected + pass + + def testSendAsyncCapacity1(self): + self.asyncTest(1) + + def testSendAsyncCapacity5(self): + self.asyncTest(5) + + def testSendAsyncCapacityUNLIMITED(self): + self.asyncTest(UNLIMITED) + + def testCapacityTimeout(self): + self.snd.capacity = 1 + msgs = [] + caught = False + while len(msgs) < 100: + m = self.content("testCapacity", len(msgs)) + try: + self.snd.send(m, sync=False, timeout=0) + msgs.append(m) + except InsufficientCapacity: + caught = True + break + self.snd.sync() + self.drain(self.rcv, expected=msgs) + self.ssn.acknowledge() + assert caught, "did not exceed capacity" + +class MessageTests(Base): + + def testCreateString(self): + m = Message("string") + assert m.content == "string" + assert m.content_type is None + + def testCreateUnicode(self): + m = Message(u"unicode") + assert m.content == u"unicode" + assert m.content_type == "text/plain" + + def testCreateMap(self): + m = Message({}) + assert m.content == {} + assert m.content_type == "amqp/map" + + def testCreateList(self): + m = Message([]) + assert m.content == [] + assert m.content_type == "amqp/list" + + def testContentTypeOverride(self): + m = Message() + m.content_type = "text/html; charset=utf8" + m.content = u"<html/>" + assert m.content_type == "text/html; charset=utf8" + +ECHO_Q = 'test-message-echo-queue; {create: always, delete: always}' + +class MessageEchoTests(Base): + + def setup_connection(self): + return Connection.open(self.broker.host, self.broker.port, + reconnect=self.reconnect()) + + def setup_session(self): + return self.conn.session() + + def setup_sender(self): + return self.ssn.sender(ECHO_Q) + + def setup_receiver(self): + return self.ssn.receiver(ECHO_Q) + + def check(self, msg): + self.snd.send(msg) + echo = self.rcv.fetch(0) + + assert msg.id == echo.id + assert msg.subject == echo.subject + assert msg.user_id == echo.user_id + assert msg.to == echo.to + assert msg.reply_to == echo.reply_to + assert msg.correlation_id == echo.correlation_id + assert msg.properties == echo.properties + assert msg.content_type == echo.content_type + assert msg.content == echo.content, "%s, %s" % (msg, echo) + + self.ssn.acknowledge(echo) + + def testStringContent(self): + self.check(Message("string")) + + def testUnicodeContent(self): + self.check(Message(u"unicode")) + + + TEST_MAP = {"key1": "string", + "key2": u"unicode", + "key3": 3, + "key4": -3, + "key5": 3.14, + "key6": -3.14, + "key7": ["one", 2, 3.14], + "key8": [], + "key9": {"sub-key0": 3}} + + def testMapContent(self): + self.check(Message(MessageEchoTests.TEST_MAP)) + + def testListContent(self): + self.check(Message([])) + self.check(Message([1, 2, 3])) + self.check(Message(["one", 2, 3.14, {"four": 4}])) + + def testProperties(self): + msg = Message() + msg.to = "to-address" + msg.subject = "subject" + msg.correlation_id = str(self.test_id) + msg.properties = MessageEchoTests.TEST_MAP + msg.reply_to = "reply-address" + self.check(msg) + +class TestTestsXXX(Test): + + def testFoo(self): + print "this test has output" + + def testBar(self): + print "this test "*8 + print "has"*10 + print "a"*75 + print "lot of"*10 + print "output"*10 + + def testQux(self): + import sys + sys.stdout.write("this test has output with no newline") + + def testQuxFail(self): + import sys + sys.stdout.write("this test has output with no newline") + fdsa diff --git a/python/qpid/tests/mimetype.py b/python/qpid/tests/mimetype.py new file mode 100644 index 0000000000..22760316f0 --- /dev/null +++ b/python/qpid/tests/mimetype.py @@ -0,0 +1,56 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from qpid.tests import Test +from qpid.mimetype import lex, parse, ParseError, EOF, WSPACE +from parser import ParserBase + +class MimeTypeTests(ParserBase, Test): + + EXCLUDE = (WSPACE, EOF) + + def do_lex(self, st): + return lex(st) + + def do_parse(self, st): + return parse(st) + + def valid(self, addr, type=None, subtype=None, parameters=None): + ParserBase.valid(self, addr, (type, subtype, parameters)) + + def testTypeOnly(self): + self.invalid("type", "expecting SLASH, got EOF line:1,4:type") + + def testTypeSubtype(self): + self.valid("type/subtype", "type", "subtype", []) + + def testTypeSubtypeParam(self): + self.valid("type/subtype ; name=value", + "type", "subtype", [("name", "value")]) + + def testTypeSubtypeParamComment(self): + self.valid("type/subtype ; name(This is a comment.)=value", + "type", "subtype", [("name", "value")]) + + def testMultipleParams(self): + self.valid("type/subtype ; name1=value1 ; name2=value2", + "type", "subtype", [("name1", "value1"), ("name2", "value2")]) + + def testCaseInsensitivity(self): + self.valid("Type/Subtype", "type", "subtype", []) diff --git a/python/qpid/tests/parser.py b/python/qpid/tests/parser.py new file mode 100644 index 0000000000..a4865cc9fe --- /dev/null +++ b/python/qpid/tests/parser.py @@ -0,0 +1,37 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from qpid.parser import ParseError + +class ParserBase: + + def lex(self, addr, *types): + toks = [t.type for t in self.do_lex(addr) if t.type not in self.EXCLUDE] + assert list(types) == toks, "expected %s, got %s" % (types, toks) + + def valid(self, addr, expected): + got = self.do_parse(addr) + assert expected == got, "expected %s, got %s" % (expected, got) + + def invalid(self, addr, error=None): + try: + p = self.do_parse(addr) + assert False, "invalid address parsed: %s" % p + except ParseError, e: + assert error == str(e), "expected %r, got %r" % (error, str(e)) diff --git a/python/qpid/util.py b/python/qpid/util.py index 1140cbe5ef..3409d777f9 100644 --- a/python/qpid/util.py +++ b/python/qpid/util.py @@ -17,7 +17,26 @@ # under the License. # -import os, socket, time, textwrap +import os, socket, time, textwrap, re + +try: + from ssl import wrap_socket as ssl +except ImportError: + from socket import ssl as wrap_socket + class ssl: + + def __init__(self, sock): + self.sock = sock + self.ssl = wrap_socket(sock) + + def recv(self, n): + return self.ssl.read(n) + + def send(self, s): + return self.ssl.write(s) + + def close(self): + self.sock.close() def connect(host, port): sock = socket.socket() @@ -32,8 +51,8 @@ def listen(host, port, predicate = lambda: True, bound = lambda: None): sock = socket.socket() sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) sock.bind((host, port)) - bound() sock.listen(5) + bound() while predicate(): s, a = sock.accept() yield s @@ -48,7 +67,9 @@ def wait(condition, predicate, timeout=None): start = time.time() while not predicate(): if timeout is None: - condition.wait() + # using the timed wait prevents keyboard interrupts from being + # blocked while waiting + condition.wait(3) elif passed < timeout: condition.wait(timeout - passed) else: @@ -76,3 +97,46 @@ def fill(text, indent, heading = None): init = sub w = textwrap.TextWrapper(initial_indent = init, subsequent_indent = sub) return w.fill(" ".join(text.split())) + +class URL: + + RE = re.compile(r""" + # [ <scheme>:// ] [ <user> [ / <password> ] @] <host> [ :<port> ] + ^ (?: ([^:/@]+)://)? (?: ([^:/@]+) (?: / ([^:/@]+) )? @)? ([^@:/]+) (?: :([0-9]+))?$ +""", re.X) + + AMQPS = "amqps" + AMQP = "amqp" + + def __init__(self, s): + match = URL.RE.match(s) + if match is None: + raise ValueError(s) + self.scheme, self.user, self.password, self.host, port = match.groups() + if port is None: + self.port = None + else: + self.port = int(port) + + def __repr__(self): + return "URL(%r)" % str(self) + + def __str__(self): + s = "" + if self.scheme: + s += "%s://" % self.scheme + if self.user: + s += self.user + if self.password: + s += "/%s" % self.password + s += "@" + s += self.host + if self.port: + s += ":%s" % self.port + return s + +def default(value, default): + if value is None: + return default + else: + return value |
