From de6315b9c5813ef846a3e97cb815c12505254fc9 Mon Sep 17 00:00:00 2001 From: Robey Pointer Date: Thu, 9 Mar 2006 00:04:50 -0800 Subject: [PATCH] [project @ robey@lag.net-20060309080450-bad95b03d60d3d4f] improve HostKeys so that it more correctly emulates a dict, and add a unit test to verify that --- paramiko/hostkeys.py | 33 +++++++++++++++++---------------- tests/test_hostkeys.py | 11 +++++++++++ 2 files changed, 28 insertions(+), 16 deletions(-) diff --git a/paramiko/hostkeys.py b/paramiko/hostkeys.py index 5ef2716..fae25c2 100644 --- a/paramiko/hostkeys.py +++ b/paramiko/hostkeys.py @@ -22,13 +22,14 @@ L{HostKeys} import base64 from Crypto.Hash import SHA, HMAC +import UserDict from paramiko.common import * from paramiko.dsskey import DSSKey from paramiko.rsakey import RSAKey -class HostKeys (object): +class HostKeys (UserDict.DictMixin): """ Representation of an openssh-style "known hosts" file. Host keys can be read from one or more files, and then individual hosts can be looked up to @@ -49,7 +50,7 @@ class HostKeys (object): @type filename: str """ # hostname -> keytype -> PKey - self.keys = {} + self._keys = {} self.contains_hashes = False if filename is not None: self.load(filename) @@ -66,11 +67,11 @@ class HostKeys (object): @param key: the key to add @type key: L{PKey} """ - if not hostname in self.keys: - self.keys[hostname] = {} + if not hostname in self._keys: + self._keys[hostname] = {} if hostname.startswith('|1|'): self.contains_hashes = True - self.keys[hostname][keytype] = key + self._keys[hostname][keytype] = key def load(self, filename): """ @@ -110,15 +111,15 @@ class HostKeys (object): @return: keys associated with this host (or C{None}) @rtype: dict(str, L{PKey}) """ - if hostname in self.keys: - return self.keys[hostname] + if hostname in self._keys: + return self._keys[hostname] if not self.contains_hashes: return None - for h in self.keys.keys(): + for h in self._keys.keys(): if h.startswith('|1|'): hmac = self.hash_host(hostname, h) if hmac == h: - return self.keys[h] + return self._keys[h] return None def check(self, hostname, key): @@ -146,21 +147,21 @@ class HostKeys (object): """ Remove all host keys from the dictionary. """ - self.keys = {} + self._keys = {} self.contains_hashes = False - def values(self): - return self.keys.values(); - def __getitem__(self, key): ret = self.lookup(key) if ret is None: raise KeyError(key) return ret - def __len__(self): - return len(self.keys) - + def keys(self): + return self._keys.keys() + + def values(self): + return self._keys.values(); + def hash_host(hostname, salt=None): """ Return a "hashed" form of the hostname, as used by openssh when storing diff --git a/tests/test_hostkeys.py b/tests/test_hostkeys.py index 1342638..6f8eb57 100644 --- a/tests/test_hostkeys.py +++ b/tests/test_hostkeys.py @@ -71,3 +71,14 @@ class HostKeysTest (unittest.TestCase): fp = paramiko.util.hexify(x['ssh-rsa'].get_fingerprint()) self.assertEquals('7EC91BB336CB6D810B124B1353C32396', fp) self.assertTrue(hostdict.check('foo.example.com', key)) + + def test_3_dict(self): + hostdict = paramiko.HostKeys('hostfile.temp') + self.assert_('secure.example.com' in hostdict) + self.assert_('not.example.com' not in hostdict) + self.assert_(hostdict.has_key('secure.example.com')) + self.assert_(not hostdict.has_key('not.example.com')) + x = hostdict.get('secure.example.com', None) + self.assertTrue(x is not None) + fp = paramiko.util.hexify(x['ssh-rsa'].get_fingerprint()) + self.assertEquals('E6684DB30E109B67B70FF1DC5C7F1363', fp)