From 71e872e23ad343370869d1041b2debf32e676265 Mon Sep 17 00:00:00 2001 From: Robey Pointer Date: Sun, 1 Nov 2009 21:28:47 -0800 Subject: [PATCH] bug 426925: lookup host keys correctly when they have a different port. --- paramiko/client.py | 14 ++++++++++---- tests/test_client.py | 38 +++++++++++++++++++------------------- 2 files changed, 29 insertions(+), 23 deletions(-) diff --git a/paramiko/client.py b/paramiko/client.py index ce14f66..023b405 100644 --- a/paramiko/client.py +++ b/paramiko/client.py @@ -36,6 +36,8 @@ from paramiko.ssh_exception import SSHException, BadHostKeyException from paramiko.transport import Transport +SSH_PORT = 22 + class MissingHostKeyPolicy (object): """ Interface for defining the policy that L{SSHClient} should use when the @@ -223,7 +225,7 @@ class SSHClient (object): """ self._policy = policy - def connect(self, hostname, port=22, username=None, password=None, pkey=None, + def connect(self, hostname, port=SSH_PORT, username=None, password=None, pkey=None, key_filename=None, timeout=None, allow_agent=True, look_for_keys=True): """ Connect to an SSH server and authenticate to it. The server's host key @@ -297,12 +299,16 @@ class SSHClient (object): server_key = t.get_remote_server_key() keytype = server_key.get_name() - our_server_key = self._system_host_keys.get(hostname, {}).get(keytype, None) + if port == SSH_PORT: + server_hostkey_name = hostname + else: + server_hostkey_name = "[%s]:%d" % (hostname, port) + our_server_key = self._system_host_keys.get(server_hostkey_name, {}).get(keytype, None) if our_server_key is None: - our_server_key = self._host_keys.get(hostname, {}).get(keytype, None) + our_server_key = self._host_keys.get(server_hostkey_name, {}).get(keytype, None) if our_server_key is None: # will raise exception if the key is rejected; let that fall out - self._policy.missing_host_key(self, hostname, server_key) + self._policy.missing_host_key(self, server_hostkey_name, server_key) # if the callback returns, assume the key is ok our_server_key = server_key diff --git a/tests/test_client.py b/tests/test_client.py index f7a1724..2f9b9a7 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -31,7 +31,7 @@ import paramiko class NullServer (paramiko.ServerInterface): - + def get_allowed_auths(self, username): if username == 'slowdive': return 'publickey,password' @@ -46,7 +46,7 @@ class NullServer (paramiko.ServerInterface): if (key.get_name() == 'ssh-dss') and (hexlify(key.get_fingerprint()) == '4478f0b9a23cc5182009ff755bc1d26c'): return paramiko.AUTH_SUCCESSFUL return paramiko.AUTH_FAILED - + def check_channel_request(self, kind, chanid): return paramiko.OPEN_SUCCEEDED @@ -81,17 +81,17 @@ class SSHClientTest (unittest.TestCase): self.ts.add_server_key(host_key) server = NullServer() self.ts.start_server(self.event, server) - - + + def test_1_client(self): """ verify that the SSHClient stuff works too. """ host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key') public_host_key = paramiko.RSAKey(data=str(host_key)) - + self.tc = paramiko.SSHClient() - self.tc.get_host_keys().add(self.addr, 'ssh-rsa', public_host_key) + self.tc.get_host_keys().add('[%s]:%d' % (self.addr, self.port), 'ssh-rsa', public_host_key) self.tc.connect(self.addr, self.port, username='slowdive', password='pygmalion') self.event.wait(1.0) @@ -111,20 +111,20 @@ class SSHClientTest (unittest.TestCase): self.assertEquals('', stdout.readline()) self.assertEquals('This is on stderr.\n', stderr.readline()) self.assertEquals('', stderr.readline()) - + stdin.close() stdout.close() stderr.close() - + def test_2_client_dsa(self): """ verify that SSHClient works with a DSA key. """ host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key') public_host_key = paramiko.RSAKey(data=str(host_key)) - + self.tc = paramiko.SSHClient() - self.tc.get_host_keys().add(self.addr, 'ssh-rsa', public_host_key) + self.tc.get_host_keys().add('[%s]:%d' % (self.addr, self.port), 'ssh-rsa', public_host_key) self.tc.connect(self.addr, self.port, username='slowdive', key_filename='tests/test_dss.key') self.event.wait(1.0) @@ -144,20 +144,20 @@ class SSHClientTest (unittest.TestCase): self.assertEquals('', stdout.readline()) self.assertEquals('This is on stderr.\n', stderr.readline()) self.assertEquals('', stderr.readline()) - + stdin.close() stdout.close() stderr.close() - + def test_3_multiple_key_files(self): """ verify that SSHClient accepts and tries multiple key files. """ host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key') public_host_key = paramiko.RSAKey(data=str(host_key)) - + self.tc = paramiko.SSHClient() - self.tc.get_host_keys().add(self.addr, 'ssh-rsa', public_host_key) + self.tc.get_host_keys().add('[%s]:%d' % (self.addr, self.port), 'ssh-rsa', public_host_key) self.tc.connect(self.addr, self.port, username='slowdive', key_filename=[ 'tests/test_rsa.key', 'tests/test_dss.key' ]) self.event.wait(1.0) @@ -172,7 +172,7 @@ class SSHClientTest (unittest.TestCase): """ host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key') public_host_key = paramiko.RSAKey(data=str(host_key)) - + self.tc = paramiko.SSHClient() self.tc.set_missing_host_key_policy(paramiko.AutoAddPolicy()) self.assertEquals(0, len(self.tc.get_host_keys())) @@ -184,7 +184,7 @@ class SSHClientTest (unittest.TestCase): self.assertEquals('slowdive', self.ts.get_username()) self.assertEquals(True, self.ts.is_authenticated()) self.assertEquals(1, len(self.tc.get_host_keys())) - self.assertEquals(public_host_key, self.tc.get_host_keys()[self.addr]['ssh-rsa']) + self.assertEquals(public_host_key, self.tc.get_host_keys()['[%s]:%d' % (self.addr, self.port)]['ssh-rsa']) def test_5_cleanup(self): """ @@ -193,7 +193,7 @@ class SSHClientTest (unittest.TestCase): """ host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key') public_host_key = paramiko.RSAKey(data=str(host_key)) - + self.tc = paramiko.SSHClient() self.tc.set_missing_host_key_policy(paramiko.AutoAddPolicy()) self.assertEquals(0, len(self.tc.get_host_keys())) @@ -202,7 +202,7 @@ class SSHClientTest (unittest.TestCase): self.event.wait(1.0) self.assert_(self.event.isSet()) self.assert_(self.ts.is_active()) - + p = weakref.ref(self.tc._transport.packetizer) self.assert_(p() is not None) del self.tc @@ -211,4 +211,4 @@ class SSHClientTest (unittest.TestCase): while (time.time() - st < 5.0) and (p() is not None): time.sleep(0.1) self.assert_(p() is None) - +