[project @ Arch-1:robey@lag.net--2003-public%secsh--dev--1.0--patch-123]
clean up authentication add new exception "BadAuthenticationType", which is raised when auth fails because your auth type (password or public-key) isn't valid on the server. used this as an excuse to clean up auth_password and auth_publickey so their 'event' arg is optional, and if missing, they block until auth is finished, raising an exception on error. also, don't close the session on failed auth -- the server may let you try again. added some test cases for failed auth.
This commit is contained in:
parent
73a0df1df3
commit
767d739299
|
@ -77,6 +77,7 @@ DSSKey = dsskey.DSSKey
|
|||
SSHException = ssh_exception.SSHException
|
||||
Message = message.Message
|
||||
PasswordRequiredException = ssh_exception.PasswordRequiredException
|
||||
BadAuthenticationType = ssh_exception.BadAuthenticationType
|
||||
SFTP = sftp_client.SFTP
|
||||
SFTPClient = sftp_client.SFTPClient
|
||||
SFTPServer = sftp_server.SFTPServer
|
||||
|
@ -105,6 +106,7 @@ __all__ = [ 'Transport',
|
|||
'Message',
|
||||
'SSHException',
|
||||
'PasswordRequiredException',
|
||||
'BadAuthenticationType',
|
||||
'SFTP',
|
||||
'SFTPHandle',
|
||||
'SFTPClient',
|
||||
|
|
|
@ -23,6 +23,8 @@ L{Transport} is a subclass of L{BaseTransport} that handles authentication.
|
|||
This separation keeps either class file from being too unwieldy.
|
||||
"""
|
||||
|
||||
import threading
|
||||
|
||||
# this helps freezing utils
|
||||
import encodings.utf_8
|
||||
|
||||
|
@ -30,7 +32,7 @@ from common import *
|
|||
import util
|
||||
from transport import BaseTransport
|
||||
from message import Message
|
||||
from ssh_exception import SSHException
|
||||
from ssh_exception import SSHException, BadAuthenticationType
|
||||
|
||||
|
||||
class Transport (BaseTransport):
|
||||
|
@ -102,12 +104,19 @@ class Transport (BaseTransport):
|
|||
else:
|
||||
return self.username
|
||||
|
||||
def auth_publickey(self, username, key, event):
|
||||
def auth_publickey(self, username, key, event=None):
|
||||
"""
|
||||
Authenticate to the server using a private key. The key is used to
|
||||
sign data from the server, so it must include the private part. The
|
||||
given L{event} is triggered on success or failure. On success,
|
||||
L{is_authenticated} will return C{True}.
|
||||
sign data from the server, so it must include the private part.
|
||||
|
||||
If an C{event} is passed in, this method will return immediately, and
|
||||
the event will be triggered once authentication succeeds or fails. On
|
||||
success, L{is_authenticated} will return C{True}. On failure, you may
|
||||
use L{get_exception} to get more detailed error information.
|
||||
|
||||
Since 1.1, if no event is passed, this method will block until the
|
||||
authentication succeeds or fails. On failure, an exception is raised.
|
||||
Otherwise, the method simply returns.
|
||||
|
||||
@param username: the username to authenticate as.
|
||||
@type username: string
|
||||
|
@ -116,26 +125,46 @@ class Transport (BaseTransport):
|
|||
@param event: an event to trigger when the authentication attempt is
|
||||
complete (whether it was successful or not)
|
||||
@type event: threading.Event
|
||||
|
||||
@raise BadAuthenticationType: if public-key authentication isn't
|
||||
allowed by the server for this user (and no event was passed in).
|
||||
@raise SSHException: if the authentication failed (and no event was
|
||||
passed in).
|
||||
"""
|
||||
if (not self.active) or (not self.initial_kex_done):
|
||||
# we should never try to authenticate unless we're on a secure link
|
||||
raise SSHException('No existing session')
|
||||
try:
|
||||
if event is None:
|
||||
my_event = threading.Event()
|
||||
else:
|
||||
my_event = event
|
||||
self.lock.acquire()
|
||||
self.auth_event = event
|
||||
try:
|
||||
self.auth_event = my_event
|
||||
self.auth_method = 'publickey'
|
||||
self.username = username
|
||||
self.private_key = key
|
||||
self._request_auth()
|
||||
finally:
|
||||
self.lock.release()
|
||||
if event is not None:
|
||||
# caller wants to wait for event themselves
|
||||
return
|
||||
self._wait_for_response(my_event)
|
||||
|
||||
def auth_password(self, username, password, event):
|
||||
def auth_password(self, username, password, event=None):
|
||||
"""
|
||||
Authenticate to the server using a password. The username and password
|
||||
are sent over an encrypted link, and the given L{event} is triggered on
|
||||
success or failure. On success, L{is_authenticated} will return
|
||||
C{True}.
|
||||
are sent over an encrypted link.
|
||||
|
||||
If an C{event} is passed in, this method will return immediately, and
|
||||
the event will be triggered once authentication succeeds or fails. On
|
||||
success, L{is_authenticated} will return C{True}. On failure, you may
|
||||
use L{get_exception} to get more detailed error information.
|
||||
|
||||
Since 1.1, if no event is passed, this method will block until the
|
||||
authentication succeeds or fails. On failure, an exception is raised.
|
||||
Otherwise, the method simply returns.
|
||||
|
||||
@param username: the username to authenticate as.
|
||||
@type username: string
|
||||
|
@ -144,19 +173,32 @@ class Transport (BaseTransport):
|
|||
@param event: an event to trigger when the authentication attempt is
|
||||
complete (whether it was successful or not)
|
||||
@type event: threading.Event
|
||||
|
||||
@raise BadAuthenticationType: if password authentication isn't
|
||||
allowed by the server for this user (and no event was passed in).
|
||||
@raise SSHException: if the authentication failed (and no event was
|
||||
passed in).
|
||||
"""
|
||||
if (not self.active) or (not self.initial_kex_done):
|
||||
# we should never try to send the password unless we're on a secure link
|
||||
raise SSHException('No existing session')
|
||||
try:
|
||||
if event is None:
|
||||
my_event = threading.Event()
|
||||
else:
|
||||
my_event = event
|
||||
self.lock.acquire()
|
||||
self.auth_event = event
|
||||
try:
|
||||
self.auth_event = my_event
|
||||
self.auth_method = 'password'
|
||||
self.username = username
|
||||
self.password = password
|
||||
self._request_auth()
|
||||
finally:
|
||||
self.lock.release()
|
||||
if event is not None:
|
||||
# caller wants to wait for event themselves
|
||||
return
|
||||
self._wait_for_response(my_event)
|
||||
|
||||
|
||||
### internals...
|
||||
|
@ -198,6 +240,22 @@ class Transport (BaseTransport):
|
|||
m.add_string(str(key))
|
||||
return str(m)
|
||||
|
||||
def _wait_for_response(self, event):
|
||||
while True:
|
||||
event.wait(0.1)
|
||||
if not self.active:
|
||||
e = self.get_exception()
|
||||
if e is None:
|
||||
e = SSHException('Authentication failed.')
|
||||
raise e
|
||||
if event.isSet():
|
||||
break
|
||||
if not self.is_authenticated():
|
||||
e = self.get_exception()
|
||||
if e is None:
|
||||
e = SSHException('Authentication failed.')
|
||||
raise e
|
||||
|
||||
def _parse_service_request(self, m):
|
||||
service = m.get_string()
|
||||
if self.server_mode and (service == 'ssh-userauth'):
|
||||
|
@ -264,12 +322,12 @@ class Transport (BaseTransport):
|
|||
result = self.server_object.check_auth_none(username)
|
||||
elif method == 'password':
|
||||
changereq = m.get_boolean()
|
||||
password = m.get_string().decode('UTF-8')
|
||||
password = m.get_string().decode('UTF-8', 'replace')
|
||||
if changereq:
|
||||
# always treated as failure, since we don't support changing passwords, but collect
|
||||
# the list of valid auth types from the callback anyway
|
||||
self._log(DEBUG, 'Auth request to change passwords (rejected)')
|
||||
newpassword = m.get_string().decode('UTF-8')
|
||||
newpassword = m.get_string().decode('UTF-8', 'replace')
|
||||
result = AUTH_FAILED
|
||||
else:
|
||||
result = self.server_object.check_auth_password(username, password)
|
||||
|
@ -339,13 +397,13 @@ class Transport (BaseTransport):
|
|||
if partial:
|
||||
self._log(INFO, 'Authentication continues...')
|
||||
self._log(DEBUG, 'Methods: ' + str(partial))
|
||||
# FIXME - do something
|
||||
# FIXME: multi-part auth not supported
|
||||
pass
|
||||
if self.auth_method not in authlist:
|
||||
self.saved_exception = BadAuthenticationType('Bad authentication type', authlist)
|
||||
self._log(INFO, 'Authentication failed.')
|
||||
self.authenticated = False
|
||||
# FIXME: i don't think we need to close() necessarily here
|
||||
self.username = None
|
||||
self.close()
|
||||
if self.auth_event != None:
|
||||
self.auth_event.set()
|
||||
|
||||
|
|
|
@ -25,12 +25,31 @@ Exceptions defined by paramiko.
|
|||
|
||||
class SSHException (Exception):
|
||||
"""
|
||||
Exception thrown by failures in SSH2 protocol negotiation or logic errors.
|
||||
Exception raised by failures in SSH2 protocol negotiation or logic errors.
|
||||
"""
|
||||
pass
|
||||
|
||||
class PasswordRequiredException (SSHException):
|
||||
"""
|
||||
Exception thrown when a password is needed to unlock a private key file.
|
||||
Exception raised when a password is needed to unlock a private key file.
|
||||
"""
|
||||
pass
|
||||
|
||||
class BadAuthenticationType (SSHException):
|
||||
"""
|
||||
Exception raised when an authentication type (like password) is used, but
|
||||
the server isn't allowing that type. (It may only allow public-key, for
|
||||
example.)
|
||||
|
||||
@ivar allowed_types: list of allowed authentication types provided by the
|
||||
server (possible values are: C{"none"}, C{"password"}, and
|
||||
C{"publickey"}).
|
||||
@type allowed_types: list
|
||||
|
||||
@since: 1.1
|
||||
"""
|
||||
allowed_types = []
|
||||
|
||||
def __init__(self, explanation, types):
|
||||
SSHException.__init__(self, explanation)
|
||||
self.allowed_types = types
|
||||
|
|
|
@ -681,8 +681,8 @@ class BaseTransport (threading.Thread):
|
|||
return False
|
||||
|
||||
def accept(self, timeout=None):
|
||||
try:
|
||||
self.lock.acquire()
|
||||
try:
|
||||
if len(self.server_accepts) > 0:
|
||||
chan = self.server_accepts.pop(0)
|
||||
else:
|
||||
|
@ -740,8 +740,7 @@ class BaseTransport (threading.Thread):
|
|||
while 1:
|
||||
event.wait(0.1)
|
||||
if not self.active:
|
||||
e = self.saved_exception
|
||||
self.saved_exception = None
|
||||
e = self.get_exception()
|
||||
if e is not None:
|
||||
raise e
|
||||
raise SSHException('Negotiation failed.')
|
||||
|
@ -759,28 +758,35 @@ class BaseTransport (threading.Thread):
|
|||
self._log(DEBUG, 'Host key verified (%s)' % hostkey.get_name())
|
||||
|
||||
if (pkey is not None) or (password is not None):
|
||||
event.clear()
|
||||
if password is not None:
|
||||
self._log(DEBUG, 'Attempting password auth...')
|
||||
self.auth_password(username, password, event)
|
||||
self.auth_password(username, password)
|
||||
else:
|
||||
self._log(DEBUG, 'Attempting pkey auth...')
|
||||
self.auth_publickey(username, pkey, event)
|
||||
while 1:
|
||||
event.wait(0.1)
|
||||
if not self.active:
|
||||
e = self.saved_exception
|
||||
self.saved_exception = None
|
||||
if e is not None:
|
||||
raise e
|
||||
raise SSHException('Authentication failed.')
|
||||
if event.isSet():
|
||||
break
|
||||
if not self.is_authenticated():
|
||||
raise SSHException('Authentication failed.')
|
||||
self._log(DEBUG, 'Attempting public-key auth...')
|
||||
self.auth_publickey(username, pkey)
|
||||
|
||||
return
|
||||
|
||||
def get_exception(self):
|
||||
"""
|
||||
Return any exception that happened during the last server request.
|
||||
This can be used to fetch more specific error information after using
|
||||
calls like L{start_client}. The exception (if any) is cleared after
|
||||
this call.
|
||||
|
||||
@return: an exception, or C{None} if there is no stored exception.
|
||||
@rtype: Exception
|
||||
|
||||
@since: 1.1
|
||||
"""
|
||||
self.lock.acquire()
|
||||
try:
|
||||
e = self.saved_exception
|
||||
self.saved_exception = None
|
||||
return e
|
||||
finally:
|
||||
self.lock.release()
|
||||
|
||||
def set_subsystem_handler(self, name, handler, *larg, **kwarg):
|
||||
"""
|
||||
Set the handler class for a subsystem in server mode. If a reqeuest
|
||||
|
|
|
@ -22,14 +22,17 @@
|
|||
Some unit tests for the ssh2 protocol in Transport.
|
||||
"""
|
||||
|
||||
import unittest, threading
|
||||
from paramiko import Transport, SecurityOptions, ServerInterface, RSAKey, DSSKey
|
||||
import sys, unittest, threading
|
||||
from paramiko import Transport, SecurityOptions, ServerInterface, RSAKey, DSSKey, \
|
||||
SSHException, BadAuthenticationType
|
||||
from paramiko import AUTH_FAILED, AUTH_SUCCESSFUL
|
||||
from loop import LoopSocket
|
||||
|
||||
|
||||
class NullServer (ServerInterface):
|
||||
def get_allowed_auths(self, username):
|
||||
if username == 'slowdive':
|
||||
return 'publickey,password'
|
||||
return 'publickey'
|
||||
|
||||
def check_auth_password(self, username, password):
|
||||
|
@ -90,4 +93,51 @@ class TransportTest (unittest.TestCase):
|
|||
self.assert_(event.isSet())
|
||||
self.assert_(self.ts.is_active())
|
||||
|
||||
def test_3_bad_auth_type(self):
|
||||
"""
|
||||
verify that we get the right exception when an unsupported auth
|
||||
type is requested.
|
||||
"""
|
||||
host_key = RSAKey.from_private_key_file('tests/test_rsa.key')
|
||||
public_host_key = RSAKey(data=str(host_key))
|
||||
self.ts.add_server_key(host_key)
|
||||
event = threading.Event()
|
||||
server = NullServer()
|
||||
self.assert_(not event.isSet())
|
||||
self.ts.start_server(event, server)
|
||||
self.tc.ultra_debug = True
|
||||
try:
|
||||
self.tc.connect(hostkey=public_host_key,
|
||||
username='unknown', password='error')
|
||||
self.assert_(False)
|
||||
except:
|
||||
etype, evalue, etb = sys.exc_info()
|
||||
self.assertEquals(BadAuthenticationType, etype)
|
||||
self.assertEquals(['publickey'], evalue.allowed_types)
|
||||
|
||||
def test_4_bad_password(self):
|
||||
"""
|
||||
verify that a bad password gets the right exception, and that a retry
|
||||
with the right password works.
|
||||
"""
|
||||
host_key = RSAKey.from_private_key_file('tests/test_rsa.key')
|
||||
public_host_key = RSAKey(data=str(host_key))
|
||||
self.ts.add_server_key(host_key)
|
||||
event = threading.Event()
|
||||
server = NullServer()
|
||||
self.assert_(not event.isSet())
|
||||
self.ts.start_server(event, server)
|
||||
self.tc.ultra_debug = True
|
||||
self.tc.connect(hostkey=public_host_key)
|
||||
try:
|
||||
self.tc.auth_password(username='slowdive', password='error')
|
||||
self.assert_(False)
|
||||
except:
|
||||
etype, evalue, etb = sys.exc_info()
|
||||
self.assertEquals(SSHException, etype)
|
||||
self.tc.auth_password(username='slowdive', password='pygmalion')
|
||||
event.wait(1.0)
|
||||
self.assert_(event.isSet())
|
||||
self.assert_(self.ts.is_active())
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue