Merge remote-tracking branch 'scottkmaxwell/py3-support-without-py25' into python3

Conflicts:
	dev-requirements.txt
	paramiko/__init__.py
	paramiko/file.py
	paramiko/hostkeys.py
	paramiko/message.py
	paramiko/proxy.py
	paramiko/server.py
	paramiko/transport.py
	paramiko/util.py
	paramiko/win_pageant.py
	setup.py
This commit is contained in:
Jeff Forcier 2014-03-05 17:03:37 -08:00
commit b2be63ec62
64 changed files with 1970 additions and 1593 deletions

View File

@ -2,6 +2,8 @@ language: python
python: python:
- "2.6" - "2.6"
- "2.7" - "2.7"
- "3.2"
- "3.3"
install: install:
# Self-install for setup.py-driven deps # Self-install for setup.py-driven deps
- pip install -e . - pip install -e .

4
README
View File

@ -15,7 +15,7 @@ What
---- ----
"paramiko" is a combination of the esperanto words for "paranoid" and "paramiko" is a combination of the esperanto words for "paranoid" and
"friend". it's a module for python 2.5+ that implements the SSH2 protocol "friend". it's a module for python 2.6+ that implements the SSH2 protocol
for secure (encrypted and authenticated) connections to remote machines. for secure (encrypted and authenticated) connections to remote machines.
unlike SSL (aka TLS), SSH2 protocol does not require hierarchical unlike SSL (aka TLS), SSH2 protocol does not require hierarchical
certificates signed by a powerful central authority. you may know SSH2 as certificates signed by a powerful central authority. you may know SSH2 as
@ -34,7 +34,7 @@ that should have come with this archive.
Requirements Requirements
------------ ------------
- python 2.5 or better <http://www.python.org/> - python 2.6 or better <http://www.python.org/>
- pycrypto 2.1 or better <https://www.dlitz.net/software/pycrypto/> - pycrypto 2.1 or better <https://www.dlitz.net/software/pycrypto/>
- ecdsa 0.9 or better <https://pypi.python.org/pypi/ecdsa> - ecdsa 0.9 or better <https://pypi.python.org/pypi/ecdsa>

View File

@ -28,9 +28,13 @@ import socket
import sys import sys
import time import time
import traceback import traceback
from paramiko.py3compat import input
import paramiko import paramiko
import interactive try:
import interactive
except ImportError:
from . import interactive
def agent_auth(transport, username): def agent_auth(transport, username):
@ -45,24 +49,24 @@ def agent_auth(transport, username):
return return
for key in agent_keys: for key in agent_keys:
print 'Trying ssh-agent key %s' % hexlify(key.get_fingerprint()), print('Trying ssh-agent key %s' % hexlify(key.get_fingerprint()))
try: try:
transport.auth_publickey(username, key) transport.auth_publickey(username, key)
print '... success!' print('... success!')
return return
except paramiko.SSHException: except paramiko.SSHException:
print '... nope.' print('... nope.')
def manual_auth(username, hostname): def manual_auth(username, hostname):
default_auth = 'p' default_auth = 'p'
auth = raw_input('Auth by (p)assword, (r)sa key, or (d)ss key? [%s] ' % default_auth) auth = input('Auth by (p)assword, (r)sa key, or (d)ss key? [%s] ' % default_auth)
if len(auth) == 0: if len(auth) == 0:
auth = default_auth auth = default_auth
if auth == 'r': if auth == 'r':
default_path = os.path.join(os.environ['HOME'], '.ssh', 'id_rsa') default_path = os.path.join(os.environ['HOME'], '.ssh', 'id_rsa')
path = raw_input('RSA key [%s]: ' % default_path) path = input('RSA key [%s]: ' % default_path)
if len(path) == 0: if len(path) == 0:
path = default_path path = default_path
try: try:
@ -73,7 +77,7 @@ def manual_auth(username, hostname):
t.auth_publickey(username, key) t.auth_publickey(username, key)
elif auth == 'd': elif auth == 'd':
default_path = os.path.join(os.environ['HOME'], '.ssh', 'id_dsa') default_path = os.path.join(os.environ['HOME'], '.ssh', 'id_dsa')
path = raw_input('DSS key [%s]: ' % default_path) path = input('DSS key [%s]: ' % default_path)
if len(path) == 0: if len(path) == 0:
path = default_path path = default_path
try: try:
@ -96,9 +100,9 @@ if len(sys.argv) > 1:
if hostname.find('@') >= 0: if hostname.find('@') >= 0:
username, hostname = hostname.split('@') username, hostname = hostname.split('@')
else: else:
hostname = raw_input('Hostname: ') hostname = input('Hostname: ')
if len(hostname) == 0: if len(hostname) == 0:
print '*** Hostname required.' print('*** Hostname required.')
sys.exit(1) sys.exit(1)
port = 22 port = 22
if hostname.find(':') >= 0: if hostname.find(':') >= 0:
@ -109,8 +113,8 @@ if hostname.find(':') >= 0:
try: try:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect((hostname, port)) sock.connect((hostname, port))
except Exception, e: except Exception as e:
print '*** Connect failed: ' + str(e) print('*** Connect failed: ' + str(e))
traceback.print_exc() traceback.print_exc()
sys.exit(1) sys.exit(1)
@ -119,7 +123,7 @@ try:
try: try:
t.start_client() t.start_client()
except paramiko.SSHException: except paramiko.SSHException:
print '*** SSH negotiation failed.' print('*** SSH negotiation failed.')
sys.exit(1) sys.exit(1)
try: try:
@ -128,25 +132,25 @@ try:
try: try:
keys = paramiko.util.load_host_keys(os.path.expanduser('~/ssh/known_hosts')) keys = paramiko.util.load_host_keys(os.path.expanduser('~/ssh/known_hosts'))
except IOError: except IOError:
print '*** Unable to open host keys file' print('*** Unable to open host keys file')
keys = {} keys = {}
# check server's host key -- this is important. # check server's host key -- this is important.
key = t.get_remote_server_key() key = t.get_remote_server_key()
if not keys.has_key(hostname): if hostname not in keys:
print '*** WARNING: Unknown host key!' print('*** WARNING: Unknown host key!')
elif not keys[hostname].has_key(key.get_name()): elif key.get_name() not in keys[hostname]:
print '*** WARNING: Unknown host key!' print('*** WARNING: Unknown host key!')
elif keys[hostname][key.get_name()] != key: elif keys[hostname][key.get_name()] != key:
print '*** WARNING: Host key has changed!!!' print('*** WARNING: Host key has changed!!!')
sys.exit(1) sys.exit(1)
else: else:
print '*** Host key OK.' print('*** Host key OK.')
# get username # get username
if username == '': if username == '':
default_username = getpass.getuser() default_username = getpass.getuser()
username = raw_input('Username [%s]: ' % default_username) username = input('Username [%s]: ' % default_username)
if len(username) == 0: if len(username) == 0:
username = default_username username = default_username
@ -154,21 +158,20 @@ try:
if not t.is_authenticated(): if not t.is_authenticated():
manual_auth(username, hostname) manual_auth(username, hostname)
if not t.is_authenticated(): if not t.is_authenticated():
print '*** Authentication failed. :(' print('*** Authentication failed. :(')
t.close() t.close()
sys.exit(1) sys.exit(1)
chan = t.open_session() chan = t.open_session()
chan.get_pty() chan.get_pty()
chan.invoke_shell() chan.invoke_shell()
print '*** Here we go!' print('*** Here we go!\n')
print
interactive.interactive_shell(chan) interactive.interactive_shell(chan)
chan.close() chan.close()
t.close() t.close()
except Exception, e: except Exception as e:
print '*** Caught exception: ' + str(e.__class__) + ': ' + str(e) print('*** Caught exception: ' + str(e.__class__) + ': ' + str(e))
traceback.print_exc() traceback.print_exc()
try: try:
t.close() t.close()

View File

@ -17,9 +17,7 @@
# You should have received a copy of the GNU Lesser General Public License # You should have received a copy of the GNU Lesser General Public License
# along with Paramiko; if not, write to the Free Software Foundation, Inc., # along with Paramiko; if not, write to the Free Software Foundation, Inc.,
# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA. # 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA.
from __future__ import with_statement
import string
import sys import sys
from binascii import hexlify from binascii import hexlify
@ -28,6 +26,7 @@ from optparse import OptionParser
from paramiko import DSSKey from paramiko import DSSKey
from paramiko import RSAKey from paramiko import RSAKey
from paramiko.ssh_exception import SSHException from paramiko.ssh_exception import SSHException
from paramiko.py3compat import u
usage=""" usage="""
%prog [-v] [-b bits] -t type [-N new_passphrase] [-f output_keyfile]""" %prog [-v] [-b bits] -t type [-N new_passphrase] [-f output_keyfile]"""
@ -47,16 +46,16 @@ key_dispatch_table = {
def progress(arg=None): def progress(arg=None):
if not arg: if not arg:
print '0%\x08\x08\x08', sys.stdout.write('0%\x08\x08\x08 ')
sys.stdout.flush() sys.stdout.flush()
elif arg[0] == 'p': elif arg[0] == 'p':
print '25%\x08\x08\x08\x08', sys.stdout.write('25%\x08\x08\x08\x08 ')
sys.stdout.flush() sys.stdout.flush()
elif arg[0] == 'h': elif arg[0] == 'h':
print '50%\x08\x08\x08\x08', sys.stdout.write('50%\x08\x08\x08\x08 ')
sys.stdout.flush() sys.stdout.flush()
elif arg[0] == 'x': elif arg[0] == 'x':
print '75%\x08\x08\x08\x08', sys.stdout.write('75%\x08\x08\x08\x08 ')
sys.stdout.flush() sys.stdout.flush()
if __name__ == '__main__': if __name__ == '__main__':
@ -92,8 +91,8 @@ if __name__ == '__main__':
parser.print_help() parser.print_help()
sys.exit(0) sys.exit(0)
for o in default_values.keys(): for o in list(default_values.keys()):
globals()[o] = getattr(options, o, default_values[string.lower(o)]) globals()[o] = getattr(options, o, default_values[o.lower()])
if options.newphrase: if options.newphrase:
phrase = getattr(options, 'newphrase') phrase = getattr(options, 'newphrase')
@ -106,7 +105,7 @@ if __name__ == '__main__':
if ktype == 'dsa' and bits > 1024: if ktype == 'dsa' and bits > 1024:
raise SSHException("DSA Keys must be 1024 bits") raise SSHException("DSA Keys must be 1024 bits")
if not key_dispatch_table.has_key(ktype): if ktype not in key_dispatch_table:
raise SSHException("Unknown %s algorithm to generate keys pair" % ktype) raise SSHException("Unknown %s algorithm to generate keys pair" % ktype)
# generating private key # generating private key
@ -121,7 +120,7 @@ if __name__ == '__main__':
f.write(" %s" % comment) f.write(" %s" % comment)
if options.verbose: if options.verbose:
print "done." print("done.")
hash = hexlify(pub.get_fingerprint()) hash = u(hexlify(pub.get_fingerprint()))
print "Fingerprint: %d %s %s.pub (%s)" % (bits, ":".join([ hash[i:2+i] for i in range(0, len(hash), 2)]), filename, string.upper(ktype)) print("Fingerprint: %d %s %s.pub (%s)" % (bits, ":".join([ hash[i:2+i] for i in range(0, len(hash), 2)]), filename, ktype.upper()))

View File

@ -27,6 +27,7 @@ import threading
import traceback import traceback
import paramiko import paramiko
from paramiko.py3compat import b, u, decodebytes
# setup logging # setup logging
@ -35,17 +36,17 @@ paramiko.util.log_to_file('demo_server.log')
host_key = paramiko.RSAKey(filename='test_rsa.key') host_key = paramiko.RSAKey(filename='test_rsa.key')
#host_key = paramiko.DSSKey(filename='test_dss.key') #host_key = paramiko.DSSKey(filename='test_dss.key')
print 'Read key: ' + hexlify(host_key.get_fingerprint()) print('Read key: ' + u(hexlify(host_key.get_fingerprint())))
class Server (paramiko.ServerInterface): class Server (paramiko.ServerInterface):
# 'data' is the output of base64.encodestring(str(key)) # 'data' is the output of base64.encodestring(str(key))
# (using the "user_rsa_key" files) # (using the "user_rsa_key" files)
data = 'AAAAB3NzaC1yc2EAAAABIwAAAIEAyO4it3fHlmGZWJaGrfeHOVY7RWO3P9M7hp' + \ data = (b'AAAAB3NzaC1yc2EAAAABIwAAAIEAyO4it3fHlmGZWJaGrfeHOVY7RWO3P9M7hp'
'fAu7jJ2d7eothvfeuoRFtJwhUmZDluRdFyhFY/hFAh76PJKGAusIqIQKlkJxMC' + \ b'fAu7jJ2d7eothvfeuoRFtJwhUmZDluRdFyhFY/hFAh76PJKGAusIqIQKlkJxMC'
'KDqIexkgHAfID/6mqvmnSJf0b5W8v5h2pI/stOSwTQ+pxVhwJ9ctYDhRSlF0iT' + \ b'KDqIexkgHAfID/6mqvmnSJf0b5W8v5h2pI/stOSwTQ+pxVhwJ9ctYDhRSlF0iT'
'UWT10hcuO4Ks8=' b'UWT10hcuO4Ks8=')
good_pub_key = paramiko.RSAKey(data=base64.decodestring(data)) good_pub_key = paramiko.RSAKey(data=decodebytes(data))
def __init__(self): def __init__(self):
self.event = threading.Event() self.event = threading.Event()
@ -61,7 +62,7 @@ class Server (paramiko.ServerInterface):
return paramiko.AUTH_FAILED return paramiko.AUTH_FAILED
def check_auth_publickey(self, username, key): def check_auth_publickey(self, username, key):
print 'Auth attempt with key: ' + hexlify(key.get_fingerprint()) print('Auth attempt with key: ' + u(hexlify(key.get_fingerprint())))
if (username == 'robey') and (key == self.good_pub_key): if (username == 'robey') and (key == self.good_pub_key):
return paramiko.AUTH_SUCCESSFUL return paramiko.AUTH_SUCCESSFUL
return paramiko.AUTH_FAILED return paramiko.AUTH_FAILED
@ -83,47 +84,47 @@ try:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind(('', 2200)) sock.bind(('', 2200))
except Exception, e: except Exception as e:
print '*** Bind failed: ' + str(e) print('*** Bind failed: ' + str(e))
traceback.print_exc() traceback.print_exc()
sys.exit(1) sys.exit(1)
try: try:
sock.listen(100) sock.listen(100)
print 'Listening for connection ...' print('Listening for connection ...')
client, addr = sock.accept() client, addr = sock.accept()
except Exception, e: except Exception as e:
print '*** Listen/accept failed: ' + str(e) print('*** Listen/accept failed: ' + str(e))
traceback.print_exc() traceback.print_exc()
sys.exit(1) sys.exit(1)
print 'Got a connection!' print('Got a connection!')
try: try:
t = paramiko.Transport(client) t = paramiko.Transport(client)
try: try:
t.load_server_moduli() t.load_server_moduli()
except: except:
print '(Failed to load moduli -- gex will be unsupported.)' print('(Failed to load moduli -- gex will be unsupported.)')
raise raise
t.add_server_key(host_key) t.add_server_key(host_key)
server = Server() server = Server()
try: try:
t.start_server(server=server) t.start_server(server=server)
except paramiko.SSHException, x: except paramiko.SSHException:
print '*** SSH negotiation failed.' print('*** SSH negotiation failed.')
sys.exit(1) sys.exit(1)
# wait for auth # wait for auth
chan = t.accept(20) chan = t.accept(20)
if chan is None: if chan is None:
print '*** No channel.' print('*** No channel.')
sys.exit(1) sys.exit(1)
print 'Authenticated!' print('Authenticated!')
server.event.wait(10) server.event.wait(10)
if not server.event.isSet(): if not server.event.isSet():
print '*** Client never asked for a shell.' print('*** Client never asked for a shell.')
sys.exit(1) sys.exit(1)
chan.send('\r\n\r\nWelcome to my dorky little BBS!\r\n\r\n') chan.send('\r\n\r\nWelcome to my dorky little BBS!\r\n\r\n')
@ -135,8 +136,8 @@ try:
chan.send('\r\nI don\'t like you, ' + username + '.\r\n') chan.send('\r\nI don\'t like you, ' + username + '.\r\n')
chan.close() chan.close()
except Exception, e: except Exception as e:
print '*** Caught exception: ' + str(e.__class__) + ': ' + str(e) print('*** Caught exception: ' + str(e.__class__) + ': ' + str(e))
traceback.print_exc() traceback.print_exc()
try: try:
t.close() t.close()

View File

@ -28,6 +28,7 @@ import sys
import traceback import traceback
import paramiko import paramiko
from paramiko.py3compat import input
# setup logging # setup logging
@ -40,9 +41,9 @@ if len(sys.argv) > 1:
if hostname.find('@') >= 0: if hostname.find('@') >= 0:
username, hostname = hostname.split('@') username, hostname = hostname.split('@')
else: else:
hostname = raw_input('Hostname: ') hostname = input('Hostname: ')
if len(hostname) == 0: if len(hostname) == 0:
print '*** Hostname required.' print('*** Hostname required.')
sys.exit(1) sys.exit(1)
port = 22 port = 22
if hostname.find(':') >= 0: if hostname.find(':') >= 0:
@ -53,7 +54,7 @@ if hostname.find(':') >= 0:
# get username # get username
if username == '': if username == '':
default_username = getpass.getuser() default_username = getpass.getuser()
username = raw_input('Username [%s]: ' % default_username) username = input('Username [%s]: ' % default_username)
if len(username) == 0: if len(username) == 0:
username = default_username username = default_username
password = getpass.getpass('Password for %s@%s: ' % (username, hostname)) password = getpass.getpass('Password for %s@%s: ' % (username, hostname))
@ -69,13 +70,13 @@ except IOError:
# try ~/ssh/ too, because windows can't have a folder named ~/.ssh/ # try ~/ssh/ too, because windows can't have a folder named ~/.ssh/
host_keys = paramiko.util.load_host_keys(os.path.expanduser('~/ssh/known_hosts')) host_keys = paramiko.util.load_host_keys(os.path.expanduser('~/ssh/known_hosts'))
except IOError: except IOError:
print '*** Unable to open host keys file' print('*** Unable to open host keys file')
host_keys = {} host_keys = {}
if host_keys.has_key(hostname): if hostname in host_keys:
hostkeytype = host_keys[hostname].keys()[0] hostkeytype = host_keys[hostname].keys()[0]
hostkey = host_keys[hostname][hostkeytype] hostkey = host_keys[hostname][hostkeytype]
print 'Using host key of type %s' % hostkeytype print('Using host key of type %s' % hostkeytype)
# now, connect and use paramiko Transport to negotiate SSH2 across the connection # now, connect and use paramiko Transport to negotiate SSH2 across the connection
@ -86,22 +87,26 @@ try:
# dirlist on remote host # dirlist on remote host
dirlist = sftp.listdir('.') dirlist = sftp.listdir('.')
print "Dirlist:", dirlist print("Dirlist: %s" % dirlist)
# copy this demo onto the server # copy this demo onto the server
try: try:
sftp.mkdir("demo_sftp_folder") sftp.mkdir("demo_sftp_folder")
except IOError: except IOError:
print '(assuming demo_sftp_folder/ already exists)' print('(assuming demo_sftp_folder/ already exists)')
sftp.open('demo_sftp_folder/README', 'w').write('This was created by demo_sftp.py.\n') with sftp.open('demo_sftp_folder/README', 'w') as f:
data = open('demo_sftp.py', 'r').read() f.write('This was created by demo_sftp.py.\n')
with open('demo_sftp.py', 'r') as f:
data = f.read()
sftp.open('demo_sftp_folder/demo_sftp.py', 'w').write(data) sftp.open('demo_sftp_folder/demo_sftp.py', 'w').write(data)
print 'created demo_sftp_folder/ on the server' print('created demo_sftp_folder/ on the server')
# copy the README back here # copy the README back here
data = sftp.open('demo_sftp_folder/README', 'r').read() with sftp.open('demo_sftp_folder/README', 'r') as f:
open('README_demo_sftp', 'w').write(data) data = f.read()
print 'copied README back here' with open('README_demo_sftp', 'w') as f:
f.write(data)
print('copied README back here')
# BETTER: use the get() and put() methods # BETTER: use the get() and put() methods
sftp.put('demo_sftp.py', 'demo_sftp_folder/demo_sftp.py') sftp.put('demo_sftp.py', 'demo_sftp_folder/demo_sftp.py')
@ -109,8 +114,8 @@ try:
t.close() t.close()
except Exception, e: except Exception as e:
print '*** Caught exception: %s: %s' % (e.__class__, e) print('*** Caught exception: %s: %s' % (e.__class__, e))
traceback.print_exc() traceback.print_exc()
try: try:
t.close() t.close()

View File

@ -25,9 +25,13 @@ import os
import socket import socket
import sys import sys
import traceback import traceback
from paramiko.py3compat import input
import paramiko import paramiko
import interactive try:
import interactive
except ImportError:
from . import interactive
# setup logging # setup logging
@ -40,9 +44,9 @@ if len(sys.argv) > 1:
if hostname.find('@') >= 0: if hostname.find('@') >= 0:
username, hostname = hostname.split('@') username, hostname = hostname.split('@')
else: else:
hostname = raw_input('Hostname: ') hostname = input('Hostname: ')
if len(hostname) == 0: if len(hostname) == 0:
print '*** Hostname required.' print('*** Hostname required.')
sys.exit(1) sys.exit(1)
port = 22 port = 22
if hostname.find(':') >= 0: if hostname.find(':') >= 0:
@ -53,7 +57,7 @@ if hostname.find(':') >= 0:
# get username # get username
if username == '': if username == '':
default_username = getpass.getuser() default_username = getpass.getuser()
username = raw_input('Username [%s]: ' % default_username) username = input('Username [%s]: ' % default_username)
if len(username) == 0: if len(username) == 0:
username = default_username username = default_username
password = getpass.getpass('Password for %s@%s: ' % (username, hostname)) password = getpass.getpass('Password for %s@%s: ' % (username, hostname))
@ -64,18 +68,17 @@ try:
client = paramiko.SSHClient() client = paramiko.SSHClient()
client.load_system_host_keys() client.load_system_host_keys()
client.set_missing_host_key_policy(paramiko.WarningPolicy()) client.set_missing_host_key_policy(paramiko.WarningPolicy())
print '*** Connecting...' print('*** Connecting...')
client.connect(hostname, port, username, password) client.connect(hostname, port, username, password)
chan = client.invoke_shell() chan = client.invoke_shell()
print repr(client.get_transport()) print(repr(client.get_transport()))
print '*** Here we go!' print('*** Here we go!\n')
print
interactive.interactive_shell(chan) interactive.interactive_shell(chan)
chan.close() chan.close()
client.close() client.close()
except Exception, e: except Exception as e:
print '*** Caught exception: %s: %s' % (e.__class__, e) print('*** Caught exception: %s: %s' % (e.__class__, e))
traceback.print_exc() traceback.print_exc()
try: try:
client.close() client.close()

View File

@ -30,7 +30,11 @@ import getpass
import os import os
import socket import socket
import select import select
import SocketServer try:
import SocketServer
except ImportError:
import socketserver as SocketServer
import sys import sys
from optparse import OptionParser from optparse import OptionParser
@ -54,7 +58,7 @@ class Handler (SocketServer.BaseRequestHandler):
chan = self.ssh_transport.open_channel('direct-tcpip', chan = self.ssh_transport.open_channel('direct-tcpip',
(self.chain_host, self.chain_port), (self.chain_host, self.chain_port),
self.request.getpeername()) self.request.getpeername())
except Exception, e: except Exception as e:
verbose('Incoming request to %s:%d failed: %s' % (self.chain_host, verbose('Incoming request to %s:%d failed: %s' % (self.chain_host,
self.chain_port, self.chain_port,
repr(e))) repr(e)))
@ -98,7 +102,7 @@ def forward_tunnel(local_port, remote_host, remote_port, transport):
def verbose(s): def verbose(s):
if g_verbose: if g_verbose:
print s print(s)
HELP = """\ HELP = """\
@ -165,8 +169,8 @@ def main():
try: try:
client.connect(server[0], server[1], username=options.user, key_filename=options.keyfile, client.connect(server[0], server[1], username=options.user, key_filename=options.keyfile,
look_for_keys=options.look_for_keys, password=password) look_for_keys=options.look_for_keys, password=password)
except Exception, e: except Exception as e:
print '*** Failed to connect to %s:%d: %r' % (server[0], server[1], e) print('*** Failed to connect to %s:%d: %r' % (server[0], server[1], e))
sys.exit(1) sys.exit(1)
verbose('Now forwarding port %d to %s:%d ...' % (options.port, remote[0], remote[1])) verbose('Now forwarding port %d to %s:%d ...' % (options.port, remote[0], remote[1]))
@ -174,7 +178,7 @@ def main():
try: try:
forward_tunnel(options.port, remote[0], remote[1], client.get_transport()) forward_tunnel(options.port, remote[0], remote[1], client.get_transport())
except KeyboardInterrupt: except KeyboardInterrupt:
print 'C-c: Port forwarding stopped.' print('C-c: Port forwarding stopped.')
sys.exit(0) sys.exit(0)

View File

@ -19,6 +19,7 @@
import socket import socket
import sys import sys
from paramiko.py3compat import u
# windows does not have termios... # windows does not have termios...
try: try:
@ -49,9 +50,9 @@ def posix_shell(chan):
r, w, e = select.select([chan, sys.stdin], [], []) r, w, e = select.select([chan, sys.stdin], [], [])
if chan in r: if chan in r:
try: try:
x = chan.recv(1024) x = u(chan.recv(1024))
if len(x) == 0: if len(x) == 0:
print '\r\n*** EOF\r\n', sys.stdout.write('\r\n*** EOF\r\n')
break break
sys.stdout.write(x) sys.stdout.write(x)
sys.stdout.flush() sys.stdout.flush()

View File

@ -46,7 +46,7 @@ def handler(chan, host, port):
sock = socket.socket() sock = socket.socket()
try: try:
sock.connect((host, port)) sock.connect((host, port))
except Exception, e: except Exception as e:
verbose('Forwarding request to %s:%d failed: %r' % (host, port, e)) verbose('Forwarding request to %s:%d failed: %r' % (host, port, e))
return return
@ -82,7 +82,7 @@ def reverse_forward_tunnel(server_port, remote_host, remote_port, transport):
def verbose(s): def verbose(s):
if g_verbose: if g_verbose:
print s print(s)
HELP = """\ HELP = """\
@ -150,8 +150,8 @@ def main():
try: try:
client.connect(server[0], server[1], username=options.user, key_filename=options.keyfile, client.connect(server[0], server[1], username=options.user, key_filename=options.keyfile,
look_for_keys=options.look_for_keys, password=password) look_for_keys=options.look_for_keys, password=password)
except Exception, e: except Exception as e:
print '*** Failed to connect to %s:%d: %r' % (server[0], server[1], e) print('*** Failed to connect to %s:%d: %r' % (server[0], server[1], e))
sys.exit(1) sys.exit(1)
verbose('Now forwarding remote port %d to %s:%d ...' % (options.port, remote[0], remote[1])) verbose('Now forwarding remote port %d to %s:%d ...' % (options.port, remote[0], remote[1]))
@ -159,7 +159,7 @@ def main():
try: try:
reverse_forward_tunnel(options.port, remote[0], remote[1], client.get_transport()) reverse_forward_tunnel(options.port, remote[0], remote[1], client.get_transport())
except KeyboardInterrupt: except KeyboardInterrupt:
print 'C-c: Port forwarding stopped.' print('C-c: Port forwarding stopped.')
sys.exit(0) sys.exit(0)

View File

@ -18,51 +18,51 @@
import sys import sys
if sys.version_info < (2, 5): if sys.version_info < (2, 6):
raise RuntimeError('You need Python 2.5+ for this module.') raise RuntimeError('You need Python 2.6+ for this module.')
__author__ = "Jeff Forcier <jeff@bitprophet.org>" __author__ = "Jeff Forcier <jeff@bitprophet.org>"
__version__ = "1.12.2" __version__ = "1.13.0"
__version_info__ = tuple([ int(d) for d in __version__.split(".") ]) __version_info__ = tuple([ int(d) for d in __version__.split(".") ])
__license__ = "GNU Lesser General Public License (LGPL)" __license__ = "GNU Lesser General Public License (LGPL)"
from transport import SecurityOptions, Transport from paramiko.transport import SecurityOptions, Transport
from client import SSHClient, MissingHostKeyPolicy, AutoAddPolicy, RejectPolicy, WarningPolicy from paramiko.client import SSHClient, MissingHostKeyPolicy, AutoAddPolicy, RejectPolicy, WarningPolicy
from auth_handler import AuthHandler from paramiko.auth_handler import AuthHandler
from channel import Channel, ChannelFile from paramiko.channel import Channel, ChannelFile
from ssh_exception import SSHException, PasswordRequiredException, \ from paramiko.ssh_exception import SSHException, PasswordRequiredException, \
BadAuthenticationType, ChannelException, BadHostKeyException, \ BadAuthenticationType, ChannelException, BadHostKeyException, \
AuthenticationException, ProxyCommandFailure AuthenticationException, ProxyCommandFailure
from server import ServerInterface, SubsystemHandler, InteractiveQuery from paramiko.server import ServerInterface, SubsystemHandler, InteractiveQuery
from rsakey import RSAKey from paramiko.rsakey import RSAKey
from dsskey import DSSKey from paramiko.dsskey import DSSKey
from ecdsakey import ECDSAKey from paramiko.ecdsakey import ECDSAKey
from sftp import SFTPError, BaseSFTP from paramiko.sftp import SFTPError, BaseSFTP
from sftp_client import SFTP, SFTPClient from paramiko.sftp_client import SFTP, SFTPClient
from sftp_server import SFTPServer from paramiko.sftp_server import SFTPServer
from sftp_attr import SFTPAttributes from paramiko.sftp_attr import SFTPAttributes
from sftp_handle import SFTPHandle from paramiko.sftp_handle import SFTPHandle
from sftp_si import SFTPServerInterface from paramiko.sftp_si import SFTPServerInterface
from sftp_file import SFTPFile from paramiko.sftp_file import SFTPFile
from message import Message from paramiko.message import Message
from packet import Packetizer from paramiko.packet import Packetizer
from file import BufferedFile from paramiko.file import BufferedFile
from agent import Agent, AgentKey from paramiko.agent import Agent, AgentKey
from pkey import PKey from paramiko.pkey import PKey
from hostkeys import HostKeys from paramiko.hostkeys import HostKeys
from config import SSHConfig from paramiko.config import SSHConfig
from proxy import ProxyCommand from paramiko.proxy import ProxyCommand
from common import AUTH_SUCCESSFUL, AUTH_PARTIALLY_SUCCESSFUL, AUTH_FAILED, \ from paramiko.common import AUTH_SUCCESSFUL, AUTH_PARTIALLY_SUCCESSFUL, AUTH_FAILED, \
OPEN_SUCCEEDED, OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED, OPEN_FAILED_CONNECT_FAILED, \ OPEN_SUCCEEDED, OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED, OPEN_FAILED_CONNECT_FAILED, \
OPEN_FAILED_UNKNOWN_CHANNEL_TYPE, OPEN_FAILED_RESOURCE_SHORTAGE OPEN_FAILED_UNKNOWN_CHANNEL_TYPE, OPEN_FAILED_RESOURCE_SHORTAGE
from sftp import SFTP_OK, SFTP_EOF, SFTP_NO_SUCH_FILE, SFTP_PERMISSION_DENIED, SFTP_FAILURE, \ from paramiko.sftp import SFTP_OK, SFTP_EOF, SFTP_NO_SUCH_FILE, SFTP_PERMISSION_DENIED, SFTP_FAILURE, \
SFTP_BAD_MESSAGE, SFTP_NO_CONNECTION, SFTP_CONNECTION_LOST, SFTP_OP_UNSUPPORTED SFTP_BAD_MESSAGE, SFTP_NO_CONNECTION, SFTP_CONNECTION_LOST, SFTP_OP_UNSUPPORTED
from common import io_sleep from paramiko.common import io_sleep
__all__ = [ 'Transport', __all__ = [ 'Transport',
'SSHClient', 'SSHClient',

View File

@ -8,7 +8,11 @@ in jaraco.windows and asking the author to port the fixes back here.
import ctypes import ctypes
import ctypes.wintypes import ctypes.wintypes
import __builtin__ from paramiko.py3compat import u
try:
import builtins
except ImportError:
import __builtin__ as builtins
try: try:
USHORT = ctypes.wintypes.USHORT USHORT = ctypes.wintypes.USHORT
@ -40,7 +44,7 @@ def format_system_message(errno):
result_buffer = ctypes.wintypes.LPWSTR() result_buffer = ctypes.wintypes.LPWSTR()
buffer_size = 0 buffer_size = 0
arguments = None arguments = None
bytes = ctypes.windll.kernel32.FormatMessageW( format_bytes = ctypes.windll.kernel32.FormatMessageW(
flags, flags,
source, source,
message_id, message_id,
@ -52,13 +56,13 @@ def format_system_message(errno):
# note the following will cause an infinite loop if GetLastError # note the following will cause an infinite loop if GetLastError
# repeatedly returns an error that cannot be formatted, although # repeatedly returns an error that cannot be formatted, although
# this should not happen. # this should not happen.
handle_nonzero_success(bytes) handle_nonzero_success(format_bytes)
message = result_buffer.value message = result_buffer.value
ctypes.windll.kernel32.LocalFree(result_buffer) ctypes.windll.kernel32.LocalFree(result_buffer)
return message return message
class WindowsError(__builtin__.WindowsError): class WindowsError(builtins.WindowsError):
"more info about errors at http://msdn.microsoft.com/en-us/library/ms681381(VS.85).aspx" "more info about errors at http://msdn.microsoft.com/en-us/library/ms681381(VS.85).aspx"
def __init__(self, value=None): def __init__(self, value=None):
@ -120,7 +124,7 @@ class MemoryMap(object):
FILE_MAP_WRITE = 0x2 FILE_MAP_WRITE = 0x2
filemap = ctypes.windll.kernel32.CreateFileMappingW( filemap = ctypes.windll.kernel32.CreateFileMappingW(
INVALID_HANDLE_VALUE, p_SA, PAGE_READWRITE, 0, self.length, INVALID_HANDLE_VALUE, p_SA, PAGE_READWRITE, 0, self.length,
unicode(self.name)) u(self.name))
handle_nonzero_success(filemap) handle_nonzero_success(filemap)
if filemap == INVALID_HANDLE_VALUE: if filemap == INVALID_HANDLE_VALUE:
raise Exception("Failed to create file mapping") raise Exception("Failed to create file mapping")

View File

@ -34,11 +34,14 @@ from paramiko.ssh_exception import SSHException
from paramiko.message import Message from paramiko.message import Message
from paramiko.pkey import PKey from paramiko.pkey import PKey
from paramiko.channel import Channel from paramiko.channel import Channel
from paramiko.common import io_sleep from paramiko.common import *
from paramiko.util import retry_on_signal from paramiko.util import retry_on_signal
SSH2_AGENTC_REQUEST_IDENTITIES, SSH2_AGENT_IDENTITIES_ANSWER, \ cSSH2_AGENTC_REQUEST_IDENTITIES = byte_chr(11)
SSH2_AGENTC_SIGN_REQUEST, SSH2_AGENT_SIGN_RESPONSE = range(11, 15) SSH2_AGENT_IDENTITIES_ANSWER = 12
cSSH2_AGENTC_SIGN_REQUEST = byte_chr(13)
SSH2_AGENT_SIGN_RESPONSE = 14
class AgentSSH(object): class AgentSSH(object):
@ -60,12 +63,12 @@ class AgentSSH(object):
def _connect(self, conn): def _connect(self, conn):
self._conn = conn self._conn = conn
ptype, result = self._send_message(chr(SSH2_AGENTC_REQUEST_IDENTITIES)) ptype, result = self._send_message(cSSH2_AGENTC_REQUEST_IDENTITIES)
if ptype != SSH2_AGENT_IDENTITIES_ANSWER: if ptype != SSH2_AGENT_IDENTITIES_ANSWER:
raise SSHException('could not get keys from ssh-agent') raise SSHException('could not get keys from ssh-agent')
keys = [] keys = []
for i in range(result.get_int()): for i in range(result.get_int()):
keys.append(AgentKey(self, result.get_string())) keys.append(AgentKey(self, result.get_binary()))
result.get_string() result.get_string()
self._keys = tuple(keys) self._keys = tuple(keys)
@ -75,7 +78,7 @@ class AgentSSH(object):
self._keys = () self._keys = ()
def _send_message(self, msg): def _send_message(self, msg):
msg = str(msg) msg = asbytes(msg)
self._conn.send(struct.pack('>I', len(msg)) + msg) self._conn.send(struct.pack('>I', len(msg)) + msg)
l = self._read_all(4) l = self._read_all(4)
msg = Message(self._read_all(struct.unpack('>I', l)[0])) msg = Message(self._read_all(struct.unpack('>I', l)[0]))
@ -212,7 +215,7 @@ class AgentClientProxy(object):
# probably a dangling env var: the ssh agent is gone # probably a dangling env var: the ssh agent is gone
return return
elif sys.platform == 'win32': elif sys.platform == 'win32':
import win_pageant import paramiko.win_pageant as win_pageant
if win_pageant.can_talk_to_agent(): if win_pageant.can_talk_to_agent():
conn = win_pageant.PageantConnection() conn = win_pageant.PageantConnection()
else: else:
@ -328,7 +331,7 @@ class Agent(AgentSSH):
# probably a dangling env var: the ssh agent is gone # probably a dangling env var: the ssh agent is gone
return return
elif sys.platform == 'win32': elif sys.platform == 'win32':
import win_pageant from . import win_pageant
if win_pageant.can_talk_to_agent(): if win_pageant.can_talk_to_agent():
conn = win_pageant.PageantConnection() conn = win_pageant.PageantConnection()
else: else:
@ -354,21 +357,24 @@ class AgentKey(PKey):
def __init__(self, agent, blob): def __init__(self, agent, blob):
self.agent = agent self.agent = agent
self.blob = blob self.blob = blob
self.name = Message(blob).get_string() self.name = Message(blob).get_text()
def asbytes(self):
return self.blob
def __str__(self): def __str__(self):
return self.blob return self.asbytes()
def get_name(self): def get_name(self):
return self.name return self.name
def sign_ssh_data(self, rng, data): def sign_ssh_data(self, rng, data):
msg = Message() msg = Message()
msg.add_byte(chr(SSH2_AGENTC_SIGN_REQUEST)) msg.add_byte(cSSH2_AGENTC_SIGN_REQUEST)
msg.add_string(self.blob) msg.add_string(self.blob)
msg.add_string(data) msg.add_string(data)
msg.add_int(0) msg.add_int(0)
ptype, result = self.agent._send_message(msg) ptype, result = self.agent._send_message(msg)
if ptype != SSH2_AGENT_SIGN_RESPONSE: if ptype != SSH2_AGENT_SIGN_RESPONSE:
raise SSHException('key cannot be used for signing') raise SSHException('key cannot be used for signing')
return result.get_string() return result.get_binary()

View File

@ -120,13 +120,13 @@ class AuthHandler (object):
def _request_auth(self): def _request_auth(self):
m = Message() m = Message()
m.add_byte(chr(MSG_SERVICE_REQUEST)) m.add_byte(cMSG_SERVICE_REQUEST)
m.add_string('ssh-userauth') m.add_string('ssh-userauth')
self.transport._send_message(m) self.transport._send_message(m)
def _disconnect_service_not_available(self): def _disconnect_service_not_available(self):
m = Message() m = Message()
m.add_byte(chr(MSG_DISCONNECT)) m.add_byte(cMSG_DISCONNECT)
m.add_int(DISCONNECT_SERVICE_NOT_AVAILABLE) m.add_int(DISCONNECT_SERVICE_NOT_AVAILABLE)
m.add_string('Service not available') m.add_string('Service not available')
m.add_string('en') m.add_string('en')
@ -135,7 +135,7 @@ class AuthHandler (object):
def _disconnect_no_more_auth(self): def _disconnect_no_more_auth(self):
m = Message() m = Message()
m.add_byte(chr(MSG_DISCONNECT)) m.add_byte(cMSG_DISCONNECT)
m.add_int(DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE) m.add_int(DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE)
m.add_string('No more auth methods available') m.add_string('No more auth methods available')
m.add_string('en') m.add_string('en')
@ -145,14 +145,14 @@ class AuthHandler (object):
def _get_session_blob(self, key, service, username): def _get_session_blob(self, key, service, username):
m = Message() m = Message()
m.add_string(self.transport.session_id) m.add_string(self.transport.session_id)
m.add_byte(chr(MSG_USERAUTH_REQUEST)) m.add_byte(cMSG_USERAUTH_REQUEST)
m.add_string(username) m.add_string(username)
m.add_string(service) m.add_string(service)
m.add_string('publickey') m.add_string('publickey')
m.add_boolean(1) m.add_boolean(1)
m.add_string(key.get_name()) m.add_string(key.get_name())
m.add_string(str(key)) m.add_string(key)
return str(m) return m.asbytes()
def wait_for_response(self, event): def wait_for_response(self, event):
while True: while True:
@ -176,11 +176,11 @@ class AuthHandler (object):
return [] return []
def _parse_service_request(self, m): def _parse_service_request(self, m):
service = m.get_string() service = m.get_text()
if self.transport.server_mode and (service == 'ssh-userauth'): if self.transport.server_mode and (service == 'ssh-userauth'):
# accepted # accepted
m = Message() m = Message()
m.add_byte(chr(MSG_SERVICE_ACCEPT)) m.add_byte(cMSG_SERVICE_ACCEPT)
m.add_string(service) m.add_string(service)
self.transport._send_message(m) self.transport._send_message(m)
return return
@ -188,27 +188,25 @@ class AuthHandler (object):
self._disconnect_service_not_available() self._disconnect_service_not_available()
def _parse_service_accept(self, m): def _parse_service_accept(self, m):
service = m.get_string() service = m.get_text()
if service == 'ssh-userauth': if service == 'ssh-userauth':
self.transport._log(DEBUG, 'userauth is OK') self.transport._log(DEBUG, 'userauth is OK')
m = Message() m = Message()
m.add_byte(chr(MSG_USERAUTH_REQUEST)) m.add_byte(cMSG_USERAUTH_REQUEST)
m.add_string(self.username) m.add_string(self.username)
m.add_string('ssh-connection') m.add_string('ssh-connection')
m.add_string(self.auth_method) m.add_string(self.auth_method)
if self.auth_method == 'password': if self.auth_method == 'password':
m.add_boolean(False) m.add_boolean(False)
password = self.password password = bytestring(self.password)
if isinstance(password, unicode):
password = password.encode('UTF-8')
m.add_string(password) m.add_string(password)
elif self.auth_method == 'publickey': elif self.auth_method == 'publickey':
m.add_boolean(True) m.add_boolean(True)
m.add_string(self.private_key.get_name()) m.add_string(self.private_key.get_name())
m.add_string(str(self.private_key)) m.add_string(self.private_key)
blob = self._get_session_blob(self.private_key, 'ssh-connection', self.username) blob = self._get_session_blob(self.private_key, 'ssh-connection', self.username)
sig = self.private_key.sign_ssh_data(self.transport.rng, blob) sig = self.private_key.sign_ssh_data(self.transport.rng, blob)
m.add_string(str(sig)) m.add_string(sig)
elif self.auth_method == 'keyboard-interactive': elif self.auth_method == 'keyboard-interactive':
m.add_string('') m.add_string('')
m.add_string(self.submethods) m.add_string(self.submethods)
@ -225,11 +223,11 @@ class AuthHandler (object):
m = Message() m = Message()
if result == AUTH_SUCCESSFUL: if result == AUTH_SUCCESSFUL:
self.transport._log(INFO, 'Auth granted (%s).' % method) self.transport._log(INFO, 'Auth granted (%s).' % method)
m.add_byte(chr(MSG_USERAUTH_SUCCESS)) m.add_byte(cMSG_USERAUTH_SUCCESS)
self.authenticated = True self.authenticated = True
else: else:
self.transport._log(INFO, 'Auth rejected (%s).' % method) self.transport._log(INFO, 'Auth rejected (%s).' % method)
m.add_byte(chr(MSG_USERAUTH_FAILURE)) m.add_byte(cMSG_USERAUTH_FAILURE)
m.add_string(self.transport.server_object.get_allowed_auths(username)) m.add_string(self.transport.server_object.get_allowed_auths(username))
if result == AUTH_PARTIALLY_SUCCESSFUL: if result == AUTH_PARTIALLY_SUCCESSFUL:
m.add_boolean(1) m.add_boolean(1)
@ -245,10 +243,10 @@ class AuthHandler (object):
def _interactive_query(self, q): def _interactive_query(self, q):
# make interactive query instead of response # make interactive query instead of response
m = Message() m = Message()
m.add_byte(chr(MSG_USERAUTH_INFO_REQUEST)) m.add_byte(cMSG_USERAUTH_INFO_REQUEST)
m.add_string(q.name) m.add_string(q.name)
m.add_string(q.instructions) m.add_string(q.instructions)
m.add_string('') m.add_string(bytes())
m.add_int(len(q.prompts)) m.add_int(len(q.prompts))
for p in q.prompts: for p in q.prompts:
m.add_string(p[0]) m.add_string(p[0])
@ -259,7 +257,7 @@ class AuthHandler (object):
if not self.transport.server_mode: if not self.transport.server_mode:
# er, uh... what? # er, uh... what?
m = Message() m = Message()
m.add_byte(chr(MSG_USERAUTH_FAILURE)) m.add_byte(cMSG_USERAUTH_FAILURE)
m.add_string('none') m.add_string('none')
m.add_boolean(0) m.add_boolean(0)
self.transport._send_message(m) self.transport._send_message(m)
@ -267,9 +265,9 @@ class AuthHandler (object):
if self.authenticated: if self.authenticated:
# ignore # ignore
return return
username = m.get_string() username = m.get_text()
service = m.get_string() service = m.get_text()
method = m.get_string() method = m.get_text()
self.transport._log(DEBUG, 'Auth request (type=%s) service=%s, username=%s' % (method, service, username)) self.transport._log(DEBUG, 'Auth request (type=%s) service=%s, username=%s' % (method, service, username))
if service != 'ssh-connection': if service != 'ssh-connection':
self._disconnect_service_not_available() self._disconnect_service_not_available()
@ -284,7 +282,7 @@ class AuthHandler (object):
result = self.transport.server_object.check_auth_none(username) result = self.transport.server_object.check_auth_none(username)
elif method == 'password': elif method == 'password':
changereq = m.get_boolean() changereq = m.get_boolean()
password = m.get_string() password = m.get_binary()
try: try:
password = password.decode('UTF-8') password = password.decode('UTF-8')
except UnicodeError: except UnicodeError:
@ -295,7 +293,7 @@ class AuthHandler (object):
# always treated as failure, since we don't support changing passwords, but collect # always treated as failure, since we don't support changing passwords, but collect
# the list of valid auth types from the callback anyway # the list of valid auth types from the callback anyway
self.transport._log(DEBUG, 'Auth request to change passwords (rejected)') self.transport._log(DEBUG, 'Auth request to change passwords (rejected)')
newpassword = m.get_string() newpassword = m.get_binary()
try: try:
newpassword = newpassword.decode('UTF-8', 'replace') newpassword = newpassword.decode('UTF-8', 'replace')
except UnicodeError: except UnicodeError:
@ -305,11 +303,11 @@ class AuthHandler (object):
result = self.transport.server_object.check_auth_password(username, password) result = self.transport.server_object.check_auth_password(username, password)
elif method == 'publickey': elif method == 'publickey':
sig_attached = m.get_boolean() sig_attached = m.get_boolean()
keytype = m.get_string() keytype = m.get_text()
keyblob = m.get_string() keyblob = m.get_binary()
try: try:
key = self.transport._key_info[keytype](Message(keyblob)) key = self.transport._key_info[keytype](Message(keyblob))
except SSHException, e: except SSHException as e:
self.transport._log(INFO, 'Auth rejected: public key: %s' % str(e)) self.transport._log(INFO, 'Auth rejected: public key: %s' % str(e))
key = None key = None
except: except:
@ -326,12 +324,12 @@ class AuthHandler (object):
# client wants to know if this key is acceptable, before it # client wants to know if this key is acceptable, before it
# signs anything... send special "ok" message # signs anything... send special "ok" message
m = Message() m = Message()
m.add_byte(chr(MSG_USERAUTH_PK_OK)) m.add_byte(cMSG_USERAUTH_PK_OK)
m.add_string(keytype) m.add_string(keytype)
m.add_string(keyblob) m.add_string(keyblob)
self.transport._send_message(m) self.transport._send_message(m)
return return
sig = Message(m.get_string()) sig = Message(m.get_binary())
blob = self._get_session_blob(key, service, username) blob = self._get_session_blob(key, service, username)
if not key.verify_ssh_sig(blob, sig): if not key.verify_ssh_sig(blob, sig):
self.transport._log(INFO, 'Auth rejected: invalid signature') self.transport._log(INFO, 'Auth rejected: invalid signature')
@ -378,23 +376,23 @@ class AuthHandler (object):
banner = m.get_string() banner = m.get_string()
self.banner = banner self.banner = banner
lang = m.get_string() lang = m.get_string()
self.transport._log(INFO, 'Auth banner: ' + banner) self.transport._log(INFO, 'Auth banner: %s' % banner)
# who cares. # who cares.
def _parse_userauth_info_request(self, m): def _parse_userauth_info_request(self, m):
if self.auth_method != 'keyboard-interactive': if self.auth_method != 'keyboard-interactive':
raise SSHException('Illegal info request from server') raise SSHException('Illegal info request from server')
title = m.get_string() title = m.get_text()
instructions = m.get_string() instructions = m.get_text()
m.get_string() # lang m.get_binary() # lang
prompts = m.get_int() prompts = m.get_int()
prompt_list = [] prompt_list = []
for i in range(prompts): for i in range(prompts):
prompt_list.append((m.get_string(), m.get_boolean())) prompt_list.append((m.get_text(), m.get_boolean()))
response_list = self.interactive_handler(title, instructions, prompt_list) response_list = self.interactive_handler(title, instructions, prompt_list)
m = Message() m = Message()
m.add_byte(chr(MSG_USERAUTH_INFO_RESPONSE)) m.add_byte(cMSG_USERAUTH_INFO_RESPONSE)
m.add_int(len(response_list)) m.add_int(len(response_list))
for r in response_list: for r in response_list:
m.add_string(r) m.add_string(r)
@ -406,7 +404,7 @@ class AuthHandler (object):
n = m.get_int() n = m.get_int()
responses = [] responses = []
for i in range(n): for i in range(n):
responses.append(m.get_string()) responses.append(m.get_text())
result = self.transport.server_object.check_auth_interactive_response(responses) result = self.transport.server_object.check_auth_interactive_response(responses)
if isinstance(type(result), InteractiveQuery): if isinstance(type(result), InteractiveQuery):
# make interactive query instead of response # make interactive query instead of response

View File

@ -17,7 +17,8 @@
# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA. # 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA.
import util import paramiko.util as util
from paramiko.common import *
class BERException (Exception): class BERException (Exception):
@ -29,13 +30,16 @@ class BER(object):
Robey's tiny little attempt at a BER decoder. Robey's tiny little attempt at a BER decoder.
""" """
def __init__(self, content=''): def __init__(self, content=bytes()):
self.content = content self.content = b(content)
self.idx = 0 self.idx = 0
def __str__(self): def asbytes(self):
return self.content return self.content
def __str__(self):
return self.asbytes()
def __repr__(self): def __repr__(self):
return 'BER(\'' + repr(self.content) + '\')' return 'BER(\'' + repr(self.content) + '\')'
@ -45,13 +49,13 @@ class BER(object):
def decode_next(self): def decode_next(self):
if self.idx >= len(self.content): if self.idx >= len(self.content):
return None return None
ident = ord(self.content[self.idx]) ident = byte_ord(self.content[self.idx])
self.idx += 1 self.idx += 1
if (ident & 31) == 31: if (ident & 31) == 31:
# identifier > 30 # identifier > 30
ident = 0 ident = 0
while self.idx < len(self.content): while self.idx < len(self.content):
t = ord(self.content[self.idx]) t = byte_ord(self.content[self.idx])
self.idx += 1 self.idx += 1
ident = (ident << 7) | (t & 0x7f) ident = (ident << 7) | (t & 0x7f)
if not (t & 0x80): if not (t & 0x80):
@ -59,7 +63,7 @@ class BER(object):
if self.idx >= len(self.content): if self.idx >= len(self.content):
return None return None
# now fetch length # now fetch length
size = ord(self.content[self.idx]) size = byte_ord(self.content[self.idx])
self.idx += 1 self.idx += 1
if size & 0x80: if size & 0x80:
# more complimicated... # more complimicated...
@ -98,20 +102,20 @@ class BER(object):
def encode_tlv(self, ident, val): def encode_tlv(self, ident, val):
# no need to support ident > 31 here # no need to support ident > 31 here
self.content += chr(ident) self.content += byte_chr(ident)
if len(val) > 0x7f: if len(val) > 0x7f:
lenstr = util.deflate_long(len(val)) lenstr = util.deflate_long(len(val))
self.content += chr(0x80 + len(lenstr)) + lenstr self.content += byte_chr(0x80 + len(lenstr)) + lenstr
else: else:
self.content += chr(len(val)) self.content += byte_chr(len(val))
self.content += val self.content += val
def encode(self, x): def encode(self, x):
if type(x) is bool: if type(x) is bool:
if x: if x:
self.encode_tlv(1, '\xff') self.encode_tlv(1, max_byte)
else: else:
self.encode_tlv(1, '\x00') self.encode_tlv(1, zero_byte)
elif (type(x) is int) or (type(x) is long): elif (type(x) is int) or (type(x) is long):
self.encode_tlv(2, util.deflate_long(x)) self.encode_tlv(2, util.deflate_long(x))
elif type(x) is str: elif type(x) is str:
@ -125,5 +129,5 @@ class BER(object):
b = BER() b = BER()
for item in data: for item in data:
b.encode(item) b.encode(item)
return str(b) return b.asbytes()
encode_sequence = staticmethod(encode_sequence) encode_sequence = staticmethod(encode_sequence)

View File

@ -25,6 +25,7 @@ read operations are blocking and can have a timeout set.
import array import array
import threading import threading
import time import time
from paramiko.common import *
class PipeTimeout (IOError): class PipeTimeout (IOError):
@ -48,6 +49,20 @@ class BufferedPipe (object):
self._buffer = array.array('B') self._buffer = array.array('B')
self._closed = False self._closed = False
if PY2:
def _buffer_frombytes(self, data):
self._buffer.fromstring(data)
def _buffer_tobytes(self, limit=None):
return self._buffer[:limit].tostring()
else:
def _buffer_frombytes(self, data):
self._buffer.frombytes(data)
def _buffer_tobytes(self, limit=None):
return self._buffer[:limit].tobytes()
def set_event(self, event): def set_event(self, event):
""" """
Set an event on this buffer. When data is ready to be read (or the Set an event on this buffer. When data is ready to be read (or the
@ -73,7 +88,7 @@ class BufferedPipe (object):
try: try:
if self._event is not None: if self._event is not None:
self._event.set() self._event.set()
self._buffer.fromstring(data) self._buffer_frombytes(b(data))
self._cv.notifyAll() self._cv.notifyAll()
finally: finally:
self._lock.release() self._lock.release()
@ -117,7 +132,7 @@ class BufferedPipe (object):
if a timeout was specified and no data was ready before that if a timeout was specified and no data was ready before that
timeout timeout
""" """
out = '' out = bytes()
self._lock.acquire() self._lock.acquire()
try: try:
if len(self._buffer) == 0: if len(self._buffer) == 0:
@ -138,12 +153,12 @@ class BufferedPipe (object):
# something's in the buffer and we have the lock! # something's in the buffer and we have the lock!
if len(self._buffer) <= nbytes: if len(self._buffer) <= nbytes:
out = self._buffer.tostring() out = self._buffer_tobytes()
del self._buffer[:] del self._buffer[:]
if (self._event is not None) and not self._closed: if (self._event is not None) and not self._closed:
self._event.clear() self._event.clear()
else: else:
out = self._buffer[:nbytes].tostring() out = self._buffer_tobytes(nbytes)
del self._buffer[:nbytes] del self._buffer[:nbytes]
finally: finally:
self._lock.release() self._lock.release()
@ -160,7 +175,7 @@ class BufferedPipe (object):
""" """
self._lock.acquire() self._lock.acquire()
try: try:
out = self._buffer.tostring() out = self._buffer_tobytes()
del self._buffer[:] del self._buffer[:]
if (self._event is not None) and not self._closed: if (self._event is not None) and not self._closed:
self._event.clear() self._event.clear()

View File

@ -140,7 +140,7 @@ class Channel (object):
if self.closed or self.eof_received or self.eof_sent or not self.active: if self.closed or self.eof_received or self.eof_sent or not self.active:
raise SSHException('Channel is not open') raise SSHException('Channel is not open')
m = Message() m = Message()
m.add_byte(chr(MSG_CHANNEL_REQUEST)) m.add_byte(cMSG_CHANNEL_REQUEST)
m.add_int(self.remote_chanid) m.add_int(self.remote_chanid)
m.add_string('pty-req') m.add_string('pty-req')
m.add_boolean(True) m.add_boolean(True)
@ -149,7 +149,7 @@ class Channel (object):
m.add_int(height) m.add_int(height)
m.add_int(width_pixels) m.add_int(width_pixels)
m.add_int(height_pixels) m.add_int(height_pixels)
m.add_string('') m.add_string(bytes())
self._event_pending() self._event_pending()
self.transport._send_user_message(m) self.transport._send_user_message(m)
self._wait_for_event() self._wait_for_event()
@ -173,7 +173,7 @@ class Channel (object):
if self.closed or self.eof_received or self.eof_sent or not self.active: if self.closed or self.eof_received or self.eof_sent or not self.active:
raise SSHException('Channel is not open') raise SSHException('Channel is not open')
m = Message() m = Message()
m.add_byte(chr(MSG_CHANNEL_REQUEST)) m.add_byte(cMSG_CHANNEL_REQUEST)
m.add_int(self.remote_chanid) m.add_int(self.remote_chanid)
m.add_string('shell') m.add_string('shell')
m.add_boolean(1) m.add_boolean(1)
@ -199,7 +199,7 @@ class Channel (object):
if self.closed or self.eof_received or self.eof_sent or not self.active: if self.closed or self.eof_received or self.eof_sent or not self.active:
raise SSHException('Channel is not open') raise SSHException('Channel is not open')
m = Message() m = Message()
m.add_byte(chr(MSG_CHANNEL_REQUEST)) m.add_byte(cMSG_CHANNEL_REQUEST)
m.add_int(self.remote_chanid) m.add_int(self.remote_chanid)
m.add_string('exec') m.add_string('exec')
m.add_boolean(True) m.add_boolean(True)
@ -225,7 +225,7 @@ class Channel (object):
if self.closed or self.eof_received or self.eof_sent or not self.active: if self.closed or self.eof_received or self.eof_sent or not self.active:
raise SSHException('Channel is not open') raise SSHException('Channel is not open')
m = Message() m = Message()
m.add_byte(chr(MSG_CHANNEL_REQUEST)) m.add_byte(cMSG_CHANNEL_REQUEST)
m.add_int(self.remote_chanid) m.add_int(self.remote_chanid)
m.add_string('subsystem') m.add_string('subsystem')
m.add_boolean(True) m.add_boolean(True)
@ -250,7 +250,7 @@ class Channel (object):
if self.closed or self.eof_received or self.eof_sent or not self.active: if self.closed or self.eof_received or self.eof_sent or not self.active:
raise SSHException('Channel is not open') raise SSHException('Channel is not open')
m = Message() m = Message()
m.add_byte(chr(MSG_CHANNEL_REQUEST)) m.add_byte(cMSG_CHANNEL_REQUEST)
m.add_int(self.remote_chanid) m.add_int(self.remote_chanid)
m.add_string('window-change') m.add_string('window-change')
m.add_boolean(False) m.add_boolean(False)
@ -304,7 +304,7 @@ class Channel (object):
# in many cases, the channel will not still be open here. # in many cases, the channel will not still be open here.
# that's fine. # that's fine.
m = Message() m = Message()
m.add_byte(chr(MSG_CHANNEL_REQUEST)) m.add_byte(cMSG_CHANNEL_REQUEST)
m.add_int(self.remote_chanid) m.add_int(self.remote_chanid)
m.add_string('exit-status') m.add_string('exit-status')
m.add_boolean(False) m.add_boolean(False)
@ -359,7 +359,7 @@ class Channel (object):
auth_cookie = binascii.hexlify(self.transport.rng.read(16)) auth_cookie = binascii.hexlify(self.transport.rng.read(16))
m = Message() m = Message()
m.add_byte(chr(MSG_CHANNEL_REQUEST)) m.add_byte(cMSG_CHANNEL_REQUEST)
m.add_int(self.remote_chanid) m.add_int(self.remote_chanid)
m.add_string('x11-req') m.add_string('x11-req')
m.add_boolean(True) m.add_boolean(True)
@ -389,7 +389,7 @@ class Channel (object):
raise SSHException('Channel is not open') raise SSHException('Channel is not open')
m = Message() m = Message()
m.add_byte(chr(MSG_CHANNEL_REQUEST)) m.add_byte(cMSG_CHANNEL_REQUEST)
m.add_int(self.remote_chanid) m.add_int(self.remote_chanid)
m.add_string('auth-agent-req@openssh.com') m.add_string('auth-agent-req@openssh.com')
m.add_boolean(False) m.add_boolean(False)
@ -451,7 +451,7 @@ class Channel (object):
.. versionadded:: 1.1 .. versionadded:: 1.1
""" """
data = '' data = bytes()
self.lock.acquire() self.lock.acquire()
try: try:
old = self.combine_stderr old = self.combine_stderr
@ -581,14 +581,14 @@ class Channel (object):
""" """
try: try:
out = self.in_buffer.read(nbytes, self.timeout) out = self.in_buffer.read(nbytes, self.timeout)
except PipeTimeout, e: except PipeTimeout:
raise socket.timeout() raise socket.timeout()
ack = self._check_add_window(len(out)) ack = self._check_add_window(len(out))
# no need to hold the channel lock when sending this # no need to hold the channel lock when sending this
if ack > 0: if ack > 0:
m = Message() m = Message()
m.add_byte(chr(MSG_CHANNEL_WINDOW_ADJUST)) m.add_byte(cMSG_CHANNEL_WINDOW_ADJUST)
m.add_int(self.remote_chanid) m.add_int(self.remote_chanid)
m.add_int(ack) m.add_int(ack)
self.transport._send_user_message(m) self.transport._send_user_message(m)
@ -629,14 +629,14 @@ class Channel (object):
""" """
try: try:
out = self.in_stderr_buffer.read(nbytes, self.timeout) out = self.in_stderr_buffer.read(nbytes, self.timeout)
except PipeTimeout, e: except PipeTimeout:
raise socket.timeout() raise socket.timeout()
ack = self._check_add_window(len(out)) ack = self._check_add_window(len(out))
# no need to hold the channel lock when sending this # no need to hold the channel lock when sending this
if ack > 0: if ack > 0:
m = Message() m = Message()
m.add_byte(chr(MSG_CHANNEL_WINDOW_ADJUST)) m.add_byte(cMSG_CHANNEL_WINDOW_ADJUST)
m.add_int(self.remote_chanid) m.add_int(self.remote_chanid)
m.add_int(ack) m.add_int(ack)
self.transport._send_user_message(m) self.transport._send_user_message(m)
@ -686,7 +686,7 @@ class Channel (object):
# eof or similar # eof or similar
return 0 return 0
m = Message() m = Message()
m.add_byte(chr(MSG_CHANNEL_DATA)) m.add_byte(cMSG_CHANNEL_DATA)
m.add_int(self.remote_chanid) m.add_int(self.remote_chanid)
m.add_string(s[:size]) m.add_string(s[:size])
finally: finally:
@ -721,7 +721,7 @@ class Channel (object):
# eof or similar # eof or similar
return 0 return 0
m = Message() m = Message()
m.add_byte(chr(MSG_CHANNEL_EXTENDED_DATA)) m.add_byte(cMSG_CHANNEL_EXTENDED_DATA)
m.add_int(self.remote_chanid) m.add_int(self.remote_chanid)
m.add_int(1) m.add_int(1)
m.add_string(s[:size]) m.add_string(s[:size])
@ -925,16 +925,16 @@ class Channel (object):
self.transport._send_user_message(m) self.transport._send_user_message(m)
def _feed(self, m): def _feed(self, m):
if type(m) is str: if isinstance(m, bytes_types):
# passed from _feed_extended # passed from _feed_extended
s = m s = m
else: else:
s = m.get_string() s = m.get_binary()
self.in_buffer.feed(s) self.in_buffer.feed(s)
def _feed_extended(self, m): def _feed_extended(self, m):
code = m.get_int() code = m.get_int()
s = m.get_string() s = m.get_binary()
if code != 1: if code != 1:
self._log(ERROR, 'unknown extended_data type %d; discarding' % code) self._log(ERROR, 'unknown extended_data type %d; discarding' % code)
return return
@ -955,7 +955,7 @@ class Channel (object):
self.lock.release() self.lock.release()
def _handle_request(self, m): def _handle_request(self, m):
key = m.get_string() key = m.get_text()
want_reply = m.get_boolean() want_reply = m.get_boolean()
server = self.transport.server_object server = self.transport.server_object
ok = False ok = False
@ -991,13 +991,13 @@ class Channel (object):
else: else:
ok = server.check_channel_env_request(self, name, value) ok = server.check_channel_env_request(self, name, value)
elif key == 'exec': elif key == 'exec':
cmd = m.get_string() cmd = m.get_text()
if server is None: if server is None:
ok = False ok = False
else: else:
ok = server.check_channel_exec_request(self, cmd) ok = server.check_channel_exec_request(self, cmd)
elif key == 'subsystem': elif key == 'subsystem':
name = m.get_string() name = m.get_text()
if server is None: if server is None:
ok = False ok = False
else: else:
@ -1014,8 +1014,8 @@ class Channel (object):
pixelheight) pixelheight)
elif key == 'x11-req': elif key == 'x11-req':
single_connection = m.get_boolean() single_connection = m.get_boolean()
auth_proto = m.get_string() auth_proto = m.get_text()
auth_cookie = m.get_string() auth_cookie = m.get_binary()
screen_number = m.get_int() screen_number = m.get_int()
if server is None: if server is None:
ok = False ok = False
@ -1033,9 +1033,9 @@ class Channel (object):
if want_reply: if want_reply:
m = Message() m = Message()
if ok: if ok:
m.add_byte(chr(MSG_CHANNEL_SUCCESS)) m.add_byte(cMSG_CHANNEL_SUCCESS)
else: else:
m.add_byte(chr(MSG_CHANNEL_FAILURE)) m.add_byte(cMSG_CHANNEL_FAILURE)
m.add_int(self.remote_chanid) m.add_int(self.remote_chanid)
self.transport._send_user_message(m) self.transport._send_user_message(m)
@ -1101,7 +1101,7 @@ class Channel (object):
if self.eof_sent: if self.eof_sent:
return None return None
m = Message() m = Message()
m.add_byte(chr(MSG_CHANNEL_EOF)) m.add_byte(cMSG_CHANNEL_EOF)
m.add_int(self.remote_chanid) m.add_int(self.remote_chanid)
self.eof_sent = True self.eof_sent = True
self._log(DEBUG, 'EOF sent (%s)', self._name) self._log(DEBUG, 'EOF sent (%s)', self._name)
@ -1113,7 +1113,7 @@ class Channel (object):
return None, None return None, None
m1 = self._send_eof() m1 = self._send_eof()
m2 = Message() m2 = Message()
m2.add_byte(chr(MSG_CHANNEL_CLOSE)) m2.add_byte(cMSG_CHANNEL_CLOSE)
m2.add_int(self.remote_chanid) m2.add_int(self.remote_chanid)
self._set_closed() self._set_closed()
# can't unlink from the Transport yet -- the remote side may still # can't unlink from the Transport yet -- the remote side may still

View File

@ -132,11 +132,10 @@ class SSHClient (object):
if self._host_keys_filename is not None: if self._host_keys_filename is not None:
self.load_host_keys(self._host_keys_filename) self.load_host_keys(self._host_keys_filename)
f = open(filename, 'w') with open(filename, 'w') as f:
for hostname, keys in self._host_keys.iteritems(): for hostname, keys in self._host_keys.items():
for keytype, key in keys.iteritems(): for keytype, key in keys.items():
f.write('%s %s %s\n' % (hostname, keytype, key.get_base64())) f.write('%s %s %s\n' % (hostname, keytype, key.get_base64()))
f.close()
def get_host_keys(self): def get_host_keys(self):
""" """
@ -266,7 +265,7 @@ class SSHClient (object):
if key_filename is None: if key_filename is None:
key_filenames = [] key_filenames = []
elif isinstance(key_filename, (str, unicode)): elif isinstance(key_filename, string_types):
key_filenames = [ key_filename ] key_filenames = [ key_filename ]
else: else:
key_filenames = key_filename key_filenames = key_filename
@ -310,8 +309,8 @@ class SSHClient (object):
chan.settimeout(timeout) chan.settimeout(timeout)
chan.exec_command(command) chan.exec_command(command)
stdin = chan.makefile('wb', bufsize) stdin = chan.makefile('wb', bufsize)
stdout = chan.makefile('rb', bufsize) stdout = chan.makefile('r', bufsize)
stderr = chan.makefile_stderr('rb', bufsize) stderr = chan.makefile_stderr('r', bufsize)
return stdin, stdout, stderr return stdin, stdout, stderr
def invoke_shell(self, term='vt100', width=80, height=24, width_pixels=0, def invoke_shell(self, term='vt100', width=80, height=24, width_pixels=0,
@ -377,7 +376,7 @@ class SSHClient (object):
two_factor = (allowed_types == ['password']) two_factor = (allowed_types == ['password'])
if not two_factor: if not two_factor:
return return
except SSHException, e: except SSHException as e:
saved_exception = e saved_exception = e
if not two_factor: if not two_factor:
@ -391,7 +390,7 @@ class SSHClient (object):
if not two_factor: if not two_factor:
return return
break break
except SSHException, e: except SSHException as e:
saved_exception = e saved_exception = e
if not two_factor and allow_agent: if not two_factor and allow_agent:
@ -407,7 +406,7 @@ class SSHClient (object):
if not two_factor: if not two_factor:
return return
break break
except SSHException, e: except SSHException as e:
saved_exception = e saved_exception = e
if not two_factor: if not two_factor:
@ -439,17 +438,15 @@ class SSHClient (object):
if not two_factor: if not two_factor:
return return
break break
except SSHException, e: except (SSHException, IOError) as e:
saved_exception = e
except IOError, e:
saved_exception = e saved_exception = e
if password is not None: if password is not None:
try: try:
self._transport.auth_password(username, password) self._transport.auth_password(username, password)
return return
except SSHException, e: except SSHException:
saved_exception = e saved_exception = sys.exc_info()[1]
elif two_factor: elif two_factor:
raise SSHException('Two-factor authentication requires a password') raise SSHException('Two-factor authentication requires a password')

View File

@ -19,12 +19,13 @@
""" """
Common constants and global variables. Common constants and global variables.
""" """
from paramiko.py3compat import *
MSG_DISCONNECT, MSG_IGNORE, MSG_UNIMPLEMENTED, MSG_DEBUG, MSG_SERVICE_REQUEST, \ MSG_DISCONNECT, MSG_IGNORE, MSG_UNIMPLEMENTED, MSG_DEBUG, MSG_SERVICE_REQUEST, \
MSG_SERVICE_ACCEPT = range(1, 7) MSG_SERVICE_ACCEPT = range(1, 7)
MSG_KEXINIT, MSG_NEWKEYS = range(20, 22) MSG_KEXINIT, MSG_NEWKEYS = range(20, 22)
MSG_USERAUTH_REQUEST, MSG_USERAUTH_FAILURE, MSG_USERAUTH_SUCCESS, \ MSG_USERAUTH_REQUEST, MSG_USERAUTH_FAILURE, MSG_USERAUTH_SUCCESS, \
MSG_USERAUTH_BANNER = range(50, 54) MSG_USERAUTH_BANNER = range(50, 54)
MSG_USERAUTH_PK_OK = 60 MSG_USERAUTH_PK_OK = 60
MSG_USERAUTH_INFO_REQUEST, MSG_USERAUTH_INFO_RESPONSE = range(60, 62) MSG_USERAUTH_INFO_REQUEST, MSG_USERAUTH_INFO_RESPONSE = range(60, 62)
MSG_GLOBAL_REQUEST, MSG_REQUEST_SUCCESS, MSG_REQUEST_FAILURE = range(80, 83) MSG_GLOBAL_REQUEST, MSG_REQUEST_SUCCESS, MSG_REQUEST_FAILURE = range(80, 83)
@ -33,6 +34,10 @@ MSG_CHANNEL_OPEN, MSG_CHANNEL_OPEN_SUCCESS, MSG_CHANNEL_OPEN_FAILURE, \
MSG_CHANNEL_EOF, MSG_CHANNEL_CLOSE, MSG_CHANNEL_REQUEST, \ MSG_CHANNEL_EOF, MSG_CHANNEL_CLOSE, MSG_CHANNEL_REQUEST, \
MSG_CHANNEL_SUCCESS, MSG_CHANNEL_FAILURE = range(90, 101) MSG_CHANNEL_SUCCESS, MSG_CHANNEL_FAILURE = range(90, 101)
for key in list(locals().keys()):
if key.startswith('MSG_'):
locals()['c' + key] = byte_chr(locals()[key])
del key
# for debugging: # for debugging:
MSG_NAMES = { MSG_NAMES = {
@ -69,7 +74,7 @@ MSG_NAMES = {
MSG_CHANNEL_REQUEST: 'channel-request', MSG_CHANNEL_REQUEST: 'channel-request',
MSG_CHANNEL_SUCCESS: 'channel-success', MSG_CHANNEL_SUCCESS: 'channel-success',
MSG_CHANNEL_FAILURE: 'channel-failure' MSG_CHANNEL_FAILURE: 'channel-failure'
} }
# authentication request return codes: # authentication request return codes:
@ -118,6 +123,42 @@ else:
import logging import logging
PY22 = False PY22 = False
zero_byte = byte_chr(0)
one_byte = byte_chr(1)
four_byte = byte_chr(4)
max_byte = byte_chr(0xff)
cr_byte = byte_chr(13)
linefeed_byte = byte_chr(10)
crlf = cr_byte + linefeed_byte
if PY2:
cr_byte_value = cr_byte
linefeed_byte_value = linefeed_byte
else:
cr_byte_value = 13
linefeed_byte_value = 10
def asbytes(s):
if not isinstance(s, bytes_types):
if isinstance(s, string_types):
s = b(s)
else:
try:
s = s.asbytes()
except Exception:
raise Exception('Unknown type')
return s
xffffffff = long(0xffffffff)
x80000000 = long(0x80000000)
o666 = 438
o660 = 432
o644 = 420
o600 = 384
o777 = 511
o700 = 448
o70 = 56
DEBUG = logging.DEBUG DEBUG = logging.DEBUG
INFO = logging.INFO INFO = logging.INFO

View File

@ -116,7 +116,7 @@ class SSHConfig (object):
ret = {} ret = {}
for match in matches: for match in matches:
for key, value in match['config'].iteritems(): for key, value in match['config'].items():
if key not in ret: if key not in ret:
# Create a copy of the original value, # Create a copy of the original value,
# else it will reference the original list # else it will reference the original list

View File

@ -56,7 +56,7 @@ class DSSKey (PKey):
else: else:
if msg is None: if msg is None:
raise SSHException('Key object may not be empty') raise SSHException('Key object may not be empty')
if msg.get_string() != 'ssh-dss': if msg.get_text() != 'ssh-dss':
raise SSHException('Invalid key') raise SSHException('Invalid key')
self.p = msg.get_mpint() self.p = msg.get_mpint()
self.q = msg.get_mpint() self.q = msg.get_mpint()
@ -64,14 +64,17 @@ class DSSKey (PKey):
self.y = msg.get_mpint() self.y = msg.get_mpint()
self.size = util.bit_length(self.p) self.size = util.bit_length(self.p)
def __str__(self): def asbytes(self):
m = Message() m = Message()
m.add_string('ssh-dss') m.add_string('ssh-dss')
m.add_mpint(self.p) m.add_mpint(self.p)
m.add_mpint(self.q) m.add_mpint(self.q)
m.add_mpint(self.g) m.add_mpint(self.g)
m.add_mpint(self.y) m.add_mpint(self.y)
return str(m) return m.asbytes()
def __str__(self):
return self.asbytes()
def __hash__(self): def __hash__(self):
h = hash(self.get_name()) h = hash(self.get_name())
@ -107,21 +110,21 @@ class DSSKey (PKey):
rstr = util.deflate_long(r, 0) rstr = util.deflate_long(r, 0)
sstr = util.deflate_long(s, 0) sstr = util.deflate_long(s, 0)
if len(rstr) < 20: if len(rstr) < 20:
rstr = '\x00' * (20 - len(rstr)) + rstr rstr = zero_byte * (20 - len(rstr)) + rstr
if len(sstr) < 20: if len(sstr) < 20:
sstr = '\x00' * (20 - len(sstr)) + sstr sstr = zero_byte * (20 - len(sstr)) + sstr
m.add_string(rstr + sstr) m.add_string(rstr + sstr)
return m return m
def verify_ssh_sig(self, data, msg): def verify_ssh_sig(self, data, msg):
if len(str(msg)) == 40: if len(msg.asbytes()) == 40:
# spies.com bug: signature has no header # spies.com bug: signature has no header
sig = str(msg) sig = msg.asbytes()
else: else:
kind = msg.get_string() kind = msg.get_text()
if kind != 'ssh-dss': if kind != 'ssh-dss':
return 0 return 0
sig = msg.get_string() sig = msg.get_binary()
# pull out (r, s) which are NOT encoded as mpints # pull out (r, s) which are NOT encoded as mpints
sigR = util.inflate_long(sig[:20], 1) sigR = util.inflate_long(sig[:20], 1)
@ -140,7 +143,7 @@ class DSSKey (PKey):
b.encode(keylist) b.encode(keylist)
except BERException: except BERException:
raise SSHException('Unable to create ber encoding of key') raise SSHException('Unable to create ber encoding of key')
return str(b) return b.asbytes()
def write_private_key_file(self, filename, password=None): def write_private_key_file(self, filename, password=None):
self._write_private_key_file('DSA', filename, self._encode_key(), password) self._write_private_key_file('DSA', filename, self._encode_key(), password)
@ -182,8 +185,8 @@ class DSSKey (PKey):
# DSAPrivateKey = { version = 0, p, q, g, y, x } # DSAPrivateKey = { version = 0, p, q, g, y, x }
try: try:
keylist = BER(data).decode() keylist = BER(data).decode()
except BERException, x: except BERException as e:
raise SSHException('Unable to parse key file: ' + str(x)) raise SSHException('Unable to parse key file: ' + str(e))
if (type(keylist) is not list) or (len(keylist) < 6) or (keylist[0] != 0): if (type(keylist) is not list) or (len(keylist) < 6) or (keylist[0] != 0):
raise SSHException('not a valid DSA private key file (bad ber encoding)') raise SSHException('not a valid DSA private key file (bad ber encoding)')
self.p = keylist[1] self.p = keylist[1]

View File

@ -56,30 +56,33 @@ class ECDSAKey (PKey):
else: else:
if msg is None: if msg is None:
raise SSHException('Key object may not be empty') raise SSHException('Key object may not be empty')
if msg.get_string() != 'ecdsa-sha2-nistp256': if msg.get_text() != 'ecdsa-sha2-nistp256':
raise SSHException('Invalid key') raise SSHException('Invalid key')
curvename = msg.get_string() curvename = msg.get_text()
if curvename != 'nistp256': if curvename != 'nistp256':
raise SSHException("Can't handle curve of type %s" % curvename) raise SSHException("Can't handle curve of type %s" % curvename)
pointinfo = msg.get_string() pointinfo = msg.get_binary()
if pointinfo[0] != "\x04": if pointinfo[0:1] != four_byte:
raise SSHException('Point compression is being used: %s'% raise SSHException('Point compression is being used: %s' %
binascii.hexlify(pointinfo)) binascii.hexlify(pointinfo))
self.verifying_key = VerifyingKey.from_string(pointinfo[1:], self.verifying_key = VerifyingKey.from_string(pointinfo[1:],
curve=curves.NIST256p) curve=curves.NIST256p)
self.size = 256 self.size = 256
def __str__(self): def asbytes(self):
key = self.verifying_key key = self.verifying_key
m = Message() m = Message()
m.add_string('ecdsa-sha2-nistp256') m.add_string('ecdsa-sha2-nistp256')
m.add_string('nistp256') m.add_string('nistp256')
point_str = "\x04" + key.to_string() point_str = four_byte + key.to_string()
m.add_string(point_str) m.add_string(point_str)
return str(m) return m.asbytes()
def __str__(self):
return self.asbytes()
def __hash__(self): def __hash__(self):
h = hash(self.get_name()) h = hash(self.get_name())
@ -106,9 +109,9 @@ class ECDSAKey (PKey):
return m return m
def verify_ssh_sig(self, data, msg): def verify_ssh_sig(self, data, msg):
if msg.get_string() != 'ecdsa-sha2-nistp256': if msg.get_text() != 'ecdsa-sha2-nistp256':
return False return False
sig = msg.get_string() sig = msg.get_binary()
# verify the signature by SHA'ing the data and encrypting it # verify the signature by SHA'ing the data and encrypting it
# using the public key. # using the public key.
@ -154,14 +157,13 @@ class ECDSAKey (PKey):
data = self._read_private_key('EC', file_obj, password) data = self._read_private_key('EC', file_obj, password)
self._decode_key(data) self._decode_key(data)
ALLOWED_PADDINGS = ['\x01', '\x02\x02', '\x03\x03\x03', '\x04\x04\x04\x04', ALLOWED_PADDINGS = [one_byte, byte_chr(2) * 2, byte_chr(3) * 3, byte_chr(4) * 4,
'\x05\x05\x05\x05\x05', '\x06\x06\x06\x06\x06\x06', byte_chr(5) * 5, byte_chr(6) * 6, byte_chr(7) * 7]
'\x07\x07\x07\x07\x07\x07\x07']
def _decode_key(self, data): def _decode_key(self, data):
s, padding = der.remove_sequence(data) s, padding = der.remove_sequence(data)
if padding: if padding:
if padding not in self.ALLOWED_PADDINGS: if padding not in self.ALLOWED_PADDINGS:
raise ValueError, "weird padding: %s" % (binascii.hexlify(empty)) raise ValueError("weird padding: %s" % u(binascii.hexlify(data)))
data = data[:-len(padding)] data = data[:-len(padding)]
key = SigningKey.from_der(data) key = SigningKey.from_der(data)
self.signing_key = key self.signing_key = key
@ -172,7 +174,7 @@ class ECDSAKey (PKey):
msg = Message() msg = Message()
msg.add_mpint(r) msg.add_mpint(r)
msg.add_mpint(s) msg.add_mpint(s)
return str(msg) return msg.asbytes()
def _sigdecode(self, sig, order): def _sigdecode(self, sig, order):
msg = Message(sig) msg = Message(sig)

View File

@ -16,7 +16,7 @@
# along with Paramiko; if not, write to the Free Software Foundation, Inc., # along with Paramiko; if not, write to the Free Software Foundation, Inc.,
# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA. # 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA.
from cStringIO import StringIO from paramiko.common import *
class BufferedFile (object): class BufferedFile (object):
@ -43,8 +43,8 @@ class BufferedFile (object):
self.newlines = None self.newlines = None
self._flags = 0 self._flags = 0
self._bufsize = self._DEFAULT_BUFSIZE self._bufsize = self._DEFAULT_BUFSIZE
self._wbuffer = StringIO() self._wbuffer = BytesIO()
self._rbuffer = '' self._rbuffer = bytes()
self._at_trailing_cr = False self._at_trailing_cr = False
self._closed = False self._closed = False
# pos - position within the file, according to the user # pos - position within the file, according to the user
@ -82,9 +82,10 @@ class BufferedFile (object):
buffering is not turned on. buffering is not turned on.
""" """
self._write_all(self._wbuffer.getvalue()) self._write_all(self._wbuffer.getvalue())
self._wbuffer = StringIO() self._wbuffer = BytesIO()
return return
if PY2:
def next(self): def next(self):
""" """
Returns the next line from the input, or raises Returns the next line from the input, or raises
@ -99,6 +100,22 @@ class BufferedFile (object):
if not line: if not line:
raise StopIteration raise StopIteration
return line return line
else:
def __next__(self):
"""
Returns the next line from the input, or raises L{StopIteration} when
EOF is hit. Unlike python file objects, it's okay to mix calls to
C{next} and L{readline}.
@raise StopIteration: when the end of the file is reached.
@return: a line read from the file.
@rtype: str
"""
line = self.readline()
if not line:
raise StopIteration
return line
def read(self, size=None): def read(self, size=None):
""" """
@ -118,7 +135,7 @@ class BufferedFile (object):
if (size is None) or (size < 0): if (size is None) or (size < 0):
# go for broke # go for broke
result = self._rbuffer result = self._rbuffer
self._rbuffer = '' self._rbuffer = bytes()
self._pos += len(result) self._pos += len(result)
while True: while True:
try: try:
@ -130,12 +147,12 @@ class BufferedFile (object):
result += new_data result += new_data
self._realpos += len(new_data) self._realpos += len(new_data)
self._pos += len(new_data) self._pos += len(new_data)
return result return result if self._flags & self.FLAG_BINARY else u(result)
if size <= len(self._rbuffer): if size <= len(self._rbuffer):
result = self._rbuffer[:size] result = self._rbuffer[:size]
self._rbuffer = self._rbuffer[size:] self._rbuffer = self._rbuffer[size:]
self._pos += len(result) self._pos += len(result)
return result return result if self._flags & self.FLAG_BINARY else u(result)
while len(self._rbuffer) < size: while len(self._rbuffer) < size:
read_size = size - len(self._rbuffer) read_size = size - len(self._rbuffer)
if self._flags & self.FLAG_BUFFERED: if self._flags & self.FLAG_BUFFERED:
@ -151,7 +168,7 @@ class BufferedFile (object):
result = self._rbuffer[:size] result = self._rbuffer[:size]
self._rbuffer = self._rbuffer[size:] self._rbuffer = self._rbuffer[size:]
self._pos += len(result) self._pos += len(result)
return result return result if self._flags & self.FLAG_BINARY else u(result)
def readline(self, size=None): def readline(self, size=None):
""" """
@ -181,11 +198,11 @@ class BufferedFile (object):
if self._at_trailing_cr and (self._flags & self.FLAG_UNIVERSAL_NEWLINE) and (len(line) > 0): if self._at_trailing_cr and (self._flags & self.FLAG_UNIVERSAL_NEWLINE) and (len(line) > 0):
# edge case: the newline may be '\r\n' and we may have read # edge case: the newline may be '\r\n' and we may have read
# only the first '\r' last time. # only the first '\r' last time.
if line[0] == '\n': if line[0] == linefeed_byte_value:
line = line[1:] line = line[1:]
self._record_newline('\r\n') self._record_newline(crlf)
else: else:
self._record_newline('\r') self._record_newline(cr_byte)
self._at_trailing_cr = False self._at_trailing_cr = False
# check size before looking for a linefeed, in case we already have # check size before looking for a linefeed, in case we already have
# enough. # enough.
@ -195,42 +212,42 @@ class BufferedFile (object):
self._rbuffer = line[size:] self._rbuffer = line[size:]
line = line[:size] line = line[:size]
self._pos += len(line) self._pos += len(line)
return line return line if self._flags & self.FLAG_BINARY else u(line)
n = size - len(line) n = size - len(line)
else: else:
n = self._bufsize n = self._bufsize
if ('\n' in line) or ((self._flags & self.FLAG_UNIVERSAL_NEWLINE) and ('\r' in line)): if (linefeed_byte in line) or ((self._flags & self.FLAG_UNIVERSAL_NEWLINE) and (cr_byte in line)):
break break
try: try:
new_data = self._read(n) new_data = self._read(n)
except EOFError: except EOFError:
new_data = None new_data = None
if (new_data is None) or (len(new_data) == 0): if (new_data is None) or (len(new_data) == 0):
self._rbuffer = '' self._rbuffer = bytes()
self._pos += len(line) self._pos += len(line)
return line return line if self._flags & self.FLAG_BINARY else u(line)
line += new_data line += new_data
self._realpos += len(new_data) self._realpos += len(new_data)
# find the newline # find the newline
pos = line.find('\n') pos = line.find(linefeed_byte)
if self._flags & self.FLAG_UNIVERSAL_NEWLINE: if self._flags & self.FLAG_UNIVERSAL_NEWLINE:
rpos = line.find('\r') rpos = line.find(cr_byte)
if (rpos >= 0) and ((rpos < pos) or (pos < 0)): if (rpos >= 0) and ((rpos < pos) or (pos < 0)):
pos = rpos pos = rpos
xpos = pos + 1 xpos = pos + 1
if (line[pos] == '\r') and (xpos < len(line)) and (line[xpos] == '\n'): if (line[pos] == cr_byte_value) and (xpos < len(line)) and (line[xpos] == linefeed_byte_value):
xpos += 1 xpos += 1
self._rbuffer = line[xpos:] self._rbuffer = line[xpos:]
lf = line[pos:xpos] lf = line[pos:xpos]
line = line[:pos] + '\n' line = line[:pos] + linefeed_byte
if (len(self._rbuffer) == 0) and (lf == '\r'): if (len(self._rbuffer) == 0) and (lf == cr_byte):
# we could read the line up to a '\r' and there could still be a # we could read the line up to a '\r' and there could still be a
# '\n' following that we read next time. note that and eat it. # '\n' following that we read next time. note that and eat it.
self._at_trailing_cr = True self._at_trailing_cr = True
else: else:
self._record_newline(lf) self._record_newline(lf)
self._pos += len(line) self._pos += len(line)
return line return line if self._flags & self.FLAG_BINARY else u(line)
def readlines(self, sizehint=None): def readlines(self, sizehint=None):
""" """
@ -243,14 +260,14 @@ class BufferedFile (object):
:return: `list` of lines read from the file. :return: `list` of lines read from the file.
""" """
lines = [] lines = []
bytes = 0 byte_count = 0
while True: while True:
line = self.readline() line = self.readline()
if len(line) == 0: if len(line) == 0:
break break
lines.append(line) lines.append(line)
bytes += len(line) byte_count += len(line)
if (sizehint is not None) and (bytes >= sizehint): if (sizehint is not None) and (byte_count >= sizehint):
break break
return lines return lines
@ -292,6 +309,7 @@ class BufferedFile (object):
:param str data: data to write :param str data: data to write
""" """
data = b(data)
if self._closed: if self._closed:
raise IOError('File is closed') raise IOError('File is closed')
if not (self._flags & self.FLAG_WRITE): if not (self._flags & self.FLAG_WRITE):
@ -302,12 +320,12 @@ class BufferedFile (object):
self._wbuffer.write(data) self._wbuffer.write(data)
if self._flags & self.FLAG_LINE_BUFFERED: if self._flags & self.FLAG_LINE_BUFFERED:
# only scan the new data for linefeed, to avoid wasting time. # only scan the new data for linefeed, to avoid wasting time.
last_newline_pos = data.rfind('\n') last_newline_pos = data.rfind(linefeed_byte)
if last_newline_pos >= 0: if last_newline_pos >= 0:
wbuf = self._wbuffer.getvalue() wbuf = self._wbuffer.getvalue()
last_newline_pos += len(wbuf) - len(data) last_newline_pos += len(wbuf) - len(data)
self._write_all(wbuf[:last_newline_pos + 1]) self._write_all(wbuf[:last_newline_pos + 1])
self._wbuffer = StringIO() self._wbuffer = BytesIO()
self._wbuffer.write(wbuf[last_newline_pos + 1:]) self._wbuffer.write(wbuf[last_newline_pos + 1:])
return return
# even if we're line buffering, if the buffer has grown past the # even if we're line buffering, if the buffer has grown past the
@ -436,7 +454,7 @@ class BufferedFile (object):
return return
if self.newlines is None: if self.newlines is None:
self.newlines = newline self.newlines = newline
elif (type(self.newlines) is str) and (self.newlines != newline): elif self.newlines != newline and isinstance(self.newlines, bytes_types):
self.newlines = (self.newlines, newline) self.newlines = (self.newlines, newline)
elif newline not in self.newlines: elif newline not in self.newlines:
self.newlines += (newline,) self.newlines += (newline,)

View File

@ -20,7 +20,10 @@
import base64 import base64
import binascii import binascii
from Crypto.Hash import SHA, HMAC from Crypto.Hash import SHA, HMAC
import UserDict try:
from collections import MutableMapping
except ImportError:
from UserDict import DictMixin as MutableMapping
from paramiko.common import * from paramiko.common import *
from paramiko.dsskey import DSSKey from paramiko.dsskey import DSSKey
@ -29,7 +32,7 @@ from paramiko.util import get_logger, constant_time_bytes_eq
from paramiko.ecdsakey import ECDSAKey from paramiko.ecdsakey import ECDSAKey
class HostKeys (UserDict.DictMixin): class HostKeys (MutableMapping):
""" """
Representation of an OpenSSH-style "known hosts" file. Host keys can be Representation of an OpenSSH-style "known hosts" file. Host keys can be
read from one or more files, and then individual hosts can be looked up to read from one or more files, and then individual hosts can be looked up to
@ -83,20 +86,19 @@ class HostKeys (UserDict.DictMixin):
:raises IOError: if there was an error reading the file :raises IOError: if there was an error reading the file
""" """
f = open(filename, 'r') with open(filename, 'r') as f:
for lineno, line in enumerate(f): for lineno, line in enumerate(f):
line = line.strip() line = line.strip()
if (len(line) == 0) or (line[0] == '#'): if (len(line) == 0) or (line[0] == '#'):
continue continue
e = HostKeyEntry.from_line(line, lineno) e = HostKeyEntry.from_line(line, lineno)
if e is not None: if e is not None:
_hostnames = e.hostnames _hostnames = e.hostnames
for h in _hostnames: for h in _hostnames:
if self.check(h, e.key): if self.check(h, e.key):
e.hostnames.remove(h) e.hostnames.remove(h)
if len(e.hostnames): if len(e.hostnames):
self._entries.append(e) self._entries.append(e)
f.close()
def save(self, filename): def save(self, filename):
""" """
@ -111,12 +113,11 @@ class HostKeys (UserDict.DictMixin):
.. versionadded:: 1.6.1 .. versionadded:: 1.6.1
""" """
f = open(filename, 'w') with open(filename, 'w') as f:
for e in self._entries: for e in self._entries:
line = e.to_line() line = e.to_line()
if line: if line:
f.write(line) f.write(line)
f.close()
def lookup(self, hostname): def lookup(self, hostname):
""" """
@ -127,12 +128,26 @@ class HostKeys (UserDict.DictMixin):
:param str hostname: the hostname (or IP) to lookup :param str hostname: the hostname (or IP) to lookup
:return: dict of `str` -> `.PKey` keys associated with this host (or ``None``) :return: dict of `str` -> `.PKey` keys associated with this host (or ``None``)
""" """
class SubDict (UserDict.DictMixin): class SubDict (MutableMapping):
def __init__(self, hostname, entries, hostkeys): def __init__(self, hostname, entries, hostkeys):
self._hostname = hostname self._hostname = hostname
self._entries = entries self._entries = entries
self._hostkeys = hostkeys self._hostkeys = hostkeys
def __iter__(self):
for k in self.keys():
yield k
def __len__(self):
return len(self.keys())
def __delitem__(self, key):
for e in list(self._entries):
if e.key.get_name() == key:
self._entries.remove(e)
else:
raise KeyError(key)
def __getitem__(self, key): def __getitem__(self, key):
for e in self._entries: for e in self._entries:
if e.key.get_name() == key: if e.key.get_name() == key:
@ -181,7 +196,7 @@ class HostKeys (UserDict.DictMixin):
host_key = k.get(key.get_name(), None) host_key = k.get(key.get_name(), None)
if host_key is None: if host_key is None:
return False return False
return str(host_key) == str(key) return host_key.asbytes() == key.asbytes()
def clear(self): def clear(self):
""" """
@ -189,6 +204,17 @@ class HostKeys (UserDict.DictMixin):
""" """
self._entries = [] self._entries = []
def __iter__(self):
for k in self.keys():
yield k
def __len__(self):
return len(self.keys())
def __delitem__(self, key):
k = self[key]
pass
def __getitem__(self, key): def __getitem__(self, key):
ret = self.lookup(key) ret = self.lookup(key)
if ret is None: if ret is None:
@ -239,10 +265,10 @@ class HostKeys (UserDict.DictMixin):
else: else:
if salt.startswith('|1|'): if salt.startswith('|1|'):
salt = salt.split('|')[2] salt = salt.split('|')[2]
salt = base64.decodestring(salt) salt = decodebytes(b(salt))
assert len(salt) == SHA.digest_size assert len(salt) == SHA.digest_size
hmac = HMAC.HMAC(salt, hostname, SHA).digest() hmac = HMAC.HMAC(salt, b(hostname), SHA).digest()
hostkey = '|1|%s|%s' % (base64.encodestring(salt), base64.encodestring(hmac)) hostkey = '|1|%s|%s' % (u(encodebytes(salt)), u(encodebytes(hmac)))
return hostkey.replace('\n', '') return hostkey.replace('\n', '')
hash_host = staticmethod(hash_host) hash_host = staticmethod(hash_host)
@ -292,17 +318,17 @@ class HostKeyEntry:
# to hold it accordingly. # to hold it accordingly.
try: try:
if keytype == 'ssh-rsa': if keytype == 'ssh-rsa':
key = RSAKey(data=base64.decodestring(key)) key = RSAKey(data=decodebytes(key))
elif keytype == 'ssh-dss': elif keytype == 'ssh-dss':
key = DSSKey(data=base64.decodestring(key)) key = DSSKey(data=decodebytes(key))
elif keytype == 'ecdsa-sha2-nistp256': elif keytype == 'ecdsa-sha2-nistp256':
key = ECDSAKey(data=base64.decodestring(key)) key = ECDSAKey(data=decodebytes(key))
else: else:
log.info("Unable to handle key of type %s" % (keytype,)) log.info("Unable to handle key of type %s" % (keytype,))
return None return None
except binascii.Error, e: except binascii.Error as e:
raise InvalidHostKey(line, e) raise InvalidHostKey(line, sys.exc_info()[1])
return cls(names, key) return cls(names, key)
from_line = classmethod(from_line) from_line = classmethod(from_line)

View File

@ -33,6 +33,8 @@ from paramiko.ssh_exception import SSHException
_MSG_KEXDH_GEX_REQUEST_OLD, _MSG_KEXDH_GEX_GROUP, _MSG_KEXDH_GEX_INIT, \ _MSG_KEXDH_GEX_REQUEST_OLD, _MSG_KEXDH_GEX_GROUP, _MSG_KEXDH_GEX_INIT, \
_MSG_KEXDH_GEX_REPLY, _MSG_KEXDH_GEX_REQUEST = range(30, 35) _MSG_KEXDH_GEX_REPLY, _MSG_KEXDH_GEX_REQUEST = range(30, 35)
c_MSG_KEXDH_GEX_REQUEST_OLD, c_MSG_KEXDH_GEX_GROUP, c_MSG_KEXDH_GEX_INIT, \
c_MSG_KEXDH_GEX_REPLY, c_MSG_KEXDH_GEX_REQUEST = [byte_chr(c) for c in range(30, 35)]
class KexGex (object): class KexGex (object):
@ -62,11 +64,11 @@ class KexGex (object):
m = Message() m = Message()
if _test_old_style: if _test_old_style:
# only used for unit tests: we shouldn't ever send this # only used for unit tests: we shouldn't ever send this
m.add_byte(chr(_MSG_KEXDH_GEX_REQUEST_OLD)) m.add_byte(c_MSG_KEXDH_GEX_REQUEST_OLD)
m.add_int(self.preferred_bits) m.add_int(self.preferred_bits)
self.old_style = True self.old_style = True
else: else:
m.add_byte(chr(_MSG_KEXDH_GEX_REQUEST)) m.add_byte(c_MSG_KEXDH_GEX_REQUEST)
m.add_int(self.min_bits) m.add_int(self.min_bits)
m.add_int(self.preferred_bits) m.add_int(self.preferred_bits)
m.add_int(self.max_bits) m.add_int(self.max_bits)
@ -94,15 +96,15 @@ class KexGex (object):
# generate an "x" (1 < x < (p-1)/2). # generate an "x" (1 < x < (p-1)/2).
q = (self.p - 1) // 2 q = (self.p - 1) // 2
qnorm = util.deflate_long(q, 0) qnorm = util.deflate_long(q, 0)
qhbyte = ord(qnorm[0]) qhbyte = byte_ord(qnorm[0])
bytes = len(qnorm) byte_count = len(qnorm)
qmask = 0xff qmask = 0xff
while not (qhbyte & 0x80): while not (qhbyte & 0x80):
qhbyte <<= 1 qhbyte <<= 1
qmask >>= 1 qmask >>= 1
while True: while True:
x_bytes = self.transport.rng.read(bytes) x_bytes = self.transport.rng.read(byte_count)
x_bytes = chr(ord(x_bytes[0]) & qmask) + x_bytes[1:] x_bytes = byte_mask(x_bytes[0], qmask) + x_bytes[1:]
x = util.inflate_long(x_bytes, 1) x = util.inflate_long(x_bytes, 1)
if (x > 1) and (x < q): if (x > 1) and (x < q):
break break
@ -135,7 +137,7 @@ class KexGex (object):
self.transport._log(DEBUG, 'Picking p (%d <= %d <= %d bits)' % (minbits, preferredbits, maxbits)) self.transport._log(DEBUG, 'Picking p (%d <= %d <= %d bits)' % (minbits, preferredbits, maxbits))
self.g, self.p = pack.get_modulus(minbits, preferredbits, maxbits) self.g, self.p = pack.get_modulus(minbits, preferredbits, maxbits)
m = Message() m = Message()
m.add_byte(chr(_MSG_KEXDH_GEX_GROUP)) m.add_byte(c_MSG_KEXDH_GEX_GROUP)
m.add_mpint(self.p) m.add_mpint(self.p)
m.add_mpint(self.g) m.add_mpint(self.g)
self.transport._send_message(m) self.transport._send_message(m)
@ -156,7 +158,7 @@ class KexGex (object):
self.transport._log(DEBUG, 'Picking p (~ %d bits)' % (self.preferred_bits,)) self.transport._log(DEBUG, 'Picking p (~ %d bits)' % (self.preferred_bits,))
self.g, self.p = pack.get_modulus(self.min_bits, self.preferred_bits, self.max_bits) self.g, self.p = pack.get_modulus(self.min_bits, self.preferred_bits, self.max_bits)
m = Message() m = Message()
m.add_byte(chr(_MSG_KEXDH_GEX_GROUP)) m.add_byte(c_MSG_KEXDH_GEX_GROUP)
m.add_mpint(self.p) m.add_mpint(self.p)
m.add_mpint(self.g) m.add_mpint(self.g)
self.transport._send_message(m) self.transport._send_message(m)
@ -175,7 +177,7 @@ class KexGex (object):
# now compute e = g^x mod p # now compute e = g^x mod p
self.e = pow(self.g, self.x, self.p) self.e = pow(self.g, self.x, self.p)
m = Message() m = Message()
m.add_byte(chr(_MSG_KEXDH_GEX_INIT)) m.add_byte(c_MSG_KEXDH_GEX_INIT)
m.add_mpint(self.e) m.add_mpint(self.e)
self.transport._send_message(m) self.transport._send_message(m)
self.transport._expect_packet(_MSG_KEXDH_GEX_REPLY) self.transport._expect_packet(_MSG_KEXDH_GEX_REPLY)
@ -187,7 +189,7 @@ class KexGex (object):
self._generate_x() self._generate_x()
self.f = pow(self.g, self.x, self.p) self.f = pow(self.g, self.x, self.p)
K = pow(self.e, self.x, self.p) K = pow(self.e, self.x, self.p)
key = str(self.transport.get_server_key()) key = self.transport.get_server_key().asbytes()
# okay, build up the hash H of (V_C || V_S || I_C || I_S || K_S || min || n || max || p || g || e || f || K) # okay, build up the hash H of (V_C || V_S || I_C || I_S || K_S || min || n || max || p || g || e || f || K)
hm = Message() hm = Message()
hm.add(self.transport.remote_version, self.transport.local_version, hm.add(self.transport.remote_version, self.transport.local_version,
@ -203,16 +205,16 @@ class KexGex (object):
hm.add_mpint(self.e) hm.add_mpint(self.e)
hm.add_mpint(self.f) hm.add_mpint(self.f)
hm.add_mpint(K) hm.add_mpint(K)
H = SHA.new(str(hm)).digest() H = SHA.new(hm.asbytes()).digest()
self.transport._set_K_H(K, H) self.transport._set_K_H(K, H)
# sign it # sign it
sig = self.transport.get_server_key().sign_ssh_data(self.transport.rng, H) sig = self.transport.get_server_key().sign_ssh_data(self.transport.rng, H)
# send reply # send reply
m = Message() m = Message()
m.add_byte(chr(_MSG_KEXDH_GEX_REPLY)) m.add_byte(c_MSG_KEXDH_GEX_REPLY)
m.add_string(key) m.add_string(key)
m.add_mpint(self.f) m.add_mpint(self.f)
m.add_string(str(sig)) m.add_string(sig)
self.transport._send_message(m) self.transport._send_message(m)
self.transport._activate_outbound() self.transport._activate_outbound()
@ -238,6 +240,6 @@ class KexGex (object):
hm.add_mpint(self.e) hm.add_mpint(self.e)
hm.add_mpint(self.f) hm.add_mpint(self.f)
hm.add_mpint(K) hm.add_mpint(K)
self.transport._set_K_H(K, SHA.new(str(hm)).digest()) self.transport._set_K_H(K, SHA.new(hm.asbytes()).digest())
self.transport._verify_key(host_key, sig) self.transport._verify_key(host_key, sig)
self.transport._activate_outbound() self.transport._activate_outbound()

View File

@ -30,11 +30,14 @@ from paramiko.ssh_exception import SSHException
_MSG_KEXDH_INIT, _MSG_KEXDH_REPLY = range(30, 32) _MSG_KEXDH_INIT, _MSG_KEXDH_REPLY = range(30, 32)
c_MSG_KEXDH_INIT, c_MSG_KEXDH_REPLY = [byte_chr(c) for c in range(30, 32)]
# draft-ietf-secsh-transport-09.txt, page 17 # draft-ietf-secsh-transport-09.txt, page 17
P = 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFFL P = 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF
G = 2 G = 2
b7fffffffffffffff = byte_chr(0x7f) + max_byte * 7
b0000000000000000 = zero_byte * 8
class KexGroup1(object): class KexGroup1(object):
@ -42,9 +45,9 @@ class KexGroup1(object):
def __init__(self, transport): def __init__(self, transport):
self.transport = transport self.transport = transport
self.x = 0L self.x = long(0)
self.e = 0L self.e = long(0)
self.f = 0L self.f = long(0)
def start_kex(self): def start_kex(self):
self._generate_x() self._generate_x()
@ -56,7 +59,7 @@ class KexGroup1(object):
# compute e = g^x mod p (where g=2), and send it # compute e = g^x mod p (where g=2), and send it
self.e = pow(G, self.x, P) self.e = pow(G, self.x, P)
m = Message() m = Message()
m.add_byte(chr(_MSG_KEXDH_INIT)) m.add_byte(c_MSG_KEXDH_INIT)
m.add_mpint(self.e) m.add_mpint(self.e)
self.transport._send_message(m) self.transport._send_message(m)
self.transport._expect_packet(_MSG_KEXDH_REPLY) self.transport._expect_packet(_MSG_KEXDH_REPLY)
@ -80,9 +83,9 @@ class KexGroup1(object):
# larger than q (but this is a tiny tiny subset of potential x). # larger than q (but this is a tiny tiny subset of potential x).
while 1: while 1:
x_bytes = self.transport.rng.read(128) x_bytes = self.transport.rng.read(128)
x_bytes = chr(ord(x_bytes[0]) & 0x7f) + x_bytes[1:] x_bytes = byte_mask(x_bytes[0], 0x7f) + x_bytes[1:]
if (x_bytes[:8] != '\x7F\xFF\xFF\xFF\xFF\xFF\xFF\xFF') and \ if (x_bytes[:8] != b7fffffffffffffff) and \
(x_bytes[:8] != '\x00\x00\x00\x00\x00\x00\x00\x00'): (x_bytes[:8] != b0000000000000000):
break break
self.x = util.inflate_long(x_bytes) self.x = util.inflate_long(x_bytes)
@ -92,7 +95,7 @@ class KexGroup1(object):
self.f = m.get_mpint() self.f = m.get_mpint()
if (self.f < 1) or (self.f > P - 1): if (self.f < 1) or (self.f > P - 1):
raise SSHException('Server kex "f" is out of range') raise SSHException('Server kex "f" is out of range')
sig = m.get_string() sig = m.get_binary()
K = pow(self.f, self.x, P) K = pow(self.f, self.x, P)
# okay, build up the hash H of (V_C || V_S || I_C || I_S || K_S || e || f || K) # okay, build up the hash H of (V_C || V_S || I_C || I_S || K_S || e || f || K)
hm = Message() hm = Message()
@ -102,7 +105,7 @@ class KexGroup1(object):
hm.add_mpint(self.e) hm.add_mpint(self.e)
hm.add_mpint(self.f) hm.add_mpint(self.f)
hm.add_mpint(K) hm.add_mpint(K)
self.transport._set_K_H(K, SHA.new(str(hm)).digest()) self.transport._set_K_H(K, SHA.new(hm.asbytes()).digest())
self.transport._verify_key(host_key, sig) self.transport._verify_key(host_key, sig)
self.transport._activate_outbound() self.transport._activate_outbound()
@ -112,7 +115,7 @@ class KexGroup1(object):
if (self.e < 1) or (self.e > P - 1): if (self.e < 1) or (self.e > P - 1):
raise SSHException('Client kex "e" is out of range') raise SSHException('Client kex "e" is out of range')
K = pow(self.e, self.x, P) K = pow(self.e, self.x, P)
key = str(self.transport.get_server_key()) key = self.transport.get_server_key().asbytes()
# okay, build up the hash H of (V_C || V_S || I_C || I_S || K_S || e || f || K) # okay, build up the hash H of (V_C || V_S || I_C || I_S || K_S || e || f || K)
hm = Message() hm = Message()
hm.add(self.transport.remote_version, self.transport.local_version, hm.add(self.transport.remote_version, self.transport.local_version,
@ -121,15 +124,15 @@ class KexGroup1(object):
hm.add_mpint(self.e) hm.add_mpint(self.e)
hm.add_mpint(self.f) hm.add_mpint(self.f)
hm.add_mpint(K) hm.add_mpint(K)
H = SHA.new(str(hm)).digest() H = SHA.new(hm.asbytes()).digest()
self.transport._set_K_H(K, H) self.transport._set_K_H(K, H)
# sign it # sign it
sig = self.transport.get_server_key().sign_ssh_data(self.transport.rng, H) sig = self.transport.get_server_key().sign_ssh_data(self.transport.rng, H)
# send reply # send reply
m = Message() m = Message()
m.add_byte(chr(_MSG_KEXDH_REPLY)) m.add_byte(c_MSG_KEXDH_REPLY)
m.add_string(key) m.add_string(key)
m.add_mpint(self.f) m.add_mpint(self.f)
m.add_string(str(sig)) m.add_string(sig)
self.transport._send_message(m) self.transport._send_message(m)
self.transport._activate_outbound() self.transport._activate_outbound()

View File

@ -21,9 +21,9 @@ Implementation of an SSH2 "message".
""" """
import struct import struct
import cStringIO
from paramiko import util from paramiko import util
from paramiko.common import *
class Message (object): class Message (object):
@ -37,6 +37,8 @@ class Message (object):
paramiko doesn't support yet. paramiko doesn't support yet.
""" """
big_int = long(0xff000000)
def __init__(self, content=None): def __init__(self, content=None):
""" """
Create a new SSH2 message. Create a new SSH2 message.
@ -46,15 +48,15 @@ class Message (object):
decomposing a message). decomposing a message).
""" """
if content != None: if content != None:
self.packet = cStringIO.StringIO(content) self.packet = BytesIO(content)
else: else:
self.packet = cStringIO.StringIO() self.packet = BytesIO()
def __str__(self): def __str__(self):
""" """
Return the byte stream content of this message, as a string. Return the byte stream content of this message, as a string/bytes obj.
""" """
return self.packet.getvalue() return self.asbytes()
def __repr__(self): def __repr__(self):
""" """
@ -62,6 +64,15 @@ class Message (object):
""" """
return 'paramiko.Message(' + repr(self.packet.getvalue()) + ')' return 'paramiko.Message(' + repr(self.packet.getvalue()) + ')'
def asbytes(self):
"""
Return the byte stream content of this Message, as bytes.
@return: the contents of this Message.
@rtype: bytes
"""
return self.packet.getvalue()
def rewind(self): def rewind(self):
""" """
Rewind the message to the beginning as if no items had been parsed Rewind the message to the beginning as if no items had been parsed
@ -99,7 +110,7 @@ class Message (object):
b = self.packet.read(n) b = self.packet.read(n)
max_pad_size = 1<<20 # Limit padding to 1 MB max_pad_size = 1<<20 # Limit padding to 1 MB
if len(b) < n and n < max_pad_size: if len(b) < n and n < max_pad_size:
return b + '\x00' * (n - len(b)) return b + zero_byte * (n - len(b))
return b return b
def get_byte(self): def get_byte(self):
@ -118,7 +129,7 @@ class Message (object):
Fetch a boolean from the stream. Fetch a boolean from the stream.
""" """
b = self.get_bytes(1) b = self.get_bytes(1)
return b != '\x00' return b != zero_byte
def get_int(self): def get_int(self):
""" """
@ -126,6 +137,19 @@ class Message (object):
:return: a 32-bit unsigned `int`. :return: a 32-bit unsigned `int`.
""" """
byte = self.get_bytes(1)
if byte == max_byte:
return util.inflate_long(self.get_binary())
byte += self.get_bytes(3)
return struct.unpack('>I', byte)[0]
def get_size(self):
"""
Fetch an int from the stream.
@return: a 32-bit unsigned integer.
@rtype: int
"""
return struct.unpack('>I', self.get_bytes(4))[0] return struct.unpack('>I', self.get_bytes(4))[0]
def get_int64(self): def get_int64(self):
@ -142,7 +166,7 @@ class Message (object):
:return: an arbitrary-length integer (`long`). :return: an arbitrary-length integer (`long`).
""" """
return util.inflate_long(self.get_string()) return util.inflate_long(self.get_binary())
def get_string(self): def get_string(self):
""" """
@ -150,7 +174,30 @@ class Message (object):
contain unprintable characters. (It's not unheard of for a string to contain unprintable characters. (It's not unheard of for a string to
contain another byte-stream message.) contain another byte-stream message.)
""" """
return self.get_bytes(self.get_int()) return self.get_bytes(self.get_size())
def get_text(self):
"""
Fetch a string from the stream. This could be a byte string and may
contain unprintable characters. (It's not unheard of for a string to
contain another byte-stream Message.)
@return: a string.
@rtype: string
"""
return u(self.get_bytes(self.get_size()))
#return self.get_bytes(self.get_size())
def get_binary(self):
"""
Fetch a string from the stream. This could be a byte string and may
contain unprintable characters. (It's not unheard of for a string to
contain another byte-stream Message.)
@return: a string.
@rtype: string
"""
return self.get_bytes(self.get_size())
def get_list(self): def get_list(self):
""" """
@ -158,7 +205,7 @@ class Message (object):
These are trivially encoded as comma-separated values in a string. These are trivially encoded as comma-separated values in a string.
""" """
return self.get_string().split(',') return self.get_text().split(',')
def add_bytes(self, b): def add_bytes(self, b):
""" """
@ -185,9 +232,19 @@ class Message (object):
:param bool b: boolean value to add :param bool b: boolean value to add
""" """
if b: if b:
self.add_byte('\x01') self.packet.write(one_byte)
else: else:
self.add_byte('\x00') self.packet.write(zero_byte)
return self
def add_size(self, n):
"""
Add an integer to the stream.
@param n: integer to add
@type n: int
"""
self.packet.write(struct.pack('>I', n))
return self return self
def add_int(self, n): def add_int(self, n):
@ -196,6 +253,10 @@ class Message (object):
:param int n: integer to add :param int n: integer to add
""" """
if n >= Message.big_int:
self.packet.write(max_byte)
self.add_string(util.deflate_long(n))
else:
self.packet.write(struct.pack('>I', n)) self.packet.write(struct.pack('>I', n))
return self return self
@ -224,7 +285,8 @@ class Message (object):
:param str s: string to add :param str s: string to add
""" """
self.add_int(len(s)) s = asbytes(s)
self.add_size(len(s))
self.packet.write(s) self.packet.write(s)
return self return self
@ -240,21 +302,14 @@ class Message (object):
return self return self
def _add(self, i): def _add(self, i):
if type(i) is str: if type(i) is bool:
return self.add_string(i)
elif type(i) is int:
return self.add_int(i)
elif type(i) is long:
if i > 0xffffffffL:
return self.add_mpint(i)
else:
return self.add_int(i)
elif type(i) is bool:
return self.add_boolean(i) return self.add_boolean(i)
elif isinstance(i, integer_types):
return self.add_int(i)
elif type(i) is list: elif type(i) is list:
return self.add_list(i) return self.add_list(i)
else: else:
raise Exception('Unknown type') return self.add_string(i)
def add(self, *seq): def add(self, *seq):
""" """

View File

@ -38,6 +38,7 @@ try:
except ImportError: except ImportError:
from Crypto.Hash.HMAC import HMAC from Crypto.Hash.HMAC import HMAC
def compute_hmac(key, message, digest_class): def compute_hmac(key, message, digest_class):
return HMAC(key, message, digest_class).digest() return HMAC(key, message, digest_class).digest()
@ -66,7 +67,7 @@ class Packetizer (object):
self.__dump_packets = False self.__dump_packets = False
self.__need_rekey = False self.__need_rekey = False
self.__init_count = 0 self.__init_count = 0
self.__remainder = '' self.__remainder = bytes()
# used for noticing when to re-key: # used for noticing when to re-key:
self.__sent_bytes = 0 self.__sent_bytes = 0
@ -86,12 +87,12 @@ class Packetizer (object):
self.__sdctr_out = False self.__sdctr_out = False
self.__mac_engine_out = None self.__mac_engine_out = None
self.__mac_engine_in = None self.__mac_engine_in = None
self.__mac_key_out = '' self.__mac_key_out = bytes()
self.__mac_key_in = '' self.__mac_key_in = bytes()
self.__compress_engine_out = None self.__compress_engine_out = None
self.__compress_engine_in = None self.__compress_engine_in = None
self.__sequence_number_out = 0L self.__sequence_number_out = 0
self.__sequence_number_in = 0L self.__sequence_number_in = 0
# lock around outbound writes (packet computation) # lock around outbound writes (packet computation)
self.__write_lock = threading.RLock() self.__write_lock = threading.RLock()
@ -152,6 +153,7 @@ class Packetizer (object):
def close(self): def close(self):
self.__closed = True self.__closed = True
self.__socket.close()
def set_hexdump(self, hexdump): def set_hexdump(self, hexdump):
self.__dump_packets = hexdump self.__dump_packets = hexdump
@ -193,7 +195,7 @@ class Packetizer (object):
:raises EOFError: :raises EOFError:
if the socket was closed before all the bytes could be read if the socket was closed before all the bytes could be read
""" """
out = '' out = bytes()
# handle over-reading from reading the banner line # handle over-reading from reading the banner line
if len(self.__remainder) > 0: if len(self.__remainder) > 0:
out = self.__remainder[:n] out = self.__remainder[:n]
@ -211,7 +213,7 @@ class Packetizer (object):
n -= len(x) n -= len(x)
except socket.timeout: except socket.timeout:
got_timeout = True got_timeout = True
except socket.error, e: except socket.error as e:
# on Linux, sometimes instead of socket.timeout, we get # on Linux, sometimes instead of socket.timeout, we get
# EAGAIN. this is a bug in recent (> 2.6.9) kernels but # EAGAIN. this is a bug in recent (> 2.6.9) kernels but
# we need to work around it. # we need to work around it.
@ -240,7 +242,7 @@ class Packetizer (object):
n = self.__socket.send(out) n = self.__socket.send(out)
except socket.timeout: except socket.timeout:
retry_write = True retry_write = True
except socket.error, e: except socket.error as e:
if (type(e.args) is tuple) and (len(e.args) > 0) and (e.args[0] == errno.EAGAIN): if (type(e.args) is tuple) and (len(e.args) > 0) and (e.args[0] == errno.EAGAIN):
retry_write = True retry_write = True
elif (type(e.args) is tuple) and (len(e.args) > 0) and (e.args[0] == errno.EINTR): elif (type(e.args) is tuple) and (len(e.args) > 0) and (e.args[0] == errno.EINTR):
@ -270,22 +272,22 @@ class Packetizer (object):
line, so it's okay to attempt large reads. line, so it's okay to attempt large reads.
""" """
buf = self.__remainder buf = self.__remainder
while not '\n' in buf: while not linefeed_byte in buf:
buf += self._read_timeout(timeout) buf += self._read_timeout(timeout)
n = buf.index('\n') n = buf.index(linefeed_byte)
self.__remainder = buf[n+1:] self.__remainder = buf[n+1:]
buf = buf[:n] buf = buf[:n]
if (len(buf) > 0) and (buf[-1] == '\r'): if (len(buf) > 0) and (buf[-1] == cr_byte_value):
buf = buf[:-1] buf = buf[:-1]
return buf return u(buf)
def send_message(self, data): def send_message(self, data):
""" """
Write a block of data using the current cipher, as an SSH block. Write a block of data using the current cipher, as an SSH block.
""" """
# encrypt this sucka # encrypt this sucka
data = str(data) data = asbytes(data)
cmd = ord(data[0]) cmd = byte_ord(data[0])
if cmd in MSG_NAMES: if cmd in MSG_NAMES:
cmd_name = MSG_NAMES[cmd] cmd_name = MSG_NAMES[cmd]
else: else:
@ -307,7 +309,7 @@ class Packetizer (object):
if self.__block_engine_out != None: if self.__block_engine_out != None:
payload = struct.pack('>I', self.__sequence_number_out) + packet payload = struct.pack('>I', self.__sequence_number_out) + packet
out += compute_hmac(self.__mac_key_out, payload, self.__mac_engine_out)[:self.__mac_size_out] out += compute_hmac(self.__mac_key_out, payload, self.__mac_engine_out)[:self.__mac_size_out]
self.__sequence_number_out = (self.__sequence_number_out + 1) & 0xffffffffL self.__sequence_number_out = (self.__sequence_number_out + 1) & xffffffff
self.write_all(out) self.write_all(out)
self.__sent_bytes += len(out) self.__sent_bytes += len(out)
@ -356,7 +358,7 @@ class Packetizer (object):
my_mac = compute_hmac(self.__mac_key_in, mac_payload, self.__mac_engine_in)[:self.__mac_size_in] my_mac = compute_hmac(self.__mac_key_in, mac_payload, self.__mac_engine_in)[:self.__mac_size_in]
if not util.constant_time_bytes_eq(my_mac, mac): if not util.constant_time_bytes_eq(my_mac, mac):
raise SSHException('Mismatched MAC') raise SSHException('Mismatched MAC')
padding = ord(packet[0]) padding = byte_ord(packet[0])
payload = packet[1:packet_size - padding] payload = packet[1:packet_size - padding]
if self.__dump_packets: if self.__dump_packets:
@ -367,7 +369,7 @@ class Packetizer (object):
msg = Message(payload[1:]) msg = Message(payload[1:])
msg.seqno = self.__sequence_number_in msg.seqno = self.__sequence_number_in
self.__sequence_number_in = (self.__sequence_number_in + 1) & 0xffffffffL self.__sequence_number_in = (self.__sequence_number_in + 1) & xffffffff
# check for rekey # check for rekey
raw_packet_size = packet_size + self.__mac_size_in + 4 raw_packet_size = packet_size + self.__mac_size_in + 4
@ -390,7 +392,7 @@ class Packetizer (object):
self.__received_packets_overflow = 0 self.__received_packets_overflow = 0
self._trigger_rekey() self._trigger_rekey()
cmd = ord(payload[0]) cmd = byte_ord(payload[0])
if cmd in MSG_NAMES: if cmd in MSG_NAMES:
cmd_name = MSG_NAMES[cmd] cmd_name = MSG_NAMES[cmd]
else: else:
@ -465,7 +467,7 @@ class Packetizer (object):
break break
except socket.timeout: except socket.timeout:
pass pass
except EnvironmentError, e: except EnvironmentError as e:
if ((type(e.args) is tuple) and (len(e.args) > 0) and if ((type(e.args) is tuple) and (len(e.args) > 0) and
(e.args[0] == errno.EINTR)): (e.args[0] == errno.EINTR)):
pass pass
@ -487,7 +489,7 @@ class Packetizer (object):
if self.__sdctr_out or self.__block_engine_out is None: if self.__sdctr_out or self.__block_engine_out is None:
# cute trick i caught openssh doing: if we're not encrypting or SDCTR mode (RFC4344), # cute trick i caught openssh doing: if we're not encrypting or SDCTR mode (RFC4344),
# don't waste random bytes for the padding # don't waste random bytes for the padding
packet += (chr(0) * padding) packet += (zero_byte * padding)
else: else:
packet += rng.read(padding) packet += rng.read(padding)
return packet return packet

View File

@ -28,6 +28,7 @@ will trigger as readable in `select <select.select>`.
import sys import sys
import os import os
import socket import socket
from paramiko.py3compat import b
def make_pipe (): def make_pipe ():
@ -64,7 +65,7 @@ class PosixPipe (object):
if self._set or self._closed: if self._set or self._closed:
return return
self._set = True self._set = True
os.write(self._wfd, '*') os.write(self._wfd, b'*')
def set_forever (self): def set_forever (self):
self._forever = True self._forever = True
@ -110,7 +111,7 @@ class WindowsPipe (object):
if self._set or self._closed: if self._set or self._closed:
return return
self._set = True self._set = True
self._wsock.send('*') self._wsock.send(b'*')
def set_forever (self): def set_forever (self):
self._forever = True self._forever = True

View File

@ -62,13 +62,16 @@ class PKey (object):
""" """
pass pass
def __str__(self): def asbytes(self):
""" """
Return a string of an SSH `.Message` made up of the public part(s) of Return a string of an SSH `.Message` made up of the public part(s) of
this key. This string is suitable for passing to `__init__` to this key. This string is suitable for passing to `__init__` to
re-create the key object later. re-create the key object later.
""" """
return '' return bytes()
def __str__(self):
return self.asbytes()
def __cmp__(self, other): def __cmp__(self, other):
""" """
@ -83,7 +86,10 @@ class PKey (object):
ho = hash(other) ho = hash(other)
if hs != ho: if hs != ho:
return cmp(hs, ho) return cmp(hs, ho)
return cmp(str(self), str(other)) return cmp(self.asbytes(), other.asbytes())
def __eq__(self, other):
return hash(self) == hash(other)
def get_name(self): def get_name(self):
""" """
@ -120,7 +126,7 @@ class PKey (object):
a 16-byte `string <str>` (binary) of the MD5 fingerprint, in SSH a 16-byte `string <str>` (binary) of the MD5 fingerprint, in SSH
format. format.
""" """
return MD5.new(str(self)).digest() return MD5.new(self.asbytes()).digest()
def get_base64(self): def get_base64(self):
""" """
@ -130,7 +136,7 @@ class PKey (object):
:return: a base64 `string <str>` containing the public part of the key. :return: a base64 `string <str>` containing the public part of the key.
""" """
return base64.encodestring(str(self)).replace('\n', '') return u(encodebytes(self.asbytes())).replace('\n', '')
def sign_ssh_data(self, rng, data): def sign_ssh_data(self, rng, data):
""" """
@ -141,7 +147,7 @@ class PKey (object):
:param str data: the data to sign. :param str data: the data to sign.
:return: an SSH signature `message <.Message>`. :return: an SSH signature `message <.Message>`.
""" """
return '' return bytes()
def verify_ssh_sig(self, data, msg): def verify_ssh_sig(self, data, msg):
""" """
@ -246,9 +252,8 @@ class PKey (object):
encrypted, and ``password`` is ``None``. encrypted, and ``password`` is ``None``.
:raises SSHException: if the key file is invalid. :raises SSHException: if the key file is invalid.
""" """
f = open(filename, 'r') with open(filename, 'r') as f:
data = self._read_private_key(tag, f, password) data = self._read_private_key(tag, f, password)
f.close()
return data return data
def _read_private_key(self, tag, f, password=None): def _read_private_key(self, tag, f, password=None):
@ -273,8 +278,8 @@ class PKey (object):
end += 1 end += 1
# if we trudged to the end of the file, just try to cope. # if we trudged to the end of the file, just try to cope.
try: try:
data = base64.decodestring(''.join(lines[start:end])) data = decodebytes(b(''.join(lines[start:end])))
except base64.binascii.Error, e: except base64.binascii.Error as e:
raise SSHException('base64 decoding error: ' + str(e)) raise SSHException('base64 decoding error: ' + str(e))
if 'proc-type' not in headers: if 'proc-type' not in headers:
# unencryped: done # unencryped: done
@ -285,7 +290,7 @@ class PKey (object):
try: try:
encryption_type, saltstr = headers['dek-info'].split(',') encryption_type, saltstr = headers['dek-info'].split(',')
except: except:
raise SSHException('Can\'t parse DEK-info in private key file') raise SSHException("Can't parse DEK-info in private key file")
if encryption_type not in self._CIPHER_TABLE: if encryption_type not in self._CIPHER_TABLE:
raise SSHException('Unknown private key cipher "%s"' % encryption_type) raise SSHException('Unknown private key cipher "%s"' % encryption_type)
# if no password was passed in, raise an exception pointing out that we need one # if no password was passed in, raise an exception pointing out that we need one
@ -294,7 +299,7 @@ class PKey (object):
cipher = self._CIPHER_TABLE[encryption_type]['cipher'] cipher = self._CIPHER_TABLE[encryption_type]['cipher']
keysize = self._CIPHER_TABLE[encryption_type]['keysize'] keysize = self._CIPHER_TABLE[encryption_type]['keysize']
mode = self._CIPHER_TABLE[encryption_type]['mode'] mode = self._CIPHER_TABLE[encryption_type]['mode']
salt = unhexlify(saltstr) salt = unhexlify(b(saltstr))
key = util.generate_key_bytes(MD5, salt, password, keysize) key = util.generate_key_bytes(MD5, salt, password, keysize)
return cipher.new(key, mode, salt).decrypt(data) return cipher.new(key, mode, salt).decrypt(data)
@ -312,33 +317,32 @@ class PKey (object):
:raises IOError: if there was an error writing the file. :raises IOError: if there was an error writing the file.
""" """
f = open(filename, 'w', 0600) with open(filename, 'w', o600) as f:
# grrr... the mode doesn't always take hold # grrr... the mode doesn't always take hold
os.chmod(filename, 0600) os.chmod(filename, o600)
self._write_private_key(tag, f, data, password) self._write_private_key(tag, f, data, password)
f.close()
def _write_private_key(self, tag, f, data, password=None): def _write_private_key(self, tag, f, data, password=None):
f.write('-----BEGIN %s PRIVATE KEY-----\n' % tag) f.write('-----BEGIN %s PRIVATE KEY-----\n' % tag)
if password is not None: if password is not None:
# since we only support one cipher here, use it # since we only support one cipher here, use it
cipher_name = self._CIPHER_TABLE.keys()[0] cipher_name = list(self._CIPHER_TABLE.keys())[0]
cipher = self._CIPHER_TABLE[cipher_name]['cipher'] cipher = self._CIPHER_TABLE[cipher_name]['cipher']
keysize = self._CIPHER_TABLE[cipher_name]['keysize'] keysize = self._CIPHER_TABLE[cipher_name]['keysize']
blocksize = self._CIPHER_TABLE[cipher_name]['blocksize'] blocksize = self._CIPHER_TABLE[cipher_name]['blocksize']
mode = self._CIPHER_TABLE[cipher_name]['mode'] mode = self._CIPHER_TABLE[cipher_name]['mode']
salt = rng.read(8) salt = rng.read(16)
key = util.generate_key_bytes(MD5, salt, password, keysize) key = util.generate_key_bytes(MD5, salt, password, keysize)
if len(data) % blocksize != 0: if len(data) % blocksize != 0:
n = blocksize - len(data) % blocksize n = blocksize - len(data) % blocksize
#data += rng.read(n) #data += rng.read(n)
# that would make more sense ^, but it confuses openssh. # that would make more sense ^, but it confuses openssh.
data += '\0' * n data += zero_byte * n
data = cipher.new(key, mode, salt).encrypt(data) data = cipher.new(key, mode, salt).encrypt(data)
f.write('Proc-Type: 4,ENCRYPTED\n') f.write('Proc-Type: 4,ENCRYPTED\n')
f.write('DEK-Info: %s,%s\n' % (cipher_name, hexlify(salt).upper())) f.write('DEK-Info: %s,%s\n' % (cipher_name, u(hexlify(salt)).upper()))
f.write('\n') f.write('\n')
s = base64.encodestring(data) s = u(encodebytes(data))
# re-wrap to 64-char lines # re-wrap to 64-char lines
s = ''.join(s.split('\n')) s = ''.join(s.split('\n'))
s = '\n'.join([s[i : i+64] for i in range(0, len(s), 64)]) s = '\n'.join([s[i : i+64] for i in range(0, len(s), 64)])

View File

@ -24,6 +24,7 @@ from Crypto.Util import number
from paramiko import util from paramiko import util
from paramiko.ssh_exception import SSHException from paramiko.ssh_exception import SSHException
from paramiko.common import *
def _generate_prime(bits, rng): def _generate_prime(bits, rng):
@ -33,7 +34,7 @@ def _generate_prime(bits, rng):
# loop catches the case where we increment n into a higher bit-range # loop catches the case where we increment n into a higher bit-range
x = rng.read((bits+7) // 8) x = rng.read((bits+7) // 8)
if hbyte_mask > 0: if hbyte_mask > 0:
x = chr(ord(x[0]) & hbyte_mask) + x[1:] x = byte_mask(x[0], hbyte_mask) + x[1:]
n = util.inflate_long(x, 1) n = util.inflate_long(x, 1)
n |= 1 n |= 1
n |= (1 << (bits - 1)) n |= (1 << (bits - 1))
@ -46,7 +47,7 @@ def _generate_prime(bits, rng):
def _roll_random(rng, n): def _roll_random(rng, n):
"returns a random # from 0 to N-1" "returns a random # from 0 to N-1"
bits = util.bit_length(n-1) bits = util.bit_length(n-1)
bytes = (bits + 7) // 8 byte_count = (bits + 7) // 8
hbyte_mask = pow(2, bits % 8) - 1 hbyte_mask = pow(2, bits % 8) - 1
# so here's the plan: # so here's the plan:
@ -56,9 +57,9 @@ def _roll_random(rng, n):
# fits, so i can't guarantee that this loop will ever finish, but the odds # fits, so i can't guarantee that this loop will ever finish, but the odds
# of it looping forever should be infinitesimal. # of it looping forever should be infinitesimal.
while True: while True:
x = rng.read(bytes) x = rng.read(byte_count)
if hbyte_mask > 0: if hbyte_mask > 0:
x = chr(ord(x[0]) & hbyte_mask) + x[1:] x = byte_mask(x[0], hbyte_mask) + x[1:]
num = util.inflate_long(x, 1) num = util.inflate_long(x, 1)
if num < n: if num < n:
break break
@ -112,7 +113,7 @@ class ModulusPack (object):
:raises IOError: passed from any file operations that fail. :raises IOError: passed from any file operations that fail.
""" """
self.pack = {} self.pack = {}
f = open(filename, 'r') with open(filename, 'r') as f:
for line in f: for line in f:
line = line.strip() line = line.strip()
if (len(line) == 0) or (line[0] == '#'): if (len(line) == 0) or (line[0] == '#'):
@ -121,11 +122,9 @@ class ModulusPack (object):
self._parse_modulus(line) self._parse_modulus(line)
except: except:
continue continue
f.close()
def get_modulus(self, min, prefer, max): def get_modulus(self, min, prefer, max):
bitsizes = self.pack.keys() bitsizes = sorted(self.pack.keys())
bitsizes.sort()
if len(bitsizes) == 0: if len(bitsizes) == 0:
raise SSHException('no moduli available') raise SSHException('no moduli available')
good = -1 good = -1

View File

@ -59,7 +59,7 @@ class ProxyCommand(object):
""" """
try: try:
self.process.stdin.write(content) self.process.stdin.write(content)
except IOError, e: except IOError as e:
# There was a problem with the child process. It probably # There was a problem with the child process. It probably
# died and we can't proceed. The best option here is to # died and we can't proceed. The best option here is to
# raise an exception informing the user that the informed # raise an exception informing the user that the informed
@ -95,7 +95,7 @@ class ProxyCommand(object):
return result return result
except socket.timeout: except socket.timeout:
raise # socket.timeout is a subclass of IOError raise # socket.timeout is a subclass of IOError
except IOError, e: except IOError as e:
raise ProxyCommandFailure(' '.join(self.cmd), e.strerror) raise ProxyCommandFailure(' '.join(self.cmd), e.strerror)
def close(self): def close(self):

160
paramiko/py3compat.py Normal file
View File

@ -0,0 +1,160 @@
import sys
import base64
__all__ = ['PY2', 'string_types', 'integer_types', 'text_type', 'bytes_types', 'bytes', 'long', 'input',
'decodebytes', 'encodebytes', 'bytestring', 'byte_ord', 'byte_chr', 'byte_mask',
'b', 'u', 'b2s', 'StringIO', 'BytesIO', 'is_callable', 'MAXSIZE', 'next']
PY2 = sys.version_info[0] < 3
if PY2:
string_types = basestring
text_type = unicode
bytes_types = str
bytes = str
integer_types = (int, long)
long = long
input = raw_input
decodebytes = base64.decodestring
encodebytes = base64.encodestring
def bytestring(s): # NOQA
if isinstance(s, unicode):
return s.encode('utf-8')
return s
byte_ord = ord # NOQA
byte_chr = chr # NOQA
def byte_mask(c, mask):
return chr(ord(c) & mask)
def b(s, encoding='utf8'): # NOQA
"""cast unicode or bytes to bytes"""
if isinstance(s, str):
return s
elif isinstance(s, unicode):
return s.encode(encoding)
else:
raise TypeError("Expected unicode or bytes, got %r" % s)
def u(s, encoding='utf8'): # NOQA
"""cast bytes or unicode to unicode"""
if isinstance(s, str):
return s.decode(encoding)
elif isinstance(s, unicode):
return s
else:
raise TypeError("Expected unicode or bytes, got %r" % s)
def b2s(s):
return s
try:
import cStringIO
StringIO = cStringIO.StringIO # NOQA
except ImportError:
import StringIO
StringIO = StringIO.StringIO # NOQA
BytesIO = StringIO
def is_callable(c): # NOQA
return callable(c)
def get_next(c): # NOQA
return c.next
def next(c):
return c.next()
# It's possible to have sizeof(long) != sizeof(Py_ssize_t).
class X(object):
def __len__(self):
return 1 << 31
try:
len(X())
except OverflowError:
# 32-bit
MAXSIZE = int((1 << 31) - 1) # NOQA
else:
# 64-bit
MAXSIZE = int((1 << 63) - 1) # NOQA
del X
else:
import collections
import struct
string_types = str
text_type = str
bytes = bytes
bytes_types = bytes
integer_types = int
class long(int):
pass
input = input
decodebytes = base64.decodebytes
encodebytes = base64.encodebytes
def bytestring(s):
return s
def byte_ord(c):
assert isinstance(c, int)
return c
def byte_chr(c):
assert isinstance(c, int)
return struct.pack('B', c)
def byte_mask(c, mask):
assert isinstance(c, int)
return struct.pack('B', c & mask)
def b(s, encoding='utf8'):
"""cast unicode or bytes to bytes"""
if isinstance(s, bytes):
return s
elif isinstance(s, str):
return s.encode(encoding)
else:
raise TypeError("Expected unicode or bytes, got %r" % s)
def u(s, encoding='utf8'):
"""cast bytes or unicode to unicode"""
if isinstance(s, bytes):
return s.decode(encoding)
elif isinstance(s, str):
return s
else:
raise TypeError("Expected unicode or bytes, got %r" % s)
def b2s(s):
return s.decode() if isinstance(s, bytes) else s
import io
StringIO = io.StringIO # NOQA
BytesIO = io.BytesIO # NOQA
def is_callable(c):
return isinstance(c, collections.Callable)
def get_next(c):
return c.__next__
next = next
MAXSIZE = sys.maxsize # NOQA

View File

@ -31,6 +31,8 @@ from paramiko.ber import BER, BERException
from paramiko.pkey import PKey from paramiko.pkey import PKey
from paramiko.ssh_exception import SSHException from paramiko.ssh_exception import SSHException
SHA1_DIGESTINFO = b'\x30\x21\x30\x09\x06\x05\x2b\x0e\x03\x02\x1a\x05\x00\x04\x14'
class RSAKey (PKey): class RSAKey (PKey):
""" """
@ -57,18 +59,21 @@ class RSAKey (PKey):
else: else:
if msg is None: if msg is None:
raise SSHException('Key object may not be empty') raise SSHException('Key object may not be empty')
if msg.get_string() != 'ssh-rsa': if msg.get_text() != 'ssh-rsa':
raise SSHException('Invalid key') raise SSHException('Invalid key')
self.e = msg.get_mpint() self.e = msg.get_mpint()
self.n = msg.get_mpint() self.n = msg.get_mpint()
self.size = util.bit_length(self.n) self.size = util.bit_length(self.n)
def __str__(self): def asbytes(self):
m = Message() m = Message()
m.add_string('ssh-rsa') m.add_string('ssh-rsa')
m.add_mpint(self.e) m.add_mpint(self.e)
m.add_mpint(self.n) m.add_mpint(self.n)
return str(m) return m.asbytes()
def __str__(self):
return self.asbytes()
def __hash__(self): def __hash__(self):
h = hash(self.get_name()) h = hash(self.get_name())
@ -88,16 +93,16 @@ class RSAKey (PKey):
def sign_ssh_data(self, rpool, data): def sign_ssh_data(self, rpool, data):
digest = SHA.new(data).digest() digest = SHA.new(data).digest()
rsa = RSA.construct((long(self.n), long(self.e), long(self.d))) rsa = RSA.construct((long(self.n), long(self.e), long(self.d)))
sig = util.deflate_long(rsa.sign(self._pkcs1imify(digest), '')[0], 0) sig = util.deflate_long(rsa.sign(self._pkcs1imify(digest), bytes())[0], 0)
m = Message() m = Message()
m.add_string('ssh-rsa') m.add_string('ssh-rsa')
m.add_string(sig) m.add_string(sig)
return m return m
def verify_ssh_sig(self, data, msg): def verify_ssh_sig(self, data, msg):
if msg.get_string() != 'ssh-rsa': if msg.get_text() != 'ssh-rsa':
return False return False
sig = util.inflate_long(msg.get_string(), True) sig = util.inflate_long(msg.get_binary(), True)
# verify the signature by SHA'ing the data and encrypting it using the # verify the signature by SHA'ing the data and encrypting it using the
# public key. some wackiness ensues where we "pkcs1imify" the 20-byte # public key. some wackiness ensues where we "pkcs1imify" the 20-byte
# hash into a string as long as the RSA key. # hash into a string as long as the RSA key.
@ -116,7 +121,7 @@ class RSAKey (PKey):
b.encode(keylist) b.encode(keylist)
except BERException: except BERException:
raise SSHException('Unable to create ber encoding of key') raise SSHException('Unable to create ber encoding of key')
return str(b) return b.asbytes()
def write_private_key_file(self, filename, password=None): def write_private_key_file(self, filename, password=None):
self._write_private_key_file('RSA', filename, self._encode_key(), password) self._write_private_key_file('RSA', filename, self._encode_key(), password)
@ -152,10 +157,9 @@ class RSAKey (PKey):
turn a 20-byte SHA1 hash into a blob of data as large as the key's N, turn a 20-byte SHA1 hash into a blob of data as large as the key's N,
using PKCS1's \"emsa-pkcs1-v1_5\" encoding. totally bizarre. using PKCS1's \"emsa-pkcs1-v1_5\" encoding. totally bizarre.
""" """
SHA1_DIGESTINFO = '\x30\x21\x30\x09\x06\x05\x2b\x0e\x03\x02\x1a\x05\x00\x04\x14'
size = len(util.deflate_long(self.n, 0)) size = len(util.deflate_long(self.n, 0))
filler = '\xff' * (size - len(SHA1_DIGESTINFO) - len(data) - 3) filler = max_byte * (size - len(SHA1_DIGESTINFO) - len(data) - 3)
return '\x00\x01' + filler + '\x00' + SHA1_DIGESTINFO + data return zero_byte + one_byte + filler + zero_byte + SHA1_DIGESTINFO + data
def _from_private_key_file(self, filename, password): def _from_private_key_file(self, filename, password):
data = self._read_private_key_file('RSA', filename, password) data = self._read_private_key_file('RSA', filename, password)

View File

@ -514,7 +514,7 @@ class InteractiveQuery (object):
self.instructions = instructions self.instructions = instructions
self.prompts = [] self.prompts = []
for x in prompts: for x in prompts:
if (type(x) is str) or (type(x) is unicode): if isinstance(x, string_types):
self.add_prompt(x) self.add_prompt(x)
else: else:
self.add_prompt(x[0], x[1]) self.add_prompt(x[0], x[1])
@ -576,7 +576,7 @@ class SubsystemHandler (threading.Thread):
try: try:
self.__transport._log(DEBUG, 'Starting handler for subsystem %s' % self.__name) self.__transport._log(DEBUG, 'Starting handler for subsystem %s' % self.__name)
self.start_subsystem(self.__name, self.__transport, self.__channel) self.start_subsystem(self.__name, self.__transport, self.__channel)
except Exception, e: except Exception as e:
self.__transport._log(ERROR, 'Exception in subsystem handler for "%s": %s' % self.__transport._log(ERROR, 'Exception in subsystem handler for "%s": %s' %
(self.__name, str(e))) (self.__name, str(e)))
self.__transport._log(ERROR, util.tb_strings()) self.__transport._log(ERROR, util.tb_strings())

View File

@ -86,7 +86,7 @@ CMD_NAMES = {
CMD_ATTRS: 'attrs', CMD_ATTRS: 'attrs',
CMD_EXTENDED: 'extended', CMD_EXTENDED: 'extended',
CMD_EXTENDED_REPLY: 'extended_reply' CMD_EXTENDED_REPLY: 'extended_reply'
} }
class SFTPError (Exception): class SFTPError (Exception):
@ -125,7 +125,7 @@ class BaseSFTP (object):
msg = Message() msg = Message()
msg.add_int(_VERSION) msg.add_int(_VERSION)
msg.add(*extension_pairs) msg.add(*extension_pairs)
self._send_packet(CMD_VERSION, str(msg)) self._send_packet(CMD_VERSION, msg)
return version return version
def _log(self, level, msg, *args): def _log(self, level, msg, *args):
@ -142,7 +142,7 @@ class BaseSFTP (object):
return return
def _read_all(self, n): def _read_all(self, n):
out = '' out = bytes()
while n > 0: while n > 0:
if isinstance(self.sock, socket.socket): if isinstance(self.sock, socket.socket):
# sometimes sftp is used directly over a socket instead of # sometimes sftp is used directly over a socket instead of
@ -166,7 +166,8 @@ class BaseSFTP (object):
def _send_packet(self, t, packet): def _send_packet(self, t, packet):
#self._log(DEBUG2, 'write: %s (len=%d)' % (CMD_NAMES.get(t, '0x%02x' % t), len(packet))) #self._log(DEBUG2, 'write: %s (len=%d)' % (CMD_NAMES.get(t, '0x%02x' % t), len(packet)))
out = struct.pack('>I', len(packet) + 1) + chr(t) + packet packet = asbytes(packet)
out = struct.pack('>I', len(packet) + 1) + byte_chr(t) + packet
if self.ultra_debug: if self.ultra_debug:
self._log(DEBUG, util.format_binary(out, 'OUT: ')) self._log(DEBUG, util.format_binary(out, 'OUT: '))
self._write_all(out) self._write_all(out)
@ -175,14 +176,14 @@ class BaseSFTP (object):
x = self._read_all(4) x = self._read_all(4)
# most sftp servers won't accept packets larger than about 32k, so # most sftp servers won't accept packets larger than about 32k, so
# anything with the high byte set (> 16MB) is just garbage. # anything with the high byte set (> 16MB) is just garbage.
if x[0] != '\x00': if byte_ord(x[0]):
raise SFTPError('Garbage packet received') raise SFTPError('Garbage packet received')
size = struct.unpack('>I', x)[0] size = struct.unpack('>I', x)[0]
data = self._read_all(size) data = self._read_all(size)
if self.ultra_debug: if self.ultra_debug:
self._log(DEBUG, util.format_binary(data, 'IN: ')); self._log(DEBUG, util.format_binary(data, 'IN: '));
if size > 0: if size > 0:
t = ord(data[0]) t = byte_ord(data[0])
#self._log(DEBUG2, 'read: %s (len=%d)' % (CMD_NAMES.get(t), '0x%02x' % t, len(data)-1)) #self._log(DEBUG2, 'read: %s (len=%d)' % (CMD_NAMES.get(t), '0x%02x' % t, len(data)-1))
return t, data[1:] return t, data[1:]
return 0, '' return 0, bytes()

View File

@ -45,7 +45,7 @@ class SFTPAttributes (object):
FLAG_UIDGID = 2 FLAG_UIDGID = 2
FLAG_PERMISSIONS = 4 FLAG_PERMISSIONS = 4
FLAG_AMTIME = 8 FLAG_AMTIME = 8
FLAG_EXTENDED = 0x80000000L FLAG_EXTENDED = x80000000
def __init__(self): def __init__(self):
""" """
@ -141,7 +141,7 @@ class SFTPAttributes (object):
msg.add_int(long(self.st_mtime)) msg.add_int(long(self.st_mtime))
if self._flags & self.FLAG_EXTENDED: if self._flags & self.FLAG_EXTENDED:
msg.add_int(len(self.attr)) msg.add_int(len(self.attr))
for key, val in self.attr.iteritems(): for key, val in self.attr.items():
msg.add_string(key) msg.add_string(key)
msg.add_string(val) msg.add_string(val)
return return
@ -156,7 +156,7 @@ class SFTPAttributes (object):
out += 'mode=' + oct(self.st_mode) + ' ' out += 'mode=' + oct(self.st_mode) + ' '
if (self.st_atime is not None) and (self.st_mtime is not None): if (self.st_atime is not None) and (self.st_mtime is not None):
out += 'atime=%d mtime=%d ' % (self.st_atime, self.st_mtime) out += 'atime=%d mtime=%d ' % (self.st_atime, self.st_mtime)
for k, v in self.attr.iteritems(): for k, v in self.attr.items():
out += '"%s"=%r ' % (str(k), v) out += '"%s"=%r ' % (str(k), v)
out += ']' out += ']'
return out return out
@ -192,13 +192,13 @@ class SFTPAttributes (object):
ks = 's' ks = 's'
else: else:
ks = '?' ks = '?'
ks += self._rwx((self.st_mode & 0700) >> 6, self.st_mode & stat.S_ISUID) ks += self._rwx((self.st_mode & o700) >> 6, self.st_mode & stat.S_ISUID)
ks += self._rwx((self.st_mode & 070) >> 3, self.st_mode & stat.S_ISGID) ks += self._rwx((self.st_mode & o70) >> 3, self.st_mode & stat.S_ISGID)
ks += self._rwx(self.st_mode & 7, self.st_mode & stat.S_ISVTX, True) ks += self._rwx(self.st_mode & 7, self.st_mode & stat.S_ISVTX, True)
else: else:
ks = '?---------' ks = '?---------'
# compute display date # compute display date
if (self.st_mtime is None) or (self.st_mtime == 0xffffffffL): if (self.st_mtime is None) or (self.st_mtime == xffffffff):
# shouldn't really happen # shouldn't really happen
datestr = '(unknown date)' datestr = '(unknown date)'
else: else:
@ -219,3 +219,5 @@ class SFTPAttributes (object):
return '%s 1 %-8d %-8d %8d %-12s %s' % (ks, uid, gid, self.st_size, datestr, filename) return '%s 1 %-8d %-8d %8d %-12s %s' % (ks, uid, gid, self.st_size, datestr, filename)
def asbytes(self):
return b(str(self))

View File

@ -39,12 +39,13 @@ def _to_unicode(s):
""" """
try: try:
return s.encode('ascii') return s.encode('ascii')
except UnicodeError: except (UnicodeError, AttributeError):
try: try:
return s.decode('utf-8') return s.decode('utf-8')
except UnicodeError: except UnicodeError:
return s return s
b_slash = b'/'
class SFTPClient(BaseSFTP): class SFTPClient(BaseSFTP):
""" """
@ -82,7 +83,7 @@ class SFTPClient(BaseSFTP):
self.ultra_debug = transport.get_hexdump() self.ultra_debug = transport.get_hexdump()
try: try:
server_version = self._send_version() server_version = self._send_version()
except EOFError, x: except EOFError:
raise SSHException('EOF during negotiation') raise SSHException('EOF during negotiation')
self._log(INFO, 'Opened sftp connection (server version %d)' % server_version) self._log(INFO, 'Opened sftp connection (server version %d)' % server_version)
@ -162,20 +163,20 @@ class SFTPClient(BaseSFTP):
t, msg = self._request(CMD_OPENDIR, path) t, msg = self._request(CMD_OPENDIR, path)
if t != CMD_HANDLE: if t != CMD_HANDLE:
raise SFTPError('Expected handle') raise SFTPError('Expected handle')
handle = msg.get_string() handle = msg.get_binary()
filelist = [] filelist = []
while True: while True:
try: try:
t, msg = self._request(CMD_READDIR, handle) t, msg = self._request(CMD_READDIR, handle)
except EOFError, e: except EOFError:
# done with handle # done with handle
break break
if t != CMD_NAME: if t != CMD_NAME:
raise SFTPError('Expected name response') raise SFTPError('Expected name response')
count = msg.get_int() count = msg.get_int()
for i in range(count): for i in range(count):
filename = _to_unicode(msg.get_string()) filename = msg.get_text()
longname = _to_unicode(msg.get_string()) longname = msg.get_text()
attr = SFTPAttributes._from_msg(msg, filename, longname) attr = SFTPAttributes._from_msg(msg, filename, longname)
if (filename != '.') and (filename != '..'): if (filename != '.') and (filename != '..'):
filelist.append(attr) filelist.append(attr)
@ -231,7 +232,7 @@ class SFTPClient(BaseSFTP):
t, msg = self._request(CMD_OPEN, filename, imode, attrblock) t, msg = self._request(CMD_OPEN, filename, imode, attrblock)
if t != CMD_HANDLE: if t != CMD_HANDLE:
raise SFTPError('Expected handle') raise SFTPError('Expected handle')
handle = msg.get_string() handle = msg.get_binary()
self._log(DEBUG, 'open(%r, %r) -> %s' % (filename, mode, hexlify(handle))) self._log(DEBUG, 'open(%r, %r) -> %s' % (filename, mode, hexlify(handle)))
return SFTPFile(self, handle, mode, bufsize) return SFTPFile(self, handle, mode, bufsize)
@ -268,7 +269,7 @@ class SFTPClient(BaseSFTP):
self._log(DEBUG, 'rename(%r, %r)' % (oldpath, newpath)) self._log(DEBUG, 'rename(%r, %r)' % (oldpath, newpath))
self._request(CMD_RENAME, oldpath, newpath) self._request(CMD_RENAME, oldpath, newpath)
def mkdir(self, path, mode=0777): def mkdir(self, path, mode=o777):
""" """
Create a folder (directory) named ``path`` with numeric mode ``mode``. Create a folder (directory) named ``path`` with numeric mode ``mode``.
The default mode is 0777 (octal). On some systems, mode is ignored. The default mode is 0777 (octal). On some systems, mode is ignored.
@ -347,8 +348,7 @@ class SFTPClient(BaseSFTP):
""" """
dest = self._adjust_cwd(dest) dest = self._adjust_cwd(dest)
self._log(DEBUG, 'symlink(%r, %r)' % (source, dest)) self._log(DEBUG, 'symlink(%r, %r)' % (source, dest))
if type(source) is unicode: source = bytestring(source)
source = source.encode('utf-8')
self._request(CMD_SYMLINK, source, dest) self._request(CMD_SYMLINK, source, dest)
def chmod(self, path, mode): def chmod(self, path, mode):
@ -462,9 +462,9 @@ class SFTPClient(BaseSFTP):
count = msg.get_int() count = msg.get_int()
if count != 1: if count != 1:
raise SFTPError('Realpath returned %d results' % count) raise SFTPError('Realpath returned %d results' % count)
return _to_unicode(msg.get_string()) return msg.get_text()
def chdir(self, path): def chdir(self, path=None):
""" """
Change the "current directory" of this SFTP session. Since SFTP Change the "current directory" of this SFTP session. Since SFTP
doesn't really have the concept of a current working directory, this is doesn't really have the concept of a current working directory, this is
@ -484,7 +484,7 @@ class SFTPClient(BaseSFTP):
return return
if not stat.S_ISDIR(self.stat(path).st_mode): if not stat.S_ISDIR(self.stat(path).st_mode):
raise SFTPError(errno.ENOTDIR, "%s: %s" % (os.strerror(errno.ENOTDIR), path)) raise SFTPError(errno.ENOTDIR, "%s: %s" % (os.strerror(errno.ENOTDIR), path))
self._cwd = self.normalize(path).encode('utf-8') self._cwd = b(self.normalize(path))
def getcwd(self): def getcwd(self):
""" """
@ -494,7 +494,7 @@ class SFTPClient(BaseSFTP):
.. versionadded:: 1.4 .. versionadded:: 1.4
""" """
return self._cwd return self._cwd and u(self._cwd)
def putfo(self, fl, remotepath, file_size=0, callback=None, confirm=True): def putfo(self, fl, remotepath, file_size=0, callback=None, confirm=True):
""" """
@ -525,10 +525,9 @@ class SFTPClient(BaseSFTP):
.. versionchanged:: 1.7.4 .. versionchanged:: 1.7.4
Began returning rich attribute objects. Began returning rich attribute objects.
""" """
fr = self.file(remotepath, 'wb') with self.file(remotepath, 'wb') as fr:
fr.set_pipelined(True) fr.set_pipelined(True)
size = 0 size = 0
try:
while True: while True:
data = fl.read(32768) data = fl.read(32768)
fr.write(data) fr.write(data)
@ -537,8 +536,6 @@ class SFTPClient(BaseSFTP):
callback(size, file_size) callback(size, file_size)
if len(data) == 0: if len(data) == 0:
break break
finally:
fr.close()
if confirm: if confirm:
s = self.stat(remotepath) s = self.stat(remotepath)
if s.st_size != size: if s.st_size != size:
@ -573,11 +570,8 @@ class SFTPClient(BaseSFTP):
``confirm`` param added. ``confirm`` param added.
""" """
file_size = os.stat(localpath).st_size file_size = os.stat(localpath).st_size
fl = file(localpath, 'rb') with open(localpath, 'rb') as fl:
try:
return self.putfo(fl, remotepath, os.stat(localpath).st_size, callback, confirm) return self.putfo(fl, remotepath, os.stat(localpath).st_size, callback, confirm)
finally:
fl.close()
def getfo(self, remotepath, fl, callback=None): def getfo(self, remotepath, fl, callback=None):
""" """
@ -598,10 +592,9 @@ class SFTPClient(BaseSFTP):
.. versionchanged:: 1.7.4 .. versionchanged:: 1.7.4
Added the ``callable`` param. Added the ``callable`` param.
""" """
fr = self.file(remotepath, 'rb') with self.open(remotepath, 'rb') as fr:
file_size = self.stat(remotepath).st_size file_size = self.stat(remotepath).st_size
fr.prefetch() fr.prefetch()
try:
size = 0 size = 0
while True: while True:
data = fr.read(32768) data = fr.read(32768)
@ -611,8 +604,6 @@ class SFTPClient(BaseSFTP):
callback(size, file_size) callback(size, file_size)
if len(data) == 0: if len(data) == 0:
break break
finally:
fr.close()
return size return size
def get(self, remotepath, localpath, callback=None): def get(self, remotepath, localpath, callback=None):
@ -632,11 +623,8 @@ class SFTPClient(BaseSFTP):
Added the ``callback`` param Added the ``callback`` param
""" """
file_size = self.stat(remotepath).st_size file_size = self.stat(remotepath).st_size
fl = file(localpath, 'wb') with open(localpath, 'wb') as fl:
try:
size = self.getfo(remotepath, fl, callback) size = self.getfo(remotepath, fl, callback)
finally:
fl.close()
s = os.stat(localpath) s = os.stat(localpath)
if s.st_size != size: if s.st_size != size:
raise IOError('size mismatch in get! %d != %d' % (s.st_size, size)) raise IOError('size mismatch in get! %d != %d' % (s.st_size, size))
@ -656,11 +644,11 @@ class SFTPClient(BaseSFTP):
msg = Message() msg = Message()
msg.add_int(self.request_number) msg.add_int(self.request_number)
for item in arg: for item in arg:
if isinstance(item, int): if isinstance(item, long):
msg.add_int(item)
elif isinstance(item, long):
msg.add_int64(item) msg.add_int64(item)
elif isinstance(item, str): elif isinstance(item, int):
msg.add_int(item)
elif isinstance(item, (string_types, bytes_types)):
msg.add_string(item) msg.add_string(item)
elif isinstance(item, SFTPAttributes): elif isinstance(item, SFTPAttributes):
item._pack(msg) item._pack(msg)
@ -668,7 +656,7 @@ class SFTPClient(BaseSFTP):
raise Exception('unknown type for %r type %r' % (item, type(item))) raise Exception('unknown type for %r type %r' % (item, type(item)))
num = self.request_number num = self.request_number
self._expecting[num] = fileobj self._expecting[num] = fileobj
self._send_packet(t, str(msg)) self._send_packet(t, msg)
self.request_number += 1 self.request_number += 1
finally: finally:
self._lock.release() self._lock.release()
@ -678,8 +666,8 @@ class SFTPClient(BaseSFTP):
while True: while True:
try: try:
t, data = self._read_packet() t, data = self._read_packet()
except EOFError, e: except EOFError as e:
raise SSHException('Server connection dropped: %s' % (str(e),)) raise SSHException('Server connection dropped: %s' % str(e))
msg = Message(data) msg = Message(data)
num = msg.get_int() num = msg.get_int()
if num not in self._expecting: if num not in self._expecting:
@ -713,7 +701,7 @@ class SFTPClient(BaseSFTP):
Raises EOFError or IOError on error status; otherwise does nothing. Raises EOFError or IOError on error status; otherwise does nothing.
""" """
code = msg.get_int() code = msg.get_int()
text = msg.get_string() text = msg.get_text()
if code == SFTP_OK: if code == SFTP_OK:
return return
elif code == SFTP_EOF: elif code == SFTP_EOF:
@ -731,16 +719,15 @@ class SFTPClient(BaseSFTP):
Return an adjusted path if we're emulating a "current working Return an adjusted path if we're emulating a "current working
directory" for the server. directory" for the server.
""" """
if type(path) is unicode: path = b(path)
path = path.encode('utf-8')
if self._cwd is None: if self._cwd is None:
return path return path
if (len(path) > 0) and (path[0] == '/'): if len(path) and path[0:1] == b_slash:
# absolute path # absolute path
return path return path
if self._cwd == '/': if self._cwd == b_slash:
return self._cwd + path return self._cwd + path
return self._cwd + '/' + path return self._cwd + b_slash + path
class SFTP(SFTPClient): class SFTP(SFTPClient):

View File

@ -100,7 +100,7 @@ class SFTPFile (BufferedFile):
k = [x for x in self._prefetch_extents.values() if x[0] <= offset] k = [x for x in self._prefetch_extents.values() if x[0] <= offset]
if len(k) == 0: if len(k) == 0:
return False return False
k.sort(lambda x, y: cmp(x[0], y[0])) k.sort(key=lambda x: x[0])
buf_offset, buf_size = k[-1] buf_offset, buf_size = k[-1]
if buf_offset + buf_size <= offset: if buf_offset + buf_size <= offset:
# prefetch request ends before this one begins # prefetch request ends before this one begins
@ -171,7 +171,7 @@ class SFTPFile (BufferedFile):
def _write(self, data): def _write(self, data):
# may write less than requested if it would exceed max packet size # may write less than requested if it would exceed max packet size
chunk = min(len(data), self.MAX_REQUEST_SIZE) chunk = min(len(data), self.MAX_REQUEST_SIZE)
self._reqs.append(self.sftp._async_request(type(None), CMD_WRITE, self.handle, long(self._realpos), str(data[:chunk]))) self._reqs.append(self.sftp._async_request(type(None), CMD_WRITE, self.handle, long(self._realpos), data[:chunk]))
if not self.pipelined or (len(self._reqs) > 100 and self.sftp.sock.recv_ready()): if not self.pipelined or (len(self._reqs) > 100 and self.sftp.sock.recv_ready()):
while len(self._reqs): while len(self._reqs):
req = self._reqs.popleft() req = self._reqs.popleft()
@ -224,7 +224,7 @@ class SFTPFile (BufferedFile):
self._realpos = self._pos self._realpos = self._pos
else: else:
self._realpos = self._pos = self._get_size() + offset self._realpos = self._pos = self._get_size() + offset
self._rbuffer = '' self._rbuffer = bytes()
def stat(self): def stat(self):
""" """
@ -352,8 +352,8 @@ class SFTPFile (BufferedFile):
""" """
t, msg = self.sftp._request(CMD_EXTENDED, 'check-file', self.handle, t, msg = self.sftp._request(CMD_EXTENDED, 'check-file', self.handle,
hash_algorithm, long(offset), long(length), block_size) hash_algorithm, long(offset), long(length), block_size)
ext = msg.get_string() ext = msg.get_text()
alg = msg.get_string() alg = msg.get_text()
data = msg.get_remainder() data = msg.get_remainder()
return data return data
@ -469,8 +469,8 @@ class SFTPFile (BufferedFile):
# save exception and re-raise it on next file operation # save exception and re-raise it on next file operation
try: try:
self.sftp._convert_status(msg) self.sftp._convert_status(msg)
except Exception, x: except Exception as e:
self._saved_exception = x self._saved_exception = e
return return
if t != CMD_DATA: if t != CMD_DATA:
raise SFTPError('Expected data') raise SFTPError('Expected data')

View File

@ -97,7 +97,7 @@ class SFTPHandle (object):
readfile.seek(offset) readfile.seek(offset)
self.__tell = offset self.__tell = offset
data = readfile.read(length) data = readfile.read(length)
except IOError, e: except IOError as e:
self.__tell = None self.__tell = None
return SFTPServer.convert_errno(e.errno) return SFTPServer.convert_errno(e.errno)
self.__tell += len(data) self.__tell += len(data)
@ -135,7 +135,7 @@ class SFTPHandle (object):
self.__tell = offset self.__tell = offset
writefile.write(data) writefile.write(data)
writefile.flush() writefile.flush()
except IOError, e: except IOError as e:
self.__tell = None self.__tell = None
return SFTPServer.convert_errno(e.errno) return SFTPServer.convert_errno(e.errno)
if self.__tell is not None: if self.__tell is not None:

View File

@ -89,7 +89,7 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
except EOFError: except EOFError:
self._log(DEBUG, 'EOF -- end of session') self._log(DEBUG, 'EOF -- end of session')
return return
except Exception, e: except Exception as e:
self._log(DEBUG, 'Exception on channel: ' + str(e)) self._log(DEBUG, 'Exception on channel: ' + str(e))
self._log(DEBUG, util.tb_strings()) self._log(DEBUG, util.tb_strings())
return return
@ -97,7 +97,7 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
request_number = msg.get_int() request_number = msg.get_int()
try: try:
self._process(t, request_number, msg) self._process(t, request_number, msg)
except Exception, e: except Exception as e:
self._log(DEBUG, 'Exception in server processing: ' + str(e)) self._log(DEBUG, 'Exception in server processing: ' + str(e))
self._log(DEBUG, util.tb_strings()) self._log(DEBUG, util.tb_strings())
# send some kind of failure message, at least # send some kind of failure message, at least
@ -110,9 +110,9 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
self.server.session_ended() self.server.session_ended()
super(SFTPServer, self).finish_subsystem() super(SFTPServer, self).finish_subsystem()
# close any file handles that were left open (so we can return them to the OS quickly) # close any file handles that were left open (so we can return them to the OS quickly)
for f in self.file_table.itervalues(): for f in self.file_table.values():
f.close() f.close()
for f in self.folder_table.itervalues(): for f in self.folder_table.values():
f.close() f.close()
self.file_table = {} self.file_table = {}
self.folder_table = {} self.folder_table = {}
@ -159,7 +159,8 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
if attr._flags & attr.FLAG_AMTIME: if attr._flags & attr.FLAG_AMTIME:
os.utime(filename, (attr.st_atime, attr.st_mtime)) os.utime(filename, (attr.st_atime, attr.st_mtime))
if attr._flags & attr.FLAG_SIZE: if attr._flags & attr.FLAG_SIZE:
open(filename, 'w+').truncate(attr.st_size) with open(filename, 'w+') as f:
f.truncate(attr.st_size)
set_file_attr = staticmethod(set_file_attr) set_file_attr = staticmethod(set_file_attr)
@ -170,24 +171,24 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
msg = Message() msg = Message()
msg.add_int(request_number) msg.add_int(request_number)
for item in arg: for item in arg:
if type(item) is int: if isinstance(item, long):
msg.add_int(item)
elif type(item) is long:
msg.add_int64(item) msg.add_int64(item)
elif type(item) is str: elif isinstance(item, int):
msg.add_int(item)
elif isinstance(item, (string_types, bytes_types)):
msg.add_string(item) msg.add_string(item)
elif type(item) is SFTPAttributes: elif type(item) is SFTPAttributes:
item._pack(msg) item._pack(msg)
else: else:
raise Exception('unknown type for ' + repr(item) + ' type ' + repr(type(item))) raise Exception('unknown type for ' + repr(item) + ' type ' + repr(type(item)))
self._send_packet(t, str(msg)) self._send_packet(t, msg)
def _send_handle_response(self, request_number, handle, folder=False): def _send_handle_response(self, request_number, handle, folder=False):
if not issubclass(type(handle), SFTPHandle): if not issubclass(type(handle), SFTPHandle):
# must be error code # must be error code
self._send_status(request_number, handle) self._send_status(request_number, handle)
return return
handle._set_name('hx%d' % self.next_handle) handle._set_name(b('hx%d' % self.next_handle))
self.next_handle += 1 self.next_handle += 1
if folder: if folder:
self.folder_table[handle._get_name()] = handle self.folder_table[handle._get_name()] = handle
@ -225,16 +226,16 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
msg.add_int(len(flist)) msg.add_int(len(flist))
for attr in flist: for attr in flist:
msg.add_string(attr.filename) msg.add_string(attr.filename)
msg.add_string(str(attr)) msg.add_string(attr)
attr._pack(msg) attr._pack(msg)
self._send_packet(CMD_NAME, str(msg)) self._send_packet(CMD_NAME, msg)
def _check_file(self, request_number, msg): def _check_file(self, request_number, msg):
# this extension actually comes from v6 protocol, but since it's an # this extension actually comes from v6 protocol, but since it's an
# extension, i feel like we can reasonably support it backported. # extension, i feel like we can reasonably support it backported.
# it's very useful for verifying uploaded files or checking for # it's very useful for verifying uploaded files or checking for
# rsync-like differences between local and remote files. # rsync-like differences between local and remote files.
handle = msg.get_string() handle = msg.get_binary()
alg_list = msg.get_list() alg_list = msg.get_list()
start = msg.get_int64() start = msg.get_int64()
length = msg.get_int64() length = msg.get_int64()
@ -263,7 +264,7 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
self._send_status(request_number, SFTP_FAILURE, 'Block size too small') self._send_status(request_number, SFTP_FAILURE, 'Block size too small')
return return
sum_out = '' sum_out = bytes()
offset = start offset = start
while offset < start + length: while offset < start + length:
blocklen = min(block_size, start + length - offset) blocklen = min(block_size, start + length - offset)
@ -273,7 +274,7 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
hash_obj = alg.new() hash_obj = alg.new()
while count < blocklen: while count < blocklen:
data = f.read(offset, chunklen) data = f.read(offset, chunklen)
if not type(data) is str: if not isinstance(data, bytes_types):
self._send_status(request_number, data, 'Unable to hash file') self._send_status(request_number, data, 'Unable to hash file')
return return
hash_obj.update(data) hash_obj.update(data)
@ -286,7 +287,7 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
msg.add_string('check-file') msg.add_string('check-file')
msg.add_string(algname) msg.add_string(algname)
msg.add_bytes(sum_out) msg.add_bytes(sum_out)
self._send_packet(CMD_EXTENDED_REPLY, str(msg)) self._send_packet(CMD_EXTENDED_REPLY, msg)
def _convert_pflags(self, pflags): def _convert_pflags(self, pflags):
"convert SFTP-style open() flags to Python's os.open() flags" "convert SFTP-style open() flags to Python's os.open() flags"
@ -309,12 +310,12 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
def _process(self, t, request_number, msg): def _process(self, t, request_number, msg):
self._log(DEBUG, 'Request: %s' % CMD_NAMES[t]) self._log(DEBUG, 'Request: %s' % CMD_NAMES[t])
if t == CMD_OPEN: if t == CMD_OPEN:
path = msg.get_string() path = msg.get_text()
flags = self._convert_pflags(msg.get_int()) flags = self._convert_pflags(msg.get_int())
attr = SFTPAttributes._from_msg(msg) attr = SFTPAttributes._from_msg(msg)
self._send_handle_response(request_number, self.server.open(path, flags, attr)) self._send_handle_response(request_number, self.server.open(path, flags, attr))
elif t == CMD_CLOSE: elif t == CMD_CLOSE:
handle = msg.get_string() handle = msg.get_binary()
if handle in self.folder_table: if handle in self.folder_table:
del self.folder_table[handle] del self.folder_table[handle]
self._send_status(request_number, SFTP_OK) self._send_status(request_number, SFTP_OK)
@ -326,14 +327,14 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
return return
self._send_status(request_number, SFTP_BAD_MESSAGE, 'Invalid handle') self._send_status(request_number, SFTP_BAD_MESSAGE, 'Invalid handle')
elif t == CMD_READ: elif t == CMD_READ:
handle = msg.get_string() handle = msg.get_binary()
offset = msg.get_int64() offset = msg.get_int64()
length = msg.get_int() length = msg.get_int()
if handle not in self.file_table: if handle not in self.file_table:
self._send_status(request_number, SFTP_BAD_MESSAGE, 'Invalid handle') self._send_status(request_number, SFTP_BAD_MESSAGE, 'Invalid handle')
return return
data = self.file_table[handle].read(offset, length) data = self.file_table[handle].read(offset, length)
if type(data) is str: if isinstance(data, (bytes_types, string_types)):
if len(data) == 0: if len(data) == 0:
self._send_status(request_number, SFTP_EOF) self._send_status(request_number, SFTP_EOF)
else: else:
@ -341,54 +342,54 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
else: else:
self._send_status(request_number, data) self._send_status(request_number, data)
elif t == CMD_WRITE: elif t == CMD_WRITE:
handle = msg.get_string() handle = msg.get_binary()
offset = msg.get_int64() offset = msg.get_int64()
data = msg.get_string() data = msg.get_binary()
if handle not in self.file_table: if handle not in self.file_table:
self._send_status(request_number, SFTP_BAD_MESSAGE, 'Invalid handle') self._send_status(request_number, SFTP_BAD_MESSAGE, 'Invalid handle')
return return
self._send_status(request_number, self.file_table[handle].write(offset, data)) self._send_status(request_number, self.file_table[handle].write(offset, data))
elif t == CMD_REMOVE: elif t == CMD_REMOVE:
path = msg.get_string() path = msg.get_text()
self._send_status(request_number, self.server.remove(path)) self._send_status(request_number, self.server.remove(path))
elif t == CMD_RENAME: elif t == CMD_RENAME:
oldpath = msg.get_string() oldpath = msg.get_text()
newpath = msg.get_string() newpath = msg.get_text()
self._send_status(request_number, self.server.rename(oldpath, newpath)) self._send_status(request_number, self.server.rename(oldpath, newpath))
elif t == CMD_MKDIR: elif t == CMD_MKDIR:
path = msg.get_string() path = msg.get_text()
attr = SFTPAttributes._from_msg(msg) attr = SFTPAttributes._from_msg(msg)
self._send_status(request_number, self.server.mkdir(path, attr)) self._send_status(request_number, self.server.mkdir(path, attr))
elif t == CMD_RMDIR: elif t == CMD_RMDIR:
path = msg.get_string() path = msg.get_text()
self._send_status(request_number, self.server.rmdir(path)) self._send_status(request_number, self.server.rmdir(path))
elif t == CMD_OPENDIR: elif t == CMD_OPENDIR:
path = msg.get_string() path = msg.get_text()
self._open_folder(request_number, path) self._open_folder(request_number, path)
return return
elif t == CMD_READDIR: elif t == CMD_READDIR:
handle = msg.get_string() handle = msg.get_binary()
if handle not in self.folder_table: if handle not in self.folder_table:
self._send_status(request_number, SFTP_BAD_MESSAGE, 'Invalid handle') self._send_status(request_number, SFTP_BAD_MESSAGE, 'Invalid handle')
return return
folder = self.folder_table[handle] folder = self.folder_table[handle]
self._read_folder(request_number, folder) self._read_folder(request_number, folder)
elif t == CMD_STAT: elif t == CMD_STAT:
path = msg.get_string() path = msg.get_text()
resp = self.server.stat(path) resp = self.server.stat(path)
if issubclass(type(resp), SFTPAttributes): if issubclass(type(resp), SFTPAttributes):
self._response(request_number, CMD_ATTRS, resp) self._response(request_number, CMD_ATTRS, resp)
else: else:
self._send_status(request_number, resp) self._send_status(request_number, resp)
elif t == CMD_LSTAT: elif t == CMD_LSTAT:
path = msg.get_string() path = msg.get_text()
resp = self.server.lstat(path) resp = self.server.lstat(path)
if issubclass(type(resp), SFTPAttributes): if issubclass(type(resp), SFTPAttributes):
self._response(request_number, CMD_ATTRS, resp) self._response(request_number, CMD_ATTRS, resp)
else: else:
self._send_status(request_number, resp) self._send_status(request_number, resp)
elif t == CMD_FSTAT: elif t == CMD_FSTAT:
handle = msg.get_string() handle = msg.get_binary()
if handle not in self.file_table: if handle not in self.file_table:
self._send_status(request_number, SFTP_BAD_MESSAGE, 'Invalid handle') self._send_status(request_number, SFTP_BAD_MESSAGE, 'Invalid handle')
return return
@ -398,34 +399,34 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
else: else:
self._send_status(request_number, resp) self._send_status(request_number, resp)
elif t == CMD_SETSTAT: elif t == CMD_SETSTAT:
path = msg.get_string() path = msg.get_text()
attr = SFTPAttributes._from_msg(msg) attr = SFTPAttributes._from_msg(msg)
self._send_status(request_number, self.server.chattr(path, attr)) self._send_status(request_number, self.server.chattr(path, attr))
elif t == CMD_FSETSTAT: elif t == CMD_FSETSTAT:
handle = msg.get_string() handle = msg.get_binary()
attr = SFTPAttributes._from_msg(msg) attr = SFTPAttributes._from_msg(msg)
if handle not in self.file_table: if handle not in self.file_table:
self._response(request_number, SFTP_BAD_MESSAGE, 'Invalid handle') self._response(request_number, SFTP_BAD_MESSAGE, 'Invalid handle')
return return
self._send_status(request_number, self.file_table[handle].chattr(attr)) self._send_status(request_number, self.file_table[handle].chattr(attr))
elif t == CMD_READLINK: elif t == CMD_READLINK:
path = msg.get_string() path = msg.get_text()
resp = self.server.readlink(path) resp = self.server.readlink(path)
if type(resp) is str: if isinstance(resp, (bytes_types, string_types)):
self._response(request_number, CMD_NAME, 1, resp, '', SFTPAttributes()) self._response(request_number, CMD_NAME, 1, resp, '', SFTPAttributes())
else: else:
self._send_status(request_number, resp) self._send_status(request_number, resp)
elif t == CMD_SYMLINK: elif t == CMD_SYMLINK:
# the sftp 2 draft is incorrect here! path always follows target_path # the sftp 2 draft is incorrect here! path always follows target_path
target_path = msg.get_string() target_path = msg.get_text()
path = msg.get_string() path = msg.get_text()
self._send_status(request_number, self.server.symlink(target_path, path)) self._send_status(request_number, self.server.symlink(target_path, path))
elif t == CMD_REALPATH: elif t == CMD_REALPATH:
path = msg.get_string() path = msg.get_text()
rpath = self.server.canonicalize(path) rpath = self.server.canonicalize(path)
self._response(request_number, CMD_NAME, 1, rpath, '', SFTPAttributes()) self._response(request_number, CMD_NAME, 1, rpath, '', SFTPAttributes())
elif t == CMD_EXTENDED: elif t == CMD_EXTENDED:
tag = msg.get_string() tag = msg.get_text()
if tag == 'check-file': if tag == 'check-file':
self._check_file(request_number, msg) self._check_file(request_number, msg)
else: else:

View File

@ -155,7 +155,7 @@ class Transport (threading.Thread):
:param socket sock: :param socket sock:
a socket or socket-like object to create the session over. a socket or socket-like object to create the session over.
""" """
if isinstance(sock, (str, unicode)): if isinstance(sock, string_types):
# convert "host:port" into (host, port) # convert "host:port" into (host, port)
hl = sock.split(':', 1) hl = sock.split(':', 1)
if len(hl) == 1: if len(hl) == 1:
@ -173,7 +173,7 @@ class Transport (threading.Thread):
sock = socket.socket(af, socket.SOCK_STREAM) sock = socket.socket(af, socket.SOCK_STREAM)
try: try:
retry_on_signal(lambda: sock.connect((hostname, port))) retry_on_signal(lambda: sock.connect((hostname, port)))
except socket.error, e: except socket.error as e:
reason = str(e) reason = str(e)
else: else:
break break
@ -253,7 +253,7 @@ class Transport (threading.Thread):
""" """
Returns a string representation of this object, for debugging. Returns a string representation of this object, for debugging.
""" """
out = '<paramiko.Transport at %s' % hex(long(id(self)) & 0xffffffffL) out = '<paramiko.Transport at %s' % hex(long(id(self)) & xffffffff)
if not self.active: if not self.active:
out += ' (unconnected)' out += ' (unconnected)'
else: else:
@ -279,6 +279,7 @@ class Transport (threading.Thread):
.. versionadded:: 1.5.3 .. versionadded:: 1.5.3
""" """
self.sock.close()
self.close() self.close()
def get_security_options(self): def get_security_options(self):
@ -489,7 +490,7 @@ class Transport (threading.Thread):
if not self.active: if not self.active:
return return
self.stop_thread() self.stop_thread()
for chan in self._channels.values(): for chan in list(self._channels.values()):
chan._unlink() chan._unlink()
self.sock.close() self.sock.close()
@ -562,18 +563,16 @@ class Transport (threading.Thread):
""" """
return self.open_channel('auth-agent@openssh.com') return self.open_channel('auth-agent@openssh.com')
def open_forwarded_tcpip_channel(self, (src_addr, src_port), (dest_addr, dest_port)): def open_forwarded_tcpip_channel(self, src_addr, dest_addr):
""" """
Request a new channel back to the client, of type ``"forwarded-tcpip"``. Request a new channel back to the client, of type ``"forwarded-tcpip"``.
This is used after a client has requested port forwarding, for sending This is used after a client has requested port forwarding, for sending
incoming connections back to the client. incoming connections back to the client.
:param src_addr: originator's address :param src_addr: originator's address
:param src_port: originator's port
:param dest_addr: local (server) connected address :param dest_addr: local (server) connected address
:param dest_port: local (server) connected port
""" """
return self.open_channel('forwarded-tcpip', (dest_addr, dest_port), (src_addr, src_port)) return self.open_channel('forwarded-tcpip', dest_addr, src_addr)
def open_channel(self, kind, dest_addr=None, src_addr=None): def open_channel(self, kind, dest_addr=None, src_addr=None):
""" """
@ -602,7 +601,7 @@ class Transport (threading.Thread):
try: try:
chanid = self._next_channel() chanid = self._next_channel()
m = Message() m = Message()
m.add_byte(chr(MSG_CHANNEL_OPEN)) m.add_byte(cMSG_CHANNEL_OPEN)
m.add_string(kind) m.add_string(kind)
m.add_int(chanid) m.add_int(chanid)
m.add_int(self.window_size) m.add_int(self.window_size)
@ -670,7 +669,6 @@ class Transport (threading.Thread):
""" """
if not self.active: if not self.active:
raise SSHException('SSH session not active') raise SSHException('SSH session not active')
address = str(address)
port = int(port) port = int(port)
response = self.global_request('tcpip-forward', (address, port), wait=True) response = self.global_request('tcpip-forward', (address, port), wait=True)
if response is None: if response is None:
@ -678,7 +676,9 @@ class Transport (threading.Thread):
if port == 0: if port == 0:
port = response.get_int() port = response.get_int()
if handler is None: if handler is None:
def default_handler(channel, (src_addr, src_port), (dest_addr, dest_port)): def default_handler(channel, src_addr, dest_addr):
#src_addr, src_port = src_addr_port
#dest_addr, dest_port = dest_addr_port
self._queue_incoming_channel(channel) self._queue_incoming_channel(channel)
handler = default_handler handler = default_handler
self._tcp_handler = handler self._tcp_handler = handler
@ -710,22 +710,22 @@ class Transport (threading.Thread):
""" """
return SFTPClient.from_transport(self) return SFTPClient.from_transport(self)
def send_ignore(self, bytes=None): def send_ignore(self, byte_count=None):
""" """
Send a junk packet across the encrypted link. This is sometimes used Send a junk packet across the encrypted link. This is sometimes used
to add "noise" to a connection to confuse would-be attackers. It can to add "noise" to a connection to confuse would-be attackers. It can
also be used as a keep-alive for long lived connections traversing also be used as a keep-alive for long lived connections traversing
firewalls. firewalls.
:param int bytes: :param int byte_count:
the number of random bytes to send in the payload of the ignored the number of random bytes to send in the payload of the ignored
packet -- defaults to a random number from 10 to 41. packet -- defaults to a random number from 10 to 41.
""" """
m = Message() m = Message()
m.add_byte(chr(MSG_IGNORE)) m.add_byte(cMSG_IGNORE)
if bytes is None: if byte_count is None:
bytes = (ord(rng.read(1)) % 32) + 10 byte_count = (ord(rng.read(1)) % 32) + 10
m.add_bytes(rng.read(bytes)) m.add_bytes(rng.read(byte_count))
self._send_user_message(m) self._send_user_message(m)
def renegotiate_keys(self): def renegotiate_keys(self):
@ -787,7 +787,7 @@ class Transport (threading.Thread):
if wait: if wait:
self.completion_event = threading.Event() self.completion_event = threading.Event()
m = Message() m = Message()
m.add_byte(chr(MSG_GLOBAL_REQUEST)) m.add_byte(cMSG_GLOBAL_REQUEST)
m.add_string(kind) m.add_string(kind)
m.add_boolean(wait) m.add_boolean(wait)
if data is not None: if data is not None:
@ -871,10 +871,10 @@ class Transport (threading.Thread):
# check host key if we were given one # check host key if we were given one
if (hostkey is not None): if (hostkey is not None):
key = self.get_remote_server_key() key = self.get_remote_server_key()
if (key.get_name() != hostkey.get_name()) or (str(key) != str(hostkey)): if (key.get_name() != hostkey.get_name()) or (key.asbytes() != hostkey.asbytes()):
self._log(DEBUG, 'Bad host key from server') self._log(DEBUG, 'Bad host key from server')
self._log(DEBUG, 'Expected: %s: %s' % (hostkey.get_name(), repr(str(hostkey)))) self._log(DEBUG, 'Expected: %s: %s' % (hostkey.get_name(), repr(hostkey.asbytes())))
self._log(DEBUG, 'Got : %s: %s' % (key.get_name(), repr(str(key)))) self._log(DEBUG, 'Got : %s: %s' % (key.get_name(), repr(key.asbytes())))
raise SSHException('Bad host key from server') raise SSHException('Bad host key from server')
self._log(DEBUG, 'Host key verified (%s)' % hostkey.get_name()) self._log(DEBUG, 'Host key verified (%s)' % hostkey.get_name())
@ -1048,9 +1048,9 @@ class Transport (threading.Thread):
return [] return []
try: try:
return self.auth_handler.wait_for_response(my_event) return self.auth_handler.wait_for_response(my_event)
except BadAuthenticationType, x: except BadAuthenticationType as e:
# if password auth isn't allowed, but keyboard-interactive *is*, try to fudge it # if password auth isn't allowed, but keyboard-interactive *is*, try to fudge it
if not fallback or ('keyboard-interactive' not in x.allowed_types): if not fallback or ('keyboard-interactive' not in e.allowed_types):
raise raise
try: try:
def handler(title, instructions, fields): def handler(title, instructions, fields):
@ -1064,9 +1064,9 @@ class Transport (threading.Thread):
return [] return []
return [ password ] return [ password ]
return self.auth_interactive(username, handler) return self.auth_interactive(username, handler)
except SSHException, ignored: except SSHException:
# attempt failed; just raise the original exception # attempt failed; just raise the original exception
raise x raise e
return None return None
def auth_publickey(self, username, key, event=None): def auth_publickey(self, username, key, event=None):
@ -1331,15 +1331,15 @@ class Transport (threading.Thread):
m = Message() m = Message()
m.add_mpint(self.K) m.add_mpint(self.K)
m.add_bytes(self.H) m.add_bytes(self.H)
m.add_byte(id) m.add_byte(b(id))
m.add_bytes(self.session_id) m.add_bytes(self.session_id)
out = sofar = SHA.new(str(m)).digest() out = sofar = SHA.new(m.asbytes()).digest()
while len(out) < nbytes: while len(out) < nbytes:
m = Message() m = Message()
m.add_mpint(self.K) m.add_mpint(self.K)
m.add_bytes(self.H) m.add_bytes(self.H)
m.add_bytes(sofar) m.add_bytes(sofar)
digest = SHA.new(str(m)).digest() digest = SHA.new(m.asbytes()).digest()
out += digest out += digest
sofar += digest sofar += digest
return out[:nbytes] return out[:nbytes]
@ -1373,7 +1373,7 @@ class Transport (threading.Thread):
# only called if a channel has turned on x11 forwarding # only called if a channel has turned on x11 forwarding
if handler is None: if handler is None:
# by default, use the same mechanism as accept() # by default, use the same mechanism as accept()
def default_handler(channel, (src_addr, src_port)): def default_handler(channel, src_addr_port):
self._queue_incoming_channel(channel) self._queue_incoming_channel(channel)
self._x11_handler = default_handler self._x11_handler = default_handler
else: else:
@ -1404,12 +1404,12 @@ class Transport (threading.Thread):
# active=True occurs before the thread is launched, to avoid a race # active=True occurs before the thread is launched, to avoid a race
_active_threads.append(self) _active_threads.append(self)
if self.server_mode: if self.server_mode:
self._log(DEBUG, 'starting thread (server mode): %s' % hex(long(id(self)) & 0xffffffffL)) self._log(DEBUG, 'starting thread (server mode): %s' % hex(long(id(self)) & xffffffff))
else: else:
self._log(DEBUG, 'starting thread (client mode): %s' % hex(long(id(self)) & 0xffffffffL)) self._log(DEBUG, 'starting thread (client mode): %s' % hex(long(id(self)) & xffffffff))
try: try:
try: try:
self.packetizer.write_all(self.local_version + '\r\n') self.packetizer.write_all(b(self.local_version + '\r\n'))
self._check_banner() self._check_banner()
self._send_kex_init() self._send_kex_init()
self._expect_packet(MSG_KEXINIT) self._expect_packet(MSG_KEXINIT)
@ -1457,18 +1457,18 @@ class Transport (threading.Thread):
else: else:
self._log(WARNING, 'Oops, unhandled type %d' % ptype) self._log(WARNING, 'Oops, unhandled type %d' % ptype)
msg = Message() msg = Message()
msg.add_byte(chr(MSG_UNIMPLEMENTED)) msg.add_byte(cMSG_UNIMPLEMENTED)
msg.add_int(m.seqno) msg.add_int(m.seqno)
self._send_message(msg) self._send_message(msg)
except SSHException, e: except SSHException as e:
self._log(ERROR, 'Exception: ' + str(e)) self._log(ERROR, 'Exception: ' + str(e))
self._log(ERROR, util.tb_strings()) self._log(ERROR, util.tb_strings())
self.saved_exception = e self.saved_exception = e
except EOFError, e: except EOFError as e:
self._log(DEBUG, 'EOF in transport thread') self._log(DEBUG, 'EOF in transport thread')
#self._log(DEBUG, util.tb_strings()) #self._log(DEBUG, util.tb_strings())
self.saved_exception = e self.saved_exception = e
except socket.error, e: except socket.error as e:
if type(e.args) is tuple: if type(e.args) is tuple:
if e.args: if e.args:
emsg = '%s (%d)' % (e.args[1], e.args[0]) emsg = '%s (%d)' % (e.args[1], e.args[0])
@ -1478,12 +1478,12 @@ class Transport (threading.Thread):
emsg = e.args emsg = e.args
self._log(ERROR, 'Socket exception: ' + emsg) self._log(ERROR, 'Socket exception: ' + emsg)
self.saved_exception = e self.saved_exception = e
except Exception, e: except Exception as e:
self._log(ERROR, 'Unknown exception: ' + str(e)) self._log(ERROR, 'Unknown exception: ' + str(e))
self._log(ERROR, util.tb_strings()) self._log(ERROR, util.tb_strings())
self.saved_exception = e self.saved_exception = e
_active_threads.remove(self) _active_threads.remove(self)
for chan in self._channels.values(): for chan in list(self._channels.values()):
chan._unlink() chan._unlink()
if self.active: if self.active:
self.active = False self.active = False
@ -1538,8 +1538,8 @@ class Transport (threading.Thread):
buf = self.packetizer.readline(timeout) buf = self.packetizer.readline(timeout)
except ProxyCommandFailure: except ProxyCommandFailure:
raise raise
except Exception, x: except Exception as e:
raise SSHException('Error reading SSH protocol banner' + str(x)) raise SSHException('Error reading SSH protocol banner' + str(e))
if buf[:4] == 'SSH-': if buf[:4] == 'SSH-':
break break
self._log(DEBUG, 'Banner: ' + buf) self._log(DEBUG, 'Banner: ' + buf)
@ -1549,7 +1549,7 @@ class Transport (threading.Thread):
self.remote_version = buf self.remote_version = buf
# pull off any attached comment # pull off any attached comment
comment = '' comment = ''
i = string.find(buf, ' ') i = buf.find(' ')
if i >= 0: if i >= 0:
comment = buf[i+1:] comment = buf[i+1:]
buf = buf[:i] buf = buf[:i]
@ -1580,13 +1580,13 @@ class Transport (threading.Thread):
pkex = list(self.get_security_options().kex) pkex = list(self.get_security_options().kex)
pkex.remove('diffie-hellman-group-exchange-sha1') pkex.remove('diffie-hellman-group-exchange-sha1')
self.get_security_options().kex = pkex self.get_security_options().kex = pkex
available_server_keys = filter(self.server_key_dict.keys().__contains__, available_server_keys = list(filter(list(self.server_key_dict.keys()).__contains__,
self._preferred_keys) self._preferred_keys))
else: else:
available_server_keys = self._preferred_keys available_server_keys = self._preferred_keys
m = Message() m = Message()
m.add_byte(chr(MSG_KEXINIT)) m.add_byte(cMSG_KEXINIT)
m.add_bytes(rng.read(16)) m.add_bytes(rng.read(16))
m.add_list(self._preferred_kex) m.add_list(self._preferred_kex)
m.add_list(available_server_keys) m.add_list(available_server_keys)
@ -1596,12 +1596,12 @@ class Transport (threading.Thread):
m.add_list(self._preferred_macs) m.add_list(self._preferred_macs)
m.add_list(self._preferred_compression) m.add_list(self._preferred_compression)
m.add_list(self._preferred_compression) m.add_list(self._preferred_compression)
m.add_string('') m.add_string(bytes())
m.add_string('') m.add_string(bytes())
m.add_boolean(False) m.add_boolean(False)
m.add_int(0) m.add_int(0)
# save a copy for later (needed to compute a hash) # save a copy for later (needed to compute a hash)
self.local_kex_init = str(m) self.local_kex_init = m.asbytes()
self._send_message(m) self._send_message(m)
def _parse_kex_init(self, m): def _parse_kex_init(self, m):
@ -1633,19 +1633,19 @@ class Transport (threading.Thread):
# as a server, we pick the first item in the client's list that we support. # as a server, we pick the first item in the client's list that we support.
# as a client, we pick the first item in our list that the server supports. # as a client, we pick the first item in our list that the server supports.
if self.server_mode: if self.server_mode:
agreed_kex = filter(self._preferred_kex.__contains__, kex_algo_list) agreed_kex = list(filter(self._preferred_kex.__contains__, kex_algo_list))
else: else:
agreed_kex = filter(kex_algo_list.__contains__, self._preferred_kex) agreed_kex = list(filter(kex_algo_list.__contains__, self._preferred_kex))
if len(agreed_kex) == 0: if len(agreed_kex) == 0:
raise SSHException('Incompatible ssh peer (no acceptable kex algorithm)') raise SSHException('Incompatible ssh peer (no acceptable kex algorithm)')
self.kex_engine = self._kex_info[agreed_kex[0]](self) self.kex_engine = self._kex_info[agreed_kex[0]](self)
if self.server_mode: if self.server_mode:
available_server_keys = filter(self.server_key_dict.keys().__contains__, available_server_keys = list(filter(list(self.server_key_dict.keys()).__contains__,
self._preferred_keys) self._preferred_keys))
agreed_keys = filter(available_server_keys.__contains__, server_key_algo_list) agreed_keys = list(filter(available_server_keys.__contains__, server_key_algo_list))
else: else:
agreed_keys = filter(server_key_algo_list.__contains__, self._preferred_keys) agreed_keys = list(filter(server_key_algo_list.__contains__, self._preferred_keys))
if len(agreed_keys) == 0: if len(agreed_keys) == 0:
raise SSHException('Incompatible ssh peer (no acceptable host key)') raise SSHException('Incompatible ssh peer (no acceptable host key)')
self.host_key_type = agreed_keys[0] self.host_key_type = agreed_keys[0]
@ -1653,15 +1653,15 @@ class Transport (threading.Thread):
raise SSHException('Incompatible ssh peer (can\'t match requested host key type)') raise SSHException('Incompatible ssh peer (can\'t match requested host key type)')
if self.server_mode: if self.server_mode:
agreed_local_ciphers = filter(self._preferred_ciphers.__contains__, agreed_local_ciphers = list(filter(self._preferred_ciphers.__contains__,
server_encrypt_algo_list) server_encrypt_algo_list))
agreed_remote_ciphers = filter(self._preferred_ciphers.__contains__, agreed_remote_ciphers = list(filter(self._preferred_ciphers.__contains__,
client_encrypt_algo_list) client_encrypt_algo_list))
else: else:
agreed_local_ciphers = filter(client_encrypt_algo_list.__contains__, agreed_local_ciphers = list(filter(client_encrypt_algo_list.__contains__,
self._preferred_ciphers) self._preferred_ciphers))
agreed_remote_ciphers = filter(server_encrypt_algo_list.__contains__, agreed_remote_ciphers = list(filter(server_encrypt_algo_list.__contains__,
self._preferred_ciphers) self._preferred_ciphers))
if (len(agreed_local_ciphers) == 0) or (len(agreed_remote_ciphers) == 0): if (len(agreed_local_ciphers) == 0) or (len(agreed_remote_ciphers) == 0):
raise SSHException('Incompatible ssh server (no acceptable ciphers)') raise SSHException('Incompatible ssh server (no acceptable ciphers)')
self.local_cipher = agreed_local_ciphers[0] self.local_cipher = agreed_local_ciphers[0]
@ -1669,22 +1669,22 @@ class Transport (threading.Thread):
self._log(DEBUG, 'Ciphers agreed: local=%s, remote=%s' % (self.local_cipher, self.remote_cipher)) self._log(DEBUG, 'Ciphers agreed: local=%s, remote=%s' % (self.local_cipher, self.remote_cipher))
if self.server_mode: if self.server_mode:
agreed_remote_macs = filter(self._preferred_macs.__contains__, client_mac_algo_list) agreed_remote_macs = list(filter(self._preferred_macs.__contains__, client_mac_algo_list))
agreed_local_macs = filter(self._preferred_macs.__contains__, server_mac_algo_list) agreed_local_macs = list(filter(self._preferred_macs.__contains__, server_mac_algo_list))
else: else:
agreed_local_macs = filter(client_mac_algo_list.__contains__, self._preferred_macs) agreed_local_macs = list(filter(client_mac_algo_list.__contains__, self._preferred_macs))
agreed_remote_macs = filter(server_mac_algo_list.__contains__, self._preferred_macs) agreed_remote_macs = list(filter(server_mac_algo_list.__contains__, self._preferred_macs))
if (len(agreed_local_macs) == 0) or (len(agreed_remote_macs) == 0): if (len(agreed_local_macs) == 0) or (len(agreed_remote_macs) == 0):
raise SSHException('Incompatible ssh server (no acceptable macs)') raise SSHException('Incompatible ssh server (no acceptable macs)')
self.local_mac = agreed_local_macs[0] self.local_mac = agreed_local_macs[0]
self.remote_mac = agreed_remote_macs[0] self.remote_mac = agreed_remote_macs[0]
if self.server_mode: if self.server_mode:
agreed_remote_compression = filter(self._preferred_compression.__contains__, client_compress_algo_list) agreed_remote_compression = list(filter(self._preferred_compression.__contains__, client_compress_algo_list))
agreed_local_compression = filter(self._preferred_compression.__contains__, server_compress_algo_list) agreed_local_compression = list(filter(self._preferred_compression.__contains__, server_compress_algo_list))
else: else:
agreed_local_compression = filter(client_compress_algo_list.__contains__, self._preferred_compression) agreed_local_compression = list(filter(client_compress_algo_list.__contains__, self._preferred_compression))
agreed_remote_compression = filter(server_compress_algo_list.__contains__, self._preferred_compression) agreed_remote_compression = list(filter(server_compress_algo_list.__contains__, self._preferred_compression))
if (len(agreed_local_compression) == 0) or (len(agreed_remote_compression) == 0): if (len(agreed_local_compression) == 0) or (len(agreed_remote_compression) == 0):
raise SSHException('Incompatible ssh server (no acceptable compression) %r %r %r' % (agreed_local_compression, agreed_remote_compression, self._preferred_compression)) raise SSHException('Incompatible ssh server (no acceptable compression) %r %r %r' % (agreed_local_compression, agreed_remote_compression, self._preferred_compression))
self.local_compression = agreed_local_compression[0] self.local_compression = agreed_local_compression[0]
@ -1699,7 +1699,7 @@ class Transport (threading.Thread):
# actually some extra bytes (one NUL byte in openssh's case) added to # actually some extra bytes (one NUL byte in openssh's case) added to
# the end of the packet but not parsed. turns out we need to throw # the end of the packet but not parsed. turns out we need to throw
# away those bytes because they aren't part of the hash. # away those bytes because they aren't part of the hash.
self.remote_kex_init = chr(MSG_KEXINIT) + m.get_so_far() self.remote_kex_init = cMSG_KEXINIT + m.get_so_far()
def _activate_inbound(self): def _activate_inbound(self):
"switch on newly negotiated encryption parameters for inbound traffic" "switch on newly negotiated encryption parameters for inbound traffic"
@ -1728,7 +1728,7 @@ class Transport (threading.Thread):
def _activate_outbound(self): def _activate_outbound(self):
"switch on newly negotiated encryption parameters for outbound traffic" "switch on newly negotiated encryption parameters for outbound traffic"
m = Message() m = Message()
m.add_byte(chr(MSG_NEWKEYS)) m.add_byte(MSG_NEWKEYS)
self._send_message(m) self._send_message(m)
block_size = self._cipher_info[self.local_cipher]['block-size'] block_size = self._cipher_info[self.local_cipher]['block-size']
if self.server_mode: if self.server_mode:
@ -1797,24 +1797,24 @@ class Transport (threading.Thread):
def _parse_disconnect(self, m): def _parse_disconnect(self, m):
code = m.get_int() code = m.get_int()
desc = m.get_string() desc = m.get_text()
self._log(INFO, 'Disconnect (code %d): %s' % (code, desc)) self._log(INFO, 'Disconnect (code %d): %s' % (code, desc))
def _parse_global_request(self, m): def _parse_global_request(self, m):
kind = m.get_string() kind = m.get_text()
self._log(DEBUG, 'Received global request "%s"' % kind) self._log(DEBUG, 'Received global request "%s"' % kind)
want_reply = m.get_boolean() want_reply = m.get_boolean()
if not self.server_mode: if not self.server_mode:
self._log(DEBUG, 'Rejecting "%s" global request from server.' % kind) self._log(DEBUG, 'Rejecting "%s" global request from server.' % kind)
ok = False ok = False
elif kind == 'tcpip-forward': elif kind == 'tcpip-forward':
address = m.get_string() address = m.get_text()
port = m.get_int() port = m.get_int()
ok = self.server_object.check_port_forward_request(address, port) ok = self.server_object.check_port_forward_request(address, port)
if ok != False: if ok != False:
ok = (ok,) ok = (ok,)
elif kind == 'cancel-tcpip-forward': elif kind == 'cancel-tcpip-forward':
address = m.get_string() address = m.get_test()
port = m.get_int() port = m.get_int()
self.server_object.cancel_port_forward_request(address, port) self.server_object.cancel_port_forward_request(address, port)
ok = True ok = True
@ -1827,10 +1827,10 @@ class Transport (threading.Thread):
if want_reply: if want_reply:
msg = Message() msg = Message()
if ok: if ok:
msg.add_byte(chr(MSG_REQUEST_SUCCESS)) msg.add_byte(cMSG_REQUEST_SUCCESS)
msg.add(*extra) msg.add(*extra)
else: else:
msg.add_byte(chr(MSG_REQUEST_FAILURE)) msg.add_byte(cMSG_REQUEST_FAILURE)
self._send_message(msg) self._send_message(msg)
def _parse_request_success(self, m): def _parse_request_success(self, m):
@ -1868,8 +1868,8 @@ class Transport (threading.Thread):
def _parse_channel_open_failure(self, m): def _parse_channel_open_failure(self, m):
chanid = m.get_int() chanid = m.get_int()
reason = m.get_int() reason = m.get_int()
reason_str = m.get_string() reason_str = m.get_text()
lang = m.get_string() lang = m.get_text()
reason_text = CONNECTION_FAILED_CODE.get(reason, '(unknown code)') reason_text = CONNECTION_FAILED_CODE.get(reason, '(unknown code)')
self._log(INFO, 'Secsh channel %d open FAILED: %s: %s' % (chanid, reason_str, reason_text)) self._log(INFO, 'Secsh channel %d open FAILED: %s: %s' % (chanid, reason_str, reason_text))
self.lock.acquire() self.lock.acquire()
@ -1885,7 +1885,7 @@ class Transport (threading.Thread):
return return
def _parse_channel_open(self, m): def _parse_channel_open(self, m):
kind = m.get_string() kind = m.get_text()
chanid = m.get_int() chanid = m.get_int()
initial_window_size = m.get_int() initial_window_size = m.get_int()
max_packet_size = m.get_int() max_packet_size = m.get_int()
@ -1898,7 +1898,7 @@ class Transport (threading.Thread):
finally: finally:
self.lock.release() self.lock.release()
elif (kind == 'x11') and (self._x11_handler is not None): elif (kind == 'x11') and (self._x11_handler is not None):
origin_addr = m.get_string() origin_addr = m.get_text()
origin_port = m.get_int() origin_port = m.get_int()
self._log(DEBUG, 'Incoming x11 connection from %s:%d' % (origin_addr, origin_port)) self._log(DEBUG, 'Incoming x11 connection from %s:%d' % (origin_addr, origin_port))
self.lock.acquire() self.lock.acquire()
@ -1907,9 +1907,9 @@ class Transport (threading.Thread):
finally: finally:
self.lock.release() self.lock.release()
elif (kind == 'forwarded-tcpip') and (self._tcp_handler is not None): elif (kind == 'forwarded-tcpip') and (self._tcp_handler is not None):
server_addr = m.get_string() server_addr = m.get_text()
server_port = m.get_int() server_port = m.get_int()
origin_addr = m.get_string() origin_addr = m.get_text()
origin_port = m.get_int() origin_port = m.get_int()
self._log(DEBUG, 'Incoming tcp forwarded connection from %s:%d' % (origin_addr, origin_port)) self._log(DEBUG, 'Incoming tcp forwarded connection from %s:%d' % (origin_addr, origin_port))
self.lock.acquire() self.lock.acquire()
@ -1929,9 +1929,9 @@ class Transport (threading.Thread):
self.lock.release() self.lock.release()
if kind == 'direct-tcpip': if kind == 'direct-tcpip':
# handle direct-tcpip requests comming from the client # handle direct-tcpip requests comming from the client
dest_addr = m.get_string() dest_addr = m.get_text()
dest_port = m.get_int() dest_port = m.get_int()
origin_addr = m.get_string() origin_addr = m.get_text()
origin_port = m.get_int() origin_port = m.get_int()
reason = self.server_object.check_channel_direct_tcpip_request( reason = self.server_object.check_channel_direct_tcpip_request(
my_chanid, (origin_addr, origin_port), my_chanid, (origin_addr, origin_port),
@ -1943,7 +1943,7 @@ class Transport (threading.Thread):
reject = True reject = True
if reject: if reject:
msg = Message() msg = Message()
msg.add_byte(chr(MSG_CHANNEL_OPEN_FAILURE)) msg.add_byte(cMSG_CHANNEL_OPEN_FAILURE)
msg.add_int(chanid) msg.add_int(chanid)
msg.add_int(reason) msg.add_int(reason)
msg.add_string('') msg.add_string('')
@ -1962,7 +1962,7 @@ class Transport (threading.Thread):
finally: finally:
self.lock.release() self.lock.release()
m = Message() m = Message()
m.add_byte(chr(MSG_CHANNEL_OPEN_SUCCESS)) m.add_byte(cMSG_CHANNEL_OPEN_SUCCESS)
m.add_int(chanid) m.add_int(chanid)
m.add_int(my_chanid) m.add_int(my_chanid)
m.add_int(self.window_size) m.add_int(self.window_size)
@ -2029,7 +2029,8 @@ class SecurityOptions (object):
``ValueError`` will be raised. If you try to assign something besides a ``ValueError`` will be raised. If you try to assign something besides a
tuple to one of the fields, ``TypeError`` will be raised. tuple to one of the fields, ``TypeError`` will be raised.
""" """
__slots__ = [ 'ciphers', 'digests', 'key_types', 'kex', 'compression', '_transport' ] #__slots__ = [ 'ciphers', 'digests', 'key_types', 'kex', 'compression', '_transport' ]
__slots__ = '_transport'
def __init__(self, transport): def __init__(self, transport):
self._transport = transport self._transport = transport
@ -2060,8 +2061,8 @@ class SecurityOptions (object):
x = tuple(x) x = tuple(x)
if type(x) is not tuple: if type(x) is not tuple:
raise TypeError('expected tuple or list') raise TypeError('expected tuple or list')
possible = getattr(self._transport, orig).keys() possible = list(getattr(self._transport, orig).keys())
forbidden = filter(lambda n: n not in possible, x) forbidden = [n for n in x if n not in possible]
if len(forbidden) > 0: if len(forbidden) > 0:
raise ValueError('unknown cipher') raise ValueError('unknown cipher')
setattr(self._transport, name, x) setattr(self._transport, name, x)
@ -2125,7 +2126,7 @@ class ChannelMap (object):
def values(self): def values(self):
self._lock.acquire() self._lock.acquire()
try: try:
return self._map.values() return list(self._map.values())
finally: finally:
self._lock.release() self._lock.release()

View File

@ -48,60 +48,53 @@ if sys.version_info < (2,3):
def inflate_long(s, always_positive=False): def inflate_long(s, always_positive=False):
"turns a normalized byte string into a long-int (adapted from Crypto.Util.number)" "turns a normalized byte string into a long-int (adapted from Crypto.Util.number)"
out = 0L out = long(0)
negative = 0 negative = 0
if not always_positive and (len(s) > 0) and (ord(s[0]) >= 0x80): if not always_positive and (len(s) > 0) and (byte_ord(s[0]) >= 0x80):
negative = 1 negative = 1
if len(s) % 4: if len(s) % 4:
filler = '\x00' filler = zero_byte
if negative: if negative:
filler = '\xff' filler = max_byte
s = filler * (4 - len(s) % 4) + s s = filler * (4 - len(s) % 4) + s
for i in range(0, len(s), 4): for i in range(0, len(s), 4):
out = (out << 32) + struct.unpack('>I', s[i:i+4])[0] out = (out << 32) + struct.unpack('>I', s[i:i+4])[0]
if negative: if negative:
out -= (1L << (8 * len(s))) out -= (long(1) << (8 * len(s)))
return out return out
deflate_zero = zero_byte if PY2 else 0
deflate_ff = max_byte if PY2 else 0xff
def deflate_long(n, add_sign_padding=True): def deflate_long(n, add_sign_padding=True):
"turns a long-int into a normalized byte string (adapted from Crypto.Util.number)" "turns a long-int into a normalized byte string (adapted from Crypto.Util.number)"
# after much testing, this algorithm was deemed to be the fastest # after much testing, this algorithm was deemed to be the fastest
s = '' s = bytes()
n = long(n) n = long(n)
while (n != 0) and (n != -1): while (n != 0) and (n != -1):
s = struct.pack('>I', n & 0xffffffffL) + s s = struct.pack('>I', n & xffffffff) + s
n = n >> 32 n = n >> 32
# strip off leading zeros, FFs # strip off leading zeros, FFs
for i in enumerate(s): for i in enumerate(s):
if (n == 0) and (i[1] != '\000'): if (n == 0) and (i[1] != deflate_zero):
break break
if (n == -1) and (i[1] != '\xff'): if (n == -1) and (i[1] != deflate_ff):
break break
else: else:
# degenerate case, n was either 0 or -1 # degenerate case, n was either 0 or -1
i = (0,) i = (0,)
if n == 0: if n == 0:
s = '\000' s = zero_byte
else: else:
s = '\xff' s = max_byte
s = s[i[0]:] s = s[i[0]:]
if add_sign_padding: if add_sign_padding:
if (n == 0) and (ord(s[0]) >= 0x80): if (n == 0) and (byte_ord(s[0]) >= 0x80):
s = '\x00' + s s = zero_byte + s
if (n == -1) and (ord(s[0]) < 0x80): if (n == -1) and (byte_ord(s[0]) < 0x80):
s = '\xff' + s s = max_byte + s
return s return s
def format_binary_weird(data):
out = ''
for i in enumerate(data):
out += '%02X' % ord(i[1])
if i[0] % 2:
out += ' '
if i[0] % 16 == 15:
out += '\n'
return out
def format_binary(data, prefix=''): def format_binary(data, prefix=''):
x = 0 x = 0
out = [] out = []
@ -113,8 +106,8 @@ def format_binary(data, prefix=''):
return [prefix + x for x in out] return [prefix + x for x in out]
def format_binary_line(data): def format_binary_line(data):
left = ' '.join(['%02X' % ord(c) for c in data]) left = ' '.join(['%02X' % byte_ord(c) for c in data])
right = ''.join([('.%c..' % c)[(ord(c)+63)//95] for c in data]) right = ''.join([('.%c..' % c)[(byte_ord(c)+63)//95] for c in data])
return '%-50s %s' % (left, right) return '%-50s %s' % (left, right)
def hexify(s): def hexify(s):
@ -126,17 +119,20 @@ def unhexify(s):
def safe_string(s): def safe_string(s):
out = '' out = ''
for c in s: for c in s:
if (ord(c) >= 32) and (ord(c) <= 127): if (byte_ord(c) >= 32) and (byte_ord(c) <= 127):
out += c out += c
else: else:
out += '%%%02X' % ord(c) out += '%%%02X' % byte_ord(c)
return out return out
# ''.join([['%%%02X' % ord(c), c][(ord(c) >= 32) and (ord(c) <= 127)] for c in s]) # ''.join([['%%%02X' % ord(c), c][(ord(c) >= 32) and (ord(c) <= 127)] for c in s])
def bit_length(n): def bit_length(n):
try:
return n.bitlength()
except AttributeError:
norm = deflate_long(n, 0) norm = deflate_long(n, 0)
hbyte = ord(norm[0]) hbyte = byte_ord(norm[0])
if hbyte == 0: if hbyte == 0:
return 1 return 1
bitlen = len(norm) * 8 bitlen = len(norm) * 8
@ -157,20 +153,21 @@ def generate_key_bytes(hashclass, salt, key, nbytes):
:param class hashclass: :param class hashclass:
class from `Crypto.Hash` that can be used as a secure hashing function class from `Crypto.Hash` that can be used as a secure hashing function
(like ``MD5`` or ``SHA``). (like ``MD5`` or ``SHA``).
:param str salt: data to salt the hash with. :param salt: data to salt the hash with.
:type salt: byte string
:param str key: human-entered password or passphrase. :param str key: human-entered password or passphrase.
:param int nbytes: number of bytes to generate. :param int nbytes: number of bytes to generate.
:return: Key data `str` :return: Key data `str`
""" """
keydata = '' keydata = bytes()
digest = '' digest = bytes()
if len(salt) > 8: if len(salt) > 8:
salt = salt[:8] salt = salt[:8]
while nbytes > 0: while nbytes > 0:
hash_obj = hashclass.new() hash_obj = hashclass.new()
if len(digest) > 0: if len(digest) > 0:
hash_obj.update(digest) hash_obj.update(digest)
hash_obj.update(key) hash_obj.update(b(key))
hash_obj.update(salt) hash_obj.update(salt)
digest = hash_obj.digest() digest = hash_obj.digest()
size = min(nbytes, len(digest)) size = min(nbytes, len(digest))
@ -271,37 +268,37 @@ def retry_on_signal(function):
while True: while True:
try: try:
return function() return function()
except EnvironmentError, e: except EnvironmentError as e:
if e.errno != errno.EINTR: if e.errno != errno.EINTR:
raise raise
class Counter (object): class Counter (object):
"""Stateful counter for CTR mode crypto""" """Stateful counter for CTR mode crypto"""
def __init__(self, nbits, initial_value=1L, overflow=0L): def __init__(self, nbits, initial_value=long(1), overflow=long(0)):
self.blocksize = nbits / 8 self.blocksize = nbits / 8
self.overflow = overflow self.overflow = overflow
# start with value - 1 so we don't have to store intermediate values when counting # start with value - 1 so we don't have to store intermediate values when counting
# could the iv be 0? # could the iv be 0?
if initial_value == 0: if initial_value == 0:
self.value = array.array('c', '\xFF' * self.blocksize) self.value = array.array('c', max_byte * self.blocksize)
else: else:
x = deflate_long(initial_value - 1, add_sign_padding=False) x = deflate_long(initial_value - 1, add_sign_padding=False)
self.value = array.array('c', '\x00' * (self.blocksize - len(x)) + x) self.value = array.array('c', zero_byte * (self.blocksize - len(x)) + x)
def __call__(self): def __call__(self):
"""Increament the counter and return the new value""" """Increament the counter and return the new value"""
i = self.blocksize - 1 i = self.blocksize - 1
while i > -1: while i > -1:
c = self.value[i] = chr((ord(self.value[i]) + 1) % 256) c = self.value[i] = byte_chr((byte_ord(self.value[i]) + 1) % 256)
if c != '\x00': if c != zero_byte:
return self.value.tostring() return self.value.tostring()
i -= 1 i -= 1
# counter reset # counter reset
x = deflate_long(self.overflow, add_sign_padding=False) x = deflate_long(self.overflow, add_sign_padding=False)
self.value = array.array('c', '\x00' * (self.blocksize - len(x)) + x) self.value = array.array('c', zero_byte * (self.blocksize - len(x)) + x)
return self.value.tostring() return self.value.tostring()
def new(cls, nbits, initial_value=1L, overflow=0L): def new(cls, nbits, initial_value=long(1), overflow=long(0)):
return cls(nbits, initial_value=initial_value, overflow=overflow) return cls(nbits, initial_value=initial_value, overflow=overflow)
new = classmethod(new) new = classmethod(new)

View File

@ -27,6 +27,7 @@ import array
import ctypes.wintypes import ctypes.wintypes
import platform import platform
import struct import struct
from paramiko.util import *
try: try:
import _thread as thread # Python 3.x import _thread as thread # Python 3.x
@ -91,7 +92,7 @@ def _query_pageant(msg):
with pymap: with pymap:
pymap.write(msg) pymap.write(msg)
# Create an array buffer containing the mapped filename # Create an array buffer containing the mapped filename
char_buffer = array.array("c", map_name + '\0') char_buffer = array.array("c", b(map_name) + zero_byte)
char_buffer_address, char_buffer_size = char_buffer.buffer_info() char_buffer_address, char_buffer_size = char_buffer.buffer_info()
# Create a string to use for the SendMessage function call # Create a string to use for the SendMessage function call
cds = COPYDATASTRUCT(_AGENT_COPYDATA_ID, char_buffer_size, cds = COPYDATASTRUCT(_AGENT_COPYDATA_ID, char_buffer_size,

View File

@ -54,7 +54,7 @@ if sys.platform == 'darwin':
setup(name = "paramiko", setup(name = "paramiko",
version = "1.12.2", version = "1.13.0",
description = "SSH2 protocol library", description = "SSH2 protocol library",
author = "Jeff Forcier", author = "Jeff Forcier",
author_email = "jeff@bitprophet.org", author_email = "jeff@bitprophet.org",

33
test.py
View File

@ -29,22 +29,21 @@ import unittest
from optparse import OptionParser from optparse import OptionParser
import paramiko import paramiko
import threading import threading
from paramiko.py3compat import PY2
sys.path.append('tests') sys.path.append('tests')
from test_message import MessageTest from tests.test_message import MessageTest
from test_file import BufferedFileTest from tests.test_file import BufferedFileTest
from test_buffered_pipe import BufferedPipeTest from tests.test_buffered_pipe import BufferedPipeTest
from test_util import UtilTest from tests.test_util import UtilTest
from test_hostkeys import HostKeysTest from tests.test_hostkeys import HostKeysTest
from test_pkey import KeyTest from tests.test_pkey import KeyTest
from test_kex import KexTest from tests.test_kex import KexTest
from test_packetizer import PacketizerTest from tests.test_packetizer import PacketizerTest
from test_auth import AuthTest from tests.test_auth import AuthTest
from test_transport import TransportTest from tests.test_transport import TransportTest
from test_sftp import SFTPTest from tests.test_client import SSHClientTest
from test_sftp_big import BigSFTPTest
from test_client import SSHClientTest
default_host = 'localhost' default_host = 'localhost'
default_user = os.environ.get('USER', 'nobody') default_user = os.environ.get('USER', 'nobody')
@ -109,12 +108,15 @@ def main():
paramiko.util.log_to_file('test.log') paramiko.util.log_to_file('test.log')
if options.use_sftp: if options.use_sftp:
from tests.test_sftp import SFTPTest
if options.use_loopback_sftp: if options.use_loopback_sftp:
SFTPTest.init_loopback() SFTPTest.init_loopback()
else: else:
SFTPTest.init(options.hostname, options.username, options.keyfile, options.password) SFTPTest.init(options.hostname, options.username, options.keyfile, options.password)
if not options.use_big_file: if not options.use_big_file:
SFTPTest.set_big_file_test(False) SFTPTest.set_big_file_test(False)
if options.use_big_file:
from tests.test_sftp_big import BigSFTPTest
suite = unittest.TestSuite() suite = unittest.TestSuite()
suite.addTest(unittest.makeSuite(MessageTest)) suite.addTest(unittest.makeSuite(MessageTest))
@ -147,7 +149,10 @@ def main():
# TODO: make that not a problem, jeez # TODO: make that not a problem, jeez
for thread in threading.enumerate(): for thread in threading.enumerate():
if thread is not threading.currentThread(): if thread is not threading.currentThread():
thread._Thread__stop() if PY2:
thread._Thread__stop()
else:
thread._stop()
# Exit correctly # Exit correctly
if not result.wasSuccessful(): if not result.wasSuccessful():
sys.exit(1) sys.exit(1)

0
tests/__init__.py Normal file
View File

View File

@ -21,6 +21,7 @@
""" """
import threading, socket import threading, socket
from paramiko.common import *
class LoopSocket (object): class LoopSocket (object):
@ -31,7 +32,7 @@ class LoopSocket (object):
""" """
def __init__(self): def __init__(self):
self.__in_buffer = '' self.__in_buffer = bytes()
self.__lock = threading.Lock() self.__lock = threading.Lock()
self.__cv = threading.Condition(self.__lock) self.__cv = threading.Condition(self.__lock)
self.__timeout = None self.__timeout = None
@ -41,11 +42,12 @@ class LoopSocket (object):
self.__unlink() self.__unlink()
try: try:
self.__lock.acquire() self.__lock.acquire()
self.__in_buffer = '' self.__in_buffer = bytes()
finally: finally:
self.__lock.release() self.__lock.release()
def send(self, data): def send(self, data):
data = asbytes(data)
if self.__mate is None: if self.__mate is None:
# EOF # EOF
raise EOFError() raise EOFError()
@ -57,7 +59,7 @@ class LoopSocket (object):
try: try:
if self.__mate is None: if self.__mate is None:
# EOF # EOF
return '' return bytes()
if len(self.__in_buffer) == 0: if len(self.__in_buffer) == 0:
self.__cv.wait(self.__timeout) self.__cv.wait(self.__timeout)
if len(self.__in_buffer) == 0: if len(self.__in_buffer) == 0:

View File

@ -21,8 +21,10 @@ A stub SFTP server for loopback SFTP testing.
""" """
import os import os
import sys
from paramiko import ServerInterface, SFTPServerInterface, SFTPServer, SFTPAttributes, \ from paramiko import ServerInterface, SFTPServerInterface, SFTPServer, SFTPAttributes, \
SFTPHandle, SFTP_OK, AUTH_SUCCESSFUL, OPEN_SUCCEEDED SFTPHandle, SFTP_OK, AUTH_SUCCESSFUL, OPEN_SUCCEEDED
from paramiko.common import *
class StubServer (ServerInterface): class StubServer (ServerInterface):
@ -38,7 +40,7 @@ class StubSFTPHandle (SFTPHandle):
def stat(self): def stat(self):
try: try:
return SFTPAttributes.from_stat(os.fstat(self.readfile.fileno())) return SFTPAttributes.from_stat(os.fstat(self.readfile.fileno()))
except OSError, e: except OSError as e:
return SFTPServer.convert_errno(e.errno) return SFTPServer.convert_errno(e.errno)
def chattr(self, attr): def chattr(self, attr):
@ -47,7 +49,7 @@ class StubSFTPHandle (SFTPHandle):
try: try:
SFTPServer.set_file_attr(self.filename, attr) SFTPServer.set_file_attr(self.filename, attr)
return SFTP_OK return SFTP_OK
except OSError, e: except OSError as e:
return SFTPServer.convert_errno(e.errno) return SFTPServer.convert_errno(e.errno)
@ -69,21 +71,21 @@ class StubSFTPServer (SFTPServerInterface):
attr.filename = fname attr.filename = fname
out.append(attr) out.append(attr)
return out return out
except OSError, e: except OSError as e:
return SFTPServer.convert_errno(e.errno) return SFTPServer.convert_errno(e.errno)
def stat(self, path): def stat(self, path):
path = self._realpath(path) path = self._realpath(path)
try: try:
return SFTPAttributes.from_stat(os.stat(path)) return SFTPAttributes.from_stat(os.stat(path))
except OSError, e: except OSError as e:
return SFTPServer.convert_errno(e.errno) return SFTPServer.convert_errno(e.errno)
def lstat(self, path): def lstat(self, path):
path = self._realpath(path) path = self._realpath(path)
try: try:
return SFTPAttributes.from_stat(os.lstat(path)) return SFTPAttributes.from_stat(os.lstat(path))
except OSError, e: except OSError as e:
return SFTPServer.convert_errno(e.errno) return SFTPServer.convert_errno(e.errno)
def open(self, path, flags, attr): def open(self, path, flags, attr):
@ -97,8 +99,8 @@ class StubSFTPServer (SFTPServerInterface):
else: else:
# os.open() defaults to 0777 which is # os.open() defaults to 0777 which is
# an odd default mode for files # an odd default mode for files
fd = os.open(path, flags, 0666) fd = os.open(path, flags, o666)
except OSError, e: except OSError as e:
return SFTPServer.convert_errno(e.errno) return SFTPServer.convert_errno(e.errno)
if (flags & os.O_CREAT) and (attr is not None): if (flags & os.O_CREAT) and (attr is not None):
attr._flags &= ~attr.FLAG_PERMISSIONS attr._flags &= ~attr.FLAG_PERMISSIONS
@ -118,7 +120,7 @@ class StubSFTPServer (SFTPServerInterface):
fstr = 'rb' fstr = 'rb'
try: try:
f = os.fdopen(fd, fstr) f = os.fdopen(fd, fstr)
except OSError, e: except OSError as e:
return SFTPServer.convert_errno(e.errno) return SFTPServer.convert_errno(e.errno)
fobj = StubSFTPHandle(flags) fobj = StubSFTPHandle(flags)
fobj.filename = path fobj.filename = path
@ -130,7 +132,7 @@ class StubSFTPServer (SFTPServerInterface):
path = self._realpath(path) path = self._realpath(path)
try: try:
os.remove(path) os.remove(path)
except OSError, e: except OSError as e:
return SFTPServer.convert_errno(e.errno) return SFTPServer.convert_errno(e.errno)
return SFTP_OK return SFTP_OK
@ -139,7 +141,7 @@ class StubSFTPServer (SFTPServerInterface):
newpath = self._realpath(newpath) newpath = self._realpath(newpath)
try: try:
os.rename(oldpath, newpath) os.rename(oldpath, newpath)
except OSError, e: except OSError as e:
return SFTPServer.convert_errno(e.errno) return SFTPServer.convert_errno(e.errno)
return SFTP_OK return SFTP_OK
@ -149,7 +151,7 @@ class StubSFTPServer (SFTPServerInterface):
os.mkdir(path) os.mkdir(path)
if attr is not None: if attr is not None:
SFTPServer.set_file_attr(path, attr) SFTPServer.set_file_attr(path, attr)
except OSError, e: except OSError as e:
return SFTPServer.convert_errno(e.errno) return SFTPServer.convert_errno(e.errno)
return SFTP_OK return SFTP_OK
@ -157,7 +159,7 @@ class StubSFTPServer (SFTPServerInterface):
path = self._realpath(path) path = self._realpath(path)
try: try:
os.rmdir(path) os.rmdir(path)
except OSError, e: except OSError as e:
return SFTPServer.convert_errno(e.errno) return SFTPServer.convert_errno(e.errno)
return SFTP_OK return SFTP_OK
@ -165,7 +167,7 @@ class StubSFTPServer (SFTPServerInterface):
path = self._realpath(path) path = self._realpath(path)
try: try:
SFTPServer.set_file_attr(path, attr) SFTPServer.set_file_attr(path, attr)
except OSError, e: except OSError as e:
return SFTPServer.convert_errno(e.errno) return SFTPServer.convert_errno(e.errno)
return SFTP_OK return SFTP_OK
@ -185,7 +187,7 @@ class StubSFTPServer (SFTPServerInterface):
target_path = '<error>' target_path = '<error>'
try: try:
os.symlink(target_path, path) os.symlink(target_path, path)
except OSError, e: except OSError as e:
return SFTPServer.convert_errno(e.errno) return SFTPServer.convert_errno(e.errno)
return SFTP_OK return SFTP_OK
@ -193,7 +195,7 @@ class StubSFTPServer (SFTPServerInterface):
path = self._realpath(path) path = self._realpath(path)
try: try:
symlink = os.readlink(path) symlink = os.readlink(path)
except OSError, e: except OSError as e:
return SFTPServer.convert_errno(e.errno) return SFTPServer.convert_errno(e.errno)
# if it's absolute, remove the root # if it's absolute, remove the root
if os.path.isabs(symlink): if os.path.isabs(symlink):

View File

@ -29,13 +29,17 @@ from paramiko import Transport, ServerInterface, RSAKey, DSSKey, \
AuthenticationException AuthenticationException
from paramiko import AUTH_FAILED, AUTH_PARTIALLY_SUCCESSFUL, AUTH_SUCCESSFUL from paramiko import AUTH_FAILED, AUTH_PARTIALLY_SUCCESSFUL, AUTH_SUCCESSFUL
from paramiko import OPEN_SUCCEEDED, OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED from paramiko import OPEN_SUCCEEDED, OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED
from loop import LoopSocket from paramiko.py3compat import u
from tests.loop import LoopSocket
from tests.util import test_path
_pwd = u('\u2022')
class NullServer (ServerInterface): class NullServer (ServerInterface):
paranoid_did_password = False paranoid_did_password = False
paranoid_did_public_key = False paranoid_did_public_key = False
paranoid_key = DSSKey.from_private_key_file('tests/test_dss.key') paranoid_key = DSSKey.from_private_key_file(test_path('test_dss.key'))
def get_allowed_auths(self, username): def get_allowed_auths(self, username):
if username == 'slowdive': if username == 'slowdive':
@ -64,7 +68,7 @@ class NullServer (ServerInterface):
if self.paranoid_did_public_key: if self.paranoid_did_public_key:
return AUTH_SUCCESSFUL return AUTH_SUCCESSFUL
return AUTH_PARTIALLY_SUCCESSFUL return AUTH_PARTIALLY_SUCCESSFUL
if (username == 'utf8') and (password == u'\u2022'): if (username == 'utf8') and (password == _pwd):
return AUTH_SUCCESSFUL return AUTH_SUCCESSFUL
if (username == 'non-utf8') and (password == '\xff'): if (username == 'non-utf8') and (password == '\xff'):
return AUTH_SUCCESSFUL return AUTH_SUCCESSFUL
@ -110,18 +114,18 @@ class AuthTest (unittest.TestCase):
self.sockc.close() self.sockc.close()
def start_server(self): def start_server(self):
host_key = RSAKey.from_private_key_file('tests/test_rsa.key') host_key = RSAKey.from_private_key_file(test_path('test_rsa.key'))
self.public_host_key = RSAKey(data=str(host_key)) self.public_host_key = RSAKey(data=host_key.asbytes())
self.ts.add_server_key(host_key) self.ts.add_server_key(host_key)
self.event = threading.Event() self.event = threading.Event()
self.server = NullServer() self.server = NullServer()
self.assert_(not self.event.isSet()) self.assertTrue(not self.event.isSet())
self.ts.start_server(self.event, self.server) self.ts.start_server(self.event, self.server)
def verify_finished(self): def verify_finished(self):
self.event.wait(1.0) self.event.wait(1.0)
self.assert_(self.event.isSet()) self.assertTrue(self.event.isSet())
self.assert_(self.ts.is_active()) self.assertTrue(self.ts.is_active())
def test_1_bad_auth_type(self): def test_1_bad_auth_type(self):
""" """
@ -132,11 +136,11 @@ class AuthTest (unittest.TestCase):
try: try:
self.tc.connect(hostkey=self.public_host_key, self.tc.connect(hostkey=self.public_host_key,
username='unknown', password='error') username='unknown', password='error')
self.assert_(False) self.assertTrue(False)
except: except:
etype, evalue, etb = sys.exc_info() etype, evalue, etb = sys.exc_info()
self.assertEquals(BadAuthenticationType, etype) self.assertEqual(BadAuthenticationType, etype)
self.assertEquals(['publickey'], evalue.allowed_types) self.assertEqual(['publickey'], evalue.allowed_types)
def test_2_bad_password(self): def test_2_bad_password(self):
""" """
@ -147,10 +151,10 @@ class AuthTest (unittest.TestCase):
self.tc.connect(hostkey=self.public_host_key) self.tc.connect(hostkey=self.public_host_key)
try: try:
self.tc.auth_password(username='slowdive', password='error') self.tc.auth_password(username='slowdive', password='error')
self.assert_(False) self.assertTrue(False)
except: except:
etype, evalue, etb = sys.exc_info() etype, evalue, etb = sys.exc_info()
self.assert_(issubclass(etype, AuthenticationException)) self.assertTrue(issubclass(etype, AuthenticationException))
self.tc.auth_password(username='slowdive', password='pygmalion') self.tc.auth_password(username='slowdive', password='pygmalion')
self.verify_finished() self.verify_finished()
@ -161,10 +165,10 @@ class AuthTest (unittest.TestCase):
self.start_server() self.start_server()
self.tc.connect(hostkey=self.public_host_key) self.tc.connect(hostkey=self.public_host_key)
remain = self.tc.auth_password(username='paranoid', password='paranoid') remain = self.tc.auth_password(username='paranoid', password='paranoid')
self.assertEquals(['publickey'], remain) self.assertEqual(['publickey'], remain)
key = DSSKey.from_private_key_file('tests/test_dss.key') key = DSSKey.from_private_key_file(test_path('test_dss.key'))
remain = self.tc.auth_publickey(username='paranoid', key=key) remain = self.tc.auth_publickey(username='paranoid', key=key)
self.assertEquals([], remain) self.assertEqual([], remain)
self.verify_finished() self.verify_finished()
def test_4_interactive_auth(self): def test_4_interactive_auth(self):
@ -180,9 +184,9 @@ class AuthTest (unittest.TestCase):
self.got_prompts = prompts self.got_prompts = prompts
return ['cat'] return ['cat']
remain = self.tc.auth_interactive('commie', handler) remain = self.tc.auth_interactive('commie', handler)
self.assertEquals(self.got_title, 'password') self.assertEqual(self.got_title, 'password')
self.assertEquals(self.got_prompts, [('Password', False)]) self.assertEqual(self.got_prompts, [('Password', False)])
self.assertEquals([], remain) self.assertEqual([], remain)
self.verify_finished() self.verify_finished()
def test_5_interactive_auth_fallback(self): def test_5_interactive_auth_fallback(self):
@ -193,7 +197,7 @@ class AuthTest (unittest.TestCase):
self.start_server() self.start_server()
self.tc.connect(hostkey=self.public_host_key) self.tc.connect(hostkey=self.public_host_key)
remain = self.tc.auth_password('commie', 'cat') remain = self.tc.auth_password('commie', 'cat')
self.assertEquals([], remain) self.assertEqual([], remain)
self.verify_finished() self.verify_finished()
def test_6_auth_utf8(self): def test_6_auth_utf8(self):
@ -202,8 +206,8 @@ class AuthTest (unittest.TestCase):
""" """
self.start_server() self.start_server()
self.tc.connect(hostkey=self.public_host_key) self.tc.connect(hostkey=self.public_host_key)
remain = self.tc.auth_password('utf8', u'\u2022') remain = self.tc.auth_password('utf8', _pwd)
self.assertEquals([], remain) self.assertEqual([], remain)
self.verify_finished() self.verify_finished()
def test_7_auth_non_utf8(self): def test_7_auth_non_utf8(self):
@ -214,7 +218,7 @@ class AuthTest (unittest.TestCase):
self.start_server() self.start_server()
self.tc.connect(hostkey=self.public_host_key) self.tc.connect(hostkey=self.public_host_key)
remain = self.tc.auth_password('non-utf8', '\xff') remain = self.tc.auth_password('non-utf8', '\xff')
self.assertEquals([], remain) self.assertEqual([], remain)
self.verify_finished() self.verify_finished()
def test_8_auth_gets_disconnected(self): def test_8_auth_gets_disconnected(self):
@ -228,4 +232,4 @@ class AuthTest (unittest.TestCase):
remain = self.tc.auth_password('bad-server', 'hello') remain = self.tc.auth_password('bad-server', 'hello')
except: except:
etype, evalue, etb = sys.exc_info() etype, evalue, etb = sys.exc_info()
self.assert_(issubclass(etype, AuthenticationException)) self.assertTrue(issubclass(etype, AuthenticationException))

View File

@ -25,8 +25,9 @@ import time
import unittest import unittest
from paramiko.buffered_pipe import BufferedPipe, PipeTimeout from paramiko.buffered_pipe import BufferedPipe, PipeTimeout
from paramiko import pipe from paramiko import pipe
from paramiko.py3compat import b
from util import ParamikoTest from tests.util import ParamikoTest
def delay_thread(pipe): def delay_thread(pipe):
@ -44,39 +45,39 @@ def close_thread(pipe):
class BufferedPipeTest(ParamikoTest): class BufferedPipeTest(ParamikoTest):
def test_1_buffered_pipe(self): def test_1_buffered_pipe(self):
p = BufferedPipe() p = BufferedPipe()
self.assert_(not p.read_ready()) self.assertTrue(not p.read_ready())
p.feed('hello.') p.feed('hello.')
self.assert_(p.read_ready()) self.assertTrue(p.read_ready())
data = p.read(6) data = p.read(6)
self.assertEquals('hello.', data) self.assertEqual(b'hello.', data)
p.feed('plus/minus') p.feed('plus/minus')
self.assertEquals('plu', p.read(3)) self.assertEqual(b'plu', p.read(3))
self.assertEquals('s/m', p.read(3)) self.assertEqual(b's/m', p.read(3))
self.assertEquals('inus', p.read(4)) self.assertEqual(b'inus', p.read(4))
p.close() p.close()
self.assert_(not p.read_ready()) self.assertTrue(not p.read_ready())
self.assertEquals('', p.read(1)) self.assertEqual(b'', p.read(1))
def test_2_delay(self): def test_2_delay(self):
p = BufferedPipe() p = BufferedPipe()
self.assert_(not p.read_ready()) self.assertTrue(not p.read_ready())
threading.Thread(target=delay_thread, args=(p,)).start() threading.Thread(target=delay_thread, args=(p,)).start()
self.assertEquals('a', p.read(1, 0.1)) self.assertEqual(b'a', p.read(1, 0.1))
try: try:
p.read(1, 0.1) p.read(1, 0.1)
self.assert_(False) self.assertTrue(False)
except PipeTimeout: except PipeTimeout:
pass pass
self.assertEquals('b', p.read(1, 1.0)) self.assertEqual(b'b', p.read(1, 1.0))
self.assertEquals('', p.read(1)) self.assertEqual(b'', p.read(1))
def test_3_close_while_reading(self): def test_3_close_while_reading(self):
p = BufferedPipe() p = BufferedPipe()
threading.Thread(target=close_thread, args=(p,)).start() threading.Thread(target=close_thread, args=(p,)).start()
data = p.read(1, 1.0) data = p.read(1, 1.0)
self.assertEquals('', data) self.assertEqual(b'', data)
def test_4_or_pipe(self): def test_4_or_pipe(self):
p = pipe.make_pipe() p = pipe.make_pipe()

View File

@ -20,16 +20,14 @@
Some unit tests for SSHClient. Some unit tests for SSHClient.
""" """
from __future__ import with_statement # Python 2.5 support
import socket import socket
from tempfile import mkstemp
import threading import threading
import time
import unittest import unittest
import weakref import weakref
import warnings import warnings
import os import os
from binascii import hexlify from tests.util import test_path
import paramiko import paramiko
@ -46,7 +44,7 @@ class NullServer (paramiko.ServerInterface):
return paramiko.AUTH_FAILED return paramiko.AUTH_FAILED
def check_auth_publickey(self, username, key): def check_auth_publickey(self, username, key):
if (key.get_name() == 'ssh-dss') and (hexlify(key.get_fingerprint()) == '4478f0b9a23cc5182009ff755bc1d26c'): if (key.get_name() == 'ssh-dss') and key.get_fingerprint() == b'\x44\x78\xf0\xb9\xa2\x3c\xc5\x18\x20\x09\xff\x75\x5b\xc1\xd2\x6c':
return paramiko.AUTH_SUCCESSFUL return paramiko.AUTH_SUCCESSFUL
return paramiko.AUTH_FAILED return paramiko.AUTH_FAILED
@ -67,8 +65,6 @@ class SSHClientTest (unittest.TestCase):
self.sockl.listen(1) self.sockl.listen(1)
self.addr, self.port = self.sockl.getsockname() self.addr, self.port = self.sockl.getsockname()
self.event = threading.Event() self.event = threading.Event()
thread = threading.Thread(target=self._run)
thread.start()
def tearDown(self): def tearDown(self):
for attr in "tc ts socks sockl".split(): for attr in "tc ts socks sockl".split():
@ -78,28 +74,28 @@ class SSHClientTest (unittest.TestCase):
def _run(self): def _run(self):
self.socks, addr = self.sockl.accept() self.socks, addr = self.sockl.accept()
self.ts = paramiko.Transport(self.socks) self.ts = paramiko.Transport(self.socks)
host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key') host_key = paramiko.RSAKey.from_private_key_file(test_path('test_rsa.key'))
self.ts.add_server_key(host_key) self.ts.add_server_key(host_key)
server = NullServer() server = NullServer()
self.ts.start_server(self.event, server) self.ts.start_server(self.event, server)
def test_1_client(self): def test_1_client(self):
""" """
verify that the SSHClient stuff works too. verify that the SSHClient stuff works too.
""" """
host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key') threading.Thread(target=self._run).start()
public_host_key = paramiko.RSAKey(data=str(host_key)) host_key = paramiko.RSAKey.from_private_key_file(test_path('test_rsa.key'))
public_host_key = paramiko.RSAKey(data=host_key.asbytes())
self.tc = paramiko.SSHClient() self.tc = paramiko.SSHClient()
self.tc.get_host_keys().add('[%s]:%d' % (self.addr, self.port), 'ssh-rsa', public_host_key) self.tc.get_host_keys().add('[%s]:%d' % (self.addr, self.port), 'ssh-rsa', public_host_key)
self.tc.connect(self.addr, self.port, username='slowdive', password='pygmalion') self.tc.connect(self.addr, self.port, username='slowdive', password='pygmalion')
self.event.wait(1.0) self.event.wait(1.0)
self.assert_(self.event.isSet()) self.assertTrue(self.event.isSet())
self.assert_(self.ts.is_active()) self.assertTrue(self.ts.is_active())
self.assertEquals('slowdive', self.ts.get_username()) self.assertEqual('slowdive', self.ts.get_username())
self.assertEquals(True, self.ts.is_authenticated()) self.assertEqual(True, self.ts.is_authenticated())
stdin, stdout, stderr = self.tc.exec_command('yes') stdin, stdout, stderr = self.tc.exec_command('yes')
schan = self.ts.accept(1.0) schan = self.ts.accept(1.0)
@ -108,10 +104,10 @@ class SSHClientTest (unittest.TestCase):
schan.send_stderr('This is on stderr.\n') schan.send_stderr('This is on stderr.\n')
schan.close() schan.close()
self.assertEquals('Hello there.\n', stdout.readline()) self.assertEqual('Hello there.\n', stdout.readline())
self.assertEquals('', stdout.readline()) self.assertEqual('', stdout.readline())
self.assertEquals('This is on stderr.\n', stderr.readline()) self.assertEqual('This is on stderr.\n', stderr.readline())
self.assertEquals('', stderr.readline()) self.assertEqual('', stderr.readline())
stdin.close() stdin.close()
stdout.close() stdout.close()
@ -121,18 +117,19 @@ class SSHClientTest (unittest.TestCase):
""" """
verify that SSHClient works with a DSA key. verify that SSHClient works with a DSA key.
""" """
host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key') threading.Thread(target=self._run).start()
public_host_key = paramiko.RSAKey(data=str(host_key)) host_key = paramiko.RSAKey.from_private_key_file(test_path('test_rsa.key'))
public_host_key = paramiko.RSAKey(data=host_key.asbytes())
self.tc = paramiko.SSHClient() self.tc = paramiko.SSHClient()
self.tc.get_host_keys().add('[%s]:%d' % (self.addr, self.port), 'ssh-rsa', public_host_key) self.tc.get_host_keys().add('[%s]:%d' % (self.addr, self.port), 'ssh-rsa', public_host_key)
self.tc.connect(self.addr, self.port, username='slowdive', key_filename='tests/test_dss.key') self.tc.connect(self.addr, self.port, username='slowdive', key_filename=test_path('test_dss.key'))
self.event.wait(1.0) self.event.wait(1.0)
self.assert_(self.event.isSet()) self.assertTrue(self.event.isSet())
self.assert_(self.ts.is_active()) self.assertTrue(self.ts.is_active())
self.assertEquals('slowdive', self.ts.get_username()) self.assertEqual('slowdive', self.ts.get_username())
self.assertEquals(True, self.ts.is_authenticated()) self.assertEqual(True, self.ts.is_authenticated())
stdin, stdout, stderr = self.tc.exec_command('yes') stdin, stdout, stderr = self.tc.exec_command('yes')
schan = self.ts.accept(1.0) schan = self.ts.accept(1.0)
@ -141,10 +138,10 @@ class SSHClientTest (unittest.TestCase):
schan.send_stderr('This is on stderr.\n') schan.send_stderr('This is on stderr.\n')
schan.close() schan.close()
self.assertEquals('Hello there.\n', stdout.readline()) self.assertEqual('Hello there.\n', stdout.readline())
self.assertEquals('', stdout.readline()) self.assertEqual('', stdout.readline())
self.assertEquals('This is on stderr.\n', stderr.readline()) self.assertEqual('This is on stderr.\n', stderr.readline())
self.assertEquals('', stderr.readline()) self.assertEqual('', stderr.readline())
stdin.close() stdin.close()
stdout.close() stdout.close()
@ -154,38 +151,40 @@ class SSHClientTest (unittest.TestCase):
""" """
verify that SSHClient accepts and tries multiple key files. verify that SSHClient accepts and tries multiple key files.
""" """
host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key') threading.Thread(target=self._run).start()
public_host_key = paramiko.RSAKey(data=str(host_key)) host_key = paramiko.RSAKey.from_private_key_file(test_path('test_rsa.key'))
public_host_key = paramiko.RSAKey(data=host_key.asbytes())
self.tc = paramiko.SSHClient() self.tc = paramiko.SSHClient()
self.tc.get_host_keys().add('[%s]:%d' % (self.addr, self.port), 'ssh-rsa', public_host_key) self.tc.get_host_keys().add('[%s]:%d' % (self.addr, self.port), 'ssh-rsa', public_host_key)
self.tc.connect(self.addr, self.port, username='slowdive', key_filename=[ 'tests/test_rsa.key', 'tests/test_dss.key' ]) self.tc.connect(self.addr, self.port, username='slowdive', key_filename=[ test_path('test_rsa.key'), test_path('test_dss.key') ])
self.event.wait(1.0) self.event.wait(1.0)
self.assert_(self.event.isSet()) self.assertTrue(self.event.isSet())
self.assert_(self.ts.is_active()) self.assertTrue(self.ts.is_active())
self.assertEquals('slowdive', self.ts.get_username()) self.assertEqual('slowdive', self.ts.get_username())
self.assertEquals(True, self.ts.is_authenticated()) self.assertEqual(True, self.ts.is_authenticated())
def test_4_auto_add_policy(self): def test_4_auto_add_policy(self):
""" """
verify that SSHClient's AutoAddPolicy works. verify that SSHClient's AutoAddPolicy works.
""" """
host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key') threading.Thread(target=self._run).start()
public_host_key = paramiko.RSAKey(data=str(host_key)) host_key = paramiko.RSAKey.from_private_key_file(test_path('test_rsa.key'))
public_host_key = paramiko.RSAKey(data=host_key.asbytes())
self.tc = paramiko.SSHClient() self.tc = paramiko.SSHClient()
self.tc.set_missing_host_key_policy(paramiko.AutoAddPolicy()) self.tc.set_missing_host_key_policy(paramiko.AutoAddPolicy())
self.assertEquals(0, len(self.tc.get_host_keys())) self.assertEqual(0, len(self.tc.get_host_keys()))
self.tc.connect(self.addr, self.port, username='slowdive', password='pygmalion') self.tc.connect(self.addr, self.port, username='slowdive', password='pygmalion')
self.event.wait(1.0) self.event.wait(1.0)
self.assert_(self.event.isSet()) self.assertTrue(self.event.isSet())
self.assert_(self.ts.is_active()) self.assertTrue(self.ts.is_active())
self.assertEquals('slowdive', self.ts.get_username()) self.assertEqual('slowdive', self.ts.get_username())
self.assertEquals(True, self.ts.is_authenticated()) self.assertEqual(True, self.ts.is_authenticated())
self.assertEquals(1, len(self.tc.get_host_keys())) self.assertEqual(1, len(self.tc.get_host_keys()))
self.assertEquals(public_host_key, self.tc.get_host_keys()['[%s]:%d' % (self.addr, self.port)]['ssh-rsa']) self.assertEqual(public_host_key, self.tc.get_host_keys()['[%s]:%d' % (self.addr, self.port)]['ssh-rsa'])
def test_5_save_host_keys(self): def test_5_save_host_keys(self):
""" """
@ -193,9 +192,10 @@ class SSHClientTest (unittest.TestCase):
""" """
warnings.filterwarnings('ignore', 'tempnam.*') warnings.filterwarnings('ignore', 'tempnam.*')
host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key') host_key = paramiko.RSAKey.from_private_key_file(test_path('test_rsa.key'))
public_host_key = paramiko.RSAKey(data=str(host_key)) public_host_key = paramiko.RSAKey(data=host_key.asbytes())
localname = os.tempnam() fd, localname = mkstemp()
os.close(fd)
client = paramiko.SSHClient() client = paramiko.SSHClient()
self.assertEquals(0, len(client.get_host_keys())) self.assertEquals(0, len(client.get_host_keys()))
@ -218,24 +218,32 @@ class SSHClientTest (unittest.TestCase):
verify that when an SSHClient is collected, its transport (and the verify that when an SSHClient is collected, its transport (and the
transport's packetizer) is closed. transport's packetizer) is closed.
""" """
host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key') threading.Thread(target=self._run).start()
public_host_key = paramiko.RSAKey(data=str(host_key)) host_key = paramiko.RSAKey.from_private_key_file(test_path('test_rsa.key'))
public_host_key = paramiko.RSAKey(data=host_key.asbytes())
self.tc = paramiko.SSHClient() self.tc = paramiko.SSHClient()
self.tc.set_missing_host_key_policy(paramiko.AutoAddPolicy()) self.tc.set_missing_host_key_policy(paramiko.AutoAddPolicy())
self.assertEquals(0, len(self.tc.get_host_keys())) self.assertEqual(0, len(self.tc.get_host_keys()))
self.tc.connect(self.addr, self.port, username='slowdive', password='pygmalion') self.tc.connect(self.addr, self.port, username='slowdive', password='pygmalion')
self.event.wait(1.0) self.event.wait(1.0)
self.assert_(self.event.isSet()) self.assertTrue(self.event.isSet())
self.assert_(self.ts.is_active()) self.assertTrue(self.ts.is_active())
p = weakref.ref(self.tc._transport.packetizer) p = weakref.ref(self.tc._transport.packetizer)
self.assert_(p() is not None) self.assertTrue(p() is not None)
self.tc.close()
del self.tc del self.tc
# hrm, sometimes p isn't cleared right away. why is that?
st = time.time()
while (time.time() - st < 5.0) and (p() is not None):
time.sleep(0.1)
self.assert_(p() is None)
# hrm, sometimes p isn't cleared right away. why is that?
#st = time.time()
#while (time.time() - st < 5.0) and (p() is not None):
# time.sleep(0.1)
# instead of dumbly waiting for the GC to collect, force a collection
# to see whether the SSHClient object is deallocated correctly
import gc
gc.collect()
self.assertTrue(p() is None)

View File

@ -22,6 +22,7 @@ Some unit tests for the BufferedFile abstraction.
import unittest import unittest
from paramiko.file import BufferedFile from paramiko.file import BufferedFile
from paramiko.common import *
class LoopbackFile (BufferedFile): class LoopbackFile (BufferedFile):
@ -31,7 +32,7 @@ class LoopbackFile (BufferedFile):
def __init__(self, mode='r', bufsize=-1): def __init__(self, mode='r', bufsize=-1):
BufferedFile.__init__(self) BufferedFile.__init__(self)
self._set_mode(mode, bufsize) self._set_mode(mode, bufsize)
self.buffer = '' self.buffer = bytes()
def _read(self, size): def _read(self, size):
if len(self.buffer) == 0: if len(self.buffer) == 0:
@ -53,7 +54,7 @@ class BufferedFileTest (unittest.TestCase):
f = LoopbackFile('r') f = LoopbackFile('r')
try: try:
f.write('hi') f.write('hi')
self.assert_(False, 'no exception on write to read-only file') self.assertTrue(False, 'no exception on write to read-only file')
except: except:
pass pass
f.close() f.close()
@ -61,7 +62,7 @@ class BufferedFileTest (unittest.TestCase):
f = LoopbackFile('w') f = LoopbackFile('w')
try: try:
f.read(1) f.read(1)
self.assert_(False, 'no exception to read from write-only file') self.assertTrue(False, 'no exception to read from write-only file')
except: except:
pass pass
f.close() f.close()
@ -80,12 +81,12 @@ class BufferedFileTest (unittest.TestCase):
f.close() f.close()
try: try:
f.readline() f.readline()
self.assert_(False, 'no exception on readline of closed file') self.assertTrue(False, 'no exception on readline of closed file')
except IOError: except IOError:
pass pass
self.assert_('\n' in f.newlines) self.assertTrue(linefeed_byte in f.newlines)
self.assert_('\r\n' in f.newlines) self.assertTrue(crlf in f.newlines)
self.assert_('\r' not in f.newlines) self.assertTrue(cr_byte not in f.newlines)
def test_3_lf(self): def test_3_lf(self):
""" """
@ -97,7 +98,7 @@ class BufferedFileTest (unittest.TestCase):
f.write('\nSecond.\r\n') f.write('\nSecond.\r\n')
self.assertEqual(f.readline(), 'Second.\n') self.assertEqual(f.readline(), 'Second.\n')
f.close() f.close()
self.assertEqual(f.newlines, '\r\n') self.assertEqual(f.newlines, crlf)
def test_4_write(self): def test_4_write(self):
""" """

View File

@ -25,6 +25,7 @@ from binascii import hexlify
import os import os
import unittest import unittest
import paramiko import paramiko
from paramiko.py3compat import b, decodebytes
test_hosts_file = """\ test_hosts_file = """\
@ -36,12 +37,12 @@ BGQ3GQ/Fc7SX6gkpXkwcZryoi4kNFhHu5LvHcZPdxXV1D+uTMfGS1eyd2Yz/DoNWXNAl8TI0cAsW\
5ymME3bQ4J/k1IKxCtz/bAlAqFgKoc+EolMziDYqWIATtW0rYTJvzGAzTmMj80/QpsFH+Pc2M= 5ymME3bQ4J/k1IKxCtz/bAlAqFgKoc+EolMziDYqWIATtW0rYTJvzGAzTmMj80/QpsFH+Pc2M=
""" """
keyblob = """\ keyblob = b"""\
AAAAB3NzaC1yc2EAAAABIwAAAIEA8bP1ZA7DCZDB9J0s50l31MBGQ3GQ/Fc7SX6gkpXkwcZryoi4k\ AAAAB3NzaC1yc2EAAAABIwAAAIEA8bP1ZA7DCZDB9J0s50l31MBGQ3GQ/Fc7SX6gkpXkwcZryoi4k\
NFhHu5LvHcZPdxXV1D+uTMfGS1eyd2Yz/DoNWXNAl8TI0cAsW5ymME3bQ4J/k1IKxCtz/bAlAqFgK\ NFhHu5LvHcZPdxXV1D+uTMfGS1eyd2Yz/DoNWXNAl8TI0cAsW5ymME3bQ4J/k1IKxCtz/bAlAqFgK\
oc+EolMziDYqWIATtW0rYTJvzGAzTmMj80/QpsFH+Pc2M=""" oc+EolMziDYqWIATtW0rYTJvzGAzTmMj80/QpsFH+Pc2M="""
keyblob_dss = """\ keyblob_dss = b"""\
AAAAB3NzaC1kc3MAAACBAOeBpgNnfRzr/twmAQRu2XwWAp3CFtrVnug6s6fgwj/oLjYbVtjAy6pl/\ AAAAB3NzaC1kc3MAAACBAOeBpgNnfRzr/twmAQRu2XwWAp3CFtrVnug6s6fgwj/oLjYbVtjAy6pl/\
h0EKCWx2rf1IetyNsTxWrniA9I6HeDj65X1FyDkg6g8tvCnaNB8Xp/UUhuzHuGsMIipRxBxw9LF60\ h0EKCWx2rf1IetyNsTxWrniA9I6HeDj65X1FyDkg6g8tvCnaNB8Xp/UUhuzHuGsMIipRxBxw9LF60\
8EqZcj1E3ytktoW5B5OcjrkEoz3xG7C+rpIjYvAAAAFQDwz4UnmsGiSNu5iqjn3uTzwUpshwAAAIE\ 8EqZcj1E3ytktoW5B5OcjrkEoz3xG7C+rpIjYvAAAAFQDwz4UnmsGiSNu5iqjn3uTzwUpshwAAAIE\
@ -55,51 +56,50 @@ Ngw3qIch/WgRmMHy4kBq1SsXMjQCte1So6HBMvBPIW5SiMTmjCfZZiw4AYHK+B/JaOwaG9yRg2Ejg\
class HostKeysTest (unittest.TestCase): class HostKeysTest (unittest.TestCase):
def setUp(self): def setUp(self):
f = open('hostfile.temp', 'w') with open('hostfile.temp', 'w') as f:
f.write(test_hosts_file) f.write(test_hosts_file)
f.close()
def tearDown(self): def tearDown(self):
os.unlink('hostfile.temp') os.unlink('hostfile.temp')
def test_1_load(self): def test_1_load(self):
hostdict = paramiko.HostKeys('hostfile.temp') hostdict = paramiko.HostKeys('hostfile.temp')
self.assertEquals(2, len(hostdict)) self.assertEqual(2, len(hostdict))
self.assertEquals(1, len(hostdict.values()[0])) self.assertEqual(1, len(list(hostdict.values())[0]))
self.assertEquals(1, len(hostdict.values()[1])) self.assertEqual(1, len(list(hostdict.values())[1]))
fp = hexlify(hostdict['secure.example.com']['ssh-rsa'].get_fingerprint()).upper() fp = hexlify(hostdict['secure.example.com']['ssh-rsa'].get_fingerprint()).upper()
self.assertEquals('E6684DB30E109B67B70FF1DC5C7F1363', fp) self.assertEqual(b'E6684DB30E109B67B70FF1DC5C7F1363', fp)
def test_2_add(self): def test_2_add(self):
hostdict = paramiko.HostKeys('hostfile.temp') hostdict = paramiko.HostKeys('hostfile.temp')
hh = '|1|BMsIC6cUIP2zBuXR3t2LRcJYjzM=|hpkJMysjTk/+zzUUzxQEa2ieq6c=' hh = '|1|BMsIC6cUIP2zBuXR3t2LRcJYjzM=|hpkJMysjTk/+zzUUzxQEa2ieq6c='
key = paramiko.RSAKey(data=base64.decodestring(keyblob)) key = paramiko.RSAKey(data=decodebytes(keyblob))
hostdict.add(hh, 'ssh-rsa', key) hostdict.add(hh, 'ssh-rsa', key)
self.assertEquals(3, len(hostdict)) self.assertEqual(3, len(list(hostdict)))
x = hostdict['foo.example.com'] x = hostdict['foo.example.com']
fp = hexlify(x['ssh-rsa'].get_fingerprint()).upper() fp = hexlify(x['ssh-rsa'].get_fingerprint()).upper()
self.assertEquals('7EC91BB336CB6D810B124B1353C32396', fp) self.assertEqual(b'7EC91BB336CB6D810B124B1353C32396', fp)
self.assert_(hostdict.check('foo.example.com', key)) self.assertTrue(hostdict.check('foo.example.com', key))
def test_3_dict(self): def test_3_dict(self):
hostdict = paramiko.HostKeys('hostfile.temp') hostdict = paramiko.HostKeys('hostfile.temp')
self.assert_('secure.example.com' in hostdict) self.assertTrue('secure.example.com' in hostdict)
self.assert_('not.example.com' not in hostdict) self.assertTrue('not.example.com' not in hostdict)
self.assert_(hostdict.has_key('secure.example.com')) self.assertTrue('secure.example.com' in hostdict)
self.assert_(not hostdict.has_key('not.example.com')) self.assertTrue('not.example.com' not in hostdict)
x = hostdict.get('secure.example.com', None) x = hostdict.get('secure.example.com', None)
self.assert_(x is not None) self.assertTrue(x is not None)
fp = hexlify(x['ssh-rsa'].get_fingerprint()).upper() fp = hexlify(x['ssh-rsa'].get_fingerprint()).upper()
self.assertEquals('E6684DB30E109B67B70FF1DC5C7F1363', fp) self.assertEqual(b'E6684DB30E109B67B70FF1DC5C7F1363', fp)
i = 0 i = 0
for key in hostdict: for key in hostdict:
i += 1 i += 1
self.assertEquals(2, i) self.assertEqual(2, i)
def test_4_dict_set(self): def test_4_dict_set(self):
hostdict = paramiko.HostKeys('hostfile.temp') hostdict = paramiko.HostKeys('hostfile.temp')
key = paramiko.RSAKey(data=base64.decodestring(keyblob)) key = paramiko.RSAKey(data=decodebytes(keyblob))
key_dss = paramiko.DSSKey(data=base64.decodestring(keyblob_dss)) key_dss = paramiko.DSSKey(data=decodebytes(keyblob_dss))
hostdict['secure.example.com'] = { hostdict['secure.example.com'] = {
'ssh-rsa': key, 'ssh-rsa': key,
'ssh-dss': key_dss 'ssh-dss': key_dss
@ -107,11 +107,11 @@ class HostKeysTest (unittest.TestCase):
hostdict['fake.example.com'] = {} hostdict['fake.example.com'] = {}
hostdict['fake.example.com']['ssh-rsa'] = key hostdict['fake.example.com']['ssh-rsa'] = key
self.assertEquals(3, len(hostdict)) self.assertEqual(3, len(hostdict))
self.assertEquals(2, len(hostdict.values()[0])) self.assertEqual(2, len(list(hostdict.values())[0]))
self.assertEquals(1, len(hostdict.values()[1])) self.assertEqual(1, len(list(hostdict.values())[1]))
self.assertEquals(1, len(hostdict.values()[2])) self.assertEqual(1, len(list(hostdict.values())[2]))
fp = hexlify(hostdict['secure.example.com']['ssh-rsa'].get_fingerprint()).upper() fp = hexlify(hostdict['secure.example.com']['ssh-rsa'].get_fingerprint()).upper()
self.assertEquals('7EC91BB336CB6D810B124B1353C32396', fp) self.assertEqual(b'7EC91BB336CB6D810B124B1353C32396', fp)
fp = hexlify(hostdict['secure.example.com']['ssh-dss'].get_fingerprint()).upper() fp = hexlify(hostdict['secure.example.com']['ssh-dss'].get_fingerprint()).upper()
self.assertEquals('4478F0B9A23CC5182009FF755BC1D26C', fp) self.assertEqual(b'4478F0B9A23CC5182009FF755BC1D26C', fp)

View File

@ -26,22 +26,25 @@ import paramiko.util
from paramiko.kex_group1 import KexGroup1 from paramiko.kex_group1 import KexGroup1
from paramiko.kex_gex import KexGex from paramiko.kex_gex import KexGex
from paramiko import Message from paramiko import Message
from paramiko.common import *
class FakeRng (object): class FakeRng (object):
def read(self, n): def read(self, n):
return chr(0xcc) * n return byte_chr(0xcc) * n
class FakeKey (object): class FakeKey (object):
def __str__(self): def __str__(self):
return 'fake-key' return 'fake-key'
def asbytes(self):
return b'fake-key'
def sign_ssh_data(self, rng, H): def sign_ssh_data(self, rng, H):
return 'fake-sig' return b'fake-sig'
class FakeModulusPack (object): class FakeModulusPack (object):
P = 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFFL P = 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF
G = 2 G = 2
def get_modulus(self, min, ask, max): def get_modulus(self, min, ask, max):
return self.G, self.P return self.G, self.P
@ -75,7 +78,7 @@ class FakeTransport (object):
class KexTest (unittest.TestCase): class KexTest (unittest.TestCase):
K = 14730343317708716439807310032871972459448364195094179797249681733965528989482751523943515690110179031004049109375612685505881911274101441415545039654102474376472240501616988799699744135291070488314748284283496055223852115360852283821334858541043710301057312858051901453919067023103730011648890038847384890504L K = 14730343317708716439807310032871972459448364195094179797249681733965528989482751523943515690110179031004049109375612685505881911274101441415545039654102474376472240501616988799699744135291070488314748284283496055223852115360852283821334858541043710301057312858051901453919067023103730011648890038847384890504
def setUp(self): def setUp(self):
pass pass
@ -88,9 +91,9 @@ class KexTest (unittest.TestCase):
transport.server_mode = False transport.server_mode = False
kex = KexGroup1(transport) kex = KexGroup1(transport)
kex.start_kex() kex.start_kex()
x = '1E000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D4' x = b'1E000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D4'
self.assertEquals(x, hexlify(str(transport._message)).upper()) self.assertEqual(x, hexlify(transport._message.asbytes()).upper())
self.assertEquals((paramiko.kex_group1._MSG_KEXDH_REPLY,), transport._expect) self.assertEqual((paramiko.kex_group1._MSG_KEXDH_REPLY,), transport._expect)
# fake "reply" # fake "reply"
msg = Message() msg = Message()
@ -99,47 +102,47 @@ class KexTest (unittest.TestCase):
msg.add_string('fake-sig') msg.add_string('fake-sig')
msg.rewind() msg.rewind()
kex.parse_next(paramiko.kex_group1._MSG_KEXDH_REPLY, msg) kex.parse_next(paramiko.kex_group1._MSG_KEXDH_REPLY, msg)
H = '03079780F3D3AD0B3C6DB30C8D21685F367A86D2' H = b'03079780F3D3AD0B3C6DB30C8D21685F367A86D2'
self.assertEquals(self.K, transport._K) self.assertEqual(self.K, transport._K)
self.assertEquals(H, hexlify(transport._H).upper()) self.assertEqual(H, hexlify(transport._H).upper())
self.assertEquals(('fake-host-key', 'fake-sig'), transport._verify) self.assertEqual((b'fake-host-key', b'fake-sig'), transport._verify)
self.assert_(transport._activated) self.assertTrue(transport._activated)
def test_2_group1_server(self): def test_2_group1_server(self):
transport = FakeTransport() transport = FakeTransport()
transport.server_mode = True transport.server_mode = True
kex = KexGroup1(transport) kex = KexGroup1(transport)
kex.start_kex() kex.start_kex()
self.assertEquals((paramiko.kex_group1._MSG_KEXDH_INIT,), transport._expect) self.assertEqual((paramiko.kex_group1._MSG_KEXDH_INIT,), transport._expect)
msg = Message() msg = Message()
msg.add_mpint(69) msg.add_mpint(69)
msg.rewind() msg.rewind()
kex.parse_next(paramiko.kex_group1._MSG_KEXDH_INIT, msg) kex.parse_next(paramiko.kex_group1._MSG_KEXDH_INIT, msg)
H = 'B16BF34DD10945EDE84E9C1EF24A14BFDC843389' H = b'B16BF34DD10945EDE84E9C1EF24A14BFDC843389'
x = '1F0000000866616B652D6B6579000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D40000000866616B652D736967' x = b'1F0000000866616B652D6B6579000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D40000000866616B652D736967'
self.assertEquals(self.K, transport._K) self.assertEqual(self.K, transport._K)
self.assertEquals(H, hexlify(transport._H).upper()) self.assertEqual(H, hexlify(transport._H).upper())
self.assertEquals(x, hexlify(str(transport._message)).upper()) self.assertEqual(x, hexlify(transport._message.asbytes()).upper())
self.assert_(transport._activated) self.assertTrue(transport._activated)
def test_3_gex_client(self): def test_3_gex_client(self):
transport = FakeTransport() transport = FakeTransport()
transport.server_mode = False transport.server_mode = False
kex = KexGex(transport) kex = KexGex(transport)
kex.start_kex() kex.start_kex()
x = '22000004000000080000002000' x = b'22000004000000080000002000'
self.assertEquals(x, hexlify(str(transport._message)).upper()) self.assertEqual(x, hexlify(transport._message.asbytes()).upper())
self.assertEquals((paramiko.kex_gex._MSG_KEXDH_GEX_GROUP,), transport._expect) self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_GROUP,), transport._expect)
msg = Message() msg = Message()
msg.add_mpint(FakeModulusPack.P) msg.add_mpint(FakeModulusPack.P)
msg.add_mpint(FakeModulusPack.G) msg.add_mpint(FakeModulusPack.G)
msg.rewind() msg.rewind()
kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_GROUP, msg) kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_GROUP, msg)
x = '20000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D4' x = b'20000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D4'
self.assertEquals(x, hexlify(str(transport._message)).upper()) self.assertEqual(x, hexlify(transport._message.asbytes()).upper())
self.assertEquals((paramiko.kex_gex._MSG_KEXDH_GEX_REPLY,), transport._expect) self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_REPLY,), transport._expect)
msg = Message() msg = Message()
msg.add_string('fake-host-key') msg.add_string('fake-host-key')
@ -147,29 +150,29 @@ class KexTest (unittest.TestCase):
msg.add_string('fake-sig') msg.add_string('fake-sig')
msg.rewind() msg.rewind()
kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_REPLY, msg) kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_REPLY, msg)
H = 'A265563F2FA87F1A89BF007EE90D58BE2E4A4BD0' H = b'A265563F2FA87F1A89BF007EE90D58BE2E4A4BD0'
self.assertEquals(self.K, transport._K) self.assertEqual(self.K, transport._K)
self.assertEquals(H, hexlify(transport._H).upper()) self.assertEqual(H, hexlify(transport._H).upper())
self.assertEquals(('fake-host-key', 'fake-sig'), transport._verify) self.assertEqual((b'fake-host-key', b'fake-sig'), transport._verify)
self.assert_(transport._activated) self.assertTrue(transport._activated)
def test_4_gex_old_client(self): def test_4_gex_old_client(self):
transport = FakeTransport() transport = FakeTransport()
transport.server_mode = False transport.server_mode = False
kex = KexGex(transport) kex = KexGex(transport)
kex.start_kex(_test_old_style=True) kex.start_kex(_test_old_style=True)
x = '1E00000800' x = b'1E00000800'
self.assertEquals(x, hexlify(str(transport._message)).upper()) self.assertEqual(x, hexlify(transport._message.asbytes()).upper())
self.assertEquals((paramiko.kex_gex._MSG_KEXDH_GEX_GROUP,), transport._expect) self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_GROUP,), transport._expect)
msg = Message() msg = Message()
msg.add_mpint(FakeModulusPack.P) msg.add_mpint(FakeModulusPack.P)
msg.add_mpint(FakeModulusPack.G) msg.add_mpint(FakeModulusPack.G)
msg.rewind() msg.rewind()
kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_GROUP, msg) kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_GROUP, msg)
x = '20000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D4' x = b'20000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D4'
self.assertEquals(x, hexlify(str(transport._message)).upper()) self.assertEqual(x, hexlify(transport._message.asbytes()).upper())
self.assertEquals((paramiko.kex_gex._MSG_KEXDH_GEX_REPLY,), transport._expect) self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_REPLY,), transport._expect)
msg = Message() msg = Message()
msg.add_string('fake-host-key') msg.add_string('fake-host-key')
@ -177,18 +180,18 @@ class KexTest (unittest.TestCase):
msg.add_string('fake-sig') msg.add_string('fake-sig')
msg.rewind() msg.rewind()
kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_REPLY, msg) kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_REPLY, msg)
H = '807F87B269EF7AC5EC7E75676808776A27D5864C' H = b'807F87B269EF7AC5EC7E75676808776A27D5864C'
self.assertEquals(self.K, transport._K) self.assertEqual(self.K, transport._K)
self.assertEquals(H, hexlify(transport._H).upper()) self.assertEqual(H, hexlify(transport._H).upper())
self.assertEquals(('fake-host-key', 'fake-sig'), transport._verify) self.assertEqual((b'fake-host-key', b'fake-sig'), transport._verify)
self.assert_(transport._activated) self.assertTrue(transport._activated)
def test_5_gex_server(self): def test_5_gex_server(self):
transport = FakeTransport() transport = FakeTransport()
transport.server_mode = True transport.server_mode = True
kex = KexGex(transport) kex = KexGex(transport)
kex.start_kex() kex.start_kex()
self.assertEquals((paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST, paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST_OLD), transport._expect) self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST, paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST_OLD), transport._expect)
msg = Message() msg = Message()
msg.add_int(1024) msg.add_int(1024)
@ -196,45 +199,45 @@ class KexTest (unittest.TestCase):
msg.add_int(4096) msg.add_int(4096)
msg.rewind() msg.rewind()
kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST, msg) kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST, msg)
x = '1F0000008100FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF0000000102' x = b'1F0000008100FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF0000000102'
self.assertEquals(x, hexlify(str(transport._message)).upper()) self.assertEqual(x, hexlify(transport._message.asbytes()).upper())
self.assertEquals((paramiko.kex_gex._MSG_KEXDH_GEX_INIT,), transport._expect) self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_INIT,), transport._expect)
msg = Message() msg = Message()
msg.add_mpint(12345) msg.add_mpint(12345)
msg.rewind() msg.rewind()
kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_INIT, msg) kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_INIT, msg)
K = 67592995013596137876033460028393339951879041140378510871612128162185209509220726296697886624612526735888348020498716482757677848959420073720160491114319163078862905400020959196386947926388406687288901564192071077389283980347784184487280885335302632305026248574716290537036069329724382811853044654824945750581L K = 67592995013596137876033460028393339951879041140378510871612128162185209509220726296697886624612526735888348020498716482757677848959420073720160491114319163078862905400020959196386947926388406687288901564192071077389283980347784184487280885335302632305026248574716290537036069329724382811853044654824945750581
H = 'CE754197C21BF3452863B4F44D0B3951F12516EF' H = b'CE754197C21BF3452863B4F44D0B3951F12516EF'
x = '210000000866616B652D6B6579000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D40000000866616B652D736967' x = b'210000000866616B652D6B6579000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D40000000866616B652D736967'
self.assertEquals(K, transport._K) self.assertEqual(K, transport._K)
self.assertEquals(H, hexlify(transport._H).upper()) self.assertEqual(H, hexlify(transport._H).upper())
self.assertEquals(x, hexlify(str(transport._message)).upper()) self.assertEqual(x, hexlify(transport._message.asbytes()).upper())
self.assert_(transport._activated) self.assertTrue(transport._activated)
def test_6_gex_server_with_old_client(self): def test_6_gex_server_with_old_client(self):
transport = FakeTransport() transport = FakeTransport()
transport.server_mode = True transport.server_mode = True
kex = KexGex(transport) kex = KexGex(transport)
kex.start_kex() kex.start_kex()
self.assertEquals((paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST, paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST_OLD), transport._expect) self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST, paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST_OLD), transport._expect)
msg = Message() msg = Message()
msg.add_int(2048) msg.add_int(2048)
msg.rewind() msg.rewind()
kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST_OLD, msg) kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST_OLD, msg)
x = '1F0000008100FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF0000000102' x = b'1F0000008100FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF0000000102'
self.assertEquals(x, hexlify(str(transport._message)).upper()) self.assertEqual(x, hexlify(transport._message.asbytes()).upper())
self.assertEquals((paramiko.kex_gex._MSG_KEXDH_GEX_INIT,), transport._expect) self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_INIT,), transport._expect)
msg = Message() msg = Message()
msg.add_mpint(12345) msg.add_mpint(12345)
msg.rewind() msg.rewind()
kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_INIT, msg) kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_INIT, msg)
K = 67592995013596137876033460028393339951879041140378510871612128162185209509220726296697886624612526735888348020498716482757677848959420073720160491114319163078862905400020959196386947926388406687288901564192071077389283980347784184487280885335302632305026248574716290537036069329724382811853044654824945750581L K = 67592995013596137876033460028393339951879041140378510871612128162185209509220726296697886624612526735888348020498716482757677848959420073720160491114319163078862905400020959196386947926388406687288901564192071077389283980347784184487280885335302632305026248574716290537036069329724382811853044654824945750581
H = 'B41A06B2E59043CEFC1AE16EC31F1E2D12EC455B' H = b'B41A06B2E59043CEFC1AE16EC31F1E2D12EC455B'
x = '210000000866616B652D6B6579000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D40000000866616B652D736967' x = b'210000000866616B652D6B6579000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D40000000866616B652D736967'
self.assertEquals(K, transport._K) self.assertEqual(K, transport._K)
self.assertEquals(H, hexlify(transport._H).upper()) self.assertEqual(H, hexlify(transport._H).upper())
self.assertEquals(x, hexlify(str(transport._message)).upper()) self.assertEqual(x, hexlify(transport._message.asbytes()).upper())
self.assert_(transport._activated) self.assertTrue(transport._activated)

View File

@ -22,14 +22,15 @@ Some unit tests for ssh protocol message blocks.
import unittest import unittest
from paramiko.message import Message from paramiko.message import Message
from paramiko.common import *
class MessageTest (unittest.TestCase): class MessageTest (unittest.TestCase):
__a = '\x00\x00\x00\x17\x07\x60\xe0\x90\x00\x00\x00\x01q\x00\x00\x00\x05hello\x00\x00\x03\xe8' + ('x' * 1000) __a = b'\x00\x00\x00\x17\x07\x60\xe0\x90\x00\x00\x00\x01\x71\x00\x00\x00\x05\x68\x65\x6c\x6c\x6f\x00\x00\x03\xe8' + b'x' * 1000
__b = '\x01\x00\xf3\x00\x3f\x00\x00\x00\x10huey,dewey,louie' __b = b'\x01\x00\xf3\x00\x3f\x00\x00\x00\x10\x68\x75\x65\x79\x2c\x64\x65\x77\x65\x79\x2c\x6c\x6f\x75\x69\x65'
__c = '\x00\x00\x00\x00\x00\x00\x00\x05\x00\x00\xf5\xe4\xd3\xc2\xb1\x09\x00\x00\x00\x01\x11\x00\x00\x00\x07\x00\xf5\xe4\xd3\xc2\xb1\x09\x00\x00\x00\x06\x9a\x1b\x2c\x3d\x4e\xf7' __c = b'\x00\x00\x00\x00\x00\x00\x00\x05\x00\x00\xf5\xe4\xd3\xc2\xb1\x09\x00\x00\x00\x01\x11\x00\x00\x00\x07\x00\xf5\xe4\xd3\xc2\xb1\x09\x00\x00\x00\x06\x9a\x1b\x2c\x3d\x4e\xf7'
__d = '\x00\x00\x00\x05\x00\x00\x00\x05\x11\x22\x33\x44\x55\x01\x00\x00\x00\x03cat\x00\x00\x00\x03a,b' __d = b'\x00\x00\x00\x05\xff\x00\x00\x00\x05\x11\x22\x33\x44\x55\xff\x00\x00\x00\x0a\x00\xf0\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x03\x63\x61\x74\x00\x00\x00\x03\x61\x2c\x62'
def test_1_encode(self): def test_1_encode(self):
msg = Message() msg = Message()
@ -38,63 +39,65 @@ class MessageTest (unittest.TestCase):
msg.add_string('q') msg.add_string('q')
msg.add_string('hello') msg.add_string('hello')
msg.add_string('x' * 1000) msg.add_string('x' * 1000)
self.assertEquals(str(msg), self.__a) self.assertEqual(msg.asbytes(), self.__a)
msg = Message() msg = Message()
msg.add_boolean(True) msg.add_boolean(True)
msg.add_boolean(False) msg.add_boolean(False)
msg.add_byte('\xf3') msg.add_byte(byte_chr(0xf3))
msg.add_bytes('\x00\x3f')
msg.add_bytes(zero_byte + byte_chr(0x3f))
msg.add_list(['huey', 'dewey', 'louie']) msg.add_list(['huey', 'dewey', 'louie'])
self.assertEquals(str(msg), self.__b) self.assertEqual(msg.asbytes(), self.__b)
msg = Message() msg = Message()
msg.add_int64(5) msg.add_int64(5)
msg.add_int64(0xf5e4d3c2b109L) msg.add_int64(0xf5e4d3c2b109)
msg.add_mpint(17) msg.add_mpint(17)
msg.add_mpint(0xf5e4d3c2b109L) msg.add_mpint(0xf5e4d3c2b109)
msg.add_mpint(-0x65e4d3c2b109L) msg.add_mpint(-0x65e4d3c2b109)
self.assertEquals(str(msg), self.__c) self.assertEqual(msg.asbytes(), self.__c)
def test_2_decode(self): def test_2_decode(self):
msg = Message(self.__a) msg = Message(self.__a)
self.assertEquals(msg.get_int(), 23) self.assertEqual(msg.get_int(), 23)
self.assertEquals(msg.get_int(), 123789456) self.assertEqual(msg.get_int(), 123789456)
self.assertEquals(msg.get_string(), 'q') self.assertEqual(msg.get_text(), 'q')
self.assertEquals(msg.get_string(), 'hello') self.assertEqual(msg.get_text(), 'hello')
self.assertEquals(msg.get_string(), 'x' * 1000) self.assertEqual(msg.get_text(), 'x' * 1000)
msg = Message(self.__b) msg = Message(self.__b)
self.assertEquals(msg.get_boolean(), True) self.assertEqual(msg.get_boolean(), True)
self.assertEquals(msg.get_boolean(), False) self.assertEqual(msg.get_boolean(), False)
self.assertEquals(msg.get_byte(), '\xf3') self.assertEqual(msg.get_byte(), byte_chr(0xf3))
self.assertEquals(msg.get_bytes(2), '\x00\x3f') self.assertEqual(msg.get_bytes(2), zero_byte + byte_chr(0x3f))
self.assertEquals(msg.get_list(), ['huey', 'dewey', 'louie']) self.assertEqual(msg.get_list(), ['huey', 'dewey', 'louie'])
msg = Message(self.__c) msg = Message(self.__c)
self.assertEquals(msg.get_int64(), 5) self.assertEqual(msg.get_int64(), 5)
self.assertEquals(msg.get_int64(), 0xf5e4d3c2b109L) self.assertEqual(msg.get_int64(), 0xf5e4d3c2b109)
self.assertEquals(msg.get_mpint(), 17) self.assertEqual(msg.get_mpint(), 17)
self.assertEquals(msg.get_mpint(), 0xf5e4d3c2b109L) self.assertEqual(msg.get_mpint(), 0xf5e4d3c2b109)
self.assertEquals(msg.get_mpint(), -0x65e4d3c2b109L) self.assertEqual(msg.get_mpint(), -0x65e4d3c2b109)
def test_3_add(self): def test_3_add(self):
msg = Message() msg = Message()
msg.add(5) msg.add(5)
msg.add(0x1122334455L) msg.add(0x1122334455)
msg.add(0xf00000000000000000)
msg.add(True) msg.add(True)
msg.add('cat') msg.add('cat')
msg.add(['a', 'b']) msg.add(['a', 'b'])
self.assertEquals(str(msg), self.__d) self.assertEqual(msg.asbytes(), self.__d)
def test_4_misc(self): def test_4_misc(self):
msg = Message(self.__d) msg = Message(self.__d)
self.assertEquals(msg.get_int(), 5) self.assertEqual(msg.get_int(), 5)
self.assertEquals(msg.get_mpint(), 0x1122334455L) self.assertEqual(msg.get_int(), 0x1122334455)
self.assertEquals(msg.get_so_far(), self.__d[:13]) self.assertEqual(msg.get_int(), 0xf00000000000000000)
self.assertEquals(msg.get_remainder(), self.__d[13:]) self.assertEqual(msg.get_so_far(), self.__d[:29])
self.assertEqual(msg.get_remainder(), self.__d[29:])
msg.rewind() msg.rewind()
self.assertEquals(msg.get_int(), 5) self.assertEqual(msg.get_int(), 5)
self.assertEquals(msg.get_so_far(), self.__d[:4]) self.assertEqual(msg.get_so_far(), self.__d[:4])
self.assertEquals(msg.get_remainder(), self.__d[4:]) self.assertEqual(msg.get_remainder(), self.__d[4:])

View File

@ -21,10 +21,15 @@ Some unit tests for the ssh2 protocol in Transport.
""" """
import unittest import unittest
from loop import LoopSocket from tests.loop import LoopSocket
from Crypto.Cipher import AES from Crypto.Cipher import AES
from Crypto.Hash import SHA, HMAC from Crypto.Hash import SHA, HMAC
from paramiko import Message, Packetizer, util from paramiko import Message, Packetizer, util
from paramiko.common import *
x55 = byte_chr(0x55)
x1f = byte_chr(0x1f)
class PacketizerTest (unittest.TestCase): class PacketizerTest (unittest.TestCase):
@ -35,21 +40,21 @@ class PacketizerTest (unittest.TestCase):
p = Packetizer(wsock) p = Packetizer(wsock)
p.set_log(util.get_logger('paramiko.transport')) p.set_log(util.get_logger('paramiko.transport'))
p.set_hexdump(True) p.set_hexdump(True)
cipher = AES.new('\x00' * 16, AES.MODE_CBC, '\x55' * 16) cipher = AES.new(zero_byte * 16, AES.MODE_CBC, x55 * 16)
p.set_outbound_cipher(cipher, 16, SHA, 12, '\x1f' * 20) p.set_outbound_cipher(cipher, 16, SHA, 12, x1f * 20)
# message has to be at least 16 bytes long, so we'll have at least one # message has to be at least 16 bytes long, so we'll have at least one
# block of data encrypted that contains zero random padding bytes # block of data encrypted that contains zero random padding bytes
m = Message() m = Message()
m.add_byte(chr(100)) m.add_byte(byte_chr(100))
m.add_int(100) m.add_int(100)
m.add_int(1) m.add_int(1)
m.add_int(900) m.add_int(900)
p.send_message(m) p.send_message(m)
data = rsock.recv(100) data = rsock.recv(100)
# 32 + 12 bytes of MAC = 44 # 32 + 12 bytes of MAC = 44
self.assertEquals(44, len(data)) self.assertEqual(44, len(data))
self.assertEquals('\x43\x91\x97\xbd\x5b\x50\xac\x25\x87\xc2\xc4\x6b\xc7\xe9\x38\xc0', data[:16]) self.assertEqual(b'\x43\x91\x97\xbd\x5b\x50\xac\x25\x87\xc2\xc4\x6b\xc7\xe9\x38\xc0', data[:16])
def test_2_read (self): def test_2_read (self):
rsock = LoopSocket() rsock = LoopSocket()
@ -58,13 +63,11 @@ class PacketizerTest (unittest.TestCase):
p = Packetizer(rsock) p = Packetizer(rsock)
p.set_log(util.get_logger('paramiko.transport')) p.set_log(util.get_logger('paramiko.transport'))
p.set_hexdump(True) p.set_hexdump(True)
cipher = AES.new('\x00' * 16, AES.MODE_CBC, '\x55' * 16) cipher = AES.new(zero_byte * 16, AES.MODE_CBC, x55 * 16)
p.set_inbound_cipher(cipher, 16, SHA, 12, '\x1f' * 20) p.set_inbound_cipher(cipher, 16, SHA, 12, x1f * 20)
wsock.send(b'\x43\x91\x97\xbd\x5b\x50\xac\x25\x87\xc2\xc4\x6b\xc7\xe9\x38\xc0\x90\xd2\x16\x56\x0d\x71\x73\x61\x38\x7c\x4c\x3d\xfb\x97\x7d\xe2\x6e\x03\xb1\xa0\xc2\x1c\xd6\x41\x41\x4c\xb4\x59')
wsock.send('C\x91\x97\xbd[P\xac%\x87\xc2\xc4k\xc7\xe98\xc0' + \
'\x90\xd2\x16V\rqsa8|L=\xfb\x97}\xe2n\x03\xb1\xa0\xc2\x1c\xd6AAL\xb4Y')
cmd, m = p.read_message() cmd, m = p.read_message()
self.assertEquals(100, cmd) self.assertEqual(100, cmd)
self.assertEquals(100, m.get_int()) self.assertEqual(100, m.get_int())
self.assertEquals(1, m.get_int()) self.assertEqual(1, m.get_int())
self.assertEquals(900, m.get_int()) self.assertEqual(900, m.get_int())

View File

@ -20,11 +20,11 @@
Some unit tests for public/private key objects. Some unit tests for public/private key objects.
""" """
from binascii import hexlify, unhexlify from binascii import hexlify
import StringIO
import unittest import unittest
from paramiko import RSAKey, DSSKey, ECDSAKey, Message, util from paramiko import RSAKey, DSSKey, ECDSAKey, Message, util
from paramiko.common import rng from paramiko.common import rng, StringIO, byte_chr, b, bytes
from tests.util import test_path
# from openssh's ssh-keygen # from openssh's ssh-keygen
PUB_RSA = 'ssh-rsa AAAAB3NzaC1yc2EAAAABIwAAAIEA049W6geFpmsljTwfvI1UmKWWJPNFI74+vNKTk4dmzkQY2yAMs6FhlvhlI8ysU4oj71ZsRYMecHbBbxdN79+JRFVYTKaLqjwGENeTd+yv4q+V2PvZv3fLnzApI3l7EJCqhWwJUHJ1jAkZzqDx0tyOL4uoZpww3nmE0kb3y21tH4c=' PUB_RSA = 'ssh-rsa AAAAB3NzaC1yc2EAAAABIwAAAIEA049W6geFpmsljTwfvI1UmKWWJPNFI74+vNKTk4dmzkQY2yAMs6FhlvhlI8ysU4oj71ZsRYMecHbBbxdN79+JRFVYTKaLqjwGENeTd+yv4q+V2PvZv3fLnzApI3l7EJCqhWwJUHJ1jAkZzqDx0tyOL4uoZpww3nmE0kb3y21tH4c='
@ -77,6 +77,9 @@ ADRvOqQ5R98Sxst765CAqXmRtz8vwoD96g==
-----END EC PRIVATE KEY----- -----END EC PRIVATE KEY-----
""" """
x1234 = b'\x01\x02\x03\x04'
class KeyTest (unittest.TestCase): class KeyTest (unittest.TestCase):
def setUp(self): def setUp(self):
@ -87,164 +90,164 @@ class KeyTest (unittest.TestCase):
def test_1_generate_key_bytes(self): def test_1_generate_key_bytes(self):
from Crypto.Hash import MD5 from Crypto.Hash import MD5
key = util.generate_key_bytes(MD5, '\x01\x02\x03\x04', 'happy birthday', 30) key = util.generate_key_bytes(MD5, x1234, 'happy birthday', 30)
exp = unhexlify('61E1F272F4C1C4561586BD322498C0E924672780F47BB37DDA7D54019E64') exp = b'\x61\xE1\xF2\x72\xF4\xC1\xC4\x56\x15\x86\xBD\x32\x24\x98\xC0\xE9\x24\x67\x27\x80\xF4\x7B\xB3\x7D\xDA\x7D\x54\x01\x9E\x64'
self.assertEquals(exp, key) self.assertEqual(exp, key)
def test_2_load_rsa(self): def test_2_load_rsa(self):
key = RSAKey.from_private_key_file('tests/test_rsa.key') key = RSAKey.from_private_key_file(test_path('test_rsa.key'))
self.assertEquals('ssh-rsa', key.get_name()) self.assertEqual('ssh-rsa', key.get_name())
exp_rsa = FINGER_RSA.split()[1].replace(':', '') exp_rsa = b(FINGER_RSA.split()[1].replace(':', ''))
my_rsa = hexlify(key.get_fingerprint()) my_rsa = hexlify(key.get_fingerprint())
self.assertEquals(exp_rsa, my_rsa) self.assertEqual(exp_rsa, my_rsa)
self.assertEquals(PUB_RSA.split()[1], key.get_base64()) self.assertEqual(PUB_RSA.split()[1], key.get_base64())
self.assertEquals(1024, key.get_bits()) self.assertEqual(1024, key.get_bits())
s = StringIO.StringIO() s = StringIO()
key.write_private_key(s) key.write_private_key(s)
self.assertEquals(RSA_PRIVATE_OUT, s.getvalue()) self.assertEqual(RSA_PRIVATE_OUT, s.getvalue())
s.seek(0) s.seek(0)
key2 = RSAKey.from_private_key(s) key2 = RSAKey.from_private_key(s)
self.assertEquals(key, key2) self.assertEqual(key, key2)
def test_3_load_rsa_password(self): def test_3_load_rsa_password(self):
key = RSAKey.from_private_key_file('tests/test_rsa_password.key', 'television') key = RSAKey.from_private_key_file(test_path('test_rsa_password.key'), 'television')
self.assertEquals('ssh-rsa', key.get_name()) self.assertEqual('ssh-rsa', key.get_name())
exp_rsa = FINGER_RSA.split()[1].replace(':', '') exp_rsa = b(FINGER_RSA.split()[1].replace(':', ''))
my_rsa = hexlify(key.get_fingerprint()) my_rsa = hexlify(key.get_fingerprint())
self.assertEquals(exp_rsa, my_rsa) self.assertEqual(exp_rsa, my_rsa)
self.assertEquals(PUB_RSA.split()[1], key.get_base64()) self.assertEqual(PUB_RSA.split()[1], key.get_base64())
self.assertEquals(1024, key.get_bits()) self.assertEqual(1024, key.get_bits())
def test_4_load_dss(self): def test_4_load_dss(self):
key = DSSKey.from_private_key_file('tests/test_dss.key') key = DSSKey.from_private_key_file(test_path('test_dss.key'))
self.assertEquals('ssh-dss', key.get_name()) self.assertEqual('ssh-dss', key.get_name())
exp_dss = FINGER_DSS.split()[1].replace(':', '') exp_dss = b(FINGER_DSS.split()[1].replace(':', ''))
my_dss = hexlify(key.get_fingerprint()) my_dss = hexlify(key.get_fingerprint())
self.assertEquals(exp_dss, my_dss) self.assertEqual(exp_dss, my_dss)
self.assertEquals(PUB_DSS.split()[1], key.get_base64()) self.assertEqual(PUB_DSS.split()[1], key.get_base64())
self.assertEquals(1024, key.get_bits()) self.assertEqual(1024, key.get_bits())
s = StringIO.StringIO() s = StringIO()
key.write_private_key(s) key.write_private_key(s)
self.assertEquals(DSS_PRIVATE_OUT, s.getvalue()) self.assertEqual(DSS_PRIVATE_OUT, s.getvalue())
s.seek(0) s.seek(0)
key2 = DSSKey.from_private_key(s) key2 = DSSKey.from_private_key(s)
self.assertEquals(key, key2) self.assertEqual(key, key2)
def test_5_load_dss_password(self): def test_5_load_dss_password(self):
key = DSSKey.from_private_key_file('tests/test_dss_password.key', 'television') key = DSSKey.from_private_key_file(test_path('test_dss_password.key'), 'television')
self.assertEquals('ssh-dss', key.get_name()) self.assertEqual('ssh-dss', key.get_name())
exp_dss = FINGER_DSS.split()[1].replace(':', '') exp_dss = b(FINGER_DSS.split()[1].replace(':', ''))
my_dss = hexlify(key.get_fingerprint()) my_dss = hexlify(key.get_fingerprint())
self.assertEquals(exp_dss, my_dss) self.assertEqual(exp_dss, my_dss)
self.assertEquals(PUB_DSS.split()[1], key.get_base64()) self.assertEqual(PUB_DSS.split()[1], key.get_base64())
self.assertEquals(1024, key.get_bits()) self.assertEqual(1024, key.get_bits())
def test_6_compare_rsa(self): def test_6_compare_rsa(self):
# verify that the private & public keys compare equal # verify that the private & public keys compare equal
key = RSAKey.from_private_key_file('tests/test_rsa.key') key = RSAKey.from_private_key_file(test_path('test_rsa.key'))
self.assertEquals(key, key) self.assertEqual(key, key)
pub = RSAKey(data=str(key)) pub = RSAKey(data=key.asbytes())
self.assert_(key.can_sign()) self.assertTrue(key.can_sign())
self.assert_(not pub.can_sign()) self.assertTrue(not pub.can_sign())
self.assertEquals(key, pub) self.assertEqual(key, pub)
def test_7_compare_dss(self): def test_7_compare_dss(self):
# verify that the private & public keys compare equal # verify that the private & public keys compare equal
key = DSSKey.from_private_key_file('tests/test_dss.key') key = DSSKey.from_private_key_file(test_path('test_dss.key'))
self.assertEquals(key, key) self.assertEqual(key, key)
pub = DSSKey(data=str(key)) pub = DSSKey(data=key.asbytes())
self.assert_(key.can_sign()) self.assertTrue(key.can_sign())
self.assert_(not pub.can_sign()) self.assertTrue(not pub.can_sign())
self.assertEquals(key, pub) self.assertEqual(key, pub)
def test_8_sign_rsa(self): def test_8_sign_rsa(self):
# verify that the rsa private key can sign and verify # verify that the rsa private key can sign and verify
key = RSAKey.from_private_key_file('tests/test_rsa.key') key = RSAKey.from_private_key_file(test_path('test_rsa.key'))
msg = key.sign_ssh_data(rng, 'ice weasels') msg = key.sign_ssh_data(rng, b'ice weasels')
self.assert_(type(msg) is Message) self.assertTrue(type(msg) is Message)
msg.rewind() msg.rewind()
self.assertEquals('ssh-rsa', msg.get_string()) self.assertEqual('ssh-rsa', msg.get_text())
sig = ''.join([chr(int(x, 16)) for x in SIGNED_RSA.split(':')]) sig = bytes().join([byte_chr(int(x, 16)) for x in SIGNED_RSA.split(':')])
self.assertEquals(sig, msg.get_string()) self.assertEqual(sig, msg.get_binary())
msg.rewind() msg.rewind()
pub = RSAKey(data=str(key)) pub = RSAKey(data=key.asbytes())
self.assert_(pub.verify_ssh_sig('ice weasels', msg)) self.assertTrue(pub.verify_ssh_sig(b'ice weasels', msg))
def test_9_sign_dss(self): def test_9_sign_dss(self):
# verify that the dss private key can sign and verify # verify that the dss private key can sign and verify
key = DSSKey.from_private_key_file('tests/test_dss.key') key = DSSKey.from_private_key_file(test_path('test_dss.key'))
msg = key.sign_ssh_data(rng, 'ice weasels') msg = key.sign_ssh_data(rng, b'ice weasels')
self.assert_(type(msg) is Message) self.assertTrue(type(msg) is Message)
msg.rewind() msg.rewind()
self.assertEquals('ssh-dss', msg.get_string()) self.assertEqual('ssh-dss', msg.get_text())
# can't do the same test as we do for RSA, because DSS signatures # can't do the same test as we do for RSA, because DSS signatures
# are usually different each time. but we can test verification # are usually different each time. but we can test verification
# anyway so it's ok. # anyway so it's ok.
self.assertEquals(40, len(msg.get_string())) self.assertEqual(40, len(msg.get_binary()))
msg.rewind() msg.rewind()
pub = DSSKey(data=str(key)) pub = DSSKey(data=key.asbytes())
self.assert_(pub.verify_ssh_sig('ice weasels', msg)) self.assertTrue(pub.verify_ssh_sig(b'ice weasels', msg))
def test_A_generate_rsa(self): def test_A_generate_rsa(self):
key = RSAKey.generate(1024) key = RSAKey.generate(1024)
msg = key.sign_ssh_data(rng, 'jerri blank') msg = key.sign_ssh_data(rng, b'jerri blank')
msg.rewind() msg.rewind()
self.assert_(key.verify_ssh_sig('jerri blank', msg)) self.assertTrue(key.verify_ssh_sig(b'jerri blank', msg))
def test_B_generate_dss(self): def test_B_generate_dss(self):
key = DSSKey.generate(1024) key = DSSKey.generate(1024)
msg = key.sign_ssh_data(rng, 'jerri blank') msg = key.sign_ssh_data(rng, b'jerri blank')
msg.rewind() msg.rewind()
self.assert_(key.verify_ssh_sig('jerri blank', msg)) self.assertTrue(key.verify_ssh_sig(b'jerri blank', msg))
def test_10_load_ecdsa(self): def test_10_load_ecdsa(self):
key = ECDSAKey.from_private_key_file('tests/test_ecdsa.key') key = ECDSAKey.from_private_key_file(test_path('test_ecdsa.key'))
self.assertEquals('ecdsa-sha2-nistp256', key.get_name()) self.assertEqual('ecdsa-sha2-nistp256', key.get_name())
exp_ecdsa = FINGER_ECDSA.split()[1].replace(':', '') exp_ecdsa = b(FINGER_ECDSA.split()[1].replace(':', ''))
my_ecdsa = hexlify(key.get_fingerprint()) my_ecdsa = hexlify(key.get_fingerprint())
self.assertEquals(exp_ecdsa, my_ecdsa) self.assertEqual(exp_ecdsa, my_ecdsa)
self.assertEquals(PUB_ECDSA.split()[1], key.get_base64()) self.assertEqual(PUB_ECDSA.split()[1], key.get_base64())
self.assertEquals(256, key.get_bits()) self.assertEqual(256, key.get_bits())
s = StringIO.StringIO() s = StringIO()
key.write_private_key(s) key.write_private_key(s)
self.assertEquals(ECDSA_PRIVATE_OUT, s.getvalue()) self.assertEqual(ECDSA_PRIVATE_OUT, s.getvalue())
s.seek(0) s.seek(0)
key2 = ECDSAKey.from_private_key(s) key2 = ECDSAKey.from_private_key(s)
self.assertEquals(key, key2) self.assertEqual(key, key2)
def test_11_load_ecdsa_password(self): def test_11_load_ecdsa_password(self):
key = ECDSAKey.from_private_key_file('tests/test_ecdsa_password.key', 'television') key = ECDSAKey.from_private_key_file(test_path('test_ecdsa_password.key'), b'television')
self.assertEquals('ecdsa-sha2-nistp256', key.get_name()) self.assertEqual('ecdsa-sha2-nistp256', key.get_name())
exp_ecdsa = FINGER_ECDSA.split()[1].replace(':', '') exp_ecdsa = b(FINGER_ECDSA.split()[1].replace(':', ''))
my_ecdsa = hexlify(key.get_fingerprint()) my_ecdsa = hexlify(key.get_fingerprint())
self.assertEquals(exp_ecdsa, my_ecdsa) self.assertEqual(exp_ecdsa, my_ecdsa)
self.assertEquals(PUB_ECDSA.split()[1], key.get_base64()) self.assertEqual(PUB_ECDSA.split()[1], key.get_base64())
self.assertEquals(256, key.get_bits()) self.assertEqual(256, key.get_bits())
def test_12_compare_ecdsa(self): def test_12_compare_ecdsa(self):
# verify that the private & public keys compare equal # verify that the private & public keys compare equal
key = ECDSAKey.from_private_key_file('tests/test_ecdsa.key') key = ECDSAKey.from_private_key_file(test_path('test_ecdsa.key'))
self.assertEquals(key, key) self.assertEqual(key, key)
pub = ECDSAKey(data=str(key)) pub = ECDSAKey(data=key.asbytes())
self.assert_(key.can_sign()) self.assertTrue(key.can_sign())
self.assert_(not pub.can_sign()) self.assertTrue(not pub.can_sign())
self.assertEquals(key, pub) self.assertEqual(key, pub)
def test_13_sign_ecdsa(self): def test_13_sign_ecdsa(self):
# verify that the rsa private key can sign and verify # verify that the rsa private key can sign and verify
key = ECDSAKey.from_private_key_file('tests/test_ecdsa.key') key = ECDSAKey.from_private_key_file(test_path('test_ecdsa.key'))
msg = key.sign_ssh_data(rng, 'ice weasels') msg = key.sign_ssh_data(rng, b'ice weasels')
self.assert_(type(msg) is Message) self.assertTrue(type(msg) is Message)
msg.rewind() msg.rewind()
self.assertEquals('ecdsa-sha2-nistp256', msg.get_string()) self.assertEqual('ecdsa-sha2-nistp256', msg.get_text())
# ECDSA signatures, like DSS signatures, tend to be different # ECDSA signatures, like DSS signatures, tend to be different
# each time, so we can't compare against a "known correct" # each time, so we can't compare against a "known correct"
# signature. # signature.
# Even the length of the signature can change. # Even the length of the signature can change.
msg.rewind() msg.rewind()
pub = ECDSAKey(data=str(key)) pub = ECDSAKey(data=key.asbytes())
self.assert_(pub.verify_ssh_sig('ice weasels', msg)) self.assertTrue(pub.verify_ssh_sig(b'ice weasels', msg))

View File

@ -23,19 +23,18 @@ a real actual sftp server is contacted, and a new folder is created there to
do test file operations in (so no existing files will be harmed). do test file operations in (so no existing files will be harmed).
""" """
from __future__ import with_statement
from binascii import hexlify from binascii import hexlify
import os import os
import warnings import warnings
import sys
import threading import threading
import unittest import unittest
import StringIO from tempfile import mkstemp
import paramiko import paramiko
from stub_sftp import StubServer, StubSFTPServer from paramiko.common import *
from loop import LoopSocket from tests.stub_sftp import StubServer, StubSFTPServer
from tests.loop import LoopSocket
from tests.util import test_path
from paramiko.sftp_attr import SFTPAttributes from paramiko.sftp_attr import SFTPAttributes
ARTICLE = ''' ARTICLE = '''
@ -70,6 +69,10 @@ FOLDER = os.environ.get('TEST_FOLDER', 'temp-testing000')
sftp = None sftp = None
tc = None tc = None
g_big_file_test = True g_big_file_test = True
# we need to use eval(compile()) here because Py3.2 doesn't support the 'u' marker for unicode
# this test is the only line in the entire program that has to be treated specially to support Py3.2
unicode_folder = eval(compile(r"u'\u00fcnic\u00f8de'" if PY2 else r"'\u00fcnic\u00f8de'", 'test_sftp.py', 'eval'))
utf8_folder = b'/\xc3\xbcnic\xc3\xb8\x64\x65'
def get_sftp(): def get_sftp():
@ -121,7 +124,7 @@ class SFTPTest (unittest.TestCase):
tc = paramiko.Transport(sockc) tc = paramiko.Transport(sockc)
ts = paramiko.Transport(socks) ts = paramiko.Transport(socks)
host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key') host_key = paramiko.RSAKey.from_private_key_file(test_path('test_rsa.key'))
ts.add_server_key(host_key) ts.add_server_key(host_key)
event = threading.Event() event = threading.Event()
server = StubServer() server = StubServer()
@ -140,7 +143,7 @@ class SFTPTest (unittest.TestCase):
def setUp(self): def setUp(self):
global FOLDER global FOLDER
for i in xrange(1000): for i in range(1000):
FOLDER = FOLDER[:-3] + '%03d' % i FOLDER = FOLDER[:-3] + '%03d' % i
try: try:
sftp.mkdir(FOLDER) sftp.mkdir(FOLDER)
@ -149,6 +152,7 @@ class SFTPTest (unittest.TestCase):
pass pass
def tearDown(self): def tearDown(self):
#sftp.chdir()
sftp.rmdir(FOLDER) sftp.rmdir(FOLDER)
def test_1_file(self): def test_1_file(self):
@ -158,8 +162,8 @@ class SFTPTest (unittest.TestCase):
f = sftp.open(FOLDER + '/test', 'w') f = sftp.open(FOLDER + '/test', 'w')
try: try:
self.assertEqual(f.stat().st_size, 0) self.assertEqual(f.stat().st_size, 0)
f.close()
finally: finally:
f.close()
sftp.remove(FOLDER + '/test') sftp.remove(FOLDER + '/test')
def test_2_close(self): def test_2_close(self):
@ -180,10 +184,9 @@ class SFTPTest (unittest.TestCase):
""" """
verify that a file can be created and written, and the size is correct. verify that a file can be created and written, and the size is correct.
""" """
f = sftp.open(FOLDER + '/duck.txt', 'w')
try: try:
f.write(ARTICLE) with sftp.open(FOLDER + '/duck.txt', 'w') as f:
f.close() f.write(ARTICLE)
self.assertEqual(sftp.stat(FOLDER + '/duck.txt').st_size, 1483) self.assertEqual(sftp.stat(FOLDER + '/duck.txt').st_size, 1483)
finally: finally:
sftp.remove(FOLDER + '/duck.txt') sftp.remove(FOLDER + '/duck.txt')
@ -203,19 +206,17 @@ class SFTPTest (unittest.TestCase):
""" """
verify that a file can be opened for append, and tell() still works. verify that a file can be opened for append, and tell() still works.
""" """
f = sftp.open(FOLDER + '/append.txt', 'w')
try: try:
f.write('first line\nsecond line\n') with sftp.open(FOLDER + '/append.txt', 'w') as f:
self.assertEqual(f.tell(), 23) f.write('first line\nsecond line\n')
f.close() self.assertEqual(f.tell(), 23)
f = sftp.open(FOLDER + '/append.txt', 'a+') with sftp.open(FOLDER + '/append.txt', 'a+') as f:
f.write('third line!!!\n') f.write('third line!!!\n')
self.assertEqual(f.tell(), 37) self.assertEqual(f.tell(), 37)
self.assertEqual(f.stat().st_size, 37) self.assertEqual(f.stat().st_size, 37)
f.seek(-26, f.SEEK_CUR) f.seek(-26, f.SEEK_CUR)
self.assertEqual(f.readline(), 'second line\n') self.assertEqual(f.readline(), 'second line\n')
f.close()
finally: finally:
sftp.remove(FOLDER + '/append.txt') sftp.remove(FOLDER + '/append.txt')
@ -223,20 +224,18 @@ class SFTPTest (unittest.TestCase):
""" """
verify that renaming a file works. verify that renaming a file works.
""" """
f = sftp.open(FOLDER + '/first.txt', 'w')
try: try:
f.write('content!\n') with sftp.open(FOLDER + '/first.txt', 'w') as f:
f.close() f.write('content!\n')
sftp.rename(FOLDER + '/first.txt', FOLDER + '/second.txt') sftp.rename(FOLDER + '/first.txt', FOLDER + '/second.txt')
try: try:
f = sftp.open(FOLDER + '/first.txt', 'r') sftp.open(FOLDER + '/first.txt', 'r')
self.assert_(False, 'no exception on reading nonexistent file') self.assertTrue(False, 'no exception on reading nonexistent file')
except IOError: except IOError:
pass pass
f = sftp.open(FOLDER + '/second.txt', 'r') with sftp.open(FOLDER + '/second.txt', 'r') as f:
f.seek(-6, f.SEEK_END) f.seek(-6, f.SEEK_END)
self.assertEqual(f.read(4), 'tent') self.assertEqual(u(f.read(4)), 'tent')
f.close()
finally: finally:
try: try:
sftp.remove(FOLDER + '/first.txt') sftp.remove(FOLDER + '/first.txt')
@ -253,14 +252,13 @@ class SFTPTest (unittest.TestCase):
remove the folder and verify that we can't create a file in it anymore. remove the folder and verify that we can't create a file in it anymore.
""" """
sftp.mkdir(FOLDER + '/subfolder') sftp.mkdir(FOLDER + '/subfolder')
f = sftp.open(FOLDER + '/subfolder/test', 'w') sftp.open(FOLDER + '/subfolder/test', 'w').close()
f.close()
sftp.remove(FOLDER + '/subfolder/test') sftp.remove(FOLDER + '/subfolder/test')
sftp.rmdir(FOLDER + '/subfolder') sftp.rmdir(FOLDER + '/subfolder')
try: try:
f = sftp.open(FOLDER + '/subfolder/test') sftp.open(FOLDER + '/subfolder/test')
# shouldn't be able to create that file # shouldn't be able to create that file
self.assert_(False, 'no exception at dummy file creation') self.assertTrue(False, 'no exception at dummy file creation')
except IOError: except IOError:
pass pass
@ -270,21 +268,16 @@ class SFTPTest (unittest.TestCase):
and those files show up in sftp.listdir. and those files show up in sftp.listdir.
""" """
try: try:
f = sftp.open(FOLDER + '/duck.txt', 'w') sftp.open(FOLDER + '/duck.txt', 'w').close()
f.close() sftp.open(FOLDER + '/fish.txt', 'w').close()
sftp.open(FOLDER + '/tertiary.py', 'w').close()
f = sftp.open(FOLDER + '/fish.txt', 'w')
f.close()
f = sftp.open(FOLDER + '/tertiary.py', 'w')
f.close()
x = sftp.listdir(FOLDER) x = sftp.listdir(FOLDER)
self.assertEqual(len(x), 3) self.assertEqual(len(x), 3)
self.assert_('duck.txt' in x) self.assertTrue('duck.txt' in x)
self.assert_('fish.txt' in x) self.assertTrue('fish.txt' in x)
self.assert_('tertiary.py' in x) self.assertTrue('tertiary.py' in x)
self.assert_('random' not in x) self.assertTrue('random' not in x)
finally: finally:
sftp.remove(FOLDER + '/duck.txt') sftp.remove(FOLDER + '/duck.txt')
sftp.remove(FOLDER + '/fish.txt') sftp.remove(FOLDER + '/fish.txt')
@ -294,22 +287,21 @@ class SFTPTest (unittest.TestCase):
""" """
verify that the setstat functions (chown, chmod, utime, truncate) work. verify that the setstat functions (chown, chmod, utime, truncate) work.
""" """
f = sftp.open(FOLDER + '/special', 'w')
try: try:
f.write('x' * 1024) with sftp.open(FOLDER + '/special', 'w') as f:
f.close() f.write('x' * 1024)
stat = sftp.stat(FOLDER + '/special') stat = sftp.stat(FOLDER + '/special')
sftp.chmod(FOLDER + '/special', (stat.st_mode & ~0777) | 0600) sftp.chmod(FOLDER + '/special', (stat.st_mode & ~o777) | o600)
stat = sftp.stat(FOLDER + '/special') stat = sftp.stat(FOLDER + '/special')
expected_mode = 0600 expected_mode = o600
if sys.platform == 'win32': if sys.platform == 'win32':
# chmod not really functional on windows # chmod not really functional on windows
expected_mode = 0666 expected_mode = o666
if sys.platform == 'cygwin': if sys.platform == 'cygwin':
# even worse. # even worse.
expected_mode = 0644 expected_mode = o644
self.assertEqual(stat.st_mode & 0777, expected_mode) self.assertEqual(stat.st_mode & o777, expected_mode)
self.assertEqual(stat.st_size, 1024) self.assertEqual(stat.st_size, 1024)
mtime = stat.st_mtime - 3600 mtime = stat.st_mtime - 3600
@ -333,40 +325,38 @@ class SFTPTest (unittest.TestCase):
verify that the fsetstat functions (chown, chmod, utime, truncate) verify that the fsetstat functions (chown, chmod, utime, truncate)
work on open files. work on open files.
""" """
f = sftp.open(FOLDER + '/special', 'w')
try: try:
f.write('x' * 1024) with sftp.open(FOLDER + '/special', 'w') as f:
f.close() f.write('x' * 1024)
f = sftp.open(FOLDER + '/special', 'r+') with sftp.open(FOLDER + '/special', 'r+') as f:
stat = f.stat() stat = f.stat()
f.chmod((stat.st_mode & ~0777) | 0600) f.chmod((stat.st_mode & ~o777) | o600)
stat = f.stat() stat = f.stat()
expected_mode = 0600 expected_mode = o600
if sys.platform == 'win32': if sys.platform == 'win32':
# chmod not really functional on windows # chmod not really functional on windows
expected_mode = 0666 expected_mode = o666
if sys.platform == 'cygwin': if sys.platform == 'cygwin':
# even worse. # even worse.
expected_mode = 0644 expected_mode = o644
self.assertEqual(stat.st_mode & 0777, expected_mode) self.assertEqual(stat.st_mode & o777, expected_mode)
self.assertEqual(stat.st_size, 1024) self.assertEqual(stat.st_size, 1024)
mtime = stat.st_mtime - 3600 mtime = stat.st_mtime - 3600
atime = stat.st_atime - 1800 atime = stat.st_atime - 1800
f.utime((atime, mtime)) f.utime((atime, mtime))
stat = f.stat() stat = f.stat()
self.assertEqual(stat.st_mtime, mtime) self.assertEqual(stat.st_mtime, mtime)
if sys.platform not in ('win32', 'cygwin'): if sys.platform not in ('win32', 'cygwin'):
self.assertEqual(stat.st_atime, atime) self.assertEqual(stat.st_atime, atime)
# can't really test chown, since we'd have to know a valid uid. # can't really test chown, since we'd have to know a valid uid.
f.truncate(512) f.truncate(512)
stat = f.stat() stat = f.stat()
self.assertEqual(stat.st_size, 512) self.assertEqual(stat.st_size, 512)
f.close()
finally: finally:
sftp.remove(FOLDER + '/special') sftp.remove(FOLDER + '/special')
@ -378,25 +368,23 @@ class SFTPTest (unittest.TestCase):
buffering is reset on 'seek'. buffering is reset on 'seek'.
""" """
try: try:
f = sftp.open(FOLDER + '/duck.txt', 'w') with sftp.open(FOLDER + '/duck.txt', 'w') as f:
f.write(ARTICLE) f.write(ARTICLE)
f.close()
f = sftp.open(FOLDER + '/duck.txt', 'r+') with sftp.open(FOLDER + '/duck.txt', 'r+') as f:
line_number = 0 line_number = 0
loc = 0 loc = 0
pos_list = [] pos_list = []
for line in f: for line in f:
line_number += 1 line_number += 1
pos_list.append(loc) pos_list.append(loc)
loc = f.tell() loc = f.tell()
f.seek(pos_list[6], f.SEEK_SET) f.seek(pos_list[6], f.SEEK_SET)
self.assertEqual(f.readline(), 'Nouzilly, France.\n') self.assertEqual(f.readline(), 'Nouzilly, France.\n')
f.seek(pos_list[17], f.SEEK_SET) f.seek(pos_list[17], f.SEEK_SET)
self.assertEqual(f.readline()[:4], 'duck') self.assertEqual(f.readline()[:4], 'duck')
f.seek(pos_list[10], f.SEEK_SET) f.seek(pos_list[10], f.SEEK_SET)
self.assertEqual(f.readline(), 'duck types were equally resistant to exogenous insulin compared with chicken.\n') self.assertEqual(f.readline(), 'duck types were equally resistant to exogenous insulin compared with chicken.\n')
f.close()
finally: finally:
sftp.remove(FOLDER + '/duck.txt') sftp.remove(FOLDER + '/duck.txt')
@ -405,17 +393,15 @@ class SFTPTest (unittest.TestCase):
create a text file, seek back and change part of it, and verify that the create a text file, seek back and change part of it, and verify that the
changes worked. changes worked.
""" """
f = sftp.open(FOLDER + '/testing.txt', 'w')
try: try:
f.write('hello kitty.\n') with sftp.open(FOLDER + '/testing.txt', 'w') as f:
f.seek(-5, f.SEEK_CUR) f.write('hello kitty.\n')
f.write('dd') f.seek(-5, f.SEEK_CUR)
f.close() f.write('dd')
self.assertEqual(sftp.stat(FOLDER + '/testing.txt').st_size, 13) self.assertEqual(sftp.stat(FOLDER + '/testing.txt').st_size, 13)
f = sftp.open(FOLDER + '/testing.txt', 'r') with sftp.open(FOLDER + '/testing.txt', 'r') as f:
data = f.read(20) data = f.read(20)
f.close()
self.assertEqual(data, 'hello kiddy.\n') self.assertEqual(data, 'hello kiddy.\n')
finally: finally:
sftp.remove(FOLDER + '/testing.txt') sftp.remove(FOLDER + '/testing.txt')
@ -428,16 +414,14 @@ class SFTPTest (unittest.TestCase):
# skip symlink tests on windows # skip symlink tests on windows
return return
f = sftp.open(FOLDER + '/original.txt', 'w')
try: try:
f.write('original\n') with sftp.open(FOLDER + '/original.txt', 'w') as f:
f.close() f.write('original\n')
sftp.symlink('original.txt', FOLDER + '/link.txt') sftp.symlink('original.txt', FOLDER + '/link.txt')
self.assertEqual(sftp.readlink(FOLDER + '/link.txt'), 'original.txt') self.assertEqual(sftp.readlink(FOLDER + '/link.txt'), 'original.txt')
f = sftp.open(FOLDER + '/link.txt', 'r') with sftp.open(FOLDER + '/link.txt', 'r') as f:
self.assertEqual(f.readlines(), ['original\n']) self.assertEqual(f.readlines(), ['original\n'])
f.close()
cwd = sftp.normalize('.') cwd = sftp.normalize('.')
if cwd[-1] == '/': if cwd[-1] == '/':
@ -450,7 +434,7 @@ class SFTPTest (unittest.TestCase):
self.assertEqual(sftp.stat(FOLDER + '/link.txt').st_size, 9) self.assertEqual(sftp.stat(FOLDER + '/link.txt').st_size, 9)
# the sftp server may be hiding extra path members from us, so the # the sftp server may be hiding extra path members from us, so the
# length may be longer than we expect: # length may be longer than we expect:
self.assert_(sftp.lstat(FOLDER + '/link2.txt').st_size >= len(abs_path)) self.assertTrue(sftp.lstat(FOLDER + '/link2.txt').st_size >= len(abs_path))
self.assertEqual(sftp.stat(FOLDER + '/link2.txt').st_size, 9) self.assertEqual(sftp.stat(FOLDER + '/link2.txt').st_size, 9)
self.assertEqual(sftp.stat(FOLDER + '/original.txt').st_size, 9) self.assertEqual(sftp.stat(FOLDER + '/original.txt').st_size, 9)
finally: finally:
@ -471,18 +455,16 @@ class SFTPTest (unittest.TestCase):
""" """
verify that buffered writes are automatically flushed on seek. verify that buffered writes are automatically flushed on seek.
""" """
f = sftp.open(FOLDER + '/happy.txt', 'w', 1)
try: try:
f.write('full line.\n') with sftp.open(FOLDER + '/happy.txt', 'w', 1) as f:
f.write('partial') f.write('full line.\n')
f.seek(9, f.SEEK_SET) f.write('partial')
f.write('?\n') f.seek(9, f.SEEK_SET)
f.close() f.write('?\n')
f = sftp.open(FOLDER + '/happy.txt', 'r') with sftp.open(FOLDER + '/happy.txt', 'r') as f:
self.assertEqual(f.readline(), 'full line?\n') self.assertEqual(f.readline(), 'full line?\n')
self.assertEqual(f.read(7), 'partial') self.assertEqual(f.read(7), 'partial')
f.close()
finally: finally:
try: try:
sftp.remove(FOLDER + '/happy.txt') sftp.remove(FOLDER + '/happy.txt')
@ -495,10 +477,10 @@ class SFTPTest (unittest.TestCase):
error. error.
""" """
pwd = sftp.normalize('.') pwd = sftp.normalize('.')
self.assert_(len(pwd) > 0) self.assertTrue(len(pwd) > 0)
f = sftp.normalize('./' + FOLDER) f = sftp.normalize('./' + FOLDER)
self.assert_(len(f) > 0) self.assertTrue(len(f) > 0)
self.assertEquals(os.path.join(pwd, FOLDER), f) self.assertEqual(os.path.join(pwd, FOLDER), f)
def test_F_mkdir(self): def test_F_mkdir(self):
""" """
@ -507,19 +489,19 @@ class SFTPTest (unittest.TestCase):
try: try:
sftp.mkdir(FOLDER + '/subfolder') sftp.mkdir(FOLDER + '/subfolder')
except: except:
self.assert_(False, 'exception creating subfolder') self.assertTrue(False, 'exception creating subfolder')
try: try:
sftp.mkdir(FOLDER + '/subfolder') sftp.mkdir(FOLDER + '/subfolder')
self.assert_(False, 'no exception overwriting subfolder') self.assertTrue(False, 'no exception overwriting subfolder')
except IOError: except IOError:
pass pass
try: try:
sftp.rmdir(FOLDER + '/subfolder') sftp.rmdir(FOLDER + '/subfolder')
except: except:
self.assert_(False, 'exception removing subfolder') self.assertTrue(False, 'exception removing subfolder')
try: try:
sftp.rmdir(FOLDER + '/subfolder') sftp.rmdir(FOLDER + '/subfolder')
self.assert_(False, 'no exception removing nonexistent subfolder') self.assertTrue(False, 'no exception removing nonexistent subfolder')
except IOError: except IOError:
pass pass
@ -534,17 +516,16 @@ class SFTPTest (unittest.TestCase):
sftp.mkdir(FOLDER + '/alpha') sftp.mkdir(FOLDER + '/alpha')
sftp.chdir(FOLDER + '/alpha') sftp.chdir(FOLDER + '/alpha')
sftp.mkdir('beta') sftp.mkdir('beta')
self.assertEquals(root + FOLDER + '/alpha', sftp.getcwd()) self.assertEqual(root + FOLDER + '/alpha', sftp.getcwd())
self.assertEquals(['beta'], sftp.listdir('.')) self.assertEqual(['beta'], sftp.listdir('.'))
sftp.chdir('beta') sftp.chdir('beta')
f = sftp.open('fish', 'w') with sftp.open('fish', 'w') as f:
f.write('hello\n') f.write('hello\n')
f.close()
sftp.chdir('..') sftp.chdir('..')
self.assertEquals(['fish'], sftp.listdir('beta')) self.assertEqual(['fish'], sftp.listdir('beta'))
sftp.chdir('..') sftp.chdir('..')
self.assertEquals(['fish'], sftp.listdir('alpha/beta')) self.assertEqual(['fish'], sftp.listdir('alpha/beta'))
finally: finally:
sftp.chdir(root) sftp.chdir(root)
try: try:
@ -566,30 +547,29 @@ class SFTPTest (unittest.TestCase):
""" """
warnings.filterwarnings('ignore', 'tempnam.*') warnings.filterwarnings('ignore', 'tempnam.*')
localname = os.tempnam() fd, localname = mkstemp()
text = 'All I wanted was a plastic bunny rabbit.\n' os.close(fd)
f = open(localname, 'wb') text = b'All I wanted was a plastic bunny rabbit.\n'
f.write(text) with open(localname, 'wb') as f:
f.close() f.write(text)
saved_progress = [] saved_progress = []
def progress_callback(x, y): def progress_callback(x, y):
saved_progress.append((x, y)) saved_progress.append((x, y))
sftp.put(localname, FOLDER + '/bunny.txt', progress_callback) sftp.put(localname, FOLDER + '/bunny.txt', progress_callback)
f = sftp.open(FOLDER + '/bunny.txt', 'r') with sftp.open(FOLDER + '/bunny.txt', 'rb') as f:
self.assertEquals(text, f.read(128)) self.assertEqual(text, f.read(128))
f.close() self.assertEqual((41, 41), saved_progress[-1])
self.assertEquals((41, 41), saved_progress[-1])
os.unlink(localname) os.unlink(localname)
localname = os.tempnam() fd, localname = mkstemp()
os.close(fd)
saved_progress = [] saved_progress = []
sftp.get(FOLDER + '/bunny.txt', localname, progress_callback) sftp.get(FOLDER + '/bunny.txt', localname, progress_callback)
f = open(localname, 'rb') with open(localname, 'rb') as f:
self.assertEquals(text, f.read(128)) self.assertEqual(text, f.read(128))
f.close() self.assertEqual((41, 41), saved_progress[-1])
self.assertEquals((41, 41), saved_progress[-1])
os.unlink(localname) os.unlink(localname)
sftp.unlink(FOLDER + '/bunny.txt') sftp.unlink(FOLDER + '/bunny.txt')
@ -600,20 +580,18 @@ class SFTPTest (unittest.TestCase):
(it's an sftp extension that we support, and may be the only ones who (it's an sftp extension that we support, and may be the only ones who
support it.) support it.)
""" """
f = sftp.open(FOLDER + '/kitty.txt', 'w') with sftp.open(FOLDER + '/kitty.txt', 'w') as f:
f.write('here kitty kitty' * 64) f.write('here kitty kitty' * 64)
f.close()
try: try:
f = sftp.open(FOLDER + '/kitty.txt', 'r') with sftp.open(FOLDER + '/kitty.txt', 'r') as f:
sum = f.check('sha1') sum = f.check('sha1')
self.assertEquals('91059CFC6615941378D413CB5ADAF4C5EB293402', hexlify(sum).upper()) self.assertEqual('91059CFC6615941378D413CB5ADAF4C5EB293402', u(hexlify(sum)).upper())
sum = f.check('md5', 0, 512) sum = f.check('md5', 0, 512)
self.assertEquals('93DE4788FCA28D471516963A1FE3856A', hexlify(sum).upper()) self.assertEqual('93DE4788FCA28D471516963A1FE3856A', u(hexlify(sum)).upper())
sum = f.check('md5', 0, 0, 510) sum = f.check('md5', 0, 0, 510)
self.assertEquals('EB3B45B8CD55A0707D99B177544A319F373183D241432BB2157AB9E46358C4AC90370B5CADE5D90336FC1716F90B36D6', self.assertEqual('EB3B45B8CD55A0707D99B177544A319F373183D241432BB2157AB9E46358C4AC90370B5CADE5D90336FC1716F90B36D6',
hexlify(sum).upper()) u(hexlify(sum)).upper())
f.close()
finally: finally:
sftp.unlink(FOLDER + '/kitty.txt') sftp.unlink(FOLDER + '/kitty.txt')
@ -621,12 +599,11 @@ class SFTPTest (unittest.TestCase):
""" """
verify that the 'x' flag works when opening a file. verify that the 'x' flag works when opening a file.
""" """
f = sftp.open(FOLDER + '/unusual.txt', 'wx') sftp.open(FOLDER + '/unusual.txt', 'wx').close()
f.close()
try: try:
try: try:
f = sftp.open(FOLDER + '/unusual.txt', 'wx') sftp.open(FOLDER + '/unusual.txt', 'wx')
self.fail('expected exception') self.fail('expected exception')
except IOError: except IOError:
pass pass
@ -637,44 +614,39 @@ class SFTPTest (unittest.TestCase):
""" """
verify that unicode strings are encoded into utf8 correctly. verify that unicode strings are encoded into utf8 correctly.
""" """
f = sftp.open(FOLDER + '/something', 'w') with sftp.open(FOLDER + '/something', 'w') as f:
f.write('okay') f.write('okay')
f.close()
try: try:
sftp.rename(FOLDER + '/something', FOLDER + u'/\u00fcnic\u00f8de') sftp.rename(FOLDER + '/something', FOLDER + '/' + unicode_folder)
sftp.open(FOLDER + '/\xc3\xbcnic\xc3\xb8\x64\x65', 'r') sftp.open(b(FOLDER) + utf8_folder, 'r')
except Exception, e: except Exception as e:
self.fail('exception ' + e) self.fail('exception ' + str(e))
sftp.unlink(FOLDER + '/\xc3\xbcnic\xc3\xb8\x64\x65') sftp.unlink(b(FOLDER) + utf8_folder)
def test_L_utf8_chdir(self): def test_L_utf8_chdir(self):
sftp.mkdir(FOLDER + u'\u00fcnic\u00f8de') sftp.mkdir(FOLDER + '/' + unicode_folder)
try: try:
sftp.chdir(FOLDER + u'\u00fcnic\u00f8de') sftp.chdir(FOLDER + '/' + unicode_folder)
f = sftp.open('something', 'w') with sftp.open('something', 'w') as f:
f.write('okay') f.write('okay')
f.close()
sftp.unlink('something') sftp.unlink('something')
finally: finally:
sftp.chdir(None) sftp.chdir()
sftp.rmdir(FOLDER + u'\u00fcnic\u00f8de') sftp.rmdir(FOLDER + '/' + unicode_folder)
def test_M_bad_readv(self): def test_M_bad_readv(self):
""" """
verify that readv at the end of the file doesn't essplode. verify that readv at the end of the file doesn't essplode.
""" """
f = sftp.open(FOLDER + '/zero', 'w') sftp.open(FOLDER + '/zero', 'w').close()
f.close()
try: try:
f = sftp.open(FOLDER + '/zero', 'r') with sftp.open(FOLDER + '/zero', 'r') as f:
f.readv([(0, 12)]) f.readv([(0, 12)])
f.close()
f = sftp.open(FOLDER + '/zero', 'r') with sftp.open(FOLDER + '/zero', 'r') as f:
f.prefetch() f.prefetch()
f.read(100) f.read(100)
f.close()
finally: finally:
sftp.unlink(FOLDER + '/zero') sftp.unlink(FOLDER + '/zero')
@ -684,45 +656,61 @@ class SFTPTest (unittest.TestCase):
""" """
warnings.filterwarnings('ignore', 'tempnam.*') warnings.filterwarnings('ignore', 'tempnam.*')
localname = os.tempnam() fd, localname = mkstemp()
os.close(fd)
text = 'All I wanted was a plastic bunny rabbit.\n' text = 'All I wanted was a plastic bunny rabbit.\n'
f = open(localname, 'wb') with open(localname, 'w') as f:
f.write(text) f.write(text)
f.close()
saved_progress = [] saved_progress = []
def progress_callback(x, y): def progress_callback(x, y):
saved_progress.append((x, y)) saved_progress.append((x, y))
res = sftp.put(localname, FOLDER + '/bunny.txt', progress_callback, False) res = sftp.put(localname, FOLDER + '/bunny.txt', progress_callback, False)
self.assertEquals(SFTPAttributes().attr, res.attr) self.assertEqual(SFTPAttributes().attr, res.attr)
f = sftp.open(FOLDER + '/bunny.txt', 'r') with sftp.open(FOLDER + '/bunny.txt', 'r') as f:
self.assertEquals(text, f.read(128)) self.assertEqual(text, f.read(128))
f.close() self.assertEqual((41, 41), saved_progress[-1])
self.assertEquals((41, 41), saved_progress[-1])
os.unlink(localname) os.unlink(localname)
sftp.unlink(FOLDER + '/bunny.txt') sftp.unlink(FOLDER + '/bunny.txt')
def test_O_getcwd(self):
"""
verify that chdir/getcwd work.
"""
self.assertEqual(None, sftp.getcwd())
root = sftp.normalize('.')
if root[-1] != '/':
root += '/'
try:
sftp.mkdir(FOLDER + '/alpha')
sftp.chdir(FOLDER + '/alpha')
self.assertEqual('/' + FOLDER + '/alpha', sftp.getcwd())
finally:
sftp.chdir(root)
try:
sftp.rmdir(FOLDER + '/alpha')
except:
pass
def XXX_test_M_seek_append(self): def XXX_test_M_seek_append(self):
""" """
verify that seek does't affect writes during append. verify that seek does't affect writes during append.
does not work except through paramiko. :( openssh fails. does not work except through paramiko. :( openssh fails.
""" """
f = sftp.open(FOLDER + '/append.txt', 'a')
try: try:
f.write('first line\nsecond line\n') with sftp.open(FOLDER + '/append.txt', 'a') as f:
f.seek(11, f.SEEK_SET) f.write('first line\nsecond line\n')
f.write('third line\n') f.seek(11, f.SEEK_SET)
f.close() f.write('third line\n')
f = sftp.open(FOLDER + '/append.txt', 'r') with sftp.open(FOLDER + '/append.txt', 'r') as f:
self.assertEqual(f.stat().st_size, 34) self.assertEqual(f.stat().st_size, 34)
self.assertEqual(f.readline(), 'first line\n') self.assertEqual(f.readline(), 'first line\n')
self.assertEqual(f.readline(), 'second line\n') self.assertEqual(f.readline(), 'second line\n')
self.assertEqual(f.readline(), 'third line\n') self.assertEqual(f.readline(), 'third line\n')
f.close()
finally: finally:
sftp.remove(FOLDER + '/append.txt') sftp.remove(FOLDER + '/append.txt')
@ -731,10 +719,16 @@ class SFTPTest (unittest.TestCase):
Send an empty file and confirm it is sent. Send an empty file and confirm it is sent.
""" """
target = FOLDER + '/empty file.txt' target = FOLDER + '/empty file.txt'
stream = StringIO.StringIO() stream = StringIO()
try: try:
attrs = sftp.putfo(stream, target) attrs = sftp.putfo(stream, target)
# the returned attributes should not be null # the returned attributes should not be null
self.assertNotEqual(attrs, None) self.assertNotEqual(attrs, None)
finally: finally:
sftp.remove(target) sftp.remove(target)
if __name__ == '__main__':
SFTPTest.init_loopback()
from unittest import main
main()

View File

@ -33,9 +33,10 @@ import time
import unittest import unittest
import paramiko import paramiko
from stub_sftp import StubServer, StubSFTPServer from paramiko.common import *
from loop import LoopSocket from tests.stub_sftp import StubServer, StubSFTPServer
from test_sftp import get_sftp from tests.loop import LoopSocket
from tests.test_sftp import get_sftp
FOLDER = os.environ.get('TEST_FOLDER', 'temp-testing000') FOLDER = os.environ.get('TEST_FOLDER', 'temp-testing000')
@ -45,7 +46,7 @@ class BigSFTPTest (unittest.TestCase):
def setUp(self): def setUp(self):
global FOLDER global FOLDER
sftp = get_sftp() sftp = get_sftp()
for i in xrange(1000): for i in range(1000):
FOLDER = FOLDER[:-3] + '%03d' % i FOLDER = FOLDER[:-3] + '%03d' % i
try: try:
sftp.mkdir(FOLDER) sftp.mkdir(FOLDER)
@ -65,19 +66,17 @@ class BigSFTPTest (unittest.TestCase):
numfiles = 100 numfiles = 100
try: try:
for i in range(numfiles): for i in range(numfiles):
f = sftp.open('%s/file%d.txt' % (FOLDER, i), 'w', 1) with sftp.open('%s/file%d.txt' % (FOLDER, i), 'w', 1) as f:
f.write('this is file #%d.\n' % i) f.write('this is file #%d.\n' % i)
f.close() sftp.chmod('%s/file%d.txt' % (FOLDER, i), o660)
sftp.chmod('%s/file%d.txt' % (FOLDER, i), 0660)
# now make sure every file is there, by creating a list of filenmes # now make sure every file is there, by creating a list of filenmes
# and reading them in random order. # and reading them in random order.
numlist = range(numfiles) numlist = list(range(numfiles))
while len(numlist) > 0: while len(numlist) > 0:
r = numlist[random.randint(0, len(numlist) - 1)] r = numlist[random.randint(0, len(numlist) - 1)]
f = sftp.open('%s/file%d.txt' % (FOLDER, r)) with sftp.open('%s/file%d.txt' % (FOLDER, r)) as f:
self.assertEqual(f.readline(), 'this is file #%d.\n' % r) self.assertEqual(f.readline(), 'this is file #%d.\n' % r)
f.close()
numlist.remove(r) numlist.remove(r)
finally: finally:
for i in range(numfiles): for i in range(numfiles):
@ -94,12 +93,11 @@ class BigSFTPTest (unittest.TestCase):
kblob = (1024 * 'x') kblob = (1024 * 'x')
start = time.time() start = time.time()
try: try:
f = sftp.open('%s/hongry.txt' % FOLDER, 'w') with sftp.open('%s/hongry.txt' % FOLDER, 'w') as f:
for n in range(1024): for n in range(1024):
f.write(kblob) f.write(kblob)
if n % 128 == 0: if n % 128 == 0:
sys.stderr.write('.') sys.stderr.write('.')
f.close()
sys.stderr.write(' ') sys.stderr.write(' ')
self.assertEqual(sftp.stat('%s/hongry.txt' % FOLDER).st_size, 1024 * 1024) self.assertEqual(sftp.stat('%s/hongry.txt' % FOLDER).st_size, 1024 * 1024)
@ -107,11 +105,10 @@ class BigSFTPTest (unittest.TestCase):
sys.stderr.write('%ds ' % round(end - start)) sys.stderr.write('%ds ' % round(end - start))
start = time.time() start = time.time()
f = sftp.open('%s/hongry.txt' % FOLDER, 'r') with sftp.open('%s/hongry.txt' % FOLDER, 'r') as f:
for n in range(1024): for n in range(1024):
data = f.read(1024) data = f.read(1024)
self.assertEqual(data, kblob) self.assertEqual(data, kblob)
f.close()
end = time.time() end = time.time()
sys.stderr.write('%ds ' % round(end - start)) sys.stderr.write('%ds ' % round(end - start))
@ -123,16 +120,15 @@ class BigSFTPTest (unittest.TestCase):
write a 1MB file, with no linefeeds, using pipelining. write a 1MB file, with no linefeeds, using pipelining.
""" """
sftp = get_sftp() sftp = get_sftp()
kblob = ''.join([struct.pack('>H', n) for n in xrange(512)]) kblob = bytes().join([struct.pack('>H', n) for n in range(512)])
start = time.time() start = time.time()
try: try:
f = sftp.open('%s/hongry.txt' % FOLDER, 'w') with sftp.open('%s/hongry.txt' % FOLDER, 'wb') as f:
f.set_pipelined(True) f.set_pipelined(True)
for n in range(1024): for n in range(1024):
f.write(kblob) f.write(kblob)
if n % 128 == 0: if n % 128 == 0:
sys.stderr.write('.') sys.stderr.write('.')
f.close()
sys.stderr.write(' ') sys.stderr.write(' ')
self.assertEqual(sftp.stat('%s/hongry.txt' % FOLDER).st_size, 1024 * 1024) self.assertEqual(sftp.stat('%s/hongry.txt' % FOLDER).st_size, 1024 * 1024)
@ -140,22 +136,21 @@ class BigSFTPTest (unittest.TestCase):
sys.stderr.write('%ds ' % round(end - start)) sys.stderr.write('%ds ' % round(end - start))
start = time.time() start = time.time()
f = sftp.open('%s/hongry.txt' % FOLDER, 'r') with sftp.open('%s/hongry.txt' % FOLDER, 'rb') as f:
f.prefetch() f.prefetch()
# read on odd boundaries to make sure the bytes aren't getting scrambled # read on odd boundaries to make sure the bytes aren't getting scrambled
n = 0 n = 0
k2blob = kblob + kblob k2blob = kblob + kblob
chunk = 629 chunk = 629
size = 1024 * 1024 size = 1024 * 1024
while n < size: while n < size:
if n + chunk > size: if n + chunk > size:
chunk = size - n chunk = size - n
data = f.read(chunk) data = f.read(chunk)
offset = n % 1024 offset = n % 1024
self.assertEqual(data, k2blob[offset:offset + chunk]) self.assertEqual(data, k2blob[offset:offset + chunk])
n += chunk n += chunk
f.close()
end = time.time() end = time.time()
sys.stderr.write('%ds ' % round(end - start)) sys.stderr.write('%ds ' % round(end - start))
@ -164,15 +159,14 @@ class BigSFTPTest (unittest.TestCase):
def test_4_prefetch_seek(self): def test_4_prefetch_seek(self):
sftp = get_sftp() sftp = get_sftp()
kblob = ''.join([struct.pack('>H', n) for n in xrange(512)]) kblob = bytes().join([struct.pack('>H', n) for n in range(512)])
try: try:
f = sftp.open('%s/hongry.txt' % FOLDER, 'w') with sftp.open('%s/hongry.txt' % FOLDER, 'wb') as f:
f.set_pipelined(True) f.set_pipelined(True)
for n in range(1024): for n in range(1024):
f.write(kblob) f.write(kblob)
if n % 128 == 0: if n % 128 == 0:
sys.stderr.write('.') sys.stderr.write('.')
f.close()
sys.stderr.write(' ') sys.stderr.write(' ')
self.assertEqual(sftp.stat('%s/hongry.txt' % FOLDER).st_size, 1024 * 1024) self.assertEqual(sftp.stat('%s/hongry.txt' % FOLDER).st_size, 1024 * 1024)
@ -180,21 +174,20 @@ class BigSFTPTest (unittest.TestCase):
start = time.time() start = time.time()
k2blob = kblob + kblob k2blob = kblob + kblob
chunk = 793 chunk = 793
for i in xrange(10): for i in range(10):
f = sftp.open('%s/hongry.txt' % FOLDER, 'r') with sftp.open('%s/hongry.txt' % FOLDER, 'rb') as f:
f.prefetch() f.prefetch()
base_offset = (512 * 1024) + 17 * random.randint(1000, 2000) base_offset = (512 * 1024) + 17 * random.randint(1000, 2000)
offsets = [base_offset + j * chunk for j in xrange(100)] offsets = [base_offset + j * chunk for j in range(100)]
# randomly seek around and read them out # randomly seek around and read them out
for j in xrange(100): for j in range(100):
offset = offsets[random.randint(0, len(offsets) - 1)] offset = offsets[random.randint(0, len(offsets) - 1)]
offsets.remove(offset) offsets.remove(offset)
f.seek(offset) f.seek(offset)
data = f.read(chunk) data = f.read(chunk)
n_offset = offset % 1024 n_offset = offset % 1024
self.assertEqual(data, k2blob[n_offset:n_offset + chunk]) self.assertEqual(data, k2blob[n_offset:n_offset + chunk])
offset += chunk offset += chunk
f.close()
end = time.time() end = time.time()
sys.stderr.write('%ds ' % round(end - start)) sys.stderr.write('%ds ' % round(end - start))
finally: finally:
@ -202,15 +195,14 @@ class BigSFTPTest (unittest.TestCase):
def test_5_readv_seek(self): def test_5_readv_seek(self):
sftp = get_sftp() sftp = get_sftp()
kblob = ''.join([struct.pack('>H', n) for n in xrange(512)]) kblob = bytes().join([struct.pack('>H', n) for n in range(512)])
try: try:
f = sftp.open('%s/hongry.txt' % FOLDER, 'w') with sftp.open('%s/hongry.txt' % FOLDER, 'wb') as f:
f.set_pipelined(True) f.set_pipelined(True)
for n in range(1024): for n in range(1024):
f.write(kblob) f.write(kblob)
if n % 128 == 0: if n % 128 == 0:
sys.stderr.write('.') sys.stderr.write('.')
f.close()
sys.stderr.write(' ') sys.stderr.write(' ')
self.assertEqual(sftp.stat('%s/hongry.txt' % FOLDER).st_size, 1024 * 1024) self.assertEqual(sftp.stat('%s/hongry.txt' % FOLDER).st_size, 1024 * 1024)
@ -218,22 +210,21 @@ class BigSFTPTest (unittest.TestCase):
start = time.time() start = time.time()
k2blob = kblob + kblob k2blob = kblob + kblob
chunk = 793 chunk = 793
for i in xrange(10): for i in range(10):
f = sftp.open('%s/hongry.txt' % FOLDER, 'r') with sftp.open('%s/hongry.txt' % FOLDER, 'rb') as f:
base_offset = (512 * 1024) + 17 * random.randint(1000, 2000) base_offset = (512 * 1024) + 17 * random.randint(1000, 2000)
# make a bunch of offsets and put them in random order # make a bunch of offsets and put them in random order
offsets = [base_offset + j * chunk for j in xrange(100)] offsets = [base_offset + j * chunk for j in range(100)]
readv_list = [] readv_list = []
for j in xrange(100): for j in range(100):
o = offsets[random.randint(0, len(offsets) - 1)] o = offsets[random.randint(0, len(offsets) - 1)]
offsets.remove(o) offsets.remove(o)
readv_list.append((o, chunk)) readv_list.append((o, chunk))
ret = f.readv(readv_list) ret = f.readv(readv_list)
for i in xrange(len(readv_list)): for i in range(len(readv_list)):
offset = readv_list[i][0] offset = readv_list[i][0]
n_offset = offset % 1024 n_offset = offset % 1024
self.assertEqual(ret.next(), k2blob[n_offset:n_offset + chunk]) self.assertEqual(next(ret), k2blob[n_offset:n_offset + chunk])
f.close()
end = time.time() end = time.time()
sys.stderr.write('%ds ' % round(end - start)) sys.stderr.write('%ds ' % round(end - start))
finally: finally:
@ -247,28 +238,26 @@ class BigSFTPTest (unittest.TestCase):
sftp = get_sftp() sftp = get_sftp()
kblob = (1024 * 'x') kblob = (1024 * 'x')
try: try:
f = sftp.open('%s/hongry.txt' % FOLDER, 'w') with sftp.open('%s/hongry.txt' % FOLDER, 'w') as f:
f.set_pipelined(True) f.set_pipelined(True)
for n in range(1024): for n in range(1024):
f.write(kblob) f.write(kblob)
if n % 128 == 0: if n % 128 == 0:
sys.stderr.write('.') sys.stderr.write('.')
f.close()
sys.stderr.write(' ') sys.stderr.write(' ')
self.assertEqual(sftp.stat('%s/hongry.txt' % FOLDER).st_size, 1024 * 1024) self.assertEqual(sftp.stat('%s/hongry.txt' % FOLDER).st_size, 1024 * 1024)
for i in range(10): for i in range(10):
f = sftp.open('%s/hongry.txt' % FOLDER, 'r') with sftp.open('%s/hongry.txt' % FOLDER, 'r') as f:
f.prefetch()
with sftp.open('%s/hongry.txt' % FOLDER, 'r') as f:
f.prefetch() f.prefetch()
f = sftp.open('%s/hongry.txt' % FOLDER, 'r') for n in range(1024):
f.prefetch() data = f.read(1024)
for n in range(1024): self.assertEqual(data, kblob)
data = f.read(1024) if n % 128 == 0:
self.assertEqual(data, kblob) sys.stderr.write('.')
if n % 128 == 0:
sys.stderr.write('.')
f.close()
sys.stderr.write(' ') sys.stderr.write(' ')
finally: finally:
sftp.remove('%s/hongry.txt' % FOLDER) sftp.remove('%s/hongry.txt' % FOLDER)
@ -278,35 +267,33 @@ class BigSFTPTest (unittest.TestCase):
verify that prefetch and readv don't conflict with each other. verify that prefetch and readv don't conflict with each other.
""" """
sftp = get_sftp() sftp = get_sftp()
kblob = ''.join([struct.pack('>H', n) for n in xrange(512)]) kblob = bytes().join([struct.pack('>H', n) for n in range(512)])
try: try:
f = sftp.open('%s/hongry.txt' % FOLDER, 'w') with sftp.open('%s/hongry.txt' % FOLDER, 'wb') as f:
f.set_pipelined(True) f.set_pipelined(True)
for n in range(1024): for n in range(1024):
f.write(kblob) f.write(kblob)
if n % 128 == 0: if n % 128 == 0:
sys.stderr.write('.') sys.stderr.write('.')
f.close()
sys.stderr.write(' ') sys.stderr.write(' ')
self.assertEqual(sftp.stat('%s/hongry.txt' % FOLDER).st_size, 1024 * 1024) self.assertEqual(sftp.stat('%s/hongry.txt' % FOLDER).st_size, 1024 * 1024)
f = sftp.open('%s/hongry.txt' % FOLDER, 'r') with sftp.open('%s/hongry.txt' % FOLDER, 'rb') as f:
f.prefetch() f.prefetch()
data = f.read(1024) data = f.read(1024)
self.assertEqual(data, kblob) self.assertEqual(data, kblob)
chunk_size = 793 chunk_size = 793
base_offset = 512 * 1024 base_offset = 512 * 1024
k2blob = kblob + kblob k2blob = kblob + kblob
chunks = [(base_offset + (chunk_size * i), chunk_size) for i in range(20)] chunks = [(base_offset + (chunk_size * i), chunk_size) for i in range(20)]
for data in f.readv(chunks): for data in f.readv(chunks):
offset = base_offset % 1024 offset = base_offset % 1024
self.assertEqual(chunk_size, len(data)) self.assertEqual(chunk_size, len(data))
self.assertEqual(k2blob[offset:offset + chunk_size], data) self.assertEqual(k2blob[offset:offset + chunk_size], data)
base_offset += chunk_size base_offset += chunk_size
f.close()
sys.stderr.write(' ') sys.stderr.write(' ')
finally: finally:
sftp.remove('%s/hongry.txt' % FOLDER) sftp.remove('%s/hongry.txt' % FOLDER)
@ -317,26 +304,24 @@ class BigSFTPTest (unittest.TestCase):
returned as a single blob. returned as a single blob.
""" """
sftp = get_sftp() sftp = get_sftp()
kblob = ''.join([struct.pack('>H', n) for n in xrange(512)]) kblob = bytes().join([struct.pack('>H', n) for n in range(512)])
try: try:
f = sftp.open('%s/hongry.txt' % FOLDER, 'w') with sftp.open('%s/hongry.txt' % FOLDER, 'wb') as f:
f.set_pipelined(True) f.set_pipelined(True)
for n in range(1024): for n in range(1024):
f.write(kblob) f.write(kblob)
if n % 128 == 0: if n % 128 == 0:
sys.stderr.write('.') sys.stderr.write('.')
f.close()
sys.stderr.write(' ') sys.stderr.write(' ')
self.assertEqual(sftp.stat('%s/hongry.txt' % FOLDER).st_size, 1024 * 1024) self.assertEqual(sftp.stat('%s/hongry.txt' % FOLDER).st_size, 1024 * 1024)
f = sftp.open('%s/hongry.txt' % FOLDER, 'r') with sftp.open('%s/hongry.txt' % FOLDER, 'rb') as f:
data = list(f.readv([(23 * 1024, 128 * 1024)])) data = list(f.readv([(23 * 1024, 128 * 1024)]))
self.assertEqual(1, len(data)) self.assertEqual(1, len(data))
data = data[0] data = data[0]
self.assertEqual(128 * 1024, len(data)) self.assertEqual(128 * 1024, len(data))
f.close()
sys.stderr.write(' ') sys.stderr.write(' ')
finally: finally:
sftp.remove('%s/hongry.txt' % FOLDER) sftp.remove('%s/hongry.txt' % FOLDER)
@ -348,9 +333,8 @@ class BigSFTPTest (unittest.TestCase):
sftp = get_sftp() sftp = get_sftp()
mblob = (1024 * 1024 * 'x') mblob = (1024 * 1024 * 'x')
try: try:
f = sftp.open('%s/hongry.txt' % FOLDER, 'w', 128 * 1024) with sftp.open('%s/hongry.txt' % FOLDER, 'w', 128 * 1024) as f:
f.write(mblob) f.write(mblob)
f.close()
self.assertEqual(sftp.stat('%s/hongry.txt' % FOLDER).st_size, 1024 * 1024) self.assertEqual(sftp.stat('%s/hongry.txt' % FOLDER).st_size, 1024 * 1024)
finally: finally:
@ -365,21 +349,26 @@ class BigSFTPTest (unittest.TestCase):
t.packetizer.REKEY_BYTES = 512 * 1024 t.packetizer.REKEY_BYTES = 512 * 1024
k32blob = (32 * 1024 * 'x') k32blob = (32 * 1024 * 'x')
try: try:
f = sftp.open('%s/hongry.txt' % FOLDER, 'w', 128 * 1024) with sftp.open('%s/hongry.txt' % FOLDER, 'w', 128 * 1024) as f:
for i in xrange(32): for i in range(32):
f.write(k32blob) f.write(k32blob)
f.close()
self.assertEqual(sftp.stat('%s/hongry.txt' % FOLDER).st_size, 1024 * 1024) self.assertEqual(sftp.stat('%s/hongry.txt' % FOLDER).st_size, 1024 * 1024)
self.assertNotEquals(t.H, t.session_id) self.assertNotEqual(t.H, t.session_id)
# try to read it too. # try to read it too.
f = sftp.open('%s/hongry.txt' % FOLDER, 'r', 128 * 1024) with sftp.open('%s/hongry.txt' % FOLDER, 'r', 128 * 1024) as f:
f.prefetch() f.prefetch()
total = 0 total = 0
while total < 1024 * 1024: while total < 1024 * 1024:
total += len(f.read(32 * 1024)) total += len(f.read(32 * 1024))
f.close()
finally: finally:
sftp.remove('%s/hongry.txt' % FOLDER) sftp.remove('%s/hongry.txt' % FOLDER)
t.packetizer.REKEY_BYTES = pow(2, 30) t.packetizer.REKEY_BYTES = pow(2, 30)
if __name__ == '__main__':
from tests.test_sftp import SFTPTest
SFTPTest.init_loopback()
from unittest import main
main()

View File

@ -20,7 +20,7 @@
Some unit tests for the ssh2 protocol in Transport. Some unit tests for the ssh2 protocol in Transport.
""" """
from binascii import hexlify, unhexlify from binascii import hexlify
import select import select
import socket import socket
import sys import sys
@ -33,10 +33,10 @@ from paramiko import Transport, SecurityOptions, ServerInterface, RSAKey, DSSKey
SSHException, BadAuthenticationType, InteractiveQuery, ChannelException SSHException, BadAuthenticationType, InteractiveQuery, ChannelException
from paramiko import AUTH_FAILED, AUTH_PARTIALLY_SUCCESSFUL, AUTH_SUCCESSFUL from paramiko import AUTH_FAILED, AUTH_PARTIALLY_SUCCESSFUL, AUTH_SUCCESSFUL
from paramiko import OPEN_SUCCEEDED, OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED from paramiko import OPEN_SUCCEEDED, OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED
from paramiko.common import MSG_KEXINIT, MSG_CHANNEL_WINDOW_ADJUST from paramiko.common import MSG_KEXINIT, MSG_CHANNEL_WINDOW_ADJUST, b, bytes
from paramiko.message import Message from paramiko.message import Message
from loop import LoopSocket from tests.loop import LoopSocket
from util import ParamikoTest from tests.util import ParamikoTest, test_path
LONG_BANNER = """\ LONG_BANNER = """\
@ -55,7 +55,7 @@ Maybe.
class NullServer (ServerInterface): class NullServer (ServerInterface):
paranoid_did_password = False paranoid_did_password = False
paranoid_did_public_key = False paranoid_did_public_key = False
paranoid_key = DSSKey.from_private_key_file('tests/test_dss.key') paranoid_key = DSSKey.from_private_key_file(test_path('test_dss.key'))
def get_allowed_auths(self, username): def get_allowed_auths(self, username):
if username == 'slowdive': if username == 'slowdive':
@ -121,8 +121,8 @@ class TransportTest(ParamikoTest):
self.sockc.close() self.sockc.close()
def setup_test_server(self, client_options=None, server_options=None): def setup_test_server(self, client_options=None, server_options=None):
host_key = RSAKey.from_private_key_file('tests/test_rsa.key') host_key = RSAKey.from_private_key_file(test_path('test_rsa.key'))
public_host_key = RSAKey(data=str(host_key)) public_host_key = RSAKey(data=host_key.asbytes())
self.ts.add_server_key(host_key) self.ts.add_server_key(host_key)
if client_options is not None: if client_options is not None:
@ -132,37 +132,37 @@ class TransportTest(ParamikoTest):
event = threading.Event() event = threading.Event()
self.server = NullServer() self.server = NullServer()
self.assert_(not event.isSet()) self.assertTrue(not event.isSet())
self.ts.start_server(event, self.server) self.ts.start_server(event, self.server)
self.tc.connect(hostkey=public_host_key, self.tc.connect(hostkey=public_host_key,
username='slowdive', password='pygmalion') username='slowdive', password='pygmalion')
event.wait(1.0) event.wait(1.0)
self.assert_(event.isSet()) self.assertTrue(event.isSet())
self.assert_(self.ts.is_active()) self.assertTrue(self.ts.is_active())
def test_1_security_options(self): def test_1_security_options(self):
o = self.tc.get_security_options() o = self.tc.get_security_options()
self.assertEquals(type(o), SecurityOptions) self.assertEqual(type(o), SecurityOptions)
self.assert_(('aes256-cbc', 'blowfish-cbc') != o.ciphers) self.assertTrue(('aes256-cbc', 'blowfish-cbc') != o.ciphers)
o.ciphers = ('aes256-cbc', 'blowfish-cbc') o.ciphers = ('aes256-cbc', 'blowfish-cbc')
self.assertEquals(('aes256-cbc', 'blowfish-cbc'), o.ciphers) self.assertEqual(('aes256-cbc', 'blowfish-cbc'), o.ciphers)
try: try:
o.ciphers = ('aes256-cbc', 'made-up-cipher') o.ciphers = ('aes256-cbc', 'made-up-cipher')
self.assert_(False) self.assertTrue(False)
except ValueError: except ValueError:
pass pass
try: try:
o.ciphers = 23 o.ciphers = 23
self.assert_(False) self.assertTrue(False)
except TypeError: except TypeError:
pass pass
def test_2_compute_key(self): def test_2_compute_key(self):
self.tc.K = 123281095979686581523377256114209720774539068973101330872763622971399429481072519713536292772709507296759612401802191955568143056534122385270077606457721553469730659233569339356140085284052436697480759510519672848743794433460113118986816826624865291116513647975790797391795651716378444844877749505443714557929L self.tc.K = 123281095979686581523377256114209720774539068973101330872763622971399429481072519713536292772709507296759612401802191955568143056534122385270077606457721553469730659233569339356140085284052436697480759510519672848743794433460113118986816826624865291116513647975790797391795651716378444844877749505443714557929
self.tc.H = unhexlify('0C8307CDE6856FF30BA93684EB0F04C2520E9ED3') self.tc.H = b'\x0C\x83\x07\xCD\xE6\x85\x6F\xF3\x0B\xA9\x36\x84\xEB\x0F\x04\xC2\x52\x0E\x9E\xD3'
self.tc.session_id = self.tc.H self.tc.session_id = self.tc.H
key = self.tc._compute_key('C', 32) key = self.tc._compute_key('C', 32)
self.assertEquals('207E66594CA87C44ECCBA3B3CD39FDDB378E6FDB0F97C54B2AA0CFBF900CD995', self.assertEqual(b'207E66594CA87C44ECCBA3B3CD39FDDB378E6FDB0F97C54B2AA0CFBF900CD995',
hexlify(key).upper()) hexlify(key).upper())
def test_3_simple(self): def test_3_simple(self):
@ -171,44 +171,44 @@ class TransportTest(ParamikoTest):
loopback sockets. this is hardly "simple" but it's simpler than the loopback sockets. this is hardly "simple" but it's simpler than the
later tests. :) later tests. :)
""" """
host_key = RSAKey.from_private_key_file('tests/test_rsa.key') host_key = RSAKey.from_private_key_file(test_path('test_rsa.key'))
public_host_key = RSAKey(data=str(host_key)) public_host_key = RSAKey(data=host_key.asbytes())
self.ts.add_server_key(host_key) self.ts.add_server_key(host_key)
event = threading.Event() event = threading.Event()
server = NullServer() server = NullServer()
self.assert_(not event.isSet()) self.assertTrue(not event.isSet())
self.assertEquals(None, self.tc.get_username()) self.assertEqual(None, self.tc.get_username())
self.assertEquals(None, self.ts.get_username()) self.assertEqual(None, self.ts.get_username())
self.assertEquals(False, self.tc.is_authenticated()) self.assertEqual(False, self.tc.is_authenticated())
self.assertEquals(False, self.ts.is_authenticated()) self.assertEqual(False, self.ts.is_authenticated())
self.ts.start_server(event, server) self.ts.start_server(event, server)
self.tc.connect(hostkey=public_host_key, self.tc.connect(hostkey=public_host_key,
username='slowdive', password='pygmalion') username='slowdive', password='pygmalion')
event.wait(1.0) event.wait(1.0)
self.assert_(event.isSet()) self.assertTrue(event.isSet())
self.assert_(self.ts.is_active()) self.assertTrue(self.ts.is_active())
self.assertEquals('slowdive', self.tc.get_username()) self.assertEqual('slowdive', self.tc.get_username())
self.assertEquals('slowdive', self.ts.get_username()) self.assertEqual('slowdive', self.ts.get_username())
self.assertEquals(True, self.tc.is_authenticated()) self.assertEqual(True, self.tc.is_authenticated())
self.assertEquals(True, self.ts.is_authenticated()) self.assertEqual(True, self.ts.is_authenticated())
def test_3a_long_banner(self): def test_3a_long_banner(self):
""" """
verify that a long banner doesn't mess up the handshake. verify that a long banner doesn't mess up the handshake.
""" """
host_key = RSAKey.from_private_key_file('tests/test_rsa.key') host_key = RSAKey.from_private_key_file(test_path('test_rsa.key'))
public_host_key = RSAKey(data=str(host_key)) public_host_key = RSAKey(data=host_key.asbytes())
self.ts.add_server_key(host_key) self.ts.add_server_key(host_key)
event = threading.Event() event = threading.Event()
server = NullServer() server = NullServer()
self.assert_(not event.isSet()) self.assertTrue(not event.isSet())
self.socks.send(LONG_BANNER) self.socks.send(LONG_BANNER)
self.ts.start_server(event, server) self.ts.start_server(event, server)
self.tc.connect(hostkey=public_host_key, self.tc.connect(hostkey=public_host_key,
username='slowdive', password='pygmalion') username='slowdive', password='pygmalion')
event.wait(1.0) event.wait(1.0)
self.assert_(event.isSet()) self.assertTrue(event.isSet())
self.assert_(self.ts.is_active()) self.assertTrue(self.ts.is_active())
def test_4_special(self): def test_4_special(self):
""" """
@ -219,10 +219,10 @@ class TransportTest(ParamikoTest):
options.ciphers = ('aes256-cbc',) options.ciphers = ('aes256-cbc',)
options.digests = ('hmac-md5-96',) options.digests = ('hmac-md5-96',)
self.setup_test_server(client_options=force_algorithms) self.setup_test_server(client_options=force_algorithms)
self.assertEquals('aes256-cbc', self.tc.local_cipher) self.assertEqual('aes256-cbc', self.tc.local_cipher)
self.assertEquals('aes256-cbc', self.tc.remote_cipher) self.assertEqual('aes256-cbc', self.tc.remote_cipher)
self.assertEquals(12, self.tc.packetizer.get_mac_size_out()) self.assertEqual(12, self.tc.packetizer.get_mac_size_out())
self.assertEquals(12, self.tc.packetizer.get_mac_size_in()) self.assertEqual(12, self.tc.packetizer.get_mac_size_in())
self.tc.send_ignore(1024) self.tc.send_ignore(1024)
self.tc.renegotiate_keys() self.tc.renegotiate_keys()
@ -233,10 +233,10 @@ class TransportTest(ParamikoTest):
verify that the keepalive will be sent. verify that the keepalive will be sent.
""" """
self.setup_test_server() self.setup_test_server()
self.assertEquals(None, getattr(self.server, '_global_request', None)) self.assertEqual(None, getattr(self.server, '_global_request', None))
self.tc.set_keepalive(1) self.tc.set_keepalive(1)
time.sleep(2) time.sleep(2)
self.assertEquals('keepalive@lag.net', self.server._global_request) self.assertEqual('keepalive@lag.net', self.server._global_request)
def test_6_exec_command(self): def test_6_exec_command(self):
""" """
@ -248,8 +248,8 @@ class TransportTest(ParamikoTest):
schan = self.ts.accept(1.0) schan = self.ts.accept(1.0)
try: try:
chan.exec_command('no') chan.exec_command('no')
self.assert_(False) self.assertTrue(False)
except SSHException, x: except SSHException:
pass pass
chan = self.tc.open_session() chan = self.tc.open_session()
@ -260,11 +260,11 @@ class TransportTest(ParamikoTest):
schan.close() schan.close()
f = chan.makefile() f = chan.makefile()
self.assertEquals('Hello there.\n', f.readline()) self.assertEqual('Hello there.\n', f.readline())
self.assertEquals('', f.readline()) self.assertEqual('', f.readline())
f = chan.makefile_stderr() f = chan.makefile_stderr()
self.assertEquals('This is on stderr.\n', f.readline()) self.assertEqual('This is on stderr.\n', f.readline())
self.assertEquals('', f.readline()) self.assertEqual('', f.readline())
# now try it with combined stdout/stderr # now try it with combined stdout/stderr
chan = self.tc.open_session() chan = self.tc.open_session()
@ -276,9 +276,9 @@ class TransportTest(ParamikoTest):
chan.set_combine_stderr(True) chan.set_combine_stderr(True)
f = chan.makefile() f = chan.makefile()
self.assertEquals('Hello there.\n', f.readline()) self.assertEqual('Hello there.\n', f.readline())
self.assertEquals('This is on stderr.\n', f.readline()) self.assertEqual('This is on stderr.\n', f.readline())
self.assertEquals('', f.readline()) self.assertEqual('', f.readline())
def test_7_invoke_shell(self): def test_7_invoke_shell(self):
""" """
@ -290,9 +290,9 @@ class TransportTest(ParamikoTest):
schan = self.ts.accept(1.0) schan = self.ts.accept(1.0)
chan.send('communist j. cat\n') chan.send('communist j. cat\n')
f = schan.makefile() f = schan.makefile()
self.assertEquals('communist j. cat\n', f.readline()) self.assertEqual('communist j. cat\n', f.readline())
chan.close() chan.close()
self.assertEquals('', f.readline()) self.assertEqual('', f.readline())
def test_8_channel_exception(self): def test_8_channel_exception(self):
""" """
@ -302,8 +302,8 @@ class TransportTest(ParamikoTest):
try: try:
chan = self.tc.open_channel('bogus') chan = self.tc.open_channel('bogus')
self.fail('expected exception') self.fail('expected exception')
except ChannelException, x: except ChannelException as e:
self.assert_(x.code == OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED) self.assertTrue(e.code == OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED)
def test_9_exit_status(self): def test_9_exit_status(self):
""" """
@ -315,7 +315,7 @@ class TransportTest(ParamikoTest):
schan = self.ts.accept(1.0) schan = self.ts.accept(1.0)
chan.exec_command('yes') chan.exec_command('yes')
schan.send('Hello there.\n') schan.send('Hello there.\n')
self.assert_(not chan.exit_status_ready()) self.assertTrue(not chan.exit_status_ready())
# trigger an EOF # trigger an EOF
schan.shutdown_read() schan.shutdown_read()
schan.shutdown_write() schan.shutdown_write()
@ -323,15 +323,15 @@ class TransportTest(ParamikoTest):
schan.close() schan.close()
f = chan.makefile() f = chan.makefile()
self.assertEquals('Hello there.\n', f.readline()) self.assertEqual('Hello there.\n', f.readline())
self.assertEquals('', f.readline()) self.assertEqual('', f.readline())
count = 0 count = 0
while not chan.exit_status_ready(): while not chan.exit_status_ready():
time.sleep(0.1) time.sleep(0.1)
count += 1 count += 1
if count > 50: if count > 50:
raise Exception("timeout") raise Exception("timeout")
self.assertEquals(23, chan.recv_exit_status()) self.assertEqual(23, chan.recv_exit_status())
chan.close() chan.close()
def test_A_select(self): def test_A_select(self):
@ -345,9 +345,9 @@ class TransportTest(ParamikoTest):
# nothing should be ready # nothing should be ready
r, w, e = select.select([chan], [], [], 0.1) r, w, e = select.select([chan], [], [], 0.1)
self.assertEquals([], r) self.assertEqual([], r)
self.assertEquals([], w) self.assertEqual([], w)
self.assertEquals([], e) self.assertEqual([], e)
schan.send('hello\n') schan.send('hello\n')
@ -357,17 +357,17 @@ class TransportTest(ParamikoTest):
if chan in r: if chan in r:
break break
time.sleep(0.1) time.sleep(0.1)
self.assertEquals([chan], r) self.assertEqual([chan], r)
self.assertEquals([], w) self.assertEqual([], w)
self.assertEquals([], e) self.assertEqual([], e)
self.assertEquals('hello\n', chan.recv(6)) self.assertEqual(b'hello\n', chan.recv(6))
# and, should be dead again now # and, should be dead again now
r, w, e = select.select([chan], [], [], 0.1) r, w, e = select.select([chan], [], [], 0.1)
self.assertEquals([], r) self.assertEqual([], r)
self.assertEquals([], w) self.assertEqual([], w)
self.assertEquals([], e) self.assertEqual([], e)
schan.close() schan.close()
@ -377,17 +377,17 @@ class TransportTest(ParamikoTest):
if chan in r: if chan in r:
break break
time.sleep(0.1) time.sleep(0.1)
self.assertEquals([chan], r) self.assertEqual([chan], r)
self.assertEquals([], w) self.assertEqual([], w)
self.assertEquals([], e) self.assertEqual([], e)
self.assertEquals('', chan.recv(16)) self.assertEqual(bytes(), chan.recv(16))
# make sure the pipe is still open for now... # make sure the pipe is still open for now...
p = chan._pipe p = chan._pipe
self.assertEquals(False, p._closed) self.assertEqual(False, p._closed)
chan.close() chan.close()
# ...and now is closed. # ...and now is closed.
self.assertEquals(True, p._closed) self.assertEqual(True, p._closed)
def test_B_renegotiate(self): def test_B_renegotiate(self):
""" """
@ -399,17 +399,17 @@ class TransportTest(ParamikoTest):
chan.exec_command('yes') chan.exec_command('yes')
schan = self.ts.accept(1.0) schan = self.ts.accept(1.0)
self.assertEquals(self.tc.H, self.tc.session_id) self.assertEqual(self.tc.H, self.tc.session_id)
for i in range(20): for i in range(20):
chan.send('x' * 1024) chan.send('x' * 1024)
chan.close() chan.close()
# allow a few seconds for the rekeying to complete # allow a few seconds for the rekeying to complete
for i in xrange(50): for i in range(50):
if self.tc.H != self.tc.session_id: if self.tc.H != self.tc.session_id:
break break
time.sleep(0.1) time.sleep(0.1)
self.assertNotEquals(self.tc.H, self.tc.session_id) self.assertNotEqual(self.tc.H, self.tc.session_id)
schan.close() schan.close()
@ -428,8 +428,8 @@ class TransportTest(ParamikoTest):
chan.send('x' * 1024) chan.send('x' * 1024)
bytes2 = self.tc.packetizer._Packetizer__sent_bytes bytes2 = self.tc.packetizer._Packetizer__sent_bytes
# tests show this is actually compressed to *52 bytes*! including packet overhead! nice!! :) # tests show this is actually compressed to *52 bytes*! including packet overhead! nice!! :)
self.assert_(bytes2 - bytes < 1024) self.assertTrue(bytes2 - bytes < 1024)
self.assertEquals(52, bytes2 - bytes) self.assertEqual(52, bytes2 - bytes)
chan.close() chan.close()
schan.close() schan.close()
@ -444,24 +444,25 @@ class TransportTest(ParamikoTest):
schan = self.ts.accept(1.0) schan = self.ts.accept(1.0)
requested = [] requested = []
def handler(c, (addr, port)): def handler(c, addr_port):
addr, port = addr_port
requested.append((addr, port)) requested.append((addr, port))
self.tc._queue_incoming_channel(c) self.tc._queue_incoming_channel(c)
self.assertEquals(None, getattr(self.server, '_x11_screen_number', None)) self.assertEqual(None, getattr(self.server, '_x11_screen_number', None))
cookie = chan.request_x11(0, single_connection=True, handler=handler) cookie = chan.request_x11(0, single_connection=True, handler=handler)
self.assertEquals(0, self.server._x11_screen_number) self.assertEqual(0, self.server._x11_screen_number)
self.assertEquals('MIT-MAGIC-COOKIE-1', self.server._x11_auth_protocol) self.assertEqual('MIT-MAGIC-COOKIE-1', self.server._x11_auth_protocol)
self.assertEquals(cookie, self.server._x11_auth_cookie) self.assertEqual(cookie, self.server._x11_auth_cookie)
self.assertEquals(True, self.server._x11_single_connection) self.assertEqual(True, self.server._x11_single_connection)
x11_server = self.ts.open_x11_channel(('localhost', 6093)) x11_server = self.ts.open_x11_channel(('localhost', 6093))
x11_client = self.tc.accept() x11_client = self.tc.accept()
self.assertEquals('localhost', requested[0][0]) self.assertEqual('localhost', requested[0][0])
self.assertEquals(6093, requested[0][1]) self.assertEqual(6093, requested[0][1])
x11_server.send('hello') x11_server.send('hello')
self.assertEquals('hello', x11_client.recv(5)) self.assertEqual(b'hello', x11_client.recv(5))
x11_server.close() x11_server.close()
x11_client.close() x11_client.close()
@ -479,13 +480,13 @@ class TransportTest(ParamikoTest):
schan = self.ts.accept(1.0) schan = self.ts.accept(1.0)
requested = [] requested = []
def handler(c, (origin_addr, origin_port), (server_addr, server_port)): def handler(c, origin_addr_port, server_addr_port):
requested.append((origin_addr, origin_port)) requested.append(origin_addr_port)
requested.append((server_addr, server_port)) requested.append(server_addr_port)
self.tc._queue_incoming_channel(c) self.tc._queue_incoming_channel(c)
port = self.tc.request_port_forward('127.0.0.1', 0, handler) port = self.tc.request_port_forward('127.0.0.1', 0, handler)
self.assertEquals(port, self.server._listen.getsockname()[1]) self.assertEqual(port, self.server._listen.getsockname()[1])
cs = socket.socket() cs = socket.socket()
cs.connect(('127.0.0.1', port)) cs.connect(('127.0.0.1', port))
@ -494,7 +495,7 @@ class TransportTest(ParamikoTest):
cch = self.tc.accept() cch = self.tc.accept()
sch.send('hello') sch.send('hello')
self.assertEquals('hello', cch.recv(5)) self.assertEqual(b'hello', cch.recv(5))
sch.close() sch.close()
cch.close() cch.close()
ss.close() ss.close()
@ -526,12 +527,12 @@ class TransportTest(ParamikoTest):
cch.connect(self.server._tcpip_dest) cch.connect(self.server._tcpip_dest)
ss, _ = greeting_server.accept() ss, _ = greeting_server.accept()
ss.send('Hello!\n') ss.send(b'Hello!\n')
ss.close() ss.close()
sch.send(cch.recv(8192)) sch.send(cch.recv(8192))
sch.close() sch.close()
self.assertEquals('Hello!\n', cs.recv(7)) self.assertEqual(b'Hello!\n', cs.recv(7))
cs.close() cs.close()
def test_G_stderr_select(self): def test_G_stderr_select(self):
@ -546,9 +547,9 @@ class TransportTest(ParamikoTest):
# nothing should be ready # nothing should be ready
r, w, e = select.select([chan], [], [], 0.1) r, w, e = select.select([chan], [], [], 0.1)
self.assertEquals([], r) self.assertEqual([], r)
self.assertEquals([], w) self.assertEqual([], w)
self.assertEquals([], e) self.assertEqual([], e)
schan.send_stderr('hello\n') schan.send_stderr('hello\n')
@ -558,17 +559,17 @@ class TransportTest(ParamikoTest):
if chan in r: if chan in r:
break break
time.sleep(0.1) time.sleep(0.1)
self.assertEquals([chan], r) self.assertEqual([chan], r)
self.assertEquals([], w) self.assertEqual([], w)
self.assertEquals([], e) self.assertEqual([], e)
self.assertEquals('hello\n', chan.recv_stderr(6)) self.assertEqual(b'hello\n', chan.recv_stderr(6))
# and, should be dead again now # and, should be dead again now
r, w, e = select.select([chan], [], [], 0.1) r, w, e = select.select([chan], [], [], 0.1)
self.assertEquals([], r) self.assertEqual([], r)
self.assertEquals([], w) self.assertEqual([], w)
self.assertEquals([], e) self.assertEqual([], e)
schan.close() schan.close()
chan.close() chan.close()
@ -582,7 +583,7 @@ class TransportTest(ParamikoTest):
chan.invoke_shell() chan.invoke_shell()
schan = self.ts.accept(1.0) schan = self.ts.accept(1.0)
self.assertEquals(chan.send_ready(), True) self.assertEqual(chan.send_ready(), True)
total = 0 total = 0
K = '*' * 1024 K = '*' * 1024
while total < 1024 * 1024: while total < 1024 * 1024:
@ -590,11 +591,11 @@ class TransportTest(ParamikoTest):
total += len(K) total += len(K)
if not chan.send_ready(): if not chan.send_ready():
break break
self.assert_(total < 1024 * 1024) self.assertTrue(total < 1024 * 1024)
schan.close() schan.close()
chan.close() chan.close()
self.assertEquals(chan.send_ready(), True) self.assertEqual(chan.send_ready(), True)
def test_I_rekey_deadlock(self): def test_I_rekey_deadlock(self):
""" """
@ -657,7 +658,7 @@ class TransportTest(ParamikoTest):
def run(self): def run(self):
try: try:
for i in xrange(1, 1+self.iterations): for i in range(1, 1+self.iterations):
if self.done_event.isSet(): if self.done_event.isSet():
break break
self.watchdog_event.set() self.watchdog_event.set()
@ -706,7 +707,7 @@ class TransportTest(ParamikoTest):
# Simulate in-transit MSG_CHANNEL_WINDOW_ADJUST by sending it # Simulate in-transit MSG_CHANNEL_WINDOW_ADJUST by sending it
# before responding to the incoming MSG_KEXINIT. # before responding to the incoming MSG_KEXINIT.
m2 = Message() m2 = Message()
m2.add_byte(chr(MSG_CHANNEL_WINDOW_ADJUST)) m2.add_byte(cMSG_CHANNEL_WINDOW_ADJUST)
m2.add_int(chan.remote_chanid) m2.add_int(chan.remote_chanid)
m2.add_int(1) # bytes to add m2.add_int(1) # bytes to add
self._send_message(m2) self._send_message(m2)

View File

@ -21,15 +21,15 @@ Some unit tests for utility functions.
""" """
from binascii import hexlify from binascii import hexlify
import cStringIO
import errno import errno
import os import os
import unittest import unittest
from Crypto.Hash import SHA from Crypto.Hash import SHA
import paramiko.util import paramiko.util
from paramiko.util import lookup_ssh_host_config as host_config from paramiko.util import lookup_ssh_host_config as host_config
from paramiko.py3compat import StringIO, byte_ord, b
from util import ParamikoTest from tests.util import ParamikoTest
test_config_file = """\ test_config_file = """\
Host * Host *
@ -65,7 +65,7 @@ class UtilTest(ParamikoTest):
""" """
verify that all the classes can be imported from paramiko. verify that all the classes can be imported from paramiko.
""" """
symbols = globals().keys() symbols = list(globals().keys())
self.assertTrue('Transport' in symbols) self.assertTrue('Transport' in symbols)
self.assertTrue('SSHClient' in symbols) self.assertTrue('SSHClient' in symbols)
self.assertTrue('MissingHostKeyPolicy' in symbols) self.assertTrue('MissingHostKeyPolicy' in symbols)
@ -101,9 +101,9 @@ class UtilTest(ParamikoTest):
def test_2_parse_config(self): def test_2_parse_config(self):
global test_config_file global test_config_file
f = cStringIO.StringIO(test_config_file) f = StringIO(test_config_file)
config = paramiko.util.parse_ssh_config(f) config = paramiko.util.parse_ssh_config(f)
self.assertEquals(config._config, self.assertEqual(config._config,
[{'host': ['*'], 'config': {}}, {'host': ['*'], 'config': {'identityfile': ['~/.ssh/id_rsa'], 'user': 'robey'}}, [{'host': ['*'], 'config': {}}, {'host': ['*'], 'config': {'identityfile': ['~/.ssh/id_rsa'], 'user': 'robey'}},
{'host': ['*.example.com'], 'config': {'user': 'bjork', 'port': '3333'}}, {'host': ['*.example.com'], 'config': {'user': 'bjork', 'port': '3333'}},
{'host': ['*'], 'config': {'crazy': 'something dumb '}}, {'host': ['*'], 'config': {'crazy': 'something dumb '}},
@ -111,7 +111,7 @@ class UtilTest(ParamikoTest):
def test_3_host_config(self): def test_3_host_config(self):
global test_config_file global test_config_file
f = cStringIO.StringIO(test_config_file) f = StringIO(test_config_file)
config = paramiko.util.parse_ssh_config(f) config = paramiko.util.parse_ssh_config(f)
for host, values in { for host, values in {
@ -131,27 +131,26 @@ class UtilTest(ParamikoTest):
hostname=host, hostname=host,
identityfile=[os.path.expanduser("~/.ssh/id_rsa")] identityfile=[os.path.expanduser("~/.ssh/id_rsa")]
) )
self.assertEquals( self.assertEqual(
paramiko.util.lookup_ssh_host_config(host, config), paramiko.util.lookup_ssh_host_config(host, config),
values values
) )
def test_4_generate_key_bytes(self): def test_4_generate_key_bytes(self):
x = paramiko.util.generate_key_bytes(SHA, 'ABCDEFGH', 'This is my secret passphrase.', 64) x = paramiko.util.generate_key_bytes(SHA, b'ABCDEFGH', 'This is my secret passphrase.', 64)
hex = ''.join(['%02x' % ord(c) for c in x]) hex = ''.join(['%02x' % byte_ord(c) for c in x])
self.assertEquals(hex, '9110e2f6793b69363e58173e9436b13a5a4b339005741d5c680e505f57d871347b4239f14fb5c46e857d5e100424873ba849ac699cea98d729e57b3e84378e8b') self.assertEqual(hex, '9110e2f6793b69363e58173e9436b13a5a4b339005741d5c680e505f57d871347b4239f14fb5c46e857d5e100424873ba849ac699cea98d729e57b3e84378e8b')
def test_5_host_keys(self): def test_5_host_keys(self):
f = open('hostfile.temp', 'w') with open('hostfile.temp', 'w') as f:
f.write(test_hosts_file) f.write(test_hosts_file)
f.close()
try: try:
hostdict = paramiko.util.load_host_keys('hostfile.temp') hostdict = paramiko.util.load_host_keys('hostfile.temp')
self.assertEquals(2, len(hostdict)) self.assertEqual(2, len(hostdict))
self.assertEquals(1, len(hostdict.values()[0])) self.assertEqual(1, len(list(hostdict.values())[0]))
self.assertEquals(1, len(hostdict.values()[1])) self.assertEqual(1, len(list(hostdict.values())[1]))
fp = hexlify(hostdict['secure.example.com']['ssh-rsa'].get_fingerprint()).upper() fp = hexlify(hostdict['secure.example.com']['ssh-rsa'].get_fingerprint()).upper()
self.assertEquals('E6684DB30E109B67B70FF1DC5C7F1363', fp) self.assertEqual(b'E6684DB30E109B67B70FF1DC5C7F1363', fp)
finally: finally:
os.unlink('hostfile.temp') os.unlink('hostfile.temp')
@ -159,7 +158,7 @@ class UtilTest(ParamikoTest):
from paramiko.common import rng from paramiko.common import rng
# just verify that we can pull out 32 bytes and not get an exception. # just verify that we can pull out 32 bytes and not get an exception.
x = rng.read(32) x = rng.read(32)
self.assertEquals(len(x), 32) self.assertEqual(len(x), 32)
def test_7_host_config_expose_issue_33(self): def test_7_host_config_expose_issue_33(self):
test_config_file = """ test_config_file = """
@ -172,16 +171,16 @@ Host *.example.com
Host * Host *
Port 3333 Port 3333
""" """
f = cStringIO.StringIO(test_config_file) f = StringIO(test_config_file)
config = paramiko.util.parse_ssh_config(f) config = paramiko.util.parse_ssh_config(f)
host = 'www13.example.com' host = 'www13.example.com'
self.assertEquals( self.assertEqual(
paramiko.util.lookup_ssh_host_config(host, config), paramiko.util.lookup_ssh_host_config(host, config),
{'hostname': host, 'port': '22'} {'hostname': host, 'port': '22'}
) )
def test_8_eintr_retry(self): def test_8_eintr_retry(self):
self.assertEquals('foo', paramiko.util.retry_on_signal(lambda: 'foo')) self.assertEqual('foo', paramiko.util.retry_on_signal(lambda: 'foo'))
# Variables that are set by raises_intr # Variables that are set by raises_intr
intr_errors_remaining = [3] intr_errors_remaining = [3]
@ -192,8 +191,8 @@ Host *
intr_errors_remaining[0] -= 1 intr_errors_remaining[0] -= 1
raise IOError(errno.EINTR, 'file', 'interrupted system call') raise IOError(errno.EINTR, 'file', 'interrupted system call')
self.assertTrue(paramiko.util.retry_on_signal(raises_intr) is None) self.assertTrue(paramiko.util.retry_on_signal(raises_intr) is None)
self.assertEquals(0, intr_errors_remaining[0]) self.assertEqual(0, intr_errors_remaining[0])
self.assertEquals(4, call_count[0]) self.assertEqual(4, call_count[0])
def raises_ioerror_not_eintr(): def raises_ioerror_not_eintr():
raise IOError(errno.ENOENT, 'file', 'file not found') raise IOError(errno.ENOENT, 'file', 'file not found')
@ -216,10 +215,10 @@ Host space-delimited
Host equals-delimited Host equals-delimited
ProxyCommand=foo bar=biz baz ProxyCommand=foo bar=biz baz
""" """
f = cStringIO.StringIO(conf) f = StringIO(conf)
config = paramiko.util.parse_ssh_config(f) config = paramiko.util.parse_ssh_config(f)
for host in ('space-delimited', 'equals-delimited'): for host in ('space-delimited', 'equals-delimited'):
self.assertEquals( self.assertEqual(
host_config(host, config)['proxycommand'], host_config(host, config)['proxycommand'],
'foo bar=biz baz' 'foo bar=biz baz'
) )
@ -228,7 +227,7 @@ Host equals-delimited
""" """
ProxyCommand should perform interpolation on the value ProxyCommand should perform interpolation on the value
""" """
config = paramiko.util.parse_ssh_config(cStringIO.StringIO(""" config = paramiko.util.parse_ssh_config(StringIO("""
Host specific Host specific
Port 37 Port 37
ProxyCommand host %h port %p lol ProxyCommand host %h port %p lol
@ -245,7 +244,7 @@ Host *
('specific', "host specific port 37 lol"), ('specific', "host specific port 37 lol"),
('portonly', "host portonly port 155"), ('portonly', "host portonly port 155"),
): ):
self.assertEquals( self.assertEqual(
host_config(host, config)['proxycommand'], host_config(host, config)['proxycommand'],
val val
) )
@ -264,10 +263,10 @@ Host www13.*
Host * Host *
Port 3333 Port 3333
""" """
f = cStringIO.StringIO(test_config_file) f = StringIO(test_config_file)
config = paramiko.util.parse_ssh_config(f) config = paramiko.util.parse_ssh_config(f)
host = 'www13.example.com' host = 'www13.example.com'
self.assertEquals( self.assertEqual(
paramiko.util.lookup_ssh_host_config(host, config), paramiko.util.lookup_ssh_host_config(host, config),
{'hostname': host, 'port': '8080'} {'hostname': host, 'port': '8080'}
) )
@ -293,9 +292,9 @@ ProxyCommand foo=bar:%h-%p
'foo=bar:proxy-without-equal-divisor-22'} 'foo=bar:proxy-without-equal-divisor-22'}
}.items(): }.items():
f = cStringIO.StringIO(test_config_file) f = StringIO(test_config_file)
config = paramiko.util.parse_ssh_config(f) config = paramiko.util.parse_ssh_config(f)
self.assertEquals( self.assertEqual(
paramiko.util.lookup_ssh_host_config(host, config), paramiko.util.lookup_ssh_host_config(host, config),
values values
) )
@ -323,9 +322,9 @@ IdentityFile id_dsa22
'identityfile': ['id_dsa0', 'id_dsa1', 'id_dsa22']} 'identityfile': ['id_dsa0', 'id_dsa1', 'id_dsa22']}
}.items(): }.items():
f = cStringIO.StringIO(test_config_file) f = StringIO(test_config_file)
config = paramiko.util.parse_ssh_config(f) config = paramiko.util.parse_ssh_config(f)
self.assertEquals( self.assertEqual(
paramiko.util.lookup_ssh_host_config(host, config), paramiko.util.lookup_ssh_host_config(host, config),
values values
) )
@ -338,5 +337,5 @@ IdentityFile id_dsa22
AddressFamily inet AddressFamily inet
IdentityFile something_%l_using_fqdn IdentityFile something_%l_using_fqdn
""" """
config = paramiko.util.parse_ssh_config(cStringIO.StringIO(test_config)) config = paramiko.util.parse_ssh_config(StringIO(test_config))
assert config.lookup('meh') # will die during lookup() if bug regresses assert config.lookup('meh') # will die during lookup() if bug regresses

View File

@ -1,5 +1,8 @@
import os
import unittest import unittest
root_path = os.path.dirname(os.path.realpath(__file__))
class ParamikoTest(unittest.TestCase): class ParamikoTest(unittest.TestCase):
# for Python 2.3 and below # for Python 2.3 and below
@ -8,3 +11,7 @@ class ParamikoTest(unittest.TestCase):
if not hasattr(unittest.TestCase, 'assertFalse'): if not hasattr(unittest.TestCase, 'assertFalse'):
assertFalse = unittest.TestCase.failIf assertFalse = unittest.TestCase.failIf
def test_path(filename):
return os.path.join(root_path, filename)

View File

@ -1,5 +1,5 @@
[tox] [tox]
envlist = py25,py26,py27 envlist = py25,py26,py27,py32,py33
[testenv] [testenv]
commands = pip install --use-mirrors -q -r tox-requirements.txt commands = pip install --use-mirrors -q -r tox-requirements.txt