variant of a patch from warren young to preserve the order of host entries from the 'known_hosts' file and preserve knowlege of which lines had multiple hostnames on them
This commit is contained in:
parent
17a93bce4c
commit
6821b6e8e8
|
@ -30,6 +30,62 @@ from paramiko.odict import odict
|
||||||
from paramiko.rsakey import RSAKey
|
from paramiko.rsakey import RSAKey
|
||||||
|
|
||||||
|
|
||||||
|
class HostKeyEntry:
|
||||||
|
"""
|
||||||
|
Representation of a line in an OpenSSH-style "known hosts" file.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, hostnames=None, key=None):
|
||||||
|
self.valid = (hostnames is not None) and (key is not None)
|
||||||
|
self.hostnames = hostnames
|
||||||
|
self.key = key
|
||||||
|
|
||||||
|
def from_line(cls, line):
|
||||||
|
"""
|
||||||
|
Parses the given line of text to find the names for the host,
|
||||||
|
the type of key given on this line in the known_hosts file, and
|
||||||
|
the key data.
|
||||||
|
|
||||||
|
Lines are expected to not have leading or training whitespace.
|
||||||
|
We don't bother to check for comments or empty lines. All of
|
||||||
|
that should be taken care of before sending the line to us.
|
||||||
|
|
||||||
|
@param line: a line from an OpenSSH known_hosts file
|
||||||
|
@type line: str
|
||||||
|
"""
|
||||||
|
fields = line.split(' ')
|
||||||
|
if len(fields) != 3:
|
||||||
|
# Bad number of fields
|
||||||
|
return None
|
||||||
|
|
||||||
|
names, keytype, key = fields
|
||||||
|
names = names.split(',')
|
||||||
|
|
||||||
|
# Decide what kind of key we're looking at and create an object
|
||||||
|
# to hold it accordingly.
|
||||||
|
if keytype == 'ssh-rsa':
|
||||||
|
key = RSAKey(data=base64.decodestring(key))
|
||||||
|
elif keytype == 'ssh-dss':
|
||||||
|
key = DSSKey(data=base64.decodestring(key))
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return cls(names, key)
|
||||||
|
from_line = classmethod(from_line)
|
||||||
|
|
||||||
|
def to_line(self):
|
||||||
|
"""
|
||||||
|
Returns a string in OpenSSH known_hosts file format, or None if
|
||||||
|
the object is not in a valid state. A trailing newline is
|
||||||
|
included.
|
||||||
|
"""
|
||||||
|
if self.valid:
|
||||||
|
return '%s %s %s\n' % (','.join(self.hostnames), self.key.get_name(),
|
||||||
|
self.key.get_base64())
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
class HostKeys (UserDict.DictMixin):
|
class HostKeys (UserDict.DictMixin):
|
||||||
"""
|
"""
|
||||||
Representation of an openssh-style "known hosts" file. Host keys can be
|
Representation of an openssh-style "known hosts" file. Host keys can be
|
||||||
|
@ -50,9 +106,8 @@ class HostKeys (UserDict.DictMixin):
|
||||||
@param filename: filename to load host keys from, or C{None}
|
@param filename: filename to load host keys from, or C{None}
|
||||||
@type filename: str
|
@type filename: str
|
||||||
"""
|
"""
|
||||||
# hostname -> keytype -> PKey
|
# emulate a dict of { hostname: { keytype: PKey } }
|
||||||
self._keys = odict()
|
self._entries = []
|
||||||
self.contains_hashes = False
|
|
||||||
if filename is not None:
|
if filename is not None:
|
||||||
self.load(filename)
|
self.load(filename)
|
||||||
|
|
||||||
|
@ -61,18 +116,18 @@ class HostKeys (UserDict.DictMixin):
|
||||||
Add a host key entry to the table. Any existing entry for a
|
Add a host key entry to the table. Any existing entry for a
|
||||||
C{(hostname, keytype)} pair will be replaced.
|
C{(hostname, keytype)} pair will be replaced.
|
||||||
|
|
||||||
@param hostname:
|
@param hostname: the hostname (or IP) to add
|
||||||
@type hostname: str
|
@type hostname: str
|
||||||
@param keytype: key type (C{"ssh-rsa"} or C{"ssh-dss"})
|
@param keytype: key type (C{"ssh-rsa"} or C{"ssh-dss"})
|
||||||
@type keytype: str
|
@type keytype: str
|
||||||
@param key: the key to add
|
@param key: the key to add
|
||||||
@type key: L{PKey}
|
@type key: L{PKey}
|
||||||
"""
|
"""
|
||||||
if not hostname in self._keys:
|
for e in self._entries:
|
||||||
self._keys[hostname] = odict()
|
if (hostname in e.hostnames) and (e.key.get_name() == keytype):
|
||||||
if hostname.startswith('|1|'):
|
e.key = key
|
||||||
self.contains_hashes = True
|
return
|
||||||
self._keys[hostname][keytype] = key
|
self._entries.append(HostKeyEntry([hostname], key))
|
||||||
|
|
||||||
def load(self, filename):
|
def load(self, filename):
|
||||||
"""
|
"""
|
||||||
|
@ -95,16 +150,9 @@ class HostKeys (UserDict.DictMixin):
|
||||||
line = line.strip()
|
line = line.strip()
|
||||||
if (len(line) == 0) or (line[0] == '#'):
|
if (len(line) == 0) or (line[0] == '#'):
|
||||||
continue
|
continue
|
||||||
keylist = line.split(' ')
|
e = HostKeyEntry.from_line(line)
|
||||||
if len(keylist) != 3:
|
if e is not None:
|
||||||
# don't understand this line
|
self._entries.append(e)
|
||||||
continue
|
|
||||||
hostlist, keytype, key = keylist
|
|
||||||
for host in hostlist.split(','):
|
|
||||||
if keytype == 'ssh-rsa':
|
|
||||||
self.add(host, keytype, RSAKey(data=base64.decodestring(key)))
|
|
||||||
elif keytype == 'ssh-dss':
|
|
||||||
self.add(host, keytype, DSSKey(data=base64.decodestring(key)))
|
|
||||||
f.close()
|
f.close()
|
||||||
|
|
||||||
def save(self, filename):
|
def save(self, filename):
|
||||||
|
@ -122,9 +170,10 @@ class HostKeys (UserDict.DictMixin):
|
||||||
@since: 1.6.1
|
@since: 1.6.1
|
||||||
"""
|
"""
|
||||||
f = open(filename, 'w')
|
f = open(filename, 'w')
|
||||||
for hostname, d in self._keys.iteritems():
|
for e in self._entries:
|
||||||
for keytype, key in d.iteritems():
|
line = e.to_line()
|
||||||
f.write('%s %s %s\n' % (hostname, keytype, key.get_base64()))
|
if line:
|
||||||
|
f.write(line)
|
||||||
f.close()
|
f.close()
|
||||||
|
|
||||||
def lookup(self, hostname):
|
def lookup(self, hostname):
|
||||||
|
@ -133,21 +182,21 @@ class HostKeys (UserDict.DictMixin):
|
||||||
C{None} is returned. Otherwise a dictionary of keytype to key is
|
C{None} is returned. Otherwise a dictionary of keytype to key is
|
||||||
returned.
|
returned.
|
||||||
|
|
||||||
@param hostname: the hostname to lookup
|
@param hostname: the hostname (or IP) to lookup
|
||||||
@type hostname: str
|
@type hostname: str
|
||||||
@return: keys associated with this host (or C{None})
|
@return: keys associated with this host (or C{None})
|
||||||
@rtype: dict(str, L{PKey})
|
@rtype: dict(str, L{PKey})
|
||||||
"""
|
"""
|
||||||
if hostname in self._keys:
|
ret = {}
|
||||||
return self._keys[hostname]
|
for e in self._entries:
|
||||||
if not self.contains_hashes:
|
for h in e.hostnames:
|
||||||
|
if h.startswith('|1|') and (self.hash_host(hostname, h) == h):
|
||||||
|
ret[e.key.get_name()] = e.key
|
||||||
|
elif h == hostname:
|
||||||
|
ret[e.key.get_name()] = e.key
|
||||||
|
if len(ret) == 0:
|
||||||
return None
|
return None
|
||||||
for h in self._keys.keys():
|
return ret
|
||||||
if h.startswith('|1|'):
|
|
||||||
hmac = self.hash_host(hostname, h)
|
|
||||||
if hmac == h:
|
|
||||||
return self._keys[h]
|
|
||||||
return None
|
|
||||||
|
|
||||||
def check(self, hostname, key):
|
def check(self, hostname, key):
|
||||||
"""
|
"""
|
||||||
|
@ -174,8 +223,7 @@ class HostKeys (UserDict.DictMixin):
|
||||||
"""
|
"""
|
||||||
Remove all host keys from the dictionary.
|
Remove all host keys from the dictionary.
|
||||||
"""
|
"""
|
||||||
self._keys = {}
|
self._entries = []
|
||||||
self.contains_hashes = False
|
|
||||||
|
|
||||||
def __getitem__(self, key):
|
def __getitem__(self, key):
|
||||||
ret = self.lookup(key)
|
ret = self.lookup(key)
|
||||||
|
@ -188,10 +236,17 @@ class HostKeys (UserDict.DictMixin):
|
||||||
self._keys[key] = value
|
self._keys[key] = value
|
||||||
|
|
||||||
def keys(self):
|
def keys(self):
|
||||||
return self._keys.keys()
|
ret = []
|
||||||
|
for e in self._entries:
|
||||||
|
for h in e.hostnames:
|
||||||
|
ret.append(h)
|
||||||
|
return ret
|
||||||
|
|
||||||
def values(self):
|
def values(self):
|
||||||
return self._keys.values();
|
ret = []
|
||||||
|
for e in self._entries:
|
||||||
|
ret.append({e.key.get_name() : e.key})
|
||||||
|
return ret
|
||||||
|
|
||||||
def hash_host(hostname, salt=None):
|
def hash_host(hostname, salt=None):
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Reference in New Issue