variant of a patch from warren young to preserve the order of host entries from the 'known_hosts' file and preserve knowlege of which lines had multiple hostnames on them
This commit is contained in:
Robey Pointer 2006-07-26 19:55:19 -07:00
parent 17a93bce4c
commit 6821b6e8e8
1 changed files with 91 additions and 36 deletions

View File

@ -30,6 +30,62 @@ from paramiko.odict import odict
from paramiko.rsakey import RSAKey
class HostKeyEntry:
"""
Representation of a line in an OpenSSH-style "known hosts" file.
"""
def __init__(self, hostnames=None, key=None):
self.valid = (hostnames is not None) and (key is not None)
self.hostnames = hostnames
self.key = key
def from_line(cls, line):
"""
Parses the given line of text to find the names for the host,
the type of key given on this line in the known_hosts file, and
the key data.
Lines are expected to not have leading or training whitespace.
We don't bother to check for comments or empty lines. All of
that should be taken care of before sending the line to us.
@param line: a line from an OpenSSH known_hosts file
@type line: str
"""
fields = line.split(' ')
if len(fields) != 3:
# Bad number of fields
return None
names, keytype, key = fields
names = names.split(',')
# Decide what kind of key we're looking at and create an object
# to hold it accordingly.
if keytype == 'ssh-rsa':
key = RSAKey(data=base64.decodestring(key))
elif keytype == 'ssh-dss':
key = DSSKey(data=base64.decodestring(key))
else:
return None
return cls(names, key)
from_line = classmethod(from_line)
def to_line(self):
"""
Returns a string in OpenSSH known_hosts file format, or None if
the object is not in a valid state. A trailing newline is
included.
"""
if self.valid:
return '%s %s %s\n' % (','.join(self.hostnames), self.key.get_name(),
self.key.get_base64())
else:
return None
class HostKeys (UserDict.DictMixin):
"""
Representation of an openssh-style "known hosts" file. Host keys can be
@ -50,9 +106,8 @@ class HostKeys (UserDict.DictMixin):
@param filename: filename to load host keys from, or C{None}
@type filename: str
"""
# hostname -> keytype -> PKey
self._keys = odict()
self.contains_hashes = False
# emulate a dict of { hostname: { keytype: PKey } }
self._entries = []
if filename is not None:
self.load(filename)
@ -61,18 +116,18 @@ class HostKeys (UserDict.DictMixin):
Add a host key entry to the table. Any existing entry for a
C{(hostname, keytype)} pair will be replaced.
@param hostname:
@param hostname: the hostname (or IP) to add
@type hostname: str
@param keytype: key type (C{"ssh-rsa"} or C{"ssh-dss"})
@type keytype: str
@param key: the key to add
@type key: L{PKey}
"""
if not hostname in self._keys:
self._keys[hostname] = odict()
if hostname.startswith('|1|'):
self.contains_hashes = True
self._keys[hostname][keytype] = key
for e in self._entries:
if (hostname in e.hostnames) and (e.key.get_name() == keytype):
e.key = key
return
self._entries.append(HostKeyEntry([hostname], key))
def load(self, filename):
"""
@ -95,16 +150,9 @@ class HostKeys (UserDict.DictMixin):
line = line.strip()
if (len(line) == 0) or (line[0] == '#'):
continue
keylist = line.split(' ')
if len(keylist) != 3:
# don't understand this line
continue
hostlist, keytype, key = keylist
for host in hostlist.split(','):
if keytype == 'ssh-rsa':
self.add(host, keytype, RSAKey(data=base64.decodestring(key)))
elif keytype == 'ssh-dss':
self.add(host, keytype, DSSKey(data=base64.decodestring(key)))
e = HostKeyEntry.from_line(line)
if e is not None:
self._entries.append(e)
f.close()
def save(self, filename):
@ -122,9 +170,10 @@ class HostKeys (UserDict.DictMixin):
@since: 1.6.1
"""
f = open(filename, 'w')
for hostname, d in self._keys.iteritems():
for keytype, key in d.iteritems():
f.write('%s %s %s\n' % (hostname, keytype, key.get_base64()))
for e in self._entries:
line = e.to_line()
if line:
f.write(line)
f.close()
def lookup(self, hostname):
@ -133,21 +182,21 @@ class HostKeys (UserDict.DictMixin):
C{None} is returned. Otherwise a dictionary of keytype to key is
returned.
@param hostname: the hostname to lookup
@param hostname: the hostname (or IP) to lookup
@type hostname: str
@return: keys associated with this host (or C{None})
@rtype: dict(str, L{PKey})
"""
if hostname in self._keys:
return self._keys[hostname]
if not self.contains_hashes:
ret = {}
for e in self._entries:
for h in e.hostnames:
if h.startswith('|1|') and (self.hash_host(hostname, h) == h):
ret[e.key.get_name()] = e.key
elif h == hostname:
ret[e.key.get_name()] = e.key
if len(ret) == 0:
return None
for h in self._keys.keys():
if h.startswith('|1|'):
hmac = self.hash_host(hostname, h)
if hmac == h:
return self._keys[h]
return None
return ret
def check(self, hostname, key):
"""
@ -174,8 +223,7 @@ class HostKeys (UserDict.DictMixin):
"""
Remove all host keys from the dictionary.
"""
self._keys = {}
self.contains_hashes = False
self._entries = []
def __getitem__(self, key):
ret = self.lookup(key)
@ -188,10 +236,17 @@ class HostKeys (UserDict.DictMixin):
self._keys[key] = value
def keys(self):
return self._keys.keys()
ret = []
for e in self._entries:
for h in e.hostnames:
ret.append(h)
return ret
def values(self):
return self._keys.values();
ret = []
for e in self._entries:
ret.append({e.key.get_name() : e.key})
return ret
def hash_host(hostname, salt=None):
"""