improve HostKeys so that it more correctly emulates a dict, and add a unit test to verify that
This commit is contained in:
parent
90a577c775
commit
de6315b9c5
|
@ -22,13 +22,14 @@ L{HostKeys}
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
from Crypto.Hash import SHA, HMAC
|
from Crypto.Hash import SHA, HMAC
|
||||||
|
import UserDict
|
||||||
|
|
||||||
from paramiko.common import *
|
from paramiko.common import *
|
||||||
from paramiko.dsskey import DSSKey
|
from paramiko.dsskey import DSSKey
|
||||||
from paramiko.rsakey import RSAKey
|
from paramiko.rsakey import RSAKey
|
||||||
|
|
||||||
|
|
||||||
class HostKeys (object):
|
class HostKeys (UserDict.DictMixin):
|
||||||
"""
|
"""
|
||||||
Representation of an openssh-style "known hosts" file. Host keys can be
|
Representation of an openssh-style "known hosts" file. Host keys can be
|
||||||
read from one or more files, and then individual hosts can be looked up to
|
read from one or more files, and then individual hosts can be looked up to
|
||||||
|
@ -49,7 +50,7 @@ class HostKeys (object):
|
||||||
@type filename: str
|
@type filename: str
|
||||||
"""
|
"""
|
||||||
# hostname -> keytype -> PKey
|
# hostname -> keytype -> PKey
|
||||||
self.keys = {}
|
self._keys = {}
|
||||||
self.contains_hashes = False
|
self.contains_hashes = False
|
||||||
if filename is not None:
|
if filename is not None:
|
||||||
self.load(filename)
|
self.load(filename)
|
||||||
|
@ -66,11 +67,11 @@ class HostKeys (object):
|
||||||
@param key: the key to add
|
@param key: the key to add
|
||||||
@type key: L{PKey}
|
@type key: L{PKey}
|
||||||
"""
|
"""
|
||||||
if not hostname in self.keys:
|
if not hostname in self._keys:
|
||||||
self.keys[hostname] = {}
|
self._keys[hostname] = {}
|
||||||
if hostname.startswith('|1|'):
|
if hostname.startswith('|1|'):
|
||||||
self.contains_hashes = True
|
self.contains_hashes = True
|
||||||
self.keys[hostname][keytype] = key
|
self._keys[hostname][keytype] = key
|
||||||
|
|
||||||
def load(self, filename):
|
def load(self, filename):
|
||||||
"""
|
"""
|
||||||
|
@ -110,15 +111,15 @@ class HostKeys (object):
|
||||||
@return: keys associated with this host (or C{None})
|
@return: keys associated with this host (or C{None})
|
||||||
@rtype: dict(str, L{PKey})
|
@rtype: dict(str, L{PKey})
|
||||||
"""
|
"""
|
||||||
if hostname in self.keys:
|
if hostname in self._keys:
|
||||||
return self.keys[hostname]
|
return self._keys[hostname]
|
||||||
if not self.contains_hashes:
|
if not self.contains_hashes:
|
||||||
return None
|
return None
|
||||||
for h in self.keys.keys():
|
for h in self._keys.keys():
|
||||||
if h.startswith('|1|'):
|
if h.startswith('|1|'):
|
||||||
hmac = self.hash_host(hostname, h)
|
hmac = self.hash_host(hostname, h)
|
||||||
if hmac == h:
|
if hmac == h:
|
||||||
return self.keys[h]
|
return self._keys[h]
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def check(self, hostname, key):
|
def check(self, hostname, key):
|
||||||
|
@ -146,21 +147,21 @@ class HostKeys (object):
|
||||||
"""
|
"""
|
||||||
Remove all host keys from the dictionary.
|
Remove all host keys from the dictionary.
|
||||||
"""
|
"""
|
||||||
self.keys = {}
|
self._keys = {}
|
||||||
self.contains_hashes = False
|
self.contains_hashes = False
|
||||||
|
|
||||||
def values(self):
|
|
||||||
return self.keys.values();
|
|
||||||
|
|
||||||
def __getitem__(self, key):
|
def __getitem__(self, key):
|
||||||
ret = self.lookup(key)
|
ret = self.lookup(key)
|
||||||
if ret is None:
|
if ret is None:
|
||||||
raise KeyError(key)
|
raise KeyError(key)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def __len__(self):
|
def keys(self):
|
||||||
return len(self.keys)
|
return self._keys.keys()
|
||||||
|
|
||||||
|
def values(self):
|
||||||
|
return self._keys.values();
|
||||||
|
|
||||||
def hash_host(hostname, salt=None):
|
def hash_host(hostname, salt=None):
|
||||||
"""
|
"""
|
||||||
Return a "hashed" form of the hostname, as used by openssh when storing
|
Return a "hashed" form of the hostname, as used by openssh when storing
|
||||||
|
|
|
@ -71,3 +71,14 @@ class HostKeysTest (unittest.TestCase):
|
||||||
fp = paramiko.util.hexify(x['ssh-rsa'].get_fingerprint())
|
fp = paramiko.util.hexify(x['ssh-rsa'].get_fingerprint())
|
||||||
self.assertEquals('7EC91BB336CB6D810B124B1353C32396', fp)
|
self.assertEquals('7EC91BB336CB6D810B124B1353C32396', fp)
|
||||||
self.assertTrue(hostdict.check('foo.example.com', key))
|
self.assertTrue(hostdict.check('foo.example.com', key))
|
||||||
|
|
||||||
|
def test_3_dict(self):
|
||||||
|
hostdict = paramiko.HostKeys('hostfile.temp')
|
||||||
|
self.assert_('secure.example.com' in hostdict)
|
||||||
|
self.assert_('not.example.com' not in hostdict)
|
||||||
|
self.assert_(hostdict.has_key('secure.example.com'))
|
||||||
|
self.assert_(not hostdict.has_key('not.example.com'))
|
||||||
|
x = hostdict.get('secure.example.com', None)
|
||||||
|
self.assertTrue(x is not None)
|
||||||
|
fp = paramiko.util.hexify(x['ssh-rsa'].get_fingerprint())
|
||||||
|
self.assertEquals('E6684DB30E109B67B70FF1DC5C7F1363', fp)
|
||||||
|
|
Loading…
Reference in New Issue