summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMichael P. Soulier <msoulier@digitaltorque.ca>2008-07-30 14:04:01 -0400
committerMichael P. Soulier <msoulier@digitaltorque.ca>2008-07-30 14:04:01 -0400
commit67302801eba3b0d939c0b5d04b5b6d654ed68101 (patch)
treea23920b31b3e01c48407568280632ac93039f750
parent33b135348f1ea9c0f035c51fa743c87d49257852 (diff)
downloadtftpy-67302801eba3b0d939c0b5d04b5b6d654ed68101.tar.gz
Adding upload patch from Lorenz Schori - patch 1897344 in SF tracker
-rw-r--r--bin/tftpy_client.py31
-rw-r--r--tftpy/TftpClient.py236
2 files changed, 235 insertions, 32 deletions
diff --git a/bin/tftpy_client.py b/bin/tftpy_client.py
index 961fb78..d4be2f8 100644
--- a/bin/tftpy_client.py
+++ b/bin/tftpy_client.py
@@ -17,6 +17,9 @@ def main():
parser.add_option('-f',
'--filename',
help='filename to fetch')
+ parser.add_option('-u',
+ '--upload',
+ help='filename to upload')
parser.add_option('-b',
'--blocksize',
help='udp packet size to use (default: 512)',
@@ -24,6 +27,9 @@ def main():
parser.add_option('-o',
'--output',
help='output file (default: same as requested filename)')
+ parser.add_option('-i',
+ '--input',
+ help='input file (default: same as upload filename)')
parser.add_option('-d',
'--debug',
action='store_true',
@@ -35,7 +41,7 @@ def main():
default=False,
help="downgrade logging from info to warning")
options, args = parser.parse_args()
- if not options.host or not options.filename:
+ if not options.host or (not options.filename and not options.upload):
sys.stderr.write("Both the --host and --filename options "
"are required.\n")
parser.print_help()
@@ -47,17 +53,14 @@ def main():
parser.print_help()
sys.exit(1)
- if not options.output:
- options.output = os.path.basename(options.filename)
-
class Progress(object):
def __init__(self, out):
self.progress = 0
self.out = out
def progresshook(self, pkt):
self.progress += len(pkt.data)
- self.out("Downloaded %d bytes" % self.progress)
-
+ self.out("Transferred %d bytes" % self.progress)
+
if options.debug:
tftpy.setLogLevel(logging.DEBUG)
elif options.quiet:
@@ -74,10 +77,18 @@ def main():
tclient = tftpy.TftpClient(options.host,
int(options.port),
tftp_options)
-
- tclient.download(options.filename,
- options.output,
- progresshook)
+ if(options.filename):
+ if not options.output:
+ options.output = os.path.basename(options.filename)
+ tclient.download(options.filename,
+ options.output,
+ progresshook)
+ elif(options.upload):
+ if not options.input:
+ options.input = os.path.basename(options.upload)
+ tclient.upload(options.upload,
+ options.input,
+ progresshook)
if __name__ == '__main__':
main()
diff --git a/tftpy/TftpClient.py b/tftpy/TftpClient.py
index c64245b..7c14762 100644
--- a/tftpy/TftpClient.py
+++ b/tftpy/TftpClient.py
@@ -11,7 +11,13 @@ class TftpClient(TftpSession):
TftpSession.__init__(self)
self.host = host
self.iport = port
+ self.filename = None
self.options = options
+ self.blocknumber = 0
+ self.fileobj = None;
+ self.timesent = 0
+ self.buffer = None;
+ self.bytes = 0;
if self.options.has_key('blksize'):
size = self.options['blksize']
tftpassert(types.IntType == type(size), "blksize must be an int")
@@ -20,21 +26,20 @@ class TftpClient(TftpSession):
else:
self.options['blksize'] = DEF_BLKSIZE
# Support other options here? timeout time, retries, etc?
-
# The remote sending port, to identify the connection.
self.port = None
self.sock = None
-
+
def gethost(self):
"Simple getter method for use in a property."
return self.__host
-
+
def sethost(self, host):
"""Setter method that also sets the address property as a result
of the host that is set."""
self.__host = host
self.address = socket.gethostbyname(host)
-
+
host = property(gethost, sethost)
def download(self, filename, output, packethook=None, timeout=SOCK_TIMEOUT):
@@ -49,12 +54,14 @@ class TftpClient(TftpSession):
# Open the output file.
# FIXME - need to support alternate return formats than files?
# File-like objects would be ideal, ala duck-typing.
- outputfile = open(output, "wb")
+ self.fileobj = open(output, "wb")
recvpkt = None
curblock = 0
dups = {}
start_time = time.time()
- bytes = 0
+ self.bytes = 0
+
+ self.filename = filename
tftp_factory = TftpPacketFactory()
self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
@@ -68,7 +75,7 @@ class TftpClient(TftpSession):
pkt.options = self.options
self.sock.sendto(pkt.encode().buffer, (self.host, self.iport))
self.state.state = 'rrq'
-
+
timeouts = 0
while True:
try:
@@ -83,9 +90,9 @@ class TftpClient(TftpSession):
recvpkt = tftp_factory.parse(buffer)
- logger.debug("Received %d bytes from %s:%s"
+ logger.debug("Received %d bytes from %s:%s"
% (len(buffer), raddress, rport))
-
+
# Check for known "connection".
if raddress != self.address:
logger.warn("Received traffic from %s, expected host %s. Discarding"
@@ -97,11 +104,11 @@ class TftpClient(TftpSession):
% (raddress, rport,
self.host, self.port))
continue
-
+
if not self.port and self.state.state == 'rrq':
self.port = rport
logger.debug("Set remote port for session to %s" % rport)
-
+
if isinstance(recvpkt, TftpPacketDAT):
logger.debug("recvpkt.blocknumber = %d" % recvpkt.blocknumber)
logger.debug("curblock = %d" % curblock)
@@ -110,22 +117,22 @@ class TftpClient(TftpSession):
logger.debug("block number rollover to 0 again")
expected_block = 0
if recvpkt.blocknumber == expected_block:
- logger.debug("good, received block %d in sequence"
+ logger.debug("good, received block %d in sequence"
% recvpkt.blocknumber)
curblock = expected_block
-
+
# ACK the packet, and save the data.
logger.info("sending ACK to block %d" % curblock)
logger.debug("ip = %s, port = %s" % (self.host, self.port))
ackpkt = TftpPacketACK()
ackpkt.blocknumber = curblock
self.sock.sendto(ackpkt.encode().buffer, (self.host, self.port))
-
- logger.debug("writing %d bytes to output file"
+
+ logger.debug("writing %d bytes to output file"
% len(recvpkt.data))
- outputfile.write(recvpkt.data)
- bytes += len(recvpkt.data)
+ self.fileobj.write(recvpkt.data)
+ self.bytes += len(recvpkt.data)
# If there is a packethook defined, call it.
if packethook:
packethook(recvpkt)
@@ -148,7 +155,7 @@ class TftpClient(TftpSession):
self.sock.sendto(ackpkt.encode().buffer, (self.host, self.port))
else:
- msg = "Whoa! Received block %d but expected %d" % (recvpkt.blocknumber,
+ msg = "Whoa! Received block %d but expected %d" % (recvpkt.blocknumber,
curblock+1)
logger.error(msg)
raise TftpException, msg
@@ -159,7 +166,7 @@ class TftpClient(TftpSession):
self.errors += 1
logger.error("Received OACK in state %s" % self.state.state)
continue
-
+
self.state.state = 'oack'
logger.info("Received OACK from server.")
if recvpkt.options.keys() > 0:
@@ -202,7 +209,7 @@ class TftpClient(TftpSession):
# end while
- outputfile.close()
+ self.fileobj.close()
end_time = time.time()
duration = end_time - start_time
@@ -210,11 +217,196 @@ class TftpClient(TftpSession):
logger.info("Duration too short, rate undetermined")
else:
logger.info('')
- logger.info("Downloaded %d bytes in %d seconds" % (bytes, duration))
- bps = (bytes * 8.0) / duration
+ logger.info("Downloaded %d bytes in %d seconds" % (self.bytes, duration))
+ bps = (self.bytes * 8.0) / duration
kbps = bps / 1024.0
logger.info("Average rate: %.2f kbps" % kbps)
dupcount = 0
for key in dups:
dupcount += dups[key]
logger.info("Received %d duplicate packets" % dupcount)
+
+ def upload(self, filename, input, packethook=None, timeout=SOCK_TIMEOUT):
+ # Open the input file.
+ self.fileobj = open(input, "rb")
+ recvpkt = None
+ curblock = 0
+ start_time = time.time()
+ self.bytes = 0
+
+ tftp_factory = TftpPacketFactory()
+ self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
+ self.sock.settimeout(timeout)
+
+ self.filename = filename
+
+ self.send_wrq()
+ self.state.state = 'wrq'
+
+ timeouts = 0
+ while True:
+ try:
+ (buffer, (raddress, rport)) = self.sock.recvfrom(MAX_BLKSIZE)
+ except socket.timeout, err:
+ timeouts += 1
+ if timeouts >= TIMEOUT_RETRIES:
+ raise TftpException, "Hit max timeouts, giving up."
+ else:
+ if self.state.state == 'dat' or self.state.state == 'fin':
+ logger.debug("Timing out on DAT. Need to resend.")
+ self.send_dat(packethook,resend=True)
+ elif self.state.state == 'wrq':
+ logger.debug("Timing out on WRQ.")
+ self.send_wrq(resend=True)
+ else:
+ tftpassert(False,
+ "Timing out in unsupported state %s" %
+ self.state.state)
+ continue
+
+ recvpkt = tftp_factory.parse(buffer)
+
+ logger.debug("Received %d bytes from %s:%s"
+ % (len(buffer), raddress, rport))
+
+ # Check for known "connection".
+ if raddress != self.address:
+ logger.warn("Received traffic from %s, expected host %s. Discarding"
+ % (raddress, self.host))
+ continue
+ if self.port and self.port != rport:
+ logger.warn("Received traffic from %s:%s but we're "
+ "connected to %s:%s. Discarding."
+ % (raddress, rport,
+ self.host, self.port))
+ continue
+
+ if not self.port and self.state.state == 'wrq':
+ self.port = rport
+ logger.debug("Set remote port for session to %s" % rport)
+
+ # Next packet type
+ if isinstance(recvpkt, TftpPacketACK):
+ logger.debug("Received an ACK from the server.")
+ # tftp on wrt54gl seems to answer with an ack to a wrq regardless
+ # if we sent options.
+ if recvpkt.blocknumber == 0 and self.state.state in ('oack','wrq'):
+ logger.debug("Received ACK with 0 blocknumber, starting upload")
+ self.state.state = 'dat'
+ self.send_dat(packethook)
+ else:
+ if self.state.state == 'dat' or self.state.state == 'fin':
+ if self.blocknumber == recvpkt.blocknumber:
+ logger.info("Received ACK for block %d"
+ % recvpkt.blocknumber)
+ if self.state.state == 'fin':
+ break
+ else:
+ self.send_dat(packethook)
+ elif recvpkt.blocknumber < self.blocknumber:
+ # Don't resend a DAT due to an old ACK. Fixes the
+ # sorceror's apprentice problem.
+ logger.warn("Received old ACK for block number %d"
+ % recvpkt.blocknumber)
+ else:
+ logger.warn("Received ACK for block number "
+ "%d, apparently from the future"
+ % recvpkt.blocknumber)
+ else:
+ logger.error("Received ACK with block number %d "
+ "while in state %s"
+ % (recvpkt.blocknumber,
+ self.state.state))
+
+ # Check other packet types.
+ elif isinstance(recvpkt, TftpPacketOACK):
+ if not self.state.state == 'wrq':
+ self.errors += 1
+ logger.error("Received OACK in state %s" % self.state.state)
+ continue
+
+ self.state.state = 'oack'
+ logger.info("Received OACK from server.")
+ if recvpkt.options.keys() > 0:
+ if recvpkt.match_options(self.options):
+ logger.info("Successful negotiation of options")
+ for key in self.options:
+ logger.info(" %s = %s" % (key, self.options[key]))
+ logger.debug("sending ACK to OACK")
+ ackpkt = TftpPacketACK()
+ ackpkt.blocknumber = 0
+ self.sock.sendto(ackpkt.encode().buffer, (self.host, self.port))
+ self.state.state = 'dat'
+ self.send_dat(packethook)
+ else:
+ logger.error("failed to negotiate options")
+ self.senderror(self.sock, TftpErrors.FailedNegotiation, self.host, self.port)
+ self.state.state = 'err'
+ raise TftpException, "Failed to negotiate options"
+
+ elif isinstance(recvpkt, TftpPacketERR):
+ self.state.state = 'err'
+ self.senderror(self.sock, TftpErrors.IllegalTftpOp, self.host, self.port)
+ tftpassert(False, "Received ERR from server: " + str(recvpkt))
+
+ elif isinstance(recvpkt, TftpPacketWRQ):
+ self.state.state = 'err'
+ self.senderror(self.sock, TftpErrors.IllegalTftpOp, self.host, self.port)
+ tftpassert(False, "Received WRQ from server: " + str(recvpkt))
+
+ else:
+ self.state.state = 'err'
+ self.senderror(self.sock, TftpErrors.IllegalTftpOp, self.host, self.port)
+ tftpassert(False, "Received unknown packet type from server: "
+ + str(recvpkt))
+
+
+ # end while
+ self.fileobj.close()
+
+ end_time = time.time()
+ duration = end_time - start_time
+ if duration == 0:
+ logger.info("Duration too short, rate undetermined")
+ else:
+ logger.info('')
+ logger.info("Uploaded %d bytes in %d seconds" % (self.bytes, duration))
+ bps = (self.bytes * 8.0) / duration
+ kbps = bps / 1024.0
+ logger.info("Average rate: %.2f kbps" % kbps)
+
+ def send_dat(self, packethook, resend=False):
+ """This method reads and sends a DAT packet based on what is in self.buffer."""
+ if not resend:
+ blksize = int(self.options['blksize'])
+ self.buffer = self.fileobj.read(blksize)
+ logger.debug("Read %d bytes into buffer" % len(self.buffer))
+ if len(self.buffer) < blksize:
+ logger.info("Reached EOF on file %s" % self.filename)
+ self.state.state = 'fin'
+ self.blocknumber += 1
+ if self.blocknumber > 65535:
+ logger.debug("Blocknumber rolled over to zero")
+ self.blocknumber = 0
+ self.bytes += len(self.buffer)
+ else:
+ logger.warn("Resending block number %d" % self.blocknumber)
+ dat = TftpPacketDAT()
+ dat.data = self.buffer
+ dat.blocknumber = self.blocknumber
+ logger.debug("Sending DAT packet %d" % self.blocknumber)
+ self.sock.sendto(dat.encode().buffer, (self.host, self.port))
+ self.timesent = time.time()
+ if packethook:
+ packethook(dat)
+
+ def send_wrq(self, resend=False):
+ """This method sends a wrq"""
+ logger.info("Sending tftp upload request to %s" % self.host)
+ logger.info(" filename -> %s" % self.filename)
+
+ wrq = TftpPacketWRQ()
+ wrq.filename = self.filename
+ wrq.mode = "octet" # FIXME - shouldn't hardcode this
+ wrq.options = self.options
+ self.sock.sendto(wrq.encode().buffer, (self.host, self.iport))