diff --git a/demo.py b/demo.py index dfb7231..9633c5c 100755 --- a/demo.py +++ b/demo.py @@ -23,7 +23,10 @@ def load_host_keys(): for host in hosts: if not keys.has_key(host): keys[host] = {} - keys[host][keytype] = base64.decodestring(key) + if keytype == 'ssh-rsa': + keys[host][keytype] = paramiko.RSAKey(data=base64.decodestring(key)) + elif keytype == 'ssh-dss': + keys[host][keytype] = paramiko.DSSKey(data=base64.decodestring(key)) f.close() return keys @@ -75,7 +78,7 @@ try: print '*** WARNING: Unknown host key!' elif not keys[hostname].has_key(key.get_name()): print '*** WARNING: Unknown host key!' - elif keys[hostname][key.get_name()] != str(key): + elif keys[hostname][key.get_name()] != key: print '*** WARNING: Host key has changed!!!' sys.exit(1) else: diff --git a/demo_simple.py b/demo_simple.py index 0bc46bf..d31b063 100755 --- a/demo_simple.py +++ b/demo_simple.py @@ -23,7 +23,10 @@ def load_host_keys(): for host in hosts: if not keys.has_key(host): keys[host] = {} - keys[host][keytype] = base64.decodestring(key) + if keytype == 'ssh-rsa': + keys[host][keytype] = paramiko.RSAKey(data=base64.decodestring(key)) + elif keytype == 'ssh-dss': + keys[host][keytype] = paramiko.DSSKey(data=base64.decodestring(key)) f.close() return keys @@ -70,7 +73,7 @@ if hkeys.has_key(hostname): # now, connect and use paramiko Transport to negotiate SSH2 across the connection try: t = paramiko.Transport((hostname, port)) - t.connect(username=username, password=password, hostkeytype=hostkeytype, hostkey=hostkey) + t.connect(username=username, password=password, hostkey=hostkey) chan = t.open_session() chan.get_pty() chan.invoke_shell() diff --git a/paramiko/pkey.py b/paramiko/pkey.py index a7dc8a8..a872ae8 100644 --- a/paramiko/pkey.py +++ b/paramiko/pkey.py @@ -143,7 +143,7 @@ class PKey (object): @since: fearow """ - return ''.join(base64.encodestring(str(self)).split('\n')) + return base64.encodestring(str(self)).replace('\n', '') def sign_ssh_data(self, randpool, data): """ diff --git a/paramiko/transport.py b/paramiko/transport.py index 9b171c7..eeac020 100644 --- a/paramiko/transport.py +++ b/paramiko/transport.py @@ -693,7 +693,7 @@ class BaseTransport (threading.Thread): self.lock.release() return chan - def connect(self, hostkeytype=None, hostkey=None, username='', password=None, pkey=None): + def connect(self, hostkey=None, username='', password=None, pkey=None): """ Negotiate an SSH2 session, and optionally verify the server's host key and authenticate using a password or private key. This is a shortcut @@ -712,13 +712,9 @@ class BaseTransport (threading.Thread): succeed, but a subsequent L{open_channel} or L{open_session} call may fail because you haven't authenticated yet. - @param hostkeytype: the type of host key expected from the server - (usually C{"ssh-rsa"} or C{"ssh-dss"}), or C{None} if you don't want - to do host key verification. - @type hostkeytype: str @param hostkey: the host key expected from the server, or C{None} if you don't want to do host key verification. - @type hostkey: str + @type hostkey: L{PKey} @param username: the username to authenticate as. @type username: str @param password: a password to use for authentication, if you want to @@ -733,8 +729,8 @@ class BaseTransport (threading.Thread): @since: doduo """ - if hostkeytype is not None: - self._preferred_keys = [ hostkeytype ] + if hostkey is not None: + self._preferred_keys = [ hostkey.get_name() ] event = threading.Event() self.start_client(event) @@ -750,14 +746,14 @@ class BaseTransport (threading.Thread): break # check host key if we were given one - if (hostkeytype is not None) and (hostkey is not None): + if (hostkey is not None): key = self.get_remote_server_key() - if (key.get_name() != hostkeytype) or (str(key) != hostkey): + if (key.get_name() != hostkey.get_name()) or (str(key) != str(hostkey)): self._log(DEBUG, 'Bad host key from server') - self._log(DEBUG, 'Expected: %s: %s' % (hostkeytype, repr(hostkey))) + self._log(DEBUG, 'Expected: %s: %s' % (hostkey.get_name(), repr(str(hostkey)))) self._log(DEBUG, 'Got : %s: %s' % (key.get_name(), repr(str(key)))) raise SSHException('Bad host key from server') - self._log(DEBUG, 'Host key verified (%s)' % hostkeytype) + self._log(DEBUG, 'Host key verified (%s)' % hostkey.get_name()) if (pkey is not None) or (password is not None): event.clear() @@ -1003,7 +999,7 @@ class BaseTransport (threading.Thread): def _verify_key(self, host_key, sig): key = self._key_info[self.host_key_type](Message(host_key)) - if (key == None) or not key.valid: + if key is None: raise SSHException('Unknown host key type') if not key.verify_ssh_sig(self.H, Message(sig)): raise SSHException('Signature verification (%s) failed. Boo. Robey should debug this.' % self.host_key_type)