if open_channel fails, it now raises ChannelException.  added a unit test for that too.  renegotiate_keys will also raise an exception now instead of returning a bool.
This commit is contained in:
Robey Pointer 2006-05-03 19:52:37 -07:00
parent aac434e9b0
commit 581103665b
4 changed files with 88 additions and 41 deletions

View File

@ -69,7 +69,7 @@ from transport import randpool, SecurityOptions, Transport
from client import SSHClient, MissingHostKeyPolicy, AutoAddPolicy, RejectPolicy
from auth_handler import AuthHandler
from channel import Channel, ChannelFile
from ssh_exception import SSHException, PasswordRequiredException, BadAuthenticationType
from ssh_exception import SSHException, PasswordRequiredException, BadAuthenticationType, ChannelException
from server import ServerInterface, SubsystemHandler, InteractiveQuery
from rsakey import RSAKey
from dsskey import DSSKey
@ -94,7 +94,7 @@ for x in (Transport, SecurityOptions, Channel, SFTPServer, SSHException,
SFTP, SFTPClient, SFTPServer, Message, Packetizer, SFTPAttributes,
SFTPHandle, SFTPServerInterface, BufferedFile, Agent, AgentKey,
PKey, BaseSFTP, SFTPFile, ServerInterface, HostKeys, SSHClient,
MissingHostKeyPolicy, AutoAddPolicy, RejectPolicy):
MissingHostKeyPolicy, AutoAddPolicy, RejectPolicy, ChannelException):
x.__module__ = 'paramiko'
from common import AUTH_SUCCESSFUL, AUTH_PARTIALLY_SUCCESSFUL, AUTH_FAILED, \
@ -119,6 +119,7 @@ __all__ = [ 'Transport',
'SSHException',
'PasswordRequiredException',
'BadAuthenticationType',
'ChannelException',
'SFTP',
'SFTPFile',
'SFTPHandle',

View File

@ -67,3 +67,14 @@ class PartialAuthentication (SSHException):
def __init__(self, types):
SSHException.__init__(self, 'partial authentication')
self.allowed_types = types
class ChannelException (SSHException):
"""
Exception raised when an attempt to open a new L{Channel} fails.
@since: 1.6
"""
def __init__(self, code, text):
SSHException.__init__(self, text)
self.code = code

View File

@ -43,7 +43,7 @@ from paramiko.primes import ModulusPack
from paramiko.rsakey import RSAKey
from paramiko.server import ServerInterface
from paramiko.sftp_client import SFTPClient
from paramiko.ssh_exception import SSHException, BadAuthenticationType
from paramiko.ssh_exception import SSHException, BadAuthenticationType, ChannelException
# these come from PyCrypt
# http://www.amk.ca/python/writing/pycrypt/
@ -558,7 +558,7 @@ class Transport (threading.Thread):
@raise SSHException: if no session is currently active.
@return: public key of the remote server.
@return: public key of the remote server
@rtype: L{PKey <pkey.PKey>}
"""
if (not self.active) or (not self.initial_kex_done):
@ -570,7 +570,7 @@ class Transport (threading.Thread):
Return true if this session is active (open).
@return: True if the session is still active (open); False if the
session is closed.
session is closed
@rtype: bool
"""
return self.active
@ -580,9 +580,11 @@ class Transport (threading.Thread):
Request a new channel to the server, of type C{"session"}. This
is just an alias for C{open_channel('session')}.
@return: a new L{Channel} on success, or C{None} if the request is
rejected or the session ends prematurely.
@return: a new L{Channel}
@rtype: L{Channel}
@raise SSHException: if the request is rejected or the session ends
prematurely
"""
return self.open_channel('session')
@ -594,18 +596,20 @@ class Transport (threading.Thread):
L{connect} or L{start_client}) and authenticating.
@param kind: the kind of channel requested (usually C{"session"},
C{"forwarded-tcpip"} or C{"direct-tcpip"}).
C{"forwarded-tcpip"} or C{"direct-tcpip"})
@type kind: str
@param dest_addr: the destination address of this port forwarding,
if C{kind} is C{"forwarded-tcpip"} or C{"direct-tcpip"} (ignored
for other channel types).
for other channel types)
@type dest_addr: (str, int)
@param src_addr: the source address of this port forwarding, if
C{kind} is C{"forwarded-tcpip"} or C{"direct-tcpip"}.
C{kind} is C{"forwarded-tcpip"} or C{"direct-tcpip"}
@type src_addr: (str, int)
@return: a new L{Channel} on success, or C{None} if the request is
rejected or the session ends prematurely.
@return: a new L{Channel} on success
@rtype: L{Channel}
@raise SSHException: if the request is rejected or the session ends
prematurely
"""
chan = None
if not self.active:
@ -637,19 +641,25 @@ class Transport (threading.Thread):
finally:
self.lock.release()
self._send_user_message(m)
while 1:
while True:
event.wait(0.1);
if not self.active:
return None
e = self.get_exception()
if e is None:
e = SSHException('Unable to open channel.')
raise e
if event.isSet():
break
try:
self.lock.acquire()
if not self.channels.has_key(chanid):
chan = None
try:
if self.channels.has_key(chanid):
return chan
finally:
self.lock.release()
return chan
e = self.get_exception()
if e is None:
e = SSHException('Unable to open channel.')
raise e
def open_sftp_client(self):
"""
@ -689,22 +699,23 @@ class Transport (threading.Thread):
bytes sent or received, but this method gives you the option of forcing
new keys whenever you want. Negotiating new keys causes a pause in
traffic both ways as the two sides swap keys and do computations. This
method returns when the session has switched to new keys, or the
session has died mid-negotiation.
method returns when the session has switched to new keys.
@return: True if the renegotiation was successful, and the link is
using new keys; False if the session dropped during renegotiation.
@rtype: bool
@raise SSHException: if the key renegotiation failed (which causes the
session to end)
"""
self.completion_event = threading.Event()
self._send_kex_init()
while 1:
self.completion_event.wait(0.1);
while True:
self.completion_event.wait(0.1)
if not self.active:
return False
e = self.get_exception()
if e is not None:
raise e
raise SSHException('Negotiation failed.')
if self.completion_event.isSet():
break
return True
return
def set_keepalive(self, interval):
"""
@ -1741,14 +1752,12 @@ class Transport (threading.Thread):
reason = m.get_int()
reason_str = m.get_string()
lang = m.get_string()
if CONNECTION_FAILED_CODE.has_key(reason):
reason_text = CONNECTION_FAILED_CODE[reason]
else:
reason_text = '(unknown code)'
reason_text = CONNECTION_FAILED_CODE.get(reason, '(unknown code)')
self._log(INFO, 'Secsh channel %d open FAILED: %s: %s' % (chanid, reason_str, reason_text))
self.lock.acquire()
try:
self.lock.aquire()
if self.channels.has_key(chanid):
self.saved_exception = ChannelException(reason, reason_text)
if self.channel_events.has_key(chanid):
del self.channels[chanid]
if self.channel_events.has_key(chanid):
self.channel_events[chanid].set()

View File

@ -23,9 +23,9 @@ Some unit tests for the ssh2 protocol in Transport.
import sys, time, threading, unittest
import select
from paramiko import Transport, SecurityOptions, ServerInterface, RSAKey, DSSKey, \
SSHException, BadAuthenticationType, InteractiveQuery, util
SSHException, BadAuthenticationType, InteractiveQuery, util, ChannelException
from paramiko import AUTH_FAILED, AUTH_PARTIALLY_SUCCESSFUL, AUTH_SUCCESSFUL
from paramiko import OPEN_SUCCEEDED
from paramiko import OPEN_SUCCEEDED, OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED
from loop import LoopSocket
@ -81,6 +81,8 @@ class NullServer (ServerInterface):
return AUTH_FAILED
def check_channel_request(self, kind, chanid):
if kind == 'bogus':
return OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED
return OPEN_SUCCEEDED
def check_channel_exec_request(self, channel, command):
@ -189,7 +191,7 @@ class TransportTest (unittest.TestCase):
self.assertEquals(12, self.tc.packetizer.get_mac_size_in())
self.tc.send_ignore(1024)
self.assert_(self.tc.renegotiate_keys())
self.tc.renegotiate_keys()
self.ts.send_ignore(1024)
def test_5_keepalive(self):
@ -408,7 +410,31 @@ class TransportTest (unittest.TestCase):
chan.close()
self.assertEquals('', f.readline())
def test_D_exit_status(self):
def test_D_channel_exception(self):
"""
verify that ChannelException is thrown for a bad open-channel request.
"""
host_key = RSAKey.from_private_key_file('tests/test_rsa.key')
public_host_key = RSAKey(data=str(host_key))
self.ts.add_server_key(host_key)
event = threading.Event()
server = NullServer()
self.assert_(not event.isSet())
self.ts.start_server(event, server)
self.tc.ultra_debug = True
self.tc.connect(hostkey=public_host_key)
self.tc.auth_password(username='slowdive', password='pygmalion')
event.wait(1.0)
self.assert_(event.isSet())
self.assert_(self.ts.is_active())
try:
chan = self.tc.open_channel('bogus')
self.fail('expected exception')
except ChannelException, x:
self.assert_(x.code == OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED)
def test_E_exit_status(self):
"""
verify that get_exit_status() works.
"""
@ -442,7 +468,7 @@ class TransportTest (unittest.TestCase):
self.assertEquals(23, chan.recv_exit_status())
chan.close()
def test_E_select(self):
def test_F_select(self):
"""
verify that select() on a channel works.
"""
@ -505,7 +531,7 @@ class TransportTest (unittest.TestCase):
chan.close()
def test_F_renegotiate(self):
def test_G_renegotiate(self):
"""
verify that a transport can correctly renegotiate mid-stream.
"""
@ -541,7 +567,7 @@ class TransportTest (unittest.TestCase):
schan.close()
def test_G_compression(self):
def test_H_compression(self):
"""
verify that zlib compression is basically working.
"""