diff --git a/paramiko/hostkeys.py b/paramiko/hostkeys.py index 7cc8d40..bf5f442 100644 --- a/paramiko/hostkeys.py +++ b/paramiko/hostkeys.py @@ -30,6 +30,62 @@ from paramiko.odict import odict from paramiko.rsakey import RSAKey +class HostKeyEntry: + """ + Representation of a line in an OpenSSH-style "known hosts" file. + """ + + def __init__(self, hostnames=None, key=None): + self.valid = (hostnames is not None) and (key is not None) + self.hostnames = hostnames + self.key = key + + def from_line(cls, line): + """ + Parses the given line of text to find the names for the host, + the type of key given on this line in the known_hosts file, and + the key data. + + Lines are expected to not have leading or training whitespace. + We don't bother to check for comments or empty lines. All of + that should be taken care of before sending the line to us. + + @param line: a line from an OpenSSH known_hosts file + @type line: str + """ + fields = line.split(' ') + if len(fields) != 3: + # Bad number of fields + return None + + names, keytype, key = fields + names = names.split(',') + + # Decide what kind of key we're looking at and create an object + # to hold it accordingly. + if keytype == 'ssh-rsa': + key = RSAKey(data=base64.decodestring(key)) + elif keytype == 'ssh-dss': + key = DSSKey(data=base64.decodestring(key)) + else: + return None + + return cls(names, key) + from_line = classmethod(from_line) + + def to_line(self): + """ + Returns a string in OpenSSH known_hosts file format, or None if + the object is not in a valid state. A trailing newline is + included. + """ + if self.valid: + return '%s %s %s\n' % (','.join(self.hostnames), self.key.get_name(), + self.key.get_base64()) + else: + return None + + class HostKeys (UserDict.DictMixin): """ Representation of an openssh-style "known hosts" file. Host keys can be @@ -50,9 +106,8 @@ class HostKeys (UserDict.DictMixin): @param filename: filename to load host keys from, or C{None} @type filename: str """ - # hostname -> keytype -> PKey - self._keys = odict() - self.contains_hashes = False + # emulate a dict of { hostname: { keytype: PKey } } + self._entries = [] if filename is not None: self.load(filename) @@ -61,18 +116,18 @@ class HostKeys (UserDict.DictMixin): Add a host key entry to the table. Any existing entry for a C{(hostname, keytype)} pair will be replaced. - @param hostname: + @param hostname: the hostname (or IP) to add @type hostname: str @param keytype: key type (C{"ssh-rsa"} or C{"ssh-dss"}) @type keytype: str @param key: the key to add @type key: L{PKey} """ - if not hostname in self._keys: - self._keys[hostname] = odict() - if hostname.startswith('|1|'): - self.contains_hashes = True - self._keys[hostname][keytype] = key + for e in self._entries: + if (hostname in e.hostnames) and (e.key.get_name() == keytype): + e.key = key + return + self._entries.append(HostKeyEntry([hostname], key)) def load(self, filename): """ @@ -95,16 +150,9 @@ class HostKeys (UserDict.DictMixin): line = line.strip() if (len(line) == 0) or (line[0] == '#'): continue - keylist = line.split(' ') - if len(keylist) != 3: - # don't understand this line - continue - hostlist, keytype, key = keylist - for host in hostlist.split(','): - if keytype == 'ssh-rsa': - self.add(host, keytype, RSAKey(data=base64.decodestring(key))) - elif keytype == 'ssh-dss': - self.add(host, keytype, DSSKey(data=base64.decodestring(key))) + e = HostKeyEntry.from_line(line) + if e is not None: + self._entries.append(e) f.close() def save(self, filename): @@ -122,9 +170,10 @@ class HostKeys (UserDict.DictMixin): @since: 1.6.1 """ f = open(filename, 'w') - for hostname, d in self._keys.iteritems(): - for keytype, key in d.iteritems(): - f.write('%s %s %s\n' % (hostname, keytype, key.get_base64())) + for e in self._entries: + line = e.to_line() + if line: + f.write(line) f.close() def lookup(self, hostname): @@ -133,21 +182,21 @@ class HostKeys (UserDict.DictMixin): C{None} is returned. Otherwise a dictionary of keytype to key is returned. - @param hostname: the hostname to lookup + @param hostname: the hostname (or IP) to lookup @type hostname: str @return: keys associated with this host (or C{None}) @rtype: dict(str, L{PKey}) """ - if hostname in self._keys: - return self._keys[hostname] - if not self.contains_hashes: + ret = {} + for e in self._entries: + for h in e.hostnames: + if h.startswith('|1|') and (self.hash_host(hostname, h) == h): + ret[e.key.get_name()] = e.key + elif h == hostname: + ret[e.key.get_name()] = e.key + if len(ret) == 0: return None - for h in self._keys.keys(): - if h.startswith('|1|'): - hmac = self.hash_host(hostname, h) - if hmac == h: - return self._keys[h] - return None + return ret def check(self, hostname, key): """ @@ -174,8 +223,7 @@ class HostKeys (UserDict.DictMixin): """ Remove all host keys from the dictionary. """ - self._keys = {} - self.contains_hashes = False + self._entries = [] def __getitem__(self, key): ret = self.lookup(key) @@ -188,10 +236,17 @@ class HostKeys (UserDict.DictMixin): self._keys[key] = value def keys(self): - return self._keys.keys() + ret = [] + for e in self._entries: + for h in e.hostnames: + ret.append(h) + return ret def values(self): - return self._keys.values(); + ret = [] + for e in self._entries: + ret.append({e.key.get_name() : e.key}) + return ret def hash_host(hostname, salt=None): """