bug 426925: lookup host keys correctly when they have a different port.

This commit is contained in:
Robey Pointer 2009-11-01 21:28:47 -08:00
parent c628faa102
commit 71e872e23a
2 changed files with 29 additions and 23 deletions

View File

@ -36,6 +36,8 @@ from paramiko.ssh_exception import SSHException, BadHostKeyException
from paramiko.transport import Transport
SSH_PORT = 22
class MissingHostKeyPolicy (object):
"""
Interface for defining the policy that L{SSHClient} should use when the
@ -223,7 +225,7 @@ class SSHClient (object):
"""
self._policy = policy
def connect(self, hostname, port=22, username=None, password=None, pkey=None,
def connect(self, hostname, port=SSH_PORT, username=None, password=None, pkey=None,
key_filename=None, timeout=None, allow_agent=True, look_for_keys=True):
"""
Connect to an SSH server and authenticate to it. The server's host key
@ -297,12 +299,16 @@ class SSHClient (object):
server_key = t.get_remote_server_key()
keytype = server_key.get_name()
our_server_key = self._system_host_keys.get(hostname, {}).get(keytype, None)
if port == SSH_PORT:
server_hostkey_name = hostname
else:
server_hostkey_name = "[%s]:%d" % (hostname, port)
our_server_key = self._system_host_keys.get(server_hostkey_name, {}).get(keytype, None)
if our_server_key is None:
our_server_key = self._host_keys.get(hostname, {}).get(keytype, None)
our_server_key = self._host_keys.get(server_hostkey_name, {}).get(keytype, None)
if our_server_key is None:
# will raise exception if the key is rejected; let that fall out
self._policy.missing_host_key(self, hostname, server_key)
self._policy.missing_host_key(self, server_hostkey_name, server_key)
# if the callback returns, assume the key is ok
our_server_key = server_key

View File

@ -91,7 +91,7 @@ class SSHClientTest (unittest.TestCase):
public_host_key = paramiko.RSAKey(data=str(host_key))
self.tc = paramiko.SSHClient()
self.tc.get_host_keys().add(self.addr, 'ssh-rsa', public_host_key)
self.tc.get_host_keys().add('[%s]:%d' % (self.addr, self.port), 'ssh-rsa', public_host_key)
self.tc.connect(self.addr, self.port, username='slowdive', password='pygmalion')
self.event.wait(1.0)
@ -124,7 +124,7 @@ class SSHClientTest (unittest.TestCase):
public_host_key = paramiko.RSAKey(data=str(host_key))
self.tc = paramiko.SSHClient()
self.tc.get_host_keys().add(self.addr, 'ssh-rsa', public_host_key)
self.tc.get_host_keys().add('[%s]:%d' % (self.addr, self.port), 'ssh-rsa', public_host_key)
self.tc.connect(self.addr, self.port, username='slowdive', key_filename='tests/test_dss.key')
self.event.wait(1.0)
@ -157,7 +157,7 @@ class SSHClientTest (unittest.TestCase):
public_host_key = paramiko.RSAKey(data=str(host_key))
self.tc = paramiko.SSHClient()
self.tc.get_host_keys().add(self.addr, 'ssh-rsa', public_host_key)
self.tc.get_host_keys().add('[%s]:%d' % (self.addr, self.port), 'ssh-rsa', public_host_key)
self.tc.connect(self.addr, self.port, username='slowdive', key_filename=[ 'tests/test_rsa.key', 'tests/test_dss.key' ])
self.event.wait(1.0)
@ -184,7 +184,7 @@ class SSHClientTest (unittest.TestCase):
self.assertEquals('slowdive', self.ts.get_username())
self.assertEquals(True, self.ts.is_authenticated())
self.assertEquals(1, len(self.tc.get_host_keys()))
self.assertEquals(public_host_key, self.tc.get_host_keys()[self.addr]['ssh-rsa'])
self.assertEquals(public_host_key, self.tc.get_host_keys()['[%s]:%d' % (self.addr, self.port)]['ssh-rsa'])
def test_5_cleanup(self):
"""