add unit tests for SSHClient, and fix a few bugs that uncovered
This commit is contained in:
parent
de1e072c73
commit
2a03425e27
|
@ -61,10 +61,7 @@ class AutoAddPolicy (MissingHostKeyPolicy):
|
|||
"""
|
||||
|
||||
def missing_host_key(self, client, hostname, key):
|
||||
if not client._host_keys.has_key(hostname):
|
||||
client._host_keys[hostname] = {}
|
||||
client._host_keys[hostname][key.get_name()] = key
|
||||
our_server_key = server_key
|
||||
client._host_keys.add(hostname, key.get_name(), key)
|
||||
if client._host_keys_filename is not None:
|
||||
client.save_host_keys(client._host_keys_filename)
|
||||
client._log(DEBUG, 'Adding %s host key for %s: %s' %
|
||||
|
@ -97,6 +94,8 @@ class SSHClient (object):
|
|||
You may pass in explicit overrides for authentication and server host key
|
||||
checking. The default mechanism is to try to use local key files or an
|
||||
SSH agent (if one is running).
|
||||
|
||||
@since: 1.6
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
|
@ -177,7 +176,24 @@ class SSHClient (object):
|
|||
f.write('%s %s %s\n' % (hostname, keytype, key.get_base64()))
|
||||
f.close()
|
||||
|
||||
def get_host_keys(self):
|
||||
"""
|
||||
Get the local L{HostKeys} object. This can be used to examine the
|
||||
local host keys or change them.
|
||||
|
||||
@return: the local host keys
|
||||
@rtype: L{HostKeys}
|
||||
"""
|
||||
return self._host_keys
|
||||
|
||||
def set_log_channel(self, channel):
|
||||
"""
|
||||
Set the channel for logging. The default is C{"paramiko.transport"}
|
||||
but it can be set to anything you want.
|
||||
|
||||
@param name: new channel name for logging
|
||||
@type name: str
|
||||
"""
|
||||
self._log_channel = channel
|
||||
|
||||
def set_missing_host_key_policy(self, policy):
|
||||
|
@ -232,7 +248,7 @@ class SSHClient (object):
|
|||
@raise SSHException: if there was an error authenticating or verifying
|
||||
the server's host key
|
||||
"""
|
||||
t = Transport((hostname, port))
|
||||
t = self._transport = Transport((hostname, port))
|
||||
if self._log_channel is not None:
|
||||
t.set_log_channel(self._log_channel)
|
||||
t.start_client()
|
||||
|
@ -247,6 +263,8 @@ class SSHClient (object):
|
|||
if our_server_key is None:
|
||||
# will raise exception if the key is rejected; let that fall out
|
||||
self._policy.missing_host_key(self, hostname, server_key)
|
||||
# if this continues, assume the key is ok
|
||||
our_server_key = server_key
|
||||
|
||||
our_server_key_hex = hexify(our_server_key.get_fingerprint())
|
||||
|
||||
|
@ -254,7 +272,6 @@ class SSHClient (object):
|
|||
raise SSHException('Host key for server %s does not match! (%s != %s)' %
|
||||
(hostname, our_server_key_kex, server_key_hex))
|
||||
|
||||
self._transport = t
|
||||
if username is None:
|
||||
username = getpass.getuser()
|
||||
self._auth(username, password, pkey, key_filename)
|
||||
|
@ -281,20 +298,42 @@ class SSHClient (object):
|
|||
@raise SSHException: if the server fails to execute the command
|
||||
"""
|
||||
chan = self._transport.open_session()
|
||||
if not chan.exec_command(command):
|
||||
raise SSHException('Command execution failed.')
|
||||
chan.exec_command(command)
|
||||
stdin = chan.makefile('wb')
|
||||
stdout = chan.makefile('rb')
|
||||
stderr = chan.makefile_stderr('rb')
|
||||
return stdin, stdout, stderr
|
||||
|
||||
def invoke_shell(self):
|
||||
pass
|
||||
#FIXME
|
||||
def invoke_shell(self, term='vt100', width=80, height=24):
|
||||
"""
|
||||
Start an interactive shell session on the SSH server. A new L{Channel}
|
||||
is opened and connected to a pseudo-terminal using the requested
|
||||
terminal type and size.
|
||||
|
||||
@param term: the terminal type to emulate (for example, C{"vt100"})
|
||||
@type term: str
|
||||
@param width: the width (in characters) of the terminal window
|
||||
@type width: int
|
||||
@param height: the height (in characters) of the terminal window
|
||||
@type height: int
|
||||
@return: a new channel connected to the remote shell
|
||||
@rtype: L{Channel}
|
||||
|
||||
@raise SSHException: if the server fails to invoke a shell
|
||||
"""
|
||||
chan = self._transport.open_session()
|
||||
chan.get_pty(term, width, height)
|
||||
chan.invoke_shell()
|
||||
return chan
|
||||
|
||||
def open_sftp(self):
|
||||
pass
|
||||
# FIXME
|
||||
"""
|
||||
Open an SFTP session on the SSH server.
|
||||
|
||||
@return: a new SFTP session object
|
||||
@rtype: L{SFTPClient}
|
||||
"""
|
||||
return self._transport.open_sftp_client()
|
||||
|
||||
def _auth(self, username, password, pkey, key_filename):
|
||||
"""
|
||||
|
@ -318,7 +357,7 @@ class SSHClient (object):
|
|||
saved_exception = e
|
||||
|
||||
if key_filename is not None:
|
||||
for pkey_class in (paramiko.RSAKey, paramiko.DSSKey):
|
||||
for pkey_class in (RSAKey, DSSKey):
|
||||
try:
|
||||
key = pkey_class.from_private_key_file(key_filename, password)
|
||||
self._log(DEBUG, 'Trying key %s from %s' % (hexify(key.get_fingerprint()), key_filename))
|
||||
|
@ -335,8 +374,8 @@ class SSHClient (object):
|
|||
except SSHException, e:
|
||||
saved_exception = e
|
||||
|
||||
for pkey_class, filename in ((paramiko.RSAKey, 'id_rsa'),
|
||||
(paramiko.DSSKey, 'id_dsa')):
|
||||
for pkey_class, filename in ((RSAKey, 'id_rsa'),
|
||||
(DSSKey, 'id_dsa')):
|
||||
filename = os.path.expanduser('~/.ssh/' + filename)
|
||||
try:
|
||||
key = pkey_class.from_private_key_file(filename, password)
|
||||
|
@ -345,10 +384,12 @@ class SSHClient (object):
|
|||
return
|
||||
except SSHException, e:
|
||||
saved_exception = e
|
||||
except IOError, e:
|
||||
saved_exception = e
|
||||
|
||||
if password is not None:
|
||||
try:
|
||||
transport.auth_password(username, password)
|
||||
self._transport.auth_password(username, password)
|
||||
return
|
||||
except SSHException, e:
|
||||
saved_exception = e
|
||||
|
|
2
test.py
2
test.py
|
@ -39,6 +39,7 @@ from test_packetizer import PacketizerTest
|
|||
from test_transport import TransportTest
|
||||
from test_sftp import SFTPTest
|
||||
from test_sftp_big import BigSFTPTest
|
||||
from test_client import SSHClientTest
|
||||
|
||||
default_host = 'localhost'
|
||||
default_user = os.environ.get('USER', 'nobody')
|
||||
|
@ -100,6 +101,7 @@ suite.addTest(unittest.makeSuite(KexTest))
|
|||
suite.addTest(unittest.makeSuite(PacketizerTest))
|
||||
if options.use_transport:
|
||||
suite.addTest(unittest.makeSuite(TransportTest))
|
||||
suite.addTest(unittest.makeSuite(SSHClientTest))
|
||||
if options.use_sftp:
|
||||
suite.addTest(unittest.makeSuite(SFTPTest))
|
||||
if options.use_big_file:
|
||||
|
|
|
@ -0,0 +1,127 @@
|
|||
# Copyright (C) 2003-2005 Robey Pointer <robey@lag.net>
|
||||
#
|
||||
# This file is part of paramiko.
|
||||
#
|
||||
# Paramiko is free software; you can redistribute it and/or modify it under the
|
||||
# terms of the GNU Lesser General Public License as published by the Free
|
||||
# Software Foundation; either version 2.1 of the License, or (at your option)
|
||||
# any later version.
|
||||
#
|
||||
# Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY
|
||||
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
|
||||
# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
|
||||
# details.
|
||||
#
|
||||
# You should have received a copy of the GNU Lesser General Public License
|
||||
# along with Paramiko; if not, write to the Free Software Foundation, Inc.,
|
||||
# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA.
|
||||
|
||||
"""
|
||||
Some unit tests for SSHClient.
|
||||
"""
|
||||
|
||||
import socket
|
||||
import threading
|
||||
import unittest
|
||||
import paramiko
|
||||
|
||||
|
||||
class NullServer (paramiko.ServerInterface):
|
||||
|
||||
def get_allowed_auths(self, username):
|
||||
if username == 'slowdive':
|
||||
return 'publickey,password'
|
||||
return 'publickey'
|
||||
|
||||
def check_auth_password(self, username, password):
|
||||
if (username == 'slowdive') and (password == 'pygmalion'):
|
||||
return paramiko.AUTH_SUCCESSFUL
|
||||
return paramiko.AUTH_FAILED
|
||||
|
||||
def check_channel_request(self, kind, chanid):
|
||||
return paramiko.OPEN_SUCCEEDED
|
||||
|
||||
def check_channel_exec_request(self, channel, command):
|
||||
if command != 'yes':
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class SSHClientTest (unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.sockl = socket.socket()
|
||||
self.sockl.bind(('localhost', 0))
|
||||
self.sockl.listen(1)
|
||||
self.addr, self.port = self.sockl.getsockname()
|
||||
self.event = threading.Event()
|
||||
thread = threading.Thread(target=self._run)
|
||||
thread.start()
|
||||
|
||||
def tearDown(self):
|
||||
self.tc.close()
|
||||
self.ts.close()
|
||||
self.socks.close()
|
||||
self.sockl.close()
|
||||
|
||||
def _run(self):
|
||||
self.socks, addr = self.sockl.accept()
|
||||
self.ts = paramiko.Transport(self.socks)
|
||||
host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key')
|
||||
self.ts.add_server_key(host_key)
|
||||
server = NullServer()
|
||||
self.ts.start_server(self.event, server)
|
||||
|
||||
|
||||
def test_1_client(self):
|
||||
"""
|
||||
verify that the SSHClient stuff works too.
|
||||
"""
|
||||
host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key')
|
||||
public_host_key = paramiko.RSAKey(data=str(host_key))
|
||||
|
||||
self.tc = paramiko.SSHClient()
|
||||
self.tc.get_host_keys().add(self.addr, 'ssh-rsa', public_host_key)
|
||||
self.tc.connect(self.addr, self.port, username='slowdive', password='pygmalion')
|
||||
|
||||
self.event.wait(1.0)
|
||||
self.assert_(self.event.isSet())
|
||||
self.assert_(self.ts.is_active())
|
||||
self.assertEquals('slowdive', self.ts.get_username())
|
||||
self.assertEquals(True, self.ts.is_authenticated())
|
||||
|
||||
stdin, stdout, stderr = self.tc.exec_command('yes')
|
||||
schan = self.ts.accept(1.0)
|
||||
|
||||
schan.send('Hello there.\n')
|
||||
schan.send_stderr('This is on stderr.\n')
|
||||
schan.close()
|
||||
|
||||
self.assertEquals('Hello there.\n', stdout.readline())
|
||||
self.assertEquals('', stdout.readline())
|
||||
self.assertEquals('This is on stderr.\n', stderr.readline())
|
||||
self.assertEquals('', stderr.readline())
|
||||
|
||||
stdin.close()
|
||||
stdout.close()
|
||||
stderr.close()
|
||||
|
||||
def test_2_auto_add_policy(self):
|
||||
"""
|
||||
verify that SSHClient's AutoAddPolicy works.
|
||||
"""
|
||||
host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key')
|
||||
public_host_key = paramiko.RSAKey(data=str(host_key))
|
||||
|
||||
self.tc = paramiko.SSHClient()
|
||||
self.tc.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
||||
self.assertEquals(0, len(self.tc.get_host_keys()))
|
||||
self.tc.connect(self.addr, self.port, username='slowdive', password='pygmalion')
|
||||
|
||||
self.event.wait(1.0)
|
||||
self.assert_(self.event.isSet())
|
||||
self.assert_(self.ts.is_active())
|
||||
self.assertEquals('slowdive', self.ts.get_username())
|
||||
self.assertEquals(True, self.ts.is_authenticated())
|
||||
self.assertEquals(1, len(self.tc.get_host_keys()))
|
||||
self.assertEquals(public_host_key, self.tc.get_host_keys()[self.addr]['ssh-rsa'])
|
Loading…
Reference in New Issue