diff --git a/paramiko/hostkeys.py b/paramiko/hostkeys.py index 3ad0fe5..0a9c10b 100644 --- a/paramiko/hostkeys.py +++ b/paramiko/hostkeys.py @@ -82,6 +82,9 @@ class HostKeyEntry: return '%s %s %s\n' % (','.join(self.hostnames), self.key.get_name(), self.key.get_base64()) return None + + def __repr__(self): + return '' % (self.hostnames, self.key) class HostKeys (UserDict.DictMixin): @@ -178,7 +181,7 @@ class HostKeys (UserDict.DictMixin): """ Find a hostkey entry for a given hostname or IP. If no entry is found, C{None} is returned. Otherwise a dictionary of keytype to key is - returned. + returned. The keytype will be either C{"ssh-rsa"} or C{"ssh-dss"}. @param hostname: the hostname (or IP) to lookup @type hostname: str @@ -186,13 +189,15 @@ class HostKeys (UserDict.DictMixin): @rtype: dict(str, L{PKey}) """ ret = {} + valid = False for e in self._entries: for h in e.hostnames: - if h.startswith('|1|') and (self.hash_host(hostname, h) == h): + if (h.startswith('|1|') and (self.hash_host(hostname, h) == h)) or (h == hostname): + valid = True + if e.key is None: + continue ret[e.key.get_name()] = e.key - elif h == hostname: - ret[e.key.get_name()] = e.key - if len(ret) == 0: + if not valid: return None return ret @@ -231,6 +236,9 @@ class HostKeys (UserDict.DictMixin): def __setitem__(self, hostname, entry): # don't use this please. + if len(entry) == 0: + self._entries.append(HostKeyEntry([hostname], None)) + return for key_type in entry.keys(): found = False for e in self._entries: @@ -242,6 +250,7 @@ class HostKeys (UserDict.DictMixin): self._entries.append(HostKeyEntry([hostname], entry[key_type])) def keys(self): + # python 2.4 sets would be nice here. ret = [] for e in self._entries: for h in e.hostnames: diff --git a/tests/test_hostkeys.py b/tests/test_hostkeys.py index b7d9ae9..e9580dd 100644 --- a/tests/test_hostkeys.py +++ b/tests/test_hostkeys.py @@ -104,9 +104,14 @@ class HostKeysTest (unittest.TestCase): 'ssh-rsa': key, 'ssh-dss': key_dss } - self.assertEquals(2, len(hostdict)) + hostdict['fake.example.com'] = {} + # this line will have no effect, but at least shouldn't crash: + hostdict['fake.example.com']['ssh-rsa'] = key + + self.assertEquals(3, len(hostdict)) self.assertEquals(2, len(hostdict.values()[0])) self.assertEquals(1, len(hostdict.values()[1])) + self.assertEquals(0, len(hostdict.values()[2])) fp = hexlify(hostdict['secure.example.com']['ssh-rsa'].get_fingerprint()).upper() self.assertEquals('7EC91BB336CB6D810B124B1353C32396', fp) fp = hexlify(hostdict['secure.example.com']['ssh-dss'].get_fingerprint()).upper()