diff --git a/test.py b/test.py index 1370b90..b7f6c8a 100755 --- a/test.py +++ b/test.py @@ -39,6 +39,7 @@ from test_hostkeys import HostKeysTest from test_pkey import KeyTest from test_kex import KexTest from test_packetizer import PacketizerTest +from test_auth import AuthTest from test_transport import TransportTest from test_sftp import SFTPTest from test_sftp_big import BigSFTPTest @@ -125,6 +126,7 @@ def main(): suite.addTest(unittest.makeSuite(KexTest)) suite.addTest(unittest.makeSuite(PacketizerTest)) if options.use_transport: + suite.addTest(unittest.makeSuite(AuthTest)) suite.addTest(unittest.makeSuite(TransportTest)) suite.addTest(unittest.makeSuite(SSHClientTest)) if options.use_sftp: diff --git a/tests/test_auth.py b/tests/test_auth.py new file mode 100644 index 0000000..cc35ee9 --- /dev/null +++ b/tests/test_auth.py @@ -0,0 +1,212 @@ +# Copyright (C) 2008 Robey Pointer +# +# This file is part of paramiko. +# +# Paramiko is free software; you can redistribute it and/or modify it under the +# terms of the GNU Lesser General Public License as published by the Free +# Software Foundation; either version 2.1 of the License, or (at your option) +# any later version. +# +# Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY +# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR +# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with Paramiko; if not, write to the Free Software Foundation, Inc., +# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA. + +""" +Some unit tests for authenticating over a Transport. +""" + +import sys +import threading +import unittest + +from paramiko import Transport, ServerInterface, RSAKey, DSSKey, \ + SSHException, BadAuthenticationType, InteractiveQuery, ChannelException +from paramiko import AUTH_FAILED, AUTH_PARTIALLY_SUCCESSFUL, AUTH_SUCCESSFUL +from paramiko import OPEN_SUCCEEDED, OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED +from loop import LoopSocket + + +class NullServer (ServerInterface): + paranoid_did_password = False + paranoid_did_public_key = False + paranoid_key = DSSKey.from_private_key_file('tests/test_dss.key') + + def get_allowed_auths(self, username): + if username == 'slowdive': + return 'publickey,password' + if username == 'paranoid': + if not self.paranoid_did_password and not self.paranoid_did_public_key: + return 'publickey,password' + elif self.paranoid_did_password: + return 'publickey' + else: + return 'password' + if username == 'commie': + return 'keyboard-interactive' + return 'publickey' + + def check_auth_password(self, username, password): + if (username == 'slowdive') and (password == 'pygmalion'): + return AUTH_SUCCESSFUL + if (username == 'paranoid') and (password == 'paranoid'): + # 2-part auth (even openssh doesn't support this) + self.paranoid_did_password = True + if self.paranoid_did_public_key: + return AUTH_SUCCESSFUL + return AUTH_PARTIALLY_SUCCESSFUL + if (username == 'utf8') and (password == u'\u2022'.encode('utf-8')): + return AUTH_SUCCESSFUL + return AUTH_FAILED + + def check_auth_publickey(self, username, key): + if (username == 'paranoid') and (key == self.paranoid_key): + # 2-part auth + self.paranoid_did_public_key = True + if self.paranoid_did_password: + return AUTH_SUCCESSFUL + return AUTH_PARTIALLY_SUCCESSFUL + return AUTH_FAILED + + def check_auth_interactive(self, username, submethods): + if username == 'commie': + self.username = username + return InteractiveQuery('password', 'Please enter a password.', ('Password', False)) + return AUTH_FAILED + + def check_auth_interactive_response(self, responses): + if self.username == 'commie': + if (len(responses) == 1) and (responses[0] == 'cat'): + return AUTH_SUCCESSFUL + return AUTH_FAILED + + +class AuthTest (unittest.TestCase): + + def setUp(self): + self.socks = LoopSocket() + self.sockc = LoopSocket() + self.sockc.link(self.socks) + self.tc = Transport(self.sockc) + self.ts = Transport(self.socks) + + def tearDown(self): + self.tc.close() + self.ts.close() + self.socks.close() + self.sockc.close() + + def test_1_bad_auth_type(self): + """ + verify that we get the right exception when an unsupported auth + type is requested. + """ + host_key = RSAKey.from_private_key_file('tests/test_rsa.key') + public_host_key = RSAKey(data=str(host_key)) + self.ts.add_server_key(host_key) + event = threading.Event() + server = NullServer() + self.assert_(not event.isSet()) + self.ts.start_server(event, server) + try: + self.tc.connect(hostkey=public_host_key, + username='unknown', password='error') + self.assert_(False) + except: + etype, evalue, etb = sys.exc_info() + self.assertEquals(BadAuthenticationType, etype) + self.assertEquals(['publickey'], evalue.allowed_types) + + def test_2_bad_password(self): + """ + verify that a bad password gets the right exception, and that a retry + with the right password works. + """ + host_key = RSAKey.from_private_key_file('tests/test_rsa.key') + public_host_key = RSAKey(data=str(host_key)) + self.ts.add_server_key(host_key) + event = threading.Event() + server = NullServer() + self.assert_(not event.isSet()) + self.ts.start_server(event, server) + self.tc.connect(hostkey=public_host_key) + try: + self.tc.auth_password(username='slowdive', password='error') + self.assert_(False) + except: + etype, evalue, etb = sys.exc_info() + self.assert_(issubclass(etype, SSHException)) + self.tc.auth_password(username='slowdive', password='pygmalion') + event.wait(1.0) + self.assert_(event.isSet()) + self.assert_(self.ts.is_active()) + + def test_3_multipart_auth(self): + """ + verify that multipart auth works. + """ + host_key = RSAKey.from_private_key_file('tests/test_rsa.key') + public_host_key = RSAKey(data=str(host_key)) + self.ts.add_server_key(host_key) + event = threading.Event() + server = NullServer() + self.assert_(not event.isSet()) + self.ts.start_server(event, server) + self.tc.connect(hostkey=public_host_key) + remain = self.tc.auth_password(username='paranoid', password='paranoid') + self.assertEquals(['publickey'], remain) + key = DSSKey.from_private_key_file('tests/test_dss.key') + remain = self.tc.auth_publickey(username='paranoid', key=key) + self.assertEquals([], remain) + event.wait(1.0) + self.assert_(event.isSet()) + self.assert_(self.ts.is_active()) + + def test_4_interactive_auth(self): + """ + verify keyboard-interactive auth works. + """ + host_key = RSAKey.from_private_key_file('tests/test_rsa.key') + public_host_key = RSAKey(data=str(host_key)) + self.ts.add_server_key(host_key) + event = threading.Event() + server = NullServer() + self.assert_(not event.isSet()) + self.ts.start_server(event, server) + self.tc.connect(hostkey=public_host_key) + + def handler(title, instructions, prompts): + self.got_title = title + self.got_instructions = instructions + self.got_prompts = prompts + return ['cat'] + remain = self.tc.auth_interactive('commie', handler) + self.assertEquals(self.got_title, 'password') + self.assertEquals(self.got_prompts, [('Password', False)]) + self.assertEquals([], remain) + event.wait(1.0) + self.assert_(event.isSet()) + self.assert_(self.ts.is_active()) + + def test_5_interactive_auth_fallback(self): + """ + verify that a password auth attempt will fallback to "interactive" + if password auth isn't supported but interactive is. + """ + host_key = RSAKey.from_private_key_file('tests/test_rsa.key') + public_host_key = RSAKey(data=str(host_key)) + self.ts.add_server_key(host_key) + event = threading.Event() + server = NullServer() + self.assert_(not event.isSet()) + self.ts.start_server(event, server) + self.tc.connect(hostkey=public_host_key) + remain = self.tc.auth_password('commie', 'cat') + self.assertEquals([], remain) + event.wait(1.0) + self.assert_(event.isSet()) + self.assert_(self.ts.is_active()) diff --git a/tests/test_transport.py b/tests/test_transport.py index 187fee9..f0dbed3 100644 --- a/tests/test_transport.py +++ b/tests/test_transport.py @@ -63,6 +63,8 @@ class NullServer (ServerInterface): if self.paranoid_did_public_key: return AUTH_SUCCESSFUL return AUTH_PARTIALLY_SUCCESSFUL + if (username == 'utf8') and (password == u'\u2022'.encode('utf-8')): + return AUTH_SUCCESSFUL return AUTH_FAILED def check_auth_publickey(self, username, key): @@ -139,16 +141,22 @@ class TransportTest (unittest.TestCase): self.socks.close() self.sockc.close() - def setup_test_server(self): + def setup_test_server(self, client_options=None, server_options=None): host_key = RSAKey.from_private_key_file('tests/test_rsa.key') public_host_key = RSAKey(data=str(host_key)) self.ts.add_server_key(host_key) + + if client_options is not None: + client_options(self.tc.get_security_options()) + if server_options is not None: + server_options(self.ts.get_security_options()) + event = threading.Event() self.server = NullServer() self.assert_(not event.isSet()) self.ts.start_server(event, self.server) - self.tc.connect(hostkey=public_host_key) - self.tc.auth_password(username='slowdive', password='pygmalion') + self.tc.connect(hostkey=public_host_key, + username='slowdive', password='pygmalion') event.wait(1.0) self.assert_(event.isSet()) self.assert_(self.ts.is_active()) @@ -210,21 +218,10 @@ class TransportTest (unittest.TestCase): verify that the client can demand odd handshake settings, and can renegotiate keys in mid-stream. """ - host_key = RSAKey.from_private_key_file('tests/test_rsa.key') - public_host_key = RSAKey(data=str(host_key)) - self.ts.add_server_key(host_key) - event = threading.Event() - server = NullServer() - self.assert_(not event.isSet()) - self.ts.start_server(event, server) - options = self.tc.get_security_options() - options.ciphers = ('aes256-cbc',) - options.digests = ('hmac-md5-96',) - self.tc.connect(hostkey=public_host_key, - username='slowdive', password='pygmalion') - event.wait(1.0) - self.assert_(event.isSet()) - self.assert_(self.ts.is_active()) + def force_algorithms(options): + options.ciphers = ('aes256-cbc',) + options.digests = ('hmac-md5-96',) + self.setup_test_server(client_options=force_algorithms) self.assertEquals('aes256-cbc', self.tc.local_cipher) self.assertEquals('aes256-cbc', self.tc.remote_cipher) self.assertEquals(12, self.tc.packetizer.get_mac_size_out()) @@ -238,142 +235,13 @@ class TransportTest (unittest.TestCase): """ verify that the keepalive will be sent. """ - self.tc.set_hexdump(True) - - host_key = RSAKey.from_private_key_file('tests/test_rsa.key') - public_host_key = RSAKey(data=str(host_key)) - self.ts.add_server_key(host_key) - event = threading.Event() - server = NullServer() - self.assert_(not event.isSet()) - self.ts.start_server(event, server) - self.tc.connect(hostkey=public_host_key, - username='slowdive', password='pygmalion') - event.wait(1.0) - self.assert_(event.isSet()) - self.assert_(self.ts.is_active()) - - self.assertEquals(None, getattr(server, '_global_request', None)) + self.setup_test_server() + self.assertEquals(None, getattr(self.server, '_global_request', None)) self.tc.set_keepalive(1) time.sleep(2) - self.assertEquals('keepalive@lag.net', server._global_request) + self.assertEquals('keepalive@lag.net', self.server._global_request) - def test_6_bad_auth_type(self): - """ - verify that we get the right exception when an unsupported auth - type is requested. - """ - host_key = RSAKey.from_private_key_file('tests/test_rsa.key') - public_host_key = RSAKey(data=str(host_key)) - self.ts.add_server_key(host_key) - event = threading.Event() - server = NullServer() - self.assert_(not event.isSet()) - self.ts.start_server(event, server) - try: - self.tc.connect(hostkey=public_host_key, - username='unknown', password='error') - self.assert_(False) - except: - etype, evalue, etb = sys.exc_info() - self.assertEquals(BadAuthenticationType, etype) - self.assertEquals(['publickey'], evalue.allowed_types) - - def test_7_bad_password(self): - """ - verify that a bad password gets the right exception, and that a retry - with the right password works. - """ - host_key = RSAKey.from_private_key_file('tests/test_rsa.key') - public_host_key = RSAKey(data=str(host_key)) - self.ts.add_server_key(host_key) - event = threading.Event() - server = NullServer() - self.assert_(not event.isSet()) - self.ts.start_server(event, server) - self.tc.ultra_debug = True - self.tc.connect(hostkey=public_host_key) - try: - self.tc.auth_password(username='slowdive', password='error') - self.assert_(False) - except: - etype, evalue, etb = sys.exc_info() - self.assert_(issubclass(etype, SSHException)) - self.tc.auth_password(username='slowdive', password='pygmalion') - event.wait(1.0) - self.assert_(event.isSet()) - self.assert_(self.ts.is_active()) - - def test_8_multipart_auth(self): - """ - verify that multipart auth works. - """ - host_key = RSAKey.from_private_key_file('tests/test_rsa.key') - public_host_key = RSAKey(data=str(host_key)) - self.ts.add_server_key(host_key) - event = threading.Event() - server = NullServer() - self.assert_(not event.isSet()) - self.ts.start_server(event, server) - self.tc.ultra_debug = True - self.tc.connect(hostkey=public_host_key) - remain = self.tc.auth_password(username='paranoid', password='paranoid') - self.assertEquals(['publickey'], remain) - key = DSSKey.from_private_key_file('tests/test_dss.key') - remain = self.tc.auth_publickey(username='paranoid', key=key) - self.assertEquals([], remain) - event.wait(1.0) - self.assert_(event.isSet()) - self.assert_(self.ts.is_active()) - - def test_9_interactive_auth(self): - """ - verify keyboard-interactive auth works. - """ - host_key = RSAKey.from_private_key_file('tests/test_rsa.key') - public_host_key = RSAKey(data=str(host_key)) - self.ts.add_server_key(host_key) - event = threading.Event() - server = NullServer() - self.assert_(not event.isSet()) - self.ts.start_server(event, server) - self.tc.ultra_debug = True - self.tc.connect(hostkey=public_host_key) - - def handler(title, instructions, prompts): - self.got_title = title - self.got_instructions = instructions - self.got_prompts = prompts - return ['cat'] - remain = self.tc.auth_interactive('commie', handler) - self.assertEquals(self.got_title, 'password') - self.assertEquals(self.got_prompts, [('Password', False)]) - self.assertEquals([], remain) - event.wait(1.0) - self.assert_(event.isSet()) - self.assert_(self.ts.is_active()) - - def test_A_interactive_auth_fallback(self): - """ - verify that a password auth attempt will fallback to "interactive" - if password auth isn't supported but interactive is. - """ - host_key = RSAKey.from_private_key_file('tests/test_rsa.key') - public_host_key = RSAKey(data=str(host_key)) - self.ts.add_server_key(host_key) - event = threading.Event() - server = NullServer() - self.assert_(not event.isSet()) - self.ts.start_server(event, server) - self.tc.ultra_debug = True - self.tc.connect(hostkey=public_host_key) - remain = self.tc.auth_password('commie', 'cat') - self.assertEquals([], remain) - event.wait(1.0) - self.assert_(event.isSet()) - self.assert_(self.ts.is_active()) - - def test_B_exec_command(self): + def test_6_exec_command(self): """ verify that exec_command() does something reasonable. """ @@ -415,7 +283,7 @@ class TransportTest (unittest.TestCase): self.assertEquals('This is on stderr.\n', f.readline()) self.assertEquals('', f.readline()) - def test_C_invoke_shell(self): + def test_7_invoke_shell(self): """ verify that invoke_shell() does something reasonable. """ @@ -429,7 +297,7 @@ class TransportTest (unittest.TestCase): chan.close() self.assertEquals('', f.readline()) - def test_D_channel_exception(self): + def test_8_channel_exception(self): """ verify that ChannelException is thrown for a bad open-channel request. """ @@ -440,7 +308,7 @@ class TransportTest (unittest.TestCase): except ChannelException, x: self.assert_(x.code == OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED) - def test_E_exit_status(self): + def test_9_exit_status(self): """ verify that get_exit_status() works. """ @@ -462,7 +330,7 @@ class TransportTest (unittest.TestCase): self.assertEquals(23, chan.recv_exit_status()) chan.close() - def test_F_select(self): + def test_A_select(self): """ verify that select() on a channel works. """ @@ -517,7 +385,7 @@ class TransportTest (unittest.TestCase): # ...and now is closed. self.assertEquals(True, p._closed) - def test_G_renegotiate(self): + def test_B_renegotiate(self): """ verify that a transport can correctly renegotiate mid-stream. """ @@ -541,24 +409,13 @@ class TransportTest (unittest.TestCase): schan.close() - def test_H_compression(self): + def test_C_compression(self): """ verify that zlib compression is basically working. """ - host_key = RSAKey.from_private_key_file('tests/test_rsa.key') - public_host_key = RSAKey(data=str(host_key)) - self.ts.add_server_key(host_key) - self.ts.get_security_options().compression = ('zlib',) - self.tc.get_security_options().compression = ('zlib',) - event = threading.Event() - server = NullServer() - self.ts.start_server(event, server) - self.tc.connect(hostkey=public_host_key, - username='slowdive', password='pygmalion') - event.wait(1.0) - self.assert_(event.isSet()) - self.assert_(self.ts.is_active()) - + def force_compression(o): + o.compression = ('zlib',) + self.setup_test_server(force_compression, force_compression) chan = self.tc.open_session() chan.exec_command('yes') schan = self.ts.accept(1.0) @@ -573,7 +430,7 @@ class TransportTest (unittest.TestCase): chan.close() schan.close() - def test_I_x11(self): + def test_D_x11(self): """ verify that an x11 port can be requested and opened. """ @@ -607,7 +464,7 @@ class TransportTest (unittest.TestCase): chan.close() schan.close() - def test_J_reverse_port_forwarding(self): + def test_E_reverse_port_forwarding(self): """ verify that a client can ask the server to open a reverse port for forwarding. @@ -643,7 +500,7 @@ class TransportTest (unittest.TestCase): self.tc.cancel_port_forward('', port) self.assertTrue(self.server._listen is None) - def test_K_port_forwarding(self): + def test_F_port_forwarding(self): """ verify that a client can forward new connections from a locally- forwarded port. @@ -672,7 +529,7 @@ class TransportTest (unittest.TestCase): self.assertEquals('Hello!\n', cs.recv(7)) cs.close() - def test_L_stderr_select(self): + def test_G_stderr_select(self): """ verify that select() on a channel works even if only stderr is receiving data. @@ -711,7 +568,7 @@ class TransportTest (unittest.TestCase): schan.close() chan.close() - def test_M_send_ready(self): + def test_H_send_ready(self): """ verify that send_ready() indicates when a send would not block. """