add unit tests for SSHClient, and fix a few bugs that uncovered
This commit is contained in:
Robey Pointer 2006-05-07 17:20:07 -07:00
parent de1e072c73
commit 2a03425e27
3 changed files with 187 additions and 17 deletions

View File

@ -61,10 +61,7 @@ class AutoAddPolicy (MissingHostKeyPolicy):
""" """
def missing_host_key(self, client, hostname, key): def missing_host_key(self, client, hostname, key):
if not client._host_keys.has_key(hostname): client._host_keys.add(hostname, key.get_name(), key)
client._host_keys[hostname] = {}
client._host_keys[hostname][key.get_name()] = key
our_server_key = server_key
if client._host_keys_filename is not None: if client._host_keys_filename is not None:
client.save_host_keys(client._host_keys_filename) client.save_host_keys(client._host_keys_filename)
client._log(DEBUG, 'Adding %s host key for %s: %s' % client._log(DEBUG, 'Adding %s host key for %s: %s' %
@ -97,6 +94,8 @@ class SSHClient (object):
You may pass in explicit overrides for authentication and server host key You may pass in explicit overrides for authentication and server host key
checking. The default mechanism is to try to use local key files or an checking. The default mechanism is to try to use local key files or an
SSH agent (if one is running). SSH agent (if one is running).
@since: 1.6
""" """
def __init__(self): def __init__(self):
@ -177,7 +176,24 @@ class SSHClient (object):
f.write('%s %s %s\n' % (hostname, keytype, key.get_base64())) f.write('%s %s %s\n' % (hostname, keytype, key.get_base64()))
f.close() f.close()
def get_host_keys(self):
"""
Get the local L{HostKeys} object. This can be used to examine the
local host keys or change them.
@return: the local host keys
@rtype: L{HostKeys}
"""
return self._host_keys
def set_log_channel(self, channel): def set_log_channel(self, channel):
"""
Set the channel for logging. The default is C{"paramiko.transport"}
but it can be set to anything you want.
@param name: new channel name for logging
@type name: str
"""
self._log_channel = channel self._log_channel = channel
def set_missing_host_key_policy(self, policy): def set_missing_host_key_policy(self, policy):
@ -232,7 +248,7 @@ class SSHClient (object):
@raise SSHException: if there was an error authenticating or verifying @raise SSHException: if there was an error authenticating or verifying
the server's host key the server's host key
""" """
t = Transport((hostname, port)) t = self._transport = Transport((hostname, port))
if self._log_channel is not None: if self._log_channel is not None:
t.set_log_channel(self._log_channel) t.set_log_channel(self._log_channel)
t.start_client() t.start_client()
@ -247,6 +263,8 @@ class SSHClient (object):
if our_server_key is None: if our_server_key is None:
# will raise exception if the key is rejected; let that fall out # will raise exception if the key is rejected; let that fall out
self._policy.missing_host_key(self, hostname, server_key) self._policy.missing_host_key(self, hostname, server_key)
# if this continues, assume the key is ok
our_server_key = server_key
our_server_key_hex = hexify(our_server_key.get_fingerprint()) our_server_key_hex = hexify(our_server_key.get_fingerprint())
@ -254,7 +272,6 @@ class SSHClient (object):
raise SSHException('Host key for server %s does not match! (%s != %s)' % raise SSHException('Host key for server %s does not match! (%s != %s)' %
(hostname, our_server_key_kex, server_key_hex)) (hostname, our_server_key_kex, server_key_hex))
self._transport = t
if username is None: if username is None:
username = getpass.getuser() username = getpass.getuser()
self._auth(username, password, pkey, key_filename) self._auth(username, password, pkey, key_filename)
@ -281,20 +298,42 @@ class SSHClient (object):
@raise SSHException: if the server fails to execute the command @raise SSHException: if the server fails to execute the command
""" """
chan = self._transport.open_session() chan = self._transport.open_session()
if not chan.exec_command(command): chan.exec_command(command)
raise SSHException('Command execution failed.')
stdin = chan.makefile('wb') stdin = chan.makefile('wb')
stdout = chan.makefile('rb') stdout = chan.makefile('rb')
stderr = chan.makefile_stderr('rb') stderr = chan.makefile_stderr('rb')
return stdin, stdout, stderr return stdin, stdout, stderr
def invoke_shell(self): def invoke_shell(self, term='vt100', width=80, height=24):
pass """
#FIXME Start an interactive shell session on the SSH server. A new L{Channel}
is opened and connected to a pseudo-terminal using the requested
terminal type and size.
@param term: the terminal type to emulate (for example, C{"vt100"})
@type term: str
@param width: the width (in characters) of the terminal window
@type width: int
@param height: the height (in characters) of the terminal window
@type height: int
@return: a new channel connected to the remote shell
@rtype: L{Channel}
@raise SSHException: if the server fails to invoke a shell
"""
chan = self._transport.open_session()
chan.get_pty(term, width, height)
chan.invoke_shell()
return chan
def open_sftp(self): def open_sftp(self):
pass """
# FIXME Open an SFTP session on the SSH server.
@return: a new SFTP session object
@rtype: L{SFTPClient}
"""
return self._transport.open_sftp_client()
def _auth(self, username, password, pkey, key_filename): def _auth(self, username, password, pkey, key_filename):
""" """
@ -318,7 +357,7 @@ class SSHClient (object):
saved_exception = e saved_exception = e
if key_filename is not None: if key_filename is not None:
for pkey_class in (paramiko.RSAKey, paramiko.DSSKey): for pkey_class in (RSAKey, DSSKey):
try: try:
key = pkey_class.from_private_key_file(key_filename, password) key = pkey_class.from_private_key_file(key_filename, password)
self._log(DEBUG, 'Trying key %s from %s' % (hexify(key.get_fingerprint()), key_filename)) self._log(DEBUG, 'Trying key %s from %s' % (hexify(key.get_fingerprint()), key_filename))
@ -335,8 +374,8 @@ class SSHClient (object):
except SSHException, e: except SSHException, e:
saved_exception = e saved_exception = e
for pkey_class, filename in ((paramiko.RSAKey, 'id_rsa'), for pkey_class, filename in ((RSAKey, 'id_rsa'),
(paramiko.DSSKey, 'id_dsa')): (DSSKey, 'id_dsa')):
filename = os.path.expanduser('~/.ssh/' + filename) filename = os.path.expanduser('~/.ssh/' + filename)
try: try:
key = pkey_class.from_private_key_file(filename, password) key = pkey_class.from_private_key_file(filename, password)
@ -345,10 +384,12 @@ class SSHClient (object):
return return
except SSHException, e: except SSHException, e:
saved_exception = e saved_exception = e
except IOError, e:
saved_exception = e
if password is not None: if password is not None:
try: try:
transport.auth_password(username, password) self._transport.auth_password(username, password)
return return
except SSHException, e: except SSHException, e:
saved_exception = e saved_exception = e

View File

@ -39,6 +39,7 @@ from test_packetizer import PacketizerTest
from test_transport import TransportTest from test_transport import TransportTest
from test_sftp import SFTPTest from test_sftp import SFTPTest
from test_sftp_big import BigSFTPTest from test_sftp_big import BigSFTPTest
from test_client import SSHClientTest
default_host = 'localhost' default_host = 'localhost'
default_user = os.environ.get('USER', 'nobody') default_user = os.environ.get('USER', 'nobody')
@ -100,6 +101,7 @@ suite.addTest(unittest.makeSuite(KexTest))
suite.addTest(unittest.makeSuite(PacketizerTest)) suite.addTest(unittest.makeSuite(PacketizerTest))
if options.use_transport: if options.use_transport:
suite.addTest(unittest.makeSuite(TransportTest)) suite.addTest(unittest.makeSuite(TransportTest))
suite.addTest(unittest.makeSuite(SSHClientTest))
if options.use_sftp: if options.use_sftp:
suite.addTest(unittest.makeSuite(SFTPTest)) suite.addTest(unittest.makeSuite(SFTPTest))
if options.use_big_file: if options.use_big_file:

127
tests/test_client.py Normal file
View File

@ -0,0 +1,127 @@
# Copyright (C) 2003-2005 Robey Pointer <robey@lag.net>
#
# This file is part of paramiko.
#
# Paramiko is free software; you can redistribute it and/or modify it under the
# terms of the GNU Lesser General Public License as published by the Free
# Software Foundation; either version 2.1 of the License, or (at your option)
# any later version.
#
# Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
# details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with Paramiko; if not, write to the Free Software Foundation, Inc.,
# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA.
"""
Some unit tests for SSHClient.
"""
import socket
import threading
import unittest
import paramiko
class NullServer (paramiko.ServerInterface):
def get_allowed_auths(self, username):
if username == 'slowdive':
return 'publickey,password'
return 'publickey'
def check_auth_password(self, username, password):
if (username == 'slowdive') and (password == 'pygmalion'):
return paramiko.AUTH_SUCCESSFUL
return paramiko.AUTH_FAILED
def check_channel_request(self, kind, chanid):
return paramiko.OPEN_SUCCEEDED
def check_channel_exec_request(self, channel, command):
if command != 'yes':
return False
return True
class SSHClientTest (unittest.TestCase):
def setUp(self):
self.sockl = socket.socket()
self.sockl.bind(('localhost', 0))
self.sockl.listen(1)
self.addr, self.port = self.sockl.getsockname()
self.event = threading.Event()
thread = threading.Thread(target=self._run)
thread.start()
def tearDown(self):
self.tc.close()
self.ts.close()
self.socks.close()
self.sockl.close()
def _run(self):
self.socks, addr = self.sockl.accept()
self.ts = paramiko.Transport(self.socks)
host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key')
self.ts.add_server_key(host_key)
server = NullServer()
self.ts.start_server(self.event, server)
def test_1_client(self):
"""
verify that the SSHClient stuff works too.
"""
host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key')
public_host_key = paramiko.RSAKey(data=str(host_key))
self.tc = paramiko.SSHClient()
self.tc.get_host_keys().add(self.addr, 'ssh-rsa', public_host_key)
self.tc.connect(self.addr, self.port, username='slowdive', password='pygmalion')
self.event.wait(1.0)
self.assert_(self.event.isSet())
self.assert_(self.ts.is_active())
self.assertEquals('slowdive', self.ts.get_username())
self.assertEquals(True, self.ts.is_authenticated())
stdin, stdout, stderr = self.tc.exec_command('yes')
schan = self.ts.accept(1.0)
schan.send('Hello there.\n')
schan.send_stderr('This is on stderr.\n')
schan.close()
self.assertEquals('Hello there.\n', stdout.readline())
self.assertEquals('', stdout.readline())
self.assertEquals('This is on stderr.\n', stderr.readline())
self.assertEquals('', stderr.readline())
stdin.close()
stdout.close()
stderr.close()
def test_2_auto_add_policy(self):
"""
verify that SSHClient's AutoAddPolicy works.
"""
host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key')
public_host_key = paramiko.RSAKey(data=str(host_key))
self.tc = paramiko.SSHClient()
self.tc.set_missing_host_key_policy(paramiko.AutoAddPolicy())
self.assertEquals(0, len(self.tc.get_host_keys()))
self.tc.connect(self.addr, self.port, username='slowdive', password='pygmalion')
self.event.wait(1.0)
self.assert_(self.event.isSet())
self.assert_(self.ts.is_active())
self.assertEquals('slowdive', self.ts.get_username())
self.assertEquals(True, self.ts.is_authenticated())
self.assertEquals(1, len(self.tc.get_host_keys()))
self.assertEquals(public_host_key, self.tc.get_host_keys()[self.addr]['ssh-rsa'])