diff --git a/paramiko/transport.py b/paramiko/transport.py index 29f59e8..4e2ddbc 100644 --- a/paramiko/transport.py +++ b/paramiko/transport.py @@ -22,7 +22,7 @@ L{BaseTransport} handles the core SSH2 protocol. """ -import sys, os, string, threading, socket, struct +import sys, os, string, threading, socket, struct, time from common import * from ssh_exception import SSHException @@ -154,7 +154,7 @@ class BaseTransport (threading.Thread): # /negotiated crypto parameters self.expected_packet = 0 self.active = False - self.initial_kex_done = 0 + self.initial_kex_done = False self.write_lock = threading.Lock() # lock around outbound writes (packet computation) self.lock = threading.Lock() # synchronization (always higher level than write_lock) self.channels = { } # (id -> Channel) @@ -171,6 +171,9 @@ class BaseTransport (threading.Thread): self.received_packets_overflow = 0 # user-defined event callbacks: self.completion_event = None + # keepalives: + self.keepalive_interval = 0 + self.keepalive_last = time.time() # server mode: self.server_mode = 0 self.server_key_dict = { } @@ -432,6 +435,8 @@ class BaseTransport (threading.Thread): @param bytes: the number of random bytes to send in the payload of the ignored packet -- defaults to a random number from 10 to 41. @type bytes: int + + @since: fearow """ m = Message() m.add_byte(chr(MSG_IGNORE)) @@ -464,6 +469,19 @@ class BaseTransport (threading.Thread): break return True + def set_keepalive(self, interval): + """ + Turn on/off keepalive packets (default is off). If this is set, after + C{interval} seconds without sending any data over the connection, a + "keepalive" packet will be sent (and ignored by the remote host). This + can be useful to keep connections alive over a NAT, for example. + + @param interval: seconds to wait before sending a keepalive packet (or + 0 to disable keepalives). + @type interval: int + """ + self.keepalive_interval = interval + def global_request(self, kind, data=None, wait=True): """ Make a global request to the remote host. These are normally @@ -481,6 +499,8 @@ class BaseTransport (threading.Thread): request was successful (or an empty L{Message} if C{wait} was C{False}); C{None} if the request was denied. @rtype: L{Message} + + @since: fearow """ if wait: self.completion_event = threading.Event() @@ -491,6 +511,7 @@ class BaseTransport (threading.Thread): if data is not None: for item in data: m.add(item) + self._log(DEBUG, 'Sending global request "%s"' % kind) self._send_message(m) if not wait: return True @@ -691,6 +712,13 @@ class BaseTransport (threading.Thread): finally: self.lock.release() + def _check_keepalive(self): + if (not self.keepalive_interval) or (not self.initial_kex_done): + return + now = time.time() + if now > self.keepalive_last + self.keepalive_interval: + self.global_request('keepalive@lag.net', wait=False) + def _py22_read_all(self, n): out = '' while n > 0: @@ -698,6 +726,7 @@ class BaseTransport (threading.Thread): if self.sock not in r: if not self.active: raise EOFError() + self._check_keepalive() else: x = self.sock.recv(n) if len(x) == 0: @@ -720,9 +749,11 @@ class BaseTransport (threading.Thread): except socket.timeout: if not self.active: raise EOFError() + self._check_keepalive() return out def _write_all(self, out): + self.keepalive_last = time.time() while len(out) > 0: n = self.sock.send(out) if n <= 0: @@ -1156,7 +1187,7 @@ class BaseTransport (threading.Thread): self.e = self.f = self.K = self.x = None if not self.initial_kex_done: # this was the first key exchange - self.initial_kex_done = 1 + self.initial_kex_done = True # send an event? if self.completion_event != None: self.completion_event.set() @@ -1169,6 +1200,7 @@ class BaseTransport (threading.Thread): def _parse_global_request(self, m): kind = m.get_string() + self._log(DEBUG, 'Received global request "%s"' % kind) want_reply = m.get_boolean() ok = self.check_global_request(kind, m) extra = () @@ -1186,11 +1218,13 @@ class BaseTransport (threading.Thread): self._send_message(msg) def _parse_request_success(self, m): + self._log(DEBUG, 'Global request successful.') self.global_response = m if self.completion_event is not None: self.completion_event.set() def _parse_request_failure(self, m): + self._log(DEBUG, 'Global request denied.') self.global_response = None if self.completion_event is not None: self.completion_event.set()