improve HostKeys so that it more correctly emulates a dict, and add a unit test to verify that
This commit is contained in:
Robey Pointer 2006-03-09 00:04:50 -08:00
parent 90a577c775
commit de6315b9c5
2 changed files with 28 additions and 16 deletions

View File

@ -22,13 +22,14 @@ L{HostKeys}
import base64
from Crypto.Hash import SHA, HMAC
import UserDict
from paramiko.common import *
from paramiko.dsskey import DSSKey
from paramiko.rsakey import RSAKey
class HostKeys (object):
class HostKeys (UserDict.DictMixin):
"""
Representation of an openssh-style "known hosts" file. Host keys can be
read from one or more files, and then individual hosts can be looked up to
@ -49,7 +50,7 @@ class HostKeys (object):
@type filename: str
"""
# hostname -> keytype -> PKey
self.keys = {}
self._keys = {}
self.contains_hashes = False
if filename is not None:
self.load(filename)
@ -66,11 +67,11 @@ class HostKeys (object):
@param key: the key to add
@type key: L{PKey}
"""
if not hostname in self.keys:
self.keys[hostname] = {}
if not hostname in self._keys:
self._keys[hostname] = {}
if hostname.startswith('|1|'):
self.contains_hashes = True
self.keys[hostname][keytype] = key
self._keys[hostname][keytype] = key
def load(self, filename):
"""
@ -110,15 +111,15 @@ class HostKeys (object):
@return: keys associated with this host (or C{None})
@rtype: dict(str, L{PKey})
"""
if hostname in self.keys:
return self.keys[hostname]
if hostname in self._keys:
return self._keys[hostname]
if not self.contains_hashes:
return None
for h in self.keys.keys():
for h in self._keys.keys():
if h.startswith('|1|'):
hmac = self.hash_host(hostname, h)
if hmac == h:
return self.keys[h]
return self._keys[h]
return None
def check(self, hostname, key):
@ -146,21 +147,21 @@ class HostKeys (object):
"""
Remove all host keys from the dictionary.
"""
self.keys = {}
self._keys = {}
self.contains_hashes = False
def values(self):
return self.keys.values();
def __getitem__(self, key):
ret = self.lookup(key)
if ret is None:
raise KeyError(key)
return ret
def __len__(self):
return len(self.keys)
def keys(self):
return self._keys.keys()
def values(self):
return self._keys.values();
def hash_host(hostname, salt=None):
"""
Return a "hashed" form of the hostname, as used by openssh when storing

View File

@ -71,3 +71,14 @@ class HostKeysTest (unittest.TestCase):
fp = paramiko.util.hexify(x['ssh-rsa'].get_fingerprint())
self.assertEquals('7EC91BB336CB6D810B124B1353C32396', fp)
self.assertTrue(hostdict.check('foo.example.com', key))
def test_3_dict(self):
hostdict = paramiko.HostKeys('hostfile.temp')
self.assert_('secure.example.com' in hostdict)
self.assert_('not.example.com' not in hostdict)
self.assert_(hostdict.has_key('secure.example.com'))
self.assert_(not hostdict.has_key('not.example.com'))
x = hostdict.get('secure.example.com', None)
self.assertTrue(x is not None)
fp = paramiko.util.hexify(x['ssh-rsa'].get_fingerprint())
self.assertEquals('E6684DB30E109B67B70FF1DC5C7F1363', fp)