diff --git a/paramiko/client.py b/paramiko/client.py index 567c833..d3898dc 100644 --- a/paramiko/client.py +++ b/paramiko/client.py @@ -61,10 +61,7 @@ class AutoAddPolicy (MissingHostKeyPolicy): """ def missing_host_key(self, client, hostname, key): - if not client._host_keys.has_key(hostname): - client._host_keys[hostname] = {} - client._host_keys[hostname][key.get_name()] = key - our_server_key = server_key + client._host_keys.add(hostname, key.get_name(), key) if client._host_keys_filename is not None: client.save_host_keys(client._host_keys_filename) client._log(DEBUG, 'Adding %s host key for %s: %s' % @@ -97,6 +94,8 @@ class SSHClient (object): You may pass in explicit overrides for authentication and server host key checking. The default mechanism is to try to use local key files or an SSH agent (if one is running). + + @since: 1.6 """ def __init__(self): @@ -177,7 +176,24 @@ class SSHClient (object): f.write('%s %s %s\n' % (hostname, keytype, key.get_base64())) f.close() + def get_host_keys(self): + """ + Get the local L{HostKeys} object. This can be used to examine the + local host keys or change them. + + @return: the local host keys + @rtype: L{HostKeys} + """ + return self._host_keys + def set_log_channel(self, channel): + """ + Set the channel for logging. The default is C{"paramiko.transport"} + but it can be set to anything you want. + + @param name: new channel name for logging + @type name: str + """ self._log_channel = channel def set_missing_host_key_policy(self, policy): @@ -232,7 +248,7 @@ class SSHClient (object): @raise SSHException: if there was an error authenticating or verifying the server's host key """ - t = Transport((hostname, port)) + t = self._transport = Transport((hostname, port)) if self._log_channel is not None: t.set_log_channel(self._log_channel) t.start_client() @@ -247,6 +263,8 @@ class SSHClient (object): if our_server_key is None: # will raise exception if the key is rejected; let that fall out self._policy.missing_host_key(self, hostname, server_key) + # if this continues, assume the key is ok + our_server_key = server_key our_server_key_hex = hexify(our_server_key.get_fingerprint()) @@ -254,7 +272,6 @@ class SSHClient (object): raise SSHException('Host key for server %s does not match! (%s != %s)' % (hostname, our_server_key_kex, server_key_hex)) - self._transport = t if username is None: username = getpass.getuser() self._auth(username, password, pkey, key_filename) @@ -281,20 +298,42 @@ class SSHClient (object): @raise SSHException: if the server fails to execute the command """ chan = self._transport.open_session() - if not chan.exec_command(command): - raise SSHException('Command execution failed.') + chan.exec_command(command) stdin = chan.makefile('wb') stdout = chan.makefile('rb') stderr = chan.makefile_stderr('rb') return stdin, stdout, stderr - def invoke_shell(self): - pass - #FIXME + def invoke_shell(self, term='vt100', width=80, height=24): + """ + Start an interactive shell session on the SSH server. A new L{Channel} + is opened and connected to a pseudo-terminal using the requested + terminal type and size. + + @param term: the terminal type to emulate (for example, C{"vt100"}) + @type term: str + @param width: the width (in characters) of the terminal window + @type width: int + @param height: the height (in characters) of the terminal window + @type height: int + @return: a new channel connected to the remote shell + @rtype: L{Channel} + + @raise SSHException: if the server fails to invoke a shell + """ + chan = self._transport.open_session() + chan.get_pty(term, width, height) + chan.invoke_shell() + return chan def open_sftp(self): - pass - # FIXME + """ + Open an SFTP session on the SSH server. + + @return: a new SFTP session object + @rtype: L{SFTPClient} + """ + return self._transport.open_sftp_client() def _auth(self, username, password, pkey, key_filename): """ @@ -318,7 +357,7 @@ class SSHClient (object): saved_exception = e if key_filename is not None: - for pkey_class in (paramiko.RSAKey, paramiko.DSSKey): + for pkey_class in (RSAKey, DSSKey): try: key = pkey_class.from_private_key_file(key_filename, password) self._log(DEBUG, 'Trying key %s from %s' % (hexify(key.get_fingerprint()), key_filename)) @@ -335,8 +374,8 @@ class SSHClient (object): except SSHException, e: saved_exception = e - for pkey_class, filename in ((paramiko.RSAKey, 'id_rsa'), - (paramiko.DSSKey, 'id_dsa')): + for pkey_class, filename in ((RSAKey, 'id_rsa'), + (DSSKey, 'id_dsa')): filename = os.path.expanduser('~/.ssh/' + filename) try: key = pkey_class.from_private_key_file(filename, password) @@ -345,10 +384,12 @@ class SSHClient (object): return except SSHException, e: saved_exception = e + except IOError, e: + saved_exception = e if password is not None: try: - transport.auth_password(username, password) + self._transport.auth_password(username, password) return except SSHException, e: saved_exception = e diff --git a/test.py b/test.py index 62f6b00..00767eb 100755 --- a/test.py +++ b/test.py @@ -39,6 +39,7 @@ from test_packetizer import PacketizerTest from test_transport import TransportTest from test_sftp import SFTPTest from test_sftp_big import BigSFTPTest +from test_client import SSHClientTest default_host = 'localhost' default_user = os.environ.get('USER', 'nobody') @@ -100,6 +101,7 @@ suite.addTest(unittest.makeSuite(KexTest)) suite.addTest(unittest.makeSuite(PacketizerTest)) if options.use_transport: suite.addTest(unittest.makeSuite(TransportTest)) +suite.addTest(unittest.makeSuite(SSHClientTest)) if options.use_sftp: suite.addTest(unittest.makeSuite(SFTPTest)) if options.use_big_file: diff --git a/tests/test_client.py b/tests/test_client.py new file mode 100644 index 0000000..32f1a30 --- /dev/null +++ b/tests/test_client.py @@ -0,0 +1,127 @@ +# Copyright (C) 2003-2005 Robey Pointer +# +# This file is part of paramiko. +# +# Paramiko is free software; you can redistribute it and/or modify it under the +# terms of the GNU Lesser General Public License as published by the Free +# Software Foundation; either version 2.1 of the License, or (at your option) +# any later version. +# +# Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY +# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR +# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with Paramiko; if not, write to the Free Software Foundation, Inc., +# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA. + +""" +Some unit tests for SSHClient. +""" + +import socket +import threading +import unittest +import paramiko + + +class NullServer (paramiko.ServerInterface): + + def get_allowed_auths(self, username): + if username == 'slowdive': + return 'publickey,password' + return 'publickey' + + def check_auth_password(self, username, password): + if (username == 'slowdive') and (password == 'pygmalion'): + return paramiko.AUTH_SUCCESSFUL + return paramiko.AUTH_FAILED + + def check_channel_request(self, kind, chanid): + return paramiko.OPEN_SUCCEEDED + + def check_channel_exec_request(self, channel, command): + if command != 'yes': + return False + return True + + +class SSHClientTest (unittest.TestCase): + + def setUp(self): + self.sockl = socket.socket() + self.sockl.bind(('localhost', 0)) + self.sockl.listen(1) + self.addr, self.port = self.sockl.getsockname() + self.event = threading.Event() + thread = threading.Thread(target=self._run) + thread.start() + + def tearDown(self): + self.tc.close() + self.ts.close() + self.socks.close() + self.sockl.close() + + def _run(self): + self.socks, addr = self.sockl.accept() + self.ts = paramiko.Transport(self.socks) + host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key') + self.ts.add_server_key(host_key) + server = NullServer() + self.ts.start_server(self.event, server) + + + def test_1_client(self): + """ + verify that the SSHClient stuff works too. + """ + host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key') + public_host_key = paramiko.RSAKey(data=str(host_key)) + + self.tc = paramiko.SSHClient() + self.tc.get_host_keys().add(self.addr, 'ssh-rsa', public_host_key) + self.tc.connect(self.addr, self.port, username='slowdive', password='pygmalion') + + self.event.wait(1.0) + self.assert_(self.event.isSet()) + self.assert_(self.ts.is_active()) + self.assertEquals('slowdive', self.ts.get_username()) + self.assertEquals(True, self.ts.is_authenticated()) + + stdin, stdout, stderr = self.tc.exec_command('yes') + schan = self.ts.accept(1.0) + + schan.send('Hello there.\n') + schan.send_stderr('This is on stderr.\n') + schan.close() + + self.assertEquals('Hello there.\n', stdout.readline()) + self.assertEquals('', stdout.readline()) + self.assertEquals('This is on stderr.\n', stderr.readline()) + self.assertEquals('', stderr.readline()) + + stdin.close() + stdout.close() + stderr.close() + + def test_2_auto_add_policy(self): + """ + verify that SSHClient's AutoAddPolicy works. + """ + host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key') + public_host_key = paramiko.RSAKey(data=str(host_key)) + + self.tc = paramiko.SSHClient() + self.tc.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + self.assertEquals(0, len(self.tc.get_host_keys())) + self.tc.connect(self.addr, self.port, username='slowdive', password='pygmalion') + + self.event.wait(1.0) + self.assert_(self.event.isSet()) + self.assert_(self.ts.is_active()) + self.assertEquals('slowdive', self.ts.get_username()) + self.assertEquals(True, self.ts.is_authenticated()) + self.assertEquals(1, len(self.tc.get_host_keys())) + self.assertEquals(public_host_key, self.tc.get_host_keys()[self.addr]['ssh-rsa'])