improve HostKeys so that it more correctly emulates a dict, and add a unit test to verify that
This commit is contained in:
Robey Pointer 2006-03-09 00:04:50 -08:00
parent 90a577c775
commit de6315b9c5
2 changed files with 28 additions and 16 deletions

View File

@ -22,13 +22,14 @@ L{HostKeys}
import base64 import base64
from Crypto.Hash import SHA, HMAC from Crypto.Hash import SHA, HMAC
import UserDict
from paramiko.common import * from paramiko.common import *
from paramiko.dsskey import DSSKey from paramiko.dsskey import DSSKey
from paramiko.rsakey import RSAKey from paramiko.rsakey import RSAKey
class HostKeys (object): class HostKeys (UserDict.DictMixin):
""" """
Representation of an openssh-style "known hosts" file. Host keys can be Representation of an openssh-style "known hosts" file. Host keys can be
read from one or more files, and then individual hosts can be looked up to read from one or more files, and then individual hosts can be looked up to
@ -49,7 +50,7 @@ class HostKeys (object):
@type filename: str @type filename: str
""" """
# hostname -> keytype -> PKey # hostname -> keytype -> PKey
self.keys = {} self._keys = {}
self.contains_hashes = False self.contains_hashes = False
if filename is not None: if filename is not None:
self.load(filename) self.load(filename)
@ -66,11 +67,11 @@ class HostKeys (object):
@param key: the key to add @param key: the key to add
@type key: L{PKey} @type key: L{PKey}
""" """
if not hostname in self.keys: if not hostname in self._keys:
self.keys[hostname] = {} self._keys[hostname] = {}
if hostname.startswith('|1|'): if hostname.startswith('|1|'):
self.contains_hashes = True self.contains_hashes = True
self.keys[hostname][keytype] = key self._keys[hostname][keytype] = key
def load(self, filename): def load(self, filename):
""" """
@ -110,15 +111,15 @@ class HostKeys (object):
@return: keys associated with this host (or C{None}) @return: keys associated with this host (or C{None})
@rtype: dict(str, L{PKey}) @rtype: dict(str, L{PKey})
""" """
if hostname in self.keys: if hostname in self._keys:
return self.keys[hostname] return self._keys[hostname]
if not self.contains_hashes: if not self.contains_hashes:
return None return None
for h in self.keys.keys(): for h in self._keys.keys():
if h.startswith('|1|'): if h.startswith('|1|'):
hmac = self.hash_host(hostname, h) hmac = self.hash_host(hostname, h)
if hmac == h: if hmac == h:
return self.keys[h] return self._keys[h]
return None return None
def check(self, hostname, key): def check(self, hostname, key):
@ -146,21 +147,21 @@ class HostKeys (object):
""" """
Remove all host keys from the dictionary. Remove all host keys from the dictionary.
""" """
self.keys = {} self._keys = {}
self.contains_hashes = False self.contains_hashes = False
def values(self):
return self.keys.values();
def __getitem__(self, key): def __getitem__(self, key):
ret = self.lookup(key) ret = self.lookup(key)
if ret is None: if ret is None:
raise KeyError(key) raise KeyError(key)
return ret return ret
def __len__(self): def keys(self):
return len(self.keys) return self._keys.keys()
def values(self):
return self._keys.values();
def hash_host(hostname, salt=None): def hash_host(hostname, salt=None):
""" """
Return a "hashed" form of the hostname, as used by openssh when storing Return a "hashed" form of the hostname, as used by openssh when storing

View File

@ -71,3 +71,14 @@ class HostKeysTest (unittest.TestCase):
fp = paramiko.util.hexify(x['ssh-rsa'].get_fingerprint()) fp = paramiko.util.hexify(x['ssh-rsa'].get_fingerprint())
self.assertEquals('7EC91BB336CB6D810B124B1353C32396', fp) self.assertEquals('7EC91BB336CB6D810B124B1353C32396', fp)
self.assertTrue(hostdict.check('foo.example.com', key)) self.assertTrue(hostdict.check('foo.example.com', key))
def test_3_dict(self):
hostdict = paramiko.HostKeys('hostfile.temp')
self.assert_('secure.example.com' in hostdict)
self.assert_('not.example.com' not in hostdict)
self.assert_(hostdict.has_key('secure.example.com'))
self.assert_(not hostdict.has_key('not.example.com'))
x = hostdict.get('secure.example.com', None)
self.assertTrue(x is not None)
fp = paramiko.util.hexify(x['ssh-rsa'].get_fingerprint())
self.assertEquals('E6684DB30E109B67B70FF1DC5C7F1363', fp)