diff --git a/paramiko/hostkeys.py b/paramiko/hostkeys.py index 0a9c10b..316acc6 100644 --- a/paramiko/hostkeys.py +++ b/paramiko/hostkeys.py @@ -188,18 +188,43 @@ class HostKeys (UserDict.DictMixin): @return: keys associated with this host (or C{None}) @rtype: dict(str, L{PKey}) """ - ret = {} - valid = False + class SubDict (UserDict.DictMixin): + def __init__(self, hostname, entries, hostkeys): + self._hostname = hostname + self._entries = entries + self._hostkeys = hostkeys + + def __getitem__(self, key): + for e in self._entries: + if e.key.get_name() == key: + return e.key + raise KeyError(key) + + def __setitem__(self, key, val): + for e in self._entries: + if e.key is None: + continue + if e.key.get_name() == key: + # replace + e.key = val + break + else: + # add a new one + e = HostKeyEntry([hostname], val) + self._entries.append(e) + self._hostkeys._entries.append(e) + + def keys(self): + return [e.key.get_name() for e in self._entries if e.key is not None] + + entries = [] for e in self._entries: for h in e.hostnames: if (h.startswith('|1|') and (self.hash_host(hostname, h) == h)) or (h == hostname): - valid = True - if e.key is None: - continue - ret[e.key.get_name()] = e.key - if not valid: + entries.append(e) + if len(entries) == 0: return None - return ret + return SubDict(hostname, entries, self) def check(self, hostname, key): """ diff --git a/tests/test_hostkeys.py b/tests/test_hostkeys.py index e9580dd..2430357 100644 --- a/tests/test_hostkeys.py +++ b/tests/test_hostkeys.py @@ -105,13 +105,12 @@ class HostKeysTest (unittest.TestCase): 'ssh-dss': key_dss } hostdict['fake.example.com'] = {} - # this line will have no effect, but at least shouldn't crash: hostdict['fake.example.com']['ssh-rsa'] = key self.assertEquals(3, len(hostdict)) self.assertEquals(2, len(hostdict.values()[0])) self.assertEquals(1, len(hostdict.values()[1])) - self.assertEquals(0, len(hostdict.values()[2])) + self.assertEquals(1, len(hostdict.values()[2])) fp = hexlify(hostdict['secure.example.com']['ssh-rsa'].get_fingerprint()).upper() self.assertEquals('7EC91BB336CB6D810B124B1353C32396', fp) fp = hexlify(hostdict['secure.example.com']['ssh-dss'].get_fingerprint()).upper()