variant of a patch from warren young to preserve the order of host entries from the 'known_hosts' file and preserve knowlege of which lines had multiple hostnames on them
This commit is contained in:
parent
17a93bce4c
commit
6821b6e8e8
|
@ -30,6 +30,62 @@ from paramiko.odict import odict
|
|||
from paramiko.rsakey import RSAKey
|
||||
|
||||
|
||||
class HostKeyEntry:
|
||||
"""
|
||||
Representation of a line in an OpenSSH-style "known hosts" file.
|
||||
"""
|
||||
|
||||
def __init__(self, hostnames=None, key=None):
|
||||
self.valid = (hostnames is not None) and (key is not None)
|
||||
self.hostnames = hostnames
|
||||
self.key = key
|
||||
|
||||
def from_line(cls, line):
|
||||
"""
|
||||
Parses the given line of text to find the names for the host,
|
||||
the type of key given on this line in the known_hosts file, and
|
||||
the key data.
|
||||
|
||||
Lines are expected to not have leading or training whitespace.
|
||||
We don't bother to check for comments or empty lines. All of
|
||||
that should be taken care of before sending the line to us.
|
||||
|
||||
@param line: a line from an OpenSSH known_hosts file
|
||||
@type line: str
|
||||
"""
|
||||
fields = line.split(' ')
|
||||
if len(fields) != 3:
|
||||
# Bad number of fields
|
||||
return None
|
||||
|
||||
names, keytype, key = fields
|
||||
names = names.split(',')
|
||||
|
||||
# Decide what kind of key we're looking at and create an object
|
||||
# to hold it accordingly.
|
||||
if keytype == 'ssh-rsa':
|
||||
key = RSAKey(data=base64.decodestring(key))
|
||||
elif keytype == 'ssh-dss':
|
||||
key = DSSKey(data=base64.decodestring(key))
|
||||
else:
|
||||
return None
|
||||
|
||||
return cls(names, key)
|
||||
from_line = classmethod(from_line)
|
||||
|
||||
def to_line(self):
|
||||
"""
|
||||
Returns a string in OpenSSH known_hosts file format, or None if
|
||||
the object is not in a valid state. A trailing newline is
|
||||
included.
|
||||
"""
|
||||
if self.valid:
|
||||
return '%s %s %s\n' % (','.join(self.hostnames), self.key.get_name(),
|
||||
self.key.get_base64())
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
class HostKeys (UserDict.DictMixin):
|
||||
"""
|
||||
Representation of an openssh-style "known hosts" file. Host keys can be
|
||||
|
@ -50,9 +106,8 @@ class HostKeys (UserDict.DictMixin):
|
|||
@param filename: filename to load host keys from, or C{None}
|
||||
@type filename: str
|
||||
"""
|
||||
# hostname -> keytype -> PKey
|
||||
self._keys = odict()
|
||||
self.contains_hashes = False
|
||||
# emulate a dict of { hostname: { keytype: PKey } }
|
||||
self._entries = []
|
||||
if filename is not None:
|
||||
self.load(filename)
|
||||
|
||||
|
@ -61,18 +116,18 @@ class HostKeys (UserDict.DictMixin):
|
|||
Add a host key entry to the table. Any existing entry for a
|
||||
C{(hostname, keytype)} pair will be replaced.
|
||||
|
||||
@param hostname:
|
||||
@param hostname: the hostname (or IP) to add
|
||||
@type hostname: str
|
||||
@param keytype: key type (C{"ssh-rsa"} or C{"ssh-dss"})
|
||||
@type keytype: str
|
||||
@param key: the key to add
|
||||
@type key: L{PKey}
|
||||
"""
|
||||
if not hostname in self._keys:
|
||||
self._keys[hostname] = odict()
|
||||
if hostname.startswith('|1|'):
|
||||
self.contains_hashes = True
|
||||
self._keys[hostname][keytype] = key
|
||||
for e in self._entries:
|
||||
if (hostname in e.hostnames) and (e.key.get_name() == keytype):
|
||||
e.key = key
|
||||
return
|
||||
self._entries.append(HostKeyEntry([hostname], key))
|
||||
|
||||
def load(self, filename):
|
||||
"""
|
||||
|
@ -95,16 +150,9 @@ class HostKeys (UserDict.DictMixin):
|
|||
line = line.strip()
|
||||
if (len(line) == 0) or (line[0] == '#'):
|
||||
continue
|
||||
keylist = line.split(' ')
|
||||
if len(keylist) != 3:
|
||||
# don't understand this line
|
||||
continue
|
||||
hostlist, keytype, key = keylist
|
||||
for host in hostlist.split(','):
|
||||
if keytype == 'ssh-rsa':
|
||||
self.add(host, keytype, RSAKey(data=base64.decodestring(key)))
|
||||
elif keytype == 'ssh-dss':
|
||||
self.add(host, keytype, DSSKey(data=base64.decodestring(key)))
|
||||
e = HostKeyEntry.from_line(line)
|
||||
if e is not None:
|
||||
self._entries.append(e)
|
||||
f.close()
|
||||
|
||||
def save(self, filename):
|
||||
|
@ -122,9 +170,10 @@ class HostKeys (UserDict.DictMixin):
|
|||
@since: 1.6.1
|
||||
"""
|
||||
f = open(filename, 'w')
|
||||
for hostname, d in self._keys.iteritems():
|
||||
for keytype, key in d.iteritems():
|
||||
f.write('%s %s %s\n' % (hostname, keytype, key.get_base64()))
|
||||
for e in self._entries:
|
||||
line = e.to_line()
|
||||
if line:
|
||||
f.write(line)
|
||||
f.close()
|
||||
|
||||
def lookup(self, hostname):
|
||||
|
@ -133,21 +182,21 @@ class HostKeys (UserDict.DictMixin):
|
|||
C{None} is returned. Otherwise a dictionary of keytype to key is
|
||||
returned.
|
||||
|
||||
@param hostname: the hostname to lookup
|
||||
@param hostname: the hostname (or IP) to lookup
|
||||
@type hostname: str
|
||||
@return: keys associated with this host (or C{None})
|
||||
@rtype: dict(str, L{PKey})
|
||||
"""
|
||||
if hostname in self._keys:
|
||||
return self._keys[hostname]
|
||||
if not self.contains_hashes:
|
||||
return None
|
||||
for h in self._keys.keys():
|
||||
if h.startswith('|1|'):
|
||||
hmac = self.hash_host(hostname, h)
|
||||
if hmac == h:
|
||||
return self._keys[h]
|
||||
ret = {}
|
||||
for e in self._entries:
|
||||
for h in e.hostnames:
|
||||
if h.startswith('|1|') and (self.hash_host(hostname, h) == h):
|
||||
ret[e.key.get_name()] = e.key
|
||||
elif h == hostname:
|
||||
ret[e.key.get_name()] = e.key
|
||||
if len(ret) == 0:
|
||||
return None
|
||||
return ret
|
||||
|
||||
def check(self, hostname, key):
|
||||
"""
|
||||
|
@ -174,8 +223,7 @@ class HostKeys (UserDict.DictMixin):
|
|||
"""
|
||||
Remove all host keys from the dictionary.
|
||||
"""
|
||||
self._keys = {}
|
||||
self.contains_hashes = False
|
||||
self._entries = []
|
||||
|
||||
def __getitem__(self, key):
|
||||
ret = self.lookup(key)
|
||||
|
@ -188,10 +236,17 @@ class HostKeys (UserDict.DictMixin):
|
|||
self._keys[key] = value
|
||||
|
||||
def keys(self):
|
||||
return self._keys.keys()
|
||||
ret = []
|
||||
for e in self._entries:
|
||||
for h in e.hostnames:
|
||||
ret.append(h)
|
||||
return ret
|
||||
|
||||
def values(self):
|
||||
return self._keys.values();
|
||||
ret = []
|
||||
for e in self._entries:
|
||||
ret.append({e.key.get_name() : e.key})
|
||||
return ret
|
||||
|
||||
def hash_host(hostname, salt=None):
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue