From e2d83576224825d6a727d7675da6f712d354ad50 Mon Sep 17 00:00:00 2001 From: Robey Pointer Date: Sat, 11 Nov 2006 21:56:17 -0800 Subject: [PATCH] [project @ robey@lag.net-20061112055617-098a150cf051bffa] try a trick that should let 'hostkeys[hostname][keytype] = key' work for HostKeys objects again. --- paramiko/hostkeys.py | 41 +++++++++++++++++++++++++++++++++-------- tests/test_hostkeys.py | 3 +-- 2 files changed, 34 insertions(+), 10 deletions(-) diff --git a/paramiko/hostkeys.py b/paramiko/hostkeys.py index 0a9c10b..316acc6 100644 --- a/paramiko/hostkeys.py +++ b/paramiko/hostkeys.py @@ -188,18 +188,43 @@ class HostKeys (UserDict.DictMixin): @return: keys associated with this host (or C{None}) @rtype: dict(str, L{PKey}) """ - ret = {} - valid = False + class SubDict (UserDict.DictMixin): + def __init__(self, hostname, entries, hostkeys): + self._hostname = hostname + self._entries = entries + self._hostkeys = hostkeys + + def __getitem__(self, key): + for e in self._entries: + if e.key.get_name() == key: + return e.key + raise KeyError(key) + + def __setitem__(self, key, val): + for e in self._entries: + if e.key is None: + continue + if e.key.get_name() == key: + # replace + e.key = val + break + else: + # add a new one + e = HostKeyEntry([hostname], val) + self._entries.append(e) + self._hostkeys._entries.append(e) + + def keys(self): + return [e.key.get_name() for e in self._entries if e.key is not None] + + entries = [] for e in self._entries: for h in e.hostnames: if (h.startswith('|1|') and (self.hash_host(hostname, h) == h)) or (h == hostname): - valid = True - if e.key is None: - continue - ret[e.key.get_name()] = e.key - if not valid: + entries.append(e) + if len(entries) == 0: return None - return ret + return SubDict(hostname, entries, self) def check(self, hostname, key): """ diff --git a/tests/test_hostkeys.py b/tests/test_hostkeys.py index e9580dd..2430357 100644 --- a/tests/test_hostkeys.py +++ b/tests/test_hostkeys.py @@ -105,13 +105,12 @@ class HostKeysTest (unittest.TestCase): 'ssh-dss': key_dss } hostdict['fake.example.com'] = {} - # this line will have no effect, but at least shouldn't crash: hostdict['fake.example.com']['ssh-rsa'] = key self.assertEquals(3, len(hostdict)) self.assertEquals(2, len(hostdict.values()[0])) self.assertEquals(1, len(hostdict.values()[1])) - self.assertEquals(0, len(hostdict.values()[2])) + self.assertEquals(1, len(hostdict.values()[2])) fp = hexlify(hostdict['secure.example.com']['ssh-rsa'].get_fingerprint()).upper() self.assertEquals('7EC91BB336CB6D810B124B1353C32396', fp) fp = hexlify(hostdict['secure.example.com']['ssh-dss'].get_fingerprint()).upper()