From 581103665b82f50d71aacb12881f9fd0b3fcca88 Mon Sep 17 00:00:00 2001 From: Robey Pointer Date: Wed, 3 May 2006 19:52:37 -0700 Subject: [PATCH] [project @ robey@lag.net-20060504025237-a015ee747d9a2e75] if open_channel fails, it now raises ChannelException. added a unit test for that too. renegotiate_keys will also raise an exception now instead of returning a bool. --- paramiko/__init__.py | 5 +-- paramiko/ssh_exception.py | 11 ++++++ paramiko/transport.py | 73 ++++++++++++++++++++++----------------- tests/test_transport.py | 40 +++++++++++++++++---- 4 files changed, 88 insertions(+), 41 deletions(-) diff --git a/paramiko/__init__.py b/paramiko/__init__.py index c341d2b..e9d504a 100644 --- a/paramiko/__init__.py +++ b/paramiko/__init__.py @@ -69,7 +69,7 @@ from transport import randpool, SecurityOptions, Transport from client import SSHClient, MissingHostKeyPolicy, AutoAddPolicy, RejectPolicy from auth_handler import AuthHandler from channel import Channel, ChannelFile -from ssh_exception import SSHException, PasswordRequiredException, BadAuthenticationType +from ssh_exception import SSHException, PasswordRequiredException, BadAuthenticationType, ChannelException from server import ServerInterface, SubsystemHandler, InteractiveQuery from rsakey import RSAKey from dsskey import DSSKey @@ -94,7 +94,7 @@ for x in (Transport, SecurityOptions, Channel, SFTPServer, SSHException, SFTP, SFTPClient, SFTPServer, Message, Packetizer, SFTPAttributes, SFTPHandle, SFTPServerInterface, BufferedFile, Agent, AgentKey, PKey, BaseSFTP, SFTPFile, ServerInterface, HostKeys, SSHClient, - MissingHostKeyPolicy, AutoAddPolicy, RejectPolicy): + MissingHostKeyPolicy, AutoAddPolicy, RejectPolicy, ChannelException): x.__module__ = 'paramiko' from common import AUTH_SUCCESSFUL, AUTH_PARTIALLY_SUCCESSFUL, AUTH_FAILED, \ @@ -119,6 +119,7 @@ __all__ = [ 'Transport', 'SSHException', 'PasswordRequiredException', 'BadAuthenticationType', + 'ChannelException', 'SFTP', 'SFTPFile', 'SFTPHandle', diff --git a/paramiko/ssh_exception.py b/paramiko/ssh_exception.py index 3aa4860..99eaa64 100644 --- a/paramiko/ssh_exception.py +++ b/paramiko/ssh_exception.py @@ -67,3 +67,14 @@ class PartialAuthentication (SSHException): def __init__(self, types): SSHException.__init__(self, 'partial authentication') self.allowed_types = types + + +class ChannelException (SSHException): + """ + Exception raised when an attempt to open a new L{Channel} fails. + + @since: 1.6 + """ + def __init__(self, code, text): + SSHException.__init__(self, text) + self.code = code diff --git a/paramiko/transport.py b/paramiko/transport.py index 6fe7218..31a5423 100644 --- a/paramiko/transport.py +++ b/paramiko/transport.py @@ -43,7 +43,7 @@ from paramiko.primes import ModulusPack from paramiko.rsakey import RSAKey from paramiko.server import ServerInterface from paramiko.sftp_client import SFTPClient -from paramiko.ssh_exception import SSHException, BadAuthenticationType +from paramiko.ssh_exception import SSHException, BadAuthenticationType, ChannelException # these come from PyCrypt # http://www.amk.ca/python/writing/pycrypt/ @@ -558,7 +558,7 @@ class Transport (threading.Thread): @raise SSHException: if no session is currently active. - @return: public key of the remote server. + @return: public key of the remote server @rtype: L{PKey } """ if (not self.active) or (not self.initial_kex_done): @@ -570,7 +570,7 @@ class Transport (threading.Thread): Return true if this session is active (open). @return: True if the session is still active (open); False if the - session is closed. + session is closed @rtype: bool """ return self.active @@ -580,9 +580,11 @@ class Transport (threading.Thread): Request a new channel to the server, of type C{"session"}. This is just an alias for C{open_channel('session')}. - @return: a new L{Channel} on success, or C{None} if the request is - rejected or the session ends prematurely. + @return: a new L{Channel} @rtype: L{Channel} + + @raise SSHException: if the request is rejected or the session ends + prematurely """ return self.open_channel('session') @@ -594,18 +596,20 @@ class Transport (threading.Thread): L{connect} or L{start_client}) and authenticating. @param kind: the kind of channel requested (usually C{"session"}, - C{"forwarded-tcpip"} or C{"direct-tcpip"}). + C{"forwarded-tcpip"} or C{"direct-tcpip"}) @type kind: str @param dest_addr: the destination address of this port forwarding, if C{kind} is C{"forwarded-tcpip"} or C{"direct-tcpip"} (ignored - for other channel types). + for other channel types) @type dest_addr: (str, int) @param src_addr: the source address of this port forwarding, if - C{kind} is C{"forwarded-tcpip"} or C{"direct-tcpip"}. + C{kind} is C{"forwarded-tcpip"} or C{"direct-tcpip"} @type src_addr: (str, int) - @return: a new L{Channel} on success, or C{None} if the request is - rejected or the session ends prematurely. + @return: a new L{Channel} on success @rtype: L{Channel} + + @raise SSHException: if the request is rejected or the session ends + prematurely """ chan = None if not self.active: @@ -637,19 +641,25 @@ class Transport (threading.Thread): finally: self.lock.release() self._send_user_message(m) - while 1: + while True: event.wait(0.1); if not self.active: - return None + e = self.get_exception() + if e is None: + e = SSHException('Unable to open channel.') + raise e if event.isSet(): break + self.lock.acquire() try: - self.lock.acquire() - if not self.channels.has_key(chanid): - chan = None + if self.channels.has_key(chanid): + return chan finally: self.lock.release() - return chan + e = self.get_exception() + if e is None: + e = SSHException('Unable to open channel.') + raise e def open_sftp_client(self): """ @@ -689,22 +699,23 @@ class Transport (threading.Thread): bytes sent or received, but this method gives you the option of forcing new keys whenever you want. Negotiating new keys causes a pause in traffic both ways as the two sides swap keys and do computations. This - method returns when the session has switched to new keys, or the - session has died mid-negotiation. + method returns when the session has switched to new keys. - @return: True if the renegotiation was successful, and the link is - using new keys; False if the session dropped during renegotiation. - @rtype: bool + @raise SSHException: if the key renegotiation failed (which causes the + session to end) """ self.completion_event = threading.Event() self._send_kex_init() - while 1: - self.completion_event.wait(0.1); + while True: + self.completion_event.wait(0.1) if not self.active: - return False + e = self.get_exception() + if e is not None: + raise e + raise SSHException('Negotiation failed.') if self.completion_event.isSet(): break - return True + return def set_keepalive(self, interval): """ @@ -1017,7 +1028,7 @@ class Transport (threading.Thread): except SSHException, ignored: # attempt failed; just raise the original exception raise x - return None + return None def auth_publickey(self, username, key, event=None): """ @@ -1741,14 +1752,12 @@ class Transport (threading.Thread): reason = m.get_int() reason_str = m.get_string() lang = m.get_string() - if CONNECTION_FAILED_CODE.has_key(reason): - reason_text = CONNECTION_FAILED_CODE[reason] - else: - reason_text = '(unknown code)' + reason_text = CONNECTION_FAILED_CODE.get(reason, '(unknown code)') self._log(INFO, 'Secsh channel %d open FAILED: %s: %s' % (chanid, reason_str, reason_text)) + self.lock.acquire() try: - self.lock.aquire() - if self.channels.has_key(chanid): + self.saved_exception = ChannelException(reason, reason_text) + if self.channel_events.has_key(chanid): del self.channels[chanid] if self.channel_events.has_key(chanid): self.channel_events[chanid].set() diff --git a/tests/test_transport.py b/tests/test_transport.py index 5fcc786..b2e8b6f 100644 --- a/tests/test_transport.py +++ b/tests/test_transport.py @@ -23,9 +23,9 @@ Some unit tests for the ssh2 protocol in Transport. import sys, time, threading, unittest import select from paramiko import Transport, SecurityOptions, ServerInterface, RSAKey, DSSKey, \ - SSHException, BadAuthenticationType, InteractiveQuery, util + SSHException, BadAuthenticationType, InteractiveQuery, util, ChannelException from paramiko import AUTH_FAILED, AUTH_PARTIALLY_SUCCESSFUL, AUTH_SUCCESSFUL -from paramiko import OPEN_SUCCEEDED +from paramiko import OPEN_SUCCEEDED, OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED from loop import LoopSocket @@ -81,6 +81,8 @@ class NullServer (ServerInterface): return AUTH_FAILED def check_channel_request(self, kind, chanid): + if kind == 'bogus': + return OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED return OPEN_SUCCEEDED def check_channel_exec_request(self, channel, command): @@ -189,7 +191,7 @@ class TransportTest (unittest.TestCase): self.assertEquals(12, self.tc.packetizer.get_mac_size_in()) self.tc.send_ignore(1024) - self.assert_(self.tc.renegotiate_keys()) + self.tc.renegotiate_keys() self.ts.send_ignore(1024) def test_5_keepalive(self): @@ -408,7 +410,31 @@ class TransportTest (unittest.TestCase): chan.close() self.assertEquals('', f.readline()) - def test_D_exit_status(self): + def test_D_channel_exception(self): + """ + verify that ChannelException is thrown for a bad open-channel request. + """ + host_key = RSAKey.from_private_key_file('tests/test_rsa.key') + public_host_key = RSAKey(data=str(host_key)) + self.ts.add_server_key(host_key) + event = threading.Event() + server = NullServer() + self.assert_(not event.isSet()) + self.ts.start_server(event, server) + self.tc.ultra_debug = True + self.tc.connect(hostkey=public_host_key) + self.tc.auth_password(username='slowdive', password='pygmalion') + event.wait(1.0) + self.assert_(event.isSet()) + self.assert_(self.ts.is_active()) + + try: + chan = self.tc.open_channel('bogus') + self.fail('expected exception') + except ChannelException, x: + self.assert_(x.code == OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED) + + def test_E_exit_status(self): """ verify that get_exit_status() works. """ @@ -442,7 +468,7 @@ class TransportTest (unittest.TestCase): self.assertEquals(23, chan.recv_exit_status()) chan.close() - def test_E_select(self): + def test_F_select(self): """ verify that select() on a channel works. """ @@ -505,7 +531,7 @@ class TransportTest (unittest.TestCase): chan.close() - def test_F_renegotiate(self): + def test_G_renegotiate(self): """ verify that a transport can correctly renegotiate mid-stream. """ @@ -541,7 +567,7 @@ class TransportTest (unittest.TestCase): schan.close() - def test_G_compression(self): + def test_H_compression(self): """ verify that zlib compression is basically working. """