Merge remote-tracking branch 'scottkmaxwell/py3-support-without-py25' into python3

Conflicts:
	dev-requirements.txt
	paramiko/__init__.py
	paramiko/file.py
	paramiko/hostkeys.py
	paramiko/message.py
	paramiko/proxy.py
	paramiko/server.py
	paramiko/transport.py
	paramiko/util.py
	paramiko/win_pageant.py
	setup.py
This commit is contained in:
Jeff Forcier 2014-03-05 17:03:37 -08:00
commit b2be63ec62
64 changed files with 1970 additions and 1593 deletions

View File

@ -2,6 +2,8 @@ language: python
python:
- "2.6"
- "2.7"
- "3.2"
- "3.3"
install:
# Self-install for setup.py-driven deps
- pip install -e .

4
README
View File

@ -15,7 +15,7 @@ What
----
"paramiko" is a combination of the esperanto words for "paranoid" and
"friend". it's a module for python 2.5+ that implements the SSH2 protocol
"friend". it's a module for python 2.6+ that implements the SSH2 protocol
for secure (encrypted and authenticated) connections to remote machines.
unlike SSL (aka TLS), SSH2 protocol does not require hierarchical
certificates signed by a powerful central authority. you may know SSH2 as
@ -34,7 +34,7 @@ that should have come with this archive.
Requirements
------------
- python 2.5 or better <http://www.python.org/>
- python 2.6 or better <http://www.python.org/>
- pycrypto 2.1 or better <https://www.dlitz.net/software/pycrypto/>
- ecdsa 0.9 or better <https://pypi.python.org/pypi/ecdsa>

View File

@ -28,9 +28,13 @@ import socket
import sys
import time
import traceback
from paramiko.py3compat import input
import paramiko
import interactive
try:
import interactive
except ImportError:
from . import interactive
def agent_auth(transport, username):
@ -45,24 +49,24 @@ def agent_auth(transport, username):
return
for key in agent_keys:
print 'Trying ssh-agent key %s' % hexlify(key.get_fingerprint()),
print('Trying ssh-agent key %s' % hexlify(key.get_fingerprint()))
try:
transport.auth_publickey(username, key)
print '... success!'
print('... success!')
return
except paramiko.SSHException:
print '... nope.'
print('... nope.')
def manual_auth(username, hostname):
default_auth = 'p'
auth = raw_input('Auth by (p)assword, (r)sa key, or (d)ss key? [%s] ' % default_auth)
auth = input('Auth by (p)assword, (r)sa key, or (d)ss key? [%s] ' % default_auth)
if len(auth) == 0:
auth = default_auth
if auth == 'r':
default_path = os.path.join(os.environ['HOME'], '.ssh', 'id_rsa')
path = raw_input('RSA key [%s]: ' % default_path)
path = input('RSA key [%s]: ' % default_path)
if len(path) == 0:
path = default_path
try:
@ -73,7 +77,7 @@ def manual_auth(username, hostname):
t.auth_publickey(username, key)
elif auth == 'd':
default_path = os.path.join(os.environ['HOME'], '.ssh', 'id_dsa')
path = raw_input('DSS key [%s]: ' % default_path)
path = input('DSS key [%s]: ' % default_path)
if len(path) == 0:
path = default_path
try:
@ -96,9 +100,9 @@ if len(sys.argv) > 1:
if hostname.find('@') >= 0:
username, hostname = hostname.split('@')
else:
hostname = raw_input('Hostname: ')
hostname = input('Hostname: ')
if len(hostname) == 0:
print '*** Hostname required.'
print('*** Hostname required.')
sys.exit(1)
port = 22
if hostname.find(':') >= 0:
@ -109,8 +113,8 @@ if hostname.find(':') >= 0:
try:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect((hostname, port))
except Exception, e:
print '*** Connect failed: ' + str(e)
except Exception as e:
print('*** Connect failed: ' + str(e))
traceback.print_exc()
sys.exit(1)
@ -119,7 +123,7 @@ try:
try:
t.start_client()
except paramiko.SSHException:
print '*** SSH negotiation failed.'
print('*** SSH negotiation failed.')
sys.exit(1)
try:
@ -128,25 +132,25 @@ try:
try:
keys = paramiko.util.load_host_keys(os.path.expanduser('~/ssh/known_hosts'))
except IOError:
print '*** Unable to open host keys file'
print('*** Unable to open host keys file')
keys = {}
# check server's host key -- this is important.
key = t.get_remote_server_key()
if not keys.has_key(hostname):
print '*** WARNING: Unknown host key!'
elif not keys[hostname].has_key(key.get_name()):
print '*** WARNING: Unknown host key!'
if hostname not in keys:
print('*** WARNING: Unknown host key!')
elif key.get_name() not in keys[hostname]:
print('*** WARNING: Unknown host key!')
elif keys[hostname][key.get_name()] != key:
print '*** WARNING: Host key has changed!!!'
print('*** WARNING: Host key has changed!!!')
sys.exit(1)
else:
print '*** Host key OK.'
print('*** Host key OK.')
# get username
if username == '':
default_username = getpass.getuser()
username = raw_input('Username [%s]: ' % default_username)
username = input('Username [%s]: ' % default_username)
if len(username) == 0:
username = default_username
@ -154,21 +158,20 @@ try:
if not t.is_authenticated():
manual_auth(username, hostname)
if not t.is_authenticated():
print '*** Authentication failed. :('
print('*** Authentication failed. :(')
t.close()
sys.exit(1)
chan = t.open_session()
chan.get_pty()
chan.invoke_shell()
print '*** Here we go!'
print
print('*** Here we go!\n')
interactive.interactive_shell(chan)
chan.close()
t.close()
except Exception, e:
print '*** Caught exception: ' + str(e.__class__) + ': ' + str(e)
except Exception as e:
print('*** Caught exception: ' + str(e.__class__) + ': ' + str(e))
traceback.print_exc()
try:
t.close()

View File

@ -17,9 +17,7 @@
# You should have received a copy of the GNU Lesser General Public License
# along with Paramiko; if not, write to the Free Software Foundation, Inc.,
# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA.
from __future__ import with_statement
import string
import sys
from binascii import hexlify
@ -28,6 +26,7 @@ from optparse import OptionParser
from paramiko import DSSKey
from paramiko import RSAKey
from paramiko.ssh_exception import SSHException
from paramiko.py3compat import u
usage="""
%prog [-v] [-b bits] -t type [-N new_passphrase] [-f output_keyfile]"""
@ -47,16 +46,16 @@ key_dispatch_table = {
def progress(arg=None):
if not arg:
print '0%\x08\x08\x08',
sys.stdout.write('0%\x08\x08\x08 ')
sys.stdout.flush()
elif arg[0] == 'p':
print '25%\x08\x08\x08\x08',
sys.stdout.write('25%\x08\x08\x08\x08 ')
sys.stdout.flush()
elif arg[0] == 'h':
print '50%\x08\x08\x08\x08',
sys.stdout.write('50%\x08\x08\x08\x08 ')
sys.stdout.flush()
elif arg[0] == 'x':
print '75%\x08\x08\x08\x08',
sys.stdout.write('75%\x08\x08\x08\x08 ')
sys.stdout.flush()
if __name__ == '__main__':
@ -92,8 +91,8 @@ if __name__ == '__main__':
parser.print_help()
sys.exit(0)
for o in default_values.keys():
globals()[o] = getattr(options, o, default_values[string.lower(o)])
for o in list(default_values.keys()):
globals()[o] = getattr(options, o, default_values[o.lower()])
if options.newphrase:
phrase = getattr(options, 'newphrase')
@ -106,7 +105,7 @@ if __name__ == '__main__':
if ktype == 'dsa' and bits > 1024:
raise SSHException("DSA Keys must be 1024 bits")
if not key_dispatch_table.has_key(ktype):
if ktype not in key_dispatch_table:
raise SSHException("Unknown %s algorithm to generate keys pair" % ktype)
# generating private key
@ -121,7 +120,7 @@ if __name__ == '__main__':
f.write(" %s" % comment)
if options.verbose:
print "done."
print("done.")
hash = hexlify(pub.get_fingerprint())
print "Fingerprint: %d %s %s.pub (%s)" % (bits, ":".join([ hash[i:2+i] for i in range(0, len(hash), 2)]), filename, string.upper(ktype))
hash = u(hexlify(pub.get_fingerprint()))
print("Fingerprint: %d %s %s.pub (%s)" % (bits, ":".join([ hash[i:2+i] for i in range(0, len(hash), 2)]), filename, ktype.upper()))

View File

@ -27,6 +27,7 @@ import threading
import traceback
import paramiko
from paramiko.py3compat import b, u, decodebytes
# setup logging
@ -35,17 +36,17 @@ paramiko.util.log_to_file('demo_server.log')
host_key = paramiko.RSAKey(filename='test_rsa.key')
#host_key = paramiko.DSSKey(filename='test_dss.key')
print 'Read key: ' + hexlify(host_key.get_fingerprint())
print('Read key: ' + u(hexlify(host_key.get_fingerprint())))
class Server (paramiko.ServerInterface):
# 'data' is the output of base64.encodestring(str(key))
# (using the "user_rsa_key" files)
data = 'AAAAB3NzaC1yc2EAAAABIwAAAIEAyO4it3fHlmGZWJaGrfeHOVY7RWO3P9M7hp' + \
'fAu7jJ2d7eothvfeuoRFtJwhUmZDluRdFyhFY/hFAh76PJKGAusIqIQKlkJxMC' + \
'KDqIexkgHAfID/6mqvmnSJf0b5W8v5h2pI/stOSwTQ+pxVhwJ9ctYDhRSlF0iT' + \
'UWT10hcuO4Ks8='
good_pub_key = paramiko.RSAKey(data=base64.decodestring(data))
data = (b'AAAAB3NzaC1yc2EAAAABIwAAAIEAyO4it3fHlmGZWJaGrfeHOVY7RWO3P9M7hp'
b'fAu7jJ2d7eothvfeuoRFtJwhUmZDluRdFyhFY/hFAh76PJKGAusIqIQKlkJxMC'
b'KDqIexkgHAfID/6mqvmnSJf0b5W8v5h2pI/stOSwTQ+pxVhwJ9ctYDhRSlF0iT'
b'UWT10hcuO4Ks8=')
good_pub_key = paramiko.RSAKey(data=decodebytes(data))
def __init__(self):
self.event = threading.Event()
@ -61,7 +62,7 @@ class Server (paramiko.ServerInterface):
return paramiko.AUTH_FAILED
def check_auth_publickey(self, username, key):
print 'Auth attempt with key: ' + hexlify(key.get_fingerprint())
print('Auth attempt with key: ' + u(hexlify(key.get_fingerprint())))
if (username == 'robey') and (key == self.good_pub_key):
return paramiko.AUTH_SUCCESSFUL
return paramiko.AUTH_FAILED
@ -83,47 +84,47 @@ try:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind(('', 2200))
except Exception, e:
print '*** Bind failed: ' + str(e)
except Exception as e:
print('*** Bind failed: ' + str(e))
traceback.print_exc()
sys.exit(1)
try:
sock.listen(100)
print 'Listening for connection ...'
print('Listening for connection ...')
client, addr = sock.accept()
except Exception, e:
print '*** Listen/accept failed: ' + str(e)
except Exception as e:
print('*** Listen/accept failed: ' + str(e))
traceback.print_exc()
sys.exit(1)
print 'Got a connection!'
print('Got a connection!')
try:
t = paramiko.Transport(client)
try:
t.load_server_moduli()
except:
print '(Failed to load moduli -- gex will be unsupported.)'
print('(Failed to load moduli -- gex will be unsupported.)')
raise
t.add_server_key(host_key)
server = Server()
try:
t.start_server(server=server)
except paramiko.SSHException, x:
print '*** SSH negotiation failed.'
except paramiko.SSHException:
print('*** SSH negotiation failed.')
sys.exit(1)
# wait for auth
chan = t.accept(20)
if chan is None:
print '*** No channel.'
print('*** No channel.')
sys.exit(1)
print 'Authenticated!'
print('Authenticated!')
server.event.wait(10)
if not server.event.isSet():
print '*** Client never asked for a shell.'
print('*** Client never asked for a shell.')
sys.exit(1)
chan.send('\r\n\r\nWelcome to my dorky little BBS!\r\n\r\n')
@ -135,8 +136,8 @@ try:
chan.send('\r\nI don\'t like you, ' + username + '.\r\n')
chan.close()
except Exception, e:
print '*** Caught exception: ' + str(e.__class__) + ': ' + str(e)
except Exception as e:
print('*** Caught exception: ' + str(e.__class__) + ': ' + str(e))
traceback.print_exc()
try:
t.close()

View File

@ -28,6 +28,7 @@ import sys
import traceback
import paramiko
from paramiko.py3compat import input
# setup logging
@ -40,9 +41,9 @@ if len(sys.argv) > 1:
if hostname.find('@') >= 0:
username, hostname = hostname.split('@')
else:
hostname = raw_input('Hostname: ')
hostname = input('Hostname: ')
if len(hostname) == 0:
print '*** Hostname required.'
print('*** Hostname required.')
sys.exit(1)
port = 22
if hostname.find(':') >= 0:
@ -53,7 +54,7 @@ if hostname.find(':') >= 0:
# get username
if username == '':
default_username = getpass.getuser()
username = raw_input('Username [%s]: ' % default_username)
username = input('Username [%s]: ' % default_username)
if len(username) == 0:
username = default_username
password = getpass.getpass('Password for %s@%s: ' % (username, hostname))
@ -69,13 +70,13 @@ except IOError:
# try ~/ssh/ too, because windows can't have a folder named ~/.ssh/
host_keys = paramiko.util.load_host_keys(os.path.expanduser('~/ssh/known_hosts'))
except IOError:
print '*** Unable to open host keys file'
print('*** Unable to open host keys file')
host_keys = {}
if host_keys.has_key(hostname):
if hostname in host_keys:
hostkeytype = host_keys[hostname].keys()[0]
hostkey = host_keys[hostname][hostkeytype]
print 'Using host key of type %s' % hostkeytype
print('Using host key of type %s' % hostkeytype)
# now, connect and use paramiko Transport to negotiate SSH2 across the connection
@ -86,22 +87,26 @@ try:
# dirlist on remote host
dirlist = sftp.listdir('.')
print "Dirlist:", dirlist
print("Dirlist: %s" % dirlist)
# copy this demo onto the server
try:
sftp.mkdir("demo_sftp_folder")
except IOError:
print '(assuming demo_sftp_folder/ already exists)'
sftp.open('demo_sftp_folder/README', 'w').write('This was created by demo_sftp.py.\n')
data = open('demo_sftp.py', 'r').read()
print('(assuming demo_sftp_folder/ already exists)')
with sftp.open('demo_sftp_folder/README', 'w') as f:
f.write('This was created by demo_sftp.py.\n')
with open('demo_sftp.py', 'r') as f:
data = f.read()
sftp.open('demo_sftp_folder/demo_sftp.py', 'w').write(data)
print 'created demo_sftp_folder/ on the server'
print('created demo_sftp_folder/ on the server')
# copy the README back here
data = sftp.open('demo_sftp_folder/README', 'r').read()
open('README_demo_sftp', 'w').write(data)
print 'copied README back here'
with sftp.open('demo_sftp_folder/README', 'r') as f:
data = f.read()
with open('README_demo_sftp', 'w') as f:
f.write(data)
print('copied README back here')
# BETTER: use the get() and put() methods
sftp.put('demo_sftp.py', 'demo_sftp_folder/demo_sftp.py')
@ -109,8 +114,8 @@ try:
t.close()
except Exception, e:
print '*** Caught exception: %s: %s' % (e.__class__, e)
except Exception as e:
print('*** Caught exception: %s: %s' % (e.__class__, e))
traceback.print_exc()
try:
t.close()

View File

@ -25,9 +25,13 @@ import os
import socket
import sys
import traceback
from paramiko.py3compat import input
import paramiko
import interactive
try:
import interactive
except ImportError:
from . import interactive
# setup logging
@ -40,9 +44,9 @@ if len(sys.argv) > 1:
if hostname.find('@') >= 0:
username, hostname = hostname.split('@')
else:
hostname = raw_input('Hostname: ')
hostname = input('Hostname: ')
if len(hostname) == 0:
print '*** Hostname required.'
print('*** Hostname required.')
sys.exit(1)
port = 22
if hostname.find(':') >= 0:
@ -53,7 +57,7 @@ if hostname.find(':') >= 0:
# get username
if username == '':
default_username = getpass.getuser()
username = raw_input('Username [%s]: ' % default_username)
username = input('Username [%s]: ' % default_username)
if len(username) == 0:
username = default_username
password = getpass.getpass('Password for %s@%s: ' % (username, hostname))
@ -64,18 +68,17 @@ try:
client = paramiko.SSHClient()
client.load_system_host_keys()
client.set_missing_host_key_policy(paramiko.WarningPolicy())
print '*** Connecting...'
print('*** Connecting...')
client.connect(hostname, port, username, password)
chan = client.invoke_shell()
print repr(client.get_transport())
print '*** Here we go!'
print
print(repr(client.get_transport()))
print('*** Here we go!\n')
interactive.interactive_shell(chan)
chan.close()
client.close()
except Exception, e:
print '*** Caught exception: %s: %s' % (e.__class__, e)
except Exception as e:
print('*** Caught exception: %s: %s' % (e.__class__, e))
traceback.print_exc()
try:
client.close()

View File

@ -30,7 +30,11 @@ import getpass
import os
import socket
import select
import SocketServer
try:
import SocketServer
except ImportError:
import socketserver as SocketServer
import sys
from optparse import OptionParser
@ -54,7 +58,7 @@ class Handler (SocketServer.BaseRequestHandler):
chan = self.ssh_transport.open_channel('direct-tcpip',
(self.chain_host, self.chain_port),
self.request.getpeername())
except Exception, e:
except Exception as e:
verbose('Incoming request to %s:%d failed: %s' % (self.chain_host,
self.chain_port,
repr(e)))
@ -98,7 +102,7 @@ def forward_tunnel(local_port, remote_host, remote_port, transport):
def verbose(s):
if g_verbose:
print s
print(s)
HELP = """\
@ -165,8 +169,8 @@ def main():
try:
client.connect(server[0], server[1], username=options.user, key_filename=options.keyfile,
look_for_keys=options.look_for_keys, password=password)
except Exception, e:
print '*** Failed to connect to %s:%d: %r' % (server[0], server[1], e)
except Exception as e:
print('*** Failed to connect to %s:%d: %r' % (server[0], server[1], e))
sys.exit(1)
verbose('Now forwarding port %d to %s:%d ...' % (options.port, remote[0], remote[1]))
@ -174,7 +178,7 @@ def main():
try:
forward_tunnel(options.port, remote[0], remote[1], client.get_transport())
except KeyboardInterrupt:
print 'C-c: Port forwarding stopped.'
print('C-c: Port forwarding stopped.')
sys.exit(0)

View File

@ -19,6 +19,7 @@
import socket
import sys
from paramiko.py3compat import u
# windows does not have termios...
try:
@ -49,9 +50,9 @@ def posix_shell(chan):
r, w, e = select.select([chan, sys.stdin], [], [])
if chan in r:
try:
x = chan.recv(1024)
x = u(chan.recv(1024))
if len(x) == 0:
print '\r\n*** EOF\r\n',
sys.stdout.write('\r\n*** EOF\r\n')
break
sys.stdout.write(x)
sys.stdout.flush()

View File

@ -46,7 +46,7 @@ def handler(chan, host, port):
sock = socket.socket()
try:
sock.connect((host, port))
except Exception, e:
except Exception as e:
verbose('Forwarding request to %s:%d failed: %r' % (host, port, e))
return
@ -82,7 +82,7 @@ def reverse_forward_tunnel(server_port, remote_host, remote_port, transport):
def verbose(s):
if g_verbose:
print s
print(s)
HELP = """\
@ -150,8 +150,8 @@ def main():
try:
client.connect(server[0], server[1], username=options.user, key_filename=options.keyfile,
look_for_keys=options.look_for_keys, password=password)
except Exception, e:
print '*** Failed to connect to %s:%d: %r' % (server[0], server[1], e)
except Exception as e:
print('*** Failed to connect to %s:%d: %r' % (server[0], server[1], e))
sys.exit(1)
verbose('Now forwarding remote port %d to %s:%d ...' % (options.port, remote[0], remote[1]))
@ -159,7 +159,7 @@ def main():
try:
reverse_forward_tunnel(options.port, remote[0], remote[1], client.get_transport())
except KeyboardInterrupt:
print 'C-c: Port forwarding stopped.'
print('C-c: Port forwarding stopped.')
sys.exit(0)

View File

@ -18,51 +18,51 @@
import sys
if sys.version_info < (2, 5):
raise RuntimeError('You need Python 2.5+ for this module.')
if sys.version_info < (2, 6):
raise RuntimeError('You need Python 2.6+ for this module.')
__author__ = "Jeff Forcier <jeff@bitprophet.org>"
__version__ = "1.12.2"
__version__ = "1.13.0"
__version_info__ = tuple([ int(d) for d in __version__.split(".") ])
__license__ = "GNU Lesser General Public License (LGPL)"
from transport import SecurityOptions, Transport
from client import SSHClient, MissingHostKeyPolicy, AutoAddPolicy, RejectPolicy, WarningPolicy
from auth_handler import AuthHandler
from channel import Channel, ChannelFile
from ssh_exception import SSHException, PasswordRequiredException, \
from paramiko.transport import SecurityOptions, Transport
from paramiko.client import SSHClient, MissingHostKeyPolicy, AutoAddPolicy, RejectPolicy, WarningPolicy
from paramiko.auth_handler import AuthHandler
from paramiko.channel import Channel, ChannelFile
from paramiko.ssh_exception import SSHException, PasswordRequiredException, \
BadAuthenticationType, ChannelException, BadHostKeyException, \
AuthenticationException, ProxyCommandFailure
from server import ServerInterface, SubsystemHandler, InteractiveQuery
from rsakey import RSAKey
from dsskey import DSSKey
from ecdsakey import ECDSAKey
from sftp import SFTPError, BaseSFTP
from sftp_client import SFTP, SFTPClient
from sftp_server import SFTPServer
from sftp_attr import SFTPAttributes
from sftp_handle import SFTPHandle
from sftp_si import SFTPServerInterface
from sftp_file import SFTPFile
from message import Message
from packet import Packetizer
from file import BufferedFile
from agent import Agent, AgentKey
from pkey import PKey
from hostkeys import HostKeys
from config import SSHConfig
from proxy import ProxyCommand
from paramiko.server import ServerInterface, SubsystemHandler, InteractiveQuery
from paramiko.rsakey import RSAKey
from paramiko.dsskey import DSSKey
from paramiko.ecdsakey import ECDSAKey
from paramiko.sftp import SFTPError, BaseSFTP
from paramiko.sftp_client import SFTP, SFTPClient
from paramiko.sftp_server import SFTPServer
from paramiko.sftp_attr import SFTPAttributes
from paramiko.sftp_handle import SFTPHandle
from paramiko.sftp_si import SFTPServerInterface
from paramiko.sftp_file import SFTPFile
from paramiko.message import Message
from paramiko.packet import Packetizer
from paramiko.file import BufferedFile
from paramiko.agent import Agent, AgentKey
from paramiko.pkey import PKey
from paramiko.hostkeys import HostKeys
from paramiko.config import SSHConfig
from paramiko.proxy import ProxyCommand
from common import AUTH_SUCCESSFUL, AUTH_PARTIALLY_SUCCESSFUL, AUTH_FAILED, \
from paramiko.common import AUTH_SUCCESSFUL, AUTH_PARTIALLY_SUCCESSFUL, AUTH_FAILED, \
OPEN_SUCCEEDED, OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED, OPEN_FAILED_CONNECT_FAILED, \
OPEN_FAILED_UNKNOWN_CHANNEL_TYPE, OPEN_FAILED_RESOURCE_SHORTAGE
from sftp import SFTP_OK, SFTP_EOF, SFTP_NO_SUCH_FILE, SFTP_PERMISSION_DENIED, SFTP_FAILURE, \
from paramiko.sftp import SFTP_OK, SFTP_EOF, SFTP_NO_SUCH_FILE, SFTP_PERMISSION_DENIED, SFTP_FAILURE, \
SFTP_BAD_MESSAGE, SFTP_NO_CONNECTION, SFTP_CONNECTION_LOST, SFTP_OP_UNSUPPORTED
from common import io_sleep
from paramiko.common import io_sleep
__all__ = [ 'Transport',
'SSHClient',

View File

@ -8,7 +8,11 @@ in jaraco.windows and asking the author to port the fixes back here.
import ctypes
import ctypes.wintypes
import __builtin__
from paramiko.py3compat import u
try:
import builtins
except ImportError:
import __builtin__ as builtins
try:
USHORT = ctypes.wintypes.USHORT
@ -40,7 +44,7 @@ def format_system_message(errno):
result_buffer = ctypes.wintypes.LPWSTR()
buffer_size = 0
arguments = None
bytes = ctypes.windll.kernel32.FormatMessageW(
format_bytes = ctypes.windll.kernel32.FormatMessageW(
flags,
source,
message_id,
@ -52,13 +56,13 @@ def format_system_message(errno):
# note the following will cause an infinite loop if GetLastError
# repeatedly returns an error that cannot be formatted, although
# this should not happen.
handle_nonzero_success(bytes)
handle_nonzero_success(format_bytes)
message = result_buffer.value
ctypes.windll.kernel32.LocalFree(result_buffer)
return message
class WindowsError(__builtin__.WindowsError):
class WindowsError(builtins.WindowsError):
"more info about errors at http://msdn.microsoft.com/en-us/library/ms681381(VS.85).aspx"
def __init__(self, value=None):
@ -120,7 +124,7 @@ class MemoryMap(object):
FILE_MAP_WRITE = 0x2
filemap = ctypes.windll.kernel32.CreateFileMappingW(
INVALID_HANDLE_VALUE, p_SA, PAGE_READWRITE, 0, self.length,
unicode(self.name))
u(self.name))
handle_nonzero_success(filemap)
if filemap == INVALID_HANDLE_VALUE:
raise Exception("Failed to create file mapping")

View File

@ -34,11 +34,14 @@ from paramiko.ssh_exception import SSHException
from paramiko.message import Message
from paramiko.pkey import PKey
from paramiko.channel import Channel
from paramiko.common import io_sleep
from paramiko.common import *
from paramiko.util import retry_on_signal
SSH2_AGENTC_REQUEST_IDENTITIES, SSH2_AGENT_IDENTITIES_ANSWER, \
SSH2_AGENTC_SIGN_REQUEST, SSH2_AGENT_SIGN_RESPONSE = range(11, 15)
cSSH2_AGENTC_REQUEST_IDENTITIES = byte_chr(11)
SSH2_AGENT_IDENTITIES_ANSWER = 12
cSSH2_AGENTC_SIGN_REQUEST = byte_chr(13)
SSH2_AGENT_SIGN_RESPONSE = 14
class AgentSSH(object):
@ -60,12 +63,12 @@ class AgentSSH(object):
def _connect(self, conn):
self._conn = conn
ptype, result = self._send_message(chr(SSH2_AGENTC_REQUEST_IDENTITIES))
ptype, result = self._send_message(cSSH2_AGENTC_REQUEST_IDENTITIES)
if ptype != SSH2_AGENT_IDENTITIES_ANSWER:
raise SSHException('could not get keys from ssh-agent')
keys = []
for i in range(result.get_int()):
keys.append(AgentKey(self, result.get_string()))
keys.append(AgentKey(self, result.get_binary()))
result.get_string()
self._keys = tuple(keys)
@ -75,7 +78,7 @@ class AgentSSH(object):
self._keys = ()
def _send_message(self, msg):
msg = str(msg)
msg = asbytes(msg)
self._conn.send(struct.pack('>I', len(msg)) + msg)
l = self._read_all(4)
msg = Message(self._read_all(struct.unpack('>I', l)[0]))
@ -212,7 +215,7 @@ class AgentClientProxy(object):
# probably a dangling env var: the ssh agent is gone
return
elif sys.platform == 'win32':
import win_pageant
import paramiko.win_pageant as win_pageant
if win_pageant.can_talk_to_agent():
conn = win_pageant.PageantConnection()
else:
@ -328,7 +331,7 @@ class Agent(AgentSSH):
# probably a dangling env var: the ssh agent is gone
return
elif sys.platform == 'win32':
import win_pageant
from . import win_pageant
if win_pageant.can_talk_to_agent():
conn = win_pageant.PageantConnection()
else:
@ -354,21 +357,24 @@ class AgentKey(PKey):
def __init__(self, agent, blob):
self.agent = agent
self.blob = blob
self.name = Message(blob).get_string()
self.name = Message(blob).get_text()
def asbytes(self):
return self.blob
def __str__(self):
return self.blob
return self.asbytes()
def get_name(self):
return self.name
def sign_ssh_data(self, rng, data):
msg = Message()
msg.add_byte(chr(SSH2_AGENTC_SIGN_REQUEST))
msg.add_byte(cSSH2_AGENTC_SIGN_REQUEST)
msg.add_string(self.blob)
msg.add_string(data)
msg.add_int(0)
ptype, result = self.agent._send_message(msg)
if ptype != SSH2_AGENT_SIGN_RESPONSE:
raise SSHException('key cannot be used for signing')
return result.get_string()
return result.get_binary()

View File

@ -120,13 +120,13 @@ class AuthHandler (object):
def _request_auth(self):
m = Message()
m.add_byte(chr(MSG_SERVICE_REQUEST))
m.add_byte(cMSG_SERVICE_REQUEST)
m.add_string('ssh-userauth')
self.transport._send_message(m)
def _disconnect_service_not_available(self):
m = Message()
m.add_byte(chr(MSG_DISCONNECT))
m.add_byte(cMSG_DISCONNECT)
m.add_int(DISCONNECT_SERVICE_NOT_AVAILABLE)
m.add_string('Service not available')
m.add_string('en')
@ -135,7 +135,7 @@ class AuthHandler (object):
def _disconnect_no_more_auth(self):
m = Message()
m.add_byte(chr(MSG_DISCONNECT))
m.add_byte(cMSG_DISCONNECT)
m.add_int(DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE)
m.add_string('No more auth methods available')
m.add_string('en')
@ -145,14 +145,14 @@ class AuthHandler (object):
def _get_session_blob(self, key, service, username):
m = Message()
m.add_string(self.transport.session_id)
m.add_byte(chr(MSG_USERAUTH_REQUEST))
m.add_byte(cMSG_USERAUTH_REQUEST)
m.add_string(username)
m.add_string(service)
m.add_string('publickey')
m.add_boolean(1)
m.add_string(key.get_name())
m.add_string(str(key))
return str(m)
m.add_string(key)
return m.asbytes()
def wait_for_response(self, event):
while True:
@ -176,11 +176,11 @@ class AuthHandler (object):
return []
def _parse_service_request(self, m):
service = m.get_string()
service = m.get_text()
if self.transport.server_mode and (service == 'ssh-userauth'):
# accepted
m = Message()
m.add_byte(chr(MSG_SERVICE_ACCEPT))
m.add_byte(cMSG_SERVICE_ACCEPT)
m.add_string(service)
self.transport._send_message(m)
return
@ -188,27 +188,25 @@ class AuthHandler (object):
self._disconnect_service_not_available()
def _parse_service_accept(self, m):
service = m.get_string()
service = m.get_text()
if service == 'ssh-userauth':
self.transport._log(DEBUG, 'userauth is OK')
m = Message()
m.add_byte(chr(MSG_USERAUTH_REQUEST))
m.add_byte(cMSG_USERAUTH_REQUEST)
m.add_string(self.username)
m.add_string('ssh-connection')
m.add_string(self.auth_method)
if self.auth_method == 'password':
m.add_boolean(False)
password = self.password
if isinstance(password, unicode):
password = password.encode('UTF-8')
password = bytestring(self.password)
m.add_string(password)
elif self.auth_method == 'publickey':
m.add_boolean(True)
m.add_string(self.private_key.get_name())
m.add_string(str(self.private_key))
m.add_string(self.private_key)
blob = self._get_session_blob(self.private_key, 'ssh-connection', self.username)
sig = self.private_key.sign_ssh_data(self.transport.rng, blob)
m.add_string(str(sig))
m.add_string(sig)
elif self.auth_method == 'keyboard-interactive':
m.add_string('')
m.add_string(self.submethods)
@ -225,11 +223,11 @@ class AuthHandler (object):
m = Message()
if result == AUTH_SUCCESSFUL:
self.transport._log(INFO, 'Auth granted (%s).' % method)
m.add_byte(chr(MSG_USERAUTH_SUCCESS))
m.add_byte(cMSG_USERAUTH_SUCCESS)
self.authenticated = True
else:
self.transport._log(INFO, 'Auth rejected (%s).' % method)
m.add_byte(chr(MSG_USERAUTH_FAILURE))
m.add_byte(cMSG_USERAUTH_FAILURE)
m.add_string(self.transport.server_object.get_allowed_auths(username))
if result == AUTH_PARTIALLY_SUCCESSFUL:
m.add_boolean(1)
@ -245,10 +243,10 @@ class AuthHandler (object):
def _interactive_query(self, q):
# make interactive query instead of response
m = Message()
m.add_byte(chr(MSG_USERAUTH_INFO_REQUEST))
m.add_byte(cMSG_USERAUTH_INFO_REQUEST)
m.add_string(q.name)
m.add_string(q.instructions)
m.add_string('')
m.add_string(bytes())
m.add_int(len(q.prompts))
for p in q.prompts:
m.add_string(p[0])
@ -259,7 +257,7 @@ class AuthHandler (object):
if not self.transport.server_mode:
# er, uh... what?
m = Message()
m.add_byte(chr(MSG_USERAUTH_FAILURE))
m.add_byte(cMSG_USERAUTH_FAILURE)
m.add_string('none')
m.add_boolean(0)
self.transport._send_message(m)
@ -267,9 +265,9 @@ class AuthHandler (object):
if self.authenticated:
# ignore
return
username = m.get_string()
service = m.get_string()
method = m.get_string()
username = m.get_text()
service = m.get_text()
method = m.get_text()
self.transport._log(DEBUG, 'Auth request (type=%s) service=%s, username=%s' % (method, service, username))
if service != 'ssh-connection':
self._disconnect_service_not_available()
@ -284,7 +282,7 @@ class AuthHandler (object):
result = self.transport.server_object.check_auth_none(username)
elif method == 'password':
changereq = m.get_boolean()
password = m.get_string()
password = m.get_binary()
try:
password = password.decode('UTF-8')
except UnicodeError:
@ -295,7 +293,7 @@ class AuthHandler (object):
# always treated as failure, since we don't support changing passwords, but collect
# the list of valid auth types from the callback anyway
self.transport._log(DEBUG, 'Auth request to change passwords (rejected)')
newpassword = m.get_string()
newpassword = m.get_binary()
try:
newpassword = newpassword.decode('UTF-8', 'replace')
except UnicodeError:
@ -305,11 +303,11 @@ class AuthHandler (object):
result = self.transport.server_object.check_auth_password(username, password)
elif method == 'publickey':
sig_attached = m.get_boolean()
keytype = m.get_string()
keyblob = m.get_string()
keytype = m.get_text()
keyblob = m.get_binary()
try:
key = self.transport._key_info[keytype](Message(keyblob))
except SSHException, e:
except SSHException as e:
self.transport._log(INFO, 'Auth rejected: public key: %s' % str(e))
key = None
except:
@ -326,12 +324,12 @@ class AuthHandler (object):
# client wants to know if this key is acceptable, before it
# signs anything... send special "ok" message
m = Message()
m.add_byte(chr(MSG_USERAUTH_PK_OK))
m.add_byte(cMSG_USERAUTH_PK_OK)
m.add_string(keytype)
m.add_string(keyblob)
self.transport._send_message(m)
return
sig = Message(m.get_string())
sig = Message(m.get_binary())
blob = self._get_session_blob(key, service, username)
if not key.verify_ssh_sig(blob, sig):
self.transport._log(INFO, 'Auth rejected: invalid signature')
@ -378,23 +376,23 @@ class AuthHandler (object):
banner = m.get_string()
self.banner = banner
lang = m.get_string()
self.transport._log(INFO, 'Auth banner: ' + banner)
self.transport._log(INFO, 'Auth banner: %s' % banner)
# who cares.
def _parse_userauth_info_request(self, m):
if self.auth_method != 'keyboard-interactive':
raise SSHException('Illegal info request from server')
title = m.get_string()
instructions = m.get_string()
m.get_string() # lang
title = m.get_text()
instructions = m.get_text()
m.get_binary() # lang
prompts = m.get_int()
prompt_list = []
for i in range(prompts):
prompt_list.append((m.get_string(), m.get_boolean()))
prompt_list.append((m.get_text(), m.get_boolean()))
response_list = self.interactive_handler(title, instructions, prompt_list)
m = Message()
m.add_byte(chr(MSG_USERAUTH_INFO_RESPONSE))
m.add_byte(cMSG_USERAUTH_INFO_RESPONSE)
m.add_int(len(response_list))
for r in response_list:
m.add_string(r)
@ -406,7 +404,7 @@ class AuthHandler (object):
n = m.get_int()
responses = []
for i in range(n):
responses.append(m.get_string())
responses.append(m.get_text())
result = self.transport.server_object.check_auth_interactive_response(responses)
if isinstance(type(result), InteractiveQuery):
# make interactive query instead of response

View File

@ -17,7 +17,8 @@
# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA.
import util
import paramiko.util as util
from paramiko.common import *
class BERException (Exception):
@ -29,13 +30,16 @@ class BER(object):
Robey's tiny little attempt at a BER decoder.
"""
def __init__(self, content=''):
self.content = content
def __init__(self, content=bytes()):
self.content = b(content)
self.idx = 0
def __str__(self):
def asbytes(self):
return self.content
def __str__(self):
return self.asbytes()
def __repr__(self):
return 'BER(\'' + repr(self.content) + '\')'
@ -45,13 +49,13 @@ class BER(object):
def decode_next(self):
if self.idx >= len(self.content):
return None
ident = ord(self.content[self.idx])
ident = byte_ord(self.content[self.idx])
self.idx += 1
if (ident & 31) == 31:
# identifier > 30
ident = 0
while self.idx < len(self.content):
t = ord(self.content[self.idx])
t = byte_ord(self.content[self.idx])
self.idx += 1
ident = (ident << 7) | (t & 0x7f)
if not (t & 0x80):
@ -59,7 +63,7 @@ class BER(object):
if self.idx >= len(self.content):
return None
# now fetch length
size = ord(self.content[self.idx])
size = byte_ord(self.content[self.idx])
self.idx += 1
if size & 0x80:
# more complimicated...
@ -98,20 +102,20 @@ class BER(object):
def encode_tlv(self, ident, val):
# no need to support ident > 31 here
self.content += chr(ident)
self.content += byte_chr(ident)
if len(val) > 0x7f:
lenstr = util.deflate_long(len(val))
self.content += chr(0x80 + len(lenstr)) + lenstr
self.content += byte_chr(0x80 + len(lenstr)) + lenstr
else:
self.content += chr(len(val))
self.content += byte_chr(len(val))
self.content += val
def encode(self, x):
if type(x) is bool:
if x:
self.encode_tlv(1, '\xff')
self.encode_tlv(1, max_byte)
else:
self.encode_tlv(1, '\x00')
self.encode_tlv(1, zero_byte)
elif (type(x) is int) or (type(x) is long):
self.encode_tlv(2, util.deflate_long(x))
elif type(x) is str:
@ -125,5 +129,5 @@ class BER(object):
b = BER()
for item in data:
b.encode(item)
return str(b)
return b.asbytes()
encode_sequence = staticmethod(encode_sequence)

View File

@ -25,6 +25,7 @@ read operations are blocking and can have a timeout set.
import array
import threading
import time
from paramiko.common import *
class PipeTimeout (IOError):
@ -48,6 +49,20 @@ class BufferedPipe (object):
self._buffer = array.array('B')
self._closed = False
if PY2:
def _buffer_frombytes(self, data):
self._buffer.fromstring(data)
def _buffer_tobytes(self, limit=None):
return self._buffer[:limit].tostring()
else:
def _buffer_frombytes(self, data):
self._buffer.frombytes(data)
def _buffer_tobytes(self, limit=None):
return self._buffer[:limit].tobytes()
def set_event(self, event):
"""
Set an event on this buffer. When data is ready to be read (or the
@ -73,7 +88,7 @@ class BufferedPipe (object):
try:
if self._event is not None:
self._event.set()
self._buffer.fromstring(data)
self._buffer_frombytes(b(data))
self._cv.notifyAll()
finally:
self._lock.release()
@ -117,7 +132,7 @@ class BufferedPipe (object):
if a timeout was specified and no data was ready before that
timeout
"""
out = ''
out = bytes()
self._lock.acquire()
try:
if len(self._buffer) == 0:
@ -138,12 +153,12 @@ class BufferedPipe (object):
# something's in the buffer and we have the lock!
if len(self._buffer) <= nbytes:
out = self._buffer.tostring()
out = self._buffer_tobytes()
del self._buffer[:]
if (self._event is not None) and not self._closed:
self._event.clear()
else:
out = self._buffer[:nbytes].tostring()
out = self._buffer_tobytes(nbytes)
del self._buffer[:nbytes]
finally:
self._lock.release()
@ -160,7 +175,7 @@ class BufferedPipe (object):
"""
self._lock.acquire()
try:
out = self._buffer.tostring()
out = self._buffer_tobytes()
del self._buffer[:]
if (self._event is not None) and not self._closed:
self._event.clear()

View File

@ -140,7 +140,7 @@ class Channel (object):
if self.closed or self.eof_received or self.eof_sent or not self.active:
raise SSHException('Channel is not open')
m = Message()
m.add_byte(chr(MSG_CHANNEL_REQUEST))
m.add_byte(cMSG_CHANNEL_REQUEST)
m.add_int(self.remote_chanid)
m.add_string('pty-req')
m.add_boolean(True)
@ -149,7 +149,7 @@ class Channel (object):
m.add_int(height)
m.add_int(width_pixels)
m.add_int(height_pixels)
m.add_string('')
m.add_string(bytes())
self._event_pending()
self.transport._send_user_message(m)
self._wait_for_event()
@ -173,7 +173,7 @@ class Channel (object):
if self.closed or self.eof_received or self.eof_sent or not self.active:
raise SSHException('Channel is not open')
m = Message()
m.add_byte(chr(MSG_CHANNEL_REQUEST))
m.add_byte(cMSG_CHANNEL_REQUEST)
m.add_int(self.remote_chanid)
m.add_string('shell')
m.add_boolean(1)
@ -199,7 +199,7 @@ class Channel (object):
if self.closed or self.eof_received or self.eof_sent or not self.active:
raise SSHException('Channel is not open')
m = Message()
m.add_byte(chr(MSG_CHANNEL_REQUEST))
m.add_byte(cMSG_CHANNEL_REQUEST)
m.add_int(self.remote_chanid)
m.add_string('exec')
m.add_boolean(True)
@ -225,7 +225,7 @@ class Channel (object):
if self.closed or self.eof_received or self.eof_sent or not self.active:
raise SSHException('Channel is not open')
m = Message()
m.add_byte(chr(MSG_CHANNEL_REQUEST))
m.add_byte(cMSG_CHANNEL_REQUEST)
m.add_int(self.remote_chanid)
m.add_string('subsystem')
m.add_boolean(True)
@ -250,7 +250,7 @@ class Channel (object):
if self.closed or self.eof_received or self.eof_sent or not self.active:
raise SSHException('Channel is not open')
m = Message()
m.add_byte(chr(MSG_CHANNEL_REQUEST))
m.add_byte(cMSG_CHANNEL_REQUEST)
m.add_int(self.remote_chanid)
m.add_string('window-change')
m.add_boolean(False)
@ -304,7 +304,7 @@ class Channel (object):
# in many cases, the channel will not still be open here.
# that's fine.
m = Message()
m.add_byte(chr(MSG_CHANNEL_REQUEST))
m.add_byte(cMSG_CHANNEL_REQUEST)
m.add_int(self.remote_chanid)
m.add_string('exit-status')
m.add_boolean(False)
@ -359,7 +359,7 @@ class Channel (object):
auth_cookie = binascii.hexlify(self.transport.rng.read(16))
m = Message()
m.add_byte(chr(MSG_CHANNEL_REQUEST))
m.add_byte(cMSG_CHANNEL_REQUEST)
m.add_int(self.remote_chanid)
m.add_string('x11-req')
m.add_boolean(True)
@ -389,7 +389,7 @@ class Channel (object):
raise SSHException('Channel is not open')
m = Message()
m.add_byte(chr(MSG_CHANNEL_REQUEST))
m.add_byte(cMSG_CHANNEL_REQUEST)
m.add_int(self.remote_chanid)
m.add_string('auth-agent-req@openssh.com')
m.add_boolean(False)
@ -451,7 +451,7 @@ class Channel (object):
.. versionadded:: 1.1
"""
data = ''
data = bytes()
self.lock.acquire()
try:
old = self.combine_stderr
@ -581,14 +581,14 @@ class Channel (object):
"""
try:
out = self.in_buffer.read(nbytes, self.timeout)
except PipeTimeout, e:
except PipeTimeout:
raise socket.timeout()
ack = self._check_add_window(len(out))
# no need to hold the channel lock when sending this
if ack > 0:
m = Message()
m.add_byte(chr(MSG_CHANNEL_WINDOW_ADJUST))
m.add_byte(cMSG_CHANNEL_WINDOW_ADJUST)
m.add_int(self.remote_chanid)
m.add_int(ack)
self.transport._send_user_message(m)
@ -629,14 +629,14 @@ class Channel (object):
"""
try:
out = self.in_stderr_buffer.read(nbytes, self.timeout)
except PipeTimeout, e:
except PipeTimeout:
raise socket.timeout()
ack = self._check_add_window(len(out))
# no need to hold the channel lock when sending this
if ack > 0:
m = Message()
m.add_byte(chr(MSG_CHANNEL_WINDOW_ADJUST))
m.add_byte(cMSG_CHANNEL_WINDOW_ADJUST)
m.add_int(self.remote_chanid)
m.add_int(ack)
self.transport._send_user_message(m)
@ -686,7 +686,7 @@ class Channel (object):
# eof or similar
return 0
m = Message()
m.add_byte(chr(MSG_CHANNEL_DATA))
m.add_byte(cMSG_CHANNEL_DATA)
m.add_int(self.remote_chanid)
m.add_string(s[:size])
finally:
@ -721,7 +721,7 @@ class Channel (object):
# eof or similar
return 0
m = Message()
m.add_byte(chr(MSG_CHANNEL_EXTENDED_DATA))
m.add_byte(cMSG_CHANNEL_EXTENDED_DATA)
m.add_int(self.remote_chanid)
m.add_int(1)
m.add_string(s[:size])
@ -925,16 +925,16 @@ class Channel (object):
self.transport._send_user_message(m)
def _feed(self, m):
if type(m) is str:
if isinstance(m, bytes_types):
# passed from _feed_extended
s = m
else:
s = m.get_string()
s = m.get_binary()
self.in_buffer.feed(s)
def _feed_extended(self, m):
code = m.get_int()
s = m.get_string()
s = m.get_binary()
if code != 1:
self._log(ERROR, 'unknown extended_data type %d; discarding' % code)
return
@ -955,7 +955,7 @@ class Channel (object):
self.lock.release()
def _handle_request(self, m):
key = m.get_string()
key = m.get_text()
want_reply = m.get_boolean()
server = self.transport.server_object
ok = False
@ -991,13 +991,13 @@ class Channel (object):
else:
ok = server.check_channel_env_request(self, name, value)
elif key == 'exec':
cmd = m.get_string()
cmd = m.get_text()
if server is None:
ok = False
else:
ok = server.check_channel_exec_request(self, cmd)
elif key == 'subsystem':
name = m.get_string()
name = m.get_text()
if server is None:
ok = False
else:
@ -1014,8 +1014,8 @@ class Channel (object):
pixelheight)
elif key == 'x11-req':
single_connection = m.get_boolean()
auth_proto = m.get_string()
auth_cookie = m.get_string()
auth_proto = m.get_text()
auth_cookie = m.get_binary()
screen_number = m.get_int()
if server is None:
ok = False
@ -1033,9 +1033,9 @@ class Channel (object):
if want_reply:
m = Message()
if ok:
m.add_byte(chr(MSG_CHANNEL_SUCCESS))
m.add_byte(cMSG_CHANNEL_SUCCESS)
else:
m.add_byte(chr(MSG_CHANNEL_FAILURE))
m.add_byte(cMSG_CHANNEL_FAILURE)
m.add_int(self.remote_chanid)
self.transport._send_user_message(m)
@ -1101,7 +1101,7 @@ class Channel (object):
if self.eof_sent:
return None
m = Message()
m.add_byte(chr(MSG_CHANNEL_EOF))
m.add_byte(cMSG_CHANNEL_EOF)
m.add_int(self.remote_chanid)
self.eof_sent = True
self._log(DEBUG, 'EOF sent (%s)', self._name)
@ -1113,7 +1113,7 @@ class Channel (object):
return None, None
m1 = self._send_eof()
m2 = Message()
m2.add_byte(chr(MSG_CHANNEL_CLOSE))
m2.add_byte(cMSG_CHANNEL_CLOSE)
m2.add_int(self.remote_chanid)
self._set_closed()
# can't unlink from the Transport yet -- the remote side may still

View File

@ -132,11 +132,10 @@ class SSHClient (object):
if self._host_keys_filename is not None:
self.load_host_keys(self._host_keys_filename)
f = open(filename, 'w')
for hostname, keys in self._host_keys.iteritems():
for keytype, key in keys.iteritems():
with open(filename, 'w') as f:
for hostname, keys in self._host_keys.items():
for keytype, key in keys.items():
f.write('%s %s %s\n' % (hostname, keytype, key.get_base64()))
f.close()
def get_host_keys(self):
"""
@ -266,7 +265,7 @@ class SSHClient (object):
if key_filename is None:
key_filenames = []
elif isinstance(key_filename, (str, unicode)):
elif isinstance(key_filename, string_types):
key_filenames = [ key_filename ]
else:
key_filenames = key_filename
@ -310,8 +309,8 @@ class SSHClient (object):
chan.settimeout(timeout)
chan.exec_command(command)
stdin = chan.makefile('wb', bufsize)
stdout = chan.makefile('rb', bufsize)
stderr = chan.makefile_stderr('rb', bufsize)
stdout = chan.makefile('r', bufsize)
stderr = chan.makefile_stderr('r', bufsize)
return stdin, stdout, stderr
def invoke_shell(self, term='vt100', width=80, height=24, width_pixels=0,
@ -377,7 +376,7 @@ class SSHClient (object):
two_factor = (allowed_types == ['password'])
if not two_factor:
return
except SSHException, e:
except SSHException as e:
saved_exception = e
if not two_factor:
@ -391,7 +390,7 @@ class SSHClient (object):
if not two_factor:
return
break
except SSHException, e:
except SSHException as e:
saved_exception = e
if not two_factor and allow_agent:
@ -407,7 +406,7 @@ class SSHClient (object):
if not two_factor:
return
break
except SSHException, e:
except SSHException as e:
saved_exception = e
if not two_factor:
@ -439,17 +438,15 @@ class SSHClient (object):
if not two_factor:
return
break
except SSHException, e:
saved_exception = e
except IOError, e:
except (SSHException, IOError) as e:
saved_exception = e
if password is not None:
try:
self._transport.auth_password(username, password)
return
except SSHException, e:
saved_exception = e
except SSHException:
saved_exception = sys.exc_info()[1]
elif two_factor:
raise SSHException('Two-factor authentication requires a password')

View File

@ -19,12 +19,13 @@
"""
Common constants and global variables.
"""
from paramiko.py3compat import *
MSG_DISCONNECT, MSG_IGNORE, MSG_UNIMPLEMENTED, MSG_DEBUG, MSG_SERVICE_REQUEST, \
MSG_SERVICE_ACCEPT = range(1, 7)
MSG_KEXINIT, MSG_NEWKEYS = range(20, 22)
MSG_USERAUTH_REQUEST, MSG_USERAUTH_FAILURE, MSG_USERAUTH_SUCCESS, \
MSG_USERAUTH_BANNER = range(50, 54)
MSG_USERAUTH_BANNER = range(50, 54)
MSG_USERAUTH_PK_OK = 60
MSG_USERAUTH_INFO_REQUEST, MSG_USERAUTH_INFO_RESPONSE = range(60, 62)
MSG_GLOBAL_REQUEST, MSG_REQUEST_SUCCESS, MSG_REQUEST_FAILURE = range(80, 83)
@ -33,6 +34,10 @@ MSG_CHANNEL_OPEN, MSG_CHANNEL_OPEN_SUCCESS, MSG_CHANNEL_OPEN_FAILURE, \
MSG_CHANNEL_EOF, MSG_CHANNEL_CLOSE, MSG_CHANNEL_REQUEST, \
MSG_CHANNEL_SUCCESS, MSG_CHANNEL_FAILURE = range(90, 101)
for key in list(locals().keys()):
if key.startswith('MSG_'):
locals()['c' + key] = byte_chr(locals()[key])
del key
# for debugging:
MSG_NAMES = {
@ -69,7 +74,7 @@ MSG_NAMES = {
MSG_CHANNEL_REQUEST: 'channel-request',
MSG_CHANNEL_SUCCESS: 'channel-success',
MSG_CHANNEL_FAILURE: 'channel-failure'
}
}
# authentication request return codes:
@ -118,6 +123,42 @@ else:
import logging
PY22 = False
zero_byte = byte_chr(0)
one_byte = byte_chr(1)
four_byte = byte_chr(4)
max_byte = byte_chr(0xff)
cr_byte = byte_chr(13)
linefeed_byte = byte_chr(10)
crlf = cr_byte + linefeed_byte
if PY2:
cr_byte_value = cr_byte
linefeed_byte_value = linefeed_byte
else:
cr_byte_value = 13
linefeed_byte_value = 10
def asbytes(s):
if not isinstance(s, bytes_types):
if isinstance(s, string_types):
s = b(s)
else:
try:
s = s.asbytes()
except Exception:
raise Exception('Unknown type')
return s
xffffffff = long(0xffffffff)
x80000000 = long(0x80000000)
o666 = 438
o660 = 432
o644 = 420
o600 = 384
o777 = 511
o700 = 448
o70 = 56
DEBUG = logging.DEBUG
INFO = logging.INFO

View File

@ -116,7 +116,7 @@ class SSHConfig (object):
ret = {}
for match in matches:
for key, value in match['config'].iteritems():
for key, value in match['config'].items():
if key not in ret:
# Create a copy of the original value,
# else it will reference the original list

View File

@ -56,7 +56,7 @@ class DSSKey (PKey):
else:
if msg is None:
raise SSHException('Key object may not be empty')
if msg.get_string() != 'ssh-dss':
if msg.get_text() != 'ssh-dss':
raise SSHException('Invalid key')
self.p = msg.get_mpint()
self.q = msg.get_mpint()
@ -64,14 +64,17 @@ class DSSKey (PKey):
self.y = msg.get_mpint()
self.size = util.bit_length(self.p)
def __str__(self):
def asbytes(self):
m = Message()
m.add_string('ssh-dss')
m.add_mpint(self.p)
m.add_mpint(self.q)
m.add_mpint(self.g)
m.add_mpint(self.y)
return str(m)
return m.asbytes()
def __str__(self):
return self.asbytes()
def __hash__(self):
h = hash(self.get_name())
@ -107,21 +110,21 @@ class DSSKey (PKey):
rstr = util.deflate_long(r, 0)
sstr = util.deflate_long(s, 0)
if len(rstr) < 20:
rstr = '\x00' * (20 - len(rstr)) + rstr
rstr = zero_byte * (20 - len(rstr)) + rstr
if len(sstr) < 20:
sstr = '\x00' * (20 - len(sstr)) + sstr
sstr = zero_byte * (20 - len(sstr)) + sstr
m.add_string(rstr + sstr)
return m
def verify_ssh_sig(self, data, msg):
if len(str(msg)) == 40:
if len(msg.asbytes()) == 40:
# spies.com bug: signature has no header
sig = str(msg)
sig = msg.asbytes()
else:
kind = msg.get_string()
kind = msg.get_text()
if kind != 'ssh-dss':
return 0
sig = msg.get_string()
sig = msg.get_binary()
# pull out (r, s) which are NOT encoded as mpints
sigR = util.inflate_long(sig[:20], 1)
@ -140,7 +143,7 @@ class DSSKey (PKey):
b.encode(keylist)
except BERException:
raise SSHException('Unable to create ber encoding of key')
return str(b)
return b.asbytes()
def write_private_key_file(self, filename, password=None):
self._write_private_key_file('DSA', filename, self._encode_key(), password)
@ -182,8 +185,8 @@ class DSSKey (PKey):
# DSAPrivateKey = { version = 0, p, q, g, y, x }
try:
keylist = BER(data).decode()
except BERException, x:
raise SSHException('Unable to parse key file: ' + str(x))
except BERException as e:
raise SSHException('Unable to parse key file: ' + str(e))
if (type(keylist) is not list) or (len(keylist) < 6) or (keylist[0] != 0):
raise SSHException('not a valid DSA private key file (bad ber encoding)')
self.p = keylist[1]

View File

@ -56,30 +56,33 @@ class ECDSAKey (PKey):
else:
if msg is None:
raise SSHException('Key object may not be empty')
if msg.get_string() != 'ecdsa-sha2-nistp256':
if msg.get_text() != 'ecdsa-sha2-nistp256':
raise SSHException('Invalid key')
curvename = msg.get_string()
curvename = msg.get_text()
if curvename != 'nistp256':
raise SSHException("Can't handle curve of type %s" % curvename)
pointinfo = msg.get_string()
if pointinfo[0] != "\x04":
raise SSHException('Point compression is being used: %s'%
pointinfo = msg.get_binary()
if pointinfo[0:1] != four_byte:
raise SSHException('Point compression is being used: %s' %
binascii.hexlify(pointinfo))
self.verifying_key = VerifyingKey.from_string(pointinfo[1:],
curve=curves.NIST256p)
curve=curves.NIST256p)
self.size = 256
def __str__(self):
def asbytes(self):
key = self.verifying_key
m = Message()
m.add_string('ecdsa-sha2-nistp256')
m.add_string('nistp256')
point_str = "\x04" + key.to_string()
point_str = four_byte + key.to_string()
m.add_string(point_str)
return str(m)
return m.asbytes()
def __str__(self):
return self.asbytes()
def __hash__(self):
h = hash(self.get_name())
@ -106,9 +109,9 @@ class ECDSAKey (PKey):
return m
def verify_ssh_sig(self, data, msg):
if msg.get_string() != 'ecdsa-sha2-nistp256':
if msg.get_text() != 'ecdsa-sha2-nistp256':
return False
sig = msg.get_string()
sig = msg.get_binary()
# verify the signature by SHA'ing the data and encrypting it
# using the public key.
@ -154,14 +157,13 @@ class ECDSAKey (PKey):
data = self._read_private_key('EC', file_obj, password)
self._decode_key(data)
ALLOWED_PADDINGS = ['\x01', '\x02\x02', '\x03\x03\x03', '\x04\x04\x04\x04',
'\x05\x05\x05\x05\x05', '\x06\x06\x06\x06\x06\x06',
'\x07\x07\x07\x07\x07\x07\x07']
ALLOWED_PADDINGS = [one_byte, byte_chr(2) * 2, byte_chr(3) * 3, byte_chr(4) * 4,
byte_chr(5) * 5, byte_chr(6) * 6, byte_chr(7) * 7]
def _decode_key(self, data):
s, padding = der.remove_sequence(data)
if padding:
if padding not in self.ALLOWED_PADDINGS:
raise ValueError, "weird padding: %s" % (binascii.hexlify(empty))
raise ValueError("weird padding: %s" % u(binascii.hexlify(data)))
data = data[:-len(padding)]
key = SigningKey.from_der(data)
self.signing_key = key
@ -172,7 +174,7 @@ class ECDSAKey (PKey):
msg = Message()
msg.add_mpint(r)
msg.add_mpint(s)
return str(msg)
return msg.asbytes()
def _sigdecode(self, sig, order):
msg = Message(sig)

View File

@ -16,7 +16,7 @@
# along with Paramiko; if not, write to the Free Software Foundation, Inc.,
# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA.
from cStringIO import StringIO
from paramiko.common import *
class BufferedFile (object):
@ -43,8 +43,8 @@ class BufferedFile (object):
self.newlines = None
self._flags = 0
self._bufsize = self._DEFAULT_BUFSIZE
self._wbuffer = StringIO()
self._rbuffer = ''
self._wbuffer = BytesIO()
self._rbuffer = bytes()
self._at_trailing_cr = False
self._closed = False
# pos - position within the file, according to the user
@ -82,9 +82,10 @@ class BufferedFile (object):
buffering is not turned on.
"""
self._write_all(self._wbuffer.getvalue())
self._wbuffer = StringIO()
self._wbuffer = BytesIO()
return
if PY2:
def next(self):
"""
Returns the next line from the input, or raises
@ -99,6 +100,22 @@ class BufferedFile (object):
if not line:
raise StopIteration
return line
else:
def __next__(self):
"""
Returns the next line from the input, or raises L{StopIteration} when
EOF is hit. Unlike python file objects, it's okay to mix calls to
C{next} and L{readline}.
@raise StopIteration: when the end of the file is reached.
@return: a line read from the file.
@rtype: str
"""
line = self.readline()
if not line:
raise StopIteration
return line
def read(self, size=None):
"""
@ -118,7 +135,7 @@ class BufferedFile (object):
if (size is None) or (size < 0):
# go for broke
result = self._rbuffer
self._rbuffer = ''
self._rbuffer = bytes()
self._pos += len(result)
while True:
try:
@ -130,12 +147,12 @@ class BufferedFile (object):
result += new_data
self._realpos += len(new_data)
self._pos += len(new_data)
return result
return result if self._flags & self.FLAG_BINARY else u(result)
if size <= len(self._rbuffer):
result = self._rbuffer[:size]
self._rbuffer = self._rbuffer[size:]
self._pos += len(result)
return result
return result if self._flags & self.FLAG_BINARY else u(result)
while len(self._rbuffer) < size:
read_size = size - len(self._rbuffer)
if self._flags & self.FLAG_BUFFERED:
@ -151,7 +168,7 @@ class BufferedFile (object):
result = self._rbuffer[:size]
self._rbuffer = self._rbuffer[size:]
self._pos += len(result)
return result
return result if self._flags & self.FLAG_BINARY else u(result)
def readline(self, size=None):
"""
@ -181,11 +198,11 @@ class BufferedFile (object):
if self._at_trailing_cr and (self._flags & self.FLAG_UNIVERSAL_NEWLINE) and (len(line) > 0):
# edge case: the newline may be '\r\n' and we may have read
# only the first '\r' last time.
if line[0] == '\n':
if line[0] == linefeed_byte_value:
line = line[1:]
self._record_newline('\r\n')
self._record_newline(crlf)
else:
self._record_newline('\r')
self._record_newline(cr_byte)
self._at_trailing_cr = False
# check size before looking for a linefeed, in case we already have
# enough.
@ -195,42 +212,42 @@ class BufferedFile (object):
self._rbuffer = line[size:]
line = line[:size]
self._pos += len(line)
return line
return line if self._flags & self.FLAG_BINARY else u(line)
n = size - len(line)
else:
n = self._bufsize
if ('\n' in line) or ((self._flags & self.FLAG_UNIVERSAL_NEWLINE) and ('\r' in line)):
if (linefeed_byte in line) or ((self._flags & self.FLAG_UNIVERSAL_NEWLINE) and (cr_byte in line)):
break
try:
new_data = self._read(n)
except EOFError:
new_data = None
if (new_data is None) or (len(new_data) == 0):
self._rbuffer = ''
self._rbuffer = bytes()
self._pos += len(line)
return line
return line if self._flags & self.FLAG_BINARY else u(line)
line += new_data
self._realpos += len(new_data)
# find the newline
pos = line.find('\n')
pos = line.find(linefeed_byte)
if self._flags & self.FLAG_UNIVERSAL_NEWLINE:
rpos = line.find('\r')
rpos = line.find(cr_byte)
if (rpos >= 0) and ((rpos < pos) or (pos < 0)):
pos = rpos
xpos = pos + 1
if (line[pos] == '\r') and (xpos < len(line)) and (line[xpos] == '\n'):
if (line[pos] == cr_byte_value) and (xpos < len(line)) and (line[xpos] == linefeed_byte_value):
xpos += 1
self._rbuffer = line[xpos:]
lf = line[pos:xpos]
line = line[:pos] + '\n'
if (len(self._rbuffer) == 0) and (lf == '\r'):
line = line[:pos] + linefeed_byte
if (len(self._rbuffer) == 0) and (lf == cr_byte):
# we could read the line up to a '\r' and there could still be a
# '\n' following that we read next time. note that and eat it.
self._at_trailing_cr = True
else:
self._record_newline(lf)
self._pos += len(line)
return line
return line if self._flags & self.FLAG_BINARY else u(line)
def readlines(self, sizehint=None):
"""
@ -243,14 +260,14 @@ class BufferedFile (object):
:return: `list` of lines read from the file.
"""
lines = []
bytes = 0
byte_count = 0
while True:
line = self.readline()
if len(line) == 0:
break
lines.append(line)
bytes += len(line)
if (sizehint is not None) and (bytes >= sizehint):
byte_count += len(line)
if (sizehint is not None) and (byte_count >= sizehint):
break
return lines
@ -292,6 +309,7 @@ class BufferedFile (object):
:param str data: data to write
"""
data = b(data)
if self._closed:
raise IOError('File is closed')
if not (self._flags & self.FLAG_WRITE):
@ -302,12 +320,12 @@ class BufferedFile (object):
self._wbuffer.write(data)
if self._flags & self.FLAG_LINE_BUFFERED:
# only scan the new data for linefeed, to avoid wasting time.
last_newline_pos = data.rfind('\n')
last_newline_pos = data.rfind(linefeed_byte)
if last_newline_pos >= 0:
wbuf = self._wbuffer.getvalue()
last_newline_pos += len(wbuf) - len(data)
self._write_all(wbuf[:last_newline_pos + 1])
self._wbuffer = StringIO()
self._wbuffer = BytesIO()
self._wbuffer.write(wbuf[last_newline_pos + 1:])
return
# even if we're line buffering, if the buffer has grown past the
@ -436,7 +454,7 @@ class BufferedFile (object):
return
if self.newlines is None:
self.newlines = newline
elif (type(self.newlines) is str) and (self.newlines != newline):
elif self.newlines != newline and isinstance(self.newlines, bytes_types):
self.newlines = (self.newlines, newline)
elif newline not in self.newlines:
self.newlines += (newline,)

View File

@ -20,7 +20,10 @@
import base64
import binascii
from Crypto.Hash import SHA, HMAC
import UserDict
try:
from collections import MutableMapping
except ImportError:
from UserDict import DictMixin as MutableMapping
from paramiko.common import *
from paramiko.dsskey import DSSKey
@ -29,7 +32,7 @@ from paramiko.util import get_logger, constant_time_bytes_eq
from paramiko.ecdsakey import ECDSAKey
class HostKeys (UserDict.DictMixin):
class HostKeys (MutableMapping):
"""
Representation of an OpenSSH-style "known hosts" file. Host keys can be
read from one or more files, and then individual hosts can be looked up to
@ -83,20 +86,19 @@ class HostKeys (UserDict.DictMixin):
:raises IOError: if there was an error reading the file
"""
f = open(filename, 'r')
for lineno, line in enumerate(f):
line = line.strip()
if (len(line) == 0) or (line[0] == '#'):
continue
e = HostKeyEntry.from_line(line, lineno)
if e is not None:
_hostnames = e.hostnames
for h in _hostnames:
if self.check(h, e.key):
e.hostnames.remove(h)
if len(e.hostnames):
self._entries.append(e)
f.close()
with open(filename, 'r') as f:
for lineno, line in enumerate(f):
line = line.strip()
if (len(line) == 0) or (line[0] == '#'):
continue
e = HostKeyEntry.from_line(line, lineno)
if e is not None:
_hostnames = e.hostnames
for h in _hostnames:
if self.check(h, e.key):
e.hostnames.remove(h)
if len(e.hostnames):
self._entries.append(e)
def save(self, filename):
"""
@ -111,12 +113,11 @@ class HostKeys (UserDict.DictMixin):
.. versionadded:: 1.6.1
"""
f = open(filename, 'w')
for e in self._entries:
line = e.to_line()
if line:
f.write(line)
f.close()
with open(filename, 'w') as f:
for e in self._entries:
line = e.to_line()
if line:
f.write(line)
def lookup(self, hostname):
"""
@ -127,12 +128,26 @@ class HostKeys (UserDict.DictMixin):
:param str hostname: the hostname (or IP) to lookup
:return: dict of `str` -> `.PKey` keys associated with this host (or ``None``)
"""
class SubDict (UserDict.DictMixin):
class SubDict (MutableMapping):
def __init__(self, hostname, entries, hostkeys):
self._hostname = hostname
self._entries = entries
self._hostkeys = hostkeys
def __iter__(self):
for k in self.keys():
yield k
def __len__(self):
return len(self.keys())
def __delitem__(self, key):
for e in list(self._entries):
if e.key.get_name() == key:
self._entries.remove(e)
else:
raise KeyError(key)
def __getitem__(self, key):
for e in self._entries:
if e.key.get_name() == key:
@ -181,7 +196,7 @@ class HostKeys (UserDict.DictMixin):
host_key = k.get(key.get_name(), None)
if host_key is None:
return False
return str(host_key) == str(key)
return host_key.asbytes() == key.asbytes()
def clear(self):
"""
@ -189,6 +204,17 @@ class HostKeys (UserDict.DictMixin):
"""
self._entries = []
def __iter__(self):
for k in self.keys():
yield k
def __len__(self):
return len(self.keys())
def __delitem__(self, key):
k = self[key]
pass
def __getitem__(self, key):
ret = self.lookup(key)
if ret is None:
@ -239,10 +265,10 @@ class HostKeys (UserDict.DictMixin):
else:
if salt.startswith('|1|'):
salt = salt.split('|')[2]
salt = base64.decodestring(salt)
salt = decodebytes(b(salt))
assert len(salt) == SHA.digest_size
hmac = HMAC.HMAC(salt, hostname, SHA).digest()
hostkey = '|1|%s|%s' % (base64.encodestring(salt), base64.encodestring(hmac))
hmac = HMAC.HMAC(salt, b(hostname), SHA).digest()
hostkey = '|1|%s|%s' % (u(encodebytes(salt)), u(encodebytes(hmac)))
return hostkey.replace('\n', '')
hash_host = staticmethod(hash_host)
@ -292,17 +318,17 @@ class HostKeyEntry:
# to hold it accordingly.
try:
if keytype == 'ssh-rsa':
key = RSAKey(data=base64.decodestring(key))
key = RSAKey(data=decodebytes(key))
elif keytype == 'ssh-dss':
key = DSSKey(data=base64.decodestring(key))
key = DSSKey(data=decodebytes(key))
elif keytype == 'ecdsa-sha2-nistp256':
key = ECDSAKey(data=base64.decodestring(key))
key = ECDSAKey(data=decodebytes(key))
else:
log.info("Unable to handle key of type %s" % (keytype,))
return None
except binascii.Error, e:
raise InvalidHostKey(line, e)
except binascii.Error as e:
raise InvalidHostKey(line, sys.exc_info()[1])
return cls(names, key)
from_line = classmethod(from_line)

View File

@ -33,6 +33,8 @@ from paramiko.ssh_exception import SSHException
_MSG_KEXDH_GEX_REQUEST_OLD, _MSG_KEXDH_GEX_GROUP, _MSG_KEXDH_GEX_INIT, \
_MSG_KEXDH_GEX_REPLY, _MSG_KEXDH_GEX_REQUEST = range(30, 35)
c_MSG_KEXDH_GEX_REQUEST_OLD, c_MSG_KEXDH_GEX_GROUP, c_MSG_KEXDH_GEX_INIT, \
c_MSG_KEXDH_GEX_REPLY, c_MSG_KEXDH_GEX_REQUEST = [byte_chr(c) for c in range(30, 35)]
class KexGex (object):
@ -62,11 +64,11 @@ class KexGex (object):
m = Message()
if _test_old_style:
# only used for unit tests: we shouldn't ever send this
m.add_byte(chr(_MSG_KEXDH_GEX_REQUEST_OLD))
m.add_byte(c_MSG_KEXDH_GEX_REQUEST_OLD)
m.add_int(self.preferred_bits)
self.old_style = True
else:
m.add_byte(chr(_MSG_KEXDH_GEX_REQUEST))
m.add_byte(c_MSG_KEXDH_GEX_REQUEST)
m.add_int(self.min_bits)
m.add_int(self.preferred_bits)
m.add_int(self.max_bits)
@ -94,15 +96,15 @@ class KexGex (object):
# generate an "x" (1 < x < (p-1)/2).
q = (self.p - 1) // 2
qnorm = util.deflate_long(q, 0)
qhbyte = ord(qnorm[0])
bytes = len(qnorm)
qhbyte = byte_ord(qnorm[0])
byte_count = len(qnorm)
qmask = 0xff
while not (qhbyte & 0x80):
qhbyte <<= 1
qmask >>= 1
while True:
x_bytes = self.transport.rng.read(bytes)
x_bytes = chr(ord(x_bytes[0]) & qmask) + x_bytes[1:]
x_bytes = self.transport.rng.read(byte_count)
x_bytes = byte_mask(x_bytes[0], qmask) + x_bytes[1:]
x = util.inflate_long(x_bytes, 1)
if (x > 1) and (x < q):
break
@ -135,7 +137,7 @@ class KexGex (object):
self.transport._log(DEBUG, 'Picking p (%d <= %d <= %d bits)' % (minbits, preferredbits, maxbits))
self.g, self.p = pack.get_modulus(minbits, preferredbits, maxbits)
m = Message()
m.add_byte(chr(_MSG_KEXDH_GEX_GROUP))
m.add_byte(c_MSG_KEXDH_GEX_GROUP)
m.add_mpint(self.p)
m.add_mpint(self.g)
self.transport._send_message(m)
@ -156,7 +158,7 @@ class KexGex (object):
self.transport._log(DEBUG, 'Picking p (~ %d bits)' % (self.preferred_bits,))
self.g, self.p = pack.get_modulus(self.min_bits, self.preferred_bits, self.max_bits)
m = Message()
m.add_byte(chr(_MSG_KEXDH_GEX_GROUP))
m.add_byte(c_MSG_KEXDH_GEX_GROUP)
m.add_mpint(self.p)
m.add_mpint(self.g)
self.transport._send_message(m)
@ -175,7 +177,7 @@ class KexGex (object):
# now compute e = g^x mod p
self.e = pow(self.g, self.x, self.p)
m = Message()
m.add_byte(chr(_MSG_KEXDH_GEX_INIT))
m.add_byte(c_MSG_KEXDH_GEX_INIT)
m.add_mpint(self.e)
self.transport._send_message(m)
self.transport._expect_packet(_MSG_KEXDH_GEX_REPLY)
@ -187,7 +189,7 @@ class KexGex (object):
self._generate_x()
self.f = pow(self.g, self.x, self.p)
K = pow(self.e, self.x, self.p)
key = str(self.transport.get_server_key())
key = self.transport.get_server_key().asbytes()
# okay, build up the hash H of (V_C || V_S || I_C || I_S || K_S || min || n || max || p || g || e || f || K)
hm = Message()
hm.add(self.transport.remote_version, self.transport.local_version,
@ -203,16 +205,16 @@ class KexGex (object):
hm.add_mpint(self.e)
hm.add_mpint(self.f)
hm.add_mpint(K)
H = SHA.new(str(hm)).digest()
H = SHA.new(hm.asbytes()).digest()
self.transport._set_K_H(K, H)
# sign it
sig = self.transport.get_server_key().sign_ssh_data(self.transport.rng, H)
# send reply
m = Message()
m.add_byte(chr(_MSG_KEXDH_GEX_REPLY))
m.add_byte(c_MSG_KEXDH_GEX_REPLY)
m.add_string(key)
m.add_mpint(self.f)
m.add_string(str(sig))
m.add_string(sig)
self.transport._send_message(m)
self.transport._activate_outbound()
@ -238,6 +240,6 @@ class KexGex (object):
hm.add_mpint(self.e)
hm.add_mpint(self.f)
hm.add_mpint(K)
self.transport._set_K_H(K, SHA.new(str(hm)).digest())
self.transport._set_K_H(K, SHA.new(hm.asbytes()).digest())
self.transport._verify_key(host_key, sig)
self.transport._activate_outbound()

View File

@ -30,11 +30,14 @@ from paramiko.ssh_exception import SSHException
_MSG_KEXDH_INIT, _MSG_KEXDH_REPLY = range(30, 32)
c_MSG_KEXDH_INIT, c_MSG_KEXDH_REPLY = [byte_chr(c) for c in range(30, 32)]
# draft-ietf-secsh-transport-09.txt, page 17
P = 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFFL
P = 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF
G = 2
b7fffffffffffffff = byte_chr(0x7f) + max_byte * 7
b0000000000000000 = zero_byte * 8
class KexGroup1(object):
@ -42,9 +45,9 @@ class KexGroup1(object):
def __init__(self, transport):
self.transport = transport
self.x = 0L
self.e = 0L
self.f = 0L
self.x = long(0)
self.e = long(0)
self.f = long(0)
def start_kex(self):
self._generate_x()
@ -56,7 +59,7 @@ class KexGroup1(object):
# compute e = g^x mod p (where g=2), and send it
self.e = pow(G, self.x, P)
m = Message()
m.add_byte(chr(_MSG_KEXDH_INIT))
m.add_byte(c_MSG_KEXDH_INIT)
m.add_mpint(self.e)
self.transport._send_message(m)
self.transport._expect_packet(_MSG_KEXDH_REPLY)
@ -80,9 +83,9 @@ class KexGroup1(object):
# larger than q (but this is a tiny tiny subset of potential x).
while 1:
x_bytes = self.transport.rng.read(128)
x_bytes = chr(ord(x_bytes[0]) & 0x7f) + x_bytes[1:]
if (x_bytes[:8] != '\x7F\xFF\xFF\xFF\xFF\xFF\xFF\xFF') and \
(x_bytes[:8] != '\x00\x00\x00\x00\x00\x00\x00\x00'):
x_bytes = byte_mask(x_bytes[0], 0x7f) + x_bytes[1:]
if (x_bytes[:8] != b7fffffffffffffff) and \
(x_bytes[:8] != b0000000000000000):
break
self.x = util.inflate_long(x_bytes)
@ -92,7 +95,7 @@ class KexGroup1(object):
self.f = m.get_mpint()
if (self.f < 1) or (self.f > P - 1):
raise SSHException('Server kex "f" is out of range')
sig = m.get_string()
sig = m.get_binary()
K = pow(self.f, self.x, P)
# okay, build up the hash H of (V_C || V_S || I_C || I_S || K_S || e || f || K)
hm = Message()
@ -102,7 +105,7 @@ class KexGroup1(object):
hm.add_mpint(self.e)
hm.add_mpint(self.f)
hm.add_mpint(K)
self.transport._set_K_H(K, SHA.new(str(hm)).digest())
self.transport._set_K_H(K, SHA.new(hm.asbytes()).digest())
self.transport._verify_key(host_key, sig)
self.transport._activate_outbound()
@ -112,7 +115,7 @@ class KexGroup1(object):
if (self.e < 1) or (self.e > P - 1):
raise SSHException('Client kex "e" is out of range')
K = pow(self.e, self.x, P)
key = str(self.transport.get_server_key())
key = self.transport.get_server_key().asbytes()
# okay, build up the hash H of (V_C || V_S || I_C || I_S || K_S || e || f || K)
hm = Message()
hm.add(self.transport.remote_version, self.transport.local_version,
@ -121,15 +124,15 @@ class KexGroup1(object):
hm.add_mpint(self.e)
hm.add_mpint(self.f)
hm.add_mpint(K)
H = SHA.new(str(hm)).digest()
H = SHA.new(hm.asbytes()).digest()
self.transport._set_K_H(K, H)
# sign it
sig = self.transport.get_server_key().sign_ssh_data(self.transport.rng, H)
# send reply
m = Message()
m.add_byte(chr(_MSG_KEXDH_REPLY))
m.add_byte(c_MSG_KEXDH_REPLY)
m.add_string(key)
m.add_mpint(self.f)
m.add_string(str(sig))
m.add_string(sig)
self.transport._send_message(m)
self.transport._activate_outbound()

View File

@ -21,9 +21,9 @@ Implementation of an SSH2 "message".
"""
import struct
import cStringIO
from paramiko import util
from paramiko.common import *
class Message (object):
@ -37,6 +37,8 @@ class Message (object):
paramiko doesn't support yet.
"""
big_int = long(0xff000000)
def __init__(self, content=None):
"""
Create a new SSH2 message.
@ -46,15 +48,15 @@ class Message (object):
decomposing a message).
"""
if content != None:
self.packet = cStringIO.StringIO(content)
self.packet = BytesIO(content)
else:
self.packet = cStringIO.StringIO()
self.packet = BytesIO()
def __str__(self):
"""
Return the byte stream content of this message, as a string.
Return the byte stream content of this message, as a string/bytes obj.
"""
return self.packet.getvalue()
return self.asbytes()
def __repr__(self):
"""
@ -62,6 +64,15 @@ class Message (object):
"""
return 'paramiko.Message(' + repr(self.packet.getvalue()) + ')'
def asbytes(self):
"""
Return the byte stream content of this Message, as bytes.
@return: the contents of this Message.
@rtype: bytes
"""
return self.packet.getvalue()
def rewind(self):
"""
Rewind the message to the beginning as if no items had been parsed
@ -99,7 +110,7 @@ class Message (object):
b = self.packet.read(n)
max_pad_size = 1<<20 # Limit padding to 1 MB
if len(b) < n and n < max_pad_size:
return b + '\x00' * (n - len(b))
return b + zero_byte * (n - len(b))
return b
def get_byte(self):
@ -118,7 +129,7 @@ class Message (object):
Fetch a boolean from the stream.
"""
b = self.get_bytes(1)
return b != '\x00'
return b != zero_byte
def get_int(self):
"""
@ -126,6 +137,19 @@ class Message (object):
:return: a 32-bit unsigned `int`.
"""
byte = self.get_bytes(1)
if byte == max_byte:
return util.inflate_long(self.get_binary())
byte += self.get_bytes(3)
return struct.unpack('>I', byte)[0]
def get_size(self):
"""
Fetch an int from the stream.
@return: a 32-bit unsigned integer.
@rtype: int
"""
return struct.unpack('>I', self.get_bytes(4))[0]
def get_int64(self):
@ -142,7 +166,7 @@ class Message (object):
:return: an arbitrary-length integer (`long`).
"""
return util.inflate_long(self.get_string())
return util.inflate_long(self.get_binary())
def get_string(self):
"""
@ -150,7 +174,30 @@ class Message (object):
contain unprintable characters. (It's not unheard of for a string to
contain another byte-stream message.)
"""
return self.get_bytes(self.get_int())
return self.get_bytes(self.get_size())
def get_text(self):
"""
Fetch a string from the stream. This could be a byte string and may
contain unprintable characters. (It's not unheard of for a string to
contain another byte-stream Message.)
@return: a string.
@rtype: string
"""
return u(self.get_bytes(self.get_size()))
#return self.get_bytes(self.get_size())
def get_binary(self):
"""
Fetch a string from the stream. This could be a byte string and may
contain unprintable characters. (It's not unheard of for a string to
contain another byte-stream Message.)
@return: a string.
@rtype: string
"""
return self.get_bytes(self.get_size())
def get_list(self):
"""
@ -158,7 +205,7 @@ class Message (object):
These are trivially encoded as comma-separated values in a string.
"""
return self.get_string().split(',')
return self.get_text().split(',')
def add_bytes(self, b):
"""
@ -185,9 +232,19 @@ class Message (object):
:param bool b: boolean value to add
"""
if b:
self.add_byte('\x01')
self.packet.write(one_byte)
else:
self.add_byte('\x00')
self.packet.write(zero_byte)
return self
def add_size(self, n):
"""
Add an integer to the stream.
@param n: integer to add
@type n: int
"""
self.packet.write(struct.pack('>I', n))
return self
def add_int(self, n):
@ -196,6 +253,10 @@ class Message (object):
:param int n: integer to add
"""
if n >= Message.big_int:
self.packet.write(max_byte)
self.add_string(util.deflate_long(n))
else:
self.packet.write(struct.pack('>I', n))
return self
@ -224,7 +285,8 @@ class Message (object):
:param str s: string to add
"""
self.add_int(len(s))
s = asbytes(s)
self.add_size(len(s))
self.packet.write(s)
return self
@ -240,21 +302,14 @@ class Message (object):
return self
def _add(self, i):
if type(i) is str:
return self.add_string(i)
elif type(i) is int:
return self.add_int(i)
elif type(i) is long:
if i > 0xffffffffL:
return self.add_mpint(i)
else:
return self.add_int(i)
elif type(i) is bool:
if type(i) is bool:
return self.add_boolean(i)
elif isinstance(i, integer_types):
return self.add_int(i)
elif type(i) is list:
return self.add_list(i)
else:
raise Exception('Unknown type')
return self.add_string(i)
def add(self, *seq):
"""

View File

@ -38,6 +38,7 @@ try:
except ImportError:
from Crypto.Hash.HMAC import HMAC
def compute_hmac(key, message, digest_class):
return HMAC(key, message, digest_class).digest()
@ -66,7 +67,7 @@ class Packetizer (object):
self.__dump_packets = False
self.__need_rekey = False
self.__init_count = 0
self.__remainder = ''
self.__remainder = bytes()
# used for noticing when to re-key:
self.__sent_bytes = 0
@ -86,12 +87,12 @@ class Packetizer (object):
self.__sdctr_out = False
self.__mac_engine_out = None
self.__mac_engine_in = None
self.__mac_key_out = ''
self.__mac_key_in = ''
self.__mac_key_out = bytes()
self.__mac_key_in = bytes()
self.__compress_engine_out = None
self.__compress_engine_in = None
self.__sequence_number_out = 0L
self.__sequence_number_in = 0L
self.__sequence_number_out = 0
self.__sequence_number_in = 0
# lock around outbound writes (packet computation)
self.__write_lock = threading.RLock()
@ -152,6 +153,7 @@ class Packetizer (object):
def close(self):
self.__closed = True
self.__socket.close()
def set_hexdump(self, hexdump):
self.__dump_packets = hexdump
@ -193,7 +195,7 @@ class Packetizer (object):
:raises EOFError:
if the socket was closed before all the bytes could be read
"""
out = ''
out = bytes()
# handle over-reading from reading the banner line
if len(self.__remainder) > 0:
out = self.__remainder[:n]
@ -211,7 +213,7 @@ class Packetizer (object):
n -= len(x)
except socket.timeout:
got_timeout = True
except socket.error, e:
except socket.error as e:
# on Linux, sometimes instead of socket.timeout, we get
# EAGAIN. this is a bug in recent (> 2.6.9) kernels but
# we need to work around it.
@ -240,7 +242,7 @@ class Packetizer (object):
n = self.__socket.send(out)
except socket.timeout:
retry_write = True
except socket.error, e:
except socket.error as e:
if (type(e.args) is tuple) and (len(e.args) > 0) and (e.args[0] == errno.EAGAIN):
retry_write = True
elif (type(e.args) is tuple) and (len(e.args) > 0) and (e.args[0] == errno.EINTR):
@ -270,22 +272,22 @@ class Packetizer (object):
line, so it's okay to attempt large reads.
"""
buf = self.__remainder
while not '\n' in buf:
while not linefeed_byte in buf:
buf += self._read_timeout(timeout)
n = buf.index('\n')
n = buf.index(linefeed_byte)
self.__remainder = buf[n+1:]
buf = buf[:n]
if (len(buf) > 0) and (buf[-1] == '\r'):
if (len(buf) > 0) and (buf[-1] == cr_byte_value):
buf = buf[:-1]
return buf
return u(buf)
def send_message(self, data):
"""
Write a block of data using the current cipher, as an SSH block.
"""
# encrypt this sucka
data = str(data)
cmd = ord(data[0])
data = asbytes(data)
cmd = byte_ord(data[0])
if cmd in MSG_NAMES:
cmd_name = MSG_NAMES[cmd]
else:
@ -307,7 +309,7 @@ class Packetizer (object):
if self.__block_engine_out != None:
payload = struct.pack('>I', self.__sequence_number_out) + packet
out += compute_hmac(self.__mac_key_out, payload, self.__mac_engine_out)[:self.__mac_size_out]
self.__sequence_number_out = (self.__sequence_number_out + 1) & 0xffffffffL
self.__sequence_number_out = (self.__sequence_number_out + 1) & xffffffff
self.write_all(out)
self.__sent_bytes += len(out)
@ -356,7 +358,7 @@ class Packetizer (object):
my_mac = compute_hmac(self.__mac_key_in, mac_payload, self.__mac_engine_in)[:self.__mac_size_in]
if not util.constant_time_bytes_eq(my_mac, mac):
raise SSHException('Mismatched MAC')
padding = ord(packet[0])
padding = byte_ord(packet[0])
payload = packet[1:packet_size - padding]
if self.__dump_packets:
@ -367,7 +369,7 @@ class Packetizer (object):
msg = Message(payload[1:])
msg.seqno = self.__sequence_number_in
self.__sequence_number_in = (self.__sequence_number_in + 1) & 0xffffffffL
self.__sequence_number_in = (self.__sequence_number_in + 1) & xffffffff
# check for rekey
raw_packet_size = packet_size + self.__mac_size_in + 4
@ -390,7 +392,7 @@ class Packetizer (object):
self.__received_packets_overflow = 0
self._trigger_rekey()
cmd = ord(payload[0])
cmd = byte_ord(payload[0])
if cmd in MSG_NAMES:
cmd_name = MSG_NAMES[cmd]
else:
@ -465,7 +467,7 @@ class Packetizer (object):
break
except socket.timeout:
pass
except EnvironmentError, e:
except EnvironmentError as e:
if ((type(e.args) is tuple) and (len(e.args) > 0) and
(e.args[0] == errno.EINTR)):
pass
@ -487,7 +489,7 @@ class Packetizer (object):
if self.__sdctr_out or self.__block_engine_out is None:
# cute trick i caught openssh doing: if we're not encrypting or SDCTR mode (RFC4344),
# don't waste random bytes for the padding
packet += (chr(0) * padding)
packet += (zero_byte * padding)
else:
packet += rng.read(padding)
return packet

View File

@ -28,6 +28,7 @@ will trigger as readable in `select <select.select>`.
import sys
import os
import socket
from paramiko.py3compat import b
def make_pipe ():
@ -64,7 +65,7 @@ class PosixPipe (object):
if self._set or self._closed:
return
self._set = True
os.write(self._wfd, '*')
os.write(self._wfd, b'*')
def set_forever (self):
self._forever = True
@ -110,7 +111,7 @@ class WindowsPipe (object):
if self._set or self._closed:
return
self._set = True
self._wsock.send('*')
self._wsock.send(b'*')
def set_forever (self):
self._forever = True

View File

@ -62,13 +62,16 @@ class PKey (object):
"""
pass
def __str__(self):
def asbytes(self):
"""
Return a string of an SSH `.Message` made up of the public part(s) of
this key. This string is suitable for passing to `__init__` to
re-create the key object later.
"""
return ''
return bytes()
def __str__(self):
return self.asbytes()
def __cmp__(self, other):
"""
@ -83,7 +86,10 @@ class PKey (object):
ho = hash(other)
if hs != ho:
return cmp(hs, ho)
return cmp(str(self), str(other))
return cmp(self.asbytes(), other.asbytes())
def __eq__(self, other):
return hash(self) == hash(other)
def get_name(self):
"""
@ -120,7 +126,7 @@ class PKey (object):
a 16-byte `string <str>` (binary) of the MD5 fingerprint, in SSH
format.
"""
return MD5.new(str(self)).digest()
return MD5.new(self.asbytes()).digest()
def get_base64(self):
"""
@ -130,7 +136,7 @@ class PKey (object):
:return: a base64 `string <str>` containing the public part of the key.
"""
return base64.encodestring(str(self)).replace('\n', '')
return u(encodebytes(self.asbytes())).replace('\n', '')
def sign_ssh_data(self, rng, data):
"""
@ -141,7 +147,7 @@ class PKey (object):
:param str data: the data to sign.
:return: an SSH signature `message <.Message>`.
"""
return ''
return bytes()
def verify_ssh_sig(self, data, msg):
"""
@ -246,9 +252,8 @@ class PKey (object):
encrypted, and ``password`` is ``None``.
:raises SSHException: if the key file is invalid.
"""
f = open(filename, 'r')
with open(filename, 'r') as f:
data = self._read_private_key(tag, f, password)
f.close()
return data
def _read_private_key(self, tag, f, password=None):
@ -273,8 +278,8 @@ class PKey (object):
end += 1
# if we trudged to the end of the file, just try to cope.
try:
data = base64.decodestring(''.join(lines[start:end]))
except base64.binascii.Error, e:
data = decodebytes(b(''.join(lines[start:end])))
except base64.binascii.Error as e:
raise SSHException('base64 decoding error: ' + str(e))
if 'proc-type' not in headers:
# unencryped: done
@ -285,7 +290,7 @@ class PKey (object):
try:
encryption_type, saltstr = headers['dek-info'].split(',')
except:
raise SSHException('Can\'t parse DEK-info in private key file')
raise SSHException("Can't parse DEK-info in private key file")
if encryption_type not in self._CIPHER_TABLE:
raise SSHException('Unknown private key cipher "%s"' % encryption_type)
# if no password was passed in, raise an exception pointing out that we need one
@ -294,7 +299,7 @@ class PKey (object):
cipher = self._CIPHER_TABLE[encryption_type]['cipher']
keysize = self._CIPHER_TABLE[encryption_type]['keysize']
mode = self._CIPHER_TABLE[encryption_type]['mode']
salt = unhexlify(saltstr)
salt = unhexlify(b(saltstr))
key = util.generate_key_bytes(MD5, salt, password, keysize)
return cipher.new(key, mode, salt).decrypt(data)
@ -312,33 +317,32 @@ class PKey (object):
:raises IOError: if there was an error writing the file.
"""
f = open(filename, 'w', 0600)
with open(filename, 'w', o600) as f:
# grrr... the mode doesn't always take hold
os.chmod(filename, 0600)
os.chmod(filename, o600)
self._write_private_key(tag, f, data, password)
f.close()
def _write_private_key(self, tag, f, data, password=None):
f.write('-----BEGIN %s PRIVATE KEY-----\n' % tag)
if password is not None:
# since we only support one cipher here, use it
cipher_name = self._CIPHER_TABLE.keys()[0]
cipher_name = list(self._CIPHER_TABLE.keys())[0]
cipher = self._CIPHER_TABLE[cipher_name]['cipher']
keysize = self._CIPHER_TABLE[cipher_name]['keysize']
blocksize = self._CIPHER_TABLE[cipher_name]['blocksize']
mode = self._CIPHER_TABLE[cipher_name]['mode']
salt = rng.read(8)
salt = rng.read(16)
key = util.generate_key_bytes(MD5, salt, password, keysize)
if len(data) % blocksize != 0:
n = blocksize - len(data) % blocksize
#data += rng.read(n)
# that would make more sense ^, but it confuses openssh.
data += '\0' * n
data += zero_byte * n
data = cipher.new(key, mode, salt).encrypt(data)
f.write('Proc-Type: 4,ENCRYPTED\n')
f.write('DEK-Info: %s,%s\n' % (cipher_name, hexlify(salt).upper()))
f.write('DEK-Info: %s,%s\n' % (cipher_name, u(hexlify(salt)).upper()))
f.write('\n')
s = base64.encodestring(data)
s = u(encodebytes(data))
# re-wrap to 64-char lines
s = ''.join(s.split('\n'))
s = '\n'.join([s[i : i+64] for i in range(0, len(s), 64)])

View File

@ -24,6 +24,7 @@ from Crypto.Util import number
from paramiko import util
from paramiko.ssh_exception import SSHException
from paramiko.common import *
def _generate_prime(bits, rng):
@ -33,7 +34,7 @@ def _generate_prime(bits, rng):
# loop catches the case where we increment n into a higher bit-range
x = rng.read((bits+7) // 8)
if hbyte_mask > 0:
x = chr(ord(x[0]) & hbyte_mask) + x[1:]
x = byte_mask(x[0], hbyte_mask) + x[1:]
n = util.inflate_long(x, 1)
n |= 1
n |= (1 << (bits - 1))
@ -46,7 +47,7 @@ def _generate_prime(bits, rng):
def _roll_random(rng, n):
"returns a random # from 0 to N-1"
bits = util.bit_length(n-1)
bytes = (bits + 7) // 8
byte_count = (bits + 7) // 8
hbyte_mask = pow(2, bits % 8) - 1
# so here's the plan:
@ -56,9 +57,9 @@ def _roll_random(rng, n):
# fits, so i can't guarantee that this loop will ever finish, but the odds
# of it looping forever should be infinitesimal.
while True:
x = rng.read(bytes)
x = rng.read(byte_count)
if hbyte_mask > 0:
x = chr(ord(x[0]) & hbyte_mask) + x[1:]
x = byte_mask(x[0], hbyte_mask) + x[1:]
num = util.inflate_long(x, 1)
if num < n:
break
@ -112,7 +113,7 @@ class ModulusPack (object):
:raises IOError: passed from any file operations that fail.
"""
self.pack = {}
f = open(filename, 'r')
with open(filename, 'r') as f:
for line in f:
line = line.strip()
if (len(line) == 0) or (line[0] == '#'):
@ -121,11 +122,9 @@ class ModulusPack (object):
self._parse_modulus(line)
except:
continue
f.close()
def get_modulus(self, min, prefer, max):
bitsizes = self.pack.keys()
bitsizes.sort()
bitsizes = sorted(self.pack.keys())
if len(bitsizes) == 0:
raise SSHException('no moduli available')
good = -1

View File

@ -59,7 +59,7 @@ class ProxyCommand(object):
"""
try:
self.process.stdin.write(content)
except IOError, e:
except IOError as e:
# There was a problem with the child process. It probably
# died and we can't proceed. The best option here is to
# raise an exception informing the user that the informed
@ -95,7 +95,7 @@ class ProxyCommand(object):
return result
except socket.timeout:
raise # socket.timeout is a subclass of IOError
except IOError, e:
except IOError as e:
raise ProxyCommandFailure(' '.join(self.cmd), e.strerror)
def close(self):

160
paramiko/py3compat.py Normal file
View File

@ -0,0 +1,160 @@
import sys
import base64
__all__ = ['PY2', 'string_types', 'integer_types', 'text_type', 'bytes_types', 'bytes', 'long', 'input',
'decodebytes', 'encodebytes', 'bytestring', 'byte_ord', 'byte_chr', 'byte_mask',
'b', 'u', 'b2s', 'StringIO', 'BytesIO', 'is_callable', 'MAXSIZE', 'next']
PY2 = sys.version_info[0] < 3
if PY2:
string_types = basestring
text_type = unicode
bytes_types = str
bytes = str
integer_types = (int, long)
long = long
input = raw_input
decodebytes = base64.decodestring
encodebytes = base64.encodestring
def bytestring(s): # NOQA
if isinstance(s, unicode):
return s.encode('utf-8')
return s
byte_ord = ord # NOQA
byte_chr = chr # NOQA
def byte_mask(c, mask):
return chr(ord(c) & mask)
def b(s, encoding='utf8'): # NOQA
"""cast unicode or bytes to bytes"""
if isinstance(s, str):
return s
elif isinstance(s, unicode):
return s.encode(encoding)
else:
raise TypeError("Expected unicode or bytes, got %r" % s)
def u(s, encoding='utf8'): # NOQA
"""cast bytes or unicode to unicode"""
if isinstance(s, str):
return s.decode(encoding)
elif isinstance(s, unicode):
return s
else:
raise TypeError("Expected unicode or bytes, got %r" % s)
def b2s(s):
return s
try:
import cStringIO
StringIO = cStringIO.StringIO # NOQA
except ImportError:
import StringIO
StringIO = StringIO.StringIO # NOQA
BytesIO = StringIO
def is_callable(c): # NOQA
return callable(c)
def get_next(c): # NOQA
return c.next
def next(c):
return c.next()
# It's possible to have sizeof(long) != sizeof(Py_ssize_t).
class X(object):
def __len__(self):
return 1 << 31
try:
len(X())
except OverflowError:
# 32-bit
MAXSIZE = int((1 << 31) - 1) # NOQA
else:
# 64-bit
MAXSIZE = int((1 << 63) - 1) # NOQA
del X
else:
import collections
import struct
string_types = str
text_type = str
bytes = bytes
bytes_types = bytes
integer_types = int
class long(int):
pass
input = input
decodebytes = base64.decodebytes
encodebytes = base64.encodebytes
def bytestring(s):
return s
def byte_ord(c):
assert isinstance(c, int)
return c
def byte_chr(c):
assert isinstance(c, int)
return struct.pack('B', c)
def byte_mask(c, mask):
assert isinstance(c, int)
return struct.pack('B', c & mask)
def b(s, encoding='utf8'):
"""cast unicode or bytes to bytes"""
if isinstance(s, bytes):
return s
elif isinstance(s, str):
return s.encode(encoding)
else:
raise TypeError("Expected unicode or bytes, got %r" % s)
def u(s, encoding='utf8'):
"""cast bytes or unicode to unicode"""
if isinstance(s, bytes):
return s.decode(encoding)
elif isinstance(s, str):
return s
else:
raise TypeError("Expected unicode or bytes, got %r" % s)
def b2s(s):
return s.decode() if isinstance(s, bytes) else s
import io
StringIO = io.StringIO # NOQA
BytesIO = io.BytesIO # NOQA
def is_callable(c):
return isinstance(c, collections.Callable)
def get_next(c):
return c.__next__
next = next
MAXSIZE = sys.maxsize # NOQA

View File

@ -31,6 +31,8 @@ from paramiko.ber import BER, BERException
from paramiko.pkey import PKey
from paramiko.ssh_exception import SSHException
SHA1_DIGESTINFO = b'\x30\x21\x30\x09\x06\x05\x2b\x0e\x03\x02\x1a\x05\x00\x04\x14'
class RSAKey (PKey):
"""
@ -57,18 +59,21 @@ class RSAKey (PKey):
else:
if msg is None:
raise SSHException('Key object may not be empty')
if msg.get_string() != 'ssh-rsa':
if msg.get_text() != 'ssh-rsa':
raise SSHException('Invalid key')
self.e = msg.get_mpint()
self.n = msg.get_mpint()
self.size = util.bit_length(self.n)
def __str__(self):
def asbytes(self):
m = Message()
m.add_string('ssh-rsa')
m.add_mpint(self.e)
m.add_mpint(self.n)
return str(m)
return m.asbytes()
def __str__(self):
return self.asbytes()
def __hash__(self):
h = hash(self.get_name())
@ -88,16 +93,16 @@ class RSAKey (PKey):
def sign_ssh_data(self, rpool, data):
digest = SHA.new(data).digest()
rsa = RSA.construct((long(self.n), long(self.e), long(self.d)))
sig = util.deflate_long(rsa.sign(self._pkcs1imify(digest), '')[0], 0)
sig = util.deflate_long(rsa.sign(self._pkcs1imify(digest), bytes())[0], 0)
m = Message()
m.add_string('ssh-rsa')
m.add_string(sig)
return m
def verify_ssh_sig(self, data, msg):
if msg.get_string() != 'ssh-rsa':
if msg.get_text() != 'ssh-rsa':
return False
sig = util.inflate_long(msg.get_string(), True)
sig = util.inflate_long(msg.get_binary(), True)
# verify the signature by SHA'ing the data and encrypting it using the
# public key. some wackiness ensues where we "pkcs1imify" the 20-byte
# hash into a string as long as the RSA key.
@ -116,7 +121,7 @@ class RSAKey (PKey):
b.encode(keylist)
except BERException:
raise SSHException('Unable to create ber encoding of key')
return str(b)
return b.asbytes()
def write_private_key_file(self, filename, password=None):
self._write_private_key_file('RSA', filename, self._encode_key(), password)
@ -152,10 +157,9 @@ class RSAKey (PKey):
turn a 20-byte SHA1 hash into a blob of data as large as the key's N,
using PKCS1's \"emsa-pkcs1-v1_5\" encoding. totally bizarre.
"""
SHA1_DIGESTINFO = '\x30\x21\x30\x09\x06\x05\x2b\x0e\x03\x02\x1a\x05\x00\x04\x14'
size = len(util.deflate_long(self.n, 0))
filler = '\xff' * (size - len(SHA1_DIGESTINFO) - len(data) - 3)
return '\x00\x01' + filler + '\x00' + SHA1_DIGESTINFO + data
filler = max_byte * (size - len(SHA1_DIGESTINFO) - len(data) - 3)
return zero_byte + one_byte + filler + zero_byte + SHA1_DIGESTINFO + data
def _from_private_key_file(self, filename, password):
data = self._read_private_key_file('RSA', filename, password)

View File

@ -514,7 +514,7 @@ class InteractiveQuery (object):
self.instructions = instructions
self.prompts = []
for x in prompts:
if (type(x) is str) or (type(x) is unicode):
if isinstance(x, string_types):
self.add_prompt(x)
else:
self.add_prompt(x[0], x[1])
@ -576,7 +576,7 @@ class SubsystemHandler (threading.Thread):
try:
self.__transport._log(DEBUG, 'Starting handler for subsystem %s' % self.__name)
self.start_subsystem(self.__name, self.__transport, self.__channel)
except Exception, e:
except Exception as e:
self.__transport._log(ERROR, 'Exception in subsystem handler for "%s": %s' %
(self.__name, str(e)))
self.__transport._log(ERROR, util.tb_strings())

View File

@ -86,7 +86,7 @@ CMD_NAMES = {
CMD_ATTRS: 'attrs',
CMD_EXTENDED: 'extended',
CMD_EXTENDED_REPLY: 'extended_reply'
}
}
class SFTPError (Exception):
@ -125,7 +125,7 @@ class BaseSFTP (object):
msg = Message()
msg.add_int(_VERSION)
msg.add(*extension_pairs)
self._send_packet(CMD_VERSION, str(msg))
self._send_packet(CMD_VERSION, msg)
return version
def _log(self, level, msg, *args):
@ -142,7 +142,7 @@ class BaseSFTP (object):
return
def _read_all(self, n):
out = ''
out = bytes()
while n > 0:
if isinstance(self.sock, socket.socket):
# sometimes sftp is used directly over a socket instead of
@ -166,7 +166,8 @@ class BaseSFTP (object):
def _send_packet(self, t, packet):
#self._log(DEBUG2, 'write: %s (len=%d)' % (CMD_NAMES.get(t, '0x%02x' % t), len(packet)))
out = struct.pack('>I', len(packet) + 1) + chr(t) + packet
packet = asbytes(packet)
out = struct.pack('>I', len(packet) + 1) + byte_chr(t) + packet
if self.ultra_debug:
self._log(DEBUG, util.format_binary(out, 'OUT: '))
self._write_all(out)
@ -175,14 +176,14 @@ class BaseSFTP (object):
x = self._read_all(4)
# most sftp servers won't accept packets larger than about 32k, so
# anything with the high byte set (> 16MB) is just garbage.
if x[0] != '\x00':
if byte_ord(x[0]):
raise SFTPError('Garbage packet received')
size = struct.unpack('>I', x)[0]
data = self._read_all(size)
if self.ultra_debug:
self._log(DEBUG, util.format_binary(data, 'IN: '));
if size > 0:
t = ord(data[0])
t = byte_ord(data[0])
#self._log(DEBUG2, 'read: %s (len=%d)' % (CMD_NAMES.get(t), '0x%02x' % t, len(data)-1))
return t, data[1:]
return 0, ''
return 0, bytes()

View File

@ -45,7 +45,7 @@ class SFTPAttributes (object):
FLAG_UIDGID = 2
FLAG_PERMISSIONS = 4
FLAG_AMTIME = 8
FLAG_EXTENDED = 0x80000000L
FLAG_EXTENDED = x80000000
def __init__(self):
"""
@ -141,7 +141,7 @@ class SFTPAttributes (object):
msg.add_int(long(self.st_mtime))
if self._flags & self.FLAG_EXTENDED:
msg.add_int(len(self.attr))
for key, val in self.attr.iteritems():
for key, val in self.attr.items():
msg.add_string(key)
msg.add_string(val)
return
@ -156,7 +156,7 @@ class SFTPAttributes (object):
out += 'mode=' + oct(self.st_mode) + ' '
if (self.st_atime is not None) and (self.st_mtime is not None):
out += 'atime=%d mtime=%d ' % (self.st_atime, self.st_mtime)
for k, v in self.attr.iteritems():
for k, v in self.attr.items():
out += '"%s"=%r ' % (str(k), v)
out += ']'
return out
@ -192,13 +192,13 @@ class SFTPAttributes (object):
ks = 's'
else:
ks = '?'
ks += self._rwx((self.st_mode & 0700) >> 6, self.st_mode & stat.S_ISUID)
ks += self._rwx((self.st_mode & 070) >> 3, self.st_mode & stat.S_ISGID)
ks += self._rwx((self.st_mode & o700) >> 6, self.st_mode & stat.S_ISUID)
ks += self._rwx((self.st_mode & o70) >> 3, self.st_mode & stat.S_ISGID)
ks += self._rwx(self.st_mode & 7, self.st_mode & stat.S_ISVTX, True)
else:
ks = '?---------'
# compute display date
if (self.st_mtime is None) or (self.st_mtime == 0xffffffffL):
if (self.st_mtime is None) or (self.st_mtime == xffffffff):
# shouldn't really happen
datestr = '(unknown date)'
else:
@ -219,3 +219,5 @@ class SFTPAttributes (object):
return '%s 1 %-8d %-8d %8d %-12s %s' % (ks, uid, gid, self.st_size, datestr, filename)
def asbytes(self):
return b(str(self))

View File

@ -39,12 +39,13 @@ def _to_unicode(s):
"""
try:
return s.encode('ascii')
except UnicodeError:
except (UnicodeError, AttributeError):
try:
return s.decode('utf-8')
except UnicodeError:
return s
b_slash = b'/'
class SFTPClient(BaseSFTP):
"""
@ -82,7 +83,7 @@ class SFTPClient(BaseSFTP):
self.ultra_debug = transport.get_hexdump()
try:
server_version = self._send_version()
except EOFError, x:
except EOFError:
raise SSHException('EOF during negotiation')
self._log(INFO, 'Opened sftp connection (server version %d)' % server_version)
@ -162,20 +163,20 @@ class SFTPClient(BaseSFTP):
t, msg = self._request(CMD_OPENDIR, path)
if t != CMD_HANDLE:
raise SFTPError('Expected handle')
handle = msg.get_string()
handle = msg.get_binary()
filelist = []
while True:
try:
t, msg = self._request(CMD_READDIR, handle)
except EOFError, e:
except EOFError:
# done with handle
break
if t != CMD_NAME:
raise SFTPError('Expected name response')
count = msg.get_int()
for i in range(count):
filename = _to_unicode(msg.get_string())
longname = _to_unicode(msg.get_string())
filename = msg.get_text()
longname = msg.get_text()
attr = SFTPAttributes._from_msg(msg, filename, longname)
if (filename != '.') and (filename != '..'):
filelist.append(attr)
@ -231,7 +232,7 @@ class SFTPClient(BaseSFTP):
t, msg = self._request(CMD_OPEN, filename, imode, attrblock)
if t != CMD_HANDLE:
raise SFTPError('Expected handle')
handle = msg.get_string()
handle = msg.get_binary()
self._log(DEBUG, 'open(%r, %r) -> %s' % (filename, mode, hexlify(handle)))
return SFTPFile(self, handle, mode, bufsize)
@ -268,7 +269,7 @@ class SFTPClient(BaseSFTP):
self._log(DEBUG, 'rename(%r, %r)' % (oldpath, newpath))
self._request(CMD_RENAME, oldpath, newpath)
def mkdir(self, path, mode=0777):
def mkdir(self, path, mode=o777):
"""
Create a folder (directory) named ``path`` with numeric mode ``mode``.
The default mode is 0777 (octal). On some systems, mode is ignored.
@ -347,8 +348,7 @@ class SFTPClient(BaseSFTP):
"""
dest = self._adjust_cwd(dest)
self._log(DEBUG, 'symlink(%r, %r)' % (source, dest))
if type(source) is unicode:
source = source.encode('utf-8')
source = bytestring(source)
self._request(CMD_SYMLINK, source, dest)
def chmod(self, path, mode):
@ -462,9 +462,9 @@ class SFTPClient(BaseSFTP):
count = msg.get_int()
if count != 1:
raise SFTPError('Realpath returned %d results' % count)
return _to_unicode(msg.get_string())
return msg.get_text()
def chdir(self, path):
def chdir(self, path=None):
"""
Change the "current directory" of this SFTP session. Since SFTP
doesn't really have the concept of a current working directory, this is
@ -484,7 +484,7 @@ class SFTPClient(BaseSFTP):
return
if not stat.S_ISDIR(self.stat(path).st_mode):
raise SFTPError(errno.ENOTDIR, "%s: %s" % (os.strerror(errno.ENOTDIR), path))
self._cwd = self.normalize(path).encode('utf-8')
self._cwd = b(self.normalize(path))
def getcwd(self):
"""
@ -494,7 +494,7 @@ class SFTPClient(BaseSFTP):
.. versionadded:: 1.4
"""
return self._cwd
return self._cwd and u(self._cwd)
def putfo(self, fl, remotepath, file_size=0, callback=None, confirm=True):
"""
@ -525,10 +525,9 @@ class SFTPClient(BaseSFTP):
.. versionchanged:: 1.7.4
Began returning rich attribute objects.
"""
fr = self.file(remotepath, 'wb')
with self.file(remotepath, 'wb') as fr:
fr.set_pipelined(True)
size = 0
try:
while True:
data = fl.read(32768)
fr.write(data)
@ -537,8 +536,6 @@ class SFTPClient(BaseSFTP):
callback(size, file_size)
if len(data) == 0:
break
finally:
fr.close()
if confirm:
s = self.stat(remotepath)
if s.st_size != size:
@ -573,11 +570,8 @@ class SFTPClient(BaseSFTP):
``confirm`` param added.
"""
file_size = os.stat(localpath).st_size
fl = file(localpath, 'rb')
try:
with open(localpath, 'rb') as fl:
return self.putfo(fl, remotepath, os.stat(localpath).st_size, callback, confirm)
finally:
fl.close()
def getfo(self, remotepath, fl, callback=None):
"""
@ -598,10 +592,9 @@ class SFTPClient(BaseSFTP):
.. versionchanged:: 1.7.4
Added the ``callable`` param.
"""
fr = self.file(remotepath, 'rb')
with self.open(remotepath, 'rb') as fr:
file_size = self.stat(remotepath).st_size
fr.prefetch()
try:
size = 0
while True:
data = fr.read(32768)
@ -611,8 +604,6 @@ class SFTPClient(BaseSFTP):
callback(size, file_size)
if len(data) == 0:
break
finally:
fr.close()
return size
def get(self, remotepath, localpath, callback=None):
@ -632,11 +623,8 @@ class SFTPClient(BaseSFTP):
Added the ``callback`` param
"""
file_size = self.stat(remotepath).st_size
fl = file(localpath, 'wb')
try:
with open(localpath, 'wb') as fl:
size = self.getfo(remotepath, fl, callback)
finally:
fl.close()
s = os.stat(localpath)
if s.st_size != size:
raise IOError('size mismatch in get! %d != %d' % (s.st_size, size))
@ -656,11 +644,11 @@ class SFTPClient(BaseSFTP):
msg = Message()
msg.add_int(self.request_number)
for item in arg:
if isinstance(item, int):
msg.add_int(item)
elif isinstance(item, long):
if isinstance(item, long):
msg.add_int64(item)
elif isinstance(item, str):
elif isinstance(item, int):
msg.add_int(item)
elif isinstance(item, (string_types, bytes_types)):
msg.add_string(item)
elif isinstance(item, SFTPAttributes):
item._pack(msg)
@ -668,7 +656,7 @@ class SFTPClient(BaseSFTP):
raise Exception('unknown type for %r type %r' % (item, type(item)))
num = self.request_number
self._expecting[num] = fileobj
self._send_packet(t, str(msg))
self._send_packet(t, msg)
self.request_number += 1
finally:
self._lock.release()
@ -678,8 +666,8 @@ class SFTPClient(BaseSFTP):
while True:
try:
t, data = self._read_packet()
except EOFError, e:
raise SSHException('Server connection dropped: %s' % (str(e),))
except EOFError as e:
raise SSHException('Server connection dropped: %s' % str(e))
msg = Message(data)
num = msg.get_int()
if num not in self._expecting:
@ -713,7 +701,7 @@ class SFTPClient(BaseSFTP):
Raises EOFError or IOError on error status; otherwise does nothing.
"""
code = msg.get_int()
text = msg.get_string()
text = msg.get_text()
if code == SFTP_OK:
return
elif code == SFTP_EOF:
@ -731,16 +719,15 @@ class SFTPClient(BaseSFTP):
Return an adjusted path if we're emulating a "current working
directory" for the server.
"""
if type(path) is unicode:
path = path.encode('utf-8')
path = b(path)
if self._cwd is None:
return path
if (len(path) > 0) and (path[0] == '/'):
if len(path) and path[0:1] == b_slash:
# absolute path
return path
if self._cwd == '/':
if self._cwd == b_slash:
return self._cwd + path
return self._cwd + '/' + path
return self._cwd + b_slash + path
class SFTP(SFTPClient):

View File

@ -100,7 +100,7 @@ class SFTPFile (BufferedFile):
k = [x for x in self._prefetch_extents.values() if x[0] <= offset]
if len(k) == 0:
return False
k.sort(lambda x, y: cmp(x[0], y[0]))
k.sort(key=lambda x: x[0])
buf_offset, buf_size = k[-1]
if buf_offset + buf_size <= offset:
# prefetch request ends before this one begins
@ -171,7 +171,7 @@ class SFTPFile (BufferedFile):
def _write(self, data):
# may write less than requested if it would exceed max packet size
chunk = min(len(data), self.MAX_REQUEST_SIZE)
self._reqs.append(self.sftp._async_request(type(None), CMD_WRITE, self.handle, long(self._realpos), str(data[:chunk])))
self._reqs.append(self.sftp._async_request(type(None), CMD_WRITE, self.handle, long(self._realpos), data[:chunk]))
if not self.pipelined or (len(self._reqs) > 100 and self.sftp.sock.recv_ready()):
while len(self._reqs):
req = self._reqs.popleft()
@ -224,7 +224,7 @@ class SFTPFile (BufferedFile):
self._realpos = self._pos
else:
self._realpos = self._pos = self._get_size() + offset
self._rbuffer = ''
self._rbuffer = bytes()
def stat(self):
"""
@ -352,8 +352,8 @@ class SFTPFile (BufferedFile):
"""
t, msg = self.sftp._request(CMD_EXTENDED, 'check-file', self.handle,
hash_algorithm, long(offset), long(length), block_size)
ext = msg.get_string()
alg = msg.get_string()
ext = msg.get_text()
alg = msg.get_text()
data = msg.get_remainder()
return data
@ -469,8 +469,8 @@ class SFTPFile (BufferedFile):
# save exception and re-raise it on next file operation
try:
self.sftp._convert_status(msg)
except Exception, x:
self._saved_exception = x
except Exception as e:
self._saved_exception = e
return
if t != CMD_DATA:
raise SFTPError('Expected data')

View File

@ -97,7 +97,7 @@ class SFTPHandle (object):
readfile.seek(offset)
self.__tell = offset
data = readfile.read(length)
except IOError, e:
except IOError as e:
self.__tell = None
return SFTPServer.convert_errno(e.errno)
self.__tell += len(data)
@ -135,7 +135,7 @@ class SFTPHandle (object):
self.__tell = offset
writefile.write(data)
writefile.flush()
except IOError, e:
except IOError as e:
self.__tell = None
return SFTPServer.convert_errno(e.errno)
if self.__tell is not None:

View File

@ -89,7 +89,7 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
except EOFError:
self._log(DEBUG, 'EOF -- end of session')
return
except Exception, e:
except Exception as e:
self._log(DEBUG, 'Exception on channel: ' + str(e))
self._log(DEBUG, util.tb_strings())
return
@ -97,7 +97,7 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
request_number = msg.get_int()
try:
self._process(t, request_number, msg)
except Exception, e:
except Exception as e:
self._log(DEBUG, 'Exception in server processing: ' + str(e))
self._log(DEBUG, util.tb_strings())
# send some kind of failure message, at least
@ -110,9 +110,9 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
self.server.session_ended()
super(SFTPServer, self).finish_subsystem()
# close any file handles that were left open (so we can return them to the OS quickly)
for f in self.file_table.itervalues():
for f in self.file_table.values():
f.close()
for f in self.folder_table.itervalues():
for f in self.folder_table.values():
f.close()
self.file_table = {}
self.folder_table = {}
@ -159,7 +159,8 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
if attr._flags & attr.FLAG_AMTIME:
os.utime(filename, (attr.st_atime, attr.st_mtime))
if attr._flags & attr.FLAG_SIZE:
open(filename, 'w+').truncate(attr.st_size)
with open(filename, 'w+') as f:
f.truncate(attr.st_size)
set_file_attr = staticmethod(set_file_attr)
@ -170,24 +171,24 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
msg = Message()
msg.add_int(request_number)
for item in arg:
if type(item) is int:
msg.add_int(item)
elif type(item) is long:
if isinstance(item, long):
msg.add_int64(item)
elif type(item) is str:
elif isinstance(item, int):
msg.add_int(item)
elif isinstance(item, (string_types, bytes_types)):
msg.add_string(item)
elif type(item) is SFTPAttributes:
item._pack(msg)
else:
raise Exception('unknown type for ' + repr(item) + ' type ' + repr(type(item)))
self._send_packet(t, str(msg))
self._send_packet(t, msg)
def _send_handle_response(self, request_number, handle, folder=False):
if not issubclass(type(handle), SFTPHandle):
# must be error code
self._send_status(request_number, handle)
return
handle._set_name('hx%d' % self.next_handle)
handle._set_name(b('hx%d' % self.next_handle))
self.next_handle += 1
if folder:
self.folder_table[handle._get_name()] = handle
@ -225,16 +226,16 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
msg.add_int(len(flist))
for attr in flist:
msg.add_string(attr.filename)
msg.add_string(str(attr))
msg.add_string(attr)
attr._pack(msg)
self._send_packet(CMD_NAME, str(msg))
self._send_packet(CMD_NAME, msg)
def _check_file(self, request_number, msg):
# this extension actually comes from v6 protocol, but since it's an
# extension, i feel like we can reasonably support it backported.
# it's very useful for verifying uploaded files or checking for
# rsync-like differences between local and remote files.
handle = msg.get_string()
handle = msg.get_binary()
alg_list = msg.get_list()
start = msg.get_int64()
length = msg.get_int64()
@ -263,7 +264,7 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
self._send_status(request_number, SFTP_FAILURE, 'Block size too small')
return
sum_out = ''
sum_out = bytes()
offset = start
while offset < start + length:
blocklen = min(block_size, start + length - offset)
@ -273,7 +274,7 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
hash_obj = alg.new()
while count < blocklen:
data = f.read(offset, chunklen)
if not type(data) is str:
if not isinstance(data, bytes_types):
self._send_status(request_number, data, 'Unable to hash file')
return
hash_obj.update(data)
@ -286,7 +287,7 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
msg.add_string('check-file')
msg.add_string(algname)
msg.add_bytes(sum_out)
self._send_packet(CMD_EXTENDED_REPLY, str(msg))
self._send_packet(CMD_EXTENDED_REPLY, msg)
def _convert_pflags(self, pflags):
"convert SFTP-style open() flags to Python's os.open() flags"
@ -309,12 +310,12 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
def _process(self, t, request_number, msg):
self._log(DEBUG, 'Request: %s' % CMD_NAMES[t])
if t == CMD_OPEN:
path = msg.get_string()
path = msg.get_text()
flags = self._convert_pflags(msg.get_int())
attr = SFTPAttributes._from_msg(msg)
self._send_handle_response(request_number, self.server.open(path, flags, attr))
elif t == CMD_CLOSE:
handle = msg.get_string()
handle = msg.get_binary()
if handle in self.folder_table:
del self.folder_table[handle]
self._send_status(request_number, SFTP_OK)
@ -326,14 +327,14 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
return
self._send_status(request_number, SFTP_BAD_MESSAGE, 'Invalid handle')
elif t == CMD_READ:
handle = msg.get_string()
handle = msg.get_binary()
offset = msg.get_int64()
length = msg.get_int()
if handle not in self.file_table:
self._send_status(request_number, SFTP_BAD_MESSAGE, 'Invalid handle')
return
data = self.file_table[handle].read(offset, length)
if type(data) is str:
if isinstance(data, (bytes_types, string_types)):
if len(data) == 0:
self._send_status(request_number, SFTP_EOF)
else:
@ -341,54 +342,54 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
else:
self._send_status(request_number, data)
elif t == CMD_WRITE:
handle = msg.get_string()
handle = msg.get_binary()
offset = msg.get_int64()
data = msg.get_string()
data = msg.get_binary()
if handle not in self.file_table:
self._send_status(request_number, SFTP_BAD_MESSAGE, 'Invalid handle')
return
self._send_status(request_number, self.file_table[handle].write(offset, data))
elif t == CMD_REMOVE:
path = msg.get_string()
path = msg.get_text()
self._send_status(request_number, self.server.remove(path))
elif t == CMD_RENAME:
oldpath = msg.get_string()
newpath = msg.get_string()
oldpath = msg.get_text()
newpath = msg.get_text()
self._send_status(request_number, self.server.rename(oldpath, newpath))
elif t == CMD_MKDIR:
path = msg.get_string()
path = msg.get_text()
attr = SFTPAttributes._from_msg(msg)
self._send_status(request_number, self.server.mkdir(path, attr))
elif t == CMD_RMDIR:
path = msg.get_string()
path = msg.get_text()
self._send_status(request_number, self.server.rmdir(path))
elif t == CMD_OPENDIR:
path = msg.get_string()
path = msg.get_text()
self._open_folder(request_number, path)
return
elif t == CMD_READDIR:
handle = msg.get_string()
handle = msg.get_binary()
if handle not in self.folder_table:
self._send_status(request_number, SFTP_BAD_MESSAGE, 'Invalid handle')
return
folder = self.folder_table[handle]
self._read_folder(request_number, folder)
elif t == CMD_STAT:
path = msg.get_string()
path = msg.get_text()
resp = self.server.stat(path)
if issubclass(type(resp), SFTPAttributes):
self._response(request_number, CMD_ATTRS, resp)
else:
self._send_status(request_number, resp)
elif t == CMD_LSTAT:
path = msg.get_string()
path = msg.get_text()
resp = self.server.lstat(path)
if issubclass(type(resp), SFTPAttributes):
self._response(request_number, CMD_ATTRS, resp)
else:
self._send_status(request_number, resp)
elif t == CMD_FSTAT:
handle = msg.get_string()
handle = msg.get_binary()
if handle not in self.file_table:
self._send_status(request_number, SFTP_BAD_MESSAGE, 'Invalid handle')
return
@ -398,34 +399,34 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
else:
self._send_status(request_number, resp)
elif t == CMD_SETSTAT:
path = msg.get_string()
path = msg.get_text()
attr = SFTPAttributes._from_msg(msg)
self._send_status(request_number, self.server.chattr(path, attr))
elif t == CMD_FSETSTAT:
handle = msg.get_string()
handle = msg.get_binary()
attr = SFTPAttributes._from_msg(msg)
if handle not in self.file_table:
self._response(request_number, SFTP_BAD_MESSAGE, 'Invalid handle')
return
self._send_status(request_number, self.file_table[handle].chattr(attr))
elif t == CMD_READLINK:
path = msg.get_string()
path = msg.get_text()
resp = self.server.readlink(path)
if type(resp) is str:
if isinstance(resp, (bytes_types, string_types)):
self._response(request_number, CMD_NAME, 1, resp, '', SFTPAttributes())
else:
self._send_status(request_number, resp)
elif t == CMD_SYMLINK:
# the sftp 2 draft is incorrect here! path always follows target_path
target_path = msg.get_string()
path = msg.get_string()
target_path = msg.get_text()
path = msg.get_text()
self._send_status(request_number, self.server.symlink(target_path, path))
elif t == CMD_REALPATH:
path = msg.get_string()
path = msg.get_text()
rpath = self.server.canonicalize(path)
self._response(request_number, CMD_NAME, 1, rpath, '', SFTPAttributes())
elif t == CMD_EXTENDED:
tag = msg.get_string()
tag = msg.get_text()
if tag == 'check-file':
self._check_file(request_number, msg)
else:

View File

@ -155,7 +155,7 @@ class Transport (threading.Thread):
:param socket sock:
a socket or socket-like object to create the session over.
"""
if isinstance(sock, (str, unicode)):
if isinstance(sock, string_types):
# convert "host:port" into (host, port)
hl = sock.split(':', 1)
if len(hl) == 1:
@ -173,7 +173,7 @@ class Transport (threading.Thread):
sock = socket.socket(af, socket.SOCK_STREAM)
try:
retry_on_signal(lambda: sock.connect((hostname, port)))
except socket.error, e:
except socket.error as e:
reason = str(e)
else:
break
@ -253,7 +253,7 @@ class Transport (threading.Thread):
"""
Returns a string representation of this object, for debugging.
"""
out = '<paramiko.Transport at %s' % hex(long(id(self)) & 0xffffffffL)
out = '<paramiko.Transport at %s' % hex(long(id(self)) & xffffffff)
if not self.active:
out += ' (unconnected)'
else:
@ -279,6 +279,7 @@ class Transport (threading.Thread):
.. versionadded:: 1.5.3
"""
self.sock.close()
self.close()
def get_security_options(self):
@ -489,7 +490,7 @@ class Transport (threading.Thread):
if not self.active:
return
self.stop_thread()
for chan in self._channels.values():
for chan in list(self._channels.values()):
chan._unlink()
self.sock.close()
@ -562,18 +563,16 @@ class Transport (threading.Thread):
"""
return self.open_channel('auth-agent@openssh.com')
def open_forwarded_tcpip_channel(self, (src_addr, src_port), (dest_addr, dest_port)):
def open_forwarded_tcpip_channel(self, src_addr, dest_addr):
"""
Request a new channel back to the client, of type ``"forwarded-tcpip"``.
This is used after a client has requested port forwarding, for sending
incoming connections back to the client.
:param src_addr: originator's address
:param src_port: originator's port
:param dest_addr: local (server) connected address
:param dest_port: local (server) connected port
"""
return self.open_channel('forwarded-tcpip', (dest_addr, dest_port), (src_addr, src_port))
return self.open_channel('forwarded-tcpip', dest_addr, src_addr)
def open_channel(self, kind, dest_addr=None, src_addr=None):
"""
@ -602,7 +601,7 @@ class Transport (threading.Thread):
try:
chanid = self._next_channel()
m = Message()
m.add_byte(chr(MSG_CHANNEL_OPEN))
m.add_byte(cMSG_CHANNEL_OPEN)
m.add_string(kind)
m.add_int(chanid)
m.add_int(self.window_size)
@ -670,7 +669,6 @@ class Transport (threading.Thread):
"""
if not self.active:
raise SSHException('SSH session not active')
address = str(address)
port = int(port)
response = self.global_request('tcpip-forward', (address, port), wait=True)
if response is None:
@ -678,7 +676,9 @@ class Transport (threading.Thread):
if port == 0:
port = response.get_int()
if handler is None:
def default_handler(channel, (src_addr, src_port), (dest_addr, dest_port)):
def default_handler(channel, src_addr, dest_addr):
#src_addr, src_port = src_addr_port
#dest_addr, dest_port = dest_addr_port
self._queue_incoming_channel(channel)
handler = default_handler
self._tcp_handler = handler
@ -710,22 +710,22 @@ class Transport (threading.Thread):
"""
return SFTPClient.from_transport(self)
def send_ignore(self, bytes=None):
def send_ignore(self, byte_count=None):
"""
Send a junk packet across the encrypted link. This is sometimes used
to add "noise" to a connection to confuse would-be attackers. It can
also be used as a keep-alive for long lived connections traversing
firewalls.
:param int bytes:
:param int byte_count:
the number of random bytes to send in the payload of the ignored
packet -- defaults to a random number from 10 to 41.
"""
m = Message()
m.add_byte(chr(MSG_IGNORE))
if bytes is None:
bytes = (ord(rng.read(1)) % 32) + 10
m.add_bytes(rng.read(bytes))
m.add_byte(cMSG_IGNORE)
if byte_count is None:
byte_count = (ord(rng.read(1)) % 32) + 10
m.add_bytes(rng.read(byte_count))
self._send_user_message(m)
def renegotiate_keys(self):
@ -787,7 +787,7 @@ class Transport (threading.Thread):
if wait:
self.completion_event = threading.Event()
m = Message()
m.add_byte(chr(MSG_GLOBAL_REQUEST))
m.add_byte(cMSG_GLOBAL_REQUEST)
m.add_string(kind)
m.add_boolean(wait)
if data is not None:
@ -871,10 +871,10 @@ class Transport (threading.Thread):
# check host key if we were given one
if (hostkey is not None):
key = self.get_remote_server_key()
if (key.get_name() != hostkey.get_name()) or (str(key) != str(hostkey)):
if (key.get_name() != hostkey.get_name()) or (key.asbytes() != hostkey.asbytes()):
self._log(DEBUG, 'Bad host key from server')
self._log(DEBUG, 'Expected: %s: %s' % (hostkey.get_name(), repr(str(hostkey))))
self._log(DEBUG, 'Got : %s: %s' % (key.get_name(), repr(str(key))))
self._log(DEBUG, 'Expected: %s: %s' % (hostkey.get_name(), repr(hostkey.asbytes())))
self._log(DEBUG, 'Got : %s: %s' % (key.get_name(), repr(key.asbytes())))
raise SSHException('Bad host key from server')
self._log(DEBUG, 'Host key verified (%s)' % hostkey.get_name())
@ -1048,9 +1048,9 @@ class Transport (threading.Thread):
return []
try:
return self.auth_handler.wait_for_response(my_event)
except BadAuthenticationType, x:
except BadAuthenticationType as e:
# if password auth isn't allowed, but keyboard-interactive *is*, try to fudge it
if not fallback or ('keyboard-interactive' not in x.allowed_types):
if not fallback or ('keyboard-interactive' not in e.allowed_types):
raise
try:
def handler(title, instructions, fields):
@ -1064,9 +1064,9 @@ class Transport (threading.Thread):
return []
return [ password ]
return self.auth_interactive(username, handler)
except SSHException, ignored:
except SSHException:
# attempt failed; just raise the original exception
raise x
raise e
return None
def auth_publickey(self, username, key, event=None):
@ -1331,15 +1331,15 @@ class Transport (threading.Thread):
m = Message()
m.add_mpint(self.K)
m.add_bytes(self.H)
m.add_byte(id)
m.add_byte(b(id))
m.add_bytes(self.session_id)
out = sofar = SHA.new(str(m)).digest()
out = sofar = SHA.new(m.asbytes()).digest()
while len(out) < nbytes:
m = Message()
m.add_mpint(self.K)
m.add_bytes(self.H)
m.add_bytes(sofar)
digest = SHA.new(str(m)).digest()
digest = SHA.new(m.asbytes()).digest()
out += digest
sofar += digest
return out[:nbytes]
@ -1373,7 +1373,7 @@ class Transport (threading.Thread):
# only called if a channel has turned on x11 forwarding
if handler is None:
# by default, use the same mechanism as accept()
def default_handler(channel, (src_addr, src_port)):
def default_handler(channel, src_addr_port):
self._queue_incoming_channel(channel)
self._x11_handler = default_handler
else:
@ -1404,12 +1404,12 @@ class Transport (threading.Thread):
# active=True occurs before the thread is launched, to avoid a race
_active_threads.append(self)
if self.server_mode:
self._log(DEBUG, 'starting thread (server mode): %s' % hex(long(id(self)) & 0xffffffffL))
self._log(DEBUG, 'starting thread (server mode): %s' % hex(long(id(self)) & xffffffff))
else:
self._log(DEBUG, 'starting thread (client mode): %s' % hex(long(id(self)) & 0xffffffffL))
self._log(DEBUG, 'starting thread (client mode): %s' % hex(long(id(self)) & xffffffff))
try:
try:
self.packetizer.write_all(self.local_version + '\r\n')
self.packetizer.write_all(b(self.local_version + '\r\n'))
self._check_banner()
self._send_kex_init()
self._expect_packet(MSG_KEXINIT)
@ -1457,18 +1457,18 @@ class Transport (threading.Thread):
else:
self._log(WARNING, 'Oops, unhandled type %d' % ptype)
msg = Message()
msg.add_byte(chr(MSG_UNIMPLEMENTED))
msg.add_byte(cMSG_UNIMPLEMENTED)
msg.add_int(m.seqno)
self._send_message(msg)
except SSHException, e:
except SSHException as e:
self._log(ERROR, 'Exception: ' + str(e))
self._log(ERROR, util.tb_strings())
self.saved_exception = e
except EOFError, e:
except EOFError as e:
self._log(DEBUG, 'EOF in transport thread')
#self._log(DEBUG, util.tb_strings())
self.saved_exception = e
except socket.error, e:
except socket.error as e:
if type(e.args) is tuple:
if e.args:
emsg = '%s (%d)' % (e.args[1], e.args[0])
@ -1478,12 +1478,12 @@ class Transport (threading.Thread):
emsg = e.args
self._log(ERROR, 'Socket exception: ' + emsg)
self.saved_exception = e
except Exception, e:
except Exception as e:
self._log(ERROR, 'Unknown exception: ' + str(e))
self._log(ERROR, util.tb_strings())
self.saved_exception = e
_active_threads.remove(self)
for chan in self._channels.values():
for chan in list(self._channels.values()):
chan._unlink()
if self.active:
self.active = False
@ -1538,8 +1538,8 @@ class Transport (threading.Thread):
buf = self.packetizer.readline(timeout)
except ProxyCommandFailure:
raise
except Exception, x:
raise SSHException('Error reading SSH protocol banner' + str(x))
except Exception as e:
raise SSHException('Error reading SSH protocol banner' + str(e))
if buf[:4] == 'SSH-':
break
self._log(DEBUG, 'Banner: ' + buf)
@ -1549,7 +1549,7 @@ class Transport (threading.Thread):
self.remote_version = buf
# pull off any attached comment
comment = ''
i = string.find(buf, ' ')
i = buf.find(' ')
if i >= 0:
comment = buf[i+1:]
buf = buf[:i]
@ -1580,13 +1580,13 @@ class Transport (threading.Thread):
pkex = list(self.get_security_options().kex)
pkex.remove('diffie-hellman-group-exchange-sha1')
self.get_security_options().kex = pkex
available_server_keys = filter(self.server_key_dict.keys().__contains__,
self._preferred_keys)
available_server_keys = list(filter(list(self.server_key_dict.keys()).__contains__,
self._preferred_keys))
else:
available_server_keys = self._preferred_keys
m = Message()
m.add_byte(chr(MSG_KEXINIT))
m.add_byte(cMSG_KEXINIT)
m.add_bytes(rng.read(16))
m.add_list(self._preferred_kex)
m.add_list(available_server_keys)
@ -1596,12 +1596,12 @@ class Transport (threading.Thread):
m.add_list(self._preferred_macs)
m.add_list(self._preferred_compression)
m.add_list(self._preferred_compression)
m.add_string('')
m.add_string('')
m.add_string(bytes())
m.add_string(bytes())
m.add_boolean(False)
m.add_int(0)
# save a copy for later (needed to compute a hash)
self.local_kex_init = str(m)
self.local_kex_init = m.asbytes()
self._send_message(m)
def _parse_kex_init(self, m):
@ -1633,19 +1633,19 @@ class Transport (threading.Thread):
# as a server, we pick the first item in the client's list that we support.
# as a client, we pick the first item in our list that the server supports.
if self.server_mode:
agreed_kex = filter(self._preferred_kex.__contains__, kex_algo_list)
agreed_kex = list(filter(self._preferred_kex.__contains__, kex_algo_list))
else:
agreed_kex = filter(kex_algo_list.__contains__, self._preferred_kex)
agreed_kex = list(filter(kex_algo_list.__contains__, self._preferred_kex))
if len(agreed_kex) == 0:
raise SSHException('Incompatible ssh peer (no acceptable kex algorithm)')
self.kex_engine = self._kex_info[agreed_kex[0]](self)
if self.server_mode:
available_server_keys = filter(self.server_key_dict.keys().__contains__,
self._preferred_keys)
agreed_keys = filter(available_server_keys.__contains__, server_key_algo_list)
available_server_keys = list(filter(list(self.server_key_dict.keys()).__contains__,
self._preferred_keys))
agreed_keys = list(filter(available_server_keys.__contains__, server_key_algo_list))
else:
agreed_keys = filter(server_key_algo_list.__contains__, self._preferred_keys)
agreed_keys = list(filter(server_key_algo_list.__contains__, self._preferred_keys))
if len(agreed_keys) == 0:
raise SSHException('Incompatible ssh peer (no acceptable host key)')
self.host_key_type = agreed_keys[0]
@ -1653,15 +1653,15 @@ class Transport (threading.Thread):
raise SSHException('Incompatible ssh peer (can\'t match requested host key type)')
if self.server_mode:
agreed_local_ciphers = filter(self._preferred_ciphers.__contains__,
server_encrypt_algo_list)
agreed_remote_ciphers = filter(self._preferred_ciphers.__contains__,
client_encrypt_algo_list)
agreed_local_ciphers = list(filter(self._preferred_ciphers.__contains__,
server_encrypt_algo_list))
agreed_remote_ciphers = list(filter(self._preferred_ciphers.__contains__,
client_encrypt_algo_list))
else:
agreed_local_ciphers = filter(client_encrypt_algo_list.__contains__,
self._preferred_ciphers)
agreed_remote_ciphers = filter(server_encrypt_algo_list.__contains__,
self._preferred_ciphers)
agreed_local_ciphers = list(filter(client_encrypt_algo_list.__contains__,
self._preferred_ciphers))
agreed_remote_ciphers = list(filter(server_encrypt_algo_list.__contains__,
self._preferred_ciphers))
if (len(agreed_local_ciphers) == 0) or (len(agreed_remote_ciphers) == 0):
raise SSHException('Incompatible ssh server (no acceptable ciphers)')
self.local_cipher = agreed_local_ciphers[0]
@ -1669,22 +1669,22 @@ class Transport (threading.Thread):
self._log(DEBUG, 'Ciphers agreed: local=%s, remote=%s' % (self.local_cipher, self.remote_cipher))
if self.server_mode:
agreed_remote_macs = filter(self._preferred_macs.__contains__, client_mac_algo_list)
agreed_local_macs = filter(self._preferred_macs.__contains__, server_mac_algo_list)
agreed_remote_macs = list(filter(self._preferred_macs.__contains__, client_mac_algo_list))
agreed_local_macs = list(filter(self._preferred_macs.__contains__, server_mac_algo_list))
else:
agreed_local_macs = filter(client_mac_algo_list.__contains__, self._preferred_macs)
agreed_remote_macs = filter(server_mac_algo_list.__contains__, self._preferred_macs)
agreed_local_macs = list(filter(client_mac_algo_list.__contains__, self._preferred_macs))
agreed_remote_macs = list(filter(server_mac_algo_list.__contains__, self._preferred_macs))
if (len(agreed_local_macs) == 0) or (len(agreed_remote_macs) == 0):
raise SSHException('Incompatible ssh server (no acceptable macs)')
self.local_mac = agreed_local_macs[0]
self.remote_mac = agreed_remote_macs[0]
if self.server_mode:
agreed_remote_compression = filter(self._preferred_compression.__contains__, client_compress_algo_list)
agreed_local_compression = filter(self._preferred_compression.__contains__, server_compress_algo_list)
agreed_remote_compression = list(filter(self._preferred_compression.__contains__, client_compress_algo_list))
agreed_local_compression = list(filter(self._preferred_compression.__contains__, server_compress_algo_list))
else:
agreed_local_compression = filter(client_compress_algo_list.__contains__, self._preferred_compression)
agreed_remote_compression = filter(server_compress_algo_list.__contains__, self._preferred_compression)
agreed_local_compression = list(filter(client_compress_algo_list.__contains__, self._preferred_compression))
agreed_remote_compression = list(filter(server_compress_algo_list.__contains__, self._preferred_compression))
if (len(agreed_local_compression) == 0) or (len(agreed_remote_compression) == 0):
raise SSHException('Incompatible ssh server (no acceptable compression) %r %r %r' % (agreed_local_compression, agreed_remote_compression, self._preferred_compression))
self.local_compression = agreed_local_compression[0]
@ -1699,7 +1699,7 @@ class Transport (threading.Thread):
# actually some extra bytes (one NUL byte in openssh's case) added to
# the end of the packet but not parsed. turns out we need to throw
# away those bytes because they aren't part of the hash.
self.remote_kex_init = chr(MSG_KEXINIT) + m.get_so_far()
self.remote_kex_init = cMSG_KEXINIT + m.get_so_far()
def _activate_inbound(self):
"switch on newly negotiated encryption parameters for inbound traffic"
@ -1728,7 +1728,7 @@ class Transport (threading.Thread):
def _activate_outbound(self):
"switch on newly negotiated encryption parameters for outbound traffic"
m = Message()
m.add_byte(chr(MSG_NEWKEYS))
m.add_byte(MSG_NEWKEYS)
self._send_message(m)
block_size = self._cipher_info[self.local_cipher]['block-size']
if self.server_mode:
@ -1797,24 +1797,24 @@ class Transport (threading.Thread):
def _parse_disconnect(self, m):
code = m.get_int()
desc = m.get_string()
desc = m.get_text()
self._log(INFO, 'Disconnect (code %d): %s' % (code, desc))
def _parse_global_request(self, m):
kind = m.get_string()
kind = m.get_text()
self._log(DEBUG, 'Received global request "%s"' % kind)
want_reply = m.get_boolean()
if not self.server_mode:
self._log(DEBUG, 'Rejecting "%s" global request from server.' % kind)
ok = False
elif kind == 'tcpip-forward':
address = m.get_string()
address = m.get_text()
port = m.get_int()
ok = self.server_object.check_port_forward_request(address, port)
if ok != False:
ok = (ok,)
elif kind == 'cancel-tcpip-forward':
address = m.get_string()
address = m.get_test()
port = m.get_int()
self.server_object.cancel_port_forward_request(address, port)
ok = True
@ -1827,10 +1827,10 @@ class Transport (threading.Thread):
if want_reply:
msg = Message()
if ok:
msg.add_byte(chr(MSG_REQUEST_SUCCESS))
msg.add_byte(cMSG_REQUEST_SUCCESS)
msg.add(*extra)
else:
msg.add_byte(chr(MSG_REQUEST_FAILURE))
msg.add_byte(cMSG_REQUEST_FAILURE)
self._send_message(msg)
def _parse_request_success(self, m):
@ -1868,8 +1868,8 @@ class Transport (threading.Thread):
def _parse_channel_open_failure(self, m):
chanid = m.get_int()
reason = m.get_int()
reason_str = m.get_string()
lang = m.get_string()
reason_str = m.get_text()
lang = m.get_text()
reason_text = CONNECTION_FAILED_CODE.get(reason, '(unknown code)')
self._log(INFO, 'Secsh channel %d open FAILED: %s: %s' % (chanid, reason_str, reason_text))
self.lock.acquire()
@ -1885,7 +1885,7 @@ class Transport (threading.Thread):
return
def _parse_channel_open(self, m):
kind = m.get_string()
kind = m.get_text()
chanid = m.get_int()
initial_window_size = m.get_int()
max_packet_size = m.get_int()
@ -1898,7 +1898,7 @@ class Transport (threading.Thread):
finally:
self.lock.release()
elif (kind == 'x11') and (self._x11_handler is not None):
origin_addr = m.get_string()
origin_addr = m.get_text()
origin_port = m.get_int()
self._log(DEBUG, 'Incoming x11 connection from %s:%d' % (origin_addr, origin_port))
self.lock.acquire()
@ -1907,9 +1907,9 @@ class Transport (threading.Thread):
finally:
self.lock.release()
elif (kind == 'forwarded-tcpip') and (self._tcp_handler is not None):
server_addr = m.get_string()
server_addr = m.get_text()
server_port = m.get_int()
origin_addr = m.get_string()
origin_addr = m.get_text()
origin_port = m.get_int()
self._log(DEBUG, 'Incoming tcp forwarded connection from %s:%d' % (origin_addr, origin_port))
self.lock.acquire()
@ -1929,9 +1929,9 @@ class Transport (threading.Thread):
self.lock.release()
if kind == 'direct-tcpip':
# handle direct-tcpip requests comming from the client
dest_addr = m.get_string()
dest_addr = m.get_text()
dest_port = m.get_int()
origin_addr = m.get_string()
origin_addr = m.get_text()
origin_port = m.get_int()
reason = self.server_object.check_channel_direct_tcpip_request(
my_chanid, (origin_addr, origin_port),
@ -1943,7 +1943,7 @@ class Transport (threading.Thread):
reject = True
if reject:
msg = Message()
msg.add_byte(chr(MSG_CHANNEL_OPEN_FAILURE))
msg.add_byte(cMSG_CHANNEL_OPEN_FAILURE)
msg.add_int(chanid)
msg.add_int(reason)
msg.add_string('')
@ -1962,7 +1962,7 @@ class Transport (threading.Thread):
finally:
self.lock.release()
m = Message()
m.add_byte(chr(MSG_CHANNEL_OPEN_SUCCESS))
m.add_byte(cMSG_CHANNEL_OPEN_SUCCESS)
m.add_int(chanid)
m.add_int(my_chanid)
m.add_int(self.window_size)
@ -2029,7 +2029,8 @@ class SecurityOptions (object):
``ValueError`` will be raised. If you try to assign something besides a
tuple to one of the fields, ``TypeError`` will be raised.
"""
__slots__ = [ 'ciphers', 'digests', 'key_types', 'kex', 'compression', '_transport' ]
#__slots__ = [ 'ciphers', 'digests', 'key_types', 'kex', 'compression', '_transport' ]
__slots__ = '_transport'
def __init__(self, transport):
self._transport = transport
@ -2060,8 +2061,8 @@ class SecurityOptions (object):
x = tuple(x)
if type(x) is not tuple:
raise TypeError('expected tuple or list')
possible = getattr(self._transport, orig).keys()
forbidden = filter(lambda n: n not in possible, x)
possible = list(getattr(self._transport, orig).keys())
forbidden = [n for n in x if n not in possible]
if len(forbidden) > 0:
raise ValueError('unknown cipher')
setattr(self._transport, name, x)
@ -2125,7 +2126,7 @@ class ChannelMap (object):
def values(self):
self._lock.acquire()
try:
return self._map.values()
return list(self._map.values())
finally:
self._lock.release()

View File

@ -48,60 +48,53 @@ if sys.version_info < (2,3):
def inflate_long(s, always_positive=False):
"turns a normalized byte string into a long-int (adapted from Crypto.Util.number)"
out = 0L
out = long(0)
negative = 0
if not always_positive and (len(s) > 0) and (ord(s[0]) >= 0x80):
if not always_positive and (len(s) > 0) and (byte_ord(s[0]) >= 0x80):
negative = 1
if len(s) % 4:
filler = '\x00'
filler = zero_byte
if negative:
filler = '\xff'
filler = max_byte
s = filler * (4 - len(s) % 4) + s
for i in range(0, len(s), 4):
out = (out << 32) + struct.unpack('>I', s[i:i+4])[0]
if negative:
out -= (1L << (8 * len(s)))
out -= (long(1) << (8 * len(s)))
return out
deflate_zero = zero_byte if PY2 else 0
deflate_ff = max_byte if PY2 else 0xff
def deflate_long(n, add_sign_padding=True):
"turns a long-int into a normalized byte string (adapted from Crypto.Util.number)"
# after much testing, this algorithm was deemed to be the fastest
s = ''
s = bytes()
n = long(n)
while (n != 0) and (n != -1):
s = struct.pack('>I', n & 0xffffffffL) + s
s = struct.pack('>I', n & xffffffff) + s
n = n >> 32
# strip off leading zeros, FFs
for i in enumerate(s):
if (n == 0) and (i[1] != '\000'):
if (n == 0) and (i[1] != deflate_zero):
break
if (n == -1) and (i[1] != '\xff'):
if (n == -1) and (i[1] != deflate_ff):
break
else:
# degenerate case, n was either 0 or -1
i = (0,)
if n == 0:
s = '\000'
s = zero_byte
else:
s = '\xff'
s = max_byte
s = s[i[0]:]
if add_sign_padding:
if (n == 0) and (ord(s[0]) >= 0x80):
s = '\x00' + s
if (n == -1) and (ord(s[0]) < 0x80):
s = '\xff' + s
if (n == 0) and (byte_ord(s[0]) >= 0x80):
s = zero_byte + s
if (n == -1) and (byte_ord(s[0]) < 0x80):
s = max_byte + s
return s
def format_binary_weird(data):
out = ''
for i in enumerate(data):
out += '%02X' % ord(i[1])
if i[0] % 2:
out += ' '
if i[0] % 16 == 15:
out += '\n'
return out
def format_binary(data, prefix=''):
x = 0
out = []
@ -113,8 +106,8 @@ def format_binary(data, prefix=''):
return [prefix + x for x in out]
def format_binary_line(data):
left = ' '.join(['%02X' % ord(c) for c in data])
right = ''.join([('.%c..' % c)[(ord(c)+63)//95] for c in data])
left = ' '.join(['%02X' % byte_ord(c) for c in data])
right = ''.join([('.%c..' % c)[(byte_ord(c)+63)//95] for c in data])
return '%-50s %s' % (left, right)
def hexify(s):
@ -126,17 +119,20 @@ def unhexify(s):
def safe_string(s):
out = ''
for c in s:
if (ord(c) >= 32) and (ord(c) <= 127):
if (byte_ord(c) >= 32) and (byte_ord(c) <= 127):
out += c
else:
out += '%%%02X' % ord(c)
out += '%%%02X' % byte_ord(c)
return out
# ''.join([['%%%02X' % ord(c), c][(ord(c) >= 32) and (ord(c) <= 127)] for c in s])
def bit_length(n):
try:
return n.bitlength()
except AttributeError:
norm = deflate_long(n, 0)
hbyte = ord(norm[0])
hbyte = byte_ord(norm[0])
if hbyte == 0:
return 1
bitlen = len(norm) * 8
@ -157,20 +153,21 @@ def generate_key_bytes(hashclass, salt, key, nbytes):
:param class hashclass:
class from `Crypto.Hash` that can be used as a secure hashing function
(like ``MD5`` or ``SHA``).
:param str salt: data to salt the hash with.
:param salt: data to salt the hash with.
:type salt: byte string
:param str key: human-entered password or passphrase.
:param int nbytes: number of bytes to generate.
:return: Key data `str`
"""
keydata = ''
digest = ''
keydata = bytes()
digest = bytes()
if len(salt) > 8:
salt = salt[:8]
while nbytes > 0:
hash_obj = hashclass.new()
if len(digest) > 0:
hash_obj.update(digest)
hash_obj.update(key)
hash_obj.update(b(key))
hash_obj.update(salt)
digest = hash_obj.digest()
size = min(nbytes, len(digest))
@ -271,37 +268,37 @@ def retry_on_signal(function):
while True:
try:
return function()
except EnvironmentError, e:
except EnvironmentError as e:
if e.errno != errno.EINTR:
raise
class Counter (object):
"""Stateful counter for CTR mode crypto"""
def __init__(self, nbits, initial_value=1L, overflow=0L):
def __init__(self, nbits, initial_value=long(1), overflow=long(0)):
self.blocksize = nbits / 8
self.overflow = overflow
# start with value - 1 so we don't have to store intermediate values when counting
# could the iv be 0?
if initial_value == 0:
self.value = array.array('c', '\xFF' * self.blocksize)
self.value = array.array('c', max_byte * self.blocksize)
else:
x = deflate_long(initial_value - 1, add_sign_padding=False)
self.value = array.array('c', '\x00' * (self.blocksize - len(x)) + x)
self.value = array.array('c', zero_byte * (self.blocksize - len(x)) + x)
def __call__(self):
"""Increament the counter and return the new value"""
i = self.blocksize - 1
while i > -1:
c = self.value[i] = chr((ord(self.value[i]) + 1) % 256)
if c != '\x00':
c = self.value[i] = byte_chr((byte_ord(self.value[i]) + 1) % 256)
if c != zero_byte:
return self.value.tostring()
i -= 1
# counter reset
x = deflate_long(self.overflow, add_sign_padding=False)
self.value = array.array('c', '\x00' * (self.blocksize - len(x)) + x)
self.value = array.array('c', zero_byte * (self.blocksize - len(x)) + x)
return self.value.tostring()
def new(cls, nbits, initial_value=1L, overflow=0L):
def new(cls, nbits, initial_value=long(1), overflow=long(0)):
return cls(nbits, initial_value=initial_value, overflow=overflow)
new = classmethod(new)

View File

@ -27,6 +27,7 @@ import array
import ctypes.wintypes
import platform
import struct
from paramiko.util import *
try:
import _thread as thread # Python 3.x
@ -91,7 +92,7 @@ def _query_pageant(msg):
with pymap:
pymap.write(msg)
# Create an array buffer containing the mapped filename
char_buffer = array.array("c", map_name + '\0')
char_buffer = array.array("c", b(map_name) + zero_byte)
char_buffer_address, char_buffer_size = char_buffer.buffer_info()
# Create a string to use for the SendMessage function call
cds = COPYDATASTRUCT(_AGENT_COPYDATA_ID, char_buffer_size,

View File

@ -54,7 +54,7 @@ if sys.platform == 'darwin':
setup(name = "paramiko",
version = "1.12.2",
version = "1.13.0",
description = "SSH2 protocol library",
author = "Jeff Forcier",
author_email = "jeff@bitprophet.org",

33
test.py
View File

@ -29,22 +29,21 @@ import unittest
from optparse import OptionParser
import paramiko
import threading
from paramiko.py3compat import PY2
sys.path.append('tests')
from test_message import MessageTest
from test_file import BufferedFileTest
from test_buffered_pipe import BufferedPipeTest
from test_util import UtilTest
from test_hostkeys import HostKeysTest
from test_pkey import KeyTest
from test_kex import KexTest
from test_packetizer import PacketizerTest
from test_auth import AuthTest
from test_transport import TransportTest
from test_sftp import SFTPTest
from test_sftp_big import BigSFTPTest
from test_client import SSHClientTest
from tests.test_message import MessageTest
from tests.test_file import BufferedFileTest
from tests.test_buffered_pipe import BufferedPipeTest
from tests.test_util import UtilTest
from tests.test_hostkeys import HostKeysTest
from tests.test_pkey import KeyTest
from tests.test_kex import KexTest
from tests.test_packetizer import PacketizerTest
from tests.test_auth import AuthTest
from tests.test_transport import TransportTest
from tests.test_client import SSHClientTest
default_host = 'localhost'
default_user = os.environ.get('USER', 'nobody')
@ -109,12 +108,15 @@ def main():
paramiko.util.log_to_file('test.log')
if options.use_sftp:
from tests.test_sftp import SFTPTest
if options.use_loopback_sftp:
SFTPTest.init_loopback()
else:
SFTPTest.init(options.hostname, options.username, options.keyfile, options.password)
if not options.use_big_file:
SFTPTest.set_big_file_test(False)
if options.use_big_file:
from tests.test_sftp_big import BigSFTPTest
suite = unittest.TestSuite()
suite.addTest(unittest.makeSuite(MessageTest))
@ -147,7 +149,10 @@ def main():
# TODO: make that not a problem, jeez
for thread in threading.enumerate():
if thread is not threading.currentThread():
thread._Thread__stop()
if PY2:
thread._Thread__stop()
else:
thread._stop()
# Exit correctly
if not result.wasSuccessful():
sys.exit(1)

0
tests/__init__.py Normal file
View File

View File

@ -21,6 +21,7 @@
"""
import threading, socket
from paramiko.common import *
class LoopSocket (object):
@ -31,7 +32,7 @@ class LoopSocket (object):
"""
def __init__(self):
self.__in_buffer = ''
self.__in_buffer = bytes()
self.__lock = threading.Lock()
self.__cv = threading.Condition(self.__lock)
self.__timeout = None
@ -41,11 +42,12 @@ class LoopSocket (object):
self.__unlink()
try:
self.__lock.acquire()
self.__in_buffer = ''
self.__in_buffer = bytes()
finally:
self.__lock.release()
def send(self, data):
data = asbytes(data)
if self.__mate is None:
# EOF
raise EOFError()
@ -57,7 +59,7 @@ class LoopSocket (object):
try:
if self.__mate is None:
# EOF
return ''
return bytes()
if len(self.__in_buffer) == 0:
self.__cv.wait(self.__timeout)
if len(self.__in_buffer) == 0:

View File

@ -21,8 +21,10 @@ A stub SFTP server for loopback SFTP testing.
"""
import os
import sys
from paramiko import ServerInterface, SFTPServerInterface, SFTPServer, SFTPAttributes, \
SFTPHandle, SFTP_OK, AUTH_SUCCESSFUL, OPEN_SUCCEEDED
from paramiko.common import *
class StubServer (ServerInterface):
@ -38,7 +40,7 @@ class StubSFTPHandle (SFTPHandle):
def stat(self):
try:
return SFTPAttributes.from_stat(os.fstat(self.readfile.fileno()))
except OSError, e:
except OSError as e:
return SFTPServer.convert_errno(e.errno)
def chattr(self, attr):
@ -47,7 +49,7 @@ class StubSFTPHandle (SFTPHandle):
try:
SFTPServer.set_file_attr(self.filename, attr)
return SFTP_OK
except OSError, e:
except OSError as e:
return SFTPServer.convert_errno(e.errno)
@ -69,21 +71,21 @@ class StubSFTPServer (SFTPServerInterface):
attr.filename = fname
out.append(attr)
return out
except OSError, e:
except OSError as e:
return SFTPServer.convert_errno(e.errno)
def stat(self, path):
path = self._realpath(path)
try:
return SFTPAttributes.from_stat(os.stat(path))
except OSError, e:
except OSError as e:
return SFTPServer.convert_errno(e.errno)
def lstat(self, path):
path = self._realpath(path)
try:
return SFTPAttributes.from_stat(os.lstat(path))
except OSError, e:
except OSError as e:
return SFTPServer.convert_errno(e.errno)
def open(self, path, flags, attr):
@ -97,8 +99,8 @@ class StubSFTPServer (SFTPServerInterface):
else:
# os.open() defaults to 0777 which is
# an odd default mode for files
fd = os.open(path, flags, 0666)
except OSError, e:
fd = os.open(path, flags, o666)
except OSError as e:
return SFTPServer.convert_errno(e.errno)
if (flags & os.O_CREAT) and (attr is not None):
attr._flags &= ~attr.FLAG_PERMISSIONS
@ -118,7 +120,7 @@ class StubSFTPServer (SFTPServerInterface):
fstr = 'rb'
try:
f = os.fdopen(fd, fstr)
except OSError, e:
except OSError as e:
return SFTPServer.convert_errno(e.errno)
fobj = StubSFTPHandle(flags)
fobj.filename = path
@ -130,7 +132,7 @@ class StubSFTPServer (SFTPServerInterface):
path = self._realpath(path)
try:
os.remove(path)
except OSError, e:
except OSError as e:
return SFTPServer.convert_errno(e.errno)
return SFTP_OK
@ -139,7 +141,7 @@ class StubSFTPServer (SFTPServerInterface):
newpath = self._realpath(newpath)
try:
os.rename(oldpath, newpath)
except OSError, e:
except OSError as e:
return SFTPServer.convert_errno(e.errno)
return SFTP_OK
@ -149,7 +151,7 @@ class StubSFTPServer (SFTPServerInterface):
os.mkdir(path)
if attr is not None:
SFTPServer.set_file_attr(path, attr)
except OSError, e:
except OSError as e:
return SFTPServer.convert_errno(e.errno)
return SFTP_OK
@ -157,7 +159,7 @@ class StubSFTPServer (SFTPServerInterface):
path = self._realpath(path)
try:
os.rmdir(path)
except OSError, e:
except OSError as e:
return SFTPServer.convert_errno(e.errno)
return SFTP_OK
@ -165,7 +167,7 @@ class StubSFTPServer (SFTPServerInterface):
path = self._realpath(path)
try:
SFTPServer.set_file_attr(path, attr)
except OSError, e:
except OSError as e:
return SFTPServer.convert_errno(e.errno)
return SFTP_OK
@ -185,7 +187,7 @@ class StubSFTPServer (SFTPServerInterface):
target_path = '<error>'
try:
os.symlink(target_path, path)
except OSError, e:
except OSError as e:
return SFTPServer.convert_errno(e.errno)
return SFTP_OK
@ -193,7 +195,7 @@ class StubSFTPServer (SFTPServerInterface):
path = self._realpath(path)
try:
symlink = os.readlink(path)
except OSError, e:
except OSError as e:
return SFTPServer.convert_errno(e.errno)
# if it's absolute, remove the root
if os.path.isabs(symlink):

View File

@ -29,13 +29,17 @@ from paramiko import Transport, ServerInterface, RSAKey, DSSKey, \
AuthenticationException
from paramiko import AUTH_FAILED, AUTH_PARTIALLY_SUCCESSFUL, AUTH_SUCCESSFUL
from paramiko import OPEN_SUCCEEDED, OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED
from loop import LoopSocket
from paramiko.py3compat import u
from tests.loop import LoopSocket
from tests.util import test_path
_pwd = u('\u2022')
class NullServer (ServerInterface):
paranoid_did_password = False
paranoid_did_public_key = False
paranoid_key = DSSKey.from_private_key_file('tests/test_dss.key')
paranoid_key = DSSKey.from_private_key_file(test_path('test_dss.key'))
def get_allowed_auths(self, username):
if username == 'slowdive':
@ -64,7 +68,7 @@ class NullServer (ServerInterface):
if self.paranoid_did_public_key:
return AUTH_SUCCESSFUL
return AUTH_PARTIALLY_SUCCESSFUL
if (username == 'utf8') and (password == u'\u2022'):
if (username == 'utf8') and (password == _pwd):
return AUTH_SUCCESSFUL
if (username == 'non-utf8') and (password == '\xff'):
return AUTH_SUCCESSFUL
@ -110,18 +114,18 @@ class AuthTest (unittest.TestCase):
self.sockc.close()
def start_server(self):
host_key = RSAKey.from_private_key_file('tests/test_rsa.key')
self.public_host_key = RSAKey(data=str(host_key))
host_key = RSAKey.from_private_key_file(test_path('test_rsa.key'))
self.public_host_key = RSAKey(data=host_key.asbytes())
self.ts.add_server_key(host_key)
self.event = threading.Event()
self.server = NullServer()
self.assert_(not self.event.isSet())
self.assertTrue(not self.event.isSet())
self.ts.start_server(self.event, self.server)
def verify_finished(self):
self.event.wait(1.0)
self.assert_(self.event.isSet())
self.assert_(self.ts.is_active())
self.assertTrue(self.event.isSet())
self.assertTrue(self.ts.is_active())
def test_1_bad_auth_type(self):
"""
@ -132,11 +136,11 @@ class AuthTest (unittest.TestCase):
try:
self.tc.connect(hostkey=self.public_host_key,
username='unknown', password='error')
self.assert_(False)
self.assertTrue(False)
except:
etype, evalue, etb = sys.exc_info()
self.assertEquals(BadAuthenticationType, etype)
self.assertEquals(['publickey'], evalue.allowed_types)
self.assertEqual(BadAuthenticationType, etype)
self.assertEqual(['publickey'], evalue.allowed_types)
def test_2_bad_password(self):
"""
@ -147,10 +151,10 @@ class AuthTest (unittest.TestCase):
self.tc.connect(hostkey=self.public_host_key)
try:
self.tc.auth_password(username='slowdive', password='error')
self.assert_(False)
self.assertTrue(False)
except:
etype, evalue, etb = sys.exc_info()
self.assert_(issubclass(etype, AuthenticationException))
self.assertTrue(issubclass(etype, AuthenticationException))
self.tc.auth_password(username='slowdive', password='pygmalion')
self.verify_finished()
@ -161,10 +165,10 @@ class AuthTest (unittest.TestCase):
self.start_server()
self.tc.connect(hostkey=self.public_host_key)
remain = self.tc.auth_password(username='paranoid', password='paranoid')
self.assertEquals(['publickey'], remain)
key = DSSKey.from_private_key_file('tests/test_dss.key')
self.assertEqual(['publickey'], remain)
key = DSSKey.from_private_key_file(test_path('test_dss.key'))
remain = self.tc.auth_publickey(username='paranoid', key=key)
self.assertEquals([], remain)
self.assertEqual([], remain)
self.verify_finished()
def test_4_interactive_auth(self):
@ -180,9 +184,9 @@ class AuthTest (unittest.TestCase):
self.got_prompts = prompts
return ['cat']
remain = self.tc.auth_interactive('commie', handler)
self.assertEquals(self.got_title, 'password')
self.assertEquals(self.got_prompts, [('Password', False)])
self.assertEquals([], remain)
self.assertEqual(self.got_title, 'password')
self.assertEqual(self.got_prompts, [('Password', False)])
self.assertEqual([], remain)
self.verify_finished()
def test_5_interactive_auth_fallback(self):
@ -193,7 +197,7 @@ class AuthTest (unittest.TestCase):
self.start_server()
self.tc.connect(hostkey=self.public_host_key)
remain = self.tc.auth_password('commie', 'cat')
self.assertEquals([], remain)
self.assertEqual([], remain)
self.verify_finished()
def test_6_auth_utf8(self):
@ -202,8 +206,8 @@ class AuthTest (unittest.TestCase):
"""
self.start_server()
self.tc.connect(hostkey=self.public_host_key)
remain = self.tc.auth_password('utf8', u'\u2022')
self.assertEquals([], remain)
remain = self.tc.auth_password('utf8', _pwd)
self.assertEqual([], remain)
self.verify_finished()
def test_7_auth_non_utf8(self):
@ -214,7 +218,7 @@ class AuthTest (unittest.TestCase):
self.start_server()
self.tc.connect(hostkey=self.public_host_key)
remain = self.tc.auth_password('non-utf8', '\xff')
self.assertEquals([], remain)
self.assertEqual([], remain)
self.verify_finished()
def test_8_auth_gets_disconnected(self):
@ -228,4 +232,4 @@ class AuthTest (unittest.TestCase):
remain = self.tc.auth_password('bad-server', 'hello')
except:
etype, evalue, etb = sys.exc_info()
self.assert_(issubclass(etype, AuthenticationException))
self.assertTrue(issubclass(etype, AuthenticationException))

View File

@ -25,8 +25,9 @@ import time
import unittest
from paramiko.buffered_pipe import BufferedPipe, PipeTimeout
from paramiko import pipe
from paramiko.py3compat import b
from util import ParamikoTest
from tests.util import ParamikoTest
def delay_thread(pipe):
@ -44,39 +45,39 @@ def close_thread(pipe):
class BufferedPipeTest(ParamikoTest):
def test_1_buffered_pipe(self):
p = BufferedPipe()
self.assert_(not p.read_ready())
self.assertTrue(not p.read_ready())
p.feed('hello.')
self.assert_(p.read_ready())
self.assertTrue(p.read_ready())
data = p.read(6)
self.assertEquals('hello.', data)
self.assertEqual(b'hello.', data)
p.feed('plus/minus')
self.assertEquals('plu', p.read(3))
self.assertEquals('s/m', p.read(3))
self.assertEquals('inus', p.read(4))
self.assertEqual(b'plu', p.read(3))
self.assertEqual(b's/m', p.read(3))
self.assertEqual(b'inus', p.read(4))
p.close()
self.assert_(not p.read_ready())
self.assertEquals('', p.read(1))
self.assertTrue(not p.read_ready())
self.assertEqual(b'', p.read(1))
def test_2_delay(self):
p = BufferedPipe()
self.assert_(not p.read_ready())
self.assertTrue(not p.read_ready())
threading.Thread(target=delay_thread, args=(p,)).start()
self.assertEquals('a', p.read(1, 0.1))
self.assertEqual(b'a', p.read(1, 0.1))
try:
p.read(1, 0.1)
self.assert_(False)
self.assertTrue(False)
except PipeTimeout:
pass
self.assertEquals('b', p.read(1, 1.0))
self.assertEquals('', p.read(1))
self.assertEqual(b'b', p.read(1, 1.0))
self.assertEqual(b'', p.read(1))
def test_3_close_while_reading(self):
p = BufferedPipe()
threading.Thread(target=close_thread, args=(p,)).start()
data = p.read(1, 1.0)
self.assertEquals('', data)
self.assertEqual(b'', data)
def test_4_or_pipe(self):
p = pipe.make_pipe()

View File

@ -20,16 +20,14 @@
Some unit tests for SSHClient.
"""
from __future__ import with_statement # Python 2.5 support
import socket
from tempfile import mkstemp
import threading
import time
import unittest
import weakref
import warnings
import os
from binascii import hexlify
from tests.util import test_path
import paramiko
@ -46,7 +44,7 @@ class NullServer (paramiko.ServerInterface):
return paramiko.AUTH_FAILED
def check_auth_publickey(self, username, key):
if (key.get_name() == 'ssh-dss') and (hexlify(key.get_fingerprint()) == '4478f0b9a23cc5182009ff755bc1d26c'):
if (key.get_name() == 'ssh-dss') and key.get_fingerprint() == b'\x44\x78\xf0\xb9\xa2\x3c\xc5\x18\x20\x09\xff\x75\x5b\xc1\xd2\x6c':
return paramiko.AUTH_SUCCESSFUL
return paramiko.AUTH_FAILED
@ -67,8 +65,6 @@ class SSHClientTest (unittest.TestCase):
self.sockl.listen(1)
self.addr, self.port = self.sockl.getsockname()
self.event = threading.Event()
thread = threading.Thread(target=self._run)
thread.start()
def tearDown(self):
for attr in "tc ts socks sockl".split():
@ -78,28 +74,28 @@ class SSHClientTest (unittest.TestCase):
def _run(self):
self.socks, addr = self.sockl.accept()
self.ts = paramiko.Transport(self.socks)
host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key')
host_key = paramiko.RSAKey.from_private_key_file(test_path('test_rsa.key'))
self.ts.add_server_key(host_key)
server = NullServer()
self.ts.start_server(self.event, server)
def test_1_client(self):
"""
verify that the SSHClient stuff works too.
"""
host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key')
public_host_key = paramiko.RSAKey(data=str(host_key))
threading.Thread(target=self._run).start()
host_key = paramiko.RSAKey.from_private_key_file(test_path('test_rsa.key'))
public_host_key = paramiko.RSAKey(data=host_key.asbytes())
self.tc = paramiko.SSHClient()
self.tc.get_host_keys().add('[%s]:%d' % (self.addr, self.port), 'ssh-rsa', public_host_key)
self.tc.connect(self.addr, self.port, username='slowdive', password='pygmalion')
self.event.wait(1.0)
self.assert_(self.event.isSet())
self.assert_(self.ts.is_active())
self.assertEquals('slowdive', self.ts.get_username())
self.assertEquals(True, self.ts.is_authenticated())
self.assertTrue(self.event.isSet())
self.assertTrue(self.ts.is_active())
self.assertEqual('slowdive', self.ts.get_username())
self.assertEqual(True, self.ts.is_authenticated())
stdin, stdout, stderr = self.tc.exec_command('yes')
schan = self.ts.accept(1.0)
@ -108,10 +104,10 @@ class SSHClientTest (unittest.TestCase):
schan.send_stderr('This is on stderr.\n')
schan.close()
self.assertEquals('Hello there.\n', stdout.readline())
self.assertEquals('', stdout.readline())
self.assertEquals('This is on stderr.\n', stderr.readline())
self.assertEquals('', stderr.readline())
self.assertEqual('Hello there.\n', stdout.readline())
self.assertEqual('', stdout.readline())
self.assertEqual('This is on stderr.\n', stderr.readline())
self.assertEqual('', stderr.readline())
stdin.close()
stdout.close()
@ -121,18 +117,19 @@ class SSHClientTest (unittest.TestCase):
"""
verify that SSHClient works with a DSA key.
"""
host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key')
public_host_key = paramiko.RSAKey(data=str(host_key))
threading.Thread(target=self._run).start()
host_key = paramiko.RSAKey.from_private_key_file(test_path('test_rsa.key'))
public_host_key = paramiko.RSAKey(data=host_key.asbytes())
self.tc = paramiko.SSHClient()
self.tc.get_host_keys().add('[%s]:%d' % (self.addr, self.port), 'ssh-rsa', public_host_key)
self.tc.connect(self.addr, self.port, username='slowdive', key_filename='tests/test_dss.key')
self.tc.connect(self.addr, self.port, username='slowdive', key_filename=test_path('test_dss.key'))
self.event.wait(1.0)
self.assert_(self.event.isSet())
self.assert_(self.ts.is_active())
self.assertEquals('slowdive', self.ts.get_username())
self.assertEquals(True, self.ts.is_authenticated())
self.assertTrue(self.event.isSet())
self.assertTrue(self.ts.is_active())
self.assertEqual('slowdive', self.ts.get_username())
self.assertEqual(True, self.ts.is_authenticated())
stdin, stdout, stderr = self.tc.exec_command('yes')
schan = self.ts.accept(1.0)
@ -141,10 +138,10 @@ class SSHClientTest (unittest.TestCase):
schan.send_stderr('This is on stderr.\n')
schan.close()
self.assertEquals('Hello there.\n', stdout.readline())
self.assertEquals('', stdout.readline())
self.assertEquals('This is on stderr.\n', stderr.readline())
self.assertEquals('', stderr.readline())
self.assertEqual('Hello there.\n', stdout.readline())
self.assertEqual('', stdout.readline())
self.assertEqual('This is on stderr.\n', stderr.readline())
self.assertEqual('', stderr.readline())
stdin.close()
stdout.close()
@ -154,38 +151,40 @@ class SSHClientTest (unittest.TestCase):
"""
verify that SSHClient accepts and tries multiple key files.
"""
host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key')
public_host_key = paramiko.RSAKey(data=str(host_key))
threading.Thread(target=self._run).start()
host_key = paramiko.RSAKey.from_private_key_file(test_path('test_rsa.key'))
public_host_key = paramiko.RSAKey(data=host_key.asbytes())
self.tc = paramiko.SSHClient()
self.tc.get_host_keys().add('[%s]:%d' % (self.addr, self.port), 'ssh-rsa', public_host_key)
self.tc.connect(self.addr, self.port, username='slowdive', key_filename=[ 'tests/test_rsa.key', 'tests/test_dss.key' ])
self.tc.connect(self.addr, self.port, username='slowdive', key_filename=[ test_path('test_rsa.key'), test_path('test_dss.key') ])
self.event.wait(1.0)
self.assert_(self.event.isSet())
self.assert_(self.ts.is_active())
self.assertEquals('slowdive', self.ts.get_username())
self.assertEquals(True, self.ts.is_authenticated())
self.assertTrue(self.event.isSet())
self.assertTrue(self.ts.is_active())
self.assertEqual('slowdive', self.ts.get_username())
self.assertEqual(True, self.ts.is_authenticated())
def test_4_auto_add_policy(self):
"""
verify that SSHClient's AutoAddPolicy works.
"""
host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key')
public_host_key = paramiko.RSAKey(data=str(host_key))
threading.Thread(target=self._run).start()
host_key = paramiko.RSAKey.from_private_key_file(test_path('test_rsa.key'))
public_host_key = paramiko.RSAKey(data=host_key.asbytes())
self.tc = paramiko.SSHClient()
self.tc.set_missing_host_key_policy(paramiko.AutoAddPolicy())
self.assertEquals(0, len(self.tc.get_host_keys()))
self.assertEqual(0, len(self.tc.get_host_keys()))
self.tc.connect(self.addr, self.port, username='slowdive', password='pygmalion')
self.event.wait(1.0)
self.assert_(self.event.isSet())
self.assert_(self.ts.is_active())
self.assertEquals('slowdive', self.ts.get_username())
self.assertEquals(True, self.ts.is_authenticated())
self.assertEquals(1, len(self.tc.get_host_keys()))
self.assertEquals(public_host_key, self.tc.get_host_keys()['[%s]:%d' % (self.addr, self.port)]['ssh-rsa'])
self.assertTrue(self.event.isSet())
self.assertTrue(self.ts.is_active())
self.assertEqual('slowdive', self.ts.get_username())
self.assertEqual(True, self.ts.is_authenticated())
self.assertEqual(1, len(self.tc.get_host_keys()))
self.assertEqual(public_host_key, self.tc.get_host_keys()['[%s]:%d' % (self.addr, self.port)]['ssh-rsa'])
def test_5_save_host_keys(self):
"""
@ -193,9 +192,10 @@ class SSHClientTest (unittest.TestCase):
"""
warnings.filterwarnings('ignore', 'tempnam.*')
host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key')
public_host_key = paramiko.RSAKey(data=str(host_key))
localname = os.tempnam()
host_key = paramiko.RSAKey.from_private_key_file(test_path('test_rsa.key'))
public_host_key = paramiko.RSAKey(data=host_key.asbytes())
fd, localname = mkstemp()
os.close(fd)
client = paramiko.SSHClient()
self.assertEquals(0, len(client.get_host_keys()))
@ -218,24 +218,32 @@ class SSHClientTest (unittest.TestCase):
verify that when an SSHClient is collected, its transport (and the
transport's packetizer) is closed.
"""
host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key')
public_host_key = paramiko.RSAKey(data=str(host_key))
threading.Thread(target=self._run).start()
host_key = paramiko.RSAKey.from_private_key_file(test_path('test_rsa.key'))
public_host_key = paramiko.RSAKey(data=host_key.asbytes())
self.tc = paramiko.SSHClient()
self.tc.set_missing_host_key_policy(paramiko.AutoAddPolicy())
self.assertEquals(0, len(self.tc.get_host_keys()))
self.assertEqual(0, len(self.tc.get_host_keys()))
self.tc.connect(self.addr, self.port, username='slowdive', password='pygmalion')
self.event.wait(1.0)
self.assert_(self.event.isSet())
self.assert_(self.ts.is_active())
self.assertTrue(self.event.isSet())
self.assertTrue(self.ts.is_active())
p = weakref.ref(self.tc._transport.packetizer)
self.assert_(p() is not None)
self.assertTrue(p() is not None)
self.tc.close()
del self.tc
# hrm, sometimes p isn't cleared right away. why is that?
st = time.time()
while (time.time() - st < 5.0) and (p() is not None):
time.sleep(0.1)
self.assert_(p() is None)
# hrm, sometimes p isn't cleared right away. why is that?
#st = time.time()
#while (time.time() - st < 5.0) and (p() is not None):
# time.sleep(0.1)
# instead of dumbly waiting for the GC to collect, force a collection
# to see whether the SSHClient object is deallocated correctly
import gc
gc.collect()
self.assertTrue(p() is None)

View File

@ -22,6 +22,7 @@ Some unit tests for the BufferedFile abstraction.
import unittest
from paramiko.file import BufferedFile
from paramiko.common import *
class LoopbackFile (BufferedFile):
@ -31,7 +32,7 @@ class LoopbackFile (BufferedFile):
def __init__(self, mode='r', bufsize=-1):
BufferedFile.__init__(self)
self._set_mode(mode, bufsize)
self.buffer = ''
self.buffer = bytes()
def _read(self, size):
if len(self.buffer) == 0:
@ -53,7 +54,7 @@ class BufferedFileTest (unittest.TestCase):
f = LoopbackFile('r')
try:
f.write('hi')
self.assert_(False, 'no exception on write to read-only file')
self.assertTrue(False, 'no exception on write to read-only file')
except:
pass
f.close()
@ -61,7 +62,7 @@ class BufferedFileTest (unittest.TestCase):
f = LoopbackFile('w')
try:
f.read(1)
self.assert_(False, 'no exception to read from write-only file')
self.assertTrue(False, 'no exception to read from write-only file')
except:
pass
f.close()
@ -80,12 +81,12 @@ class BufferedFileTest (unittest.TestCase):
f.close()
try:
f.readline()
self.assert_(False, 'no exception on readline of closed file')
self.assertTrue(False, 'no exception on readline of closed file')
except IOError:
pass
self.assert_('\n' in f.newlines)
self.assert_('\r\n' in f.newlines)
self.assert_('\r' not in f.newlines)
self.assertTrue(linefeed_byte in f.newlines)
self.assertTrue(crlf in f.newlines)
self.assertTrue(cr_byte not in f.newlines)
def test_3_lf(self):
"""
@ -97,7 +98,7 @@ class BufferedFileTest (unittest.TestCase):
f.write('\nSecond.\r\n')
self.assertEqual(f.readline(), 'Second.\n')
f.close()
self.assertEqual(f.newlines, '\r\n')
self.assertEqual(f.newlines, crlf)
def test_4_write(self):
"""

View File

@ -25,6 +25,7 @@ from binascii import hexlify
import os
import unittest
import paramiko
from paramiko.py3compat import b, decodebytes
test_hosts_file = """\
@ -36,12 +37,12 @@ BGQ3GQ/Fc7SX6gkpXkwcZryoi4kNFhHu5LvHcZPdxXV1D+uTMfGS1eyd2Yz/DoNWXNAl8TI0cAsW\
5ymME3bQ4J/k1IKxCtz/bAlAqFgKoc+EolMziDYqWIATtW0rYTJvzGAzTmMj80/QpsFH+Pc2M=
"""
keyblob = """\
keyblob = b"""\
AAAAB3NzaC1yc2EAAAABIwAAAIEA8bP1ZA7DCZDB9J0s50l31MBGQ3GQ/Fc7SX6gkpXkwcZryoi4k\
NFhHu5LvHcZPdxXV1D+uTMfGS1eyd2Yz/DoNWXNAl8TI0cAsW5ymME3bQ4J/k1IKxCtz/bAlAqFgK\
oc+EolMziDYqWIATtW0rYTJvzGAzTmMj80/QpsFH+Pc2M="""
keyblob_dss = """\
keyblob_dss = b"""\
AAAAB3NzaC1kc3MAAACBAOeBpgNnfRzr/twmAQRu2XwWAp3CFtrVnug6s6fgwj/oLjYbVtjAy6pl/\
h0EKCWx2rf1IetyNsTxWrniA9I6HeDj65X1FyDkg6g8tvCnaNB8Xp/UUhuzHuGsMIipRxBxw9LF60\
8EqZcj1E3ytktoW5B5OcjrkEoz3xG7C+rpIjYvAAAAFQDwz4UnmsGiSNu5iqjn3uTzwUpshwAAAIE\
@ -55,51 +56,50 @@ Ngw3qIch/WgRmMHy4kBq1SsXMjQCte1So6HBMvBPIW5SiMTmjCfZZiw4AYHK+B/JaOwaG9yRg2Ejg\
class HostKeysTest (unittest.TestCase):
def setUp(self):
f = open('hostfile.temp', 'w')
f.write(test_hosts_file)
f.close()
with open('hostfile.temp', 'w') as f:
f.write(test_hosts_file)
def tearDown(self):
os.unlink('hostfile.temp')
def test_1_load(self):
hostdict = paramiko.HostKeys('hostfile.temp')
self.assertEquals(2, len(hostdict))
self.assertEquals(1, len(hostdict.values()[0]))
self.assertEquals(1, len(hostdict.values()[1]))
self.assertEqual(2, len(hostdict))
self.assertEqual(1, len(list(hostdict.values())[0]))
self.assertEqual(1, len(list(hostdict.values())[1]))
fp = hexlify(hostdict['secure.example.com']['ssh-rsa'].get_fingerprint()).upper()
self.assertEquals('E6684DB30E109B67B70FF1DC5C7F1363', fp)
self.assertEqual(b'E6684DB30E109B67B70FF1DC5C7F1363', fp)
def test_2_add(self):
hostdict = paramiko.HostKeys('hostfile.temp')
hh = '|1|BMsIC6cUIP2zBuXR3t2LRcJYjzM=|hpkJMysjTk/+zzUUzxQEa2ieq6c='
key = paramiko.RSAKey(data=base64.decodestring(keyblob))
key = paramiko.RSAKey(data=decodebytes(keyblob))
hostdict.add(hh, 'ssh-rsa', key)
self.assertEquals(3, len(hostdict))
self.assertEqual(3, len(list(hostdict)))
x = hostdict['foo.example.com']
fp = hexlify(x['ssh-rsa'].get_fingerprint()).upper()
self.assertEquals('7EC91BB336CB6D810B124B1353C32396', fp)
self.assert_(hostdict.check('foo.example.com', key))
self.assertEqual(b'7EC91BB336CB6D810B124B1353C32396', fp)
self.assertTrue(hostdict.check('foo.example.com', key))
def test_3_dict(self):
hostdict = paramiko.HostKeys('hostfile.temp')
self.assert_('secure.example.com' in hostdict)
self.assert_('not.example.com' not in hostdict)
self.assert_(hostdict.has_key('secure.example.com'))
self.assert_(not hostdict.has_key('not.example.com'))
self.assertTrue('secure.example.com' in hostdict)
self.assertTrue('not.example.com' not in hostdict)
self.assertTrue('secure.example.com' in hostdict)
self.assertTrue('not.example.com' not in hostdict)
x = hostdict.get('secure.example.com', None)
self.assert_(x is not None)
self.assertTrue(x is not None)
fp = hexlify(x['ssh-rsa'].get_fingerprint()).upper()
self.assertEquals('E6684DB30E109B67B70FF1DC5C7F1363', fp)
self.assertEqual(b'E6684DB30E109B67B70FF1DC5C7F1363', fp)
i = 0
for key in hostdict:
i += 1
self.assertEquals(2, i)
self.assertEqual(2, i)
def test_4_dict_set(self):
hostdict = paramiko.HostKeys('hostfile.temp')
key = paramiko.RSAKey(data=base64.decodestring(keyblob))
key_dss = paramiko.DSSKey(data=base64.decodestring(keyblob_dss))
key = paramiko.RSAKey(data=decodebytes(keyblob))
key_dss = paramiko.DSSKey(data=decodebytes(keyblob_dss))
hostdict['secure.example.com'] = {
'ssh-rsa': key,
'ssh-dss': key_dss
@ -107,11 +107,11 @@ class HostKeysTest (unittest.TestCase):
hostdict['fake.example.com'] = {}
hostdict['fake.example.com']['ssh-rsa'] = key
self.assertEquals(3, len(hostdict))
self.assertEquals(2, len(hostdict.values()[0]))
self.assertEquals(1, len(hostdict.values()[1]))
self.assertEquals(1, len(hostdict.values()[2]))
self.assertEqual(3, len(hostdict))
self.assertEqual(2, len(list(hostdict.values())[0]))
self.assertEqual(1, len(list(hostdict.values())[1]))
self.assertEqual(1, len(list(hostdict.values())[2]))
fp = hexlify(hostdict['secure.example.com']['ssh-rsa'].get_fingerprint()).upper()
self.assertEquals('7EC91BB336CB6D810B124B1353C32396', fp)
self.assertEqual(b'7EC91BB336CB6D810B124B1353C32396', fp)
fp = hexlify(hostdict['secure.example.com']['ssh-dss'].get_fingerprint()).upper()
self.assertEquals('4478F0B9A23CC5182009FF755BC1D26C', fp)
self.assertEqual(b'4478F0B9A23CC5182009FF755BC1D26C', fp)

View File

@ -26,22 +26,25 @@ import paramiko.util
from paramiko.kex_group1 import KexGroup1
from paramiko.kex_gex import KexGex
from paramiko import Message
from paramiko.common import *
class FakeRng (object):
def read(self, n):
return chr(0xcc) * n
return byte_chr(0xcc) * n
class FakeKey (object):
def __str__(self):
return 'fake-key'
def asbytes(self):
return b'fake-key'
def sign_ssh_data(self, rng, H):
return 'fake-sig'
return b'fake-sig'
class FakeModulusPack (object):
P = 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFFL
P = 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF
G = 2
def get_modulus(self, min, ask, max):
return self.G, self.P
@ -75,7 +78,7 @@ class FakeTransport (object):
class KexTest (unittest.TestCase):
K = 14730343317708716439807310032871972459448364195094179797249681733965528989482751523943515690110179031004049109375612685505881911274101441415545039654102474376472240501616988799699744135291070488314748284283496055223852115360852283821334858541043710301057312858051901453919067023103730011648890038847384890504L
K = 14730343317708716439807310032871972459448364195094179797249681733965528989482751523943515690110179031004049109375612685505881911274101441415545039654102474376472240501616988799699744135291070488314748284283496055223852115360852283821334858541043710301057312858051901453919067023103730011648890038847384890504
def setUp(self):
pass
@ -88,9 +91,9 @@ class KexTest (unittest.TestCase):
transport.server_mode = False
kex = KexGroup1(transport)
kex.start_kex()
x = '1E000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D4'
self.assertEquals(x, hexlify(str(transport._message)).upper())
self.assertEquals((paramiko.kex_group1._MSG_KEXDH_REPLY,), transport._expect)
x = b'1E000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D4'
self.assertEqual(x, hexlify(transport._message.asbytes()).upper())
self.assertEqual((paramiko.kex_group1._MSG_KEXDH_REPLY,), transport._expect)
# fake "reply"
msg = Message()
@ -99,47 +102,47 @@ class KexTest (unittest.TestCase):
msg.add_string('fake-sig')
msg.rewind()
kex.parse_next(paramiko.kex_group1._MSG_KEXDH_REPLY, msg)
H = '03079780F3D3AD0B3C6DB30C8D21685F367A86D2'
self.assertEquals(self.K, transport._K)
self.assertEquals(H, hexlify(transport._H).upper())
self.assertEquals(('fake-host-key', 'fake-sig'), transport._verify)
self.assert_(transport._activated)
H = b'03079780F3D3AD0B3C6DB30C8D21685F367A86D2'
self.assertEqual(self.K, transport._K)
self.assertEqual(H, hexlify(transport._H).upper())
self.assertEqual((b'fake-host-key', b'fake-sig'), transport._verify)
self.assertTrue(transport._activated)
def test_2_group1_server(self):
transport = FakeTransport()
transport.server_mode = True
kex = KexGroup1(transport)
kex.start_kex()
self.assertEquals((paramiko.kex_group1._MSG_KEXDH_INIT,), transport._expect)
self.assertEqual((paramiko.kex_group1._MSG_KEXDH_INIT,), transport._expect)
msg = Message()
msg.add_mpint(69)
msg.rewind()
kex.parse_next(paramiko.kex_group1._MSG_KEXDH_INIT, msg)
H = 'B16BF34DD10945EDE84E9C1EF24A14BFDC843389'
x = '1F0000000866616B652D6B6579000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D40000000866616B652D736967'
self.assertEquals(self.K, transport._K)
self.assertEquals(H, hexlify(transport._H).upper())
self.assertEquals(x, hexlify(str(transport._message)).upper())
self.assert_(transport._activated)
H = b'B16BF34DD10945EDE84E9C1EF24A14BFDC843389'
x = b'1F0000000866616B652D6B6579000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D40000000866616B652D736967'
self.assertEqual(self.K, transport._K)
self.assertEqual(H, hexlify(transport._H).upper())
self.assertEqual(x, hexlify(transport._message.asbytes()).upper())
self.assertTrue(transport._activated)
def test_3_gex_client(self):
transport = FakeTransport()
transport.server_mode = False
kex = KexGex(transport)
kex.start_kex()
x = '22000004000000080000002000'
self.assertEquals(x, hexlify(str(transport._message)).upper())
self.assertEquals((paramiko.kex_gex._MSG_KEXDH_GEX_GROUP,), transport._expect)
x = b'22000004000000080000002000'
self.assertEqual(x, hexlify(transport._message.asbytes()).upper())
self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_GROUP,), transport._expect)
msg = Message()
msg.add_mpint(FakeModulusPack.P)
msg.add_mpint(FakeModulusPack.G)
msg.rewind()
kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_GROUP, msg)
x = '20000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D4'
self.assertEquals(x, hexlify(str(transport._message)).upper())
self.assertEquals((paramiko.kex_gex._MSG_KEXDH_GEX_REPLY,), transport._expect)
x = b'20000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D4'
self.assertEqual(x, hexlify(transport._message.asbytes()).upper())
self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_REPLY,), transport._expect)
msg = Message()
msg.add_string('fake-host-key')
@ -147,29 +150,29 @@ class KexTest (unittest.TestCase):
msg.add_string('fake-sig')
msg.rewind()
kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_REPLY, msg)
H = 'A265563F2FA87F1A89BF007EE90D58BE2E4A4BD0'
self.assertEquals(self.K, transport._K)
self.assertEquals(H, hexlify(transport._H).upper())
self.assertEquals(('fake-host-key', 'fake-sig'), transport._verify)
self.assert_(transport._activated)
H = b'A265563F2FA87F1A89BF007EE90D58BE2E4A4BD0'
self.assertEqual(self.K, transport._K)
self.assertEqual(H, hexlify(transport._H).upper())
self.assertEqual((b'fake-host-key', b'fake-sig'), transport._verify)
self.assertTrue(transport._activated)
def test_4_gex_old_client(self):
transport = FakeTransport()
transport.server_mode = False
kex = KexGex(transport)
kex.start_kex(_test_old_style=True)
x = '1E00000800'
self.assertEquals(x, hexlify(str(transport._message)).upper())
self.assertEquals((paramiko.kex_gex._MSG_KEXDH_GEX_GROUP,), transport._expect)
x = b'1E00000800'
self.assertEqual(x, hexlify(transport._message.asbytes()).upper())
self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_GROUP,), transport._expect)
msg = Message()
msg.add_mpint(FakeModulusPack.P)
msg.add_mpint(FakeModulusPack.G)
msg.rewind()
kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_GROUP, msg)
x = '20000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D4'
self.assertEquals(x, hexlify(str(transport._message)).upper())
self.assertEquals((paramiko.kex_gex._MSG_KEXDH_GEX_REPLY,), transport._expect)
x = b'20000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D4'
self.assertEqual(x, hexlify(transport._message.asbytes()).upper())
self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_REPLY,), transport._expect)
msg = Message()
msg.add_string('fake-host-key')
@ -177,18 +180,18 @@ class KexTest (unittest.TestCase):
msg.add_string('fake-sig')
msg.rewind()
kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_REPLY, msg)
H = '807F87B269EF7AC5EC7E75676808776A27D5864C'
self.assertEquals(self.K, transport._K)
self.assertEquals(H, hexlify(transport._H).upper())
self.assertEquals(('fake-host-key', 'fake-sig'), transport._verify)
self.assert_(transport._activated)
H = b'807F87B269EF7AC5EC7E75676808776A27D5864C'
self.assertEqual(self.K, transport._K)
self.assertEqual(H, hexlify(transport._H).upper())
self.assertEqual((b'fake-host-key', b'fake-sig'), transport._verify)
self.assertTrue(transport._activated)
def test_5_gex_server(self):
transport = FakeTransport()
transport.server_mode = True
kex = KexGex(transport)
kex.start_kex()
self.assertEquals((paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST, paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST_OLD), transport._expect)
self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST, paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST_OLD), transport._expect)
msg = Message()
msg.add_int(1024)
@ -196,45 +199,45 @@ class KexTest (unittest.TestCase):
msg.add_int(4096)
msg.rewind()
kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST, msg)
x = '1F0000008100FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF0000000102'
self.assertEquals(x, hexlify(str(transport._message)).upper())
self.assertEquals((paramiko.kex_gex._MSG_KEXDH_GEX_INIT,), transport._expect)
x = b'1F0000008100FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF0000000102'
self.assertEqual(x, hexlify(transport._message.asbytes()).upper())
self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_INIT,), transport._expect)
msg = Message()
msg.add_mpint(12345)
msg.rewind()
kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_INIT, msg)
K = 67592995013596137876033460028393339951879041140378510871612128162185209509220726296697886624612526735888348020498716482757677848959420073720160491114319163078862905400020959196386947926388406687288901564192071077389283980347784184487280885335302632305026248574716290537036069329724382811853044654824945750581L
H = 'CE754197C21BF3452863B4F44D0B3951F12516EF'
x = '210000000866616B652D6B6579000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D40000000866616B652D736967'
self.assertEquals(K, transport._K)
self.assertEquals(H, hexlify(transport._H).upper())
self.assertEquals(x, hexlify(str(transport._message)).upper())
self.assert_(transport._activated)
K = 67592995013596137876033460028393339951879041140378510871612128162185209509220726296697886624612526735888348020498716482757677848959420073720160491114319163078862905400020959196386947926388406687288901564192071077389283980347784184487280885335302632305026248574716290537036069329724382811853044654824945750581
H = b'CE754197C21BF3452863B4F44D0B3951F12516EF'
x = b'210000000866616B652D6B6579000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D40000000866616B652D736967'
self.assertEqual(K, transport._K)
self.assertEqual(H, hexlify(transport._H).upper())
self.assertEqual(x, hexlify(transport._message.asbytes()).upper())
self.assertTrue(transport._activated)
def test_6_gex_server_with_old_client(self):
transport = FakeTransport()
transport.server_mode = True
kex = KexGex(transport)
kex.start_kex()
self.assertEquals((paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST, paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST_OLD), transport._expect)
self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST, paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST_OLD), transport._expect)
msg = Message()
msg.add_int(2048)
msg.rewind()
kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST_OLD, msg)
x = '1F0000008100FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF0000000102'
self.assertEquals(x, hexlify(str(transport._message)).upper())
self.assertEquals((paramiko.kex_gex._MSG_KEXDH_GEX_INIT,), transport._expect)
x = b'1F0000008100FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF0000000102'
self.assertEqual(x, hexlify(transport._message.asbytes()).upper())
self.assertEqual((paramiko.kex_gex._MSG_KEXDH_GEX_INIT,), transport._expect)
msg = Message()
msg.add_mpint(12345)
msg.rewind()
kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_INIT, msg)
K = 67592995013596137876033460028393339951879041140378510871612128162185209509220726296697886624612526735888348020498716482757677848959420073720160491114319163078862905400020959196386947926388406687288901564192071077389283980347784184487280885335302632305026248574716290537036069329724382811853044654824945750581L
H = 'B41A06B2E59043CEFC1AE16EC31F1E2D12EC455B'
x = '210000000866616B652D6B6579000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D40000000866616B652D736967'
self.assertEquals(K, transport._K)
self.assertEquals(H, hexlify(transport._H).upper())
self.assertEquals(x, hexlify(str(transport._message)).upper())
self.assert_(transport._activated)
K = 67592995013596137876033460028393339951879041140378510871612128162185209509220726296697886624612526735888348020498716482757677848959420073720160491114319163078862905400020959196386947926388406687288901564192071077389283980347784184487280885335302632305026248574716290537036069329724382811853044654824945750581
H = b'B41A06B2E59043CEFC1AE16EC31F1E2D12EC455B'
x = b'210000000866616B652D6B6579000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D40000000866616B652D736967'
self.assertEqual(K, transport._K)
self.assertEqual(H, hexlify(transport._H).upper())
self.assertEqual(x, hexlify(transport._message.asbytes()).upper())
self.assertTrue(transport._activated)

View File

@ -22,14 +22,15 @@ Some unit tests for ssh protocol message blocks.
import unittest
from paramiko.message import Message
from paramiko.common import *
class MessageTest (unittest.TestCase):
__a = '\x00\x00\x00\x17\x07\x60\xe0\x90\x00\x00\x00\x01q\x00\x00\x00\x05hello\x00\x00\x03\xe8' + ('x' * 1000)
__b = '\x01\x00\xf3\x00\x3f\x00\x00\x00\x10huey,dewey,louie'
__c = '\x00\x00\x00\x00\x00\x00\x00\x05\x00\x00\xf5\xe4\xd3\xc2\xb1\x09\x00\x00\x00\x01\x11\x00\x00\x00\x07\x00\xf5\xe4\xd3\xc2\xb1\x09\x00\x00\x00\x06\x9a\x1b\x2c\x3d\x4e\xf7'
__d = '\x00\x00\x00\x05\x00\x00\x00\x05\x11\x22\x33\x44\x55\x01\x00\x00\x00\x03cat\x00\x00\x00\x03a,b'
__a = b'\x00\x00\x00\x17\x07\x60\xe0\x90\x00\x00\x00\x01\x71\x00\x00\x00\x05\x68\x65\x6c\x6c\x6f\x00\x00\x03\xe8' + b'x' * 1000
__b = b'\x01\x00\xf3\x00\x3f\x00\x00\x00\x10\x68\x75\x65\x79\x2c\x64\x65\x77\x65\x79\x2c\x6c\x6f\x75\x69\x65'
__c = b'\x00\x00\x00\x00\x00\x00\x00\x05\x00\x00\xf5\xe4\xd3\xc2\xb1\x09\x00\x00\x00\x01\x11\x00\x00\x00\x07\x00\xf5\xe4\xd3\xc2\xb1\x09\x00\x00\x00\x06\x9a\x1b\x2c\x3d\x4e\xf7'
__d = b'\x00\x00\x00\x05\xff\x00\x00\x00\x05\x11\x22\x33\x44\x55\xff\x00\x00\x00\x0a\x00\xf0\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x03\x63\x61\x74\x00\x00\x00\x03\x61\x2c\x62'
def test_1_encode(self):
msg = Message()
@ -38,63 +39,65 @@ class MessageTest (unittest.TestCase):
msg.add_string('q')
msg.add_string('hello')
msg.add_string('x' * 1000)
self.assertEquals(str(msg), self.__a)
self.assertEqual(msg.asbytes(), self.__a)
msg = Message()
msg.add_boolean(True)
msg.add_boolean(False)
msg.add_byte('\xf3')
msg.add_bytes('\x00\x3f')
msg.add_byte(byte_chr(0xf3))
msg.add_bytes(zero_byte + byte_chr(0x3f))
msg.add_list(['huey', 'dewey', 'louie'])
self.assertEquals(str(msg), self.__b)
self.assertEqual(msg.asbytes(), self.__b)
msg = Message()
msg.add_int64(5)
msg.add_int64(0xf5e4d3c2b109L)
msg.add_int64(0xf5e4d3c2b109)
msg.add_mpint(17)
msg.add_mpint(0xf5e4d3c2b109L)
msg.add_mpint(-0x65e4d3c2b109L)
self.assertEquals(str(msg), self.__c)
msg.add_mpint(0xf5e4d3c2b109)
msg.add_mpint(-0x65e4d3c2b109)
self.assertEqual(msg.asbytes(), self.__c)
def test_2_decode(self):
msg = Message(self.__a)
self.assertEquals(msg.get_int(), 23)
self.assertEquals(msg.get_int(), 123789456)
self.assertEquals(msg.get_string(), 'q')
self.assertEquals(msg.get_string(), 'hello')
self.assertEquals(msg.get_string(), 'x' * 1000)
self.assertEqual(msg.get_int(), 23)
self.assertEqual(msg.get_int(), 123789456)
self.assertEqual(msg.get_text(), 'q')
self.assertEqual(msg.get_text(), 'hello')
self.assertEqual(msg.get_text(), 'x' * 1000)
msg = Message(self.__b)
self.assertEquals(msg.get_boolean(), True)
self.assertEquals(msg.get_boolean(), False)
self.assertEquals(msg.get_byte(), '\xf3')
self.assertEquals(msg.get_bytes(2), '\x00\x3f')
self.assertEquals(msg.get_list(), ['huey', 'dewey', 'louie'])
self.assertEqual(msg.get_boolean(), True)
self.assertEqual(msg.get_boolean(), False)
self.assertEqual(msg.get_byte(), byte_chr(0xf3))
self.assertEqual(msg.get_bytes(2), zero_byte + byte_chr(0x3f))
self.assertEqual(msg.get_list(), ['huey', 'dewey', 'louie'])
msg = Message(self.__c)
self.assertEquals(msg.get_int64(), 5)
self.assertEquals(msg.get_int64(), 0xf5e4d3c2b109L)
self.assertEquals(msg.get_mpint(), 17)
self.assertEquals(msg.get_mpint(), 0xf5e4d3c2b109L)
self.assertEquals(msg.get_mpint(), -0x65e4d3c2b109L)
self.assertEqual(msg.get_int64(), 5)
self.assertEqual(msg.get_int64(), 0xf5e4d3c2b109)
self.assertEqual(msg.get_mpint(), 17)
self.assertEqual(msg.get_mpint(), 0xf5e4d3c2b109)
self.assertEqual(msg.get_mpint(), -0x65e4d3c2b109)
def test_3_add(self):
msg = Message()
msg.add(5)
msg.add(0x1122334455L)
msg.add(0x1122334455)
msg.add(0xf00000000000000000)
msg.add(True)
msg.add('cat')
msg.add(['a', 'b'])
self.assertEquals(str(msg), self.__d)
self.assertEqual(msg.asbytes(), self.__d)
def test_4_misc(self):
msg = Message(self.__d)
self.assertEquals(msg.get_int(), 5)
self.assertEquals(msg.get_mpint(), 0x1122334455L)
self.assertEquals(msg.get_so_far(), self.__d[:13])
self.assertEquals(msg.get_remainder(), self.__d[13:])
self.assertEqual(msg.get_int(), 5)
self.assertEqual(msg.get_int(), 0x1122334455)
self.assertEqual(msg.get_int(), 0xf00000000000000000)
self.assertEqual(msg.get_so_far(), self.__d[:29])
self.assertEqual(msg.get_remainder(), self.__d[29:])
msg.rewind()
self.assertEquals(msg.get_int(), 5)
self.assertEquals(msg.get_so_far(), self.__d[:4])
self.assertEquals(msg.get_remainder(), self.__d[4:])
self.assertEqual(msg.get_int(), 5)
self.assertEqual(msg.get_so_far(), self.__d[:4])
self.assertEqual(msg.get_remainder(), self.__d[4:])

View File

@ -21,10 +21,15 @@ Some unit tests for the ssh2 protocol in Transport.
"""
import unittest
from loop import LoopSocket
from tests.loop import LoopSocket
from Crypto.Cipher import AES
from Crypto.Hash import SHA, HMAC
from paramiko import Message, Packetizer, util
from paramiko.common import *
x55 = byte_chr(0x55)
x1f = byte_chr(0x1f)
class PacketizerTest (unittest.TestCase):
@ -35,21 +40,21 @@ class PacketizerTest (unittest.TestCase):
p = Packetizer(wsock)
p.set_log(util.get_logger('paramiko.transport'))
p.set_hexdump(True)
cipher = AES.new('\x00' * 16, AES.MODE_CBC, '\x55' * 16)
p.set_outbound_cipher(cipher, 16, SHA, 12, '\x1f' * 20)
cipher = AES.new(zero_byte * 16, AES.MODE_CBC, x55 * 16)
p.set_outbound_cipher(cipher, 16, SHA, 12, x1f * 20)
# message has to be at least 16 bytes long, so we'll have at least one
# block of data encrypted that contains zero random padding bytes
m = Message()
m.add_byte(chr(100))
m.add_byte(byte_chr(100))
m.add_int(100)
m.add_int(1)
m.add_int(900)
p.send_message(m)
data = rsock.recv(100)
# 32 + 12 bytes of MAC = 44
self.assertEquals(44, len(data))
self.assertEquals('\x43\x91\x97\xbd\x5b\x50\xac\x25\x87\xc2\xc4\x6b\xc7\xe9\x38\xc0', data[:16])
self.assertEqual(44, len(data))
self.assertEqual(b'\x43\x91\x97\xbd\x5b\x50\xac\x25\x87\xc2\xc4\x6b\xc7\xe9\x38\xc0', data[:16])
def test_2_read (self):
rsock = LoopSocket()
@ -58,13 +63,11 @@ class PacketizerTest (unittest.TestCase):
p = Packetizer(rsock)
p.set_log(util.get_logger('paramiko.transport'))
p.set_hexdump(True)
cipher = AES.new('\x00' * 16, AES.MODE_CBC, '\x55' * 16)
p.set_inbound_cipher(cipher, 16, SHA, 12, '\x1f' * 20)
wsock.send('C\x91\x97\xbd[P\xac%\x87\xc2\xc4k\xc7\xe98\xc0' + \
'\x90\xd2\x16V\rqsa8|L=\xfb\x97}\xe2n\x03\xb1\xa0\xc2\x1c\xd6AAL\xb4Y')
cipher = AES.new(zero_byte * 16, AES.MODE_CBC, x55 * 16)
p.set_inbound_cipher(cipher, 16, SHA, 12, x1f * 20)
wsock.send(b'\x43\x91\x97\xbd\x5b\x50\xac\x25\x87\xc2\xc4\x6b\xc7\xe9\x38\xc0\x90\xd2\x16\x56\x0d\x71\x73\x61\x38\x7c\x4c\x3d\xfb\x97\x7d\xe2\x6e\x03\xb1\xa0\xc2\x1c\xd6\x41\x41\x4c\xb4\x59')
cmd, m = p.read_message()
self.assertEquals(100, cmd)
self.assertEquals(100, m.get_int())
self.assertEquals(1, m.get_int())
self.assertEquals(900, m.get_int())
self.assertEqual(100, cmd)
self.assertEqual(100, m.get_int())
self.assertEqual(1, m.get_int())
self.assertEqual(900, m.get_int())

View File

@ -20,11 +20,11 @@
Some unit tests for public/private key objects.
"""
from binascii import hexlify, unhexlify
import StringIO
from binascii import hexlify
import unittest
from paramiko import RSAKey, DSSKey, ECDSAKey, Message, util
from paramiko.common import rng
from paramiko.common import rng, StringIO, byte_chr, b, bytes
from tests.util import test_path
# from openssh's ssh-keygen
PUB_RSA = 'ssh-rsa AAAAB3NzaC1yc2EAAAABIwAAAIEA049W6geFpmsljTwfvI1UmKWWJPNFI74+vNKTk4dmzkQY2yAMs6FhlvhlI8ysU4oj71ZsRYMecHbBbxdN79+JRFVYTKaLqjwGENeTd+yv4q+V2PvZv3fLnzApI3l7EJCqhWwJUHJ1jAkZzqDx0tyOL4uoZpww3nmE0kb3y21tH4c='
@ -77,6 +77,9 @@ ADRvOqQ5R98Sxst765CAqXmRtz8vwoD96g==
-----END EC PRIVATE KEY-----
"""
x1234 = b'\x01\x02\x03\x04'
class KeyTest (unittest.TestCase):
def setUp(self):
@ -87,164 +90,164 @@ class KeyTest (unittest.TestCase):
def test_1_generate_key_bytes(self):
from Crypto.Hash import MD5
key = util.generate_key_bytes(MD5, '\x01\x02\x03\x04', 'happy birthday', 30)
exp = unhexlify('61E1F272F4C1C4561586BD322498C0E924672780F47BB37DDA7D54019E64')
self.assertEquals(exp, key)
key = util.generate_key_bytes(MD5, x1234, 'happy birthday', 30)
exp = b'\x61\xE1\xF2\x72\xF4\xC1\xC4\x56\x15\x86\xBD\x32\x24\x98\xC0\xE9\x24\x67\x27\x80\xF4\x7B\xB3\x7D\xDA\x7D\x54\x01\x9E\x64'
self.assertEqual(exp, key)
def test_2_load_rsa(self):
key = RSAKey.from_private_key_file('tests/test_rsa.key')
self.assertEquals('ssh-rsa', key.get_name())
exp_rsa = FINGER_RSA.split()[1].replace(':', '')
key = RSAKey.from_private_key_file(test_path('test_rsa.key'))
self.assertEqual('ssh-rsa', key.get_name())
exp_rsa = b(FINGER_RSA.split()[1].replace(':', ''))
my_rsa = hexlify(key.get_fingerprint())
self.assertEquals(exp_rsa, my_rsa)
self.assertEquals(PUB_RSA.split()[1], key.get_base64())
self.assertEquals(1024, key.get_bits())
self.assertEqual(exp_rsa, my_rsa)
self.assertEqual(PUB_RSA.split()[1], key.get_base64())
self.assertEqual(1024, key.get_bits())
s = StringIO.StringIO()
s = StringIO()
key.write_private_key(s)
self.assertEquals(RSA_PRIVATE_OUT, s.getvalue())
self.assertEqual(RSA_PRIVATE_OUT, s.getvalue())
s.seek(0)
key2 = RSAKey.from_private_key(s)
self.assertEquals(key, key2)
self.assertEqual(key, key2)
def test_3_load_rsa_password(self):
key = RSAKey.from_private_key_file('tests/test_rsa_password.key', 'television')
self.assertEquals('ssh-rsa', key.get_name())
exp_rsa = FINGER_RSA.split()[1].replace(':', '')
key = RSAKey.from_private_key_file(test_path('test_rsa_password.key'), 'television')
self.assertEqual('ssh-rsa', key.get_name())
exp_rsa = b(FINGER_RSA.split()[1].replace(':', ''))
my_rsa = hexlify(key.get_fingerprint())
self.assertEquals(exp_rsa, my_rsa)
self.assertEquals(PUB_RSA.split()[1], key.get_base64())
self.assertEquals(1024, key.get_bits())
self.assertEqual(exp_rsa, my_rsa)
self.assertEqual(PUB_RSA.split()[1], key.get_base64())
self.assertEqual(1024, key.get_bits())
def test_4_load_dss(self):
key = DSSKey.from_private_key_file('tests/test_dss.key')
self.assertEquals('ssh-dss', key.get_name())
exp_dss = FINGER_DSS.split()[1].replace(':', '')
key = DSSKey.from_private_key_file(test_path('test_dss.key'))
self.assertEqual('ssh-dss', key.get_name())
exp_dss = b(FINGER_DSS.split()[1].replace(':', ''))
my_dss = hexlify(key.get_fingerprint())
self.assertEquals(exp_dss, my_dss)
self.assertEquals(PUB_DSS.split()[1], key.get_base64())
self.assertEquals(1024, key.get_bits())
self.assertEqual(exp_dss, my_dss)
self.assertEqual(PUB_DSS.split()[1], key.get_base64())
self.assertEqual(1024, key.get_bits())
s = StringIO.StringIO()
s = StringIO()
key.write_private_key(s)
self.assertEquals(DSS_PRIVATE_OUT, s.getvalue())
self.assertEqual(DSS_PRIVATE_OUT, s.getvalue())
s.seek(0)
key2 = DSSKey.from_private_key(s)
self.assertEquals(key, key2)
self.assertEqual(key, key2)
def test_5_load_dss_password(self):
key = DSSKey.from_private_key_file('tests/test_dss_password.key', 'television')
self.assertEquals('ssh-dss', key.get_name())
exp_dss = FINGER_DSS.split()[1].replace(':', '')
key = DSSKey.from_private_key_file(test_path('test_dss_password.key'), 'television')
self.assertEqual('ssh-dss', key.get_name())
exp_dss = b(FINGER_DSS.split()[1].replace(':', ''))
my_dss = hexlify(key.get_fingerprint())
self.assertEquals(exp_dss, my_dss)
self.assertEquals(PUB_DSS.split()[1], key.get_base64())
self.assertEquals(1024, key.get_bits())
self.assertEqual(exp_dss, my_dss)
self.assertEqual(PUB_DSS.split()[1], key.get_base64())
self.assertEqual(1024, key.get_bits())
def test_6_compare_rsa(self):
# verify that the private & public keys compare equal
key = RSAKey.from_private_key_file('tests/test_rsa.key')
self.assertEquals(key, key)
pub = RSAKey(data=str(key))
self.assert_(key.can_sign())
self.assert_(not pub.can_sign())
self.assertEquals(key, pub)
key = RSAKey.from_private_key_file(test_path('test_rsa.key'))
self.assertEqual(key, key)
pub = RSAKey(data=key.asbytes())
self.assertTrue(key.can_sign())
self.assertTrue(not pub.can_sign())
self.assertEqual(key, pub)
def test_7_compare_dss(self):
# verify that the private & public keys compare equal
key = DSSKey.from_private_key_file('tests/test_dss.key')
self.assertEquals(key, key)
pub = DSSKey(data=str(key))
self.assert_(key.can_sign())
self.assert_(not pub.can_sign())
self.assertEquals(key, pub)
key = DSSKey.from_private_key_file(test_path('test_dss.key'))
self.assertEqual(key, key)
pub = DSSKey(data=key.asbytes())
self.assertTrue(key.can_sign())
self.assertTrue(not pub.can_sign())
self.assertEqual(key, pub)
def test_8_sign_rsa(self):
# verify that the rsa private key can sign and verify
key = RSAKey.from_private_key_file('tests/test_rsa.key')
msg = key.sign_ssh_data(rng, 'ice weasels')
self.assert_(type(msg) is Message)
key = RSAKey.from_private_key_file(test_path('test_rsa.key'))
msg = key.sign_ssh_data(rng, b'ice weasels')
self.assertTrue(type(msg) is Message)
msg.rewind()
self.assertEquals('ssh-rsa', msg.get_string())
sig = ''.join([chr(int(x, 16)) for x in SIGNED_RSA.split(':')])
self.assertEquals(sig, msg.get_string())
self.assertEqual('ssh-rsa', msg.get_text())
sig = bytes().join([byte_chr(int(x, 16)) for x in SIGNED_RSA.split(':')])
self.assertEqual(sig, msg.get_binary())
msg.rewind()
pub = RSAKey(data=str(key))
self.assert_(pub.verify_ssh_sig('ice weasels', msg))
pub = RSAKey(data=key.asbytes())
self.assertTrue(pub.verify_ssh_sig(b'ice weasels', msg))
def test_9_sign_dss(self):
# verify that the dss private key can sign and verify
key = DSSKey.from_private_key_file('tests/test_dss.key')
msg = key.sign_ssh_data(rng, 'ice weasels')
self.assert_(type(msg) is Message)
key = DSSKey.from_private_key_file(test_path('test_dss.key'))
msg = key.sign_ssh_data(rng, b'ice weasels')
self.assertTrue(type(msg) is Message)
msg.rewind()
self.assertEquals('ssh-dss', msg.get_string())
self.assertEqual('ssh-dss', msg.get_text())
# can't do the same test as we do for RSA, because DSS signatures
# are usually different each time. but we can test verification
# anyway so it's ok.
self.assertEquals(40, len(msg.get_string()))
self.assertEqual(40, len(msg.get_binary()))
msg.rewind()
pub = DSSKey(data=str(key))
self.assert_(pub.verify_ssh_sig('ice weasels', msg))
pub = DSSKey(data=key.asbytes())
self.assertTrue(pub.verify_ssh_sig(b'ice weasels', msg))
def test_A_generate_rsa(self):
key = RSAKey.generate(1024)
msg = key.sign_ssh_data(rng, 'jerri blank')
msg = key.sign_ssh_data(rng, b'jerri blank')
msg.rewind()
self.assert_(key.verify_ssh_sig('jerri blank', msg))
self.assertTrue(key.verify_ssh_sig(b'jerri blank', msg))
def test_B_generate_dss(self):
key = DSSKey.generate(1024)
msg = key.sign_ssh_data(rng, 'jerri blank')
msg = key.sign_ssh_data(rng, b'jerri blank')
msg.rewind()
self.assert_(key.verify_ssh_sig('jerri blank', msg))
self.assertTrue(key.verify_ssh_sig(b'jerri blank', msg))
def test_10_load_ecdsa(self):
key = ECDSAKey.from_private_key_file('tests/test_ecdsa.key')
self.assertEquals('ecdsa-sha2-nistp256', key.get_name())
exp_ecdsa = FINGER_ECDSA.split()[1].replace(':', '')
key = ECDSAKey.from_private_key_file(test_path('test_ecdsa.key'))
self.assertEqual('ecdsa-sha2-nistp256', key.get_name())
exp_ecdsa = b(FINGER_ECDSA.split()[1].replace(':', ''))
my_ecdsa = hexlify(key.get_fingerprint())
self.assertEquals(exp_ecdsa, my_ecdsa)
self.assertEquals(PUB_ECDSA.split()[1], key.get_base64())
self.assertEquals(256, key.get_bits())
self.assertEqual(exp_ecdsa, my_ecdsa)
self.assertEqual(PUB_ECDSA.split()[1], key.get_base64())
self.assertEqual(256, key.get_bits())
s = StringIO.StringIO()
s = StringIO()
key.write_private_key(s)
self.assertEquals(ECDSA_PRIVATE_OUT, s.getvalue())
self.assertEqual(ECDSA_PRIVATE_OUT, s.getvalue())
s.seek(0)
key2 = ECDSAKey.from_private_key(s)
self.assertEquals(key, key2)
self.assertEqual(key, key2)
def test_11_load_ecdsa_password(self):
key = ECDSAKey.from_private_key_file('tests/test_ecdsa_password.key', 'television')
self.assertEquals('ecdsa-sha2-nistp256', key.get_name())
exp_ecdsa = FINGER_ECDSA.split()[1].replace(':', '')
key = ECDSAKey.from_private_key_file(test_path('test_ecdsa_password.key'), b'television')
self.assertEqual('ecdsa-sha2-nistp256', key.get_name())
exp_ecdsa = b(FINGER_ECDSA.split()[1].replace(':', ''))
my_ecdsa = hexlify(key.get_fingerprint())
self.assertEquals(exp_ecdsa, my_ecdsa)
self.assertEquals(PUB_ECDSA.split()[1], key.get_base64())
self.assertEquals(256, key.get_bits())
self.assertEqual(exp_ecdsa, my_ecdsa)
self.assertEqual(PUB_ECDSA.split()[1], key.get_base64())
self.assertEqual(256, key.get_bits())
def test_12_compare_ecdsa(self):
# verify that the private & public keys compare equal
key = ECDSAKey.from_private_key_file('tests/test_ecdsa.key')
self.assertEquals(key, key)
pub = ECDSAKey(data=str(key))
self.assert_(key.can_sign())
self.assert_(not pub.can_sign())
self.assertEquals(key, pub)
key = ECDSAKey.from_private_key_file(test_path('test_ecdsa.key'))
self.assertEqual(key, key)
pub = ECDSAKey(data=key.asbytes())
self.assertTrue(key.can_sign())
self.assertTrue(not pub.can_sign())
self.assertEqual(key, pub)
def test_13_sign_ecdsa(self):
# verify that the rsa private key can sign and verify
key = ECDSAKey.from_private_key_file('tests/test_ecdsa.key')
msg = key.sign_ssh_data(rng, 'ice weasels')
self.assert_(type(msg) is Message)
key = ECDSAKey.from_private_key_file(test_path('test_ecdsa.key'))
msg = key.sign_ssh_data(rng, b'ice weasels')
self.assertTrue(type(msg) is Message)
msg.rewind()
self.assertEquals('ecdsa-sha2-nistp256', msg.get_string())
self.assertEqual('ecdsa-sha2-nistp256', msg.get_text())
# ECDSA signatures, like DSS signatures, tend to be different
# each time, so we can't compare against a "known correct"
# signature.
# Even the length of the signature can change.
msg.rewind()
pub = ECDSAKey(data=str(key))
self.assert_(pub.verify_ssh_sig('ice weasels', msg))
pub = ECDSAKey(data=key.asbytes())
self.assertTrue(pub.verify_ssh_sig(b'ice weasels', msg))

View File

@ -23,19 +23,18 @@ a real actual sftp server is contacted, and a new folder is created there to
do test file operations in (so no existing files will be harmed).
"""
from __future__ import with_statement
from binascii import hexlify
import os
import warnings
import sys
import threading
import unittest
import StringIO
from tempfile import mkstemp
import paramiko
from stub_sftp import StubServer, StubSFTPServer
from loop import LoopSocket
from paramiko.common import *
from tests.stub_sftp import StubServer, StubSFTPServer
from tests.loop import LoopSocket
from tests.util import test_path
from paramiko.sftp_attr import SFTPAttributes
ARTICLE = '''
@ -70,6 +69,10 @@ FOLDER = os.environ.get('TEST_FOLDER', 'temp-testing000')
sftp = None
tc = None
g_big_file_test = True
# we need to use eval(compile()) here because Py3.2 doesn't support the 'u' marker for unicode
# this test is the only line in the entire program that has to be treated specially to support Py3.2
unicode_folder = eval(compile(r"u'\u00fcnic\u00f8de'" if PY2 else r"'\u00fcnic\u00f8de'", 'test_sftp.py', 'eval'))
utf8_folder = b'/\xc3\xbcnic\xc3\xb8\x64\x65'
def get_sftp():
@ -121,7 +124,7 @@ class SFTPTest (unittest.TestCase):
tc = paramiko.Transport(sockc)
ts = paramiko.Transport(socks)
host_key = paramiko.RSAKey.from_private_key_file('tests/test_rsa.key')
host_key = paramiko.RSAKey.from_private_key_file(test_path('test_rsa.key'))
ts.add_server_key(host_key)
event = threading.Event()
server = StubServer()
@ -140,7 +143,7 @@ class SFTPTest (unittest.TestCase):
def setUp(self):
global FOLDER
for i in xrange(1000):
for i in range(1000):
FOLDER = FOLDER[:-3] + '%03d' % i
try:
sftp.mkdir(FOLDER)
@ -149,6 +152,7 @@ class SFTPTest (unittest.TestCase):
pass
def tearDown(self):
#sftp.chdir()
sftp.rmdir(FOLDER)
def test_1_file(self):
@ -158,8 +162,8 @@ class SFTPTest (unittest.TestCase):
f = sftp.open(FOLDER + '/test', 'w')
try:
self.assertEqual(f.stat().st_size, 0)
f.close()
finally:
f.close()
sftp.remove(FOLDER + '/test')
def test_2_close(self):
@ -180,10 +184,9 @@ class SFTPTest (unittest.TestCase):
"""
verify that a file can be created and written, and the size is correct.
"""
f = sftp.open(FOLDER + '/duck.txt', 'w')
try:
f.write(ARTICLE)
f.close()
with sftp.open(FOLDER + '/duck.txt', 'w') as f:
f.write(ARTICLE)
self.assertEqual(sftp.stat(FOLDER + '/duck.txt').st_size, 1483)
finally:
sftp.remove(FOLDER + '/duck.txt')
@ -203,19 +206,17 @@ class SFTPTest (unittest.TestCase):
"""
verify that a file can be opened for append, and tell() still works.
"""
f = sftp.open(FOLDER + '/append.txt', 'w')
try:
f.write('first line\nsecond line\n')
self.assertEqual(f.tell(), 23)
f.close()
with sftp.open(FOLDER + '/append.txt', 'w') as f:
f.write('first line\nsecond line\n')
self.assertEqual(f.tell(), 23)
f = sftp.open(FOLDER + '/append.txt', 'a+')
f.write('third line!!!\n')
self.assertEqual(f.tell(), 37)
self.assertEqual(f.stat().st_size, 37)
f.seek(-26, f.SEEK_CUR)
self.assertEqual(f.readline(), 'second line\n')
f.close()
with sftp.open(FOLDER + '/append.txt', 'a+') as f:
f.write('third line!!!\n')
self.assertEqual(f.tell(), 37)
self.assertEqual(f.stat().st_size, 37)
f.seek(-26, f.SEEK_CUR)
self.assertEqual(f.readline(), 'second line\n')
finally:
sftp.remove(FOLDER + '/append.txt')
@ -223,20 +224,18 @@ class SFTPTest (unittest.TestCase):
"""
verify that renaming a file works.
"""
f = sftp.open(FOLDER + '/first.txt', 'w')
try:
f.write('content!\n')
f.close()
with sftp.open(FOLDER + '/first.txt', 'w') as f:
f.write('content!\n')
sftp.rename(FOLDER + '/first.txt', FOLDER + '/second.txt')
try:
f = sftp.open(FOLDER + '/first.txt', 'r')
self.assert_(False, 'no exception on reading nonexistent file')
sftp.open(FOLDER + '/first.txt', 'r')
self.assertTrue(False, 'no exception on reading nonexistent file')
except IOError:
pass
f = sftp.open(FOLDER + '/second.txt', 'r')
f.seek(-6, f.SEEK_END)
self.assertEqual(f.read(4), 'tent')
f.close()
with sftp.open(FOLDER + '/second.txt', 'r') as f:
f.seek(-6, f.SEEK_END)
self.assertEqual(u(f.read(4)), 'tent')
finally:
try:
sftp.remove(FOLDER + '/first.txt')
@ -253,14 +252,13 @@ class SFTPTest (unittest.TestCase):
remove the folder and verify that we can't create a file in it anymore.
"""
sftp.mkdir(FOLDER + '/subfolder')
f = sftp.open(FOLDER + '/subfolder/test', 'w')
f.close()
sftp.open(FOLDER + '/subfolder/test', 'w').close()
sftp.remove(FOLDER + '/subfolder/test')
sftp.rmdir(FOLDER + '/subfolder')
try:
f = sftp.open(FOLDER + '/subfolder/test')
sftp.open(FOLDER + '/subfolder/test')
# shouldn't be able to create that file
self.assert_(False, 'no exception at dummy file creation')
self.assertTrue(False, 'no exception at dummy file creation')
except IOError:
pass
@ -270,21 +268,16 @@ class SFTPTest (unittest.TestCase):
and those files show up in sftp.listdir.
"""
try:
f = sftp.open(FOLDER + '/duck.txt', 'w')
f.close()
f = sftp.open(FOLDER + '/fish.txt', 'w')
f.close()
f = sftp.open(FOLDER + '/tertiary.py', 'w')
f.close()
sftp.open(FOLDER + '/duck.txt', 'w').close()
sftp.open(FOLDER + '/fish.txt', 'w').close()
sftp.open(FOLDER + '/tertiary.py', 'w').close()
x = sftp.listdir(FOLDER)
self.assertEqual(len(x), 3)
self.assert_('duck.txt' in x)
self.assert_('fish.txt' in x)
self.assert_('tertiary.py' in x)
self.assert_('random' not in x)
self.assertTrue('duck.txt' in x)
self.assertTrue('fish.txt' in x)
self.assertTrue('tertiary.py' in x)
self.assertTrue('random' not in x)
finally:
sftp.remove(FOLDER + '/duck.txt')
sftp.remove(FOLDER + '/fish.txt')
@ -294,22 +287,21 @@ class SFTPTest (unittest.TestCase):
"""
verify that the setstat functions (chown, chmod, utime, truncate) work.
"""
f = sftp.open(FOLDER + '/special', 'w')
try:
f.write('x' * 1024)
f.close()
with sftp.open(FOLDER + '/special', 'w') as f:
f.write('x' * 1024)
stat = sftp.stat(FOLDER + '/special')
sftp.chmod(FOLDER + '/special', (stat.st_mode & ~0777) | 0600)
sftp.chmod(FOLDER + '/special', (stat.st_mode & ~o777) | o600)
stat = sftp.stat(FOLDER + '/special')
expected_mode = 0600
expected_mode = o600
if sys.platform == 'win32':
# chmod not really functional on windows
expected_mode = 0666
expected_mode = o666
if sys.platform == 'cygwin':
# even worse.
expected_mode = 0644
self.assertEqual(stat.st_mode & 0777, expected_mode)
expected_mode = o644
self.assertEqual(stat.st_mode & o777, expected_mode)
self.assertEqual(stat.st_size, 1024)
mtime = stat.st_mtime - 3600
@ -333,40 +325,38 @@ class SFTPTest (unittest.TestCase):
verify that the fsetstat functions (chown, chmod, utime, truncate)
work on open files.
"""
f = sftp.open(FOLDER + '/special', 'w')
try:
f.write('x' * 1024)
f.close()
with sftp.open(FOLDER + '/special', 'w') as f:
f.write('x' * 1024)
f = sftp.open(FOLDER + '/special', 'r+')
stat = f.stat()
f.chmod((stat.st_mode & ~0777) | 0600)
stat = f.stat()
with sftp.open(FOLDER + '/special', 'r+') as f:
stat = f.stat()
f.chmod((stat.st_mode & ~o777) | o600)
stat = f.stat()
expected_mode = 0600
if sys.platform == 'win32':
# chmod not really functional on windows
expected_mode = 0666
if sys.platform == 'cygwin':
# even worse.
expected_mode = 0644
self.assertEqual(stat.st_mode & 0777, expected_mode)
self.assertEqual(stat.st_size, 1024)
expected_mode = o600
if sys.platform == 'win32':
# chmod not really functional on windows
expected_mode = o666
if sys.platform == 'cygwin':
# even worse.
expected_mode = o644
self.assertEqual(stat.st_mode & o777, expected_mode)
self.assertEqual(stat.st_size, 1024)
mtime = stat.st_mtime - 3600
atime = stat.st_atime - 1800
f.utime((atime, mtime))
stat = f.stat()
self.assertEqual(stat.st_mtime, mtime)
if sys.platform not in ('win32', 'cygwin'):
self.assertEqual(stat.st_atime, atime)
mtime = stat.st_mtime - 3600
atime = stat.st_atime - 1800
f.utime((atime, mtime))
stat = f.stat()
self.assertEqual(stat.st_mtime, mtime)
if sys.platform not in ('win32', 'cygwin'):
self.assertEqual(stat.st_atime, atime)
# can't really test chown, since we'd have to know a valid uid.
# can't really test chown, since we'd have to know a valid uid.
f.truncate(512)
stat = f.stat()
self.assertEqual(stat.st_size, 512)
f.close()
f.truncate(512)
stat = f.stat()
self.assertEqual(stat.st_size, 512)
finally:
sftp.remove(FOLDER + '/special')
@ -378,25 +368,23 @@ class SFTPTest (unittest.TestCase):
buffering is reset on 'seek'.
"""
try:
f = sftp.open(FOLDER + '/duck.txt', 'w')
f.write(ARTICLE)
f.close()
with sftp.open(FOLDER + '/duck.txt', 'w') as f:
f.write(ARTICLE)
f = sftp.open(FOLDER + '/duck.txt', 'r+')
line_number = 0
loc = 0
pos_list = []
for line in f:
line_number += 1
pos_list.append(loc)
loc = f.tell()
f.seek(pos_list[6], f.SEEK_SET)
self.assertEqual(f.readline(), 'Nouzilly, France.\n')
f.seek(pos_list[17], f.SEEK_SET)
self.assertEqual(f.readline()[:4], 'duck')
f.seek(pos_list[10], f.SEEK_SET)
self.assertEqual(f.readline(), 'duck types were equally resistant to exogenous insulin compared with chicken.\n')
f.close()
with sftp.open(FOLDER + '/duck.txt', 'r+') as f:
line_number = 0
loc = 0
pos_list = []
for line in f:
line_number += 1
pos_list.append(loc)
loc = f.tell()
f.seek(pos_list[6], f.SEEK_SET)
self.assertEqual(f.readline(), 'Nouzilly, France.\n')
f.seek(pos_list[17], f.SEEK_SET)
self.assertEqual(f.readline()[:4], 'duck')
f.seek(pos_list[10], f.SEEK_SET)
self.assertEqual(f.readline(), 'duck types were equally resistant to exogenous insulin compared with chicken.\n')
finally:
sftp.remove(FOLDER + '/duck.txt')
@ -405,17 +393,15 @@ class SFTPTest (unittest.TestCase):
create a text file, seek back and change part of it, and verify that the
changes worked.
"""
f = sftp.open(FOLDER + '/testing.txt', 'w')
try:
f.write('hello kitty.\n')
f.seek(-5, f.SEEK_CUR)
f.write('dd')
f.close()
with sftp.open(FOLDER + '/testing.txt', 'w') as f:
f.write('hello kitty.\n')
f.seek(-5, f.SEEK_CUR)
f.write('dd')
self.assertEqual(sftp.stat(FOLDER + '/testing.txt').st_size, 13)
f = sftp.open(FOLDER + '/testing.txt', 'r')
data = f.read(20)
f.close()
with sftp.open(FOLDER + '/testing.txt', 'r') as f:
data = f.read(20)
self.assertEqual(data, 'hello kiddy.\n')
finally:
sftp.remove(FOLDER + '/testing.txt')
@ -428,16 +414,14 @@ class SFTPTest (unittest.TestCase):
# skip symlink tests on windows
return
f = sftp.open(FOLDER + '/original.txt', 'w')
try:
f.write('original\n')
f.close()
with sftp.open(FOLDER + '/original.txt', 'w') as f:
f.write('original\n')
sftp.symlink('original.txt', FOLDER + '/link.txt')
self.assertEqual(sftp.readlink(FOLDER + '/link.txt'), 'original.txt')
f = sftp.open(FOLDER + '/link.txt', 'r')
self.assertEqual(f.readlines(), ['original\n'])
f.close()
with sftp.open(FOLDER + '/link.txt', 'r') as f:
self.assertEqual(f.readlines(), ['original\n'])
cwd = sftp.normalize('.')
if cwd[-1] == '/':
@ -450,7 +434,7 @@ class SFTPTest (unittest.TestCase):
self.assertEqual(sftp.stat(FOLDER + '/link.txt').st_size, 9)
# the sftp server may be hiding extra path members from us, so the
# length may be longer than we expect:
self.assert_(sftp.lstat(FOLDER + '/link2.txt').st_size >= len(abs_path))
self.assertTrue(sftp.lstat(FOLDER + '/link2.txt').st_size >= len(abs_path))
self.assertEqual(sftp.stat(FOLDER + '/link2.txt').st_size, 9)
self.assertEqual(sftp.stat(FOLDER + '/original.txt').st_size, 9)
finally:
@ -471,18 +455,16 @@ class SFTPTest (unittest.TestCase):
"""
verify that buffered writes are automatically flushed on seek.
"""
f = sftp.open(FOLDER + '/happy.txt', 'w', 1)
try:
f.write('full line.\n')
f.write('partial')
f.seek(9, f.SEEK_SET)
f.write('?\n')
f.close()
with sftp.open(FOLDER + '/happy.txt', 'w', 1) as f:
f.write('full line.\n')
f.write('partial')
f.seek(9, f.SEEK_SET)
f.write('?\n')
f = sftp.open(FOLDER + '/happy.txt', 'r')
self.assertEqual(f.readline(), 'full line?\n')
self.assertEqual(f.read(7), 'partial')
f.close()
with sftp.open(FOLDER + '/happy.txt', 'r') as f:
self.assertEqual(f.readline(), 'full line?\n')
self.assertEqual(f.read(7), 'partial')
finally:
try:
sftp.remove(FOLDER + '/happy.txt')
@ -495,10 +477,10 @@ class SFTPTest (unittest.TestCase):
error.
"""
pwd = sftp.normalize('.')
self.assert_(len(pwd) > 0)
self.assertTrue(len(pwd) > 0)
f = sftp.normalize('./' + FOLDER)
self.assert_(len(f) > 0)
self.assertEquals(os.path.join(pwd, FOLDER), f)
self.assertTrue(len(f) > 0)
self.assertEqual(os.path.join(pwd, FOLDER), f)
def test_F_mkdir(self):
"""
@ -507,19 +489,19 @@ class SFTPTest (unittest.TestCase):
try:
sftp.mkdir(FOLDER + '/subfolder')
except:
self.assert_(False, 'exception creating subfolder')
self.assertTrue(False, 'exception creating subfolder')
try:
sftp.mkdir(FOLDER + '/subfolder')
self.assert_(False, 'no exception overwriting subfolder')
self.assertTrue(False, 'no exception overwriting subfolder')
except IOError:
pass
try:
sftp.rmdir(FOLDER + '/subfolder')
except:
self.assert_(False, 'exception removing subfolder')
self.assertTrue(False, 'exception removing subfolder')
try:
sftp.rmdir(FOLDER + '/subfolder')
self.assert_(False, 'no exception removing nonexistent subfolder')
self.assertTrue(False, 'no exception removing nonexistent subfolder')
except IOError:
pass
@ -534,17 +516,16 @@ class SFTPTest (unittest.TestCase):
sftp.mkdir(FOLDER + '/alpha')
sftp.chdir(FOLDER + '/alpha')
sftp.mkdir('beta')
self.assertEquals(root + FOLDER + '/alpha', sftp.getcwd())
self.assertEquals(['beta'], sftp.listdir('.'))
self.assertEqual(root + FOLDER + '/alpha', sftp.getcwd())
self.assertEqual(['beta'], sftp.listdir('.'))
sftp.chdir('beta')
f = sftp.open('fish', 'w')
f.write('hello\n')
f.close()
with sftp.open('fish', 'w') as f:
f.write('hello\n')
sftp.chdir('..')
self.assertEquals(['fish'], sftp.listdir('beta'))
self.assertEqual(['fish'], sftp.listdir('beta'))
sftp.chdir('..')
self.assertEquals(['fish'], sftp.listdir('alpha/beta'))
self.assertEqual(['fish'], sftp.listdir('alpha/beta'))
finally:
sftp.chdir(root)
try:
@ -566,30 +547,29 @@ class SFTPTest (unittest.TestCase):
"""
warnings.filterwarnings('ignore', 'tempnam.*')
localname = os.tempnam()
text = 'All I wanted was a plastic bunny rabbit.\n'
f = open(localname, 'wb')
f.write(text)
f.close()
fd, localname = mkstemp()
os.close(fd)
text = b'All I wanted was a plastic bunny rabbit.\n'
with open(localname, 'wb') as f:
f.write(text)
saved_progress = []
def progress_callback(x, y):
saved_progress.append((x, y))
sftp.put(localname, FOLDER + '/bunny.txt', progress_callback)
f = sftp.open(FOLDER + '/bunny.txt', 'r')
self.assertEquals(text, f.read(128))
f.close()
self.assertEquals((41, 41), saved_progress[-1])
with sftp.open(FOLDER + '/bunny.txt', 'rb') as f:
self.assertEqual(text, f.read(128))
self.assertEqual((41, 41), saved_progress[-1])
os.unlink(localname)
localname = os.tempnam()
fd, localname = mkstemp()
os.close(fd)
saved_progress = []
sftp.get(FOLDER + '/bunny.txt', localname, progress_callback)
f = open(localname, 'rb')
self.assertEquals(text, f.read(128))
f.close()
self.assertEquals((41, 41), saved_progress[-1])
with open(localname, 'rb') as f:
self.assertEqual(text, f.read(128))
self.assertEqual((41, 41), saved_progress[-1])
os.unlink(localname)
sftp.unlink(FOLDER + '/bunny.txt')
@ -600,20 +580,18 @@ class SFTPTest (unittest.TestCase):
(it's an sftp extension that we support, and may be the only ones who
support it.)
"""
f = sftp.open(FOLDER + '/kitty.txt', 'w')
f.write('here kitty kitty' * 64)
f.close()
with sftp.open(FOLDER + '/kitty.txt', 'w') as f:
f.write('here kitty kitty' * 64)
try:
f = sftp.open(FOLDER + '/kitty.txt', 'r')
sum = f.check('sha1')
self.assertEquals('91059CFC6615941378D413CB5ADAF4C5EB293402', hexlify(sum).upper())
sum = f.check('md5', 0, 512)
self.assertEquals('93DE4788FCA28D471516963A1FE3856A', hexlify(sum).upper())
sum = f.check('md5', 0, 0, 510)
self.assertEquals('EB3B45B8CD55A0707D99B177544A319F373183D241432BB2157AB9E46358C4AC90370B5CADE5D90336FC1716F90B36D6',
hexlify(sum).upper())
f.close()
with sftp.open(FOLDER + '/kitty.txt', 'r') as f:
sum = f.check('sha1')
self.assertEqual('91059CFC6615941378D413CB5ADAF4C5EB293402', u(hexlify(sum)).upper())
sum = f.check('md5', 0, 512)
self.assertEqual('93DE4788FCA28D471516963A1FE3856A', u(hexlify(sum)).upper())
sum = f.check('md5', 0, 0, 510)
self.assertEqual('EB3B45B8CD55A0707D99B177544A319F373183D241432BB2157AB9E46358C4AC90370B5CADE5D90336FC1716F90B36D6',
u(hexlify(sum)).upper())
finally:
sftp.unlink(FOLDER + '/kitty.txt')
@ -621,12 +599,11 @@ class SFTPTest (unittest.TestCase):
"""
verify that the 'x' flag works when opening a file.
"""
f = sftp.open(FOLDER + '/unusual.txt', 'wx')
f.close()
sftp.open(FOLDER + '/unusual.txt', 'wx').close()
try:
try:
f = sftp.open(FOLDER + '/unusual.txt', 'wx')
sftp.open(FOLDER + '/unusual.txt', 'wx')
self.fail('expected exception')
except IOError:
pass
@ -637,44 +614,39 @@ class SFTPTest (unittest.TestCase):
"""
verify that unicode strings are encoded into utf8 correctly.
"""
f = sftp.open(FOLDER + '/something', 'w')
f.write('okay')
f.close()
with sftp.open(FOLDER + '/something', 'w') as f:
f.write('okay')
try:
sftp.rename(FOLDER + '/something', FOLDER + u'/\u00fcnic\u00f8de')
sftp.open(FOLDER + '/\xc3\xbcnic\xc3\xb8\x64\x65', 'r')
except Exception, e:
self.fail('exception ' + e)
sftp.unlink(FOLDER + '/\xc3\xbcnic\xc3\xb8\x64\x65')
sftp.rename(FOLDER + '/something', FOLDER + '/' + unicode_folder)
sftp.open(b(FOLDER) + utf8_folder, 'r')
except Exception as e:
self.fail('exception ' + str(e))
sftp.unlink(b(FOLDER) + utf8_folder)
def test_L_utf8_chdir(self):
sftp.mkdir(FOLDER + u'\u00fcnic\u00f8de')
sftp.mkdir(FOLDER + '/' + unicode_folder)
try:
sftp.chdir(FOLDER + u'\u00fcnic\u00f8de')
f = sftp.open('something', 'w')
f.write('okay')
f.close()
sftp.chdir(FOLDER + '/' + unicode_folder)
with sftp.open('something', 'w') as f:
f.write('okay')
sftp.unlink('something')
finally:
sftp.chdir(None)
sftp.rmdir(FOLDER + u'\u00fcnic\u00f8de')
sftp.chdir()
sftp.rmdir(FOLDER + '/' + unicode_folder)
def test_M_bad_readv(self):
"""
verify that readv at the end of the file doesn't essplode.
"""
f = sftp.open(FOLDER + '/zero', 'w')
f.close()
sftp.open(FOLDER + '/zero', 'w').close()
try:
f = sftp.open(FOLDER + '/zero', 'r')
f.readv([(0, 12)])
f.close()
with sftp.open(FOLDER + '/zero', 'r') as f:
f.readv([(0, 12)])
f = sftp.open(FOLDER + '/zero', 'r')
f.prefetch()
f.read(100)
f.close()
with sftp.open(FOLDER + '/zero', 'r') as f:
f.prefetch()
f.read(100)
finally:
sftp.unlink(FOLDER + '/zero')
@ -684,45 +656,61 @@ class SFTPTest (unittest.TestCase):
"""
warnings.filterwarnings('ignore', 'tempnam.*')
localname = os.tempnam()
fd, localname = mkstemp()
os.close(fd)
text = 'All I wanted was a plastic bunny rabbit.\n'
f = open(localname, 'wb')
f.write(text)
f.close()
with open(localname, 'w') as f:
f.write(text)
saved_progress = []
def progress_callback(x, y):
saved_progress.append((x, y))
res = sftp.put(localname, FOLDER + '/bunny.txt', progress_callback, False)
self.assertEquals(SFTPAttributes().attr, res.attr)
self.assertEqual(SFTPAttributes().attr, res.attr)
f = sftp.open(FOLDER + '/bunny.txt', 'r')
self.assertEquals(text, f.read(128))
f.close()
self.assertEquals((41, 41), saved_progress[-1])
with sftp.open(FOLDER + '/bunny.txt', 'r') as f:
self.assertEqual(text, f.read(128))
self.assertEqual((41, 41), saved_progress[-1])
os.unlink(localname)
sftp.unlink(FOLDER + '/bunny.txt')
def test_O_getcwd(self):
"""
verify that chdir/getcwd work.
"""
self.assertEqual(None, sftp.getcwd())
root = sftp.normalize('.')
if root[-1] != '/':
root += '/'
try:
sftp.mkdir(FOLDER + '/alpha')
sftp.chdir(FOLDER + '/alpha')
self.assertEqual('/' + FOLDER + '/alpha', sftp.getcwd())
finally:
sftp.chdir(root)
try:
sftp.rmdir(FOLDER + '/alpha')
except:
pass
def XXX_test_M_seek_append(self):
"""
verify that seek does't affect writes during append.
does not work except through paramiko. :( openssh fails.
"""
f = sftp.open(FOLDER + '/append.txt', 'a')
try:
f.write('first line\nsecond line\n')
f.seek(11, f.SEEK_SET)
f.write('third line\n')
f.close()
with sftp.open(FOLDER + '/append.txt', 'a') as f:
f.write('first line\nsecond line\n')
f.seek(11, f.SEEK_SET)
f.write('third line\n')
f = sftp.open(FOLDER + '/append.txt', 'r')
self.assertEqual(f.stat().st_size, 34)
self.assertEqual(f.readline(), 'first line\n')
self.assertEqual(f.readline(), 'second line\n')
self.assertEqual(f.readline(), 'third line\n')
f.close()
with sftp.open(FOLDER + '/append.txt', 'r') as f:
self.assertEqual(f.stat().st_size, 34)
self.assertEqual(f.readline(), 'first line\n')
self.assertEqual(f.readline(), 'second line\n')
self.assertEqual(f.readline(), 'third line\n')
finally:
sftp.remove(FOLDER + '/append.txt')
@ -731,10 +719,16 @@ class SFTPTest (unittest.TestCase):
Send an empty file and confirm it is sent.
"""
target = FOLDER + '/empty file.txt'
stream = StringIO.StringIO()
stream = StringIO()
try:
attrs = sftp.putfo(stream, target)
# the returned attributes should not be null
self.assertNotEqual(attrs, None)
finally:
sftp.remove(target)
if __name__ == '__main__':
SFTPTest.init_loopback()
from unittest import main
main()

View File

@ -33,9 +33,10 @@ import time
import unittest
import paramiko
from stub_sftp import StubServer, StubSFTPServer
from loop import LoopSocket
from test_sftp import get_sftp
from paramiko.common import *
from tests.stub_sftp import StubServer, StubSFTPServer
from tests.loop import LoopSocket
from tests.test_sftp import get_sftp
FOLDER = os.environ.get('TEST_FOLDER', 'temp-testing000')
@ -45,7 +46,7 @@ class BigSFTPTest (unittest.TestCase):
def setUp(self):
global FOLDER
sftp = get_sftp()
for i in xrange(1000):
for i in range(1000):
FOLDER = FOLDER[:-3] + '%03d' % i
try:
sftp.mkdir(FOLDER)
@ -65,19 +66,17 @@ class BigSFTPTest (unittest.TestCase):
numfiles = 100
try:
for i in range(numfiles):
f = sftp.open('%s/file%d.txt' % (FOLDER, i), 'w', 1)
f.write('this is file #%d.\n' % i)
f.close()
sftp.chmod('%s/file%d.txt' % (FOLDER, i), 0660)
with sftp.open('%s/file%d.txt' % (FOLDER, i), 'w', 1) as f:
f.write('this is file #%d.\n' % i)
sftp.chmod('%s/file%d.txt' % (FOLDER, i), o660)
# now make sure every file is there, by creating a list of filenmes
# and reading them in random order.
numlist = range(numfiles)
numlist = list(range(numfiles))
while len(numlist) > 0:
r = numlist[random.randint(0, len(numlist) - 1)]
f = sftp.open('%s/file%d.txt' % (FOLDER, r))
self.assertEqual(f.readline(), 'this is file #%d.\n' % r)
f.close()
with sftp.open('%s/file%d.txt' % (FOLDER, r)) as f:
self.assertEqual(f.readline(), 'this is file #%d.\n' % r)
numlist.remove(r)
finally:
for i in range(numfiles):
@ -94,12 +93,11 @@ class BigSFTPTest (unittest.TestCase):
kblob = (1024 * 'x')
start = time.time()
try:
f = sftp.open('%s/hongry.txt' % FOLDER, 'w')
for n in range(1024):
f.write(kblob)
if n % 128 == 0:
sys.stderr.write('.')
f.close()
with sftp.open('%s/hongry.txt' % FOLDER, 'w') as f:
for n in range(1024):
f.write(kblob)
if n % 128 == 0:
sys.stderr.write('.')
sys.stderr.write(' ')
self.assertEqual(sftp.stat('%s/hongry.txt' % FOLDER).st_size, 1024 * 1024)
@ -107,11 +105,10 @@ class BigSFTPTest (unittest.TestCase):
sys.stderr.write('%ds ' % round(end - start))
start = time.time()
f = sftp.open('%s/hongry.txt' % FOLDER, 'r')
for n in range(1024):
data = f.read(1024)
self.assertEqual(data, kblob)
f.close()
with sftp.open('%s/hongry.txt' % FOLDER, 'r') as f:
for n in range(1024):
data = f.read(1024)
self.assertEqual(data, kblob)
end = time.time()
sys.stderr.write('%ds ' % round(end - start))
@ -123,16 +120,15 @@ class BigSFTPTest (unittest.TestCase):
write a 1MB file, with no linefeeds, using pipelining.
"""
sftp = get_sftp()
kblob = ''.join([struct.pack('>H', n) for n in xrange(512)])
kblob = bytes().join([struct.pack('>H', n) for n in range(512)])
start = time.time()
try:
f = sftp.open('%s/hongry.txt' % FOLDER, 'w')
f.set_pipelined(True)
for n in range(1024):
f.write(kblob)
if n % 128 == 0:
sys.stderr.write('.')
f.close()
with sftp.open('%s/hongry.txt' % FOLDER, 'wb') as f:
f.set_pipelined(True)
for n in range(1024):
f.write(kblob)
if n % 128 == 0:
sys.stderr.write('.')
sys.stderr.write(' ')
self.assertEqual(sftp.stat('%s/hongry.txt' % FOLDER).st_size, 1024 * 1024)
@ -140,22 +136,21 @@ class BigSFTPTest (unittest.TestCase):
sys.stderr.write('%ds ' % round(end - start))
start = time.time()
f = sftp.open('%s/hongry.txt' % FOLDER, 'r')
f.prefetch()
with sftp.open('%s/hongry.txt' % FOLDER, 'rb') as f:
f.prefetch()
# read on odd boundaries to make sure the bytes aren't getting scrambled
n = 0
k2blob = kblob + kblob
chunk = 629
size = 1024 * 1024
while n < size:
if n + chunk > size:
chunk = size - n
data = f.read(chunk)
offset = n % 1024
self.assertEqual(data, k2blob[offset:offset + chunk])
n += chunk
f.close()
# read on odd boundaries to make sure the bytes aren't getting scrambled
n = 0
k2blob = kblob + kblob
chunk = 629
size = 1024 * 1024
while n < size:
if n + chunk > size:
chunk = size - n
data = f.read(chunk)
offset = n % 1024
self.assertEqual(data, k2blob[offset:offset + chunk])
n += chunk
end = time.time()
sys.stderr.write('%ds ' % round(end - start))
@ -164,15 +159,14 @@ class BigSFTPTest (unittest.TestCase):
def test_4_prefetch_seek(self):
sftp = get_sftp()
kblob = ''.join([struct.pack('>H', n) for n in xrange(512)])
kblob = bytes().join([struct.pack('>H', n) for n in range(512)])
try:
f = sftp.open('%s/hongry.txt' % FOLDER, 'w')
f.set_pipelined(True)
for n in range(1024):
f.write(kblob)
if n % 128 == 0:
sys.stderr.write('.')
f.close()
with sftp.open('%s/hongry.txt' % FOLDER, 'wb') as f:
f.set_pipelined(True)
for n in range(1024):
f.write(kblob)
if n % 128 == 0:
sys.stderr.write('.')
sys.stderr.write(' ')
self.assertEqual(sftp.stat('%s/hongry.txt' % FOLDER).st_size, 1024 * 1024)
@ -180,21 +174,20 @@ class BigSFTPTest (unittest.TestCase):
start = time.time()
k2blob = kblob + kblob
chunk = 793
for i in xrange(10):
f = sftp.open('%s/hongry.txt' % FOLDER, 'r')
f.prefetch()
base_offset = (512 * 1024) + 17 * random.randint(1000, 2000)
offsets = [base_offset + j * chunk for j in xrange(100)]
# randomly seek around and read them out
for j in xrange(100):
offset = offsets[random.randint(0, len(offsets) - 1)]
offsets.remove(offset)
f.seek(offset)
data = f.read(chunk)
n_offset = offset % 1024
self.assertEqual(data, k2blob[n_offset:n_offset + chunk])
offset += chunk
f.close()
for i in range(10):
with sftp.open('%s/hongry.txt' % FOLDER, 'rb') as f:
f.prefetch()
base_offset = (512 * 1024) + 17 * random.randint(1000, 2000)
offsets = [base_offset + j * chunk for j in range(100)]
# randomly seek around and read them out
for j in range(100):
offset = offsets[random.randint(0, len(offsets) - 1)]
offsets.remove(offset)
f.seek(offset)
data = f.read(chunk)
n_offset = offset % 1024
self.assertEqual(data, k2blob[n_offset:n_offset + chunk])
offset += chunk
end = time.time()
sys.stderr.write('%ds ' % round(end - start))
finally:
@ -202,15 +195,14 @@ class BigSFTPTest (unittest.TestCase):
def test_5_readv_seek(self):
sftp = get_sftp()
kblob = ''.join([struct.pack('>H', n) for n in xrange(512)])
kblob = bytes().join([struct.pack('>H', n) for n in range(512)])
try:
f = sftp.open('%s/hongry.txt' % FOLDER, 'w')
f.set_pipelined(True)
for n in range(1024):
f.write(kblob)
if n % 128 == 0:
sys.stderr.write('.')
f.close()
with sftp.open('%s/hongry.txt' % FOLDER, 'wb') as f:
f.set_pipelined(True)
for n in range(1024):
f.write(kblob)
if n % 128 == 0:
sys.stderr.write('.')
sys.stderr.write(' ')
self.assertEqual(sftp.stat('%s/hongry.txt' % FOLDER).st_size, 1024 * 1024)
@ -218,22 +210,21 @@ class BigSFTPTest (unittest.TestCase):
start = time.time()
k2blob = kblob + kblob
chunk = 793
for i in xrange(10):
f = sftp.open('%s/hongry.txt' % FOLDER, 'r')
base_offset = (512 * 1024) + 17 * random.randint(1000, 2000)
# make a bunch of offsets and put them in random order
offsets = [base_offset + j * chunk for j in xrange(100)]
readv_list = []
for j in xrange(100):
o = offsets[random.randint(0, len(offsets) - 1)]
offsets.remove(o)
readv_list.append((o, chunk))
ret = f.readv(readv_list)
for i in xrange(len(readv_list)):
offset = readv_list[i][0]
n_offset = offset % 1024
self.assertEqual(ret.next(), k2blob[n_offset:n_offset + chunk])
f.close()
for i in range(10):
with sftp.open('%s/hongry.txt' % FOLDER, 'rb') as f:
base_offset = (512 * 1024) + 17 * random.randint(1000, 2000)
# make a bunch of offsets and put them in random order
offsets = [base_offset + j * chunk for j in range(100)]
readv_list = []
for j in range(100):
o = offsets[random.randint(0, len(offsets) - 1)]
offsets.remove(o)
readv_list.append((o, chunk))
ret = f.readv(readv_list)
for i in range(len(readv_list)):
offset = readv_list[i][0]
n_offset = offset % 1024
self.assertEqual(next(ret), k2blob[n_offset:n_offset + chunk])
end = time.time()
sys.stderr.write('%ds ' % round(end - start))
finally:
@ -247,28 +238,26 @@ class BigSFTPTest (unittest.TestCase):
sftp = get_sftp()
kblob = (1024 * 'x')
try:
f = sftp.open('%s/hongry.txt' % FOLDER, 'w')
f.set_pipelined(True)
for n in range(1024):
f.write(kblob)
if n % 128 == 0:
sys.stderr.write('.')
f.close()
with sftp.open('%s/hongry.txt' % FOLDER, 'w') as f:
f.set_pipelined(True)
for n in range(1024):
f.write(kblob)
if n % 128 == 0:
sys.stderr.write('.')
sys.stderr.write(' ')
self.assertEqual(sftp.stat('%s/hongry.txt' % FOLDER).st_size, 1024 * 1024)
for i in range(10):
f = sftp.open('%s/hongry.txt' % FOLDER, 'r')
with sftp.open('%s/hongry.txt' % FOLDER, 'r') as f:
f.prefetch()
with sftp.open('%s/hongry.txt' % FOLDER, 'r') as f:
f.prefetch()
f = sftp.open('%s/hongry.txt' % FOLDER, 'r')
f.prefetch()
for n in range(1024):
data = f.read(1024)
self.assertEqual(data, kblob)
if n % 128 == 0:
sys.stderr.write('.')
f.close()
for n in range(1024):
data = f.read(1024)
self.assertEqual(data, kblob)
if n % 128 == 0:
sys.stderr.write('.')
sys.stderr.write(' ')
finally:
sftp.remove('%s/hongry.txt' % FOLDER)
@ -278,35 +267,33 @@ class BigSFTPTest (unittest.TestCase):
verify that prefetch and readv don't conflict with each other.
"""
sftp = get_sftp()
kblob = ''.join([struct.pack('>H', n) for n in xrange(512)])
kblob = bytes().join([struct.pack('>H', n) for n in range(512)])
try:
f = sftp.open('%s/hongry.txt' % FOLDER, 'w')
f.set_pipelined(True)
for n in range(1024):
f.write(kblob)
if n % 128 == 0:
sys.stderr.write('.')
f.close()
with sftp.open('%s/hongry.txt' % FOLDER, 'wb') as f:
f.set_pipelined(True)
for n in range(1024):
f.write(kblob)
if n % 128 == 0:
sys.stderr.write('.')
sys.stderr.write(' ')
self.assertEqual(sftp.stat('%s/hongry.txt' % FOLDER).st_size, 1024 * 1024)
f = sftp.open('%s/hongry.txt' % FOLDER, 'r')
f.prefetch()
data = f.read(1024)
self.assertEqual(data, kblob)
with sftp.open('%s/hongry.txt' % FOLDER, 'rb') as f:
f.prefetch()
data = f.read(1024)
self.assertEqual(data, kblob)
chunk_size = 793
base_offset = 512 * 1024
k2blob = kblob + kblob
chunks = [(base_offset + (chunk_size * i), chunk_size) for i in range(20)]
for data in f.readv(chunks):
offset = base_offset % 1024
self.assertEqual(chunk_size, len(data))
self.assertEqual(k2blob[offset:offset + chunk_size], data)
base_offset += chunk_size
chunk_size = 793
base_offset = 512 * 1024
k2blob = kblob + kblob
chunks = [(base_offset + (chunk_size * i), chunk_size) for i in range(20)]
for data in f.readv(chunks):
offset = base_offset % 1024
self.assertEqual(chunk_size, len(data))
self.assertEqual(k2blob[offset:offset + chunk_size], data)
base_offset += chunk_size
f.close()
sys.stderr.write(' ')
finally:
sftp.remove('%s/hongry.txt' % FOLDER)
@ -317,26 +304,24 @@ class BigSFTPTest (unittest.TestCase):
returned as a single blob.
"""
sftp = get_sftp()
kblob = ''.join([struct.pack('>H', n) for n in xrange(512)])
kblob = bytes().join([struct.pack('>H', n) for n in range(512)])
try:
f = sftp.open('%s/hongry.txt' % FOLDER, 'w')
f.set_pipelined(True)
for n in range(1024):
f.write(kblob)
if n % 128 == 0:
sys.stderr.write('.')
f.close()
with sftp.open('%s/hongry.txt' % FOLDER, 'wb') as f:
f.set_pipelined(True)
for n in range(1024):
f.write(kblob)
if n % 128 == 0:
sys.stderr.write('.')
sys.stderr.write(' ')
self.assertEqual(sftp.stat('%s/hongry.txt' % FOLDER).st_size, 1024 * 1024)
f = sftp.open('%s/hongry.txt' % FOLDER, 'r')
data = list(f.readv([(23 * 1024, 128 * 1024)]))
self.assertEqual(1, len(data))
data = data[0]
self.assertEqual(128 * 1024, len(data))
with sftp.open('%s/hongry.txt' % FOLDER, 'rb') as f:
data = list(f.readv([(23 * 1024, 128 * 1024)]))
self.assertEqual(1, len(data))
data = data[0]
self.assertEqual(128 * 1024, len(data))
f.close()
sys.stderr.write(' ')
finally:
sftp.remove('%s/hongry.txt' % FOLDER)
@ -348,9 +333,8 @@ class BigSFTPTest (unittest.TestCase):
sftp = get_sftp()
mblob = (1024 * 1024 * 'x')
try:
f = sftp.open('%s/hongry.txt' % FOLDER, 'w', 128 * 1024)
f.write(mblob)
f.close()
with sftp.open('%s/hongry.txt' % FOLDER, 'w', 128 * 1024) as f:
f.write(mblob)
self.assertEqual(sftp.stat('%s/hongry.txt' % FOLDER).st_size, 1024 * 1024)
finally:
@ -365,21 +349,26 @@ class BigSFTPTest (unittest.TestCase):
t.packetizer.REKEY_BYTES = 512 * 1024
k32blob = (32 * 1024 * 'x')
try:
f = sftp.open('%s/hongry.txt' % FOLDER, 'w', 128 * 1024)
for i in xrange(32):
f.write(k32blob)
f.close()
with sftp.open('%s/hongry.txt' % FOLDER, 'w', 128 * 1024) as f:
for i in range(32):
f.write(k32blob)
self.assertEqual(sftp.stat('%s/hongry.txt' % FOLDER).st_size, 1024 * 1024)
self.assertNotEquals(t.H, t.session_id)
self.assertNotEqual(t.H, t.session_id)
# try to read it too.
f = sftp.open('%s/hongry.txt' % FOLDER, 'r', 128 * 1024)
f.prefetch()
total = 0
while total < 1024 * 1024:
total += len(f.read(32 * 1024))
f.close()
with sftp.open('%s/hongry.txt' % FOLDER, 'r', 128 * 1024) as f:
f.prefetch()
total = 0
while total < 1024 * 1024:
total += len(f.read(32 * 1024))
finally:
sftp.remove('%s/hongry.txt' % FOLDER)
t.packetizer.REKEY_BYTES = pow(2, 30)
if __name__ == '__main__':
from tests.test_sftp import SFTPTest
SFTPTest.init_loopback()
from unittest import main
main()

View File

@ -20,7 +20,7 @@
Some unit tests for the ssh2 protocol in Transport.
"""
from binascii import hexlify, unhexlify
from binascii import hexlify
import select
import socket
import sys
@ -33,10 +33,10 @@ from paramiko import Transport, SecurityOptions, ServerInterface, RSAKey, DSSKey
SSHException, BadAuthenticationType, InteractiveQuery, ChannelException
from paramiko import AUTH_FAILED, AUTH_PARTIALLY_SUCCESSFUL, AUTH_SUCCESSFUL
from paramiko import OPEN_SUCCEEDED, OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED
from paramiko.common import MSG_KEXINIT, MSG_CHANNEL_WINDOW_ADJUST
from paramiko.common import MSG_KEXINIT, MSG_CHANNEL_WINDOW_ADJUST, b, bytes
from paramiko.message import Message
from loop import LoopSocket
from util import ParamikoTest
from tests.loop import LoopSocket
from tests.util import ParamikoTest, test_path
LONG_BANNER = """\
@ -55,7 +55,7 @@ Maybe.
class NullServer (ServerInterface):
paranoid_did_password = False
paranoid_did_public_key = False
paranoid_key = DSSKey.from_private_key_file('tests/test_dss.key')
paranoid_key = DSSKey.from_private_key_file(test_path('test_dss.key'))
def get_allowed_auths(self, username):
if username == 'slowdive':
@ -121,8 +121,8 @@ class TransportTest(ParamikoTest):
self.sockc.close()
def setup_test_server(self, client_options=None, server_options=None):
host_key = RSAKey.from_private_key_file('tests/test_rsa.key')
public_host_key = RSAKey(data=str(host_key))
host_key = RSAKey.from_private_key_file(test_path('test_rsa.key'))
public_host_key = RSAKey(data=host_key.asbytes())
self.ts.add_server_key(host_key)
if client_options is not None:
@ -132,37 +132,37 @@ class TransportTest(ParamikoTest):
event = threading.Event()
self.server = NullServer()
self.assert_(not event.isSet())
self.assertTrue(not event.isSet())
self.ts.start_server(event, self.server)
self.tc.connect(hostkey=public_host_key,
username='slowdive', password='pygmalion')
event.wait(1.0)
self.assert_(event.isSet())
self.assert_(self.ts.is_active())
self.assertTrue(event.isSet())
self.assertTrue(self.ts.is_active())
def test_1_security_options(self):
o = self.tc.get_security_options()
self.assertEquals(type(o), SecurityOptions)
self.assert_(('aes256-cbc', 'blowfish-cbc') != o.ciphers)
self.assertEqual(type(o), SecurityOptions)
self.assertTrue(('aes256-cbc', 'blowfish-cbc') != o.ciphers)
o.ciphers = ('aes256-cbc', 'blowfish-cbc')
self.assertEquals(('aes256-cbc', 'blowfish-cbc'), o.ciphers)
self.assertEqual(('aes256-cbc', 'blowfish-cbc'), o.ciphers)
try:
o.ciphers = ('aes256-cbc', 'made-up-cipher')
self.assert_(False)
self.assertTrue(False)
except ValueError:
pass
try:
o.ciphers = 23
self.assert_(False)
self.assertTrue(False)
except TypeError:
pass
def test_2_compute_key(self):
self.tc.K = 123281095979686581523377256114209720774539068973101330872763622971399429481072519713536292772709507296759612401802191955568143056534122385270077606457721553469730659233569339356140085284052436697480759510519672848743794433460113118986816826624865291116513647975790797391795651716378444844877749505443714557929L
self.tc.H = unhexlify('0C8307CDE6856FF30BA93684EB0F04C2520E9ED3')
self.tc.K = 123281095979686581523377256114209720774539068973101330872763622971399429481072519713536292772709507296759612401802191955568143056534122385270077606457721553469730659233569339356140085284052436697480759510519672848743794433460113118986816826624865291116513647975790797391795651716378444844877749505443714557929
self.tc.H = b'\x0C\x83\x07\xCD\xE6\x85\x6F\xF3\x0B\xA9\x36\x84\xEB\x0F\x04\xC2\x52\x0E\x9E\xD3'
self.tc.session_id = self.tc.H
key = self.tc._compute_key('C', 32)
self.assertEquals('207E66594CA87C44ECCBA3B3CD39FDDB378E6FDB0F97C54B2AA0CFBF900CD995',
self.assertEqual(b'207E66594CA87C44ECCBA3B3CD39FDDB378E6FDB0F97C54B2AA0CFBF900CD995',
hexlify(key).upper())
def test_3_simple(self):
@ -171,44 +171,44 @@ class TransportTest(ParamikoTest):
loopback sockets. this is hardly "simple" but it's simpler than the
later tests. :)
"""
host_key = RSAKey.from_private_key_file('tests/test_rsa.key')
public_host_key = RSAKey(data=str(host_key))
host_key = RSAKey.from_private_key_file(test_path('test_rsa.key'))
public_host_key = RSAKey(data=host_key.asbytes())
self.ts.add_server_key(host_key)
event = threading.Event()
server = NullServer()
self.assert_(not event.isSet())
self.assertEquals(None, self.tc.get_username())
self.assertEquals(None, self.ts.get_username())
self.assertEquals(False, self.tc.is_authenticated())
self.assertEquals(False, self.ts.is_authenticated())
self.assertTrue(not event.isSet())
self.assertEqual(None, self.tc.get_username())
self.assertEqual(None, self.ts.get_username())
self.assertEqual(False, self.tc.is_authenticated())
self.assertEqual(False, self.ts.is_authenticated())
self.ts.start_server(event, server)
self.tc.connect(hostkey=public_host_key,
username='slowdive', password='pygmalion')
event.wait(1.0)
self.assert_(event.isSet())
self.assert_(self.ts.is_active())
self.assertEquals('slowdive', self.tc.get_username())
self.assertEquals('slowdive', self.ts.get_username())
self.assertEquals(True, self.tc.is_authenticated())
self.assertEquals(True, self.ts.is_authenticated())
self.assertTrue(event.isSet())
self.assertTrue(self.ts.is_active())
self.assertEqual('slowdive', self.tc.get_username())
self.assertEqual('slowdive', self.ts.get_username())
self.assertEqual(True, self.tc.is_authenticated())
self.assertEqual(True, self.ts.is_authenticated())
def test_3a_long_banner(self):
"""
verify that a long banner doesn't mess up the handshake.
"""
host_key = RSAKey.from_private_key_file('tests/test_rsa.key')
public_host_key = RSAKey(data=str(host_key))
host_key = RSAKey.from_private_key_file(test_path('test_rsa.key'))
public_host_key = RSAKey(data=host_key.asbytes())
self.ts.add_server_key(host_key)
event = threading.Event()
server = NullServer()
self.assert_(not event.isSet())
self.assertTrue(not event.isSet())
self.socks.send(LONG_BANNER)
self.ts.start_server(event, server)
self.tc.connect(hostkey=public_host_key,
username='slowdive', password='pygmalion')
event.wait(1.0)
self.assert_(event.isSet())
self.assert_(self.ts.is_active())
self.assertTrue(event.isSet())
self.assertTrue(self.ts.is_active())
def test_4_special(self):
"""
@ -219,10 +219,10 @@ class TransportTest(ParamikoTest):
options.ciphers = ('aes256-cbc',)
options.digests = ('hmac-md5-96',)
self.setup_test_server(client_options=force_algorithms)
self.assertEquals('aes256-cbc', self.tc.local_cipher)
self.assertEquals('aes256-cbc', self.tc.remote_cipher)
self.assertEquals(12, self.tc.packetizer.get_mac_size_out())
self.assertEquals(12, self.tc.packetizer.get_mac_size_in())
self.assertEqual('aes256-cbc', self.tc.local_cipher)
self.assertEqual('aes256-cbc', self.tc.remote_cipher)
self.assertEqual(12, self.tc.packetizer.get_mac_size_out())
self.assertEqual(12, self.tc.packetizer.get_mac_size_in())
self.tc.send_ignore(1024)
self.tc.renegotiate_keys()
@ -233,10 +233,10 @@ class TransportTest(ParamikoTest):
verify that the keepalive will be sent.
"""
self.setup_test_server()
self.assertEquals(None, getattr(self.server, '_global_request', None))
self.assertEqual(None, getattr(self.server, '_global_request', None))
self.tc.set_keepalive(1)
time.sleep(2)
self.assertEquals('keepalive@lag.net', self.server._global_request)
self.assertEqual('keepalive@lag.net', self.server._global_request)
def test_6_exec_command(self):
"""
@ -248,8 +248,8 @@ class TransportTest(ParamikoTest):
schan = self.ts.accept(1.0)
try:
chan.exec_command('no')
self.assert_(False)
except SSHException, x:
self.assertTrue(False)
except SSHException:
pass
chan = self.tc.open_session()
@ -260,11 +260,11 @@ class TransportTest(ParamikoTest):
schan.close()
f = chan.makefile()
self.assertEquals('Hello there.\n', f.readline())
self.assertEquals('', f.readline())
self.assertEqual('Hello there.\n', f.readline())
self.assertEqual('', f.readline())
f = chan.makefile_stderr()
self.assertEquals('This is on stderr.\n', f.readline())
self.assertEquals('', f.readline())
self.assertEqual('This is on stderr.\n', f.readline())
self.assertEqual('', f.readline())
# now try it with combined stdout/stderr
chan = self.tc.open_session()
@ -276,9 +276,9 @@ class TransportTest(ParamikoTest):
chan.set_combine_stderr(True)
f = chan.makefile()
self.assertEquals('Hello there.\n', f.readline())
self.assertEquals('This is on stderr.\n', f.readline())
self.assertEquals('', f.readline())
self.assertEqual('Hello there.\n', f.readline())
self.assertEqual('This is on stderr.\n', f.readline())
self.assertEqual('', f.readline())
def test_7_invoke_shell(self):
"""
@ -290,9 +290,9 @@ class TransportTest(ParamikoTest):
schan = self.ts.accept(1.0)
chan.send('communist j. cat\n')
f = schan.makefile()
self.assertEquals('communist j. cat\n', f.readline())
self.assertEqual('communist j. cat\n', f.readline())
chan.close()
self.assertEquals('', f.readline())
self.assertEqual('', f.readline())
def test_8_channel_exception(self):
"""
@ -302,8 +302,8 @@ class TransportTest(ParamikoTest):
try:
chan = self.tc.open_channel('bogus')
self.fail('expected exception')
except ChannelException, x:
self.assert_(x.code == OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED)
except ChannelException as e:
self.assertTrue(e.code == OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED)
def test_9_exit_status(self):
"""
@ -315,7 +315,7 @@ class TransportTest(ParamikoTest):
schan = self.ts.accept(1.0)
chan.exec_command('yes')
schan.send('Hello there.\n')
self.assert_(not chan.exit_status_ready())
self.assertTrue(not chan.exit_status_ready())
# trigger an EOF
schan.shutdown_read()
schan.shutdown_write()
@ -323,15 +323,15 @@ class TransportTest(ParamikoTest):
schan.close()
f = chan.makefile()
self.assertEquals('Hello there.\n', f.readline())
self.assertEquals('', f.readline())
self.assertEqual('Hello there.\n', f.readline())
self.assertEqual('', f.readline())
count = 0
while not chan.exit_status_ready():
time.sleep(0.1)
count += 1
if count > 50:
raise Exception("timeout")
self.assertEquals(23, chan.recv_exit_status())
self.assertEqual(23, chan.recv_exit_status())
chan.close()
def test_A_select(self):
@ -345,9 +345,9 @@ class TransportTest(ParamikoTest):
# nothing should be ready
r, w, e = select.select([chan], [], [], 0.1)
self.assertEquals([], r)
self.assertEquals([], w)
self.assertEquals([], e)
self.assertEqual([], r)
self.assertEqual([], w)
self.assertEqual([], e)
schan.send('hello\n')
@ -357,17 +357,17 @@ class TransportTest(ParamikoTest):
if chan in r:
break
time.sleep(0.1)
self.assertEquals([chan], r)
self.assertEquals([], w)
self.assertEquals([], e)
self.assertEqual([chan], r)
self.assertEqual([], w)
self.assertEqual([], e)
self.assertEquals('hello\n', chan.recv(6))
self.assertEqual(b'hello\n', chan.recv(6))
# and, should be dead again now
r, w, e = select.select([chan], [], [], 0.1)
self.assertEquals([], r)
self.assertEquals([], w)
self.assertEquals([], e)
self.assertEqual([], r)
self.assertEqual([], w)
self.assertEqual([], e)
schan.close()
@ -377,17 +377,17 @@ class TransportTest(ParamikoTest):
if chan in r:
break
time.sleep(0.1)
self.assertEquals([chan], r)
self.assertEquals([], w)
self.assertEquals([], e)
self.assertEquals('', chan.recv(16))
self.assertEqual([chan], r)
self.assertEqual([], w)
self.assertEqual([], e)
self.assertEqual(bytes(), chan.recv(16))
# make sure the pipe is still open for now...
p = chan._pipe
self.assertEquals(False, p._closed)
self.assertEqual(False, p._closed)
chan.close()
# ...and now is closed.
self.assertEquals(True, p._closed)
self.assertEqual(True, p._closed)
def test_B_renegotiate(self):
"""
@ -399,17 +399,17 @@ class TransportTest(ParamikoTest):
chan.exec_command('yes')
schan = self.ts.accept(1.0)
self.assertEquals(self.tc.H, self.tc.session_id)
self.assertEqual(self.tc.H, self.tc.session_id)
for i in range(20):
chan.send('x' * 1024)
chan.close()
# allow a few seconds for the rekeying to complete
for i in xrange(50):
for i in range(50):
if self.tc.H != self.tc.session_id:
break
time.sleep(0.1)
self.assertNotEquals(self.tc.H, self.tc.session_id)
self.assertNotEqual(self.tc.H, self.tc.session_id)
schan.close()
@ -428,8 +428,8 @@ class TransportTest(ParamikoTest):
chan.send('x' * 1024)
bytes2 = self.tc.packetizer._Packetizer__sent_bytes
# tests show this is actually compressed to *52 bytes*! including packet overhead! nice!! :)
self.assert_(bytes2 - bytes < 1024)
self.assertEquals(52, bytes2 - bytes)
self.assertTrue(bytes2 - bytes < 1024)
self.assertEqual(52, bytes2 - bytes)
chan.close()
schan.close()
@ -444,24 +444,25 @@ class TransportTest(ParamikoTest):
schan = self.ts.accept(1.0)
requested = []
def handler(c, (addr, port)):
def handler(c, addr_port):
addr, port = addr_port
requested.append((addr, port))
self.tc._queue_incoming_channel(c)
self.assertEquals(None, getattr(self.server, '_x11_screen_number', None))
self.assertEqual(None, getattr(self.server, '_x11_screen_number', None))
cookie = chan.request_x11(0, single_connection=True, handler=handler)
self.assertEquals(0, self.server._x11_screen_number)
self.assertEquals('MIT-MAGIC-COOKIE-1', self.server._x11_auth_protocol)
self.assertEquals(cookie, self.server._x11_auth_cookie)
self.assertEquals(True, self.server._x11_single_connection)
self.assertEqual(0, self.server._x11_screen_number)
self.assertEqual('MIT-MAGIC-COOKIE-1', self.server._x11_auth_protocol)
self.assertEqual(cookie, self.server._x11_auth_cookie)
self.assertEqual(True, self.server._x11_single_connection)
x11_server = self.ts.open_x11_channel(('localhost', 6093))
x11_client = self.tc.accept()
self.assertEquals('localhost', requested[0][0])
self.assertEquals(6093, requested[0][1])
self.assertEqual('localhost', requested[0][0])
self.assertEqual(6093, requested[0][1])
x11_server.send('hello')
self.assertEquals('hello', x11_client.recv(5))
self.assertEqual(b'hello', x11_client.recv(5))
x11_server.close()
x11_client.close()
@ -479,13 +480,13 @@ class TransportTest(ParamikoTest):
schan = self.ts.accept(1.0)
requested = []
def handler(c, (origin_addr, origin_port), (server_addr, server_port)):
requested.append((origin_addr, origin_port))
requested.append((server_addr, server_port))
def handler(c, origin_addr_port, server_addr_port):
requested.append(origin_addr_port)
requested.append(server_addr_port)
self.tc._queue_incoming_channel(c)
port = self.tc.request_port_forward('127.0.0.1', 0, handler)
self.assertEquals(port, self.server._listen.getsockname()[1])
self.assertEqual(port, self.server._listen.getsockname()[1])
cs = socket.socket()
cs.connect(('127.0.0.1', port))
@ -494,7 +495,7 @@ class TransportTest(ParamikoTest):
cch = self.tc.accept()
sch.send('hello')
self.assertEquals('hello', cch.recv(5))
self.assertEqual(b'hello', cch.recv(5))
sch.close()
cch.close()
ss.close()
@ -526,12 +527,12 @@ class TransportTest(ParamikoTest):
cch.connect(self.server._tcpip_dest)
ss, _ = greeting_server.accept()
ss.send('Hello!\n')
ss.send(b'Hello!\n')
ss.close()
sch.send(cch.recv(8192))
sch.close()
self.assertEquals('Hello!\n', cs.recv(7))
self.assertEqual(b'Hello!\n', cs.recv(7))
cs.close()
def test_G_stderr_select(self):
@ -546,9 +547,9 @@ class TransportTest(ParamikoTest):
# nothing should be ready
r, w, e = select.select([chan], [], [], 0.1)
self.assertEquals([], r)
self.assertEquals([], w)
self.assertEquals([], e)
self.assertEqual([], r)
self.assertEqual([], w)
self.assertEqual([], e)
schan.send_stderr('hello\n')
@ -558,17 +559,17 @@ class TransportTest(ParamikoTest):
if chan in r:
break
time.sleep(0.1)
self.assertEquals([chan], r)
self.assertEquals([], w)
self.assertEquals([], e)
self.assertEqual([chan], r)
self.assertEqual([], w)
self.assertEqual([], e)
self.assertEquals('hello\n', chan.recv_stderr(6))
self.assertEqual(b'hello\n', chan.recv_stderr(6))
# and, should be dead again now
r, w, e = select.select([chan], [], [], 0.1)
self.assertEquals([], r)
self.assertEquals([], w)
self.assertEquals([], e)
self.assertEqual([], r)
self.assertEqual([], w)
self.assertEqual([], e)
schan.close()
chan.close()
@ -582,7 +583,7 @@ class TransportTest(ParamikoTest):
chan.invoke_shell()
schan = self.ts.accept(1.0)
self.assertEquals(chan.send_ready(), True)
self.assertEqual(chan.send_ready(), True)
total = 0
K = '*' * 1024
while total < 1024 * 1024:
@ -590,11 +591,11 @@ class TransportTest(ParamikoTest):
total += len(K)
if not chan.send_ready():
break
self.assert_(total < 1024 * 1024)
self.assertTrue(total < 1024 * 1024)
schan.close()
chan.close()
self.assertEquals(chan.send_ready(), True)
self.assertEqual(chan.send_ready(), True)
def test_I_rekey_deadlock(self):
"""
@ -657,7 +658,7 @@ class TransportTest(ParamikoTest):
def run(self):
try:
for i in xrange(1, 1+self.iterations):
for i in range(1, 1+self.iterations):
if self.done_event.isSet():
break
self.watchdog_event.set()
@ -706,7 +707,7 @@ class TransportTest(ParamikoTest):
# Simulate in-transit MSG_CHANNEL_WINDOW_ADJUST by sending it
# before responding to the incoming MSG_KEXINIT.
m2 = Message()
m2.add_byte(chr(MSG_CHANNEL_WINDOW_ADJUST))
m2.add_byte(cMSG_CHANNEL_WINDOW_ADJUST)
m2.add_int(chan.remote_chanid)
m2.add_int(1) # bytes to add
self._send_message(m2)

View File

@ -21,15 +21,15 @@ Some unit tests for utility functions.
"""
from binascii import hexlify
import cStringIO
import errno
import os
import unittest
from Crypto.Hash import SHA
import paramiko.util
from paramiko.util import lookup_ssh_host_config as host_config
from paramiko.py3compat import StringIO, byte_ord, b
from util import ParamikoTest
from tests.util import ParamikoTest
test_config_file = """\
Host *
@ -65,7 +65,7 @@ class UtilTest(ParamikoTest):
"""
verify that all the classes can be imported from paramiko.
"""
symbols = globals().keys()
symbols = list(globals().keys())
self.assertTrue('Transport' in symbols)
self.assertTrue('SSHClient' in symbols)
self.assertTrue('MissingHostKeyPolicy' in symbols)
@ -101,9 +101,9 @@ class UtilTest(ParamikoTest):
def test_2_parse_config(self):
global test_config_file
f = cStringIO.StringIO(test_config_file)
f = StringIO(test_config_file)
config = paramiko.util.parse_ssh_config(f)
self.assertEquals(config._config,
self.assertEqual(config._config,
[{'host': ['*'], 'config': {}}, {'host': ['*'], 'config': {'identityfile': ['~/.ssh/id_rsa'], 'user': 'robey'}},
{'host': ['*.example.com'], 'config': {'user': 'bjork', 'port': '3333'}},
{'host': ['*'], 'config': {'crazy': 'something dumb '}},
@ -111,7 +111,7 @@ class UtilTest(ParamikoTest):
def test_3_host_config(self):
global test_config_file
f = cStringIO.StringIO(test_config_file)
f = StringIO(test_config_file)
config = paramiko.util.parse_ssh_config(f)
for host, values in {
@ -131,27 +131,26 @@ class UtilTest(ParamikoTest):
hostname=host,
identityfile=[os.path.expanduser("~/.ssh/id_rsa")]
)
self.assertEquals(
self.assertEqual(
paramiko.util.lookup_ssh_host_config(host, config),
values
)
def test_4_generate_key_bytes(self):
x = paramiko.util.generate_key_bytes(SHA, 'ABCDEFGH', 'This is my secret passphrase.', 64)
hex = ''.join(['%02x' % ord(c) for c in x])
self.assertEquals(hex, '9110e2f6793b69363e58173e9436b13a5a4b339005741d5c680e505f57d871347b4239f14fb5c46e857d5e100424873ba849ac699cea98d729e57b3e84378e8b')
x = paramiko.util.generate_key_bytes(SHA, b'ABCDEFGH', 'This is my secret passphrase.', 64)
hex = ''.join(['%02x' % byte_ord(c) for c in x])
self.assertEqual(hex, '9110e2f6793b69363e58173e9436b13a5a4b339005741d5c680e505f57d871347b4239f14fb5c46e857d5e100424873ba849ac699cea98d729e57b3e84378e8b')
def test_5_host_keys(self):
f = open('hostfile.temp', 'w')
f.write(test_hosts_file)
f.close()
with open('hostfile.temp', 'w') as f:
f.write(test_hosts_file)
try:
hostdict = paramiko.util.load_host_keys('hostfile.temp')
self.assertEquals(2, len(hostdict))
self.assertEquals(1, len(hostdict.values()[0]))
self.assertEquals(1, len(hostdict.values()[1]))
self.assertEqual(2, len(hostdict))
self.assertEqual(1, len(list(hostdict.values())[0]))
self.assertEqual(1, len(list(hostdict.values())[1]))
fp = hexlify(hostdict['secure.example.com']['ssh-rsa'].get_fingerprint()).upper()
self.assertEquals('E6684DB30E109B67B70FF1DC5C7F1363', fp)
self.assertEqual(b'E6684DB30E109B67B70FF1DC5C7F1363', fp)
finally:
os.unlink('hostfile.temp')
@ -159,7 +158,7 @@ class UtilTest(ParamikoTest):
from paramiko.common import rng
# just verify that we can pull out 32 bytes and not get an exception.
x = rng.read(32)
self.assertEquals(len(x), 32)
self.assertEqual(len(x), 32)
def test_7_host_config_expose_issue_33(self):
test_config_file = """
@ -172,16 +171,16 @@ Host *.example.com
Host *
Port 3333
"""
f = cStringIO.StringIO(test_config_file)
f = StringIO(test_config_file)
config = paramiko.util.parse_ssh_config(f)
host = 'www13.example.com'
self.assertEquals(
self.assertEqual(
paramiko.util.lookup_ssh_host_config(host, config),
{'hostname': host, 'port': '22'}
)
def test_8_eintr_retry(self):
self.assertEquals('foo', paramiko.util.retry_on_signal(lambda: 'foo'))
self.assertEqual('foo', paramiko.util.retry_on_signal(lambda: 'foo'))
# Variables that are set by raises_intr
intr_errors_remaining = [3]
@ -192,8 +191,8 @@ Host *
intr_errors_remaining[0] -= 1
raise IOError(errno.EINTR, 'file', 'interrupted system call')
self.assertTrue(paramiko.util.retry_on_signal(raises_intr) is None)
self.assertEquals(0, intr_errors_remaining[0])
self.assertEquals(4, call_count[0])
self.assertEqual(0, intr_errors_remaining[0])
self.assertEqual(4, call_count[0])
def raises_ioerror_not_eintr():
raise IOError(errno.ENOENT, 'file', 'file not found')
@ -216,10 +215,10 @@ Host space-delimited
Host equals-delimited
ProxyCommand=foo bar=biz baz
"""
f = cStringIO.StringIO(conf)
f = StringIO(conf)
config = paramiko.util.parse_ssh_config(f)
for host in ('space-delimited', 'equals-delimited'):
self.assertEquals(
self.assertEqual(
host_config(host, config)['proxycommand'],
'foo bar=biz baz'
)
@ -228,7 +227,7 @@ Host equals-delimited
"""
ProxyCommand should perform interpolation on the value
"""
config = paramiko.util.parse_ssh_config(cStringIO.StringIO("""
config = paramiko.util.parse_ssh_config(StringIO("""
Host specific
Port 37
ProxyCommand host %h port %p lol
@ -245,7 +244,7 @@ Host *
('specific', "host specific port 37 lol"),
('portonly', "host portonly port 155"),
):
self.assertEquals(
self.assertEqual(
host_config(host, config)['proxycommand'],
val
)
@ -264,10 +263,10 @@ Host www13.*
Host *
Port 3333
"""
f = cStringIO.StringIO(test_config_file)
f = StringIO(test_config_file)
config = paramiko.util.parse_ssh_config(f)
host = 'www13.example.com'
self.assertEquals(
self.assertEqual(
paramiko.util.lookup_ssh_host_config(host, config),
{'hostname': host, 'port': '8080'}
)
@ -293,9 +292,9 @@ ProxyCommand foo=bar:%h-%p
'foo=bar:proxy-without-equal-divisor-22'}
}.items():
f = cStringIO.StringIO(test_config_file)
f = StringIO(test_config_file)
config = paramiko.util.parse_ssh_config(f)
self.assertEquals(
self.assertEqual(
paramiko.util.lookup_ssh_host_config(host, config),
values
)
@ -323,9 +322,9 @@ IdentityFile id_dsa22
'identityfile': ['id_dsa0', 'id_dsa1', 'id_dsa22']}
}.items():
f = cStringIO.StringIO(test_config_file)
f = StringIO(test_config_file)
config = paramiko.util.parse_ssh_config(f)
self.assertEquals(
self.assertEqual(
paramiko.util.lookup_ssh_host_config(host, config),
values
)
@ -338,5 +337,5 @@ IdentityFile id_dsa22
AddressFamily inet
IdentityFile something_%l_using_fqdn
"""
config = paramiko.util.parse_ssh_config(cStringIO.StringIO(test_config))
config = paramiko.util.parse_ssh_config(StringIO(test_config))
assert config.lookup('meh') # will die during lookup() if bug regresses

View File

@ -1,5 +1,8 @@
import os
import unittest
root_path = os.path.dirname(os.path.realpath(__file__))
class ParamikoTest(unittest.TestCase):
# for Python 2.3 and below
@ -8,3 +11,7 @@ class ParamikoTest(unittest.TestCase):
if not hasattr(unittest.TestCase, 'assertFalse'):
assertFalse = unittest.TestCase.failIf
def test_path(filename):
return os.path.join(root_path, filename)

View File

@ -1,5 +1,5 @@
[tox]
envlist = py25,py26,py27
envlist = py25,py26,py27,py32,py33
[testenv]
commands = pip install --use-mirrors -q -r tox-requirements.txt