variant of a patch from warren young to preserve the order of host entries from the 'known_hosts' file and preserve knowlege of which lines had multiple hostnames on them
This commit is contained in:
Robey Pointer 2006-07-26 19:55:19 -07:00
parent 17a93bce4c
commit 6821b6e8e8
1 changed files with 91 additions and 36 deletions

View File

@ -30,6 +30,62 @@ from paramiko.odict import odict
from paramiko.rsakey import RSAKey from paramiko.rsakey import RSAKey
class HostKeyEntry:
"""
Representation of a line in an OpenSSH-style "known hosts" file.
"""
def __init__(self, hostnames=None, key=None):
self.valid = (hostnames is not None) and (key is not None)
self.hostnames = hostnames
self.key = key
def from_line(cls, line):
"""
Parses the given line of text to find the names for the host,
the type of key given on this line in the known_hosts file, and
the key data.
Lines are expected to not have leading or training whitespace.
We don't bother to check for comments or empty lines. All of
that should be taken care of before sending the line to us.
@param line: a line from an OpenSSH known_hosts file
@type line: str
"""
fields = line.split(' ')
if len(fields) != 3:
# Bad number of fields
return None
names, keytype, key = fields
names = names.split(',')
# Decide what kind of key we're looking at and create an object
# to hold it accordingly.
if keytype == 'ssh-rsa':
key = RSAKey(data=base64.decodestring(key))
elif keytype == 'ssh-dss':
key = DSSKey(data=base64.decodestring(key))
else:
return None
return cls(names, key)
from_line = classmethod(from_line)
def to_line(self):
"""
Returns a string in OpenSSH known_hosts file format, or None if
the object is not in a valid state. A trailing newline is
included.
"""
if self.valid:
return '%s %s %s\n' % (','.join(self.hostnames), self.key.get_name(),
self.key.get_base64())
else:
return None
class HostKeys (UserDict.DictMixin): class HostKeys (UserDict.DictMixin):
""" """
Representation of an openssh-style "known hosts" file. Host keys can be Representation of an openssh-style "known hosts" file. Host keys can be
@ -50,9 +106,8 @@ class HostKeys (UserDict.DictMixin):
@param filename: filename to load host keys from, or C{None} @param filename: filename to load host keys from, or C{None}
@type filename: str @type filename: str
""" """
# hostname -> keytype -> PKey # emulate a dict of { hostname: { keytype: PKey } }
self._keys = odict() self._entries = []
self.contains_hashes = False
if filename is not None: if filename is not None:
self.load(filename) self.load(filename)
@ -61,18 +116,18 @@ class HostKeys (UserDict.DictMixin):
Add a host key entry to the table. Any existing entry for a Add a host key entry to the table. Any existing entry for a
C{(hostname, keytype)} pair will be replaced. C{(hostname, keytype)} pair will be replaced.
@param hostname: @param hostname: the hostname (or IP) to add
@type hostname: str @type hostname: str
@param keytype: key type (C{"ssh-rsa"} or C{"ssh-dss"}) @param keytype: key type (C{"ssh-rsa"} or C{"ssh-dss"})
@type keytype: str @type keytype: str
@param key: the key to add @param key: the key to add
@type key: L{PKey} @type key: L{PKey}
""" """
if not hostname in self._keys: for e in self._entries:
self._keys[hostname] = odict() if (hostname in e.hostnames) and (e.key.get_name() == keytype):
if hostname.startswith('|1|'): e.key = key
self.contains_hashes = True return
self._keys[hostname][keytype] = key self._entries.append(HostKeyEntry([hostname], key))
def load(self, filename): def load(self, filename):
""" """
@ -95,16 +150,9 @@ class HostKeys (UserDict.DictMixin):
line = line.strip() line = line.strip()
if (len(line) == 0) or (line[0] == '#'): if (len(line) == 0) or (line[0] == '#'):
continue continue
keylist = line.split(' ') e = HostKeyEntry.from_line(line)
if len(keylist) != 3: if e is not None:
# don't understand this line self._entries.append(e)
continue
hostlist, keytype, key = keylist
for host in hostlist.split(','):
if keytype == 'ssh-rsa':
self.add(host, keytype, RSAKey(data=base64.decodestring(key)))
elif keytype == 'ssh-dss':
self.add(host, keytype, DSSKey(data=base64.decodestring(key)))
f.close() f.close()
def save(self, filename): def save(self, filename):
@ -122,9 +170,10 @@ class HostKeys (UserDict.DictMixin):
@since: 1.6.1 @since: 1.6.1
""" """
f = open(filename, 'w') f = open(filename, 'w')
for hostname, d in self._keys.iteritems(): for e in self._entries:
for keytype, key in d.iteritems(): line = e.to_line()
f.write('%s %s %s\n' % (hostname, keytype, key.get_base64())) if line:
f.write(line)
f.close() f.close()
def lookup(self, hostname): def lookup(self, hostname):
@ -133,21 +182,21 @@ class HostKeys (UserDict.DictMixin):
C{None} is returned. Otherwise a dictionary of keytype to key is C{None} is returned. Otherwise a dictionary of keytype to key is
returned. returned.
@param hostname: the hostname to lookup @param hostname: the hostname (or IP) to lookup
@type hostname: str @type hostname: str
@return: keys associated with this host (or C{None}) @return: keys associated with this host (or C{None})
@rtype: dict(str, L{PKey}) @rtype: dict(str, L{PKey})
""" """
if hostname in self._keys: ret = {}
return self._keys[hostname] for e in self._entries:
if not self.contains_hashes: for h in e.hostnames:
if h.startswith('|1|') and (self.hash_host(hostname, h) == h):
ret[e.key.get_name()] = e.key
elif h == hostname:
ret[e.key.get_name()] = e.key
if len(ret) == 0:
return None return None
for h in self._keys.keys(): return ret
if h.startswith('|1|'):
hmac = self.hash_host(hostname, h)
if hmac == h:
return self._keys[h]
return None
def check(self, hostname, key): def check(self, hostname, key):
""" """
@ -174,8 +223,7 @@ class HostKeys (UserDict.DictMixin):
""" """
Remove all host keys from the dictionary. Remove all host keys from the dictionary.
""" """
self._keys = {} self._entries = []
self.contains_hashes = False
def __getitem__(self, key): def __getitem__(self, key):
ret = self.lookup(key) ret = self.lookup(key)
@ -188,10 +236,17 @@ class HostKeys (UserDict.DictMixin):
self._keys[key] = value self._keys[key] = value
def keys(self): def keys(self):
return self._keys.keys() ret = []
for e in self._entries:
for h in e.hostnames:
ret.append(h)
return ret
def values(self): def values(self):
return self._keys.values(); ret = []
for e in self._entries:
ret.append({e.key.get_name() : e.key})
return ret
def hash_host(hostname, salt=None): def hash_host(hostname, salt=None):
""" """