summaryrefslogtreecommitdiff
path: root/tftpy/TftpClient.py
diff options
context:
space:
mode:
authorMichael P. Soulier <msoulier@digitaltorque.ca>2009-04-08 23:29:43 -0400
committerMichael P. Soulier <msoulier@digitaltorque.ca>2009-04-10 23:07:49 -0400
commite7a63bbbc2752e79b3c6891951b73bb0ebccbb45 (patch)
tree5c5d33d83a2329b70218440a0120cf3e857187fe /tftpy/TftpClient.py
parent41bf3a25e615edc2192c2639be7f4a713e48c5ef (diff)
downloadtftpy-e7a63bbbc2752e79b3c6891951b73bb0ebccbb45.tar.gz
Started overhaul of state machine.
Diffstat (limited to 'tftpy/TftpClient.py')
-rw-r--r--tftpy/TftpClient.py208
1 files changed, 24 insertions, 184 deletions
diff --git a/tftpy/TftpClient.py b/tftpy/TftpClient.py
index a840689..5947367 100644
--- a/tftpy/TftpClient.py
+++ b/tftpy/TftpClient.py
@@ -1,6 +1,7 @@
-import socket, time, types
+import time, types
from TftpShared import *
from TftpPacketFactory import *
+from TftpStates import TftpContextClientDownload
class TftpClient(TftpSession):
"""This class is an implementation of a tftp client. Once instantiated, a
@@ -9,6 +10,7 @@ class TftpClient(TftpSession):
"""This constructor returns an instance of TftpClient, taking the
remote host, the remote port, and the filename to fetch."""
TftpSession.__init__(self)
+ self.context = None
self.host = host
self.iport = port
self.filename = None
@@ -51,192 +53,30 @@ class TftpClient(TftpSession):
object. The timeout parameter may be used to override the default
SOCK_TIMEOUT setting, which is the amount of time that the client will
wait for a receive packet to arrive."""
- # Open the output file.
- # FIXME - need to support alternate return formats than files?
- # File-like objects would be ideal, ala duck-typing.
- self.fileobj = open(output, "wb")
- recvpkt = None
- curblock = 0
- dups = {}
- start_time = time.time()
- self.bytes = 0
-
- self.filename = filename
-
- tftp_factory = TftpPacketFactory()
- self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
- self.sock.settimeout(timeout)
-
- logger.info("Sending tftp download request to %s" % self.host)
- logger.info(" filename -> %s" % filename)
- pkt = TftpPacketRRQ()
- pkt.filename = filename
- pkt.mode = "octet" # FIXME - shouldn't hardcode this
- pkt.options = self.options
- self.sock.sendto(pkt.encode().buffer, (self.host, self.iport))
- self.state.state = 'rrq'
-
- timeouts = 0
- while True:
- try:
- (buffer, (raddress, rport)) = self.sock.recvfrom(MAX_BLKSIZE)
- except socket.timeout, err:
- timeouts += 1
- if timeouts >= TIMEOUT_RETRIES:
- raise TftpException, "Hit max timeouts, giving up."
- else:
- logger.warn("Timeout waiting for traffic, retrying...")
- continue
-
- recvpkt = tftp_factory.parse(buffer)
-
- logger.debug("Received %d bytes from %s:%s"
- % (len(buffer), raddress, rport))
-
- # Check for known "connection".
- if raddress != self.address:
- logger.warn("Received traffic from %s, expected host %s. Discarding"
- % (raddress, self.host))
- continue
- if self.port and self.port != rport:
- logger.warn("Received traffic from %s:%s but we're "
- "connected to %s:%s. Discarding."
- % (raddress, rport,
- self.host, self.port))
- continue
-
- # If there is a packethook defined, call it. We unconditionally
- # pass all packets, it's up to the client to screen out different
- # kinds of packets. This way, the client is privy to things like
- # negotiated options.
- if packethook:
- packethook(recvpkt)
-
- if not self.port and self.state.state == 'rrq':
- self.port = rport
- logger.debug("Set remote port for session to %s" % rport)
-
- if isinstance(recvpkt, TftpPacketDAT):
- logger.debug("recvpkt.blocknumber = %d" % recvpkt.blocknumber)
- logger.debug("curblock = %d" % curblock)
-
- if self.state.state == 'rrq' and self.options:
- logger.info("no OACK, our options were ignored")
- self.options = { 'blksize': DEF_BLKSIZE }
- self.state.state = 'ack'
-
- expected_block = curblock + 1
- if expected_block > 65535:
- logger.debug("block number rollover to 0 again")
- expected_block = 0
- if recvpkt.blocknumber == expected_block:
- logger.debug("good, received block %d in sequence"
- % recvpkt.blocknumber)
- curblock = expected_block
-
-
- # ACK the packet, and save the data.
- logger.info("sending ACK to block %d" % curblock)
- logger.debug("ip = %s, port = %s" % (self.host, self.port))
- ackpkt = TftpPacketACK()
- ackpkt.blocknumber = curblock
- self.sock.sendto(ackpkt.encode().buffer, (self.host, self.port))
-
- logger.debug("writing %d bytes to output file"
- % len(recvpkt.data))
- self.fileobj.write(recvpkt.data)
- self.bytes += len(recvpkt.data)
- # Check for end-of-file, any less than full data packet.
- if len(recvpkt.data) < int(self.options['blksize']):
- logger.info("end of file detected")
- break
-
- elif recvpkt.blocknumber == curblock:
- logger.warn("dropping duplicate block %d" % curblock)
- if dups.has_key(curblock):
- dups[curblock] += 1
- else:
- dups[curblock] = 1
- tftpassert(dups[curblock] < MAX_DUPS,
- "Max duplicates for block %d reached" % curblock)
- logger.debug("ACKing block %d again, just in case" % curblock)
- ackpkt = TftpPacketACK()
- ackpkt.blocknumber = curblock
- self.sock.sendto(ackpkt.encode().buffer, (self.host, self.port))
-
- else:
- msg = "Whoa! Received block %d but expected %d" % (recvpkt.blocknumber,
- curblock+1)
- logger.error(msg)
- raise TftpException, msg
-
- # Check other packet types.
- elif isinstance(recvpkt, TftpPacketOACK):
- if not self.state.state == 'rrq':
- self.errors += 1
- logger.error("Received OACK in state %s" % self.state.state)
- continue
-
- self.state.state = 'oack'
- logger.info("Received OACK from server.")
- if recvpkt.options.keys() > 0:
- if recvpkt.match_options(self.options):
- logger.info("Successful negotiation of options")
- # Set options to OACK options
- self.options = recvpkt.options
- for key in self.options:
- logger.info(" %s = %s" % (key, self.options[key]))
- logger.debug("sending ACK to OACK")
- ackpkt = TftpPacketACK()
- ackpkt.blocknumber = 0
- self.sock.sendto(ackpkt.encode().buffer, (self.host, self.port))
- self.state.state = 'ack'
- else:
- logger.error("failed to negotiate options")
- self.senderror(self.sock, TftpErrors.FailedNegotiation, self.host, self.port)
- self.state.state = 'err'
- raise TftpException, "Failed to negotiate options"
-
- elif isinstance(recvpkt, TftpPacketACK):
- # Umm, we ACK, the server doesn't.
- self.state.state = 'err'
- self.senderror(self.sock, TftpErrors.IllegalTftpOp, self.host, self.port)
- tftpassert(False, "Received ACK from server while in download")
-
- elif isinstance(recvpkt, TftpPacketERR):
- self.state.state = 'err'
- self.senderror(self.sock, TftpErrors.IllegalTftpOp, self.host, self.port)
- tftpassert(False, "Received ERR from server: " + str(recvpkt))
-
- elif isinstance(recvpkt, TftpPacketWRQ):
- self.state.state = 'err'
- self.senderror(self.sock, TftpErrors.IllegalTftpOp, self.host, self.port)
- tftpassert(False, "Received WRQ from server: " + str(recvpkt))
-
- else:
- self.state.state = 'err'
- self.senderror(self.sock, TftpErrors.IllegalTftpOp, self.host, self.port)
- tftpassert(False, "Received unknown packet type from server: "
- + str(recvpkt))
-
-
- # end while
- self.fileobj.close()
-
- end_time = time.time()
- duration = end_time - start_time
- if duration == 0:
+ # We're downloading.
+ self.context = TftpContextClientDownload(self.host,
+ self.iport,
+ filename,
+ output,
+ self.options,
+ packethook,
+ timeout)
+ self.context.start()
+ # Download happens here
+ self.context.end()
+
+ metrics = self.context.metrics
+
+ # FIXME: Should we output this? Shouldn't we let the client control
+ # output? This should be in the sample client, but not in the download
+ # call.
+ if metrics.duration == 0:
logger.info("Duration too short, rate undetermined")
else:
logger.info('')
- logger.info("Downloaded %d bytes in %d seconds" % (self.bytes, duration))
- bps = (self.bytes * 8.0) / duration
- kbps = bps / 1024.0
- logger.info("Average rate: %.2f kbps" % kbps)
- dupcount = 0
- for key in dups:
- dupcount += dups[key]
- logger.info("Received %d duplicate packets" % dupcount)
+ logger.info("Downloaded %d bytes in %d seconds" % (metrics.bytes, metrics.duration))
+ logger.info("Average rate: %.2f kbps" % metrics.kbps)
+ logger.info("Received %d duplicate packets" % metrics.dupcount)
def upload(self, filename, input, packethook=None, timeout=SOCK_TIMEOUT):
# Open the input file.