diff options
author | Michael P. Soulier <msoulier@digitaltorque.ca> | 2009-08-15 22:36:58 -0400 |
---|---|---|
committer | Michael P. Soulier <msoulier@digitaltorque.ca> | 2009-08-16 19:56:06 -0400 |
commit | 62b22fb562eff64a6d6bb6c1a1a3c194d668d9a1 (patch) | |
tree | 8adf96a5c71b15cfa443c974ee3ad6e270c8b9e4 | |
parent | 03e4e748293070ac37fb7fe88abc8b915d84be96 (diff) | |
download | tftpy-62b22fb562eff64a6d6bb6c1a1a3c194d668d9a1.tar.gz |
Did some rework for the state machine in a server context.
Removed the handler framework in favour of a TftpContextServer used
as the session.
-rw-r--r-- | tftpy/TftpClient.py | 20 | ||||
-rw-r--r-- | tftpy/TftpPacketFactory.py | 4 | ||||
-rw-r--r-- | tftpy/TftpPacketTypes.py | 88 | ||||
-rw-r--r-- | tftpy/TftpServer.py | 408 | ||||
-rw-r--r-- | tftpy/TftpShared.py | 6 | ||||
-rw-r--r-- | tftpy/TftpStates.py | 588 |
6 files changed, 521 insertions, 593 deletions
diff --git a/tftpy/TftpClient.py b/tftpy/TftpClient.py index da35d05..89843f1 100644 --- a/tftpy/TftpClient.py +++ b/tftpy/TftpClient.py @@ -52,12 +52,12 @@ class TftpClient(TftpSession): # output? This should be in the sample client, but not in the download # call. if metrics.duration == 0: - logger.info("Duration too short, rate undetermined") + log.info("Duration too short, rate undetermined") else: - logger.info('') - logger.info("Downloaded %.2f bytes in %.2f seconds" % (metrics.bytes, metrics.duration)) - logger.info("Average rate: %.2f kbps" % metrics.kbps) - logger.info("Received %d duplicate packets" % metrics.dupcount) + log.info('') + log.info("Downloaded %.2f bytes in %.2f seconds" % (metrics.bytes, metrics.duration)) + log.info("Average rate: %.2f kbps" % metrics.kbps) + log.info("Received %d duplicate packets" % metrics.dupcount) def upload(self, filename, input, packethook=None, timeout=SOCK_TIMEOUT): # Open the input file. @@ -80,9 +80,9 @@ class TftpClient(TftpSession): # output? This should be in the sample client, but not in the download # call. if metrics.duration == 0: - logger.info("Duration too short, rate undetermined") + log.info("Duration too short, rate undetermined") else: - logger.info('') - logger.info("Downloaded %.2f bytes in %.2f seconds" % (metrics.bytes, metrics.duration)) - logger.info("Average rate: %.2f kbps" % metrics.kbps) - logger.info("Received %d duplicate packets" % metrics.dupcount)
\ No newline at end of file + log.info('') + log.info("Downloaded %.2f bytes in %.2f seconds" % (metrics.bytes, metrics.duration)) + log.info("Average rate: %.2f kbps" % metrics.kbps) + log.info("Received %d duplicate packets" % metrics.dupcount) diff --git a/tftpy/TftpPacketFactory.py b/tftpy/TftpPacketFactory.py index 642b4d8..3f287de 100644 --- a/tftpy/TftpPacketFactory.py +++ b/tftpy/TftpPacketFactory.py @@ -19,9 +19,9 @@ class TftpPacketFactory(object): """This method is used to parse an existing datagram into its corresponding TftpPacket object. The buffer is the raw bytes off of the network.""" - logger.debug("parsing a %d byte packet" % len(buffer)) + log.debug("parsing a %d byte packet" % len(buffer)) (opcode,) = struct.unpack("!H", buffer[:2]) - logger.debug("opcode is %d" % opcode) + log.debug("opcode is %d" % opcode) packet = self.__create(opcode) packet.buffer = buffer return packet.decode() diff --git a/tftpy/TftpPacketTypes.py b/tftpy/TftpPacketTypes.py index e269deb..b9328c5 100644 --- a/tftpy/TftpPacketTypes.py +++ b/tftpy/TftpPacketTypes.py @@ -15,7 +15,7 @@ class TftpSession(object): def senderror(self, sock, errorcode, address, port): """This method uses the socket passed, and uses the errorcode, address and port to compose and send an error packet.""" - logger.debug("In senderror, being asked to send error %d to %s:%s" + log.debug("In senderror, being asked to send error %d to %s:%s" % (errorcode, address, port)) errpkt = TftpPacketERR() errpkt.errorcode = errorcode @@ -27,23 +27,23 @@ class TftpPacketWithOptions(object): goal is just to share code here, and not cause diamond inheritance.""" def __init__(self): - self.options = [] + self.options = {} def setoptions(self, options): - logger.debug("in TftpPacketWithOptions.setoptions") - logger.debug("options: " + str(options)) + log.debug("in TftpPacketWithOptions.setoptions") + log.debug("options: " + str(options)) myoptions = {} for key in options: newkey = str(key) myoptions[newkey] = str(options[key]) - logger.debug("populated myoptions with %s = %s" + log.debug("populated myoptions with %s = %s" % (newkey, myoptions[newkey])) - logger.debug("setting options hash to: " + str(myoptions)) + log.debug("setting options hash to: " + str(myoptions)) self._options = myoptions def getoptions(self): - logger.debug("in TftpPacketWithOptions.getoptions") + log.debug("in TftpPacketWithOptions.getoptions") return self._options # Set up getter and setter on options to ensure that they are the proper @@ -59,19 +59,19 @@ class TftpPacketWithOptions(object): format = "!" options = {} - logger.debug("decode_options: buffer is: " + repr(buffer)) - logger.debug("size of buffer is %d bytes" % len(buffer)) + log.debug("decode_options: buffer is: " + repr(buffer)) + log.debug("size of buffer is %d bytes" % len(buffer)) if len(buffer) == 0: - logger.debug("size of buffer is zero, returning empty hash") + log.debug("size of buffer is zero, returning empty hash") return {} # Count the nulls in the buffer. Each one terminates a string. - logger.debug("about to iterate options buffer counting nulls") + log.debug("about to iterate options buffer counting nulls") length = 0 for c in buffer: - #logger.debug("iterating this byte: " + repr(c)) + #log.debug("iterating this byte: " + repr(c)) if ord(c) == 0: - logger.debug("found a null at length %d" % length) + log.debug("found a null at length %d" % length) if length > 0: format += "%dsx" % length length = -1 @@ -79,14 +79,14 @@ class TftpPacketWithOptions(object): raise TftpException, "Invalid options in buffer" length += 1 - logger.debug("about to unpack, format is: %s" % format) + log.debug("about to unpack, format is: %s" % format) mystruct = struct.unpack(format, buffer) tftpassert(len(mystruct) % 2 == 0, "packet with odd number of option/value pairs") for i in range(0, len(mystruct), 2): - logger.debug("setting option %s to %s" % (mystruct[i], mystruct[i+1])) + log.debug("setting option %s to %s" % (mystruct[i], mystruct[i+1])) options[mystruct[i]] = mystruct[i+1] return options @@ -134,10 +134,10 @@ class TftpPacketInitial(TftpPacket, TftpPacketWithOptions): ptype = None if self.opcode == 1: ptype = "RRQ" else: ptype = "WRQ" - logger.debug("Encoding %s packet, filename = %s, mode = %s" + log.debug("Encoding %s packet, filename = %s, mode = %s" % (ptype, self.filename, self.mode)) for key in self.options: - logger.debug(" Option %s = %s" % (key, self.options[key])) + log.debug(" Option %s = %s" % (key, self.options[key])) format = "!H" format += "%dsx" % len(self.filename) @@ -148,7 +148,7 @@ class TftpPacketInitial(TftpPacket, TftpPacketWithOptions): # Add options. options_list = [] if self.options.keys() > 0: - logger.debug("there are options to encode") + log.debug("there are options to encode") for key in self.options: # Populate the option name format += "%dsx" % len(key) @@ -157,9 +157,9 @@ class TftpPacketInitial(TftpPacket, TftpPacketWithOptions): format += "%dsx" % len(str(self.options[key])) options_list.append(str(self.options[key])) - logger.debug("format is %s" % format) - logger.debug("options_list is %s" % options_list) - logger.debug("size of struct is %d" % struct.calcsize(format)) + log.debug("format is %s" % format) + log.debug("options_list is %s" % options_list) + log.debug("size of struct is %d" % struct.calcsize(format)) self.buffer = struct.pack(format, self.opcode, @@ -167,7 +167,7 @@ class TftpPacketInitial(TftpPacket, TftpPacketWithOptions): self.mode, *options_list) - logger.debug("buffer is " + repr(self.buffer)) + log.debug("buffer is " + repr(self.buffer)) return self def decode(self): @@ -177,13 +177,13 @@ class TftpPacketInitial(TftpPacket, TftpPacketWithOptions): nulls = 0 format = "" nulls = length = tlength = 0 - logger.debug("in decode: about to iterate buffer counting nulls") + log.debug("in decode: about to iterate buffer counting nulls") subbuf = self.buffer[2:] for c in subbuf: - logger.debug("iterating this byte: " + repr(c)) + log.debug("iterating this byte: " + repr(c)) if ord(c) == 0: nulls += 1 - logger.debug("found a null at length %d, now have %d" + log.debug("found a null at length %d, now have %d" % (length, nulls)) format += "%dsx" % length length = -1 @@ -193,17 +193,17 @@ class TftpPacketInitial(TftpPacket, TftpPacketWithOptions): length += 1 tlength += 1 - logger.debug("hopefully found end of mode at length %d" % tlength) + log.debug("hopefully found end of mode at length %d" % tlength) # length should now be the end of the mode. tftpassert(nulls == 2, "malformed packet") shortbuf = subbuf[:tlength+1] - logger.debug("about to unpack buffer with format: %s" % format) - logger.debug("unpacking buffer: " + repr(shortbuf)) + log.debug("about to unpack buffer with format: %s" % format) + log.debug("unpacking buffer: " + repr(shortbuf)) mystruct = struct.unpack(format, shortbuf) tftpassert(len(mystruct) == 2, "malformed packet") - logger.debug("setting filename to %s" % mystruct[0]) - logger.debug("setting mode to %s" % mystruct[1]) + log.debug("setting filename to %s" % mystruct[0]) + log.debug("setting mode to %s" % mystruct[1]) self.filename = mystruct[0] self.mode = mystruct[1] @@ -269,7 +269,7 @@ DATA | 03 | Block # | Data | """Encode the DAT packet. This method populates self.buffer, and returns self for easy method chaining.""" if len(self.data) == 0: - logger.debug("Encoding an empty DAT packet") + log.debug("Encoding an empty DAT packet") format = "!HH%ds" % len(self.data) self.buffer = struct.pack(format, self.opcode, @@ -283,12 +283,12 @@ DATA | 03 | Block # | Data | # We know the first 2 bytes are the opcode. The second two are the # block number. (self.blocknumber,) = struct.unpack("!H", self.buffer[2:4]) - logger.debug("decoding DAT packet, block number %d" % self.blocknumber) - logger.debug("should be %d bytes in the packet total" + log.debug("decoding DAT packet, block number %d" % self.blocknumber) + log.debug("should be %d bytes in the packet total" % len(self.buffer)) # Everything else is data. self.data = self.buffer[4:] - logger.debug("found %d bytes of data" + log.debug("found %d bytes of data" % len(self.data)) return self @@ -308,14 +308,14 @@ ACK | 04 | Block # | return 'ACK packet: block %d' % self.blocknumber def encode(self): - logger.debug("encoding ACK: opcode = %d, block = %d" + log.debug("encoding ACK: opcode = %d, block = %d" % (self.opcode, self.blocknumber)) self.buffer = struct.pack("!HH", self.opcode, self.blocknumber) return self def decode(self): self.opcode, self.blocknumber = struct.unpack("!HH", self.buffer) - logger.debug("decoded ACK packet: opcode = %d, block = %d" + log.debug("decoded ACK packet: opcode = %d, block = %d" % (self.opcode, self.blocknumber)) return self @@ -365,7 +365,7 @@ ERROR | 05 | ErrorCode | ErrMsg | 0 | """Encode the DAT packet based on instance variables, populating self.buffer, returning self.""" format = "!HH%dsx" % len(self.errmsgs[self.errorcode]) - logger.debug("encoding ERR packet with format %s" % format) + log.debug("encoding ERR packet with format %s" % format) self.buffer = struct.pack(format, self.opcode, self.errorcode, @@ -375,13 +375,13 @@ ERROR | 05 | ErrorCode | ErrMsg | 0 | def decode(self): "Decode self.buffer, populating instance variables and return self." tftpassert(len(self.buffer) > 4, "malformed ERR packet, too short") - logger.debug("Decoding ERR packet, length %s bytes" % + log.debug("Decoding ERR packet, length %s bytes" % len(self.buffer)) format = "!HH%dsx" % (len(self.buffer) - 5) - logger.debug("Decoding ERR packet with format: %s" % format) + log.debug("Decoding ERR packet with format: %s" % format) self.opcode, self.errorcode, self.errmsg = struct.unpack(format, self.buffer) - logger.error("ERR packet - errorcode: %d, message: %s" + log.error("ERR packet - errorcode: %d, message: %s" % (self.errorcode, self.errmsg)) return self @@ -402,10 +402,10 @@ class TftpPacketOACK(TftpPacket, TftpPacketWithOptions): def encode(self): format = "!H" # opcode options_list = [] - logger.debug("in TftpPacketOACK.encode") + log.debug("in TftpPacketOACK.encode") for key in self.options: - logger.debug("looping on option key %s" % key) - logger.debug("value is %s" % self.options[key]) + log.debug("looping on option key %s" % key) + log.debug("value is %s" % self.options[key]) format += "%dsx" % len(key) format += "%dsx" % len(self.options[key]) options_list.append(key) @@ -429,7 +429,7 @@ class TftpPacketOACK(TftpPacket, TftpPacketWithOptions): # We can accept anything between the min and max values. size = self.options[name] if size >= MIN_BLKSIZE and size <= MAX_BLKSIZE: - logger.debug("negotiated blksize of %d bytes" % size) + log.debug("negotiated blksize of %d bytes" % size) options[blksize] = size else: raise TftpException, "Unsupported option: %s" % name diff --git a/tftpy/TftpServer.py b/tftpy/TftpServer.py index e846979..ad781a2 100644 --- a/tftpy/TftpServer.py +++ b/tftpy/TftpServer.py @@ -1,4 +1,5 @@ import socket, os, re, time, random +import select from TftpShared import * from TftpPacketTypes import * from TftpPacketFactory import * @@ -15,26 +16,27 @@ class TftpServer(TftpSession): self.listenip = None self.listenport = None self.sock = None + # FIXME: What about multiple roots? self.root = os.path.abspath(tftproot) - self.dynfunc = dyn_file_func + self.dyn_file_func = dyn_file_func # A dict of handlers, where each session is keyed by a string like # ip:tid for the remote end. self.handlers = {} if os.path.exists(self.root): - logger.debug("tftproot %s does exist" % self.root) + log.debug("tftproot %s does exist" % self.root) if not os.path.isdir(self.root): raise TftpException, "The tftproot must be a directory." else: - logger.debug("tftproot %s is a directory" % self.root) + log.debug("tftproot %s is a directory" % self.root) if os.access(self.root, os.R_OK): - logger.debug("tftproot %s is readable" % self.root) + log.debug("tftproot %s is readable" % self.root) else: raise TftpException, "The tftproot must be readable" if os.access(self.root, os.W_OK): - logger.debug("tftproot %s is writable" % self.root) + log.debug("tftproot %s is writable" % self.root) else: - logger.warning("The tftproot %s is not writable" % self.root) + log.warning("The tftproot %s is not writable" % self.root) else: raise TftpException, "The tftproot does not exist." @@ -45,14 +47,12 @@ class TftpServer(TftpSession): """Start a server listening on the supplied interface and port. This defaults to INADDR_ANY (all interfaces) and UDP port 69. You can also supply a different socket timeout value, if desired.""" - import select - tftp_factory = TftpPacketFactory() # Don't use new 2.5 ternary operator yet # listenip = listenip if listenip else '0.0.0.0' if not listenip: listenip = '0.0.0.0' - logger.info("Server requested on ip %s, port %s" + log.info("Server requested on ip %s, port %s" % (listenip, listenport)) try: # FIXME - sockets should be non-blocking? @@ -62,388 +62,82 @@ class TftpServer(TftpSession): # Reraise it for now. raise - logger.info("Starting receive loop...") + log.info("Starting receive loop...") while True: # Build the inputlist array of sockets to select() on. inputlist = [] inputlist.append(self.sock) - for key in self.handlers: - inputlist.append(self.handlers[key].sock) + for key in self.sessions: + inputlist.append(self.sessions[key].sock) # Block until some socket has input on it. - logger.debug("Performing select on this inputlist: %s" % inputlist) + log.debug("Performing select on this inputlist: %s" % inputlist) readyinput, readyoutput, readyspecial = select.select(inputlist, [], [], SOCK_TIMEOUT) - #(buffer, (raddress, rport)) = self.sock.recvfrom(MAX_BLKSIZE) - #recvpkt = tftp_factory.parse(buffer) - #key = "%s:%s" % (raddress, rport) - deletion_list = [] + # Handle the available data, if any. Maybe we timed-out. for readysock in readyinput: + # Is the traffic on the main server socket? ie. new session? if readysock == self.sock: - logger.debug("Data ready on our main socket") + log.debug("Data ready on our main socket") buffer, (raddress, rport) = self.sock.recvfrom(MAX_BLKSIZE) - logger.debug("Read %d bytes" % len(buffer)) - recvpkt = tftp_factory.parse(buffer) - key = "%s:%s" % (raddress, rport) - if isinstance(recvpkt, TftpPacketRRQ): - logger.debug("RRQ packet from %s:%s" % (raddress, rport)) - if not self.handlers.has_key(key): - try: - logger.debug("New download request, session key = %s" - % key) - self.handlers[key] = TftpServerHandler(key, - 'rrq', - self.root, - listenip, - tftp_factory, - self.dynfunc) - self.handlers[key].handle((recvpkt, raddress, rport)) - except TftpException, err: - logger.error("Fatal exception thrown from handler: %s" - % str(err)) - logger.debug("Deleting handler: %s" % key) - deletion_list.append(key) + log.debug("Read %d bytes" % len(buffer)) - else: - logger.warn("Received RRQ for existing session!") - self.senderror(self.sock, - TftpErrors.IllegalTftpOp, - raddress, - rport) - continue + recvpkt = tftp_factory.parse(buffer) + # FIXME: Is this the best way to do a session key? What + # about symmetric udp? + key = "%s:%s" % (raddress, rport) - elif isinstance(recvpkt, TftpPacketWRQ): - logger.error("Write requests not implemented at this time.") - self.senderror(self.sock, - TftpErrors.IllegalTftpOp, - raddress, - rport) - continue + if not self.sessions.has_key(key): + log.debug("Creating new server context for " + "session key = %s" % key) + self.sessions[key] = TftpContextServer(raddress, + rport, + timeout, + self.root, + self.dyn_file_func) + self.sessions[key].start(buffer) else: - # FIXME - this will have to change if we do symmetric UDP - logger.error("Should only receive RRQ or WRQ packets " - "on main listen port. Received %s" % recvpkt) - self.senderror(self.sock, - TftpErrors.IllegalTftpOp, - raddress, - rport) - continue + log.warn("received traffic on main socket for " + "existing session??") else: - for key in self.handlers: - if readysock == self.handlers[key].sock: - # FIXME - violating DRY principle with above code + # Must find the owner of this traffic. + for key in self.session: + if readysock == self.session[key].sock: try: - self.handlers[key].handle() + self.session[key].cycle() + if self.session[key].state == None: + log.info("Successful transfer.") + deletion_list.append(key) break except TftpException, err: deletion_list.append(key) - if self.handlers[key].state.state == 'fin': - logger.info("Successful transfer.") - break - else: - logger.error("Fatal exception thrown from handler: %s" - % str(err)) + log.error("Fatal exception thrown from " + "handler: %s" % str(err)) else: - logger.error("Can't find the owner for this packet. Discarding.") + log.error("Can't find the owner for this packet. " + "Discarding.") - logger.debug("Looping on all handlers to check for timeouts") + log.debug("Looping on all handlers to check for timeouts") now = time.time() - for key in self.handlers: + for key in self.sessions: try: - self.handlers[key].check_timeout(now) + self.sessions[key].checkTimeout(now) except TftpException, err: - logger.error("Fatal exception thrown from handler: %s" + log.error("Fatal exception thrown from handler: %s" % str(err)) deletion_list.append(key) - logger.debug("Iterating deletion list.") + log.debug("Iterating deletion list.") for key in deletion_list: - if self.handlers.has_key(key): - logger.debug("Deleting handler %s" % key) - del self.handlers[key] + if self.sessions.has_key(key): + log.debug("Deleting handler %s" % key) + del self.sessions[key] deletion_list = [] - -class TftpServerHandler(TftpSession): - """This class implements a handler for a given server session, handling - the work for one download.""" - - def __init__(self, key, state, root, listenip, factory, dyn_file_func): - TftpSession.__init__(self) - logger.info("Starting new handler. Key %s." % key) - self.key = key - self.host, self.port = self.key.split(':') - self.port = int(self.port) - self.listenip = listenip - # Note, correct state here is important as it tells the handler whether it's - # handling a download or an upload. - self.state = state - self.root = root - self.mode = None - self.filename = None - self.sock = False - self.options = { 'blksize': DEF_BLKSIZE } - self.blocknumber = 0 - self.buffer = None - self.fileobj = None - self.timesent = 0 - self.timeouts = 0 - self.tftp_factory = factory - self.dynfunc = dyn_file_func - count = 0 - while not self.sock: - self.sock = self.gensock(listenip) - count += 1 - if count > 10: - raise TftpException, "Failed to bind this handler to any port" - - def check_timeout(self, now): - """This method checks to see if we've timed-out waiting for traffic - from the client.""" - if self.timesent: - if now - self.timesent > SOCK_TIMEOUT: - self.timeout() - - def timeout(self): - """This method handles a timeout condition.""" - logger.debug("Handling timeout for handler %s" % self.key) - self.timeouts += 1 - if self.timeouts > TIMEOUT_RETRIES: - raise TftpException, "Hit max retries, giving up." - - if self.state.state == 'dat' or self.state.state == 'fin': - logger.debug("Timing out on DAT. Need to resend.") - self.send_dat(resend=True) - elif self.state.state == 'oack': - logger.debug("Timing out on OACK. Need to resend.") - self.send_oack() - else: - tftpassert(False, - "Timing out in unsupported state %s" % - self.state.state) - - def gensock(self, listenip): - """This method generates a new UDP socket, whose listening port must - be randomly generated, and not conflict with any already in use. For - now, let the OS do this.""" - random.seed() - port = random.randrange(1025, 65536) - # FIXME - sockets should be non-blocking? - sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - logger.debug("Trying a handler socket on port %d" % port) - try: - sock.bind((listenip, port)) - return sock - except socket.error, err: - if err[0] == 98: - logger.warn("Handler %s, port %d was already taken" % (self.key, port)) - return False - else: - raise - - def handle(self, pkttuple=None): - """This method informs a handler instance that it has data waiting on - its socket that it must read and process.""" - recvpkt = raddress = rport = None - if pkttuple: - logger.debug("Handed pkt %s for handler %s" % (recvpkt, self.key)) - recvpkt, raddress, rport = pkttuple - else: - logger.debug("Data ready for handler %s" % self.key) - buffer, (raddress, rport) = self.sock.recvfrom(MAX_BLKSIZE) - logger.debug("Read %d bytes" % len(buffer)) - recvpkt = self.tftp_factory.parse(buffer) - - # FIXME - refactor into another method, this is too big - if isinstance(recvpkt, TftpPacketRRQ): - logger.debug("Handler %s received RRQ packet" % self.key) - logger.debug("Requested file is %s, mode is %s" % (recvpkt.filename, - recvpkt.mode)) - # FIXME - only octet mode is supported at this time. - if recvpkt.mode != 'octet': - self.senderror(self.sock, - TftpErrors.IllegalTftpOp, - raddress, - rport) - raise TftpException, "Unsupported mode: %s" % recvpkt.mode - - # test host/port of client end - if self.host != raddress or self.port != rport: - self.senderror(self.sock, - TftpErrors.UnknownTID, - raddress, - rport) - logger.error("Expected traffic from %s:%s but received it " - "from %s:%s instead." - % (self.host, self.port, raddress, rport)) - self.errors += 1 - return - - if self.state.state == 'rrq': - logger.debug("Received RRQ. Composing response.") - self.filename = self.root + os.sep + recvpkt.filename - logger.debug("The path to the desired file is %s" % - self.filename) - self.filename = os.path.abspath(self.filename) - logger.debug("The absolute path is %s" % self.filename) - # Security check. Make sure it's prefixed by the tftproot. - if self.filename.find(self.root) == 0: - logger.debug("The path appears to be safe: %s" % - self.filename) - else: - logger.error("Insecure path: %s" % self.filename) - self.errors += 1 - self.senderror(self.sock, - TftpErrors.AccessViolation, - raddress, - rport) - raise TftpException, "Insecure path: %s" % self.filename - - # Does the file exist? - if(os.path.exists(self.filename) or not self.dynfunc is None): - logger.debug("File %s exists." % self.filename) - - # Check options. Currently we only support the blksize - # option. - if recvpkt.options.has_key('blksize'): - logger.debug("RRQ includes a blksize option") - blksize = int(recvpkt.options['blksize']) - # Delete the option now that it's handled. - del recvpkt.options['blksize'] - if blksize >= MIN_BLKSIZE and blksize <= MAX_BLKSIZE: - logger.info("Client requested blksize = %d" - % blksize) - self.options['blksize'] = blksize - else: - logger.warning("Client %s requested invalid " - "blocksize %d, responding with default" - % (self.key, blksize)) - self.options['blksize'] = DEF_BLKSIZE - - if recvpkt.options.has_key('tsize'): - logger.info('RRQ includes tsize option') - self.options['tsize'] = os.stat(self.filename).st_size - # Delete the option now that it's handled. - del recvpkt.options['tsize'] - - if len(recvpkt.options.keys()) > 0: - logger.warning("Client %s requested unsupported options: %s" - % (self.key, recvpkt.options)) - - if self.options: - logger.info("Options requested, sending OACK") - self.send_oack() - else: - logger.debug("Client %s requested no options." - % self.key) - self.start_download() - - else: - logger.error("Requested file %s does not exist." % - self.filename) - self.senderror(self.sock, - TftpErrors.FileNotFound, - raddress, - rport) - raise TftpException, "Requested file not found: %s" % self.filename - - else: - # We're receiving an RRQ when we're not expecting one. - logger.error("Received an RRQ in handler %s " - "but we're in state %s" % (self.key, self.state)) - self.errors += 1 - - # Next packet type - elif isinstance(recvpkt, TftpPacketACK): - logger.debug("Received an ACK from the client.") - if recvpkt.blocknumber == 0 and self.state.state == 'oack': - logger.debug("Received ACK with 0 blocknumber, starting download") - self.start_download() - else: - if self.state.state == 'dat' or self.state.state == 'fin': - if self.blocknumber == recvpkt.blocknumber: - logger.debug("Received ACK for block %d" - % recvpkt.blocknumber) - if self.state.state == 'fin': - raise TftpException, "Successful transfer." - else: - self.send_dat() - elif recvpkt.blocknumber < self.blocknumber: - # Don't resend a DAT due to an old ACK. Fixes the - # sorceror's apprentice problem. - logger.warn("Received old ACK for block number %d" - % recvpkt.blocknumber) - else: - logger.warn("Received ACK for block number " - "%d, apparently from the future" - % recvpkt.blocknumber) - else: - logger.error("Received ACK with block number %d " - "while in state %s" - % (recvpkt.blocknumber, - self.state.state)) - - elif isinstance(recvpkt, TftpPacketERR): - logger.error("Received error packet from client: %s" % recvpkt) - self.state.state = 'err' - raise TftpException, "Received error from client" - - # Handle other packet types. - else: - logger.error("Received packet %s while handling a download" - % recvpkt) - self.senderror(self.sock, - TftpErrors.IllegalTftpOp, - self.host, - self.port) - raise TftpException, "Invalid packet received during download" - - def start_download(self): - """This method opens self.filename, stores the resulting file object - in self.fileobj, and calls send_dat().""" - self.state.state = 'dat' - if os.path.exists(self.filename): - self.fileobj = open(self.filename, "rb") - else: - self.fileobj = self.dynfunc(self.filename) - self.send_dat() - - def send_dat(self, resend=False): - """This method reads sends a DAT packet based on what is in self.buffer.""" - if not resend: - blksize = int(self.options['blksize']) - self.buffer = self.fileobj.read(blksize) - logger.debug("Read %d bytes into buffer" % len(self.buffer)) - if len(self.buffer) < blksize: - logger.info("Reached EOF on file %s" % self.filename) - self.state.state = 'fin' - self.blocknumber += 1 - if self.blocknumber > 65535: - logger.debug("Blocknumber rolled over to zero") - self.blocknumber = 0 - else: - logger.warn("Resending block number %d" % self.blocknumber) - dat = TftpPacketDAT() - dat.data = self.buffer - dat.blocknumber = self.blocknumber - logger.debug("Sending DAT packet %d" % self.blocknumber) - self.sock.sendto(dat.encode().buffer, (self.host, self.port)) - self.timesent = time.time() - - # FIXME - should these be factored-out into the session class? - def send_oack(self): - """This method sends an OACK packet based on current params.""" - logger.debug("Composing and sending OACK packet") - oack = TftpPacketOACK() - oack.options = self.options - self.sock.sendto(oack.encode().buffer, - (self.host, self.port)) - self.timesent = time.time() - self.state.state = 'oack' diff --git a/tftpy/TftpShared.py b/tftpy/TftpShared.py index 95172c3..bb95ad4 100644 --- a/tftpy/TftpShared.py +++ b/tftpy/TftpShared.py @@ -17,7 +17,7 @@ DEF_TFTP_PORT = 69 logging.basicConfig() # The logger used by this library. Feel free to clobber it with your own, if you like, as # long as it conforms to Python's logging. -logger = logging.getLogger('tftpy') +log = logging.getLogger('tftpy') def tftpassert(condition, msg): """This function is a simple utility that will check the condition @@ -31,8 +31,8 @@ def setLogLevel(level): """This function is a utility function for setting the internal log level. The log level defaults to logging.NOTSET, so unwanted output to stdout is not created.""" - global logger - logger.setLevel(level) + global log + log.setLevel(level) class TftpErrors(object): """This class is a convenience for defining the common tftp error codes, diff --git a/tftpy/TftpStates.py b/tftpy/TftpStates.py index 88c4fa1..3d71e16 100644 --- a/tftpy/TftpStates.py +++ b/tftpy/TftpStates.py @@ -22,17 +22,28 @@ class TftpMetrics(object): # Rates self.bps = 0 self.kbps = 0 + # Generic errors + self.errors = 0 def compute(self): # Compute transfer time self.duration = self.end_time - self.start_time - logger.debug("TftpMetrics.compute: duration is %s" % self.duration) + log.debug("TftpMetrics.compute: duration is %s" % self.duration) self.bps = (self.bytes * 8.0) / self.duration self.kbps = self.bps / 1024.0 - logger.debug("TftpMetrics.compute: kbps is %s" % self.kbps) - dupcount = 0 + log.debug("TftpMetrics.compute: kbps is %s" % self.kbps) for key in self.dups: - dupcount += self.dups[key] + self.dupcount += self.dups[key] + + def add_dup(self, blocknumber): + """This method adds a dup for a block number to the metrics.""" + log.debug("Recording a dup for block %d" % blocknumber) + if self.dups.has_key(blocknumber): + self.dups[pkt.blocknumber] += 1 + else: + self.dups[pkt.blocknumber] = 1 + tftpassert(self.dups[pkt.blocknumber] < MAX_DUPS, + "Max duplicates for block %d reached" % blocknumber) ############################################################################### # Context classes @@ -40,16 +51,32 @@ class TftpMetrics(object): class TftpContext(object): """The base class of the contexts.""" - def __init__(self, host, port): + + def __init__(self, host, port, timeout): """Constructor for the base context, setting shared instance variables.""" + self.file_to_transfer = None + self.fileobj = None + self.options = None + self.packethook = None + self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + self.sock.settimeout(timeout) + self.state = None + self.next_block = 0 self.factory = TftpPacketFactory() + # Note, setting the host will also set self.address, as it's a property. self.host = host self.port = port # The port associated with the TID self.tidport = None # Metrics self.metrics = TftpMetrics() + # Flag when the transfer is pending completion. + self.pending_complete = False + + def checkTimeout(self, now): + # FIXME + pass def start(self): return NotImplementedError, "Abstract method" @@ -69,37 +96,9 @@ class TftpContext(object): host = property(gethost, sethost) - def sendAck(self, blocknumber): - """This method sends an ack packet to the block number specified.""" - logger.info("sending ack to block %d" % blocknumber) - ackpkt = TftpPacketACK() - ackpkt.blocknumber = blocknumber - self.sock.sendto(ackpkt.encode().buffer, (self.host, self.tidport)) - - def sendError(self, errorcode): - """This method uses the socket passed, and uses the errorcode to - compose and send an error packet.""" - logger.debug("In sendError, being asked to send error %d" % errorcode) - errpkt = TftpPacketERR() - errpkt.errorcode = errorcode - self.sock.sendto(errpkt.encode().buffer, (self.host, self.tidport)) - -class TftpContextClient(TftpContext): - """This class represents shared functionality by both the download and - upload client contexts.""" - def __init__(self, host, port, filename, options, packethook, timeout): - TftpContext.__init__(self, host, port) - self.file_to_transfer = filename - self.options = options - self.packethook = packethook - self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - self.sock.settimeout(timeout) - self.state = None - self.next_block = 0 - def setNextBlock(self, block): if block > 2 ** 16: - logger.debug("block number rollover to 0 again") + log.debug("block number rollover to 0 again") block = 0 self.__eblock = block @@ -111,19 +110,21 @@ class TftpContextClient(TftpContext): def cycle(self): """Here we wait for a response from the server after sending it something, and dispatch appropriate action to that response.""" + # FIXME: This won't work very well in a server context with multiple + # sessions running. for i in range(TIMEOUT_RETRIES): - logger.debug("in cycle, receive attempt %d" % i) + log.debug("in cycle, receive attempt %d" % i) try: (buffer, (raddress, rport)) = self.sock.recvfrom(MAX_BLKSIZE) except socket.timeout, err: - logger.warn("Timeout waiting for traffic, retrying...") + log.warn("Timeout waiting for traffic, retrying...") continue break else: raise TftpException, "Hit max timeouts, giving up." # Ok, we've received a packet. Log it. - logger.debug("Received %d bytes from %s:%s" + log.debug("Received %d bytes from %s:%s" % (len(buffer), raddress, rport)) # Decode it. @@ -131,11 +132,11 @@ class TftpContextClient(TftpContext): # Check for known "connection". if raddress != self.address: - logger.warn("Received traffic from %s, expected host %s. Discarding" + log.warn("Received traffic from %s, expected host %s. Discarding" % (raddress, self.host)) if self.tidport and self.tidport != rport: - logger.warn("Received traffic from %s:%s but we're " + log.warn("Received traffic from %s:%s but we're " "connected to %s:%s. Discarding." % (raddress, rport, self.host, self.tidport)) @@ -150,29 +151,66 @@ class TftpContextClient(TftpContext): # And handle it, possibly changing state. self.state = self.state.handle(recvpkt, raddress, rport) -class TftpContextClientUpload(TftpContextClient): +class TftpContextServer(TftpContext): + """The context for the server.""" + def __init__(self, host, port, timeout, root, dyn_file_func): + TftpContext.__init__(self, + host, + port, + timeout) + # At this point we have no idea if this is a download or an upload. We + # need to let the start state determine that. + self.state = TftpStateServerStart() + self.root = root + self.dyn_file_func = dyn_file_func + + def start(self, buffer): + """Start the state cycle. Note that the server context receives an + initial packet in its start method.""" + log.debug("TftpContextServer.start() - pkt = %s" % pkt) + + self.metrics.start_time = time.time() + log.debug("set metrics.start_time to %s" % self.metrics.start_time) + + pkt = self.factory.parse(buffer) + log.debug("TftpContextServer.start() - factory returned a %s" % pkt) + + # Call handle once with the initial packet. This should put us into + # the download or the upload state. + self.state = self.state.handle(pkt, + self.host, + self.port) + + try: + while self.state: + log.debug("state is %s" % self.state) + self.cycle() + finally: + self.fileobj.close() + +class TftpContextClientUpload(TftpContext): """The upload context for the client during an upload.""" def __init__(self, host, port, filename, input, options, packethook, timeout): - TftpContextClient.__init__(self, - host, - port, - filename, - options, - packethook, - timeout) + TftpContext.__init__(self, + host, + port, + timeout) + self.file_to_transfer = filename + self.options = options + self.packethook = packethook self.fileobj = open(input, "wb") - logger.debug("TftpContextClientUpload.__init__()") - logger.debug("file_to_transfer = %s, options = %s" % + log.debug("TftpContextClientUpload.__init__()") + log.debug("file_to_transfer = %s, options = %s" % (self.file_to_transfer, self.options)) def start(self): - logger.info("sending tftp upload request to %s" % self.host) - logger.info(" filename -> %s" % self.file_to_transfer) - logger.info(" options -> %s" % self.options) + log.info("sending tftp upload request to %s" % self.host) + log.info(" filename -> %s" % self.file_to_transfer) + log.info(" options -> %s" % self.options) self.metrics.start_time = time.time() - logger.debug("set metrics.start_time to %s" % self.metrics.start_time) + log.debug("set metrics.start_time to %s" % self.metrics.start_time) # FIXME: put this in a sendWRQ method? pkt = TftpPacketWRQ() @@ -186,7 +224,7 @@ class TftpContextClientUpload(TftpContextClient): try: while self.state: - logger.debug("state is %s" % self.state) + log.debug("state is %s" % self.state) self.cycle() finally: self.fileobj.close() @@ -194,32 +232,32 @@ class TftpContextClientUpload(TftpContextClient): def end(self): pass -class TftpContextClientDownload(TftpContextClient): +class TftpContextClientDownload(TftpContext): """The download context for the client during a download.""" def __init__(self, host, port, filename, output, options, packethook, timeout): - TftpContextClient.__init__(self, - host, - port, - filename, - options, - packethook, - timeout) + TftpContext.__init__(self, + host, + port, + filename, + options, + packethook, + timeout) # FIXME - need to support alternate return formats than files? # File-like objects would be ideal, ala duck-typing. self.fileobj = open(output, "wb") - logger.debug("TftpContextClientDownload.__init__()") - logger.debug("file_to_transfer = %s, options = %s" % + log.debug("TftpContextClientDownload.__init__()") + log.debug("file_to_transfer = %s, options = %s" % (self.file_to_transfer, self.options)) def start(self): """Initiate the download.""" - logger.info("sending tftp download request to %s" % self.host) - logger.info(" filename -> %s" % self.file_to_transfer) - logger.info(" options -> %s" % self.options) + log.info("sending tftp download request to %s" % self.host) + log.info(" filename -> %s" % self.file_to_transfer) + log.info(" options -> %s" % self.options) self.metrics.start_time = time.time() - logger.debug("set metrics.start_time to %s" % self.metrics.start_time) + log.debug("set metrics.start_time to %s" % self.metrics.start_time) # FIXME: put this in a sendRRQ method? pkt = TftpPacketRRQ() @@ -233,7 +271,7 @@ class TftpContextClientDownload(TftpContextClient): try: while self.state: - logger.debug("state is %s" % self.state) + log.debug("state is %s" % self.state) self.cycle() finally: self.fileobj.close() @@ -241,7 +279,7 @@ class TftpContextClientDownload(TftpContextClient): def end(self): """Finish up the context.""" self.metrics.end_time = time.time() - logger.debug("set metrics.end_time to %s" % self.metrics.end_time) + log.debug("set metrics.end_time to %s" % self.metrics.end_time) self.metrics.compute() @@ -268,235 +306,431 @@ class TftpState(object): options.""" if pkt.options.keys() > 0: if pkt.match_options(self.context.options): - logger.info("Successful negotiation of options") + log.info("Successful negotiation of options") # Set options to OACK options self.context.options = pkt.options for key in self.context.options: - logger.info(" %s = %s" % (key, self.context.options[key])) + log.info(" %s = %s" % (key, self.context.options[key])) else: - logger.error("failed to negotiate options") + log.error("failed to negotiate options") raise TftpException, "Failed to negotiate options" else: raise TftpException, "No options found in OACK" -class TftpStateUpload(TftpState): - """A class holding common code for upload states.""" - def sendDat(self, resend=False): + def returnSupportedOptions(self, options): + """This method takes a requested options list from a client, and + returns the ones that are supported.""" + # We support the options blksize and tsize right now. + # FIXME - put this somewhere else? + accepted_options = {} + for option in options: + if option == 'blksize': + # Make sure it's valid. + if int(options[option]) > MAX_BLKSIZE: + accepted_options[option] = MAX_BLKSIZE + elif option == 'tsize': + log.debug("tsize option is set") + accepted_options['tsize'] = 1 + else: + log.info("Dropping unsupported option '%s'" % option) + return accepted_options + + def serverInitial(self, pkt, raddress, rport): + """This method performs initial setup for a server context transfer, + put here to refactor code out of the TftpStateServerRecvRRQ and + TftpStateServerRecvWRQ classes, since their initial setup is + identical. The method returns a boolean, sendoack, to indicate whether + it is required to send an OACK to the client.""" + options = pkt.options + sendoack = False + if not options: + log.debug("setting default options, blksize") + # FIXME: put default options elsewhere + self.context.options = { 'blksize': DEF_BLKSIZE } + else: + log.debug("options requested: %s" % options) + self.context.options = self.returnSupportedOptions(options) + sendoack = True + + # FIXME - only octet mode is supported at this time. + if pkt.mode != 'octet': + self.sendError(TftpErrors.IllegalTftpOp) + raise TftpException, \ + "Only octet transfers are supported at this time." + + # test host/port of client end + if self.context.host != raddress or self.context.port != rport: + self.sendError(TftpErrors.UnknownTID) + log.error("Expected traffic from %s:%s but received it " + "from %s:%s instead." + % (self.context.host, + self.context.port, + raddress, + rport)) + # FIXME: increment an error count? + # Return same state, we're still waiting for valid traffic. + return self + + log.debug("requested filename is %s" % pkt.filename) + # There are no os.sep's allowed in the filename. + # FIXME: Should we allow subdirectories? + if pkt.filename.find(os.sep) >= 0: + self.sendError(TftpErrors.IllegalTftpOp) + raise TftpException, "%s found in filename, not permitted" % os.sep + + self.context.file_to_transfer = pkt.filename + + return sendoack + + def sendDAT(self, resend=False): + """This method sends the next DAT packet based on the data in the + context. It returns a boolean indicating whether the transfer is + finished.""" finished = False blocknumber = self.context.next_block if not resend: blksize = int(self.context.options['blksize']) buffer = self.context.fileobj.read(blksize) - logger.debug("Read %d bytes into buffer" % len(buffer)) + log.debug("Read %d bytes into buffer" % len(buffer)) if len(buffer) < blksize: - logger.info("Reached EOF on file %s" % self.context.input) + log.info("Reached EOF on file %s" % self.context.input) finished = True self.context.next_block += 1 self.bytes += len(buffer) else: - logger.warn("Resending block number %d" % blocknumber) + log.warn("Resending block number %d" % blocknumber) dat = TftpPacketDAT() dat.data = buffer dat.blocknumber = blocknumber - logger.debug("Sending DAT packet %d" % blocknumber) + log.debug("Sending DAT packet %d" % blocknumber) self.context.sock.sendto(dat.encode().buffer, (self.context.host, self.context.port)) if self.context.packethook: self.context.packethook(dat) return finished -class TftpStateDownload(TftpState): - """A class holding common code for download states.""" + def sendACK(self, blocknumber=None): + """This method sends an ack packet to the block number specified. If + none is specified, it defaults to the next_block property in the + parent context.""" + if not blocknumber: + blocknumber = self.context.next_block + log.info("sending ack to block %d" % blocknumber) + ackpkt = TftpPacketACK() + ackpkt.blocknumber = blocknumber + self.context.sock.sendto(ackpkt.encode().buffer, + (self.context.host, + self.context.tidport)) + + def sendError(self, errorcode): + """This method uses the socket passed, and uses the errorcode to + compose and send an error packet.""" + log.debug("In sendError, being asked to send error %d" % errorcode) + errpkt = TftpPacketERR() + errpkt.errorcode = errorcode + self.context.sock.sendto(errpkt.encode().buffer, + (self.context.host, + self.context.tidport)) + + def sendOACK(self): + """This method sends an OACK packet with the options from the current + context.""" + log.debug("In sendOACK with options %s" % options) + pkt = TftpPacketOACK() + pkt.options = self.options + self.context.sock.sendto(pkt.encode().buffer, + (self.context.host, + self.context.tidport)) + def handleDat(self, pkt): - """This method handles a DAT packet during a download.""" - logger.info("handling DAT packet - block %d" % pkt.blocknumber) - logger.debug("expecting block %s" % self.context.next_block) + """This method handles a DAT packet during a client download, or a + server upload.""" + log.info("handling DAT packet - block %d" % pkt.blocknumber) + log.debug("expecting block %s" % self.context.next_block) if pkt.blocknumber == self.context.next_block: - logger.debug("good, received block %d in sequence" + log.debug("good, received block %d in sequence" % pkt.blocknumber) - self.context.sendAck(pkt.blocknumber) + self.sendACK() self.context.next_block += 1 - logger.debug("writing %d bytes to output file" + log.debug("writing %d bytes to output file" % len(pkt.data)) self.context.fileobj.write(pkt.data) self.context.metrics.bytes += len(pkt.data) # Check for end-of-file, any less than full data packet. if len(pkt.data) < int(self.context.options['blksize']): - logger.info("end of file detected") + log.info("end of file detected") return None elif pkt.blocknumber < self.context.next_block: - logger.warn("dropping duplicate block %d" % pkt.blocknumber) - if self.context.metrics.dups.has_key(pkt.blocknumber): - self.context.metrics.dups[pkt.blocknumber] += 1 - else: - self.context.metrics.dups[pkt.blocknumber] = 1 - tftpassert(self.context.metrics.dups[pkt.blocknumber] < MAX_DUPS, - "Max duplicates for block %d reached" % pkt.blocknumber) - # FIXME: double-check sorceror's apprentice problem! - logger.debug("ACKing block %d again, just in case" % pkt.blocknumber) - self.context.sendAck(pkt.blocknumber) + log.warn("dropping duplicate block %d" % pkt.blocknumber) + self.context.metrics.add_dup(pkt.blocknumber) + log.debug("ACKing block %d again, just in case" % pkt.blocknumber) + self.sendACK(pkt.blocknumber) else: # FIXME: should we be more tolerant and just discard instead? msg = "Whoa! Received future block %d but expected %d" \ % (pkt.blocknumber, self.context.next_block) - logger.error(msg) + log.error(msg) raise TftpException, msg # Default is to ack - return TftpStateSentACK(self.context) + return TftpStateExpectDAT(self.context) + +class TftpStateServerRecvRRQ(TftpState): + """This class represents the state of the TFTP server when it has just + received an RRQ packet.""" + def handle(self, pkt, raddress, rport): + "Handle an initial RRQ packet as a server." + log.debug("In TftpStateServerRecvRRQ.handle") + sendoack = self.serverInitial(pkt, raddress, rport) + path = self.context.root + os.sep + self.context.file_to_transfer + log.info("Opening file %s for reading" % path) + if os.path.exists(path): + # Note: Open in binary mode for win32 portability, since win32 + # blows. + self.context.fileobj = open(path, "rb") + elif self.dyn_file_func: + log.debug("No such file %s but using dyn_file_func" % path) + self.context.fileobj = \ + self.dyn_file_func(self.context.file_to_transfer) + else: + send.sendError(TftpErrors.FileNotFound) + raise TftpException, "File not found: %s" % path + + # Options negotiation. + if sendoack: + self.sendOACK() + return TftpStateServerOACK(self.context) + else: + log.debug("No requested options, starting send...") + self.context.pending_complete = self.sendDAT() + return TftpStateExpectACK(self.context) + + # Note, we don't have to check any other states in this method, that's + # up to the caller. + +class TftpStateServerRecvWRQ(TftpState): + """This class represents the state of the TFTP server when it has just + received a WRQ packet.""" + def handle(self, pkt, raddress, rport): + "Handle an initial WRQ packet as a server." + log.debug("In TftpStateServerRecvWRQ.handle") + sendoack = self.serverInitial(pkt, raddress, rport) + path = self.context.root + os.sep + self.context.file_to_transfer + log.info("Opening file %s for writing" % path) + if os.path.exists(path): + # FIXME: correct behavior? + log.warn("File %s exists already, overwriting...") + # FIXME: I think we should upload to a temp file and not overwrite the + # existing file until the file is successfully uploaded. + self.context.fileobj = open(path, "wb") + + # Options negotiation. + if sendoack: + log.debug("Sending OACK to client") + self.sendOACK() + else: + log.debug("No requested options, starting send...") + self.sendACK() + # We may have sent an OACK, but we're expecting a DAT as the response + # to either the OACK or an ACK, so lets unconditionally use the + # TftpStateExpectDAT state. + return TftpStateExpectDAT(self.context) + + # Note, we don't have to check any other states in this method, that's + # up to the caller. + +class TftpStateServerStart(TftpState): + """The start state for the server.""" + def handle(self, pkt, raddress, rport): + """Handle a packet we just received.""" + log.debug("In TftpStateServerStart.handle") + if isinstance(pkt, TftpPacketRRQ): + log.debug("handling an RRQ packet") + return TftpStateServerRecvRRQ(self.context).handle(pkt, + raddress, + rport) + elif isinstance(pkt, TftpPacketWRQ): + log.debug("handling a WRQ packet") + return TftpStateServerRecvWRQ(self.context).handle(pkt, + raddress, + rport) + else: + self.sendError(TftpErrors.IllegalTftpOp) + raise TftpException, \ + "Invalid packet to begin up/download: %s" % pkt + +class TftpStateExpectACK(TftpState): + """This class represents the state of the transfer when a DAT was just + sent, and we are waiting for an ACK from the server. This class is the + same one used by the client during the upload, and the server during the + download.""" + def handle(self, pkt, raddress, rport): + "Handle a packet, hopefully an ACK since we just sent a DAT." + if isinstance(pkt, TftpPacketACK): + log.info("Received ACK for packet %d" % pkt.blocknumber) + # Is this an ack to the one we just sent? + if self.context.next_block == pkt.blocknumber: + if self.context.pending_complete: + log.info("Received ACK to final DAT, we're done.") + return None + else: + log.debug("Good ACK, sending next DAT") + self.context.pending_complete = self.sendDAT() -class TftpStateSentWRQ(TftpStateUpload): + elif pkt.blocknumber < self.context.next_block: + self.context.metrics.add_dup(pkt.blocknumber) + + else: + log.warn("Oooh, time warp. Received ACK to packet we " + "didn't send yet. Discarding.") + self.context.metrics.errors += 1 + return self + elif isinstance(pkt, TftpPacketERR): + log.error("Received ERR packet from peer: %s" % str(pkt)) + raise TftpException, \ + "Received ERR packet from peer: %s" % str(pkt) + else: + log.warn("Discarding unsupported packet: %s" % str(pkt)) + return self + +class TftpStateExpectDAT(TftpState): + """Just sent an ACK packet. Waiting for DAT.""" + def handle(self, pkt, raddress, rport): + """Handle the packet in response to an ACK, which should be a DAT.""" + if isinstance(pkt, TftpPacketDAT): + return self.handleDat(pkt) + + # Every other packet type is a problem. + elif isinstance(recvpkt, TftpPacketACK): + # Umm, we ACK, you don't. + self.sendError(TftpErrors.IllegalTftpOp) + raise TftpException, "Received ACK from peer when expecting DAT" + + elif isinstance(recvpkt, TftpPacketWRQ): + self.sendError(TftpErrors.IllegalTftpOp) + raise TftpException, "Received WRQ from peer when expecting DAT" + + elif isinstance(recvpkt, TftpPacketERR): + self.sendError(TftpErrors.IllegalTftpOp) + raise TftpException, "Received ERR from peer: " + str(recvpkt) + + else: + self.sendError(TftpErrors.IllegalTftpOp) + raise TftpException, "Received unknown packet type from peer: " + str(recvpkt) + +class TftpStateSentWRQ(TftpState): """Just sent an WRQ packet for an upload.""" def handle(self, pkt, raddress, rport): """Handle a packet we just received.""" if not self.context.tidport: self.context.tidport = rport - logger.debug("Set remote port for session to %s" % rport) + log.debug("Set remote port for session to %s" % rport) # If we're going to successfully transfer the file, then we should see # either an OACK for accepted options, or an ACK to ignore options. if isinstance(pkt, TftpPacketOACK): - logger.info("received OACK from server") + log.info("received OACK from server") try: self.handleOACK(pkt) except TftpException, err: - logger.error("failed to negotiate options") - self.context.sendError(TftpErrors.FailedNegotiation) + log.error("failed to negotiate options") + self.sendError(TftpErrors.FailedNegotiation) raise else: - logger.debug("sending first DAT packet") - fin = self.context.sendDat() - if fin: - logger.info("Add done") - return None - else: - logger.debug("Changing state to TftpStateSentDAT") - return TftpStateSentDAT(self.context) + log.debug("sending first DAT packet") + self.context.pending_complete = self.sendDAT() + log.debug("Changing state to TftpStateExpectACK") + return TftpStateExpectACK(self.context) elif isinstance(pkt, TftpPacketACK): - logger.info("received ACK from server") - logger.debug("apparently the server ignored our options") + log.info("received ACK from server") + log.debug("apparently the server ignored our options") # The block number should be zero. if pkt.blocknumber == 0: - logger.debug("ack blocknumber is zero as expected") - logger.debug("sending first DAT packet") - fin = self.context.sendDat() - if fin: - logger.info("Add done") - return None - else: - logger.debug("Changing state to TftpStateSentDAT") - return TftpStateSentDAT(self.context) + log.debug("ack blocknumber is zero as expected") + log.debug("sending first DAT packet") + self.pending_complete = self.context.sendDAT() + log.debug("Changing state to TftpStateExpectACK") + return TftpStateExpectACK(self.context) else: - logger.warn("discarding ACK to block %s" % pkt.blocknumber) - logger.debug("still waiting for valid response from server") + log.warn("discarding ACK to block %s" % pkt.blocknumber) + log.debug("still waiting for valid response from server") return self elif isinstance(pkt, TftpPacketERR): - self.context.sendError(TftpErrors.IllegalTftpOp) + self.sendError(TftpErrors.IllegalTftpOp) raise TftpException, "Received ERR from server: " + str(pkt) elif isinstance(pkt, TftpPacketRRQ): - self.context.sendError(TftpErrors.IllegalTftpOp) + self.sendError(TftpErrors.IllegalTftpOp) raise TftpException, "Received RRQ from server while in upload" elif isinstance(pkt, TftpPacketDAT): - self.context.sendError(TftpErrors.IllegalTftpOp) + self.sendError(TftpErrors.IllegalTftpOp) raise TftpException, "Received DAT from server while in upload" else: - self.context.sendError(TftpErrors.IllegalTftpOp) + self.sendError(TftpErrors.IllegalTftpOp) raise TftpException, "Received unknown packet type from server: " + str(pkt) # By default, no state change. return self -class TftpStateSentDAT(TftpStateUpload): - """This class represents the state of the transfer when a DAT was just - sent, and we are waiting for an ACK from the server. This class is the - same one used by the client during the upload, and the server during the - download.""" - -class TftpStateSentRRQ(TftpStateDownload): +class TftpStateSentRRQ(TftpState): """Just sent an RRQ packet.""" def handle(self, pkt, raddress, rport): """Handle the packet in response to an RRQ to the server.""" if not self.context.tidport: self.context.tidport = rport - logger.debug("Set remote port for session to %s" % rport) + log.debug("Set remote port for session to %s" % rport) # Now check the packet type and dispatch it properly. if isinstance(pkt, TftpPacketOACK): - logger.info("received OACK from server") + log.info("received OACK from server") try: self.handleOACK(pkt) except TftpException, err: - logger.error("failed to negotiate options: %s" % str(err)) - self.context.sendError(TftpErrors.FailedNegotiation) + log.error("failed to negotiate options: %s" % str(err)) + self.sendError(TftpErrors.FailedNegotiation) raise else: - logger.debug("sending ACK to OACK") + log.debug("sending ACK to OACK") - self.context.sendAck(blocknumber=0) + self.sendACK(blocknumber=0) - logger.debug("Changing state to TftpStateSentACK") - return TftpStateSentACK(self.context) + log.debug("Changing state to TftpStateExpectDAT") + return TftpStateExpectDAT(self.context) elif isinstance(pkt, TftpPacketDAT): # If there are any options set, then the server didn't honour any # of them. - logger.info("received DAT from server") + log.info("received DAT from server") if self.context.options: - logger.info("server ignored options, falling back to defaults") + log.info("server ignored options, falling back to defaults") self.context.options = { 'blksize': DEF_BLKSIZE } return self.handleDat(pkt) # Every other packet type is a problem. elif isinstance(recvpkt, TftpPacketACK): # Umm, we ACK, the server doesn't. - self.context.sendError(TftpErrors.IllegalTftpOp) + self.sendError(TftpErrors.IllegalTftpOp) raise TftpException, "Received ACK from server while in download" elif isinstance(recvpkt, TftpPacketWRQ): - self.context.sendError(TftpErrors.IllegalTftpOp) + self.sendError(TftpErrors.IllegalTftpOp) raise TftpException, "Received WRQ from server while in download" elif isinstance(recvpkt, TftpPacketERR): - self.context.sendError(TftpErrors.IllegalTftpOp) + self.sendError(TftpErrors.IllegalTftpOp) raise TftpException, "Received ERR from server: " + str(recvpkt) else: - self.context.sendError(TftpErrors.IllegalTftpOp) + self.sendError(TftpErrors.IllegalTftpOp) raise TftpException, "Received unknown packet type from server: " + str(recvpkt) # By default, no state change. return self - -class TftpStateSentACK(TftpStateDownload): - """Just sent an ACK packet. Waiting for DAT.""" - def handle(self, pkt, raddress, rport): - """Handle the packet in response to an ACK, which should be a DAT.""" - if isinstance(pkt, TftpPacketDAT): - return self.handleDat(pkt) - - # Every other packet type is a problem. - elif isinstance(recvpkt, TftpPacketACK): - # Umm, we ACK, the server doesn't. - self.context.sendError(TftpErrors.IllegalTftpOp) - raise TftpException, "Received ACK from server while in download" - - elif isinstance(recvpkt, TftpPacketWRQ): - self.context.sendError(TftpErrors.IllegalTftpOp) - raise TftpException, "Received WRQ from server while in download" - - elif isinstance(recvpkt, TftpPacketERR): - self.context.sendError(TftpErrors.IllegalTftpOp) - raise TftpException, "Received ERR from server: " + str(recvpkt) - - else: - self.context.sendError(TftpErrors.IllegalTftpOp) - raise TftpException, "Received unknown packet type from server: " + str(recvpkt) |