allow multiple key files to be specified in SSHClient.
suggested by Bernhard Walle.
This commit is contained in:
Robey Pointer 2008-07-06 13:37:06 -07:00
parent 305f5e09a5
commit c2ef48cf18
2 changed files with 31 additions and 8 deletions

View File

@ -255,9 +255,9 @@ class SSHClient (object):
@type password: str @type password: str
@param pkey: an optional private key to use for authentication @param pkey: an optional private key to use for authentication
@type pkey: L{PKey} @type pkey: L{PKey}
@param key_filename: the filename of an optional private key to use @param key_filename: the filename, or list of filenames, of optional
for authentication private key(s) to try for authentication
@type key_filename: str @type key_filename: str or list(str)
@param timeout: an optional timeout (in seconds) for the TCP connect @param timeout: an optional timeout (in seconds) for the TCP connect
@type timeout: float @type timeout: float
@param allow_agent: set to False to disable connecting to the SSH agent @param allow_agent: set to False to disable connecting to the SSH agent
@ -306,7 +306,13 @@ class SSHClient (object):
if username is None: if username is None:
username = getpass.getuser() username = getpass.getuser()
self._auth(username, password, pkey, key_filename, allow_agent, look_for_keys) if key_filename is None:
key_filenames = []
elif isinstance(key_filename, (str, unicode)):
key_filenames = [ key_filename ]
else:
key_filenames = key_filename
self._auth(username, password, pkey, key_filenames, allow_agent, look_for_keys)
def close(self): def close(self):
""" """
@ -382,7 +388,7 @@ class SSHClient (object):
""" """
return self._transport return self._transport
def _auth(self, username, password, pkey, key_filename, allow_agent, look_for_keys): def _auth(self, username, password, pkey, key_filenames, allow_agent, look_for_keys):
""" """
Try, in order: Try, in order:
@ -403,7 +409,7 @@ class SSHClient (object):
except SSHException, e: except SSHException, e:
saved_exception = e saved_exception = e
if key_filename is not None: for key_filename in key_filenames:
for pkey_class in (RSAKey, DSSKey): for pkey_class in (RSAKey, DSSKey):
try: try:
key = pkey_class.from_private_key_file(key_filename, password) key = pkey_class.from_private_key_file(key_filename, password)

View File

@ -149,7 +149,24 @@ class SSHClientTest (unittest.TestCase):
stdout.close() stdout.close()
stderr.close() stderr.close()
def test_3_auto_add_policy(self): def test_3_multiple_key_files(self):
"""
verify that SSHClient accepts and tries multiple key files.
"""
host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key')
public_host_key = paramiko.RSAKey(data=str(host_key))
self.tc = paramiko.SSHClient()
self.tc.get_host_keys().add(self.addr, 'ssh-rsa', public_host_key)
self.tc.connect(self.addr, self.port, username='slowdive', key_filename=[ 'tests/test_rsa.key', 'tests/test_dss.key' ])
self.event.wait(1.0)
self.assert_(self.event.isSet())
self.assert_(self.ts.is_active())
self.assertEquals('slowdive', self.ts.get_username())
self.assertEquals(True, self.ts.is_authenticated())
def test_4_auto_add_policy(self):
""" """
verify that SSHClient's AutoAddPolicy works. verify that SSHClient's AutoAddPolicy works.
""" """
@ -169,7 +186,7 @@ class SSHClientTest (unittest.TestCase):
self.assertEquals(1, len(self.tc.get_host_keys())) self.assertEquals(1, len(self.tc.get_host_keys()))
self.assertEquals(public_host_key, self.tc.get_host_keys()[self.addr]['ssh-rsa']) self.assertEquals(public_host_key, self.tc.get_host_keys()[self.addr]['ssh-rsa'])
def test_4_cleanup(self): def test_5_cleanup(self):
""" """
verify that when an SSHClient is collected, its transport (and the verify that when an SSHClient is collected, its transport (and the
transport's packetizer) is closed. transport's packetizer) is closed.