diff options
author | Michael P. Soulier <msoulier@digitaltorque.ca> | 2008-07-30 14:04:01 -0400 |
---|---|---|
committer | Michael P. Soulier <msoulier@digitaltorque.ca> | 2008-07-30 14:04:01 -0400 |
commit | 67302801eba3b0d939c0b5d04b5b6d654ed68101 (patch) | |
tree | a23920b31b3e01c48407568280632ac93039f750 | |
parent | 33b135348f1ea9c0f035c51fa743c87d49257852 (diff) | |
download | tftpy-67302801eba3b0d939c0b5d04b5b6d654ed68101.tar.gz |
Adding upload patch from Lorenz Schori - patch 1897344 in SF tracker
-rw-r--r-- | bin/tftpy_client.py | 31 | ||||
-rw-r--r-- | tftpy/TftpClient.py | 236 |
2 files changed, 235 insertions, 32 deletions
diff --git a/bin/tftpy_client.py b/bin/tftpy_client.py index 961fb78..d4be2f8 100644 --- a/bin/tftpy_client.py +++ b/bin/tftpy_client.py @@ -17,6 +17,9 @@ def main(): parser.add_option('-f', '--filename', help='filename to fetch') + parser.add_option('-u', + '--upload', + help='filename to upload') parser.add_option('-b', '--blocksize', help='udp packet size to use (default: 512)', @@ -24,6 +27,9 @@ def main(): parser.add_option('-o', '--output', help='output file (default: same as requested filename)') + parser.add_option('-i', + '--input', + help='input file (default: same as upload filename)') parser.add_option('-d', '--debug', action='store_true', @@ -35,7 +41,7 @@ def main(): default=False, help="downgrade logging from info to warning") options, args = parser.parse_args() - if not options.host or not options.filename: + if not options.host or (not options.filename and not options.upload): sys.stderr.write("Both the --host and --filename options " "are required.\n") parser.print_help() @@ -47,17 +53,14 @@ def main(): parser.print_help() sys.exit(1) - if not options.output: - options.output = os.path.basename(options.filename) - class Progress(object): def __init__(self, out): self.progress = 0 self.out = out def progresshook(self, pkt): self.progress += len(pkt.data) - self.out("Downloaded %d bytes" % self.progress) - + self.out("Transferred %d bytes" % self.progress) + if options.debug: tftpy.setLogLevel(logging.DEBUG) elif options.quiet: @@ -74,10 +77,18 @@ def main(): tclient = tftpy.TftpClient(options.host, int(options.port), tftp_options) - - tclient.download(options.filename, - options.output, - progresshook) + if(options.filename): + if not options.output: + options.output = os.path.basename(options.filename) + tclient.download(options.filename, + options.output, + progresshook) + elif(options.upload): + if not options.input: + options.input = os.path.basename(options.upload) + tclient.upload(options.upload, + options.input, + progresshook) if __name__ == '__main__': main() diff --git a/tftpy/TftpClient.py b/tftpy/TftpClient.py index c64245b..7c14762 100644 --- a/tftpy/TftpClient.py +++ b/tftpy/TftpClient.py @@ -11,7 +11,13 @@ class TftpClient(TftpSession): TftpSession.__init__(self) self.host = host self.iport = port + self.filename = None self.options = options + self.blocknumber = 0 + self.fileobj = None; + self.timesent = 0 + self.buffer = None; + self.bytes = 0; if self.options.has_key('blksize'): size = self.options['blksize'] tftpassert(types.IntType == type(size), "blksize must be an int") @@ -20,21 +26,20 @@ class TftpClient(TftpSession): else: self.options['blksize'] = DEF_BLKSIZE # Support other options here? timeout time, retries, etc? - # The remote sending port, to identify the connection. self.port = None self.sock = None - + def gethost(self): "Simple getter method for use in a property." return self.__host - + def sethost(self, host): """Setter method that also sets the address property as a result of the host that is set.""" self.__host = host self.address = socket.gethostbyname(host) - + host = property(gethost, sethost) def download(self, filename, output, packethook=None, timeout=SOCK_TIMEOUT): @@ -49,12 +54,14 @@ class TftpClient(TftpSession): # Open the output file. # FIXME - need to support alternate return formats than files? # File-like objects would be ideal, ala duck-typing. - outputfile = open(output, "wb") + self.fileobj = open(output, "wb") recvpkt = None curblock = 0 dups = {} start_time = time.time() - bytes = 0 + self.bytes = 0 + + self.filename = filename tftp_factory = TftpPacketFactory() self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) @@ -68,7 +75,7 @@ class TftpClient(TftpSession): pkt.options = self.options self.sock.sendto(pkt.encode().buffer, (self.host, self.iport)) self.state.state = 'rrq' - + timeouts = 0 while True: try: @@ -83,9 +90,9 @@ class TftpClient(TftpSession): recvpkt = tftp_factory.parse(buffer) - logger.debug("Received %d bytes from %s:%s" + logger.debug("Received %d bytes from %s:%s" % (len(buffer), raddress, rport)) - + # Check for known "connection". if raddress != self.address: logger.warn("Received traffic from %s, expected host %s. Discarding" @@ -97,11 +104,11 @@ class TftpClient(TftpSession): % (raddress, rport, self.host, self.port)) continue - + if not self.port and self.state.state == 'rrq': self.port = rport logger.debug("Set remote port for session to %s" % rport) - + if isinstance(recvpkt, TftpPacketDAT): logger.debug("recvpkt.blocknumber = %d" % recvpkt.blocknumber) logger.debug("curblock = %d" % curblock) @@ -110,22 +117,22 @@ class TftpClient(TftpSession): logger.debug("block number rollover to 0 again") expected_block = 0 if recvpkt.blocknumber == expected_block: - logger.debug("good, received block %d in sequence" + logger.debug("good, received block %d in sequence" % recvpkt.blocknumber) curblock = expected_block - + # ACK the packet, and save the data. logger.info("sending ACK to block %d" % curblock) logger.debug("ip = %s, port = %s" % (self.host, self.port)) ackpkt = TftpPacketACK() ackpkt.blocknumber = curblock self.sock.sendto(ackpkt.encode().buffer, (self.host, self.port)) - - logger.debug("writing %d bytes to output file" + + logger.debug("writing %d bytes to output file" % len(recvpkt.data)) - outputfile.write(recvpkt.data) - bytes += len(recvpkt.data) + self.fileobj.write(recvpkt.data) + self.bytes += len(recvpkt.data) # If there is a packethook defined, call it. if packethook: packethook(recvpkt) @@ -148,7 +155,7 @@ class TftpClient(TftpSession): self.sock.sendto(ackpkt.encode().buffer, (self.host, self.port)) else: - msg = "Whoa! Received block %d but expected %d" % (recvpkt.blocknumber, + msg = "Whoa! Received block %d but expected %d" % (recvpkt.blocknumber, curblock+1) logger.error(msg) raise TftpException, msg @@ -159,7 +166,7 @@ class TftpClient(TftpSession): self.errors += 1 logger.error("Received OACK in state %s" % self.state.state) continue - + self.state.state = 'oack' logger.info("Received OACK from server.") if recvpkt.options.keys() > 0: @@ -202,7 +209,7 @@ class TftpClient(TftpSession): # end while - outputfile.close() + self.fileobj.close() end_time = time.time() duration = end_time - start_time @@ -210,11 +217,196 @@ class TftpClient(TftpSession): logger.info("Duration too short, rate undetermined") else: logger.info('') - logger.info("Downloaded %d bytes in %d seconds" % (bytes, duration)) - bps = (bytes * 8.0) / duration + logger.info("Downloaded %d bytes in %d seconds" % (self.bytes, duration)) + bps = (self.bytes * 8.0) / duration kbps = bps / 1024.0 logger.info("Average rate: %.2f kbps" % kbps) dupcount = 0 for key in dups: dupcount += dups[key] logger.info("Received %d duplicate packets" % dupcount) + + def upload(self, filename, input, packethook=None, timeout=SOCK_TIMEOUT): + # Open the input file. + self.fileobj = open(input, "rb") + recvpkt = None + curblock = 0 + start_time = time.time() + self.bytes = 0 + + tftp_factory = TftpPacketFactory() + self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + self.sock.settimeout(timeout) + + self.filename = filename + + self.send_wrq() + self.state.state = 'wrq' + + timeouts = 0 + while True: + try: + (buffer, (raddress, rport)) = self.sock.recvfrom(MAX_BLKSIZE) + except socket.timeout, err: + timeouts += 1 + if timeouts >= TIMEOUT_RETRIES: + raise TftpException, "Hit max timeouts, giving up." + else: + if self.state.state == 'dat' or self.state.state == 'fin': + logger.debug("Timing out on DAT. Need to resend.") + self.send_dat(packethook,resend=True) + elif self.state.state == 'wrq': + logger.debug("Timing out on WRQ.") + self.send_wrq(resend=True) + else: + tftpassert(False, + "Timing out in unsupported state %s" % + self.state.state) + continue + + recvpkt = tftp_factory.parse(buffer) + + logger.debug("Received %d bytes from %s:%s" + % (len(buffer), raddress, rport)) + + # Check for known "connection". + if raddress != self.address: + logger.warn("Received traffic from %s, expected host %s. Discarding" + % (raddress, self.host)) + continue + if self.port and self.port != rport: + logger.warn("Received traffic from %s:%s but we're " + "connected to %s:%s. Discarding." + % (raddress, rport, + self.host, self.port)) + continue + + if not self.port and self.state.state == 'wrq': + self.port = rport + logger.debug("Set remote port for session to %s" % rport) + + # Next packet type + if isinstance(recvpkt, TftpPacketACK): + logger.debug("Received an ACK from the server.") + # tftp on wrt54gl seems to answer with an ack to a wrq regardless + # if we sent options. + if recvpkt.blocknumber == 0 and self.state.state in ('oack','wrq'): + logger.debug("Received ACK with 0 blocknumber, starting upload") + self.state.state = 'dat' + self.send_dat(packethook) + else: + if self.state.state == 'dat' or self.state.state == 'fin': + if self.blocknumber == recvpkt.blocknumber: + logger.info("Received ACK for block %d" + % recvpkt.blocknumber) + if self.state.state == 'fin': + break + else: + self.send_dat(packethook) + elif recvpkt.blocknumber < self.blocknumber: + # Don't resend a DAT due to an old ACK. Fixes the + # sorceror's apprentice problem. + logger.warn("Received old ACK for block number %d" + % recvpkt.blocknumber) + else: + logger.warn("Received ACK for block number " + "%d, apparently from the future" + % recvpkt.blocknumber) + else: + logger.error("Received ACK with block number %d " + "while in state %s" + % (recvpkt.blocknumber, + self.state.state)) + + # Check other packet types. + elif isinstance(recvpkt, TftpPacketOACK): + if not self.state.state == 'wrq': + self.errors += 1 + logger.error("Received OACK in state %s" % self.state.state) + continue + + self.state.state = 'oack' + logger.info("Received OACK from server.") + if recvpkt.options.keys() > 0: + if recvpkt.match_options(self.options): + logger.info("Successful negotiation of options") + for key in self.options: + logger.info(" %s = %s" % (key, self.options[key])) + logger.debug("sending ACK to OACK") + ackpkt = TftpPacketACK() + ackpkt.blocknumber = 0 + self.sock.sendto(ackpkt.encode().buffer, (self.host, self.port)) + self.state.state = 'dat' + self.send_dat(packethook) + else: + logger.error("failed to negotiate options") + self.senderror(self.sock, TftpErrors.FailedNegotiation, self.host, self.port) + self.state.state = 'err' + raise TftpException, "Failed to negotiate options" + + elif isinstance(recvpkt, TftpPacketERR): + self.state.state = 'err' + self.senderror(self.sock, TftpErrors.IllegalTftpOp, self.host, self.port) + tftpassert(False, "Received ERR from server: " + str(recvpkt)) + + elif isinstance(recvpkt, TftpPacketWRQ): + self.state.state = 'err' + self.senderror(self.sock, TftpErrors.IllegalTftpOp, self.host, self.port) + tftpassert(False, "Received WRQ from server: " + str(recvpkt)) + + else: + self.state.state = 'err' + self.senderror(self.sock, TftpErrors.IllegalTftpOp, self.host, self.port) + tftpassert(False, "Received unknown packet type from server: " + + str(recvpkt)) + + + # end while + self.fileobj.close() + + end_time = time.time() + duration = end_time - start_time + if duration == 0: + logger.info("Duration too short, rate undetermined") + else: + logger.info('') + logger.info("Uploaded %d bytes in %d seconds" % (self.bytes, duration)) + bps = (self.bytes * 8.0) / duration + kbps = bps / 1024.0 + logger.info("Average rate: %.2f kbps" % kbps) + + def send_dat(self, packethook, resend=False): + """This method reads and sends a DAT packet based on what is in self.buffer.""" + if not resend: + blksize = int(self.options['blksize']) + self.buffer = self.fileobj.read(blksize) + logger.debug("Read %d bytes into buffer" % len(self.buffer)) + if len(self.buffer) < blksize: + logger.info("Reached EOF on file %s" % self.filename) + self.state.state = 'fin' + self.blocknumber += 1 + if self.blocknumber > 65535: + logger.debug("Blocknumber rolled over to zero") + self.blocknumber = 0 + self.bytes += len(self.buffer) + else: + logger.warn("Resending block number %d" % self.blocknumber) + dat = TftpPacketDAT() + dat.data = self.buffer + dat.blocknumber = self.blocknumber + logger.debug("Sending DAT packet %d" % self.blocknumber) + self.sock.sendto(dat.encode().buffer, (self.host, self.port)) + self.timesent = time.time() + if packethook: + packethook(dat) + + def send_wrq(self, resend=False): + """This method sends a wrq""" + logger.info("Sending tftp upload request to %s" % self.host) + logger.info(" filename -> %s" % self.filename) + + wrq = TftpPacketWRQ() + wrq.filename = self.filename + wrq.mode = "octet" # FIXME - shouldn't hardcode this + wrq.options = self.options + self.sock.sendto(wrq.encode().buffer, (self.host, self.iport)) |