if open_channel fails, it now raises ChannelException.  added a unit test for that too.  renegotiate_keys will also raise an exception now instead of returning a bool.
This commit is contained in:
Robey Pointer 2006-05-03 19:52:37 -07:00
parent aac434e9b0
commit 581103665b
4 changed files with 88 additions and 41 deletions

View File

@ -69,7 +69,7 @@ from transport import randpool, SecurityOptions, Transport
from client import SSHClient, MissingHostKeyPolicy, AutoAddPolicy, RejectPolicy from client import SSHClient, MissingHostKeyPolicy, AutoAddPolicy, RejectPolicy
from auth_handler import AuthHandler from auth_handler import AuthHandler
from channel import Channel, ChannelFile from channel import Channel, ChannelFile
from ssh_exception import SSHException, PasswordRequiredException, BadAuthenticationType from ssh_exception import SSHException, PasswordRequiredException, BadAuthenticationType, ChannelException
from server import ServerInterface, SubsystemHandler, InteractiveQuery from server import ServerInterface, SubsystemHandler, InteractiveQuery
from rsakey import RSAKey from rsakey import RSAKey
from dsskey import DSSKey from dsskey import DSSKey
@ -94,7 +94,7 @@ for x in (Transport, SecurityOptions, Channel, SFTPServer, SSHException,
SFTP, SFTPClient, SFTPServer, Message, Packetizer, SFTPAttributes, SFTP, SFTPClient, SFTPServer, Message, Packetizer, SFTPAttributes,
SFTPHandle, SFTPServerInterface, BufferedFile, Agent, AgentKey, SFTPHandle, SFTPServerInterface, BufferedFile, Agent, AgentKey,
PKey, BaseSFTP, SFTPFile, ServerInterface, HostKeys, SSHClient, PKey, BaseSFTP, SFTPFile, ServerInterface, HostKeys, SSHClient,
MissingHostKeyPolicy, AutoAddPolicy, RejectPolicy): MissingHostKeyPolicy, AutoAddPolicy, RejectPolicy, ChannelException):
x.__module__ = 'paramiko' x.__module__ = 'paramiko'
from common import AUTH_SUCCESSFUL, AUTH_PARTIALLY_SUCCESSFUL, AUTH_FAILED, \ from common import AUTH_SUCCESSFUL, AUTH_PARTIALLY_SUCCESSFUL, AUTH_FAILED, \
@ -119,6 +119,7 @@ __all__ = [ 'Transport',
'SSHException', 'SSHException',
'PasswordRequiredException', 'PasswordRequiredException',
'BadAuthenticationType', 'BadAuthenticationType',
'ChannelException',
'SFTP', 'SFTP',
'SFTPFile', 'SFTPFile',
'SFTPHandle', 'SFTPHandle',

View File

@ -67,3 +67,14 @@ class PartialAuthentication (SSHException):
def __init__(self, types): def __init__(self, types):
SSHException.__init__(self, 'partial authentication') SSHException.__init__(self, 'partial authentication')
self.allowed_types = types self.allowed_types = types
class ChannelException (SSHException):
"""
Exception raised when an attempt to open a new L{Channel} fails.
@since: 1.6
"""
def __init__(self, code, text):
SSHException.__init__(self, text)
self.code = code

View File

@ -43,7 +43,7 @@ from paramiko.primes import ModulusPack
from paramiko.rsakey import RSAKey from paramiko.rsakey import RSAKey
from paramiko.server import ServerInterface from paramiko.server import ServerInterface
from paramiko.sftp_client import SFTPClient from paramiko.sftp_client import SFTPClient
from paramiko.ssh_exception import SSHException, BadAuthenticationType from paramiko.ssh_exception import SSHException, BadAuthenticationType, ChannelException
# these come from PyCrypt # these come from PyCrypt
# http://www.amk.ca/python/writing/pycrypt/ # http://www.amk.ca/python/writing/pycrypt/
@ -558,7 +558,7 @@ class Transport (threading.Thread):
@raise SSHException: if no session is currently active. @raise SSHException: if no session is currently active.
@return: public key of the remote server. @return: public key of the remote server
@rtype: L{PKey <pkey.PKey>} @rtype: L{PKey <pkey.PKey>}
""" """
if (not self.active) or (not self.initial_kex_done): if (not self.active) or (not self.initial_kex_done):
@ -570,7 +570,7 @@ class Transport (threading.Thread):
Return true if this session is active (open). Return true if this session is active (open).
@return: True if the session is still active (open); False if the @return: True if the session is still active (open); False if the
session is closed. session is closed
@rtype: bool @rtype: bool
""" """
return self.active return self.active
@ -580,9 +580,11 @@ class Transport (threading.Thread):
Request a new channel to the server, of type C{"session"}. This Request a new channel to the server, of type C{"session"}. This
is just an alias for C{open_channel('session')}. is just an alias for C{open_channel('session')}.
@return: a new L{Channel} on success, or C{None} if the request is @return: a new L{Channel}
rejected or the session ends prematurely.
@rtype: L{Channel} @rtype: L{Channel}
@raise SSHException: if the request is rejected or the session ends
prematurely
""" """
return self.open_channel('session') return self.open_channel('session')
@ -594,18 +596,20 @@ class Transport (threading.Thread):
L{connect} or L{start_client}) and authenticating. L{connect} or L{start_client}) and authenticating.
@param kind: the kind of channel requested (usually C{"session"}, @param kind: the kind of channel requested (usually C{"session"},
C{"forwarded-tcpip"} or C{"direct-tcpip"}). C{"forwarded-tcpip"} or C{"direct-tcpip"})
@type kind: str @type kind: str
@param dest_addr: the destination address of this port forwarding, @param dest_addr: the destination address of this port forwarding,
if C{kind} is C{"forwarded-tcpip"} or C{"direct-tcpip"} (ignored if C{kind} is C{"forwarded-tcpip"} or C{"direct-tcpip"} (ignored
for other channel types). for other channel types)
@type dest_addr: (str, int) @type dest_addr: (str, int)
@param src_addr: the source address of this port forwarding, if @param src_addr: the source address of this port forwarding, if
C{kind} is C{"forwarded-tcpip"} or C{"direct-tcpip"}. C{kind} is C{"forwarded-tcpip"} or C{"direct-tcpip"}
@type src_addr: (str, int) @type src_addr: (str, int)
@return: a new L{Channel} on success, or C{None} if the request is @return: a new L{Channel} on success
rejected or the session ends prematurely.
@rtype: L{Channel} @rtype: L{Channel}
@raise SSHException: if the request is rejected or the session ends
prematurely
""" """
chan = None chan = None
if not self.active: if not self.active:
@ -637,19 +641,25 @@ class Transport (threading.Thread):
finally: finally:
self.lock.release() self.lock.release()
self._send_user_message(m) self._send_user_message(m)
while 1: while True:
event.wait(0.1); event.wait(0.1);
if not self.active: if not self.active:
return None e = self.get_exception()
if e is None:
e = SSHException('Unable to open channel.')
raise e
if event.isSet(): if event.isSet():
break break
self.lock.acquire()
try: try:
self.lock.acquire() if self.channels.has_key(chanid):
if not self.channels.has_key(chanid): return chan
chan = None
finally: finally:
self.lock.release() self.lock.release()
return chan e = self.get_exception()
if e is None:
e = SSHException('Unable to open channel.')
raise e
def open_sftp_client(self): def open_sftp_client(self):
""" """
@ -689,22 +699,23 @@ class Transport (threading.Thread):
bytes sent or received, but this method gives you the option of forcing bytes sent or received, but this method gives you the option of forcing
new keys whenever you want. Negotiating new keys causes a pause in new keys whenever you want. Negotiating new keys causes a pause in
traffic both ways as the two sides swap keys and do computations. This traffic both ways as the two sides swap keys and do computations. This
method returns when the session has switched to new keys, or the method returns when the session has switched to new keys.
session has died mid-negotiation.
@return: True if the renegotiation was successful, and the link is @raise SSHException: if the key renegotiation failed (which causes the
using new keys; False if the session dropped during renegotiation. session to end)
@rtype: bool
""" """
self.completion_event = threading.Event() self.completion_event = threading.Event()
self._send_kex_init() self._send_kex_init()
while 1: while True:
self.completion_event.wait(0.1); self.completion_event.wait(0.1)
if not self.active: if not self.active:
return False e = self.get_exception()
if e is not None:
raise e
raise SSHException('Negotiation failed.')
if self.completion_event.isSet(): if self.completion_event.isSet():
break break
return True return
def set_keepalive(self, interval): def set_keepalive(self, interval):
""" """
@ -1017,7 +1028,7 @@ class Transport (threading.Thread):
except SSHException, ignored: except SSHException, ignored:
# attempt failed; just raise the original exception # attempt failed; just raise the original exception
raise x raise x
return None return None
def auth_publickey(self, username, key, event=None): def auth_publickey(self, username, key, event=None):
""" """
@ -1741,14 +1752,12 @@ class Transport (threading.Thread):
reason = m.get_int() reason = m.get_int()
reason_str = m.get_string() reason_str = m.get_string()
lang = m.get_string() lang = m.get_string()
if CONNECTION_FAILED_CODE.has_key(reason): reason_text = CONNECTION_FAILED_CODE.get(reason, '(unknown code)')
reason_text = CONNECTION_FAILED_CODE[reason]
else:
reason_text = '(unknown code)'
self._log(INFO, 'Secsh channel %d open FAILED: %s: %s' % (chanid, reason_str, reason_text)) self._log(INFO, 'Secsh channel %d open FAILED: %s: %s' % (chanid, reason_str, reason_text))
self.lock.acquire()
try: try:
self.lock.aquire() self.saved_exception = ChannelException(reason, reason_text)
if self.channels.has_key(chanid): if self.channel_events.has_key(chanid):
del self.channels[chanid] del self.channels[chanid]
if self.channel_events.has_key(chanid): if self.channel_events.has_key(chanid):
self.channel_events[chanid].set() self.channel_events[chanid].set()

View File

@ -23,9 +23,9 @@ Some unit tests for the ssh2 protocol in Transport.
import sys, time, threading, unittest import sys, time, threading, unittest
import select import select
from paramiko import Transport, SecurityOptions, ServerInterface, RSAKey, DSSKey, \ from paramiko import Transport, SecurityOptions, ServerInterface, RSAKey, DSSKey, \
SSHException, BadAuthenticationType, InteractiveQuery, util SSHException, BadAuthenticationType, InteractiveQuery, util, ChannelException
from paramiko import AUTH_FAILED, AUTH_PARTIALLY_SUCCESSFUL, AUTH_SUCCESSFUL from paramiko import AUTH_FAILED, AUTH_PARTIALLY_SUCCESSFUL, AUTH_SUCCESSFUL
from paramiko import OPEN_SUCCEEDED from paramiko import OPEN_SUCCEEDED, OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED
from loop import LoopSocket from loop import LoopSocket
@ -81,6 +81,8 @@ class NullServer (ServerInterface):
return AUTH_FAILED return AUTH_FAILED
def check_channel_request(self, kind, chanid): def check_channel_request(self, kind, chanid):
if kind == 'bogus':
return OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED
return OPEN_SUCCEEDED return OPEN_SUCCEEDED
def check_channel_exec_request(self, channel, command): def check_channel_exec_request(self, channel, command):
@ -189,7 +191,7 @@ class TransportTest (unittest.TestCase):
self.assertEquals(12, self.tc.packetizer.get_mac_size_in()) self.assertEquals(12, self.tc.packetizer.get_mac_size_in())
self.tc.send_ignore(1024) self.tc.send_ignore(1024)
self.assert_(self.tc.renegotiate_keys()) self.tc.renegotiate_keys()
self.ts.send_ignore(1024) self.ts.send_ignore(1024)
def test_5_keepalive(self): def test_5_keepalive(self):
@ -408,7 +410,31 @@ class TransportTest (unittest.TestCase):
chan.close() chan.close()
self.assertEquals('', f.readline()) self.assertEquals('', f.readline())
def test_D_exit_status(self): def test_D_channel_exception(self):
"""
verify that ChannelException is thrown for a bad open-channel request.
"""
host_key = RSAKey.from_private_key_file('tests/test_rsa.key')
public_host_key = RSAKey(data=str(host_key))
self.ts.add_server_key(host_key)
event = threading.Event()
server = NullServer()
self.assert_(not event.isSet())
self.ts.start_server(event, server)
self.tc.ultra_debug = True
self.tc.connect(hostkey=public_host_key)
self.tc.auth_password(username='slowdive', password='pygmalion')
event.wait(1.0)
self.assert_(event.isSet())
self.assert_(self.ts.is_active())
try:
chan = self.tc.open_channel('bogus')
self.fail('expected exception')
except ChannelException, x:
self.assert_(x.code == OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED)
def test_E_exit_status(self):
""" """
verify that get_exit_status() works. verify that get_exit_status() works.
""" """
@ -442,7 +468,7 @@ class TransportTest (unittest.TestCase):
self.assertEquals(23, chan.recv_exit_status()) self.assertEquals(23, chan.recv_exit_status())
chan.close() chan.close()
def test_E_select(self): def test_F_select(self):
""" """
verify that select() on a channel works. verify that select() on a channel works.
""" """
@ -505,7 +531,7 @@ class TransportTest (unittest.TestCase):
chan.close() chan.close()
def test_F_renegotiate(self): def test_G_renegotiate(self):
""" """
verify that a transport can correctly renegotiate mid-stream. verify that a transport can correctly renegotiate mid-stream.
""" """
@ -541,7 +567,7 @@ class TransportTest (unittest.TestCase):
schan.close() schan.close()
def test_G_compression(self): def test_H_compression(self):
""" """
verify that zlib compression is basically working. verify that zlib compression is basically working.
""" """