From 2b8738d3ce4cdee85c02569927427676465b5e3c Mon Sep 17 00:00:00 2001 From: Robey Pointer Date: Mon, 28 Aug 2006 16:48:34 -0700 Subject: [PATCH] [project @ robey@lag.net-20060828234834-51542dc36057b361] fix __setitem__ to do the right thing --- paramiko/hostkeys.py | 22 +++++++++++++++------- tests/test_hostkeys.py | 25 +++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 7 deletions(-) diff --git a/paramiko/hostkeys.py b/paramiko/hostkeys.py index c8450b9..3ad0fe5 100644 --- a/paramiko/hostkeys.py +++ b/paramiko/hostkeys.py @@ -81,8 +81,7 @@ class HostKeyEntry: if self.valid: return '%s %s %s\n' % (','.join(self.hostnames), self.key.get_name(), self.key.get_base64()) - else: - return None + return None class HostKeys (UserDict.DictMixin): @@ -230,21 +229,30 @@ class HostKeys (UserDict.DictMixin): raise KeyError(key) return ret - def __setitem__(self, key, value): + def __setitem__(self, hostname, entry): # don't use this please. - self._keys[key] = value + for key_type in entry.keys(): + found = False + for e in self._entries: + if (hostname in e.hostnames) and (e.key.get_name() == key_type): + # replace + e.key = entry[key_type] + found = True + if not found: + self._entries.append(HostKeyEntry([hostname], entry[key_type])) def keys(self): ret = [] for e in self._entries: for h in e.hostnames: - ret.append(h) + if h not in ret: + ret.append(h) return ret def values(self): ret = [] - for e in self._entries: - ret.append({e.key.get_name() : e.key}) + for k in self.keys(): + ret.append(self.lookup(k)) return ret def hash_host(hostname, salt=None): diff --git a/tests/test_hostkeys.py b/tests/test_hostkeys.py index dc43e79..b7d9ae9 100644 --- a/tests/test_hostkeys.py +++ b/tests/test_hostkeys.py @@ -41,6 +41,16 @@ AAAAB3NzaC1yc2EAAAABIwAAAIEA8bP1ZA7DCZDB9J0s50l31MBGQ3GQ/Fc7SX6gkpXkwcZryoi4k\ NFhHu5LvHcZPdxXV1D+uTMfGS1eyd2Yz/DoNWXNAl8TI0cAsW5ymME3bQ4J/k1IKxCtz/bAlAqFgK\ oc+EolMziDYqWIATtW0rYTJvzGAzTmMj80/QpsFH+Pc2M=""" +keyblob_dss = """\ +AAAAB3NzaC1kc3MAAACBAOeBpgNnfRzr/twmAQRu2XwWAp3CFtrVnug6s6fgwj/oLjYbVtjAy6pl/\ +h0EKCWx2rf1IetyNsTxWrniA9I6HeDj65X1FyDkg6g8tvCnaNB8Xp/UUhuzHuGsMIipRxBxw9LF60\ +8EqZcj1E3ytktoW5B5OcjrkEoz3xG7C+rpIjYvAAAAFQDwz4UnmsGiSNu5iqjn3uTzwUpshwAAAIE\ +AkxfFeY8P2wZpDjX0MimZl5wkoFQDL25cPzGBuB4OnB8NoUk/yjAHIIpEShw8V+LzouMK5CTJQo5+\ +Ngw3qIch/WgRmMHy4kBq1SsXMjQCte1So6HBMvBPIW5SiMTmjCfZZiw4AYHK+B/JaOwaG9yRg2Ejg\ +4Ok10+XFDxlqZo8Y+wAAACARmR7CCPjodxASvRbIyzaVpZoJ/Z6x7dAumV+ysrV1BVYd0lYukmnjO\ +1kKBWApqpH1ve9XDQYN8zgxM4b16L21kpoWQnZtXrY3GZ4/it9kUgyB7+NwacIBlXa8cMDL7Q/69o\ +0d54U0X/NeX5QxuYR6OMJlrkQB7oiW/P/1mwjQgE=""" + class HostKeysTest (unittest.TestCase): @@ -86,3 +96,18 @@ class HostKeysTest (unittest.TestCase): i += 1 self.assertEquals(2, i) + def test_4_dict_set(self): + hostdict = paramiko.HostKeys('hostfile.temp') + key = paramiko.RSAKey(data=base64.decodestring(keyblob)) + key_dss = paramiko.DSSKey(data=base64.decodestring(keyblob_dss)) + hostdict['secure.example.com'] = { + 'ssh-rsa': key, + 'ssh-dss': key_dss + } + self.assertEquals(2, len(hostdict)) + self.assertEquals(2, len(hostdict.values()[0])) + self.assertEquals(1, len(hostdict.values()[1])) + fp = hexlify(hostdict['secure.example.com']['ssh-rsa'].get_fingerprint()).upper() + self.assertEquals('7EC91BB336CB6D810B124B1353C32396', fp) + fp = hexlify(hostdict['secure.example.com']['ssh-dss'].get_fingerprint()).upper() + self.assertEquals('4478F0B9A23CC5182009FF755BC1D26C', fp)