improve HostKeys so that it more correctly emulates a dict, and add a unit test to verify that
This commit is contained in:
parent
90a577c775
commit
de6315b9c5
|
@ -22,13 +22,14 @@ L{HostKeys}
|
|||
|
||||
import base64
|
||||
from Crypto.Hash import SHA, HMAC
|
||||
import UserDict
|
||||
|
||||
from paramiko.common import *
|
||||
from paramiko.dsskey import DSSKey
|
||||
from paramiko.rsakey import RSAKey
|
||||
|
||||
|
||||
class HostKeys (object):
|
||||
class HostKeys (UserDict.DictMixin):
|
||||
"""
|
||||
Representation of an openssh-style "known hosts" file. Host keys can be
|
||||
read from one or more files, and then individual hosts can be looked up to
|
||||
|
@ -49,7 +50,7 @@ class HostKeys (object):
|
|||
@type filename: str
|
||||
"""
|
||||
# hostname -> keytype -> PKey
|
||||
self.keys = {}
|
||||
self._keys = {}
|
||||
self.contains_hashes = False
|
||||
if filename is not None:
|
||||
self.load(filename)
|
||||
|
@ -66,11 +67,11 @@ class HostKeys (object):
|
|||
@param key: the key to add
|
||||
@type key: L{PKey}
|
||||
"""
|
||||
if not hostname in self.keys:
|
||||
self.keys[hostname] = {}
|
||||
if not hostname in self._keys:
|
||||
self._keys[hostname] = {}
|
||||
if hostname.startswith('|1|'):
|
||||
self.contains_hashes = True
|
||||
self.keys[hostname][keytype] = key
|
||||
self._keys[hostname][keytype] = key
|
||||
|
||||
def load(self, filename):
|
||||
"""
|
||||
|
@ -110,15 +111,15 @@ class HostKeys (object):
|
|||
@return: keys associated with this host (or C{None})
|
||||
@rtype: dict(str, L{PKey})
|
||||
"""
|
||||
if hostname in self.keys:
|
||||
return self.keys[hostname]
|
||||
if hostname in self._keys:
|
||||
return self._keys[hostname]
|
||||
if not self.contains_hashes:
|
||||
return None
|
||||
for h in self.keys.keys():
|
||||
for h in self._keys.keys():
|
||||
if h.startswith('|1|'):
|
||||
hmac = self.hash_host(hostname, h)
|
||||
if hmac == h:
|
||||
return self.keys[h]
|
||||
return self._keys[h]
|
||||
return None
|
||||
|
||||
def check(self, hostname, key):
|
||||
|
@ -146,20 +147,20 @@ class HostKeys (object):
|
|||
"""
|
||||
Remove all host keys from the dictionary.
|
||||
"""
|
||||
self.keys = {}
|
||||
self._keys = {}
|
||||
self.contains_hashes = False
|
||||
|
||||
def values(self):
|
||||
return self.keys.values();
|
||||
|
||||
def __getitem__(self, key):
|
||||
ret = self.lookup(key)
|
||||
if ret is None:
|
||||
raise KeyError(key)
|
||||
return ret
|
||||
|
||||
def __len__(self):
|
||||
return len(self.keys)
|
||||
def keys(self):
|
||||
return self._keys.keys()
|
||||
|
||||
def values(self):
|
||||
return self._keys.values();
|
||||
|
||||
def hash_host(hostname, salt=None):
|
||||
"""
|
||||
|
|
|
@ -71,3 +71,14 @@ class HostKeysTest (unittest.TestCase):
|
|||
fp = paramiko.util.hexify(x['ssh-rsa'].get_fingerprint())
|
||||
self.assertEquals('7EC91BB336CB6D810B124B1353C32396', fp)
|
||||
self.assertTrue(hostdict.check('foo.example.com', key))
|
||||
|
||||
def test_3_dict(self):
|
||||
hostdict = paramiko.HostKeys('hostfile.temp')
|
||||
self.assert_('secure.example.com' in hostdict)
|
||||
self.assert_('not.example.com' not in hostdict)
|
||||
self.assert_(hostdict.has_key('secure.example.com'))
|
||||
self.assert_(not hostdict.has_key('not.example.com'))
|
||||
x = hostdict.get('secure.example.com', None)
|
||||
self.assertTrue(x is not None)
|
||||
fp = paramiko.util.hexify(x['ssh-rsa'].get_fingerprint())
|
||||
self.assertEquals('E6684DB30E109B67B70FF1DC5C7F1363', fp)
|
||||
|
|
Loading…
Reference in New Issue