diff options
author | Michael P. Soulier <msoulier@digitaltorque.ca> | 2009-04-08 23:29:43 -0400 |
---|---|---|
committer | Michael P. Soulier <msoulier@digitaltorque.ca> | 2009-04-10 23:07:49 -0400 |
commit | e7a63bbbc2752e79b3c6891951b73bb0ebccbb45 (patch) | |
tree | 5c5d33d83a2329b70218440a0120cf3e857187fe /tftpy/TftpClient.py | |
parent | 41bf3a25e615edc2192c2639be7f4a713e48c5ef (diff) | |
download | tftpy-e7a63bbbc2752e79b3c6891951b73bb0ebccbb45.tar.gz |
Started overhaul of state machine.
Diffstat (limited to 'tftpy/TftpClient.py')
-rw-r--r-- | tftpy/TftpClient.py | 208 |
1 files changed, 24 insertions, 184 deletions
diff --git a/tftpy/TftpClient.py b/tftpy/TftpClient.py index a840689..5947367 100644 --- a/tftpy/TftpClient.py +++ b/tftpy/TftpClient.py @@ -1,6 +1,7 @@ -import socket, time, types +import time, types from TftpShared import * from TftpPacketFactory import * +from TftpStates import TftpContextClientDownload class TftpClient(TftpSession): """This class is an implementation of a tftp client. Once instantiated, a @@ -9,6 +10,7 @@ class TftpClient(TftpSession): """This constructor returns an instance of TftpClient, taking the remote host, the remote port, and the filename to fetch.""" TftpSession.__init__(self) + self.context = None self.host = host self.iport = port self.filename = None @@ -51,192 +53,30 @@ class TftpClient(TftpSession): object. The timeout parameter may be used to override the default SOCK_TIMEOUT setting, which is the amount of time that the client will wait for a receive packet to arrive.""" - # Open the output file. - # FIXME - need to support alternate return formats than files? - # File-like objects would be ideal, ala duck-typing. - self.fileobj = open(output, "wb") - recvpkt = None - curblock = 0 - dups = {} - start_time = time.time() - self.bytes = 0 - - self.filename = filename - - tftp_factory = TftpPacketFactory() - self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - self.sock.settimeout(timeout) - - logger.info("Sending tftp download request to %s" % self.host) - logger.info(" filename -> %s" % filename) - pkt = TftpPacketRRQ() - pkt.filename = filename - pkt.mode = "octet" # FIXME - shouldn't hardcode this - pkt.options = self.options - self.sock.sendto(pkt.encode().buffer, (self.host, self.iport)) - self.state.state = 'rrq' - - timeouts = 0 - while True: - try: - (buffer, (raddress, rport)) = self.sock.recvfrom(MAX_BLKSIZE) - except socket.timeout, err: - timeouts += 1 - if timeouts >= TIMEOUT_RETRIES: - raise TftpException, "Hit max timeouts, giving up." - else: - logger.warn("Timeout waiting for traffic, retrying...") - continue - - recvpkt = tftp_factory.parse(buffer) - - logger.debug("Received %d bytes from %s:%s" - % (len(buffer), raddress, rport)) - - # Check for known "connection". - if raddress != self.address: - logger.warn("Received traffic from %s, expected host %s. Discarding" - % (raddress, self.host)) - continue - if self.port and self.port != rport: - logger.warn("Received traffic from %s:%s but we're " - "connected to %s:%s. Discarding." - % (raddress, rport, - self.host, self.port)) - continue - - # If there is a packethook defined, call it. We unconditionally - # pass all packets, it's up to the client to screen out different - # kinds of packets. This way, the client is privy to things like - # negotiated options. - if packethook: - packethook(recvpkt) - - if not self.port and self.state.state == 'rrq': - self.port = rport - logger.debug("Set remote port for session to %s" % rport) - - if isinstance(recvpkt, TftpPacketDAT): - logger.debug("recvpkt.blocknumber = %d" % recvpkt.blocknumber) - logger.debug("curblock = %d" % curblock) - - if self.state.state == 'rrq' and self.options: - logger.info("no OACK, our options were ignored") - self.options = { 'blksize': DEF_BLKSIZE } - self.state.state = 'ack' - - expected_block = curblock + 1 - if expected_block > 65535: - logger.debug("block number rollover to 0 again") - expected_block = 0 - if recvpkt.blocknumber == expected_block: - logger.debug("good, received block %d in sequence" - % recvpkt.blocknumber) - curblock = expected_block - - - # ACK the packet, and save the data. - logger.info("sending ACK to block %d" % curblock) - logger.debug("ip = %s, port = %s" % (self.host, self.port)) - ackpkt = TftpPacketACK() - ackpkt.blocknumber = curblock - self.sock.sendto(ackpkt.encode().buffer, (self.host, self.port)) - - logger.debug("writing %d bytes to output file" - % len(recvpkt.data)) - self.fileobj.write(recvpkt.data) - self.bytes += len(recvpkt.data) - # Check for end-of-file, any less than full data packet. - if len(recvpkt.data) < int(self.options['blksize']): - logger.info("end of file detected") - break - - elif recvpkt.blocknumber == curblock: - logger.warn("dropping duplicate block %d" % curblock) - if dups.has_key(curblock): - dups[curblock] += 1 - else: - dups[curblock] = 1 - tftpassert(dups[curblock] < MAX_DUPS, - "Max duplicates for block %d reached" % curblock) - logger.debug("ACKing block %d again, just in case" % curblock) - ackpkt = TftpPacketACK() - ackpkt.blocknumber = curblock - self.sock.sendto(ackpkt.encode().buffer, (self.host, self.port)) - - else: - msg = "Whoa! Received block %d but expected %d" % (recvpkt.blocknumber, - curblock+1) - logger.error(msg) - raise TftpException, msg - - # Check other packet types. - elif isinstance(recvpkt, TftpPacketOACK): - if not self.state.state == 'rrq': - self.errors += 1 - logger.error("Received OACK in state %s" % self.state.state) - continue - - self.state.state = 'oack' - logger.info("Received OACK from server.") - if recvpkt.options.keys() > 0: - if recvpkt.match_options(self.options): - logger.info("Successful negotiation of options") - # Set options to OACK options - self.options = recvpkt.options - for key in self.options: - logger.info(" %s = %s" % (key, self.options[key])) - logger.debug("sending ACK to OACK") - ackpkt = TftpPacketACK() - ackpkt.blocknumber = 0 - self.sock.sendto(ackpkt.encode().buffer, (self.host, self.port)) - self.state.state = 'ack' - else: - logger.error("failed to negotiate options") - self.senderror(self.sock, TftpErrors.FailedNegotiation, self.host, self.port) - self.state.state = 'err' - raise TftpException, "Failed to negotiate options" - - elif isinstance(recvpkt, TftpPacketACK): - # Umm, we ACK, the server doesn't. - self.state.state = 'err' - self.senderror(self.sock, TftpErrors.IllegalTftpOp, self.host, self.port) - tftpassert(False, "Received ACK from server while in download") - - elif isinstance(recvpkt, TftpPacketERR): - self.state.state = 'err' - self.senderror(self.sock, TftpErrors.IllegalTftpOp, self.host, self.port) - tftpassert(False, "Received ERR from server: " + str(recvpkt)) - - elif isinstance(recvpkt, TftpPacketWRQ): - self.state.state = 'err' - self.senderror(self.sock, TftpErrors.IllegalTftpOp, self.host, self.port) - tftpassert(False, "Received WRQ from server: " + str(recvpkt)) - - else: - self.state.state = 'err' - self.senderror(self.sock, TftpErrors.IllegalTftpOp, self.host, self.port) - tftpassert(False, "Received unknown packet type from server: " - + str(recvpkt)) - - - # end while - self.fileobj.close() - - end_time = time.time() - duration = end_time - start_time - if duration == 0: + # We're downloading. + self.context = TftpContextClientDownload(self.host, + self.iport, + filename, + output, + self.options, + packethook, + timeout) + self.context.start() + # Download happens here + self.context.end() + + metrics = self.context.metrics + + # FIXME: Should we output this? Shouldn't we let the client control + # output? This should be in the sample client, but not in the download + # call. + if metrics.duration == 0: logger.info("Duration too short, rate undetermined") else: logger.info('') - logger.info("Downloaded %d bytes in %d seconds" % (self.bytes, duration)) - bps = (self.bytes * 8.0) / duration - kbps = bps / 1024.0 - logger.info("Average rate: %.2f kbps" % kbps) - dupcount = 0 - for key in dups: - dupcount += dups[key] - logger.info("Received %d duplicate packets" % dupcount) + logger.info("Downloaded %d bytes in %d seconds" % (metrics.bytes, metrics.duration)) + logger.info("Average rate: %.2f kbps" % metrics.kbps) + logger.info("Received %d duplicate packets" % metrics.dupcount) def upload(self, filename, input, packethook=None, timeout=SOCK_TIMEOUT): # Open the input file. |