Changes inspired by the nischu7 branch

This commit is contained in:
Scott Maxwell 2013-11-01 09:49:52 -07:00
parent 06b866cf40
commit 9662a7f779
20 changed files with 68 additions and 57 deletions

View File

@ -245,7 +245,7 @@ class AuthHandler (object):
m.add_byte(cMSG_USERAUTH_INFO_REQUEST) m.add_byte(cMSG_USERAUTH_INFO_REQUEST)
m.add_string(q.name) m.add_string(q.name)
m.add_string(q.instructions) m.add_string(q.instructions)
m.add_string('') m.add_string(bytes())
m.add_int(len(q.prompts)) m.add_int(len(q.prompts))
for p in q.prompts: for p in q.prompts:
m.add_string(p[0]) m.add_string(p[0])
@ -375,7 +375,7 @@ class AuthHandler (object):
def _parse_userauth_banner(self, m): def _parse_userauth_banner(self, m):
banner = m.get_string() banner = m.get_string()
lang = m.get_string() lang = m.get_string()
self.transport._log(INFO, 'Auth banner: ' + banner) self.transport._log(INFO, 'Auth banner: %s' % banner)
# who cares. # who cares.
def _parse_userauth_info_request(self, m): def _parse_userauth_info_request(self, m):

View File

@ -113,9 +113,9 @@ class BER(object):
def encode(self, x): def encode(self, x):
if type(x) is bool: if type(x) is bool:
if x: if x:
self.encode_tlv(1, '\xff') self.encode_tlv(1, max_byte)
else: else:
self.encode_tlv(1, '\x00') self.encode_tlv(1, zero_byte)
elif (type(x) is int) or (type(x) is long): elif (type(x) is int) or (type(x) is long):
self.encode_tlv(2, util.deflate_long(x)) self.encode_tlv(2, util.deflate_long(x))
elif type(x) is str: elif type(x) is str:

View File

@ -157,7 +157,7 @@ class Channel (object):
m.add_int(height) m.add_int(height)
m.add_int(width_pixels) m.add_int(width_pixels)
m.add_int(height_pixels) m.add_int(height_pixels)
m.add_string('') m.add_string(bytes())
self._event_pending() self._event_pending()
self.transport._send_user_message(m) self.transport._send_user_message(m)
self._wait_for_event() self._wait_for_event()
@ -477,7 +477,7 @@ class Channel (object):
@since: 1.1 @since: 1.1
""" """
data = '' data = bytes()
self.lock.acquire() self.lock.acquire()
try: try:
old = self.combine_stderr old = self.combine_stderr

View File

@ -148,8 +148,6 @@ def asbytes(s):
xffffffff = long(0xffffffff) xffffffff = long(0xffffffff)
x80000000 = long(0x80000000) x80000000 = long(0x80000000)
long_zero = long(0)
long_one = long(1)
o666 = 438 o666 = 438
o660 = 432 o660 = 432
o644 = 420 o644 = 420

View File

@ -110,9 +110,9 @@ class DSSKey (PKey):
rstr = util.deflate_long(r, 0) rstr = util.deflate_long(r, 0)
sstr = util.deflate_long(s, 0) sstr = util.deflate_long(s, 0)
if len(rstr) < 20: if len(rstr) < 20:
rstr = '\x00' * (20 - len(rstr)) + rstr rstr = zero_byte * (20 - len(rstr)) + rstr
if len(sstr) < 20: if len(sstr) < 20:
sstr = '\x00' * (20 - len(sstr)) + sstr sstr = zero_byte * (20 - len(sstr)) + sstr
m.add_string(rstr + sstr) m.add_string(rstr + sstr)
return m return m

View File

@ -164,7 +164,7 @@ class ECDSAKey (PKey):
s, padding = der.remove_sequence(data) s, padding = der.remove_sequence(data)
if padding: if padding:
if padding not in self.ALLOWED_PADDINGS: if padding not in self.ALLOWED_PADDINGS:
raise ValueError("weird padding: %s" % (binascii.hexlify(data))) raise ValueError("weird padding: %s" % u(binascii.hexlify(data)))
data = data[:-len(padding)] data = data[:-len(padding)]
key = SigningKey.from_der(data) key = SigningKey.from_der(data)
self.signing_key = key self.signing_key = key

View File

@ -36,6 +36,8 @@ c_MSG_KEXDH_INIT, c_MSG_KEXDH_REPLY = [byte_chr(c) for c in range(30, 32)]
P = 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF P = 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF
G = 2 G = 2
b7fffffffffffffff = byte_chr(0x7f) + max_byte * 7
b0000000000000000 = zero_byte * 8
class KexGroup1(object): class KexGroup1(object):
@ -43,9 +45,9 @@ class KexGroup1(object):
def __init__(self, transport): def __init__(self, transport):
self.transport = transport self.transport = transport
self.x = long_zero self.x = long(0)
self.e = long_zero self.e = long(0)
self.f = long_zero self.f = long(0)
def start_kex(self): def start_kex(self):
self._generate_x() self._generate_x()
@ -82,8 +84,8 @@ class KexGroup1(object):
while 1: while 1:
x_bytes = self.transport.rng.read(128) x_bytes = self.transport.rng.read(128)
x_bytes = byte_mask(x_bytes[0], 0x7f) + x_bytes[1:] x_bytes = byte_mask(x_bytes[0], 0x7f) + x_bytes[1:]
if (x_bytes[:8] != '\x7F\xFF\xFF\xFF\xFF\xFF\xFF\xFF') and \ if (x_bytes[:8] != b7fffffffffffffff) and \
(x_bytes[:8] != '\x00\x00\x00\x00\x00\x00\x00\x00'): (x_bytes[:8] != b0000000000000000):
break break
self.x = util.inflate_long(x_bytes) self.x = util.inflate_long(x_bytes)

View File

@ -91,8 +91,8 @@ class Packetizer (object):
self.__mac_key_in = bytes() self.__mac_key_in = bytes()
self.__compress_engine_out = None self.__compress_engine_out = None
self.__compress_engine_in = None self.__compress_engine_in = None
self.__sequence_number_out = long_zero self.__sequence_number_out = 0
self.__sequence_number_in = long_zero self.__sequence_number_in = 0
# lock around outbound writes (packet computation) # lock around outbound writes (packet computation)
self.__write_lock = threading.RLock() self.__write_lock = threading.RLock()
@ -153,6 +153,7 @@ class Packetizer (object):
def close(self): def close(self):
self.__closed = True self.__closed = True
self.__socket.close()
def set_hexdump(self, hexdump): def set_hexdump(self, hexdump):
self.__dump_packets = hexdump self.__dump_packets = hexdump

View File

@ -125,7 +125,7 @@ class ModulusPack (object):
f.close() f.close()
def get_modulus(self, min, prefer, max): def get_modulus(self, min, prefer, max):
bitsizes = sorted(self.pack.keys(), key=hash) bitsizes = sorted(self.pack.keys())
if len(bitsizes) == 0: if len(bitsizes) == 0:
raise SSHException('no moduli available') raise SSHException('no moduli available')
good = -1 good = -1

View File

@ -32,6 +32,8 @@ from paramiko.ber import BER, BERException
from paramiko.pkey import PKey from paramiko.pkey import PKey
from paramiko.ssh_exception import SSHException from paramiko.ssh_exception import SSHException
SHA1_DIGESTINFO = unhexlify(b('3021300906052b0e03021a05000414'))
class RSAKey (PKey): class RSAKey (PKey):
""" """
@ -92,7 +94,7 @@ class RSAKey (PKey):
def sign_ssh_data(self, rpool, data): def sign_ssh_data(self, rpool, data):
digest = SHA.new(data).digest() digest = SHA.new(data).digest()
rsa = RSA.construct((long(self.n), long(self.e), long(self.d))) rsa = RSA.construct((long(self.n), long(self.e), long(self.d)))
sig = util.deflate_long(rsa.sign(self._pkcs1imify(digest), '')[0], 0) sig = util.deflate_long(rsa.sign(self._pkcs1imify(digest), bytes())[0], 0)
m = Message() m = Message()
m.add_string('ssh-rsa') m.add_string('ssh-rsa')
m.add_string(sig) m.add_string(sig)
@ -158,7 +160,6 @@ class RSAKey (PKey):
turn a 20-byte SHA1 hash into a blob of data as large as the key's N, turn a 20-byte SHA1 hash into a blob of data as large as the key's N,
using PKCS1's \"emsa-pkcs1-v1_5\" encoding. totally bizarre. using PKCS1's \"emsa-pkcs1-v1_5\" encoding. totally bizarre.
""" """
SHA1_DIGESTINFO = unhexlify(b('3021300906052b0e03021a05000414'))
size = len(util.deflate_long(self.n, 0)) size = len(util.deflate_long(self.n, 0))
filler = max_byte * (size - len(SHA1_DIGESTINFO) - len(data) - 3) filler = max_byte * (size - len(SHA1_DIGESTINFO) - len(data) - 3)
return zero_byte + one_byte + filler + zero_byte + SHA1_DIGESTINFO + data return zero_byte + one_byte + filler + zero_byte + SHA1_DIGESTINFO + data

View File

@ -186,4 +186,4 @@ class BaseSFTP (object):
t = byte_ord(data[0]) t = byte_ord(data[0])
#self._log(DEBUG2, 'read: %s (len=%d)' % (CMD_NAMES.get(t), '0x%02x' % t, len(data)-1)) #self._log(DEBUG2, 'read: %s (len=%d)' % (CMD_NAMES.get(t), '0x%02x' % t, len(data)-1))
return t, data[1:] return t, data[1:]
return 0, '' return 0, bytes()

View File

@ -94,7 +94,7 @@ class SFTPFile (BufferedFile):
k = [i for i in self._prefetch_reads if i[0] <= offset] k = [i for i in self._prefetch_reads if i[0] <= offset]
if len(k) == 0: if len(k) == 0:
return False return False
k.sort(key=hash) k.sort(key=lambda x: x[0])
buf_offset, buf_size = k[-1] buf_offset, buf_size = k[-1]
if buf_offset + buf_size <= offset: if buf_offset + buf_size <= offset:
# prefetch request ends before this one begins # prefetch request ends before this one begins

View File

@ -272,7 +272,7 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
self._send_status(request_number, SFTP_FAILURE, 'Block size too small') self._send_status(request_number, SFTP_FAILURE, 'Block size too small')
return return
sum_out = '' sum_out = bytes()
offset = start offset = start
while offset < start + length: while offset < start + length:
blocklen = min(block_size, start + length - offset) blocklen = min(block_size, start + length - offset)
@ -342,7 +342,7 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
self._send_status(request_number, SFTP_BAD_MESSAGE, 'Invalid handle') self._send_status(request_number, SFTP_BAD_MESSAGE, 'Invalid handle')
return return
data = self.file_table[handle].read(offset, length) data = self.file_table[handle].read(offset, length)
if type(data) is str: if isinstance(data, (bytes_types, string_types)):
if len(data) == 0: if len(data) == 0:
self._send_status(request_number, SFTP_EOF) self._send_status(request_number, SFTP_EOF)
else: else:
@ -420,7 +420,7 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
elif t == CMD_READLINK: elif t == CMD_READLINK:
path = msg.get_text() path = msg.get_text()
resp = self.server.readlink(path) resp = self.server.readlink(path)
if isinstance(resp, string_types): if isinstance(resp, (bytes_types, string_types)):
self._response(request_number, CMD_NAME, 1, resp, '', SFTPAttributes()) self._response(request_number, CMD_NAME, 1, resp, '', SFTPAttributes())
else: else:
self._send_status(request_number, resp) self._send_status(request_number, resp)

View File

@ -80,6 +80,7 @@ class SecurityOptions (object):
tuple to one of the fields, C{TypeError} will be raised. tuple to one of the fields, C{TypeError} will be raised.
""" """
#__slots__ = [ 'ciphers', 'digests', 'key_types', 'kex', 'compression', '_transport' ] #__slots__ = [ 'ciphers', 'digests', 'key_types', 'kex', 'compression', '_transport' ]
__slots__ = '_transport'
def __init__(self, transport): def __init__(self, transport):
self._transport = transport self._transport = transport
@ -402,6 +403,7 @@ class Transport (threading.Thread):
@since: 1.5.3 @since: 1.5.3
""" """
self.sock.close()
self.close() self.close()
def get_security_options(self): def get_security_options(self):
@ -691,7 +693,7 @@ class Transport (threading.Thread):
""" """
return self.open_channel('auth-agent@openssh.com') return self.open_channel('auth-agent@openssh.com')
def open_forwarded_tcpip_channel(self, src_addr_port, dest_addr_port): def open_forwarded_tcpip_channel(self, src_addr, dest_addr):
""" """
Request a new channel back to the client, of type C{"forwarded-tcpip"}. Request a new channel back to the client, of type C{"forwarded-tcpip"}.
This is used after a client has requested port forwarding, for sending This is used after a client has requested port forwarding, for sending
@ -702,9 +704,7 @@ class Transport (threading.Thread):
@param dest_addr: local (server) connected address @param dest_addr: local (server) connected address
@param dest_port: local (server) connected port @param dest_port: local (server) connected port
""" """
src_addr, src_port = src_addr_port return self.open_channel('forwarded-tcpip', dest_addr, src_addr)
dest_addr, dest_port = dest_addr_port
return self.open_channel('forwarded-tcpip', (dest_addr, dest_port), (src_addr, src_port))
def open_channel(self, kind, dest_addr=None, src_addr=None): def open_channel(self, kind, dest_addr=None, src_addr=None):
""" """
@ -805,7 +805,6 @@ class Transport (threading.Thread):
""" """
if not self.active: if not self.active:
raise SSHException('SSH session not active') raise SSHException('SSH session not active')
address = address
port = int(port) port = int(port)
response = self.global_request('tcpip-forward', (address, port), wait=True) response = self.global_request('tcpip-forward', (address, port), wait=True)
if response is None: if response is None:
@ -1642,7 +1641,7 @@ class Transport (threading.Thread):
self.completion_event.set() self.completion_event.set()
if self.auth_handler is not None: if self.auth_handler is not None:
self.auth_handler.abort() self.auth_handler.abort()
for event in list(self.channel_events.values()): for event in self.channel_events.values():
event.set() event.set()
try: try:
self.lock.acquire() self.lock.acquire()

View File

@ -48,7 +48,7 @@ if sys.version_info < (2,3):
def inflate_long(s, always_positive=False): def inflate_long(s, always_positive=False):
"turns a normalized byte string into a long-int (adapted from Crypto.Util.number)" "turns a normalized byte string into a long-int (adapted from Crypto.Util.number)"
out = long_zero out = long(0)
negative = 0 negative = 0
if not always_positive and (len(s) > 0) and (byte_ord(s[0]) >= 0x80): if not always_positive and (len(s) > 0) and (byte_ord(s[0]) >= 0x80):
negative = 1 negative = 1
@ -60,7 +60,7 @@ def inflate_long(s, always_positive=False):
for i in range(0, len(s), 4): for i in range(0, len(s), 4):
out = (out << 32) + struct.unpack('>I', s[i:i+4])[0] out = (out << 32) + struct.unpack('>I', s[i:i+4])[0]
if negative: if negative:
out -= (long_one << (8 * len(s))) out -= (long(1) << (8 * len(s)))
return out return out
deflate_zero = 0 if PY3 else zero_byte deflate_zero = 0 if PY3 else zero_byte
@ -128,6 +128,9 @@ def safe_string(s):
# ''.join([['%%%02X' % ord(c), c][(ord(c) >= 32) and (ord(c) <= 127)] for c in s]) # ''.join([['%%%02X' % ord(c), c][(ord(c) >= 32) and (ord(c) <= 127)] for c in s])
def bit_length(n): def bit_length(n):
try:
return n.bitlength()
except AttributeError:
norm = deflate_long(n, 0) norm = deflate_long(n, 0)
hbyte = byte_ord(norm[0]) hbyte = byte_ord(norm[0])
if hbyte == 0: if hbyte == 0:
@ -276,7 +279,7 @@ def retry_on_signal(function):
class Counter (object): class Counter (object):
"""Stateful counter for CTR mode crypto""" """Stateful counter for CTR mode crypto"""
def __init__(self, nbits, initial_value=long_one, overflow=long_zero): def __init__(self, nbits, initial_value=long(1), overflow=long(0)):
self.blocksize = nbits / 8 self.blocksize = nbits / 8
self.overflow = overflow self.overflow = overflow
# start with value - 1 so we don't have to store intermediate values when counting # start with value - 1 so we don't have to store intermediate values when counting
@ -300,6 +303,6 @@ class Counter (object):
self.value = array.array('c', zero_byte * (self.blocksize - len(x)) + x) self.value = array.array('c', zero_byte * (self.blocksize - len(x)) + x)
return self.value.tostring() return self.value.tostring()
def new(cls, nbits, initial_value=long_one, overflow=long_zero): def new(cls, nbits, initial_value=long(1), overflow=long(0)):
return cls(nbits, initial_value=initial_value, overflow=overflow) return cls(nbits, initial_value=initial_value, overflow=overflow)
new = classmethod(new) new = classmethod(new)

View File

@ -28,6 +28,7 @@ import threading
import array import array
import platform import platform
import ctypes.wintypes import ctypes.wintypes
from paramiko.util import *
from . import _winapi from . import _winapi
@ -82,7 +83,7 @@ def _query_pageant(msg):
with pymap: with pymap:
pymap.write(msg) pymap.write(msg)
# Create an array buffer containing the mapped filename # Create an array buffer containing the mapped filename
char_buffer = array.array("c", map_name + '\0') char_buffer = array.array("c", b(map_name) + zero_byte)
char_buffer_address, char_buffer_size = char_buffer.buffer_info() char_buffer_address, char_buffer_size = char_buffer.buffer_info()
# Create a string to use for the SendMessage function call # Create a string to use for the SendMessage function call
cds = COPYDATASTRUCT(_AGENT_COPYDATA_ID, char_buffer_size, cds = COPYDATASTRUCT(_AGENT_COPYDATA_ID, char_buffer_size,

View File

@ -59,7 +59,7 @@ class LoopSocket (object):
try: try:
if self.__mate is None: if self.__mate is None:
# EOF # EOF
return '' return bytes()
if len(self.__in_buffer) == 0: if len(self.__in_buffer) == 0:
self.__cv.wait(self.__timeout) self.__cv.wait(self.__timeout)
if len(self.__in_buffer) == 0: if len(self.__in_buffer) == 0:

View File

@ -207,8 +207,14 @@ class SSHClientTest (unittest.TestCase):
self.assert_(p() is not None) self.assert_(p() is not None)
del self.tc del self.tc
# hrm, sometimes p isn't cleared right away. why is that? # hrm, sometimes p isn't cleared right away. why is that?
st = time.time() #st = time.time()
while (time.time() - st < 5.0) and (p() is not None): #while (time.time() - st < 5.0) and (p() is not None):
time.sleep(0.1) # time.sleep(0.1)
# instead of dumbly waiting for the GC to collect, force a collection
# to see whether the SSHClient object is deallocated correctly
import gc
gc.collect()
self.assert_(p() is None) self.assert_(p() is None)

View File

@ -611,12 +611,12 @@ class SFTPTest (unittest.TestCase):
try: try:
f = sftp.open(FOLDER + '/kitty.txt', 'r') f = sftp.open(FOLDER + '/kitty.txt', 'r')
sum = f.check('sha1') sum = f.check('sha1')
self.assertEqual('91059CFC6615941378D413CB5ADAF4C5EB293402', hexlify(sum).upper()) self.assertEqual('91059CFC6615941378D413CB5ADAF4C5EB293402', u(hexlify(sum)).upper())
sum = f.check('md5', 0, 512) sum = f.check('md5', 0, 512)
self.assertEqual('93DE4788FCA28D471516963A1FE3856A', hexlify(sum).upper()) self.assertEqual('93DE4788FCA28D471516963A1FE3856A', u(hexlify(sum)).upper())
sum = f.check('md5', 0, 0, 510) sum = f.check('md5', 0, 0, 510)
self.assertEqual('EB3B45B8CD55A0707D99B177544A319F373183D241432BB2157AB9E46358C4AC90370B5CADE5D90336FC1716F90B36D6', self.assertEqual('EB3B45B8CD55A0707D99B177544A319F373183D241432BB2157AB9E46358C4AC90370B5CADE5D90336FC1716F90B36D6',
hexlify(sum).upper()) u(hexlify(sum)).upper())
f.close() f.close()
finally: finally:
sftp.unlink(FOLDER + '/kitty.txt') sftp.unlink(FOLDER + '/kitty.txt')

View File

@ -92,7 +92,7 @@ class BigSFTPTest (unittest.TestCase):
write a 1MB file with no buffering. write a 1MB file with no buffering.
""" """
sftp = get_sftp() sftp = get_sftp()
kblob = (1024 * 'x') kblob = (1024 * b('x'))
start = time.time() start = time.time()
try: try:
f = sftp.open('%s/hongry.txt' % FOLDER, 'w') f = sftp.open('%s/hongry.txt' % FOLDER, 'w')
@ -246,7 +246,7 @@ class BigSFTPTest (unittest.TestCase):
without using it, to verify that paramiko doesn't get confused. without using it, to verify that paramiko doesn't get confused.
""" """
sftp = get_sftp() sftp = get_sftp()
kblob = (1024 * 'x') kblob = (1024 * b('x'))
try: try:
f = sftp.open('%s/hongry.txt' % FOLDER, 'w') f = sftp.open('%s/hongry.txt' % FOLDER, 'w')
f.set_pipelined(True) f.set_pipelined(True)
@ -347,7 +347,7 @@ class BigSFTPTest (unittest.TestCase):
write a 1MB file, with no linefeeds, and a big buffer. write a 1MB file, with no linefeeds, and a big buffer.
""" """
sftp = get_sftp() sftp = get_sftp()
mblob = (1024 * 1024 * 'x') mblob = (1024 * 1024 * b('x'))
try: try:
f = sftp.open('%s/hongry.txt' % FOLDER, 'w', 128 * 1024) f = sftp.open('%s/hongry.txt' % FOLDER, 'w', 128 * 1024)
f.write(mblob) f.write(mblob)
@ -364,7 +364,7 @@ class BigSFTPTest (unittest.TestCase):
sftp = get_sftp() sftp = get_sftp()
t = sftp.sock.get_transport() t = sftp.sock.get_transport()
t.packetizer.REKEY_BYTES = 512 * 1024 t.packetizer.REKEY_BYTES = 512 * 1024
k32blob = (32 * 1024 * 'x') k32blob = (32 * 1024 * b('x'))
try: try:
f = sftp.open('%s/hongry.txt' % FOLDER, 'w', 128 * 1024) f = sftp.open('%s/hongry.txt' % FOLDER, 'w', 128 * 1024)
for i in range(32): for i in range(32):