split auth tests into their own file, and clean up the remaining transport
tests a bit (use existing refactoring).
This commit is contained in:
Robey Pointer 2008-01-23 17:38:49 -08:00
parent 4031ae9292
commit 31544301f5
3 changed files with 247 additions and 176 deletions

View File

@ -39,6 +39,7 @@ from test_hostkeys import HostKeysTest
from test_pkey import KeyTest from test_pkey import KeyTest
from test_kex import KexTest from test_kex import KexTest
from test_packetizer import PacketizerTest from test_packetizer import PacketizerTest
from test_auth import AuthTest
from test_transport import TransportTest from test_transport import TransportTest
from test_sftp import SFTPTest from test_sftp import SFTPTest
from test_sftp_big import BigSFTPTest from test_sftp_big import BigSFTPTest
@ -125,6 +126,7 @@ def main():
suite.addTest(unittest.makeSuite(KexTest)) suite.addTest(unittest.makeSuite(KexTest))
suite.addTest(unittest.makeSuite(PacketizerTest)) suite.addTest(unittest.makeSuite(PacketizerTest))
if options.use_transport: if options.use_transport:
suite.addTest(unittest.makeSuite(AuthTest))
suite.addTest(unittest.makeSuite(TransportTest)) suite.addTest(unittest.makeSuite(TransportTest))
suite.addTest(unittest.makeSuite(SSHClientTest)) suite.addTest(unittest.makeSuite(SSHClientTest))
if options.use_sftp: if options.use_sftp:

212
tests/test_auth.py Normal file
View File

@ -0,0 +1,212 @@
# Copyright (C) 2008 Robey Pointer <robey@lag.net>
#
# This file is part of paramiko.
#
# Paramiko is free software; you can redistribute it and/or modify it under the
# terms of the GNU Lesser General Public License as published by the Free
# Software Foundation; either version 2.1 of the License, or (at your option)
# any later version.
#
# Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
# details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with Paramiko; if not, write to the Free Software Foundation, Inc.,
# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA.
"""
Some unit tests for authenticating over a Transport.
"""
import sys
import threading
import unittest
from paramiko import Transport, ServerInterface, RSAKey, DSSKey, \
SSHException, BadAuthenticationType, InteractiveQuery, ChannelException
from paramiko import AUTH_FAILED, AUTH_PARTIALLY_SUCCESSFUL, AUTH_SUCCESSFUL
from paramiko import OPEN_SUCCEEDED, OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED
from loop import LoopSocket
class NullServer (ServerInterface):
paranoid_did_password = False
paranoid_did_public_key = False
paranoid_key = DSSKey.from_private_key_file('tests/test_dss.key')
def get_allowed_auths(self, username):
if username == 'slowdive':
return 'publickey,password'
if username == 'paranoid':
if not self.paranoid_did_password and not self.paranoid_did_public_key:
return 'publickey,password'
elif self.paranoid_did_password:
return 'publickey'
else:
return 'password'
if username == 'commie':
return 'keyboard-interactive'
return 'publickey'
def check_auth_password(self, username, password):
if (username == 'slowdive') and (password == 'pygmalion'):
return AUTH_SUCCESSFUL
if (username == 'paranoid') and (password == 'paranoid'):
# 2-part auth (even openssh doesn't support this)
self.paranoid_did_password = True
if self.paranoid_did_public_key:
return AUTH_SUCCESSFUL
return AUTH_PARTIALLY_SUCCESSFUL
if (username == 'utf8') and (password == u'\u2022'.encode('utf-8')):
return AUTH_SUCCESSFUL
return AUTH_FAILED
def check_auth_publickey(self, username, key):
if (username == 'paranoid') and (key == self.paranoid_key):
# 2-part auth
self.paranoid_did_public_key = True
if self.paranoid_did_password:
return AUTH_SUCCESSFUL
return AUTH_PARTIALLY_SUCCESSFUL
return AUTH_FAILED
def check_auth_interactive(self, username, submethods):
if username == 'commie':
self.username = username
return InteractiveQuery('password', 'Please enter a password.', ('Password', False))
return AUTH_FAILED
def check_auth_interactive_response(self, responses):
if self.username == 'commie':
if (len(responses) == 1) and (responses[0] == 'cat'):
return AUTH_SUCCESSFUL
return AUTH_FAILED
class AuthTest (unittest.TestCase):
def setUp(self):
self.socks = LoopSocket()
self.sockc = LoopSocket()
self.sockc.link(self.socks)
self.tc = Transport(self.sockc)
self.ts = Transport(self.socks)
def tearDown(self):
self.tc.close()
self.ts.close()
self.socks.close()
self.sockc.close()
def test_1_bad_auth_type(self):
"""
verify that we get the right exception when an unsupported auth
type is requested.
"""
host_key = RSAKey.from_private_key_file('tests/test_rsa.key')
public_host_key = RSAKey(data=str(host_key))
self.ts.add_server_key(host_key)
event = threading.Event()
server = NullServer()
self.assert_(not event.isSet())
self.ts.start_server(event, server)
try:
self.tc.connect(hostkey=public_host_key,
username='unknown', password='error')
self.assert_(False)
except:
etype, evalue, etb = sys.exc_info()
self.assertEquals(BadAuthenticationType, etype)
self.assertEquals(['publickey'], evalue.allowed_types)
def test_2_bad_password(self):
"""
verify that a bad password gets the right exception, and that a retry
with the right password works.
"""
host_key = RSAKey.from_private_key_file('tests/test_rsa.key')
public_host_key = RSAKey(data=str(host_key))
self.ts.add_server_key(host_key)
event = threading.Event()
server = NullServer()
self.assert_(not event.isSet())
self.ts.start_server(event, server)
self.tc.connect(hostkey=public_host_key)
try:
self.tc.auth_password(username='slowdive', password='error')
self.assert_(False)
except:
etype, evalue, etb = sys.exc_info()
self.assert_(issubclass(etype, SSHException))
self.tc.auth_password(username='slowdive', password='pygmalion')
event.wait(1.0)
self.assert_(event.isSet())
self.assert_(self.ts.is_active())
def test_3_multipart_auth(self):
"""
verify that multipart auth works.
"""
host_key = RSAKey.from_private_key_file('tests/test_rsa.key')
public_host_key = RSAKey(data=str(host_key))
self.ts.add_server_key(host_key)
event = threading.Event()
server = NullServer()
self.assert_(not event.isSet())
self.ts.start_server(event, server)
self.tc.connect(hostkey=public_host_key)
remain = self.tc.auth_password(username='paranoid', password='paranoid')
self.assertEquals(['publickey'], remain)
key = DSSKey.from_private_key_file('tests/test_dss.key')
remain = self.tc.auth_publickey(username='paranoid', key=key)
self.assertEquals([], remain)
event.wait(1.0)
self.assert_(event.isSet())
self.assert_(self.ts.is_active())
def test_4_interactive_auth(self):
"""
verify keyboard-interactive auth works.
"""
host_key = RSAKey.from_private_key_file('tests/test_rsa.key')
public_host_key = RSAKey(data=str(host_key))
self.ts.add_server_key(host_key)
event = threading.Event()
server = NullServer()
self.assert_(not event.isSet())
self.ts.start_server(event, server)
self.tc.connect(hostkey=public_host_key)
def handler(title, instructions, prompts):
self.got_title = title
self.got_instructions = instructions
self.got_prompts = prompts
return ['cat']
remain = self.tc.auth_interactive('commie', handler)
self.assertEquals(self.got_title, 'password')
self.assertEquals(self.got_prompts, [('Password', False)])
self.assertEquals([], remain)
event.wait(1.0)
self.assert_(event.isSet())
self.assert_(self.ts.is_active())
def test_5_interactive_auth_fallback(self):
"""
verify that a password auth attempt will fallback to "interactive"
if password auth isn't supported but interactive is.
"""
host_key = RSAKey.from_private_key_file('tests/test_rsa.key')
public_host_key = RSAKey(data=str(host_key))
self.ts.add_server_key(host_key)
event = threading.Event()
server = NullServer()
self.assert_(not event.isSet())
self.ts.start_server(event, server)
self.tc.connect(hostkey=public_host_key)
remain = self.tc.auth_password('commie', 'cat')
self.assertEquals([], remain)
event.wait(1.0)
self.assert_(event.isSet())
self.assert_(self.ts.is_active())

View File

@ -63,6 +63,8 @@ class NullServer (ServerInterface):
if self.paranoid_did_public_key: if self.paranoid_did_public_key:
return AUTH_SUCCESSFUL return AUTH_SUCCESSFUL
return AUTH_PARTIALLY_SUCCESSFUL return AUTH_PARTIALLY_SUCCESSFUL
if (username == 'utf8') and (password == u'\u2022'.encode('utf-8')):
return AUTH_SUCCESSFUL
return AUTH_FAILED return AUTH_FAILED
def check_auth_publickey(self, username, key): def check_auth_publickey(self, username, key):
@ -139,16 +141,22 @@ class TransportTest (unittest.TestCase):
self.socks.close() self.socks.close()
self.sockc.close() self.sockc.close()
def setup_test_server(self): def setup_test_server(self, client_options=None, server_options=None):
host_key = RSAKey.from_private_key_file('tests/test_rsa.key') host_key = RSAKey.from_private_key_file('tests/test_rsa.key')
public_host_key = RSAKey(data=str(host_key)) public_host_key = RSAKey(data=str(host_key))
self.ts.add_server_key(host_key) self.ts.add_server_key(host_key)
if client_options is not None:
client_options(self.tc.get_security_options())
if server_options is not None:
server_options(self.ts.get_security_options())
event = threading.Event() event = threading.Event()
self.server = NullServer() self.server = NullServer()
self.assert_(not event.isSet()) self.assert_(not event.isSet())
self.ts.start_server(event, self.server) self.ts.start_server(event, self.server)
self.tc.connect(hostkey=public_host_key) self.tc.connect(hostkey=public_host_key,
self.tc.auth_password(username='slowdive', password='pygmalion') username='slowdive', password='pygmalion')
event.wait(1.0) event.wait(1.0)
self.assert_(event.isSet()) self.assert_(event.isSet())
self.assert_(self.ts.is_active()) self.assert_(self.ts.is_active())
@ -210,21 +218,10 @@ class TransportTest (unittest.TestCase):
verify that the client can demand odd handshake settings, and can verify that the client can demand odd handshake settings, and can
renegotiate keys in mid-stream. renegotiate keys in mid-stream.
""" """
host_key = RSAKey.from_private_key_file('tests/test_rsa.key') def force_algorithms(options):
public_host_key = RSAKey(data=str(host_key)) options.ciphers = ('aes256-cbc',)
self.ts.add_server_key(host_key) options.digests = ('hmac-md5-96',)
event = threading.Event() self.setup_test_server(client_options=force_algorithms)
server = NullServer()
self.assert_(not event.isSet())
self.ts.start_server(event, server)
options = self.tc.get_security_options()
options.ciphers = ('aes256-cbc',)
options.digests = ('hmac-md5-96',)
self.tc.connect(hostkey=public_host_key,
username='slowdive', password='pygmalion')
event.wait(1.0)
self.assert_(event.isSet())
self.assert_(self.ts.is_active())
self.assertEquals('aes256-cbc', self.tc.local_cipher) self.assertEquals('aes256-cbc', self.tc.local_cipher)
self.assertEquals('aes256-cbc', self.tc.remote_cipher) self.assertEquals('aes256-cbc', self.tc.remote_cipher)
self.assertEquals(12, self.tc.packetizer.get_mac_size_out()) self.assertEquals(12, self.tc.packetizer.get_mac_size_out())
@ -238,142 +235,13 @@ class TransportTest (unittest.TestCase):
""" """
verify that the keepalive will be sent. verify that the keepalive will be sent.
""" """
self.tc.set_hexdump(True) self.setup_test_server()
self.assertEquals(None, getattr(self.server, '_global_request', None))
host_key = RSAKey.from_private_key_file('tests/test_rsa.key')
public_host_key = RSAKey(data=str(host_key))
self.ts.add_server_key(host_key)
event = threading.Event()
server = NullServer()
self.assert_(not event.isSet())
self.ts.start_server(event, server)
self.tc.connect(hostkey=public_host_key,
username='slowdive', password='pygmalion')
event.wait(1.0)
self.assert_(event.isSet())
self.assert_(self.ts.is_active())
self.assertEquals(None, getattr(server, '_global_request', None))
self.tc.set_keepalive(1) self.tc.set_keepalive(1)
time.sleep(2) time.sleep(2)
self.assertEquals('keepalive@lag.net', server._global_request) self.assertEquals('keepalive@lag.net', self.server._global_request)
def test_6_bad_auth_type(self): def test_6_exec_command(self):
"""
verify that we get the right exception when an unsupported auth
type is requested.
"""
host_key = RSAKey.from_private_key_file('tests/test_rsa.key')
public_host_key = RSAKey(data=str(host_key))
self.ts.add_server_key(host_key)
event = threading.Event()
server = NullServer()
self.assert_(not event.isSet())
self.ts.start_server(event, server)
try:
self.tc.connect(hostkey=public_host_key,
username='unknown', password='error')
self.assert_(False)
except:
etype, evalue, etb = sys.exc_info()
self.assertEquals(BadAuthenticationType, etype)
self.assertEquals(['publickey'], evalue.allowed_types)
def test_7_bad_password(self):
"""
verify that a bad password gets the right exception, and that a retry
with the right password works.
"""
host_key = RSAKey.from_private_key_file('tests/test_rsa.key')
public_host_key = RSAKey(data=str(host_key))
self.ts.add_server_key(host_key)
event = threading.Event()
server = NullServer()
self.assert_(not event.isSet())
self.ts.start_server(event, server)
self.tc.ultra_debug = True
self.tc.connect(hostkey=public_host_key)
try:
self.tc.auth_password(username='slowdive', password='error')
self.assert_(False)
except:
etype, evalue, etb = sys.exc_info()
self.assert_(issubclass(etype, SSHException))
self.tc.auth_password(username='slowdive', password='pygmalion')
event.wait(1.0)
self.assert_(event.isSet())
self.assert_(self.ts.is_active())
def test_8_multipart_auth(self):
"""
verify that multipart auth works.
"""
host_key = RSAKey.from_private_key_file('tests/test_rsa.key')
public_host_key = RSAKey(data=str(host_key))
self.ts.add_server_key(host_key)
event = threading.Event()
server = NullServer()
self.assert_(not event.isSet())
self.ts.start_server(event, server)
self.tc.ultra_debug = True
self.tc.connect(hostkey=public_host_key)
remain = self.tc.auth_password(username='paranoid', password='paranoid')
self.assertEquals(['publickey'], remain)
key = DSSKey.from_private_key_file('tests/test_dss.key')
remain = self.tc.auth_publickey(username='paranoid', key=key)
self.assertEquals([], remain)
event.wait(1.0)
self.assert_(event.isSet())
self.assert_(self.ts.is_active())
def test_9_interactive_auth(self):
"""
verify keyboard-interactive auth works.
"""
host_key = RSAKey.from_private_key_file('tests/test_rsa.key')
public_host_key = RSAKey(data=str(host_key))
self.ts.add_server_key(host_key)
event = threading.Event()
server = NullServer()
self.assert_(not event.isSet())
self.ts.start_server(event, server)
self.tc.ultra_debug = True
self.tc.connect(hostkey=public_host_key)
def handler(title, instructions, prompts):
self.got_title = title
self.got_instructions = instructions
self.got_prompts = prompts
return ['cat']
remain = self.tc.auth_interactive('commie', handler)
self.assertEquals(self.got_title, 'password')
self.assertEquals(self.got_prompts, [('Password', False)])
self.assertEquals([], remain)
event.wait(1.0)
self.assert_(event.isSet())
self.assert_(self.ts.is_active())
def test_A_interactive_auth_fallback(self):
"""
verify that a password auth attempt will fallback to "interactive"
if password auth isn't supported but interactive is.
"""
host_key = RSAKey.from_private_key_file('tests/test_rsa.key')
public_host_key = RSAKey(data=str(host_key))
self.ts.add_server_key(host_key)
event = threading.Event()
server = NullServer()
self.assert_(not event.isSet())
self.ts.start_server(event, server)
self.tc.ultra_debug = True
self.tc.connect(hostkey=public_host_key)
remain = self.tc.auth_password('commie', 'cat')
self.assertEquals([], remain)
event.wait(1.0)
self.assert_(event.isSet())
self.assert_(self.ts.is_active())
def test_B_exec_command(self):
""" """
verify that exec_command() does something reasonable. verify that exec_command() does something reasonable.
""" """
@ -415,7 +283,7 @@ class TransportTest (unittest.TestCase):
self.assertEquals('This is on stderr.\n', f.readline()) self.assertEquals('This is on stderr.\n', f.readline())
self.assertEquals('', f.readline()) self.assertEquals('', f.readline())
def test_C_invoke_shell(self): def test_7_invoke_shell(self):
""" """
verify that invoke_shell() does something reasonable. verify that invoke_shell() does something reasonable.
""" """
@ -429,7 +297,7 @@ class TransportTest (unittest.TestCase):
chan.close() chan.close()
self.assertEquals('', f.readline()) self.assertEquals('', f.readline())
def test_D_channel_exception(self): def test_8_channel_exception(self):
""" """
verify that ChannelException is thrown for a bad open-channel request. verify that ChannelException is thrown for a bad open-channel request.
""" """
@ -440,7 +308,7 @@ class TransportTest (unittest.TestCase):
except ChannelException, x: except ChannelException, x:
self.assert_(x.code == OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED) self.assert_(x.code == OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED)
def test_E_exit_status(self): def test_9_exit_status(self):
""" """
verify that get_exit_status() works. verify that get_exit_status() works.
""" """
@ -462,7 +330,7 @@ class TransportTest (unittest.TestCase):
self.assertEquals(23, chan.recv_exit_status()) self.assertEquals(23, chan.recv_exit_status())
chan.close() chan.close()
def test_F_select(self): def test_A_select(self):
""" """
verify that select() on a channel works. verify that select() on a channel works.
""" """
@ -517,7 +385,7 @@ class TransportTest (unittest.TestCase):
# ...and now is closed. # ...and now is closed.
self.assertEquals(True, p._closed) self.assertEquals(True, p._closed)
def test_G_renegotiate(self): def test_B_renegotiate(self):
""" """
verify that a transport can correctly renegotiate mid-stream. verify that a transport can correctly renegotiate mid-stream.
""" """
@ -541,24 +409,13 @@ class TransportTest (unittest.TestCase):
schan.close() schan.close()
def test_H_compression(self): def test_C_compression(self):
""" """
verify that zlib compression is basically working. verify that zlib compression is basically working.
""" """
host_key = RSAKey.from_private_key_file('tests/test_rsa.key') def force_compression(o):
public_host_key = RSAKey(data=str(host_key)) o.compression = ('zlib',)
self.ts.add_server_key(host_key) self.setup_test_server(force_compression, force_compression)
self.ts.get_security_options().compression = ('zlib',)
self.tc.get_security_options().compression = ('zlib',)
event = threading.Event()
server = NullServer()
self.ts.start_server(event, server)
self.tc.connect(hostkey=public_host_key,
username='slowdive', password='pygmalion')
event.wait(1.0)
self.assert_(event.isSet())
self.assert_(self.ts.is_active())
chan = self.tc.open_session() chan = self.tc.open_session()
chan.exec_command('yes') chan.exec_command('yes')
schan = self.ts.accept(1.0) schan = self.ts.accept(1.0)
@ -573,7 +430,7 @@ class TransportTest (unittest.TestCase):
chan.close() chan.close()
schan.close() schan.close()
def test_I_x11(self): def test_D_x11(self):
""" """
verify that an x11 port can be requested and opened. verify that an x11 port can be requested and opened.
""" """
@ -607,7 +464,7 @@ class TransportTest (unittest.TestCase):
chan.close() chan.close()
schan.close() schan.close()
def test_J_reverse_port_forwarding(self): def test_E_reverse_port_forwarding(self):
""" """
verify that a client can ask the server to open a reverse port for verify that a client can ask the server to open a reverse port for
forwarding. forwarding.
@ -643,7 +500,7 @@ class TransportTest (unittest.TestCase):
self.tc.cancel_port_forward('', port) self.tc.cancel_port_forward('', port)
self.assertTrue(self.server._listen is None) self.assertTrue(self.server._listen is None)
def test_K_port_forwarding(self): def test_F_port_forwarding(self):
""" """
verify that a client can forward new connections from a locally- verify that a client can forward new connections from a locally-
forwarded port. forwarded port.
@ -672,7 +529,7 @@ class TransportTest (unittest.TestCase):
self.assertEquals('Hello!\n', cs.recv(7)) self.assertEquals('Hello!\n', cs.recv(7))
cs.close() cs.close()
def test_L_stderr_select(self): def test_G_stderr_select(self):
""" """
verify that select() on a channel works even if only stderr is verify that select() on a channel works even if only stderr is
receiving data. receiving data.
@ -711,7 +568,7 @@ class TransportTest (unittest.TestCase):
schan.close() schan.close()
chan.close() chan.close()
def test_M_send_ready(self): def test_H_send_ready(self):
""" """
verify that send_ready() indicates when a send would not block. verify that send_ready() indicates when a send would not block.
""" """