diff options
author | Kamal Swamidoss <kswamidoss@gmail.com> | 2012-12-03 21:31:00 -0500 |
---|---|---|
committer | Michael P. Soulier <msoulier@digitaltorque.ca> | 2012-12-03 21:31:00 -0500 |
commit | 7818aaf6f1f0e264ac0bfa0b1e40dc5d1f250855 (patch) | |
tree | cbdc00cab105d095314bc7569b48ec4f426562f0 | |
parent | a5a42a3cc4949538fb83b55822d2df7398eb0771 (diff) | |
download | tftpy-kamal.tar.gz |
Issue 32: Initial code from Kamalkamal
-rw-r--r-- | tftpy/TftpContexts.py | 17 | ||||
-rw-r--r-- | tftpy/TftpServer.py | 11 | ||||
-rw-r--r-- | tftpy/TftpStates.py | 8 |
3 files changed, 24 insertions, 12 deletions
diff --git a/tftpy/TftpContexts.py b/tftpy/TftpContexts.py index c3a1bd4..e061cad 100644 --- a/tftpy/TftpContexts.py +++ b/tftpy/TftpContexts.py @@ -67,7 +67,7 @@ class TftpMetrics(object): class TftpContext(object): """The base class of the contexts.""" - def __init__(self, host, port, timeout, dyn_file_func=None): + def __init__(self, host, port, timeout, dyn_file_func=None, server_callback=None): """Constructor for the base context, setting shared instance variables.""" self.file_to_transfer = None @@ -95,6 +95,7 @@ class TftpContext(object): # The last packet we sent, if applicable, to make resending easy. self.last_pkt = None self.dyn_file_func = dyn_file_func + self.server_callback = server_callback # Count the number of retry attempts. self.retry_count = 0 @@ -194,18 +195,21 @@ class TftpContext(object): class TftpContextServer(TftpContext): """The context for the server.""" - def __init__(self, host, port, timeout, root, dyn_file_func=None): + def __init__(self, host, port, timeout, root, dyn_file_func=None, server_callback=None): TftpContext.__init__(self, host, port, timeout, - dyn_file_func + dyn_file_func, + server_callback ) # At this point we have no idea if this is a download or an upload. We # need to let the start state determine that. self.state = TftpStateServerStart(self) self.root = root self.dyn_file_func = dyn_file_func + self.server_callback = server_callback + self.recv_wrq = None def __str__(self): return "%s:%s %s" % (self.host, self.port, self.state) @@ -226,13 +230,14 @@ class TftpContextServer(TftpContext): # Call handle once with the initial packet. This should put us into # the download or the upload state. - self.state = self.state.handle(pkt, - self.host, - self.port) + self.state, self.recv_wrq = self.state.handle(pkt, + self.host, + self.port) def end(self): """Finish up the context.""" TftpContext.end(self) + self.server_callback(self.file_to_transfer, self.recv_wrq) self.metrics.end_time = time.time() log.debug("Set metrics.end_time to %s" % self.metrics.end_time) self.metrics.compute() diff --git a/tftpy/TftpServer.py b/tftpy/TftpServer.py index 364227c..d9e61a1 100644 --- a/tftpy/TftpServer.py +++ b/tftpy/TftpServer.py @@ -28,6 +28,7 @@ class TftpServer(TftpSession): # A dict of sessions, where each session is keyed by a string like # ip:tid for the remote end. self.sessions = {} + self.server_callback = None if os.path.exists(self.root): log.debug("tftproot %s does exist" % self.root) @@ -49,11 +50,16 @@ class TftpServer(TftpSession): def listen(self, listenip="", listenport=DEF_TFTP_PORT, + server_callback=None, timeout=SOCK_TIMEOUT): """Start a server listening on the supplied interface and port. This defaults to INADDR_ANY (all interfaces) and UDP port 69. You can also - supply a different socket timeout value, if desired.""" + supply a different socket timeout value, if desired. The + server_callback is a callable that will be called when the transfer is + complete, being passed + tftp_factory = TftpPacketFactory() + self.server_callback = server_callback # Don't use new 2.5 ternary operator yet # listenip = listenip if listenip else '0.0.0.0' @@ -105,7 +111,8 @@ class TftpServer(TftpSession): rport, timeout, self.root, - self.dyn_file_func) + self.dyn_file_func, + self.server_callback) try: self.sessions[key].start(buffer) except TftpException, err: diff --git a/tftpy/TftpStates.py b/tftpy/TftpStates.py index c106220..2bc5fcd 100644 --- a/tftpy/TftpStates.py +++ b/tftpy/TftpStates.py @@ -333,14 +333,14 @@ class TftpStateServerStart(TftpState): log.debug("In TftpStateServerStart.handle") if isinstance(pkt, TftpPacketRRQ): log.debug("Handling an RRQ packet") - return TftpStateServerRecvRRQ(self.context).handle(pkt, + return ( TftpStateServerRecvRRQ(self.context).handle(pkt, raddress, - rport) + rport), False ) elif isinstance(pkt, TftpPacketWRQ): log.debug("Handling a WRQ packet") - return TftpStateServerRecvWRQ(self.context).handle(pkt, + return ( TftpStateServerRecvWRQ(self.context).handle(pkt, raddress, - rport) + rport), True ) else: self.sendError(TftpErrors.IllegalTftpOp) raise TftpException, \ |