add unit tests for SSHClient, and fix a few bugs that uncovered
This commit is contained in:
parent
de1e072c73
commit
2a03425e27
|
@ -61,10 +61,7 @@ class AutoAddPolicy (MissingHostKeyPolicy):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def missing_host_key(self, client, hostname, key):
|
def missing_host_key(self, client, hostname, key):
|
||||||
if not client._host_keys.has_key(hostname):
|
client._host_keys.add(hostname, key.get_name(), key)
|
||||||
client._host_keys[hostname] = {}
|
|
||||||
client._host_keys[hostname][key.get_name()] = key
|
|
||||||
our_server_key = server_key
|
|
||||||
if client._host_keys_filename is not None:
|
if client._host_keys_filename is not None:
|
||||||
client.save_host_keys(client._host_keys_filename)
|
client.save_host_keys(client._host_keys_filename)
|
||||||
client._log(DEBUG, 'Adding %s host key for %s: %s' %
|
client._log(DEBUG, 'Adding %s host key for %s: %s' %
|
||||||
|
@ -97,6 +94,8 @@ class SSHClient (object):
|
||||||
You may pass in explicit overrides for authentication and server host key
|
You may pass in explicit overrides for authentication and server host key
|
||||||
checking. The default mechanism is to try to use local key files or an
|
checking. The default mechanism is to try to use local key files or an
|
||||||
SSH agent (if one is running).
|
SSH agent (if one is running).
|
||||||
|
|
||||||
|
@since: 1.6
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -177,7 +176,24 @@ class SSHClient (object):
|
||||||
f.write('%s %s %s\n' % (hostname, keytype, key.get_base64()))
|
f.write('%s %s %s\n' % (hostname, keytype, key.get_base64()))
|
||||||
f.close()
|
f.close()
|
||||||
|
|
||||||
|
def get_host_keys(self):
|
||||||
|
"""
|
||||||
|
Get the local L{HostKeys} object. This can be used to examine the
|
||||||
|
local host keys or change them.
|
||||||
|
|
||||||
|
@return: the local host keys
|
||||||
|
@rtype: L{HostKeys}
|
||||||
|
"""
|
||||||
|
return self._host_keys
|
||||||
|
|
||||||
def set_log_channel(self, channel):
|
def set_log_channel(self, channel):
|
||||||
|
"""
|
||||||
|
Set the channel for logging. The default is C{"paramiko.transport"}
|
||||||
|
but it can be set to anything you want.
|
||||||
|
|
||||||
|
@param name: new channel name for logging
|
||||||
|
@type name: str
|
||||||
|
"""
|
||||||
self._log_channel = channel
|
self._log_channel = channel
|
||||||
|
|
||||||
def set_missing_host_key_policy(self, policy):
|
def set_missing_host_key_policy(self, policy):
|
||||||
|
@ -232,7 +248,7 @@ class SSHClient (object):
|
||||||
@raise SSHException: if there was an error authenticating or verifying
|
@raise SSHException: if there was an error authenticating or verifying
|
||||||
the server's host key
|
the server's host key
|
||||||
"""
|
"""
|
||||||
t = Transport((hostname, port))
|
t = self._transport = Transport((hostname, port))
|
||||||
if self._log_channel is not None:
|
if self._log_channel is not None:
|
||||||
t.set_log_channel(self._log_channel)
|
t.set_log_channel(self._log_channel)
|
||||||
t.start_client()
|
t.start_client()
|
||||||
|
@ -247,6 +263,8 @@ class SSHClient (object):
|
||||||
if our_server_key is None:
|
if our_server_key is None:
|
||||||
# will raise exception if the key is rejected; let that fall out
|
# will raise exception if the key is rejected; let that fall out
|
||||||
self._policy.missing_host_key(self, hostname, server_key)
|
self._policy.missing_host_key(self, hostname, server_key)
|
||||||
|
# if this continues, assume the key is ok
|
||||||
|
our_server_key = server_key
|
||||||
|
|
||||||
our_server_key_hex = hexify(our_server_key.get_fingerprint())
|
our_server_key_hex = hexify(our_server_key.get_fingerprint())
|
||||||
|
|
||||||
|
@ -254,7 +272,6 @@ class SSHClient (object):
|
||||||
raise SSHException('Host key for server %s does not match! (%s != %s)' %
|
raise SSHException('Host key for server %s does not match! (%s != %s)' %
|
||||||
(hostname, our_server_key_kex, server_key_hex))
|
(hostname, our_server_key_kex, server_key_hex))
|
||||||
|
|
||||||
self._transport = t
|
|
||||||
if username is None:
|
if username is None:
|
||||||
username = getpass.getuser()
|
username = getpass.getuser()
|
||||||
self._auth(username, password, pkey, key_filename)
|
self._auth(username, password, pkey, key_filename)
|
||||||
|
@ -281,20 +298,42 @@ class SSHClient (object):
|
||||||
@raise SSHException: if the server fails to execute the command
|
@raise SSHException: if the server fails to execute the command
|
||||||
"""
|
"""
|
||||||
chan = self._transport.open_session()
|
chan = self._transport.open_session()
|
||||||
if not chan.exec_command(command):
|
chan.exec_command(command)
|
||||||
raise SSHException('Command execution failed.')
|
|
||||||
stdin = chan.makefile('wb')
|
stdin = chan.makefile('wb')
|
||||||
stdout = chan.makefile('rb')
|
stdout = chan.makefile('rb')
|
||||||
stderr = chan.makefile_stderr('rb')
|
stderr = chan.makefile_stderr('rb')
|
||||||
return stdin, stdout, stderr
|
return stdin, stdout, stderr
|
||||||
|
|
||||||
def invoke_shell(self):
|
def invoke_shell(self, term='vt100', width=80, height=24):
|
||||||
pass
|
"""
|
||||||
#FIXME
|
Start an interactive shell session on the SSH server. A new L{Channel}
|
||||||
|
is opened and connected to a pseudo-terminal using the requested
|
||||||
|
terminal type and size.
|
||||||
|
|
||||||
|
@param term: the terminal type to emulate (for example, C{"vt100"})
|
||||||
|
@type term: str
|
||||||
|
@param width: the width (in characters) of the terminal window
|
||||||
|
@type width: int
|
||||||
|
@param height: the height (in characters) of the terminal window
|
||||||
|
@type height: int
|
||||||
|
@return: a new channel connected to the remote shell
|
||||||
|
@rtype: L{Channel}
|
||||||
|
|
||||||
|
@raise SSHException: if the server fails to invoke a shell
|
||||||
|
"""
|
||||||
|
chan = self._transport.open_session()
|
||||||
|
chan.get_pty(term, width, height)
|
||||||
|
chan.invoke_shell()
|
||||||
|
return chan
|
||||||
|
|
||||||
def open_sftp(self):
|
def open_sftp(self):
|
||||||
pass
|
"""
|
||||||
# FIXME
|
Open an SFTP session on the SSH server.
|
||||||
|
|
||||||
|
@return: a new SFTP session object
|
||||||
|
@rtype: L{SFTPClient}
|
||||||
|
"""
|
||||||
|
return self._transport.open_sftp_client()
|
||||||
|
|
||||||
def _auth(self, username, password, pkey, key_filename):
|
def _auth(self, username, password, pkey, key_filename):
|
||||||
"""
|
"""
|
||||||
|
@ -318,7 +357,7 @@ class SSHClient (object):
|
||||||
saved_exception = e
|
saved_exception = e
|
||||||
|
|
||||||
if key_filename is not None:
|
if key_filename is not None:
|
||||||
for pkey_class in (paramiko.RSAKey, paramiko.DSSKey):
|
for pkey_class in (RSAKey, DSSKey):
|
||||||
try:
|
try:
|
||||||
key = pkey_class.from_private_key_file(key_filename, password)
|
key = pkey_class.from_private_key_file(key_filename, password)
|
||||||
self._log(DEBUG, 'Trying key %s from %s' % (hexify(key.get_fingerprint()), key_filename))
|
self._log(DEBUG, 'Trying key %s from %s' % (hexify(key.get_fingerprint()), key_filename))
|
||||||
|
@ -335,8 +374,8 @@ class SSHClient (object):
|
||||||
except SSHException, e:
|
except SSHException, e:
|
||||||
saved_exception = e
|
saved_exception = e
|
||||||
|
|
||||||
for pkey_class, filename in ((paramiko.RSAKey, 'id_rsa'),
|
for pkey_class, filename in ((RSAKey, 'id_rsa'),
|
||||||
(paramiko.DSSKey, 'id_dsa')):
|
(DSSKey, 'id_dsa')):
|
||||||
filename = os.path.expanduser('~/.ssh/' + filename)
|
filename = os.path.expanduser('~/.ssh/' + filename)
|
||||||
try:
|
try:
|
||||||
key = pkey_class.from_private_key_file(filename, password)
|
key = pkey_class.from_private_key_file(filename, password)
|
||||||
|
@ -345,10 +384,12 @@ class SSHClient (object):
|
||||||
return
|
return
|
||||||
except SSHException, e:
|
except SSHException, e:
|
||||||
saved_exception = e
|
saved_exception = e
|
||||||
|
except IOError, e:
|
||||||
|
saved_exception = e
|
||||||
|
|
||||||
if password is not None:
|
if password is not None:
|
||||||
try:
|
try:
|
||||||
transport.auth_password(username, password)
|
self._transport.auth_password(username, password)
|
||||||
return
|
return
|
||||||
except SSHException, e:
|
except SSHException, e:
|
||||||
saved_exception = e
|
saved_exception = e
|
||||||
|
|
2
test.py
2
test.py
|
@ -39,6 +39,7 @@ from test_packetizer import PacketizerTest
|
||||||
from test_transport import TransportTest
|
from test_transport import TransportTest
|
||||||
from test_sftp import SFTPTest
|
from test_sftp import SFTPTest
|
||||||
from test_sftp_big import BigSFTPTest
|
from test_sftp_big import BigSFTPTest
|
||||||
|
from test_client import SSHClientTest
|
||||||
|
|
||||||
default_host = 'localhost'
|
default_host = 'localhost'
|
||||||
default_user = os.environ.get('USER', 'nobody')
|
default_user = os.environ.get('USER', 'nobody')
|
||||||
|
@ -100,6 +101,7 @@ suite.addTest(unittest.makeSuite(KexTest))
|
||||||
suite.addTest(unittest.makeSuite(PacketizerTest))
|
suite.addTest(unittest.makeSuite(PacketizerTest))
|
||||||
if options.use_transport:
|
if options.use_transport:
|
||||||
suite.addTest(unittest.makeSuite(TransportTest))
|
suite.addTest(unittest.makeSuite(TransportTest))
|
||||||
|
suite.addTest(unittest.makeSuite(SSHClientTest))
|
||||||
if options.use_sftp:
|
if options.use_sftp:
|
||||||
suite.addTest(unittest.makeSuite(SFTPTest))
|
suite.addTest(unittest.makeSuite(SFTPTest))
|
||||||
if options.use_big_file:
|
if options.use_big_file:
|
||||||
|
|
|
@ -0,0 +1,127 @@
|
||||||
|
# Copyright (C) 2003-2005 Robey Pointer <robey@lag.net>
|
||||||
|
#
|
||||||
|
# This file is part of paramiko.
|
||||||
|
#
|
||||||
|
# Paramiko is free software; you can redistribute it and/or modify it under the
|
||||||
|
# terms of the GNU Lesser General Public License as published by the Free
|
||||||
|
# Software Foundation; either version 2.1 of the License, or (at your option)
|
||||||
|
# any later version.
|
||||||
|
#
|
||||||
|
# Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY
|
||||||
|
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
|
||||||
|
# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
|
||||||
|
# details.
|
||||||
|
#
|
||||||
|
# You should have received a copy of the GNU Lesser General Public License
|
||||||
|
# along with Paramiko; if not, write to the Free Software Foundation, Inc.,
|
||||||
|
# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA.
|
||||||
|
|
||||||
|
"""
|
||||||
|
Some unit tests for SSHClient.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import socket
|
||||||
|
import threading
|
||||||
|
import unittest
|
||||||
|
import paramiko
|
||||||
|
|
||||||
|
|
||||||
|
class NullServer (paramiko.ServerInterface):
|
||||||
|
|
||||||
|
def get_allowed_auths(self, username):
|
||||||
|
if username == 'slowdive':
|
||||||
|
return 'publickey,password'
|
||||||
|
return 'publickey'
|
||||||
|
|
||||||
|
def check_auth_password(self, username, password):
|
||||||
|
if (username == 'slowdive') and (password == 'pygmalion'):
|
||||||
|
return paramiko.AUTH_SUCCESSFUL
|
||||||
|
return paramiko.AUTH_FAILED
|
||||||
|
|
||||||
|
def check_channel_request(self, kind, chanid):
|
||||||
|
return paramiko.OPEN_SUCCEEDED
|
||||||
|
|
||||||
|
def check_channel_exec_request(self, channel, command):
|
||||||
|
if command != 'yes':
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
class SSHClientTest (unittest.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.sockl = socket.socket()
|
||||||
|
self.sockl.bind(('localhost', 0))
|
||||||
|
self.sockl.listen(1)
|
||||||
|
self.addr, self.port = self.sockl.getsockname()
|
||||||
|
self.event = threading.Event()
|
||||||
|
thread = threading.Thread(target=self._run)
|
||||||
|
thread.start()
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
self.tc.close()
|
||||||
|
self.ts.close()
|
||||||
|
self.socks.close()
|
||||||
|
self.sockl.close()
|
||||||
|
|
||||||
|
def _run(self):
|
||||||
|
self.socks, addr = self.sockl.accept()
|
||||||
|
self.ts = paramiko.Transport(self.socks)
|
||||||
|
host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key')
|
||||||
|
self.ts.add_server_key(host_key)
|
||||||
|
server = NullServer()
|
||||||
|
self.ts.start_server(self.event, server)
|
||||||
|
|
||||||
|
|
||||||
|
def test_1_client(self):
|
||||||
|
"""
|
||||||
|
verify that the SSHClient stuff works too.
|
||||||
|
"""
|
||||||
|
host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key')
|
||||||
|
public_host_key = paramiko.RSAKey(data=str(host_key))
|
||||||
|
|
||||||
|
self.tc = paramiko.SSHClient()
|
||||||
|
self.tc.get_host_keys().add(self.addr, 'ssh-rsa', public_host_key)
|
||||||
|
self.tc.connect(self.addr, self.port, username='slowdive', password='pygmalion')
|
||||||
|
|
||||||
|
self.event.wait(1.0)
|
||||||
|
self.assert_(self.event.isSet())
|
||||||
|
self.assert_(self.ts.is_active())
|
||||||
|
self.assertEquals('slowdive', self.ts.get_username())
|
||||||
|
self.assertEquals(True, self.ts.is_authenticated())
|
||||||
|
|
||||||
|
stdin, stdout, stderr = self.tc.exec_command('yes')
|
||||||
|
schan = self.ts.accept(1.0)
|
||||||
|
|
||||||
|
schan.send('Hello there.\n')
|
||||||
|
schan.send_stderr('This is on stderr.\n')
|
||||||
|
schan.close()
|
||||||
|
|
||||||
|
self.assertEquals('Hello there.\n', stdout.readline())
|
||||||
|
self.assertEquals('', stdout.readline())
|
||||||
|
self.assertEquals('This is on stderr.\n', stderr.readline())
|
||||||
|
self.assertEquals('', stderr.readline())
|
||||||
|
|
||||||
|
stdin.close()
|
||||||
|
stdout.close()
|
||||||
|
stderr.close()
|
||||||
|
|
||||||
|
def test_2_auto_add_policy(self):
|
||||||
|
"""
|
||||||
|
verify that SSHClient's AutoAddPolicy works.
|
||||||
|
"""
|
||||||
|
host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key')
|
||||||
|
public_host_key = paramiko.RSAKey(data=str(host_key))
|
||||||
|
|
||||||
|
self.tc = paramiko.SSHClient()
|
||||||
|
self.tc.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
||||||
|
self.assertEquals(0, len(self.tc.get_host_keys()))
|
||||||
|
self.tc.connect(self.addr, self.port, username='slowdive', password='pygmalion')
|
||||||
|
|
||||||
|
self.event.wait(1.0)
|
||||||
|
self.assert_(self.event.isSet())
|
||||||
|
self.assert_(self.ts.is_active())
|
||||||
|
self.assertEquals('slowdive', self.ts.get_username())
|
||||||
|
self.assertEquals(True, self.ts.is_authenticated())
|
||||||
|
self.assertEquals(1, len(self.tc.get_host_keys()))
|
||||||
|
self.assertEquals(public_host_key, self.tc.get_host_keys()[self.addr]['ssh-rsa'])
|
Loading…
Reference in New Issue