try a trick that should let 'hostkeys[hostname][keytype] = key' work for
HostKeys objects again.
This commit is contained in:
Robey Pointer 2006-11-11 21:56:17 -08:00
parent bee3535484
commit e2d8357622
2 changed files with 34 additions and 10 deletions

View File

@ -188,18 +188,43 @@ class HostKeys (UserDict.DictMixin):
@return: keys associated with this host (or C{None}) @return: keys associated with this host (or C{None})
@rtype: dict(str, L{PKey}) @rtype: dict(str, L{PKey})
""" """
ret = {} class SubDict (UserDict.DictMixin):
valid = False def __init__(self, hostname, entries, hostkeys):
self._hostname = hostname
self._entries = entries
self._hostkeys = hostkeys
def __getitem__(self, key):
for e in self._entries:
if e.key.get_name() == key:
return e.key
raise KeyError(key)
def __setitem__(self, key, val):
for e in self._entries:
if e.key is None:
continue
if e.key.get_name() == key:
# replace
e.key = val
break
else:
# add a new one
e = HostKeyEntry([hostname], val)
self._entries.append(e)
self._hostkeys._entries.append(e)
def keys(self):
return [e.key.get_name() for e in self._entries if e.key is not None]
entries = []
for e in self._entries: for e in self._entries:
for h in e.hostnames: for h in e.hostnames:
if (h.startswith('|1|') and (self.hash_host(hostname, h) == h)) or (h == hostname): if (h.startswith('|1|') and (self.hash_host(hostname, h) == h)) or (h == hostname):
valid = True entries.append(e)
if e.key is None: if len(entries) == 0:
continue
ret[e.key.get_name()] = e.key
if not valid:
return None return None
return ret return SubDict(hostname, entries, self)
def check(self, hostname, key): def check(self, hostname, key):
""" """

View File

@ -105,13 +105,12 @@ class HostKeysTest (unittest.TestCase):
'ssh-dss': key_dss 'ssh-dss': key_dss
} }
hostdict['fake.example.com'] = {} hostdict['fake.example.com'] = {}
# this line will have no effect, but at least shouldn't crash:
hostdict['fake.example.com']['ssh-rsa'] = key hostdict['fake.example.com']['ssh-rsa'] = key
self.assertEquals(3, len(hostdict)) self.assertEquals(3, len(hostdict))
self.assertEquals(2, len(hostdict.values()[0])) self.assertEquals(2, len(hostdict.values()[0]))
self.assertEquals(1, len(hostdict.values()[1])) self.assertEquals(1, len(hostdict.values()[1]))
self.assertEquals(0, len(hostdict.values()[2])) self.assertEquals(1, len(hostdict.values()[2]))
fp = hexlify(hostdict['secure.example.com']['ssh-rsa'].get_fingerprint()).upper() fp = hexlify(hostdict['secure.example.com']['ssh-rsa'].get_fingerprint()).upper()
self.assertEquals('7EC91BB336CB6D810B124B1353C32396', fp) self.assertEquals('7EC91BB336CB6D810B124B1353C32396', fp)
fp = hexlify(hostdict['secure.example.com']['ssh-dss'].get_fingerprint()).upper() fp = hexlify(hostdict['secure.example.com']['ssh-dss'].get_fingerprint()).upper()