From 305f5e09a58d4c80740faf564dfdca875d98245f Mon Sep 17 00:00:00 2001 From: Robey Pointer Date: Tue, 3 Jun 2008 22:39:06 -0700 Subject: [PATCH] [project @ robey@lag.net-20080604053906-vz5toqvlp5miqy1x] merge deadlog bugfix from dwayne litzenberger. --- paramiko/channel.py | 8 +- paramiko/sftp_client.py | 6 +- paramiko/sftp_server.py | 6 +- tests/test_transport.py | 160 ++++++++++++++++++++++++++++++++++++++++ 4 files changed, 176 insertions(+), 4 deletions(-) diff --git a/paramiko/channel.py b/paramiko/channel.py index c3fa9fe..f13394d 100644 --- a/paramiko/channel.py +++ b/paramiko/channel.py @@ -693,9 +693,11 @@ class Channel (object): m.add_byte(chr(MSG_CHANNEL_DATA)) m.add_int(self.remote_chanid) m.add_string(s[:size]) - self.transport._send_user_message(m) finally: self.lock.release() + # Note: We release self.lock before calling _send_user_message. + # Otherwise, we can deadlock during re-keying. + self.transport._send_user_message(m) return size def send_stderr(self, s): @@ -729,9 +731,11 @@ class Channel (object): m.add_int(self.remote_chanid) m.add_int(1) m.add_string(s[:size]) - self.transport._send_user_message(m) finally: self.lock.release() + # Note: We release self.lock before calling _send_user_message. + # Otherwise, we can deadlock during re-keying. + self.transport._send_user_message(m) return size def sendall(self, s): diff --git a/paramiko/sftp_client.py b/paramiko/sftp_client.py index 581d0a3..ed56789 100644 --- a/paramiko/sftp_client.py +++ b/paramiko/sftp_client.py @@ -106,7 +106,11 @@ class SFTPClient (BaseSFTP): from_transport = classmethod(from_transport) def _log(self, level, msg): - super(SFTPClient, self)._log(level, "[chan " + self.sock.get_name() + "] " + msg) + if issubclass(type(msg), list): + for m in msg: + super(SFTPClient, self)._log(level, "[chan " + self.sock.get_name() + "] " + m) + else: + super(SFTPClient, self)._log(level, "[chan " + self.sock.get_name() + "] " + msg) def close(self): """ diff --git a/paramiko/sftp_server.py b/paramiko/sftp_server.py index 995fc31..099ac12 100644 --- a/paramiko/sftp_server.py +++ b/paramiko/sftp_server.py @@ -75,7 +75,11 @@ class SFTPServer (BaseSFTP, SubsystemHandler): self.server = sftp_si(server, *largs, **kwargs) def _log(self, level, msg): - super(SFTPServer, self)._log(level, "[chan " + self.sock.get_name() + "] " + msg) + if issubclass(type(msg), list): + for m in msg: + super(SFTPServer, self)._log(level, "[chan " + self.sock.get_name() + "] " + m) + else: + super(SFTPServer, self)._log(level, "[chan " + self.sock.get_name() + "] " + msg) def start_subsystem(self, name, transport, channel): self.sock = channel diff --git a/tests/test_transport.py b/tests/test_transport.py index 293f160..4b52c4f 100644 --- a/tests/test_transport.py +++ b/tests/test_transport.py @@ -27,11 +27,14 @@ import sys import time import threading import unittest +import random from paramiko import Transport, SecurityOptions, ServerInterface, RSAKey, DSSKey, \ SSHException, BadAuthenticationType, InteractiveQuery, ChannelException from paramiko import AUTH_FAILED, AUTH_PARTIALLY_SUCCESSFUL, AUTH_SUCCESSFUL from paramiko import OPEN_SUCCEEDED, OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED +from paramiko.common import MSG_KEXINIT, MSG_CHANNEL_WINDOW_ADJUST +from paramiko.message import Message from loop import LoopSocket @@ -564,3 +567,160 @@ class TransportTest (unittest.TestCase): schan.close() chan.close() self.assertEquals(chan.send_ready(), True) + + def test_I_rekey_deadlock(self): + """ + Regression test for deadlock when in-transit messages are received after MSG_KEXINIT is sent + + Note: When this test fails, it may leak threads. + """ + + # Test for an obscure deadlocking bug that can occur if we receive + # certain messages while initiating a key exchange. + # + # The deadlock occurs as follows: + # + # In the main thread: + # 1. The user's program calls Channel.send(), which sends + # MSG_CHANNEL_DATA to the remote host. + # 2. Packetizer discovers that REKEY_BYTES has been exceeded, and + # sets the __need_rekey flag. + # + # In the Transport thread: + # 3. Packetizer notices that the __need_rekey flag is set, and raises + # NeedRekeyException. + # 4. In response to NeedRekeyException, the transport thread sends + # MSG_KEXINIT to the remote host. + # + # On the remote host (using any SSH implementation): + # 5. The MSG_CHANNEL_DATA is received, and MSG_CHANNEL_WINDOW_ADJUST is sent. + # 6. The MSG_KEXINIT is received, and a corresponding MSG_KEXINIT is sent. + # + # In the main thread: + # 7. The user's program calls Channel.send(). + # 8. Channel.send acquires Channel.lock, then calls Transport._send_user_message(). + # 9. Transport._send_user_message waits for Transport.clear_to_send + # to be set (i.e., it waits for re-keying to complete). + # Channel.lock is still held. + # + # In the Transport thread: + # 10. MSG_CHANNEL_WINDOW_ADJUST is received; Channel._window_adjust + # is called to handle it. + # 11. Channel._window_adjust tries to acquire Channel.lock, but it + # blocks because the lock is already held by the main thread. + # + # The result is that the Transport thread never processes the remote + # host's MSG_KEXINIT packet, because it becomes deadlocked while + # handling the preceding MSG_CHANNEL_WINDOW_ADJUST message. + + # We set up two separate threads for sending and receiving packets, + # while the main thread acts as a watchdog timer. If the timer + # expires, a deadlock is assumed. + + class SendThread(threading.Thread): + def __init__(self, chan, iterations, done_event): + threading.Thread.__init__(self, None, None, self.__class__.__name__) + self.setDaemon(True) + self.chan = chan + self.iterations = iterations + self.done_event = done_event + self.watchdog_event = threading.Event() + self.last = None + + def run(self): + try: + for i in xrange(1, 1+self.iterations): + if self.done_event.isSet(): + break + self.watchdog_event.set() + #print i, "SEND" + self.chan.send("x" * 2048) + finally: + self.done_event.set() + self.watchdog_event.set() + + class ReceiveThread(threading.Thread): + def __init__(self, chan, done_event): + threading.Thread.__init__(self, None, None, self.__class__.__name__) + self.setDaemon(True) + self.chan = chan + self.done_event = done_event + self.watchdog_event = threading.Event() + + def run(self): + try: + while not self.done_event.isSet(): + if self.chan.recv_ready(): + chan.recv(65536) + self.watchdog_event.set() + else: + if random.randint(0, 1): + time.sleep(random.randint(0, 500) / 1000.0) + finally: + self.done_event.set() + self.watchdog_event.set() + + self.setup_test_server() + self.ts.packetizer.REKEY_BYTES = 2048 + + chan = self.tc.open_session() + chan.exec_command('yes') + schan = self.ts.accept(1.0) + + # Monkey patch the client's Transport._handler_table so that the client + # sends MSG_CHANNEL_WINDOW_ADJUST whenever it receives an initial + # MSG_KEXINIT. This is used to simulate the effect of network latency + # on a real MSG_CHANNEL_WINDOW_ADJUST message. + self.tc._handler_table = self.tc._handler_table.copy() # copy per-class dictionary + _negotiate_keys = self.tc._handler_table[MSG_KEXINIT] + def _negotiate_keys_wrapper(self, m): + if self.local_kex_init is None: # Remote side sent KEXINIT + # Simulate in-transit MSG_CHANNEL_WINDOW_ADJUST by sending it + # before responding to the incoming MSG_KEXINIT. + m2 = Message() + m2.add_byte(chr(MSG_CHANNEL_WINDOW_ADJUST)) + m2.add_int(chan.remote_chanid) + m2.add_int(1) # bytes to add + self._send_message(m2) + return _negotiate_keys(self, m) + self.tc._handler_table[MSG_KEXINIT] = _negotiate_keys_wrapper + + # Parameters for the test + iterations = 500 # The deadlock does not happen every time, but it + # should after many iterations. + timeout = 5 + + # This event is set when the test is completed + done_event = threading.Event() + + # Start the sending thread + st = SendThread(schan, iterations, done_event) + st.start() + + # Start the receiving thread + rt = ReceiveThread(chan, done_event) + rt.start() + + # Act as a watchdog timer, checking + deadlocked = False + while not deadlocked and not done_event.isSet(): + for event in (st.watchdog_event, rt.watchdog_event): + event.wait(timeout) + if done_event.isSet(): + break + if not event.isSet(): + deadlocked = True + break + event.clear() + + # Tell the threads to stop (if they haven't already stopped). Note + # that if one or more threads are deadlocked, they might hang around + # forever (until the process exits). + done_event.set() + + # Assertion: We must not have detected a timeout. + self.assertFalse(deadlocked) + + # Close the channels + schan.close() + chan.close()