diff --git a/paramiko/__init__.py b/paramiko/__init__.py index f7ddbc1..2c430d7 100644 --- a/paramiko/__init__.py +++ b/paramiko/__init__.py @@ -77,6 +77,7 @@ DSSKey = dsskey.DSSKey SSHException = ssh_exception.SSHException Message = message.Message PasswordRequiredException = ssh_exception.PasswordRequiredException +BadAuthenticationType = ssh_exception.BadAuthenticationType SFTP = sftp_client.SFTP SFTPClient = sftp_client.SFTPClient SFTPServer = sftp_server.SFTPServer @@ -105,6 +106,7 @@ __all__ = [ 'Transport', 'Message', 'SSHException', 'PasswordRequiredException', + 'BadAuthenticationType', 'SFTP', 'SFTPHandle', 'SFTPClient', diff --git a/paramiko/auth_transport.py b/paramiko/auth_transport.py index 2a6a1a2..45ef8ce 100644 --- a/paramiko/auth_transport.py +++ b/paramiko/auth_transport.py @@ -23,6 +23,8 @@ L{Transport} is a subclass of L{BaseTransport} that handles authentication. This separation keeps either class file from being too unwieldy. """ +import threading + # this helps freezing utils import encodings.utf_8 @@ -30,7 +32,7 @@ from common import * import util from transport import BaseTransport from message import Message -from ssh_exception import SSHException +from ssh_exception import SSHException, BadAuthenticationType class Transport (BaseTransport): @@ -102,12 +104,19 @@ class Transport (BaseTransport): else: return self.username - def auth_publickey(self, username, key, event): + def auth_publickey(self, username, key, event=None): """ Authenticate to the server using a private key. The key is used to - sign data from the server, so it must include the private part. The - given L{event} is triggered on success or failure. On success, - L{is_authenticated} will return C{True}. + sign data from the server, so it must include the private part. + + If an C{event} is passed in, this method will return immediately, and + the event will be triggered once authentication succeeds or fails. On + success, L{is_authenticated} will return C{True}. On failure, you may + use L{get_exception} to get more detailed error information. + + Since 1.1, if no event is passed, this method will block until the + authentication succeeds or fails. On failure, an exception is raised. + Otherwise, the method simply returns. @param username: the username to authenticate as. @type username: string @@ -116,26 +125,46 @@ class Transport (BaseTransport): @param event: an event to trigger when the authentication attempt is complete (whether it was successful or not) @type event: threading.Event + + @raise BadAuthenticationType: if public-key authentication isn't + allowed by the server for this user (and no event was passed in). + @raise SSHException: if the authentication failed (and no event was + passed in). """ if (not self.active) or (not self.initial_kex_done): # we should never try to authenticate unless we're on a secure link raise SSHException('No existing session') + if event is None: + my_event = threading.Event() + else: + my_event = event + self.lock.acquire() try: - self.lock.acquire() - self.auth_event = event + self.auth_event = my_event self.auth_method = 'publickey' self.username = username self.private_key = key self._request_auth() finally: self.lock.release() + if event is not None: + # caller wants to wait for event themselves + return + self._wait_for_response(my_event) - def auth_password(self, username, password, event): + def auth_password(self, username, password, event=None): """ Authenticate to the server using a password. The username and password - are sent over an encrypted link, and the given L{event} is triggered on - success or failure. On success, L{is_authenticated} will return - C{True}. + are sent over an encrypted link. + + If an C{event} is passed in, this method will return immediately, and + the event will be triggered once authentication succeeds or fails. On + success, L{is_authenticated} will return C{True}. On failure, you may + use L{get_exception} to get more detailed error information. + + Since 1.1, if no event is passed, this method will block until the + authentication succeeds or fails. On failure, an exception is raised. + Otherwise, the method simply returns. @param username: the username to authenticate as. @type username: string @@ -144,19 +173,32 @@ class Transport (BaseTransport): @param event: an event to trigger when the authentication attempt is complete (whether it was successful or not) @type event: threading.Event + + @raise BadAuthenticationType: if password authentication isn't + allowed by the server for this user (and no event was passed in). + @raise SSHException: if the authentication failed (and no event was + passed in). """ if (not self.active) or (not self.initial_kex_done): # we should never try to send the password unless we're on a secure link raise SSHException('No existing session') + if event is None: + my_event = threading.Event() + else: + my_event = event + self.lock.acquire() try: - self.lock.acquire() - self.auth_event = event + self.auth_event = my_event self.auth_method = 'password' self.username = username self.password = password self._request_auth() finally: self.lock.release() + if event is not None: + # caller wants to wait for event themselves + return + self._wait_for_response(my_event) ### internals... @@ -198,6 +240,22 @@ class Transport (BaseTransport): m.add_string(str(key)) return str(m) + def _wait_for_response(self, event): + while True: + event.wait(0.1) + if not self.active: + e = self.get_exception() + if e is None: + e = SSHException('Authentication failed.') + raise e + if event.isSet(): + break + if not self.is_authenticated(): + e = self.get_exception() + if e is None: + e = SSHException('Authentication failed.') + raise e + def _parse_service_request(self, m): service = m.get_string() if self.server_mode and (service == 'ssh-userauth'): @@ -264,12 +322,12 @@ class Transport (BaseTransport): result = self.server_object.check_auth_none(username) elif method == 'password': changereq = m.get_boolean() - password = m.get_string().decode('UTF-8') + password = m.get_string().decode('UTF-8', 'replace') if changereq: # always treated as failure, since we don't support changing passwords, but collect # the list of valid auth types from the callback anyway self._log(DEBUG, 'Auth request to change passwords (rejected)') - newpassword = m.get_string().decode('UTF-8') + newpassword = m.get_string().decode('UTF-8', 'replace') result = AUTH_FAILED else: result = self.server_object.check_auth_password(username, password) @@ -339,13 +397,13 @@ class Transport (BaseTransport): if partial: self._log(INFO, 'Authentication continues...') self._log(DEBUG, 'Methods: ' + str(partial)) - # FIXME - do something + # FIXME: multi-part auth not supported pass + if self.auth_method not in authlist: + self.saved_exception = BadAuthenticationType('Bad authentication type', authlist) self._log(INFO, 'Authentication failed.') self.authenticated = False - # FIXME: i don't think we need to close() necessarily here self.username = None - self.close() if self.auth_event != None: self.auth_event.set() diff --git a/paramiko/ssh_exception.py b/paramiko/ssh_exception.py index 6321821..1f9173e 100644 --- a/paramiko/ssh_exception.py +++ b/paramiko/ssh_exception.py @@ -25,12 +25,31 @@ Exceptions defined by paramiko. class SSHException (Exception): """ - Exception thrown by failures in SSH2 protocol negotiation or logic errors. + Exception raised by failures in SSH2 protocol negotiation or logic errors. """ pass class PasswordRequiredException (SSHException): """ - Exception thrown when a password is needed to unlock a private key file. + Exception raised when a password is needed to unlock a private key file. """ pass + +class BadAuthenticationType (SSHException): + """ + Exception raised when an authentication type (like password) is used, but + the server isn't allowing that type. (It may only allow public-key, for + example.) + + @ivar allowed_types: list of allowed authentication types provided by the + server (possible values are: C{"none"}, C{"password"}, and + C{"publickey"}). + @type allowed_types: list + + @since: 1.1 + """ + allowed_types = [] + + def __init__(self, explanation, types): + SSHException.__init__(self, explanation) + self.allowed_types = types diff --git a/paramiko/transport.py b/paramiko/transport.py index 3e50ffc..7e45741 100644 --- a/paramiko/transport.py +++ b/paramiko/transport.py @@ -681,8 +681,8 @@ class BaseTransport (threading.Thread): return False def accept(self, timeout=None): + self.lock.acquire() try: - self.lock.acquire() if len(self.server_accepts) > 0: chan = self.server_accepts.pop(0) else: @@ -740,8 +740,7 @@ class BaseTransport (threading.Thread): while 1: event.wait(0.1) if not self.active: - e = self.saved_exception - self.saved_exception = None + e = self.get_exception() if e is not None: raise e raise SSHException('Negotiation failed.') @@ -759,27 +758,34 @@ class BaseTransport (threading.Thread): self._log(DEBUG, 'Host key verified (%s)' % hostkey.get_name()) if (pkey is not None) or (password is not None): - event.clear() if password is not None: self._log(DEBUG, 'Attempting password auth...') - self.auth_password(username, password, event) + self.auth_password(username, password) else: - self._log(DEBUG, 'Attempting pkey auth...') - self.auth_publickey(username, pkey, event) - while 1: - event.wait(0.1) - if not self.active: - e = self.saved_exception - self.saved_exception = None - if e is not None: - raise e - raise SSHException('Authentication failed.') - if event.isSet(): - break - if not self.is_authenticated(): - raise SSHException('Authentication failed.') + self._log(DEBUG, 'Attempting public-key auth...') + self.auth_publickey(username, pkey) return + + def get_exception(self): + """ + Return any exception that happened during the last server request. + This can be used to fetch more specific error information after using + calls like L{start_client}. The exception (if any) is cleared after + this call. + + @return: an exception, or C{None} if there is no stored exception. + @rtype: Exception + + @since: 1.1 + """ + self.lock.acquire() + try: + e = self.saved_exception + self.saved_exception = None + return e + finally: + self.lock.release() def set_subsystem_handler(self, name, handler, *larg, **kwarg): """ diff --git a/tests/test_transport.py b/tests/test_transport.py index 93dc8b7..b55160a 100644 --- a/tests/test_transport.py +++ b/tests/test_transport.py @@ -22,14 +22,17 @@ Some unit tests for the ssh2 protocol in Transport. """ -import unittest, threading -from paramiko import Transport, SecurityOptions, ServerInterface, RSAKey, DSSKey +import sys, unittest, threading +from paramiko import Transport, SecurityOptions, ServerInterface, RSAKey, DSSKey, \ + SSHException, BadAuthenticationType from paramiko import AUTH_FAILED, AUTH_SUCCESSFUL from loop import LoopSocket class NullServer (ServerInterface): def get_allowed_auths(self, username): + if username == 'slowdive': + return 'publickey,password' return 'publickey' def check_auth_password(self, username, password): @@ -90,4 +93,51 @@ class TransportTest (unittest.TestCase): self.assert_(event.isSet()) self.assert_(self.ts.is_active()) - + def test_3_bad_auth_type(self): + """ + verify that we get the right exception when an unsupported auth + type is requested. + """ + host_key = RSAKey.from_private_key_file('tests/test_rsa.key') + public_host_key = RSAKey(data=str(host_key)) + self.ts.add_server_key(host_key) + event = threading.Event() + server = NullServer() + self.assert_(not event.isSet()) + self.ts.start_server(event, server) + self.tc.ultra_debug = True + try: + self.tc.connect(hostkey=public_host_key, + username='unknown', password='error') + self.assert_(False) + except: + etype, evalue, etb = sys.exc_info() + self.assertEquals(BadAuthenticationType, etype) + self.assertEquals(['publickey'], evalue.allowed_types) + + def test_4_bad_password(self): + """ + verify that a bad password gets the right exception, and that a retry + with the right password works. + """ + host_key = RSAKey.from_private_key_file('tests/test_rsa.key') + public_host_key = RSAKey(data=str(host_key)) + self.ts.add_server_key(host_key) + event = threading.Event() + server = NullServer() + self.assert_(not event.isSet()) + self.ts.start_server(event, server) + self.tc.ultra_debug = True + self.tc.connect(hostkey=public_host_key) + try: + self.tc.auth_password(username='slowdive', password='error') + self.assert_(False) + except: + etype, evalue, etb = sys.exc_info() + self.assertEquals(SSHException, etype) + self.tc.auth_password(username='slowdive', password='pygmalion') + event.wait(1.0) + self.assert_(event.isSet()) + self.assert_(self.ts.is_active()) + +