allow multiple key files to be specified in SSHClient. suggested by Bernhard Walle.
This commit is contained in:
parent
305f5e09a5
commit
c2ef48cf18
|
@ -255,9 +255,9 @@ class SSHClient (object):
|
||||||
@type password: str
|
@type password: str
|
||||||
@param pkey: an optional private key to use for authentication
|
@param pkey: an optional private key to use for authentication
|
||||||
@type pkey: L{PKey}
|
@type pkey: L{PKey}
|
||||||
@param key_filename: the filename of an optional private key to use
|
@param key_filename: the filename, or list of filenames, of optional
|
||||||
for authentication
|
private key(s) to try for authentication
|
||||||
@type key_filename: str
|
@type key_filename: str or list(str)
|
||||||
@param timeout: an optional timeout (in seconds) for the TCP connect
|
@param timeout: an optional timeout (in seconds) for the TCP connect
|
||||||
@type timeout: float
|
@type timeout: float
|
||||||
@param allow_agent: set to False to disable connecting to the SSH agent
|
@param allow_agent: set to False to disable connecting to the SSH agent
|
||||||
|
@ -306,7 +306,13 @@ class SSHClient (object):
|
||||||
if username is None:
|
if username is None:
|
||||||
username = getpass.getuser()
|
username = getpass.getuser()
|
||||||
|
|
||||||
self._auth(username, password, pkey, key_filename, allow_agent, look_for_keys)
|
if key_filename is None:
|
||||||
|
key_filenames = []
|
||||||
|
elif isinstance(key_filename, (str, unicode)):
|
||||||
|
key_filenames = [ key_filename ]
|
||||||
|
else:
|
||||||
|
key_filenames = key_filename
|
||||||
|
self._auth(username, password, pkey, key_filenames, allow_agent, look_for_keys)
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
"""
|
"""
|
||||||
|
@ -382,7 +388,7 @@ class SSHClient (object):
|
||||||
"""
|
"""
|
||||||
return self._transport
|
return self._transport
|
||||||
|
|
||||||
def _auth(self, username, password, pkey, key_filename, allow_agent, look_for_keys):
|
def _auth(self, username, password, pkey, key_filenames, allow_agent, look_for_keys):
|
||||||
"""
|
"""
|
||||||
Try, in order:
|
Try, in order:
|
||||||
|
|
||||||
|
@ -403,7 +409,7 @@ class SSHClient (object):
|
||||||
except SSHException, e:
|
except SSHException, e:
|
||||||
saved_exception = e
|
saved_exception = e
|
||||||
|
|
||||||
if key_filename is not None:
|
for key_filename in key_filenames:
|
||||||
for pkey_class in (RSAKey, DSSKey):
|
for pkey_class in (RSAKey, DSSKey):
|
||||||
try:
|
try:
|
||||||
key = pkey_class.from_private_key_file(key_filename, password)
|
key = pkey_class.from_private_key_file(key_filename, password)
|
||||||
|
|
|
@ -148,8 +148,25 @@ class SSHClientTest (unittest.TestCase):
|
||||||
stdin.close()
|
stdin.close()
|
||||||
stdout.close()
|
stdout.close()
|
||||||
stderr.close()
|
stderr.close()
|
||||||
|
|
||||||
|
def test_3_multiple_key_files(self):
|
||||||
|
"""
|
||||||
|
verify that SSHClient accepts and tries multiple key files.
|
||||||
|
"""
|
||||||
|
host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key')
|
||||||
|
public_host_key = paramiko.RSAKey(data=str(host_key))
|
||||||
|
|
||||||
|
self.tc = paramiko.SSHClient()
|
||||||
|
self.tc.get_host_keys().add(self.addr, 'ssh-rsa', public_host_key)
|
||||||
|
self.tc.connect(self.addr, self.port, username='slowdive', key_filename=[ 'tests/test_rsa.key', 'tests/test_dss.key' ])
|
||||||
|
|
||||||
def test_3_auto_add_policy(self):
|
self.event.wait(1.0)
|
||||||
|
self.assert_(self.event.isSet())
|
||||||
|
self.assert_(self.ts.is_active())
|
||||||
|
self.assertEquals('slowdive', self.ts.get_username())
|
||||||
|
self.assertEquals(True, self.ts.is_authenticated())
|
||||||
|
|
||||||
|
def test_4_auto_add_policy(self):
|
||||||
"""
|
"""
|
||||||
verify that SSHClient's AutoAddPolicy works.
|
verify that SSHClient's AutoAddPolicy works.
|
||||||
"""
|
"""
|
||||||
|
@ -169,7 +186,7 @@ class SSHClientTest (unittest.TestCase):
|
||||||
self.assertEquals(1, len(self.tc.get_host_keys()))
|
self.assertEquals(1, len(self.tc.get_host_keys()))
|
||||||
self.assertEquals(public_host_key, self.tc.get_host_keys()[self.addr]['ssh-rsa'])
|
self.assertEquals(public_host_key, self.tc.get_host_keys()[self.addr]['ssh-rsa'])
|
||||||
|
|
||||||
def test_4_cleanup(self):
|
def test_5_cleanup(self):
|
||||||
"""
|
"""
|
||||||
verify that when an SSHClient is collected, its transport (and the
|
verify that when an SSHClient is collected, its transport (and the
|
||||||
transport's packetizer) is closed.
|
transport's packetizer) is closed.
|
||||||
|
|
Loading…
Reference in New Issue