[project @ Arch-1:robey@lag.net--2003-public%secsh--dev--1.0--patch-45]

add set_keepalive()
add set_keepalive() to set an automatic keepalive mechanism.  (while waiting
for a packet on a connection, we periodically check if it's time to send a
keepalive packet.)
This commit is contained in:
Robey Pointer 2004-04-07 15:52:07 +00:00
parent 1af6360007
commit 17acfb5d28
1 changed files with 37 additions and 3 deletions

View File

@ -22,7 +22,7 @@
L{BaseTransport} handles the core SSH2 protocol. L{BaseTransport} handles the core SSH2 protocol.
""" """
import sys, os, string, threading, socket, struct import sys, os, string, threading, socket, struct, time
from common import * from common import *
from ssh_exception import SSHException from ssh_exception import SSHException
@ -154,7 +154,7 @@ class BaseTransport (threading.Thread):
# /negotiated crypto parameters # /negotiated crypto parameters
self.expected_packet = 0 self.expected_packet = 0
self.active = False self.active = False
self.initial_kex_done = 0 self.initial_kex_done = False
self.write_lock = threading.Lock() # lock around outbound writes (packet computation) self.write_lock = threading.Lock() # lock around outbound writes (packet computation)
self.lock = threading.Lock() # synchronization (always higher level than write_lock) self.lock = threading.Lock() # synchronization (always higher level than write_lock)
self.channels = { } # (id -> Channel) self.channels = { } # (id -> Channel)
@ -171,6 +171,9 @@ class BaseTransport (threading.Thread):
self.received_packets_overflow = 0 self.received_packets_overflow = 0
# user-defined event callbacks: # user-defined event callbacks:
self.completion_event = None self.completion_event = None
# keepalives:
self.keepalive_interval = 0
self.keepalive_last = time.time()
# server mode: # server mode:
self.server_mode = 0 self.server_mode = 0
self.server_key_dict = { } self.server_key_dict = { }
@ -432,6 +435,8 @@ class BaseTransport (threading.Thread):
@param bytes: the number of random bytes to send in the payload of the @param bytes: the number of random bytes to send in the payload of the
ignored packet -- defaults to a random number from 10 to 41. ignored packet -- defaults to a random number from 10 to 41.
@type bytes: int @type bytes: int
@since: fearow
""" """
m = Message() m = Message()
m.add_byte(chr(MSG_IGNORE)) m.add_byte(chr(MSG_IGNORE))
@ -464,6 +469,19 @@ class BaseTransport (threading.Thread):
break break
return True return True
def set_keepalive(self, interval):
"""
Turn on/off keepalive packets (default is off). If this is set, after
C{interval} seconds without sending any data over the connection, a
"keepalive" packet will be sent (and ignored by the remote host). This
can be useful to keep connections alive over a NAT, for example.
@param interval: seconds to wait before sending a keepalive packet (or
0 to disable keepalives).
@type interval: int
"""
self.keepalive_interval = interval
def global_request(self, kind, data=None, wait=True): def global_request(self, kind, data=None, wait=True):
""" """
Make a global request to the remote host. These are normally Make a global request to the remote host. These are normally
@ -481,6 +499,8 @@ class BaseTransport (threading.Thread):
request was successful (or an empty L{Message} if C{wait} was request was successful (or an empty L{Message} if C{wait} was
C{False}); C{None} if the request was denied. C{False}); C{None} if the request was denied.
@rtype: L{Message} @rtype: L{Message}
@since: fearow
""" """
if wait: if wait:
self.completion_event = threading.Event() self.completion_event = threading.Event()
@ -491,6 +511,7 @@ class BaseTransport (threading.Thread):
if data is not None: if data is not None:
for item in data: for item in data:
m.add(item) m.add(item)
self._log(DEBUG, 'Sending global request "%s"' % kind)
self._send_message(m) self._send_message(m)
if not wait: if not wait:
return True return True
@ -691,6 +712,13 @@ class BaseTransport (threading.Thread):
finally: finally:
self.lock.release() self.lock.release()
def _check_keepalive(self):
if (not self.keepalive_interval) or (not self.initial_kex_done):
return
now = time.time()
if now > self.keepalive_last + self.keepalive_interval:
self.global_request('keepalive@lag.net', wait=False)
def _py22_read_all(self, n): def _py22_read_all(self, n):
out = '' out = ''
while n > 0: while n > 0:
@ -698,6 +726,7 @@ class BaseTransport (threading.Thread):
if self.sock not in r: if self.sock not in r:
if not self.active: if not self.active:
raise EOFError() raise EOFError()
self._check_keepalive()
else: else:
x = self.sock.recv(n) x = self.sock.recv(n)
if len(x) == 0: if len(x) == 0:
@ -720,9 +749,11 @@ class BaseTransport (threading.Thread):
except socket.timeout: except socket.timeout:
if not self.active: if not self.active:
raise EOFError() raise EOFError()
self._check_keepalive()
return out return out
def _write_all(self, out): def _write_all(self, out):
self.keepalive_last = time.time()
while len(out) > 0: while len(out) > 0:
n = self.sock.send(out) n = self.sock.send(out)
if n <= 0: if n <= 0:
@ -1156,7 +1187,7 @@ class BaseTransport (threading.Thread):
self.e = self.f = self.K = self.x = None self.e = self.f = self.K = self.x = None
if not self.initial_kex_done: if not self.initial_kex_done:
# this was the first key exchange # this was the first key exchange
self.initial_kex_done = 1 self.initial_kex_done = True
# send an event? # send an event?
if self.completion_event != None: if self.completion_event != None:
self.completion_event.set() self.completion_event.set()
@ -1169,6 +1200,7 @@ class BaseTransport (threading.Thread):
def _parse_global_request(self, m): def _parse_global_request(self, m):
kind = m.get_string() kind = m.get_string()
self._log(DEBUG, 'Received global request "%s"' % kind)
want_reply = m.get_boolean() want_reply = m.get_boolean()
ok = self.check_global_request(kind, m) ok = self.check_global_request(kind, m)
extra = () extra = ()
@ -1186,11 +1218,13 @@ class BaseTransport (threading.Thread):
self._send_message(msg) self._send_message(msg)
def _parse_request_success(self, m): def _parse_request_success(self, m):
self._log(DEBUG, 'Global request successful.')
self.global_response = m self.global_response = m
if self.completion_event is not None: if self.completion_event is not None:
self.completion_event.set() self.completion_event.set()
def _parse_request_failure(self, m): def _parse_request_failure(self, m):
self._log(DEBUG, 'Global request denied.')
self.global_response = None self.global_response = None
if self.completion_event is not None: if self.completion_event is not None:
self.completion_event.set() self.completion_event.set()