summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMichael P. Soulier <msoulier@digitaltorque.ca>2009-08-15 22:36:58 -0400
committerMichael P. Soulier <msoulier@digitaltorque.ca>2009-08-16 19:56:06 -0400
commit62b22fb562eff64a6d6bb6c1a1a3c194d668d9a1 (patch)
tree8adf96a5c71b15cfa443c974ee3ad6e270c8b9e4
parent03e4e748293070ac37fb7fe88abc8b915d84be96 (diff)
downloadtftpy-62b22fb562eff64a6d6bb6c1a1a3c194d668d9a1.tar.gz
Did some rework for the state machine in a server context.
Removed the handler framework in favour of a TftpContextServer used as the session.
-rw-r--r--tftpy/TftpClient.py20
-rw-r--r--tftpy/TftpPacketFactory.py4
-rw-r--r--tftpy/TftpPacketTypes.py88
-rw-r--r--tftpy/TftpServer.py408
-rw-r--r--tftpy/TftpShared.py6
-rw-r--r--tftpy/TftpStates.py588
6 files changed, 521 insertions, 593 deletions
diff --git a/tftpy/TftpClient.py b/tftpy/TftpClient.py
index da35d05..89843f1 100644
--- a/tftpy/TftpClient.py
+++ b/tftpy/TftpClient.py
@@ -52,12 +52,12 @@ class TftpClient(TftpSession):
# output? This should be in the sample client, but not in the download
# call.
if metrics.duration == 0:
- logger.info("Duration too short, rate undetermined")
+ log.info("Duration too short, rate undetermined")
else:
- logger.info('')
- logger.info("Downloaded %.2f bytes in %.2f seconds" % (metrics.bytes, metrics.duration))
- logger.info("Average rate: %.2f kbps" % metrics.kbps)
- logger.info("Received %d duplicate packets" % metrics.dupcount)
+ log.info('')
+ log.info("Downloaded %.2f bytes in %.2f seconds" % (metrics.bytes, metrics.duration))
+ log.info("Average rate: %.2f kbps" % metrics.kbps)
+ log.info("Received %d duplicate packets" % metrics.dupcount)
def upload(self, filename, input, packethook=None, timeout=SOCK_TIMEOUT):
# Open the input file.
@@ -80,9 +80,9 @@ class TftpClient(TftpSession):
# output? This should be in the sample client, but not in the download
# call.
if metrics.duration == 0:
- logger.info("Duration too short, rate undetermined")
+ log.info("Duration too short, rate undetermined")
else:
- logger.info('')
- logger.info("Downloaded %.2f bytes in %.2f seconds" % (metrics.bytes, metrics.duration))
- logger.info("Average rate: %.2f kbps" % metrics.kbps)
- logger.info("Received %d duplicate packets" % metrics.dupcount) \ No newline at end of file
+ log.info('')
+ log.info("Downloaded %.2f bytes in %.2f seconds" % (metrics.bytes, metrics.duration))
+ log.info("Average rate: %.2f kbps" % metrics.kbps)
+ log.info("Received %d duplicate packets" % metrics.dupcount)
diff --git a/tftpy/TftpPacketFactory.py b/tftpy/TftpPacketFactory.py
index 642b4d8..3f287de 100644
--- a/tftpy/TftpPacketFactory.py
+++ b/tftpy/TftpPacketFactory.py
@@ -19,9 +19,9 @@ class TftpPacketFactory(object):
"""This method is used to parse an existing datagram into its
corresponding TftpPacket object. The buffer is the raw bytes off of
the network."""
- logger.debug("parsing a %d byte packet" % len(buffer))
+ log.debug("parsing a %d byte packet" % len(buffer))
(opcode,) = struct.unpack("!H", buffer[:2])
- logger.debug("opcode is %d" % opcode)
+ log.debug("opcode is %d" % opcode)
packet = self.__create(opcode)
packet.buffer = buffer
return packet.decode()
diff --git a/tftpy/TftpPacketTypes.py b/tftpy/TftpPacketTypes.py
index e269deb..b9328c5 100644
--- a/tftpy/TftpPacketTypes.py
+++ b/tftpy/TftpPacketTypes.py
@@ -15,7 +15,7 @@ class TftpSession(object):
def senderror(self, sock, errorcode, address, port):
"""This method uses the socket passed, and uses the errorcode, address
and port to compose and send an error packet."""
- logger.debug("In senderror, being asked to send error %d to %s:%s"
+ log.debug("In senderror, being asked to send error %d to %s:%s"
% (errorcode, address, port))
errpkt = TftpPacketERR()
errpkt.errorcode = errorcode
@@ -27,23 +27,23 @@ class TftpPacketWithOptions(object):
goal is just to share code here, and not cause diamond inheritance."""
def __init__(self):
- self.options = []
+ self.options = {}
def setoptions(self, options):
- logger.debug("in TftpPacketWithOptions.setoptions")
- logger.debug("options: " + str(options))
+ log.debug("in TftpPacketWithOptions.setoptions")
+ log.debug("options: " + str(options))
myoptions = {}
for key in options:
newkey = str(key)
myoptions[newkey] = str(options[key])
- logger.debug("populated myoptions with %s = %s"
+ log.debug("populated myoptions with %s = %s"
% (newkey, myoptions[newkey]))
- logger.debug("setting options hash to: " + str(myoptions))
+ log.debug("setting options hash to: " + str(myoptions))
self._options = myoptions
def getoptions(self):
- logger.debug("in TftpPacketWithOptions.getoptions")
+ log.debug("in TftpPacketWithOptions.getoptions")
return self._options
# Set up getter and setter on options to ensure that they are the proper
@@ -59,19 +59,19 @@ class TftpPacketWithOptions(object):
format = "!"
options = {}
- logger.debug("decode_options: buffer is: " + repr(buffer))
- logger.debug("size of buffer is %d bytes" % len(buffer))
+ log.debug("decode_options: buffer is: " + repr(buffer))
+ log.debug("size of buffer is %d bytes" % len(buffer))
if len(buffer) == 0:
- logger.debug("size of buffer is zero, returning empty hash")
+ log.debug("size of buffer is zero, returning empty hash")
return {}
# Count the nulls in the buffer. Each one terminates a string.
- logger.debug("about to iterate options buffer counting nulls")
+ log.debug("about to iterate options buffer counting nulls")
length = 0
for c in buffer:
- #logger.debug("iterating this byte: " + repr(c))
+ #log.debug("iterating this byte: " + repr(c))
if ord(c) == 0:
- logger.debug("found a null at length %d" % length)
+ log.debug("found a null at length %d" % length)
if length > 0:
format += "%dsx" % length
length = -1
@@ -79,14 +79,14 @@ class TftpPacketWithOptions(object):
raise TftpException, "Invalid options in buffer"
length += 1
- logger.debug("about to unpack, format is: %s" % format)
+ log.debug("about to unpack, format is: %s" % format)
mystruct = struct.unpack(format, buffer)
tftpassert(len(mystruct) % 2 == 0,
"packet with odd number of option/value pairs")
for i in range(0, len(mystruct), 2):
- logger.debug("setting option %s to %s" % (mystruct[i], mystruct[i+1]))
+ log.debug("setting option %s to %s" % (mystruct[i], mystruct[i+1]))
options[mystruct[i]] = mystruct[i+1]
return options
@@ -134,10 +134,10 @@ class TftpPacketInitial(TftpPacket, TftpPacketWithOptions):
ptype = None
if self.opcode == 1: ptype = "RRQ"
else: ptype = "WRQ"
- logger.debug("Encoding %s packet, filename = %s, mode = %s"
+ log.debug("Encoding %s packet, filename = %s, mode = %s"
% (ptype, self.filename, self.mode))
for key in self.options:
- logger.debug(" Option %s = %s" % (key, self.options[key]))
+ log.debug(" Option %s = %s" % (key, self.options[key]))
format = "!H"
format += "%dsx" % len(self.filename)
@@ -148,7 +148,7 @@ class TftpPacketInitial(TftpPacket, TftpPacketWithOptions):
# Add options.
options_list = []
if self.options.keys() > 0:
- logger.debug("there are options to encode")
+ log.debug("there are options to encode")
for key in self.options:
# Populate the option name
format += "%dsx" % len(key)
@@ -157,9 +157,9 @@ class TftpPacketInitial(TftpPacket, TftpPacketWithOptions):
format += "%dsx" % len(str(self.options[key]))
options_list.append(str(self.options[key]))
- logger.debug("format is %s" % format)
- logger.debug("options_list is %s" % options_list)
- logger.debug("size of struct is %d" % struct.calcsize(format))
+ log.debug("format is %s" % format)
+ log.debug("options_list is %s" % options_list)
+ log.debug("size of struct is %d" % struct.calcsize(format))
self.buffer = struct.pack(format,
self.opcode,
@@ -167,7 +167,7 @@ class TftpPacketInitial(TftpPacket, TftpPacketWithOptions):
self.mode,
*options_list)
- logger.debug("buffer is " + repr(self.buffer))
+ log.debug("buffer is " + repr(self.buffer))
return self
def decode(self):
@@ -177,13 +177,13 @@ class TftpPacketInitial(TftpPacket, TftpPacketWithOptions):
nulls = 0
format = ""
nulls = length = tlength = 0
- logger.debug("in decode: about to iterate buffer counting nulls")
+ log.debug("in decode: about to iterate buffer counting nulls")
subbuf = self.buffer[2:]
for c in subbuf:
- logger.debug("iterating this byte: " + repr(c))
+ log.debug("iterating this byte: " + repr(c))
if ord(c) == 0:
nulls += 1
- logger.debug("found a null at length %d, now have %d"
+ log.debug("found a null at length %d, now have %d"
% (length, nulls))
format += "%dsx" % length
length = -1
@@ -193,17 +193,17 @@ class TftpPacketInitial(TftpPacket, TftpPacketWithOptions):
length += 1
tlength += 1
- logger.debug("hopefully found end of mode at length %d" % tlength)
+ log.debug("hopefully found end of mode at length %d" % tlength)
# length should now be the end of the mode.
tftpassert(nulls == 2, "malformed packet")
shortbuf = subbuf[:tlength+1]
- logger.debug("about to unpack buffer with format: %s" % format)
- logger.debug("unpacking buffer: " + repr(shortbuf))
+ log.debug("about to unpack buffer with format: %s" % format)
+ log.debug("unpacking buffer: " + repr(shortbuf))
mystruct = struct.unpack(format, shortbuf)
tftpassert(len(mystruct) == 2, "malformed packet")
- logger.debug("setting filename to %s" % mystruct[0])
- logger.debug("setting mode to %s" % mystruct[1])
+ log.debug("setting filename to %s" % mystruct[0])
+ log.debug("setting mode to %s" % mystruct[1])
self.filename = mystruct[0]
self.mode = mystruct[1]
@@ -269,7 +269,7 @@ DATA | 03 | Block # | Data |
"""Encode the DAT packet. This method populates self.buffer, and
returns self for easy method chaining."""
if len(self.data) == 0:
- logger.debug("Encoding an empty DAT packet")
+ log.debug("Encoding an empty DAT packet")
format = "!HH%ds" % len(self.data)
self.buffer = struct.pack(format,
self.opcode,
@@ -283,12 +283,12 @@ DATA | 03 | Block # | Data |
# We know the first 2 bytes are the opcode. The second two are the
# block number.
(self.blocknumber,) = struct.unpack("!H", self.buffer[2:4])
- logger.debug("decoding DAT packet, block number %d" % self.blocknumber)
- logger.debug("should be %d bytes in the packet total"
+ log.debug("decoding DAT packet, block number %d" % self.blocknumber)
+ log.debug("should be %d bytes in the packet total"
% len(self.buffer))
# Everything else is data.
self.data = self.buffer[4:]
- logger.debug("found %d bytes of data"
+ log.debug("found %d bytes of data"
% len(self.data))
return self
@@ -308,14 +308,14 @@ ACK | 04 | Block # |
return 'ACK packet: block %d' % self.blocknumber
def encode(self):
- logger.debug("encoding ACK: opcode = %d, block = %d"
+ log.debug("encoding ACK: opcode = %d, block = %d"
% (self.opcode, self.blocknumber))
self.buffer = struct.pack("!HH", self.opcode, self.blocknumber)
return self
def decode(self):
self.opcode, self.blocknumber = struct.unpack("!HH", self.buffer)
- logger.debug("decoded ACK packet: opcode = %d, block = %d"
+ log.debug("decoded ACK packet: opcode = %d, block = %d"
% (self.opcode, self.blocknumber))
return self
@@ -365,7 +365,7 @@ ERROR | 05 | ErrorCode | ErrMsg | 0 |
"""Encode the DAT packet based on instance variables, populating
self.buffer, returning self."""
format = "!HH%dsx" % len(self.errmsgs[self.errorcode])
- logger.debug("encoding ERR packet with format %s" % format)
+ log.debug("encoding ERR packet with format %s" % format)
self.buffer = struct.pack(format,
self.opcode,
self.errorcode,
@@ -375,13 +375,13 @@ ERROR | 05 | ErrorCode | ErrMsg | 0 |
def decode(self):
"Decode self.buffer, populating instance variables and return self."
tftpassert(len(self.buffer) > 4, "malformed ERR packet, too short")
- logger.debug("Decoding ERR packet, length %s bytes" %
+ log.debug("Decoding ERR packet, length %s bytes" %
len(self.buffer))
format = "!HH%dsx" % (len(self.buffer) - 5)
- logger.debug("Decoding ERR packet with format: %s" % format)
+ log.debug("Decoding ERR packet with format: %s" % format)
self.opcode, self.errorcode, self.errmsg = struct.unpack(format,
self.buffer)
- logger.error("ERR packet - errorcode: %d, message: %s"
+ log.error("ERR packet - errorcode: %d, message: %s"
% (self.errorcode, self.errmsg))
return self
@@ -402,10 +402,10 @@ class TftpPacketOACK(TftpPacket, TftpPacketWithOptions):
def encode(self):
format = "!H" # opcode
options_list = []
- logger.debug("in TftpPacketOACK.encode")
+ log.debug("in TftpPacketOACK.encode")
for key in self.options:
- logger.debug("looping on option key %s" % key)
- logger.debug("value is %s" % self.options[key])
+ log.debug("looping on option key %s" % key)
+ log.debug("value is %s" % self.options[key])
format += "%dsx" % len(key)
format += "%dsx" % len(self.options[key])
options_list.append(key)
@@ -429,7 +429,7 @@ class TftpPacketOACK(TftpPacket, TftpPacketWithOptions):
# We can accept anything between the min and max values.
size = self.options[name]
if size >= MIN_BLKSIZE and size <= MAX_BLKSIZE:
- logger.debug("negotiated blksize of %d bytes" % size)
+ log.debug("negotiated blksize of %d bytes" % size)
options[blksize] = size
else:
raise TftpException, "Unsupported option: %s" % name
diff --git a/tftpy/TftpServer.py b/tftpy/TftpServer.py
index e846979..ad781a2 100644
--- a/tftpy/TftpServer.py
+++ b/tftpy/TftpServer.py
@@ -1,4 +1,5 @@
import socket, os, re, time, random
+import select
from TftpShared import *
from TftpPacketTypes import *
from TftpPacketFactory import *
@@ -15,26 +16,27 @@ class TftpServer(TftpSession):
self.listenip = None
self.listenport = None
self.sock = None
+ # FIXME: What about multiple roots?
self.root = os.path.abspath(tftproot)
- self.dynfunc = dyn_file_func
+ self.dyn_file_func = dyn_file_func
# A dict of handlers, where each session is keyed by a string like
# ip:tid for the remote end.
self.handlers = {}
if os.path.exists(self.root):
- logger.debug("tftproot %s does exist" % self.root)
+ log.debug("tftproot %s does exist" % self.root)
if not os.path.isdir(self.root):
raise TftpException, "The tftproot must be a directory."
else:
- logger.debug("tftproot %s is a directory" % self.root)
+ log.debug("tftproot %s is a directory" % self.root)
if os.access(self.root, os.R_OK):
- logger.debug("tftproot %s is readable" % self.root)
+ log.debug("tftproot %s is readable" % self.root)
else:
raise TftpException, "The tftproot must be readable"
if os.access(self.root, os.W_OK):
- logger.debug("tftproot %s is writable" % self.root)
+ log.debug("tftproot %s is writable" % self.root)
else:
- logger.warning("The tftproot %s is not writable" % self.root)
+ log.warning("The tftproot %s is not writable" % self.root)
else:
raise TftpException, "The tftproot does not exist."
@@ -45,14 +47,12 @@ class TftpServer(TftpSession):
"""Start a server listening on the supplied interface and port. This
defaults to INADDR_ANY (all interfaces) and UDP port 69. You can also
supply a different socket timeout value, if desired."""
- import select
-
tftp_factory = TftpPacketFactory()
# Don't use new 2.5 ternary operator yet
# listenip = listenip if listenip else '0.0.0.0'
if not listenip: listenip = '0.0.0.0'
- logger.info("Server requested on ip %s, port %s"
+ log.info("Server requested on ip %s, port %s"
% (listenip, listenport))
try:
# FIXME - sockets should be non-blocking?
@@ -62,388 +62,82 @@ class TftpServer(TftpSession):
# Reraise it for now.
raise
- logger.info("Starting receive loop...")
+ log.info("Starting receive loop...")
while True:
# Build the inputlist array of sockets to select() on.
inputlist = []
inputlist.append(self.sock)
- for key in self.handlers:
- inputlist.append(self.handlers[key].sock)
+ for key in self.sessions:
+ inputlist.append(self.sessions[key].sock)
# Block until some socket has input on it.
- logger.debug("Performing select on this inputlist: %s" % inputlist)
+ log.debug("Performing select on this inputlist: %s" % inputlist)
readyinput, readyoutput, readyspecial = select.select(inputlist,
[],
[],
SOCK_TIMEOUT)
- #(buffer, (raddress, rport)) = self.sock.recvfrom(MAX_BLKSIZE)
- #recvpkt = tftp_factory.parse(buffer)
- #key = "%s:%s" % (raddress, rport)
-
deletion_list = []
+ # Handle the available data, if any. Maybe we timed-out.
for readysock in readyinput:
+ # Is the traffic on the main server socket? ie. new session?
if readysock == self.sock:
- logger.debug("Data ready on our main socket")
+ log.debug("Data ready on our main socket")
buffer, (raddress, rport) = self.sock.recvfrom(MAX_BLKSIZE)
- logger.debug("Read %d bytes" % len(buffer))
- recvpkt = tftp_factory.parse(buffer)
- key = "%s:%s" % (raddress, rport)
- if isinstance(recvpkt, TftpPacketRRQ):
- logger.debug("RRQ packet from %s:%s" % (raddress, rport))
- if not self.handlers.has_key(key):
- try:
- logger.debug("New download request, session key = %s"
- % key)
- self.handlers[key] = TftpServerHandler(key,
- 'rrq',
- self.root,
- listenip,
- tftp_factory,
- self.dynfunc)
- self.handlers[key].handle((recvpkt, raddress, rport))
- except TftpException, err:
- logger.error("Fatal exception thrown from handler: %s"
- % str(err))
- logger.debug("Deleting handler: %s" % key)
- deletion_list.append(key)
+ log.debug("Read %d bytes" % len(buffer))
- else:
- logger.warn("Received RRQ for existing session!")
- self.senderror(self.sock,
- TftpErrors.IllegalTftpOp,
- raddress,
- rport)
- continue
+ recvpkt = tftp_factory.parse(buffer)
+ # FIXME: Is this the best way to do a session key? What
+ # about symmetric udp?
+ key = "%s:%s" % (raddress, rport)
- elif isinstance(recvpkt, TftpPacketWRQ):
- logger.error("Write requests not implemented at this time.")
- self.senderror(self.sock,
- TftpErrors.IllegalTftpOp,
- raddress,
- rport)
- continue
+ if not self.sessions.has_key(key):
+ log.debug("Creating new server context for "
+ "session key = %s" % key)
+ self.sessions[key] = TftpContextServer(raddress,
+ rport,
+ timeout,
+ self.root,
+ self.dyn_file_func)
+ self.sessions[key].start(buffer)
else:
- # FIXME - this will have to change if we do symmetric UDP
- logger.error("Should only receive RRQ or WRQ packets "
- "on main listen port. Received %s" % recvpkt)
- self.senderror(self.sock,
- TftpErrors.IllegalTftpOp,
- raddress,
- rport)
- continue
+ log.warn("received traffic on main socket for "
+ "existing session??")
else:
- for key in self.handlers:
- if readysock == self.handlers[key].sock:
- # FIXME - violating DRY principle with above code
+ # Must find the owner of this traffic.
+ for key in self.session:
+ if readysock == self.session[key].sock:
try:
- self.handlers[key].handle()
+ self.session[key].cycle()
+ if self.session[key].state == None:
+ log.info("Successful transfer.")
+ deletion_list.append(key)
break
except TftpException, err:
deletion_list.append(key)
- if self.handlers[key].state.state == 'fin':
- logger.info("Successful transfer.")
- break
- else:
- logger.error("Fatal exception thrown from handler: %s"
- % str(err))
+ log.error("Fatal exception thrown from "
+ "handler: %s" % str(err))
else:
- logger.error("Can't find the owner for this packet. Discarding.")
+ log.error("Can't find the owner for this packet. "
+ "Discarding.")
- logger.debug("Looping on all handlers to check for timeouts")
+ log.debug("Looping on all handlers to check for timeouts")
now = time.time()
- for key in self.handlers:
+ for key in self.sessions:
try:
- self.handlers[key].check_timeout(now)
+ self.sessions[key].checkTimeout(now)
except TftpException, err:
- logger.error("Fatal exception thrown from handler: %s"
+ log.error("Fatal exception thrown from handler: %s"
% str(err))
deletion_list.append(key)
- logger.debug("Iterating deletion list.")
+ log.debug("Iterating deletion list.")
for key in deletion_list:
- if self.handlers.has_key(key):
- logger.debug("Deleting handler %s" % key)
- del self.handlers[key]
+ if self.sessions.has_key(key):
+ log.debug("Deleting handler %s" % key)
+ del self.sessions[key]
deletion_list = []
-
-class TftpServerHandler(TftpSession):
- """This class implements a handler for a given server session, handling
- the work for one download."""
-
- def __init__(self, key, state, root, listenip, factory, dyn_file_func):
- TftpSession.__init__(self)
- logger.info("Starting new handler. Key %s." % key)
- self.key = key
- self.host, self.port = self.key.split(':')
- self.port = int(self.port)
- self.listenip = listenip
- # Note, correct state here is important as it tells the handler whether it's
- # handling a download or an upload.
- self.state = state
- self.root = root
- self.mode = None
- self.filename = None
- self.sock = False
- self.options = { 'blksize': DEF_BLKSIZE }
- self.blocknumber = 0
- self.buffer = None
- self.fileobj = None
- self.timesent = 0
- self.timeouts = 0
- self.tftp_factory = factory
- self.dynfunc = dyn_file_func
- count = 0
- while not self.sock:
- self.sock = self.gensock(listenip)
- count += 1
- if count > 10:
- raise TftpException, "Failed to bind this handler to any port"
-
- def check_timeout(self, now):
- """This method checks to see if we've timed-out waiting for traffic
- from the client."""
- if self.timesent:
- if now - self.timesent > SOCK_TIMEOUT:
- self.timeout()
-
- def timeout(self):
- """This method handles a timeout condition."""
- logger.debug("Handling timeout for handler %s" % self.key)
- self.timeouts += 1
- if self.timeouts > TIMEOUT_RETRIES:
- raise TftpException, "Hit max retries, giving up."
-
- if self.state.state == 'dat' or self.state.state == 'fin':
- logger.debug("Timing out on DAT. Need to resend.")
- self.send_dat(resend=True)
- elif self.state.state == 'oack':
- logger.debug("Timing out on OACK. Need to resend.")
- self.send_oack()
- else:
- tftpassert(False,
- "Timing out in unsupported state %s" %
- self.state.state)
-
- def gensock(self, listenip):
- """This method generates a new UDP socket, whose listening port must
- be randomly generated, and not conflict with any already in use. For
- now, let the OS do this."""
- random.seed()
- port = random.randrange(1025, 65536)
- # FIXME - sockets should be non-blocking?
- sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
- logger.debug("Trying a handler socket on port %d" % port)
- try:
- sock.bind((listenip, port))
- return sock
- except socket.error, err:
- if err[0] == 98:
- logger.warn("Handler %s, port %d was already taken" % (self.key, port))
- return False
- else:
- raise
-
- def handle(self, pkttuple=None):
- """This method informs a handler instance that it has data waiting on
- its socket that it must read and process."""
- recvpkt = raddress = rport = None
- if pkttuple:
- logger.debug("Handed pkt %s for handler %s" % (recvpkt, self.key))
- recvpkt, raddress, rport = pkttuple
- else:
- logger.debug("Data ready for handler %s" % self.key)
- buffer, (raddress, rport) = self.sock.recvfrom(MAX_BLKSIZE)
- logger.debug("Read %d bytes" % len(buffer))
- recvpkt = self.tftp_factory.parse(buffer)
-
- # FIXME - refactor into another method, this is too big
- if isinstance(recvpkt, TftpPacketRRQ):
- logger.debug("Handler %s received RRQ packet" % self.key)
- logger.debug("Requested file is %s, mode is %s" % (recvpkt.filename,
- recvpkt.mode))
- # FIXME - only octet mode is supported at this time.
- if recvpkt.mode != 'octet':
- self.senderror(self.sock,
- TftpErrors.IllegalTftpOp,
- raddress,
- rport)
- raise TftpException, "Unsupported mode: %s" % recvpkt.mode
-
- # test host/port of client end
- if self.host != raddress or self.port != rport:
- self.senderror(self.sock,
- TftpErrors.UnknownTID,
- raddress,
- rport)
- logger.error("Expected traffic from %s:%s but received it "
- "from %s:%s instead."
- % (self.host, self.port, raddress, rport))
- self.errors += 1
- return
-
- if self.state.state == 'rrq':
- logger.debug("Received RRQ. Composing response.")
- self.filename = self.root + os.sep + recvpkt.filename
- logger.debug("The path to the desired file is %s" %
- self.filename)
- self.filename = os.path.abspath(self.filename)
- logger.debug("The absolute path is %s" % self.filename)
- # Security check. Make sure it's prefixed by the tftproot.
- if self.filename.find(self.root) == 0:
- logger.debug("The path appears to be safe: %s" %
- self.filename)
- else:
- logger.error("Insecure path: %s" % self.filename)
- self.errors += 1
- self.senderror(self.sock,
- TftpErrors.AccessViolation,
- raddress,
- rport)
- raise TftpException, "Insecure path: %s" % self.filename
-
- # Does the file exist?
- if(os.path.exists(self.filename) or not self.dynfunc is None):
- logger.debug("File %s exists." % self.filename)
-
- # Check options. Currently we only support the blksize
- # option.
- if recvpkt.options.has_key('blksize'):
- logger.debug("RRQ includes a blksize option")
- blksize = int(recvpkt.options['blksize'])
- # Delete the option now that it's handled.
- del recvpkt.options['blksize']
- if blksize >= MIN_BLKSIZE and blksize <= MAX_BLKSIZE:
- logger.info("Client requested blksize = %d"
- % blksize)
- self.options['blksize'] = blksize
- else:
- logger.warning("Client %s requested invalid "
- "blocksize %d, responding with default"
- % (self.key, blksize))
- self.options['blksize'] = DEF_BLKSIZE
-
- if recvpkt.options.has_key('tsize'):
- logger.info('RRQ includes tsize option')
- self.options['tsize'] = os.stat(self.filename).st_size
- # Delete the option now that it's handled.
- del recvpkt.options['tsize']
-
- if len(recvpkt.options.keys()) > 0:
- logger.warning("Client %s requested unsupported options: %s"
- % (self.key, recvpkt.options))
-
- if self.options:
- logger.info("Options requested, sending OACK")
- self.send_oack()
- else:
- logger.debug("Client %s requested no options."
- % self.key)
- self.start_download()
-
- else:
- logger.error("Requested file %s does not exist." %
- self.filename)
- self.senderror(self.sock,
- TftpErrors.FileNotFound,
- raddress,
- rport)
- raise TftpException, "Requested file not found: %s" % self.filename
-
- else:
- # We're receiving an RRQ when we're not expecting one.
- logger.error("Received an RRQ in handler %s "
- "but we're in state %s" % (self.key, self.state))
- self.errors += 1
-
- # Next packet type
- elif isinstance(recvpkt, TftpPacketACK):
- logger.debug("Received an ACK from the client.")
- if recvpkt.blocknumber == 0 and self.state.state == 'oack':
- logger.debug("Received ACK with 0 blocknumber, starting download")
- self.start_download()
- else:
- if self.state.state == 'dat' or self.state.state == 'fin':
- if self.blocknumber == recvpkt.blocknumber:
- logger.debug("Received ACK for block %d"
- % recvpkt.blocknumber)
- if self.state.state == 'fin':
- raise TftpException, "Successful transfer."
- else:
- self.send_dat()
- elif recvpkt.blocknumber < self.blocknumber:
- # Don't resend a DAT due to an old ACK. Fixes the
- # sorceror's apprentice problem.
- logger.warn("Received old ACK for block number %d"
- % recvpkt.blocknumber)
- else:
- logger.warn("Received ACK for block number "
- "%d, apparently from the future"
- % recvpkt.blocknumber)
- else:
- logger.error("Received ACK with block number %d "
- "while in state %s"
- % (recvpkt.blocknumber,
- self.state.state))
-
- elif isinstance(recvpkt, TftpPacketERR):
- logger.error("Received error packet from client: %s" % recvpkt)
- self.state.state = 'err'
- raise TftpException, "Received error from client"
-
- # Handle other packet types.
- else:
- logger.error("Received packet %s while handling a download"
- % recvpkt)
- self.senderror(self.sock,
- TftpErrors.IllegalTftpOp,
- self.host,
- self.port)
- raise TftpException, "Invalid packet received during download"
-
- def start_download(self):
- """This method opens self.filename, stores the resulting file object
- in self.fileobj, and calls send_dat()."""
- self.state.state = 'dat'
- if os.path.exists(self.filename):
- self.fileobj = open(self.filename, "rb")
- else:
- self.fileobj = self.dynfunc(self.filename)
- self.send_dat()
-
- def send_dat(self, resend=False):
- """This method reads sends a DAT packet based on what is in self.buffer."""
- if not resend:
- blksize = int(self.options['blksize'])
- self.buffer = self.fileobj.read(blksize)
- logger.debug("Read %d bytes into buffer" % len(self.buffer))
- if len(self.buffer) < blksize:
- logger.info("Reached EOF on file %s" % self.filename)
- self.state.state = 'fin'
- self.blocknumber += 1
- if self.blocknumber > 65535:
- logger.debug("Blocknumber rolled over to zero")
- self.blocknumber = 0
- else:
- logger.warn("Resending block number %d" % self.blocknumber)
- dat = TftpPacketDAT()
- dat.data = self.buffer
- dat.blocknumber = self.blocknumber
- logger.debug("Sending DAT packet %d" % self.blocknumber)
- self.sock.sendto(dat.encode().buffer, (self.host, self.port))
- self.timesent = time.time()
-
- # FIXME - should these be factored-out into the session class?
- def send_oack(self):
- """This method sends an OACK packet based on current params."""
- logger.debug("Composing and sending OACK packet")
- oack = TftpPacketOACK()
- oack.options = self.options
- self.sock.sendto(oack.encode().buffer,
- (self.host, self.port))
- self.timesent = time.time()
- self.state.state = 'oack'
diff --git a/tftpy/TftpShared.py b/tftpy/TftpShared.py
index 95172c3..bb95ad4 100644
--- a/tftpy/TftpShared.py
+++ b/tftpy/TftpShared.py
@@ -17,7 +17,7 @@ DEF_TFTP_PORT = 69
logging.basicConfig()
# The logger used by this library. Feel free to clobber it with your own, if you like, as
# long as it conforms to Python's logging.
-logger = logging.getLogger('tftpy')
+log = logging.getLogger('tftpy')
def tftpassert(condition, msg):
"""This function is a simple utility that will check the condition
@@ -31,8 +31,8 @@ def setLogLevel(level):
"""This function is a utility function for setting the internal log level.
The log level defaults to logging.NOTSET, so unwanted output to stdout is
not created."""
- global logger
- logger.setLevel(level)
+ global log
+ log.setLevel(level)
class TftpErrors(object):
"""This class is a convenience for defining the common tftp error codes,
diff --git a/tftpy/TftpStates.py b/tftpy/TftpStates.py
index 88c4fa1..3d71e16 100644
--- a/tftpy/TftpStates.py
+++ b/tftpy/TftpStates.py
@@ -22,17 +22,28 @@ class TftpMetrics(object):
# Rates
self.bps = 0
self.kbps = 0
+ # Generic errors
+ self.errors = 0
def compute(self):
# Compute transfer time
self.duration = self.end_time - self.start_time
- logger.debug("TftpMetrics.compute: duration is %s" % self.duration)
+ log.debug("TftpMetrics.compute: duration is %s" % self.duration)
self.bps = (self.bytes * 8.0) / self.duration
self.kbps = self.bps / 1024.0
- logger.debug("TftpMetrics.compute: kbps is %s" % self.kbps)
- dupcount = 0
+ log.debug("TftpMetrics.compute: kbps is %s" % self.kbps)
for key in self.dups:
- dupcount += self.dups[key]
+ self.dupcount += self.dups[key]
+
+ def add_dup(self, blocknumber):
+ """This method adds a dup for a block number to the metrics."""
+ log.debug("Recording a dup for block %d" % blocknumber)
+ if self.dups.has_key(blocknumber):
+ self.dups[pkt.blocknumber] += 1
+ else:
+ self.dups[pkt.blocknumber] = 1
+ tftpassert(self.dups[pkt.blocknumber] < MAX_DUPS,
+ "Max duplicates for block %d reached" % blocknumber)
###############################################################################
# Context classes
@@ -40,16 +51,32 @@ class TftpMetrics(object):
class TftpContext(object):
"""The base class of the contexts."""
- def __init__(self, host, port):
+
+ def __init__(self, host, port, timeout):
"""Constructor for the base context, setting shared instance
variables."""
+ self.file_to_transfer = None
+ self.fileobj = None
+ self.options = None
+ self.packethook = None
+ self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
+ self.sock.settimeout(timeout)
+ self.state = None
+ self.next_block = 0
self.factory = TftpPacketFactory()
+ # Note, setting the host will also set self.address, as it's a property.
self.host = host
self.port = port
# The port associated with the TID
self.tidport = None
# Metrics
self.metrics = TftpMetrics()
+ # Flag when the transfer is pending completion.
+ self.pending_complete = False
+
+ def checkTimeout(self, now):
+ # FIXME
+ pass
def start(self):
return NotImplementedError, "Abstract method"
@@ -69,37 +96,9 @@ class TftpContext(object):
host = property(gethost, sethost)
- def sendAck(self, blocknumber):
- """This method sends an ack packet to the block number specified."""
- logger.info("sending ack to block %d" % blocknumber)
- ackpkt = TftpPacketACK()
- ackpkt.blocknumber = blocknumber
- self.sock.sendto(ackpkt.encode().buffer, (self.host, self.tidport))
-
- def sendError(self, errorcode):
- """This method uses the socket passed, and uses the errorcode to
- compose and send an error packet."""
- logger.debug("In sendError, being asked to send error %d" % errorcode)
- errpkt = TftpPacketERR()
- errpkt.errorcode = errorcode
- self.sock.sendto(errpkt.encode().buffer, (self.host, self.tidport))
-
-class TftpContextClient(TftpContext):
- """This class represents shared functionality by both the download and
- upload client contexts."""
- def __init__(self, host, port, filename, options, packethook, timeout):
- TftpContext.__init__(self, host, port)
- self.file_to_transfer = filename
- self.options = options
- self.packethook = packethook
- self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
- self.sock.settimeout(timeout)
- self.state = None
- self.next_block = 0
-
def setNextBlock(self, block):
if block > 2 ** 16:
- logger.debug("block number rollover to 0 again")
+ log.debug("block number rollover to 0 again")
block = 0
self.__eblock = block
@@ -111,19 +110,21 @@ class TftpContextClient(TftpContext):
def cycle(self):
"""Here we wait for a response from the server after sending it
something, and dispatch appropriate action to that response."""
+ # FIXME: This won't work very well in a server context with multiple
+ # sessions running.
for i in range(TIMEOUT_RETRIES):
- logger.debug("in cycle, receive attempt %d" % i)
+ log.debug("in cycle, receive attempt %d" % i)
try:
(buffer, (raddress, rport)) = self.sock.recvfrom(MAX_BLKSIZE)
except socket.timeout, err:
- logger.warn("Timeout waiting for traffic, retrying...")
+ log.warn("Timeout waiting for traffic, retrying...")
continue
break
else:
raise TftpException, "Hit max timeouts, giving up."
# Ok, we've received a packet. Log it.
- logger.debug("Received %d bytes from %s:%s"
+ log.debug("Received %d bytes from %s:%s"
% (len(buffer), raddress, rport))
# Decode it.
@@ -131,11 +132,11 @@ class TftpContextClient(TftpContext):
# Check for known "connection".
if raddress != self.address:
- logger.warn("Received traffic from %s, expected host %s. Discarding"
+ log.warn("Received traffic from %s, expected host %s. Discarding"
% (raddress, self.host))
if self.tidport and self.tidport != rport:
- logger.warn("Received traffic from %s:%s but we're "
+ log.warn("Received traffic from %s:%s but we're "
"connected to %s:%s. Discarding."
% (raddress, rport,
self.host, self.tidport))
@@ -150,29 +151,66 @@ class TftpContextClient(TftpContext):
# And handle it, possibly changing state.
self.state = self.state.handle(recvpkt, raddress, rport)
-class TftpContextClientUpload(TftpContextClient):
+class TftpContextServer(TftpContext):
+ """The context for the server."""
+ def __init__(self, host, port, timeout, root, dyn_file_func):
+ TftpContext.__init__(self,
+ host,
+ port,
+ timeout)
+ # At this point we have no idea if this is a download or an upload. We
+ # need to let the start state determine that.
+ self.state = TftpStateServerStart()
+ self.root = root
+ self.dyn_file_func = dyn_file_func
+
+ def start(self, buffer):
+ """Start the state cycle. Note that the server context receives an
+ initial packet in its start method."""
+ log.debug("TftpContextServer.start() - pkt = %s" % pkt)
+
+ self.metrics.start_time = time.time()
+ log.debug("set metrics.start_time to %s" % self.metrics.start_time)
+
+ pkt = self.factory.parse(buffer)
+ log.debug("TftpContextServer.start() - factory returned a %s" % pkt)
+
+ # Call handle once with the initial packet. This should put us into
+ # the download or the upload state.
+ self.state = self.state.handle(pkt,
+ self.host,
+ self.port)
+
+ try:
+ while self.state:
+ log.debug("state is %s" % self.state)
+ self.cycle()
+ finally:
+ self.fileobj.close()
+
+class TftpContextClientUpload(TftpContext):
"""The upload context for the client during an upload."""
def __init__(self, host, port, filename, input, options, packethook, timeout):
- TftpContextClient.__init__(self,
- host,
- port,
- filename,
- options,
- packethook,
- timeout)
+ TftpContext.__init__(self,
+ host,
+ port,
+ timeout)
+ self.file_to_transfer = filename
+ self.options = options
+ self.packethook = packethook
self.fileobj = open(input, "wb")
- logger.debug("TftpContextClientUpload.__init__()")
- logger.debug("file_to_transfer = %s, options = %s" %
+ log.debug("TftpContextClientUpload.__init__()")
+ log.debug("file_to_transfer = %s, options = %s" %
(self.file_to_transfer, self.options))
def start(self):
- logger.info("sending tftp upload request to %s" % self.host)
- logger.info(" filename -> %s" % self.file_to_transfer)
- logger.info(" options -> %s" % self.options)
+ log.info("sending tftp upload request to %s" % self.host)
+ log.info(" filename -> %s" % self.file_to_transfer)
+ log.info(" options -> %s" % self.options)
self.metrics.start_time = time.time()
- logger.debug("set metrics.start_time to %s" % self.metrics.start_time)
+ log.debug("set metrics.start_time to %s" % self.metrics.start_time)
# FIXME: put this in a sendWRQ method?
pkt = TftpPacketWRQ()
@@ -186,7 +224,7 @@ class TftpContextClientUpload(TftpContextClient):
try:
while self.state:
- logger.debug("state is %s" % self.state)
+ log.debug("state is %s" % self.state)
self.cycle()
finally:
self.fileobj.close()
@@ -194,32 +232,32 @@ class TftpContextClientUpload(TftpContextClient):
def end(self):
pass
-class TftpContextClientDownload(TftpContextClient):
+class TftpContextClientDownload(TftpContext):
"""The download context for the client during a download."""
def __init__(self, host, port, filename, output, options, packethook, timeout):
- TftpContextClient.__init__(self,
- host,
- port,
- filename,
- options,
- packethook,
- timeout)
+ TftpContext.__init__(self,
+ host,
+ port,
+ filename,
+ options,
+ packethook,
+ timeout)
# FIXME - need to support alternate return formats than files?
# File-like objects would be ideal, ala duck-typing.
self.fileobj = open(output, "wb")
- logger.debug("TftpContextClientDownload.__init__()")
- logger.debug("file_to_transfer = %s, options = %s" %
+ log.debug("TftpContextClientDownload.__init__()")
+ log.debug("file_to_transfer = %s, options = %s" %
(self.file_to_transfer, self.options))
def start(self):
"""Initiate the download."""
- logger.info("sending tftp download request to %s" % self.host)
- logger.info(" filename -> %s" % self.file_to_transfer)
- logger.info(" options -> %s" % self.options)
+ log.info("sending tftp download request to %s" % self.host)
+ log.info(" filename -> %s" % self.file_to_transfer)
+ log.info(" options -> %s" % self.options)
self.metrics.start_time = time.time()
- logger.debug("set metrics.start_time to %s" % self.metrics.start_time)
+ log.debug("set metrics.start_time to %s" % self.metrics.start_time)
# FIXME: put this in a sendRRQ method?
pkt = TftpPacketRRQ()
@@ -233,7 +271,7 @@ class TftpContextClientDownload(TftpContextClient):
try:
while self.state:
- logger.debug("state is %s" % self.state)
+ log.debug("state is %s" % self.state)
self.cycle()
finally:
self.fileobj.close()
@@ -241,7 +279,7 @@ class TftpContextClientDownload(TftpContextClient):
def end(self):
"""Finish up the context."""
self.metrics.end_time = time.time()
- logger.debug("set metrics.end_time to %s" % self.metrics.end_time)
+ log.debug("set metrics.end_time to %s" % self.metrics.end_time)
self.metrics.compute()
@@ -268,235 +306,431 @@ class TftpState(object):
options."""
if pkt.options.keys() > 0:
if pkt.match_options(self.context.options):
- logger.info("Successful negotiation of options")
+ log.info("Successful negotiation of options")
# Set options to OACK options
self.context.options = pkt.options
for key in self.context.options:
- logger.info(" %s = %s" % (key, self.context.options[key]))
+ log.info(" %s = %s" % (key, self.context.options[key]))
else:
- logger.error("failed to negotiate options")
+ log.error("failed to negotiate options")
raise TftpException, "Failed to negotiate options"
else:
raise TftpException, "No options found in OACK"
-class TftpStateUpload(TftpState):
- """A class holding common code for upload states."""
- def sendDat(self, resend=False):
+ def returnSupportedOptions(self, options):
+ """This method takes a requested options list from a client, and
+ returns the ones that are supported."""
+ # We support the options blksize and tsize right now.
+ # FIXME - put this somewhere else?
+ accepted_options = {}
+ for option in options:
+ if option == 'blksize':
+ # Make sure it's valid.
+ if int(options[option]) > MAX_BLKSIZE:
+ accepted_options[option] = MAX_BLKSIZE
+ elif option == 'tsize':
+ log.debug("tsize option is set")
+ accepted_options['tsize'] = 1
+ else:
+ log.info("Dropping unsupported option '%s'" % option)
+ return accepted_options
+
+ def serverInitial(self, pkt, raddress, rport):
+ """This method performs initial setup for a server context transfer,
+ put here to refactor code out of the TftpStateServerRecvRRQ and
+ TftpStateServerRecvWRQ classes, since their initial setup is
+ identical. The method returns a boolean, sendoack, to indicate whether
+ it is required to send an OACK to the client."""
+ options = pkt.options
+ sendoack = False
+ if not options:
+ log.debug("setting default options, blksize")
+ # FIXME: put default options elsewhere
+ self.context.options = { 'blksize': DEF_BLKSIZE }
+ else:
+ log.debug("options requested: %s" % options)
+ self.context.options = self.returnSupportedOptions(options)
+ sendoack = True
+
+ # FIXME - only octet mode is supported at this time.
+ if pkt.mode != 'octet':
+ self.sendError(TftpErrors.IllegalTftpOp)
+ raise TftpException, \
+ "Only octet transfers are supported at this time."
+
+ # test host/port of client end
+ if self.context.host != raddress or self.context.port != rport:
+ self.sendError(TftpErrors.UnknownTID)
+ log.error("Expected traffic from %s:%s but received it "
+ "from %s:%s instead."
+ % (self.context.host,
+ self.context.port,
+ raddress,
+ rport))
+ # FIXME: increment an error count?
+ # Return same state, we're still waiting for valid traffic.
+ return self
+
+ log.debug("requested filename is %s" % pkt.filename)
+ # There are no os.sep's allowed in the filename.
+ # FIXME: Should we allow subdirectories?
+ if pkt.filename.find(os.sep) >= 0:
+ self.sendError(TftpErrors.IllegalTftpOp)
+ raise TftpException, "%s found in filename, not permitted" % os.sep
+
+ self.context.file_to_transfer = pkt.filename
+
+ return sendoack
+
+ def sendDAT(self, resend=False):
+ """This method sends the next DAT packet based on the data in the
+ context. It returns a boolean indicating whether the transfer is
+ finished."""
finished = False
blocknumber = self.context.next_block
if not resend:
blksize = int(self.context.options['blksize'])
buffer = self.context.fileobj.read(blksize)
- logger.debug("Read %d bytes into buffer" % len(buffer))
+ log.debug("Read %d bytes into buffer" % len(buffer))
if len(buffer) < blksize:
- logger.info("Reached EOF on file %s" % self.context.input)
+ log.info("Reached EOF on file %s" % self.context.input)
finished = True
self.context.next_block += 1
self.bytes += len(buffer)
else:
- logger.warn("Resending block number %d" % blocknumber)
+ log.warn("Resending block number %d" % blocknumber)
dat = TftpPacketDAT()
dat.data = buffer
dat.blocknumber = blocknumber
- logger.debug("Sending DAT packet %d" % blocknumber)
+ log.debug("Sending DAT packet %d" % blocknumber)
self.context.sock.sendto(dat.encode().buffer,
(self.context.host, self.context.port))
if self.context.packethook:
self.context.packethook(dat)
return finished
-class TftpStateDownload(TftpState):
- """A class holding common code for download states."""
+ def sendACK(self, blocknumber=None):
+ """This method sends an ack packet to the block number specified. If
+ none is specified, it defaults to the next_block property in the
+ parent context."""
+ if not blocknumber:
+ blocknumber = self.context.next_block
+ log.info("sending ack to block %d" % blocknumber)
+ ackpkt = TftpPacketACK()
+ ackpkt.blocknumber = blocknumber
+ self.context.sock.sendto(ackpkt.encode().buffer,
+ (self.context.host,
+ self.context.tidport))
+
+ def sendError(self, errorcode):
+ """This method uses the socket passed, and uses the errorcode to
+ compose and send an error packet."""
+ log.debug("In sendError, being asked to send error %d" % errorcode)
+ errpkt = TftpPacketERR()
+ errpkt.errorcode = errorcode
+ self.context.sock.sendto(errpkt.encode().buffer,
+ (self.context.host,
+ self.context.tidport))
+
+ def sendOACK(self):
+ """This method sends an OACK packet with the options from the current
+ context."""
+ log.debug("In sendOACK with options %s" % options)
+ pkt = TftpPacketOACK()
+ pkt.options = self.options
+ self.context.sock.sendto(pkt.encode().buffer,
+ (self.context.host,
+ self.context.tidport))
+
def handleDat(self, pkt):
- """This method handles a DAT packet during a download."""
- logger.info("handling DAT packet - block %d" % pkt.blocknumber)
- logger.debug("expecting block %s" % self.context.next_block)
+ """This method handles a DAT packet during a client download, or a
+ server upload."""
+ log.info("handling DAT packet - block %d" % pkt.blocknumber)
+ log.debug("expecting block %s" % self.context.next_block)
if pkt.blocknumber == self.context.next_block:
- logger.debug("good, received block %d in sequence"
+ log.debug("good, received block %d in sequence"
% pkt.blocknumber)
- self.context.sendAck(pkt.blocknumber)
+ self.sendACK()
self.context.next_block += 1
- logger.debug("writing %d bytes to output file"
+ log.debug("writing %d bytes to output file"
% len(pkt.data))
self.context.fileobj.write(pkt.data)
self.context.metrics.bytes += len(pkt.data)
# Check for end-of-file, any less than full data packet.
if len(pkt.data) < int(self.context.options['blksize']):
- logger.info("end of file detected")
+ log.info("end of file detected")
return None
elif pkt.blocknumber < self.context.next_block:
- logger.warn("dropping duplicate block %d" % pkt.blocknumber)
- if self.context.metrics.dups.has_key(pkt.blocknumber):
- self.context.metrics.dups[pkt.blocknumber] += 1
- else:
- self.context.metrics.dups[pkt.blocknumber] = 1
- tftpassert(self.context.metrics.dups[pkt.blocknumber] < MAX_DUPS,
- "Max duplicates for block %d reached" % pkt.blocknumber)
- # FIXME: double-check sorceror's apprentice problem!
- logger.debug("ACKing block %d again, just in case" % pkt.blocknumber)
- self.context.sendAck(pkt.blocknumber)
+ log.warn("dropping duplicate block %d" % pkt.blocknumber)
+ self.context.metrics.add_dup(pkt.blocknumber)
+ log.debug("ACKing block %d again, just in case" % pkt.blocknumber)
+ self.sendACK(pkt.blocknumber)
else:
# FIXME: should we be more tolerant and just discard instead?
msg = "Whoa! Received future block %d but expected %d" \
% (pkt.blocknumber, self.context.next_block)
- logger.error(msg)
+ log.error(msg)
raise TftpException, msg
# Default is to ack
- return TftpStateSentACK(self.context)
+ return TftpStateExpectDAT(self.context)
+
+class TftpStateServerRecvRRQ(TftpState):
+ """This class represents the state of the TFTP server when it has just
+ received an RRQ packet."""
+ def handle(self, pkt, raddress, rport):
+ "Handle an initial RRQ packet as a server."
+ log.debug("In TftpStateServerRecvRRQ.handle")
+ sendoack = self.serverInitial(pkt, raddress, rport)
+ path = self.context.root + os.sep + self.context.file_to_transfer
+ log.info("Opening file %s for reading" % path)
+ if os.path.exists(path):
+ # Note: Open in binary mode for win32 portability, since win32
+ # blows.
+ self.context.fileobj = open(path, "rb")
+ elif self.dyn_file_func:
+ log.debug("No such file %s but using dyn_file_func" % path)
+ self.context.fileobj = \
+ self.dyn_file_func(self.context.file_to_transfer)
+ else:
+ send.sendError(TftpErrors.FileNotFound)
+ raise TftpException, "File not found: %s" % path
+
+ # Options negotiation.
+ if sendoack:
+ self.sendOACK()
+ return TftpStateServerOACK(self.context)
+ else:
+ log.debug("No requested options, starting send...")
+ self.context.pending_complete = self.sendDAT()
+ return TftpStateExpectACK(self.context)
+
+ # Note, we don't have to check any other states in this method, that's
+ # up to the caller.
+
+class TftpStateServerRecvWRQ(TftpState):
+ """This class represents the state of the TFTP server when it has just
+ received a WRQ packet."""
+ def handle(self, pkt, raddress, rport):
+ "Handle an initial WRQ packet as a server."
+ log.debug("In TftpStateServerRecvWRQ.handle")
+ sendoack = self.serverInitial(pkt, raddress, rport)
+ path = self.context.root + os.sep + self.context.file_to_transfer
+ log.info("Opening file %s for writing" % path)
+ if os.path.exists(path):
+ # FIXME: correct behavior?
+ log.warn("File %s exists already, overwriting...")
+ # FIXME: I think we should upload to a temp file and not overwrite the
+ # existing file until the file is successfully uploaded.
+ self.context.fileobj = open(path, "wb")
+
+ # Options negotiation.
+ if sendoack:
+ log.debug("Sending OACK to client")
+ self.sendOACK()
+ else:
+ log.debug("No requested options, starting send...")
+ self.sendACK()
+ # We may have sent an OACK, but we're expecting a DAT as the response
+ # to either the OACK or an ACK, so lets unconditionally use the
+ # TftpStateExpectDAT state.
+ return TftpStateExpectDAT(self.context)
+
+ # Note, we don't have to check any other states in this method, that's
+ # up to the caller.
+
+class TftpStateServerStart(TftpState):
+ """The start state for the server."""
+ def handle(self, pkt, raddress, rport):
+ """Handle a packet we just received."""
+ log.debug("In TftpStateServerStart.handle")
+ if isinstance(pkt, TftpPacketRRQ):
+ log.debug("handling an RRQ packet")
+ return TftpStateServerRecvRRQ(self.context).handle(pkt,
+ raddress,
+ rport)
+ elif isinstance(pkt, TftpPacketWRQ):
+ log.debug("handling a WRQ packet")
+ return TftpStateServerRecvWRQ(self.context).handle(pkt,
+ raddress,
+ rport)
+ else:
+ self.sendError(TftpErrors.IllegalTftpOp)
+ raise TftpException, \
+ "Invalid packet to begin up/download: %s" % pkt
+
+class TftpStateExpectACK(TftpState):
+ """This class represents the state of the transfer when a DAT was just
+ sent, and we are waiting for an ACK from the server. This class is the
+ same one used by the client during the upload, and the server during the
+ download."""
+ def handle(self, pkt, raddress, rport):
+ "Handle a packet, hopefully an ACK since we just sent a DAT."
+ if isinstance(pkt, TftpPacketACK):
+ log.info("Received ACK for packet %d" % pkt.blocknumber)
+ # Is this an ack to the one we just sent?
+ if self.context.next_block == pkt.blocknumber:
+ if self.context.pending_complete:
+ log.info("Received ACK to final DAT, we're done.")
+ return None
+ else:
+ log.debug("Good ACK, sending next DAT")
+ self.context.pending_complete = self.sendDAT()
-class TftpStateSentWRQ(TftpStateUpload):
+ elif pkt.blocknumber < self.context.next_block:
+ self.context.metrics.add_dup(pkt.blocknumber)
+
+ else:
+ log.warn("Oooh, time warp. Received ACK to packet we "
+ "didn't send yet. Discarding.")
+ self.context.metrics.errors += 1
+ return self
+ elif isinstance(pkt, TftpPacketERR):
+ log.error("Received ERR packet from peer: %s" % str(pkt))
+ raise TftpException, \
+ "Received ERR packet from peer: %s" % str(pkt)
+ else:
+ log.warn("Discarding unsupported packet: %s" % str(pkt))
+ return self
+
+class TftpStateExpectDAT(TftpState):
+ """Just sent an ACK packet. Waiting for DAT."""
+ def handle(self, pkt, raddress, rport):
+ """Handle the packet in response to an ACK, which should be a DAT."""
+ if isinstance(pkt, TftpPacketDAT):
+ return self.handleDat(pkt)
+
+ # Every other packet type is a problem.
+ elif isinstance(recvpkt, TftpPacketACK):
+ # Umm, we ACK, you don't.
+ self.sendError(TftpErrors.IllegalTftpOp)
+ raise TftpException, "Received ACK from peer when expecting DAT"
+
+ elif isinstance(recvpkt, TftpPacketWRQ):
+ self.sendError(TftpErrors.IllegalTftpOp)
+ raise TftpException, "Received WRQ from peer when expecting DAT"
+
+ elif isinstance(recvpkt, TftpPacketERR):
+ self.sendError(TftpErrors.IllegalTftpOp)
+ raise TftpException, "Received ERR from peer: " + str(recvpkt)
+
+ else:
+ self.sendError(TftpErrors.IllegalTftpOp)
+ raise TftpException, "Received unknown packet type from peer: " + str(recvpkt)
+
+class TftpStateSentWRQ(TftpState):
"""Just sent an WRQ packet for an upload."""
def handle(self, pkt, raddress, rport):
"""Handle a packet we just received."""
if not self.context.tidport:
self.context.tidport = rport
- logger.debug("Set remote port for session to %s" % rport)
+ log.debug("Set remote port for session to %s" % rport)
# If we're going to successfully transfer the file, then we should see
# either an OACK for accepted options, or an ACK to ignore options.
if isinstance(pkt, TftpPacketOACK):
- logger.info("received OACK from server")
+ log.info("received OACK from server")
try:
self.handleOACK(pkt)
except TftpException, err:
- logger.error("failed to negotiate options")
- self.context.sendError(TftpErrors.FailedNegotiation)
+ log.error("failed to negotiate options")
+ self.sendError(TftpErrors.FailedNegotiation)
raise
else:
- logger.debug("sending first DAT packet")
- fin = self.context.sendDat()
- if fin:
- logger.info("Add done")
- return None
- else:
- logger.debug("Changing state to TftpStateSentDAT")
- return TftpStateSentDAT(self.context)
+ log.debug("sending first DAT packet")
+ self.context.pending_complete = self.sendDAT()
+ log.debug("Changing state to TftpStateExpectACK")
+ return TftpStateExpectACK(self.context)
elif isinstance(pkt, TftpPacketACK):
- logger.info("received ACK from server")
- logger.debug("apparently the server ignored our options")
+ log.info("received ACK from server")
+ log.debug("apparently the server ignored our options")
# The block number should be zero.
if pkt.blocknumber == 0:
- logger.debug("ack blocknumber is zero as expected")
- logger.debug("sending first DAT packet")
- fin = self.context.sendDat()
- if fin:
- logger.info("Add done")
- return None
- else:
- logger.debug("Changing state to TftpStateSentDAT")
- return TftpStateSentDAT(self.context)
+ log.debug("ack blocknumber is zero as expected")
+ log.debug("sending first DAT packet")
+ self.pending_complete = self.context.sendDAT()
+ log.debug("Changing state to TftpStateExpectACK")
+ return TftpStateExpectACK(self.context)
else:
- logger.warn("discarding ACK to block %s" % pkt.blocknumber)
- logger.debug("still waiting for valid response from server")
+ log.warn("discarding ACK to block %s" % pkt.blocknumber)
+ log.debug("still waiting for valid response from server")
return self
elif isinstance(pkt, TftpPacketERR):
- self.context.sendError(TftpErrors.IllegalTftpOp)
+ self.sendError(TftpErrors.IllegalTftpOp)
raise TftpException, "Received ERR from server: " + str(pkt)
elif isinstance(pkt, TftpPacketRRQ):
- self.context.sendError(TftpErrors.IllegalTftpOp)
+ self.sendError(TftpErrors.IllegalTftpOp)
raise TftpException, "Received RRQ from server while in upload"
elif isinstance(pkt, TftpPacketDAT):
- self.context.sendError(TftpErrors.IllegalTftpOp)
+ self.sendError(TftpErrors.IllegalTftpOp)
raise TftpException, "Received DAT from server while in upload"
else:
- self.context.sendError(TftpErrors.IllegalTftpOp)
+ self.sendError(TftpErrors.IllegalTftpOp)
raise TftpException, "Received unknown packet type from server: " + str(pkt)
# By default, no state change.
return self
-class TftpStateSentDAT(TftpStateUpload):
- """This class represents the state of the transfer when a DAT was just
- sent, and we are waiting for an ACK from the server. This class is the
- same one used by the client during the upload, and the server during the
- download."""
-
-class TftpStateSentRRQ(TftpStateDownload):
+class TftpStateSentRRQ(TftpState):
"""Just sent an RRQ packet."""
def handle(self, pkt, raddress, rport):
"""Handle the packet in response to an RRQ to the server."""
if not self.context.tidport:
self.context.tidport = rport
- logger.debug("Set remote port for session to %s" % rport)
+ log.debug("Set remote port for session to %s" % rport)
# Now check the packet type and dispatch it properly.
if isinstance(pkt, TftpPacketOACK):
- logger.info("received OACK from server")
+ log.info("received OACK from server")
try:
self.handleOACK(pkt)
except TftpException, err:
- logger.error("failed to negotiate options: %s" % str(err))
- self.context.sendError(TftpErrors.FailedNegotiation)
+ log.error("failed to negotiate options: %s" % str(err))
+ self.sendError(TftpErrors.FailedNegotiation)
raise
else:
- logger.debug("sending ACK to OACK")
+ log.debug("sending ACK to OACK")
- self.context.sendAck(blocknumber=0)
+ self.sendACK(blocknumber=0)
- logger.debug("Changing state to TftpStateSentACK")
- return TftpStateSentACK(self.context)
+ log.debug("Changing state to TftpStateExpectDAT")
+ return TftpStateExpectDAT(self.context)
elif isinstance(pkt, TftpPacketDAT):
# If there are any options set, then the server didn't honour any
# of them.
- logger.info("received DAT from server")
+ log.info("received DAT from server")
if self.context.options:
- logger.info("server ignored options, falling back to defaults")
+ log.info("server ignored options, falling back to defaults")
self.context.options = { 'blksize': DEF_BLKSIZE }
return self.handleDat(pkt)
# Every other packet type is a problem.
elif isinstance(recvpkt, TftpPacketACK):
# Umm, we ACK, the server doesn't.
- self.context.sendError(TftpErrors.IllegalTftpOp)
+ self.sendError(TftpErrors.IllegalTftpOp)
raise TftpException, "Received ACK from server while in download"
elif isinstance(recvpkt, TftpPacketWRQ):
- self.context.sendError(TftpErrors.IllegalTftpOp)
+ self.sendError(TftpErrors.IllegalTftpOp)
raise TftpException, "Received WRQ from server while in download"
elif isinstance(recvpkt, TftpPacketERR):
- self.context.sendError(TftpErrors.IllegalTftpOp)
+ self.sendError(TftpErrors.IllegalTftpOp)
raise TftpException, "Received ERR from server: " + str(recvpkt)
else:
- self.context.sendError(TftpErrors.IllegalTftpOp)
+ self.sendError(TftpErrors.IllegalTftpOp)
raise TftpException, "Received unknown packet type from server: " + str(recvpkt)
# By default, no state change.
return self
-
-class TftpStateSentACK(TftpStateDownload):
- """Just sent an ACK packet. Waiting for DAT."""
- def handle(self, pkt, raddress, rport):
- """Handle the packet in response to an ACK, which should be a DAT."""
- if isinstance(pkt, TftpPacketDAT):
- return self.handleDat(pkt)
-
- # Every other packet type is a problem.
- elif isinstance(recvpkt, TftpPacketACK):
- # Umm, we ACK, the server doesn't.
- self.context.sendError(TftpErrors.IllegalTftpOp)
- raise TftpException, "Received ACK from server while in download"
-
- elif isinstance(recvpkt, TftpPacketWRQ):
- self.context.sendError(TftpErrors.IllegalTftpOp)
- raise TftpException, "Received WRQ from server while in download"
-
- elif isinstance(recvpkt, TftpPacketERR):
- self.context.sendError(TftpErrors.IllegalTftpOp)
- raise TftpException, "Received ERR from server: " + str(recvpkt)
-
- else:
- self.context.sendError(TftpErrors.IllegalTftpOp)
- raise TftpException, "Received unknown packet type from server: " + str(recvpkt)