bug 426925: lookup host keys correctly when they have a different port.
This commit is contained in:
parent
c628faa102
commit
71e872e23a
|
@ -36,6 +36,8 @@ from paramiko.ssh_exception import SSHException, BadHostKeyException
|
||||||
from paramiko.transport import Transport
|
from paramiko.transport import Transport
|
||||||
|
|
||||||
|
|
||||||
|
SSH_PORT = 22
|
||||||
|
|
||||||
class MissingHostKeyPolicy (object):
|
class MissingHostKeyPolicy (object):
|
||||||
"""
|
"""
|
||||||
Interface for defining the policy that L{SSHClient} should use when the
|
Interface for defining the policy that L{SSHClient} should use when the
|
||||||
|
@ -223,7 +225,7 @@ class SSHClient (object):
|
||||||
"""
|
"""
|
||||||
self._policy = policy
|
self._policy = policy
|
||||||
|
|
||||||
def connect(self, hostname, port=22, username=None, password=None, pkey=None,
|
def connect(self, hostname, port=SSH_PORT, username=None, password=None, pkey=None,
|
||||||
key_filename=None, timeout=None, allow_agent=True, look_for_keys=True):
|
key_filename=None, timeout=None, allow_agent=True, look_for_keys=True):
|
||||||
"""
|
"""
|
||||||
Connect to an SSH server and authenticate to it. The server's host key
|
Connect to an SSH server and authenticate to it. The server's host key
|
||||||
|
@ -297,12 +299,16 @@ class SSHClient (object):
|
||||||
server_key = t.get_remote_server_key()
|
server_key = t.get_remote_server_key()
|
||||||
keytype = server_key.get_name()
|
keytype = server_key.get_name()
|
||||||
|
|
||||||
our_server_key = self._system_host_keys.get(hostname, {}).get(keytype, None)
|
if port == SSH_PORT:
|
||||||
|
server_hostkey_name = hostname
|
||||||
|
else:
|
||||||
|
server_hostkey_name = "[%s]:%d" % (hostname, port)
|
||||||
|
our_server_key = self._system_host_keys.get(server_hostkey_name, {}).get(keytype, None)
|
||||||
if our_server_key is None:
|
if our_server_key is None:
|
||||||
our_server_key = self._host_keys.get(hostname, {}).get(keytype, None)
|
our_server_key = self._host_keys.get(server_hostkey_name, {}).get(keytype, None)
|
||||||
if our_server_key is None:
|
if our_server_key is None:
|
||||||
# will raise exception if the key is rejected; let that fall out
|
# will raise exception if the key is rejected; let that fall out
|
||||||
self._policy.missing_host_key(self, hostname, server_key)
|
self._policy.missing_host_key(self, server_hostkey_name, server_key)
|
||||||
# if the callback returns, assume the key is ok
|
# if the callback returns, assume the key is ok
|
||||||
our_server_key = server_key
|
our_server_key = server_key
|
||||||
|
|
||||||
|
|
|
@ -31,7 +31,7 @@ import paramiko
|
||||||
|
|
||||||
|
|
||||||
class NullServer (paramiko.ServerInterface):
|
class NullServer (paramiko.ServerInterface):
|
||||||
|
|
||||||
def get_allowed_auths(self, username):
|
def get_allowed_auths(self, username):
|
||||||
if username == 'slowdive':
|
if username == 'slowdive':
|
||||||
return 'publickey,password'
|
return 'publickey,password'
|
||||||
|
@ -46,7 +46,7 @@ class NullServer (paramiko.ServerInterface):
|
||||||
if (key.get_name() == 'ssh-dss') and (hexlify(key.get_fingerprint()) == '4478f0b9a23cc5182009ff755bc1d26c'):
|
if (key.get_name() == 'ssh-dss') and (hexlify(key.get_fingerprint()) == '4478f0b9a23cc5182009ff755bc1d26c'):
|
||||||
return paramiko.AUTH_SUCCESSFUL
|
return paramiko.AUTH_SUCCESSFUL
|
||||||
return paramiko.AUTH_FAILED
|
return paramiko.AUTH_FAILED
|
||||||
|
|
||||||
def check_channel_request(self, kind, chanid):
|
def check_channel_request(self, kind, chanid):
|
||||||
return paramiko.OPEN_SUCCEEDED
|
return paramiko.OPEN_SUCCEEDED
|
||||||
|
|
||||||
|
@ -81,17 +81,17 @@ class SSHClientTest (unittest.TestCase):
|
||||||
self.ts.add_server_key(host_key)
|
self.ts.add_server_key(host_key)
|
||||||
server = NullServer()
|
server = NullServer()
|
||||||
self.ts.start_server(self.event, server)
|
self.ts.start_server(self.event, server)
|
||||||
|
|
||||||
|
|
||||||
def test_1_client(self):
|
def test_1_client(self):
|
||||||
"""
|
"""
|
||||||
verify that the SSHClient stuff works too.
|
verify that the SSHClient stuff works too.
|
||||||
"""
|
"""
|
||||||
host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key')
|
host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key')
|
||||||
public_host_key = paramiko.RSAKey(data=str(host_key))
|
public_host_key = paramiko.RSAKey(data=str(host_key))
|
||||||
|
|
||||||
self.tc = paramiko.SSHClient()
|
self.tc = paramiko.SSHClient()
|
||||||
self.tc.get_host_keys().add(self.addr, 'ssh-rsa', public_host_key)
|
self.tc.get_host_keys().add('[%s]:%d' % (self.addr, self.port), 'ssh-rsa', public_host_key)
|
||||||
self.tc.connect(self.addr, self.port, username='slowdive', password='pygmalion')
|
self.tc.connect(self.addr, self.port, username='slowdive', password='pygmalion')
|
||||||
|
|
||||||
self.event.wait(1.0)
|
self.event.wait(1.0)
|
||||||
|
@ -111,20 +111,20 @@ class SSHClientTest (unittest.TestCase):
|
||||||
self.assertEquals('', stdout.readline())
|
self.assertEquals('', stdout.readline())
|
||||||
self.assertEquals('This is on stderr.\n', stderr.readline())
|
self.assertEquals('This is on stderr.\n', stderr.readline())
|
||||||
self.assertEquals('', stderr.readline())
|
self.assertEquals('', stderr.readline())
|
||||||
|
|
||||||
stdin.close()
|
stdin.close()
|
||||||
stdout.close()
|
stdout.close()
|
||||||
stderr.close()
|
stderr.close()
|
||||||
|
|
||||||
def test_2_client_dsa(self):
|
def test_2_client_dsa(self):
|
||||||
"""
|
"""
|
||||||
verify that SSHClient works with a DSA key.
|
verify that SSHClient works with a DSA key.
|
||||||
"""
|
"""
|
||||||
host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key')
|
host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key')
|
||||||
public_host_key = paramiko.RSAKey(data=str(host_key))
|
public_host_key = paramiko.RSAKey(data=str(host_key))
|
||||||
|
|
||||||
self.tc = paramiko.SSHClient()
|
self.tc = paramiko.SSHClient()
|
||||||
self.tc.get_host_keys().add(self.addr, 'ssh-rsa', public_host_key)
|
self.tc.get_host_keys().add('[%s]:%d' % (self.addr, self.port), 'ssh-rsa', public_host_key)
|
||||||
self.tc.connect(self.addr, self.port, username='slowdive', key_filename='tests/test_dss.key')
|
self.tc.connect(self.addr, self.port, username='slowdive', key_filename='tests/test_dss.key')
|
||||||
|
|
||||||
self.event.wait(1.0)
|
self.event.wait(1.0)
|
||||||
|
@ -144,20 +144,20 @@ class SSHClientTest (unittest.TestCase):
|
||||||
self.assertEquals('', stdout.readline())
|
self.assertEquals('', stdout.readline())
|
||||||
self.assertEquals('This is on stderr.\n', stderr.readline())
|
self.assertEquals('This is on stderr.\n', stderr.readline())
|
||||||
self.assertEquals('', stderr.readline())
|
self.assertEquals('', stderr.readline())
|
||||||
|
|
||||||
stdin.close()
|
stdin.close()
|
||||||
stdout.close()
|
stdout.close()
|
||||||
stderr.close()
|
stderr.close()
|
||||||
|
|
||||||
def test_3_multiple_key_files(self):
|
def test_3_multiple_key_files(self):
|
||||||
"""
|
"""
|
||||||
verify that SSHClient accepts and tries multiple key files.
|
verify that SSHClient accepts and tries multiple key files.
|
||||||
"""
|
"""
|
||||||
host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key')
|
host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key')
|
||||||
public_host_key = paramiko.RSAKey(data=str(host_key))
|
public_host_key = paramiko.RSAKey(data=str(host_key))
|
||||||
|
|
||||||
self.tc = paramiko.SSHClient()
|
self.tc = paramiko.SSHClient()
|
||||||
self.tc.get_host_keys().add(self.addr, 'ssh-rsa', public_host_key)
|
self.tc.get_host_keys().add('[%s]:%d' % (self.addr, self.port), 'ssh-rsa', public_host_key)
|
||||||
self.tc.connect(self.addr, self.port, username='slowdive', key_filename=[ 'tests/test_rsa.key', 'tests/test_dss.key' ])
|
self.tc.connect(self.addr, self.port, username='slowdive', key_filename=[ 'tests/test_rsa.key', 'tests/test_dss.key' ])
|
||||||
|
|
||||||
self.event.wait(1.0)
|
self.event.wait(1.0)
|
||||||
|
@ -172,7 +172,7 @@ class SSHClientTest (unittest.TestCase):
|
||||||
"""
|
"""
|
||||||
host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key')
|
host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key')
|
||||||
public_host_key = paramiko.RSAKey(data=str(host_key))
|
public_host_key = paramiko.RSAKey(data=str(host_key))
|
||||||
|
|
||||||
self.tc = paramiko.SSHClient()
|
self.tc = paramiko.SSHClient()
|
||||||
self.tc.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
self.tc.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
||||||
self.assertEquals(0, len(self.tc.get_host_keys()))
|
self.assertEquals(0, len(self.tc.get_host_keys()))
|
||||||
|
@ -184,7 +184,7 @@ class SSHClientTest (unittest.TestCase):
|
||||||
self.assertEquals('slowdive', self.ts.get_username())
|
self.assertEquals('slowdive', self.ts.get_username())
|
||||||
self.assertEquals(True, self.ts.is_authenticated())
|
self.assertEquals(True, self.ts.is_authenticated())
|
||||||
self.assertEquals(1, len(self.tc.get_host_keys()))
|
self.assertEquals(1, len(self.tc.get_host_keys()))
|
||||||
self.assertEquals(public_host_key, self.tc.get_host_keys()[self.addr]['ssh-rsa'])
|
self.assertEquals(public_host_key, self.tc.get_host_keys()['[%s]:%d' % (self.addr, self.port)]['ssh-rsa'])
|
||||||
|
|
||||||
def test_5_cleanup(self):
|
def test_5_cleanup(self):
|
||||||
"""
|
"""
|
||||||
|
@ -193,7 +193,7 @@ class SSHClientTest (unittest.TestCase):
|
||||||
"""
|
"""
|
||||||
host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key')
|
host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key')
|
||||||
public_host_key = paramiko.RSAKey(data=str(host_key))
|
public_host_key = paramiko.RSAKey(data=str(host_key))
|
||||||
|
|
||||||
self.tc = paramiko.SSHClient()
|
self.tc = paramiko.SSHClient()
|
||||||
self.tc.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
self.tc.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
||||||
self.assertEquals(0, len(self.tc.get_host_keys()))
|
self.assertEquals(0, len(self.tc.get_host_keys()))
|
||||||
|
@ -202,7 +202,7 @@ class SSHClientTest (unittest.TestCase):
|
||||||
self.event.wait(1.0)
|
self.event.wait(1.0)
|
||||||
self.assert_(self.event.isSet())
|
self.assert_(self.event.isSet())
|
||||||
self.assert_(self.ts.is_active())
|
self.assert_(self.ts.is_active())
|
||||||
|
|
||||||
p = weakref.ref(self.tc._transport.packetizer)
|
p = weakref.ref(self.tc._transport.packetizer)
|
||||||
self.assert_(p() is not None)
|
self.assert_(p() is not None)
|
||||||
del self.tc
|
del self.tc
|
||||||
|
@ -211,4 +211,4 @@ class SSHClientTest (unittest.TestCase):
|
||||||
while (time.time() - st < 5.0) and (p() is not None):
|
while (time.time() - st < 5.0) and (p() is not None):
|
||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
self.assert_(p() is None)
|
self.assert_(p() is None)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue