fix the utf-8 password bug for good (aka bug 177117) and add unit tests
this time.
This commit is contained in:
Robey Pointer 2008-01-23 20:50:17 -08:00
parent 953392c0a1
commit 888aa8d5b7
3 changed files with 71 additions and 58 deletions

View File

@ -200,7 +200,7 @@ class AuthHandler (object):
password = self.password password = self.password
if isinstance(password, unicode): if isinstance(password, unicode):
password = password.encode('UTF-8') password = password.encode('UTF-8')
m.add_string(self.password.encode('UTF-8')) m.add_string(password)
elif self.auth_method == 'publickey': elif self.auth_method == 'publickey':
m.add_boolean(True) m.add_boolean(True)
m.add_string(self.private_key.get_name()) m.add_string(self.private_key.get_name())
@ -283,12 +283,22 @@ class AuthHandler (object):
result = self.transport.server_object.check_auth_none(username) result = self.transport.server_object.check_auth_none(username)
elif method == 'password': elif method == 'password':
changereq = m.get_boolean() changereq = m.get_boolean()
password = m.get_string().decode('UTF-8', 'replace') password = m.get_string()
try:
password = password.decode('UTF-8')
except UnicodeError:
# some clients/servers expect non-utf-8 passwords!
# in this case, just return the raw byte string.
pass
if changereq: if changereq:
# always treated as failure, since we don't support changing passwords, but collect # always treated as failure, since we don't support changing passwords, but collect
# the list of valid auth types from the callback anyway # the list of valid auth types from the callback anyway
self.transport._log(DEBUG, 'Auth request to change passwords (rejected)') self.transport._log(DEBUG, 'Auth request to change passwords (rejected)')
newpassword = m.get_string().decode('UTF-8', 'replace') newpassword = m.get_string()
try:
newpassword = newpassword.decode('UTF-8', 'replace')
except UnicodeError:
pass
result = AUTH_FAILED result = AUTH_FAILED
else: else:
result = self.transport.server_object.check_auth_password(username, password) result = self.transport.server_object.check_auth_password(username, password)

View File

@ -1071,9 +1071,9 @@ class Transport (threading.Thread):
step. Otherwise, in the normal case, an empty list is returned. step. Otherwise, in the normal case, an empty list is returned.
@param username: the username to authenticate as @param username: the username to authenticate as
@type username: string @type username: str
@param password: the password to authenticate with @param password: the password to authenticate with
@type password: string @type password: str or unicode
@param event: an event to trigger when the authentication attempt is @param event: an event to trigger when the authentication attempt is
complete (whether it was successful or not) complete (whether it was successful or not)
@type event: threading.Event @type event: threading.Event

View File

@ -48,6 +48,10 @@ class NullServer (ServerInterface):
return 'password' return 'password'
if username == 'commie': if username == 'commie':
return 'keyboard-interactive' return 'keyboard-interactive'
if username == 'utf8':
return 'password'
if username == 'non-utf8':
return 'password'
return 'publickey' return 'publickey'
def check_auth_password(self, username, password): def check_auth_password(self, username, password):
@ -59,7 +63,9 @@ class NullServer (ServerInterface):
if self.paranoid_did_public_key: if self.paranoid_did_public_key:
return AUTH_SUCCESSFUL return AUTH_SUCCESSFUL
return AUTH_PARTIALLY_SUCCESSFUL return AUTH_PARTIALLY_SUCCESSFUL
if (username == 'utf8') and (password == u'\u2022'.encode('utf-8')): if (username == 'utf8') and (password == u'\u2022'):
return AUTH_SUCCESSFUL
if (username == 'non-utf8') and (password == '\xff'):
return AUTH_SUCCESSFUL return AUTH_SUCCESSFUL
return AUTH_FAILED return AUTH_FAILED
@ -99,21 +105,29 @@ class AuthTest (unittest.TestCase):
self.ts.close() self.ts.close()
self.socks.close() self.socks.close()
self.sockc.close() self.sockc.close()
def start_server(self):
host_key = RSAKey.from_private_key_file('tests/test_rsa.key')
self.public_host_key = RSAKey(data=str(host_key))
self.ts.add_server_key(host_key)
self.event = threading.Event()
self.server = NullServer()
self.assert_(not self.event.isSet())
self.ts.start_server(self.event, self.server)
def verify_finished(self):
self.event.wait(1.0)
self.assert_(self.event.isSet())
self.assert_(self.ts.is_active())
def test_1_bad_auth_type(self): def test_1_bad_auth_type(self):
""" """
verify that we get the right exception when an unsupported auth verify that we get the right exception when an unsupported auth
type is requested. type is requested.
""" """
host_key = RSAKey.from_private_key_file('tests/test_rsa.key') self.start_server()
public_host_key = RSAKey(data=str(host_key))
self.ts.add_server_key(host_key)
event = threading.Event()
server = NullServer()
self.assert_(not event.isSet())
self.ts.start_server(event, server)
try: try:
self.tc.connect(hostkey=public_host_key, self.tc.connect(hostkey=self.public_host_key,
username='unknown', password='error') username='unknown', password='error')
self.assert_(False) self.assert_(False)
except: except:
@ -126,14 +140,8 @@ class AuthTest (unittest.TestCase):
verify that a bad password gets the right exception, and that a retry verify that a bad password gets the right exception, and that a retry
with the right password works. with the right password works.
""" """
host_key = RSAKey.from_private_key_file('tests/test_rsa.key') self.start_server()
public_host_key = RSAKey(data=str(host_key)) self.tc.connect(hostkey=self.public_host_key)
self.ts.add_server_key(host_key)
event = threading.Event()
server = NullServer()
self.assert_(not event.isSet())
self.ts.start_server(event, server)
self.tc.connect(hostkey=public_host_key)
try: try:
self.tc.auth_password(username='slowdive', password='error') self.tc.auth_password(username='slowdive', password='error')
self.assert_(False) self.assert_(False)
@ -141,43 +149,27 @@ class AuthTest (unittest.TestCase):
etype, evalue, etb = sys.exc_info() etype, evalue, etb = sys.exc_info()
self.assert_(issubclass(etype, SSHException)) self.assert_(issubclass(etype, SSHException))
self.tc.auth_password(username='slowdive', password='pygmalion') self.tc.auth_password(username='slowdive', password='pygmalion')
event.wait(1.0) self.verify_finished()
self.assert_(event.isSet())
self.assert_(self.ts.is_active())
def test_3_multipart_auth(self): def test_3_multipart_auth(self):
""" """
verify that multipart auth works. verify that multipart auth works.
""" """
host_key = RSAKey.from_private_key_file('tests/test_rsa.key') self.start_server()
public_host_key = RSAKey(data=str(host_key)) self.tc.connect(hostkey=self.public_host_key)
self.ts.add_server_key(host_key)
event = threading.Event()
server = NullServer()
self.assert_(not event.isSet())
self.ts.start_server(event, server)
self.tc.connect(hostkey=public_host_key)
remain = self.tc.auth_password(username='paranoid', password='paranoid') remain = self.tc.auth_password(username='paranoid', password='paranoid')
self.assertEquals(['publickey'], remain) self.assertEquals(['publickey'], remain)
key = DSSKey.from_private_key_file('tests/test_dss.key') key = DSSKey.from_private_key_file('tests/test_dss.key')
remain = self.tc.auth_publickey(username='paranoid', key=key) remain = self.tc.auth_publickey(username='paranoid', key=key)
self.assertEquals([], remain) self.assertEquals([], remain)
event.wait(1.0) self.verify_finished()
self.assert_(event.isSet())
self.assert_(self.ts.is_active())
def test_4_interactive_auth(self): def test_4_interactive_auth(self):
""" """
verify keyboard-interactive auth works. verify keyboard-interactive auth works.
""" """
host_key = RSAKey.from_private_key_file('tests/test_rsa.key') self.start_server()
public_host_key = RSAKey(data=str(host_key)) self.tc.connect(hostkey=self.public_host_key)
self.ts.add_server_key(host_key)
event = threading.Event()
server = NullServer()
self.assert_(not event.isSet())
self.ts.start_server(event, server)
self.tc.connect(hostkey=public_host_key)
def handler(title, instructions, prompts): def handler(title, instructions, prompts):
self.got_title = title self.got_title = title
@ -188,25 +180,36 @@ class AuthTest (unittest.TestCase):
self.assertEquals(self.got_title, 'password') self.assertEquals(self.got_title, 'password')
self.assertEquals(self.got_prompts, [('Password', False)]) self.assertEquals(self.got_prompts, [('Password', False)])
self.assertEquals([], remain) self.assertEquals([], remain)
event.wait(1.0) self.verify_finished()
self.assert_(event.isSet())
self.assert_(self.ts.is_active())
def test_5_interactive_auth_fallback(self): def test_5_interactive_auth_fallback(self):
""" """
verify that a password auth attempt will fallback to "interactive" verify that a password auth attempt will fallback to "interactive"
if password auth isn't supported but interactive is. if password auth isn't supported but interactive is.
""" """
host_key = RSAKey.from_private_key_file('tests/test_rsa.key') self.start_server()
public_host_key = RSAKey(data=str(host_key)) self.tc.connect(hostkey=self.public_host_key)
self.ts.add_server_key(host_key)
event = threading.Event()
server = NullServer()
self.assert_(not event.isSet())
self.ts.start_server(event, server)
self.tc.connect(hostkey=public_host_key)
remain = self.tc.auth_password('commie', 'cat') remain = self.tc.auth_password('commie', 'cat')
self.assertEquals([], remain) self.assertEquals([], remain)
event.wait(1.0) self.verify_finished()
self.assert_(event.isSet())
self.assert_(self.ts.is_active()) def test_6_auth_utf8(self):
"""
verify that utf-8 encoding happens in authentication.
"""
self.start_server()
self.tc.connect(hostkey=self.public_host_key)
remain = self.tc.auth_password('utf8', u'\u2022')
self.assertEquals([], remain)
self.verify_finished()
def test_7_auth_non_utf8(self):
"""
verify that non-utf-8 encoded passwords can be used for broken
servers.
"""
self.start_server()
self.tc.connect(hostkey=self.public_host_key)
remain = self.tc.auth_password('non-utf8', '\xff')
self.assertEquals([], remain)
self.verify_finished()