[project @ Arch-1:robey@lag.net--2003-public%secsh--dev--1.0--patch-123]

clean up authentication
add new exception "BadAuthenticationType", which is raised when auth fails
because your auth type (password or public-key) isn't valid on the server.

used this as an excuse to clean up auth_password and auth_publickey so their
'event' arg is optional, and if missing, they block until auth is finished,
raising an exception on error.

also, don't close the session on failed auth -- the server may let you try
again.

added some test cases for failed auth.
This commit is contained in:
Robey Pointer 2004-12-11 03:43:18 +00:00
parent 73a0df1df3
commit 767d739299
5 changed files with 177 additions and 42 deletions

View File

@ -77,6 +77,7 @@ DSSKey = dsskey.DSSKey
SSHException = ssh_exception.SSHException
Message = message.Message
PasswordRequiredException = ssh_exception.PasswordRequiredException
BadAuthenticationType = ssh_exception.BadAuthenticationType
SFTP = sftp_client.SFTP
SFTPClient = sftp_client.SFTPClient
SFTPServer = sftp_server.SFTPServer
@ -105,6 +106,7 @@ __all__ = [ 'Transport',
'Message',
'SSHException',
'PasswordRequiredException',
'BadAuthenticationType',
'SFTP',
'SFTPHandle',
'SFTPClient',

View File

@ -23,6 +23,8 @@ L{Transport} is a subclass of L{BaseTransport} that handles authentication.
This separation keeps either class file from being too unwieldy.
"""
import threading
# this helps freezing utils
import encodings.utf_8
@ -30,7 +32,7 @@ from common import *
import util
from transport import BaseTransport
from message import Message
from ssh_exception import SSHException
from ssh_exception import SSHException, BadAuthenticationType
class Transport (BaseTransport):
@ -102,12 +104,19 @@ class Transport (BaseTransport):
else:
return self.username
def auth_publickey(self, username, key, event):
def auth_publickey(self, username, key, event=None):
"""
Authenticate to the server using a private key. The key is used to
sign data from the server, so it must include the private part. The
given L{event} is triggered on success or failure. On success,
L{is_authenticated} will return C{True}.
sign data from the server, so it must include the private part.
If an C{event} is passed in, this method will return immediately, and
the event will be triggered once authentication succeeds or fails. On
success, L{is_authenticated} will return C{True}. On failure, you may
use L{get_exception} to get more detailed error information.
Since 1.1, if no event is passed, this method will block until the
authentication succeeds or fails. On failure, an exception is raised.
Otherwise, the method simply returns.
@param username: the username to authenticate as.
@type username: string
@ -116,26 +125,46 @@ class Transport (BaseTransport):
@param event: an event to trigger when the authentication attempt is
complete (whether it was successful or not)
@type event: threading.Event
@raise BadAuthenticationType: if public-key authentication isn't
allowed by the server for this user (and no event was passed in).
@raise SSHException: if the authentication failed (and no event was
passed in).
"""
if (not self.active) or (not self.initial_kex_done):
# we should never try to authenticate unless we're on a secure link
raise SSHException('No existing session')
if event is None:
my_event = threading.Event()
else:
my_event = event
self.lock.acquire()
try:
self.lock.acquire()
self.auth_event = event
self.auth_event = my_event
self.auth_method = 'publickey'
self.username = username
self.private_key = key
self._request_auth()
finally:
self.lock.release()
if event is not None:
# caller wants to wait for event themselves
return
self._wait_for_response(my_event)
def auth_password(self, username, password, event):
def auth_password(self, username, password, event=None):
"""
Authenticate to the server using a password. The username and password
are sent over an encrypted link, and the given L{event} is triggered on
success or failure. On success, L{is_authenticated} will return
C{True}.
are sent over an encrypted link.
If an C{event} is passed in, this method will return immediately, and
the event will be triggered once authentication succeeds or fails. On
success, L{is_authenticated} will return C{True}. On failure, you may
use L{get_exception} to get more detailed error information.
Since 1.1, if no event is passed, this method will block until the
authentication succeeds or fails. On failure, an exception is raised.
Otherwise, the method simply returns.
@param username: the username to authenticate as.
@type username: string
@ -144,19 +173,32 @@ class Transport (BaseTransport):
@param event: an event to trigger when the authentication attempt is
complete (whether it was successful or not)
@type event: threading.Event
@raise BadAuthenticationType: if password authentication isn't
allowed by the server for this user (and no event was passed in).
@raise SSHException: if the authentication failed (and no event was
passed in).
"""
if (not self.active) or (not self.initial_kex_done):
# we should never try to send the password unless we're on a secure link
raise SSHException('No existing session')
if event is None:
my_event = threading.Event()
else:
my_event = event
self.lock.acquire()
try:
self.lock.acquire()
self.auth_event = event
self.auth_event = my_event
self.auth_method = 'password'
self.username = username
self.password = password
self._request_auth()
finally:
self.lock.release()
if event is not None:
# caller wants to wait for event themselves
return
self._wait_for_response(my_event)
### internals...
@ -198,6 +240,22 @@ class Transport (BaseTransport):
m.add_string(str(key))
return str(m)
def _wait_for_response(self, event):
while True:
event.wait(0.1)
if not self.active:
e = self.get_exception()
if e is None:
e = SSHException('Authentication failed.')
raise e
if event.isSet():
break
if not self.is_authenticated():
e = self.get_exception()
if e is None:
e = SSHException('Authentication failed.')
raise e
def _parse_service_request(self, m):
service = m.get_string()
if self.server_mode and (service == 'ssh-userauth'):
@ -264,12 +322,12 @@ class Transport (BaseTransport):
result = self.server_object.check_auth_none(username)
elif method == 'password':
changereq = m.get_boolean()
password = m.get_string().decode('UTF-8')
password = m.get_string().decode('UTF-8', 'replace')
if changereq:
# always treated as failure, since we don't support changing passwords, but collect
# the list of valid auth types from the callback anyway
self._log(DEBUG, 'Auth request to change passwords (rejected)')
newpassword = m.get_string().decode('UTF-8')
newpassword = m.get_string().decode('UTF-8', 'replace')
result = AUTH_FAILED
else:
result = self.server_object.check_auth_password(username, password)
@ -339,13 +397,13 @@ class Transport (BaseTransport):
if partial:
self._log(INFO, 'Authentication continues...')
self._log(DEBUG, 'Methods: ' + str(partial))
# FIXME - do something
# FIXME: multi-part auth not supported
pass
if self.auth_method not in authlist:
self.saved_exception = BadAuthenticationType('Bad authentication type', authlist)
self._log(INFO, 'Authentication failed.')
self.authenticated = False
# FIXME: i don't think we need to close() necessarily here
self.username = None
self.close()
if self.auth_event != None:
self.auth_event.set()

View File

@ -25,12 +25,31 @@ Exceptions defined by paramiko.
class SSHException (Exception):
"""
Exception thrown by failures in SSH2 protocol negotiation or logic errors.
Exception raised by failures in SSH2 protocol negotiation or logic errors.
"""
pass
class PasswordRequiredException (SSHException):
"""
Exception thrown when a password is needed to unlock a private key file.
Exception raised when a password is needed to unlock a private key file.
"""
pass
class BadAuthenticationType (SSHException):
"""
Exception raised when an authentication type (like password) is used, but
the server isn't allowing that type. (It may only allow public-key, for
example.)
@ivar allowed_types: list of allowed authentication types provided by the
server (possible values are: C{"none"}, C{"password"}, and
C{"publickey"}).
@type allowed_types: list
@since: 1.1
"""
allowed_types = []
def __init__(self, explanation, types):
SSHException.__init__(self, explanation)
self.allowed_types = types

View File

@ -681,8 +681,8 @@ class BaseTransport (threading.Thread):
return False
def accept(self, timeout=None):
self.lock.acquire()
try:
self.lock.acquire()
if len(self.server_accepts) > 0:
chan = self.server_accepts.pop(0)
else:
@ -740,8 +740,7 @@ class BaseTransport (threading.Thread):
while 1:
event.wait(0.1)
if not self.active:
e = self.saved_exception
self.saved_exception = None
e = self.get_exception()
if e is not None:
raise e
raise SSHException('Negotiation failed.')
@ -759,27 +758,34 @@ class BaseTransport (threading.Thread):
self._log(DEBUG, 'Host key verified (%s)' % hostkey.get_name())
if (pkey is not None) or (password is not None):
event.clear()
if password is not None:
self._log(DEBUG, 'Attempting password auth...')
self.auth_password(username, password, event)
self.auth_password(username, password)
else:
self._log(DEBUG, 'Attempting pkey auth...')
self.auth_publickey(username, pkey, event)
while 1:
event.wait(0.1)
if not self.active:
e = self.saved_exception
self.saved_exception = None
if e is not None:
raise e
raise SSHException('Authentication failed.')
if event.isSet():
break
if not self.is_authenticated():
raise SSHException('Authentication failed.')
self._log(DEBUG, 'Attempting public-key auth...')
self.auth_publickey(username, pkey)
return
def get_exception(self):
"""
Return any exception that happened during the last server request.
This can be used to fetch more specific error information after using
calls like L{start_client}. The exception (if any) is cleared after
this call.
@return: an exception, or C{None} if there is no stored exception.
@rtype: Exception
@since: 1.1
"""
self.lock.acquire()
try:
e = self.saved_exception
self.saved_exception = None
return e
finally:
self.lock.release()
def set_subsystem_handler(self, name, handler, *larg, **kwarg):
"""

View File

@ -22,14 +22,17 @@
Some unit tests for the ssh2 protocol in Transport.
"""
import unittest, threading
from paramiko import Transport, SecurityOptions, ServerInterface, RSAKey, DSSKey
import sys, unittest, threading
from paramiko import Transport, SecurityOptions, ServerInterface, RSAKey, DSSKey, \
SSHException, BadAuthenticationType
from paramiko import AUTH_FAILED, AUTH_SUCCESSFUL
from loop import LoopSocket
class NullServer (ServerInterface):
def get_allowed_auths(self, username):
if username == 'slowdive':
return 'publickey,password'
return 'publickey'
def check_auth_password(self, username, password):
@ -90,4 +93,51 @@ class TransportTest (unittest.TestCase):
self.assert_(event.isSet())
self.assert_(self.ts.is_active())
def test_3_bad_auth_type(self):
"""
verify that we get the right exception when an unsupported auth
type is requested.
"""
host_key = RSAKey.from_private_key_file('tests/test_rsa.key')
public_host_key = RSAKey(data=str(host_key))
self.ts.add_server_key(host_key)
event = threading.Event()
server = NullServer()
self.assert_(not event.isSet())
self.ts.start_server(event, server)
self.tc.ultra_debug = True
try:
self.tc.connect(hostkey=public_host_key,
username='unknown', password='error')
self.assert_(False)
except:
etype, evalue, etb = sys.exc_info()
self.assertEquals(BadAuthenticationType, etype)
self.assertEquals(['publickey'], evalue.allowed_types)
def test_4_bad_password(self):
"""
verify that a bad password gets the right exception, and that a retry
with the right password works.
"""
host_key = RSAKey.from_private_key_file('tests/test_rsa.key')
public_host_key = RSAKey(data=str(host_key))
self.ts.add_server_key(host_key)
event = threading.Event()
server = NullServer()
self.assert_(not event.isSet())
self.ts.start_server(event, server)
self.tc.ultra_debug = True
self.tc.connect(hostkey=public_host_key)
try:
self.tc.auth_password(username='slowdive', password='error')
self.assert_(False)
except:
etype, evalue, etb = sys.exc_info()
self.assertEquals(SSHException, etype)
self.tc.auth_password(username='slowdive', password='pygmalion')
event.wait(1.0)
self.assert_(event.isSet())
self.assert_(self.ts.is_active())