From 9662a7f779636f0328263a81cdeb76af25802970 Mon Sep 17 00:00:00 2001 From: Scott Maxwell Date: Fri, 1 Nov 2013 09:49:52 -0700 Subject: [PATCH] Changes inspired by the nischu7 branch --- paramiko/auth_handler.py | 4 ++-- paramiko/ber.py | 4 ++-- paramiko/channel.py | 4 ++-- paramiko/common.py | 2 -- paramiko/dsskey.py | 4 ++-- paramiko/ecdsakey.py | 2 +- paramiko/kex_group1.py | 12 +++++++----- paramiko/packet.py | 5 +++-- paramiko/primes.py | 2 +- paramiko/rsakey.py | 5 +++-- paramiko/sftp.py | 2 +- paramiko/sftp_file.py | 2 +- paramiko/sftp_server.py | 6 +++--- paramiko/transport.py | 11 +++++------ paramiko/util.py | 29 ++++++++++++++++------------- paramiko/win_pageant.py | 3 ++- tests/loop.py | 2 +- tests/test_client.py | 12 +++++++++--- tests/test_sftp.py | 6 +++--- tests/test_sftp_big.py | 8 ++++---- 20 files changed, 68 insertions(+), 57 deletions(-) diff --git a/paramiko/auth_handler.py b/paramiko/auth_handler.py index 2a65355..83f27a1 100644 --- a/paramiko/auth_handler.py +++ b/paramiko/auth_handler.py @@ -245,7 +245,7 @@ class AuthHandler (object): m.add_byte(cMSG_USERAUTH_INFO_REQUEST) m.add_string(q.name) m.add_string(q.instructions) - m.add_string('') + m.add_string(bytes()) m.add_int(len(q.prompts)) for p in q.prompts: m.add_string(p[0]) @@ -375,7 +375,7 @@ class AuthHandler (object): def _parse_userauth_banner(self, m): banner = m.get_string() lang = m.get_string() - self.transport._log(INFO, 'Auth banner: ' + banner) + self.transport._log(INFO, 'Auth banner: %s' % banner) # who cares. def _parse_userauth_info_request(self, m): diff --git a/paramiko/ber.py b/paramiko/ber.py index f4d2acc..c4f3521 100644 --- a/paramiko/ber.py +++ b/paramiko/ber.py @@ -113,9 +113,9 @@ class BER(object): def encode(self, x): if type(x) is bool: if x: - self.encode_tlv(1, '\xff') + self.encode_tlv(1, max_byte) else: - self.encode_tlv(1, '\x00') + self.encode_tlv(1, zero_byte) elif (type(x) is int) or (type(x) is long): self.encode_tlv(2, util.deflate_long(x)) elif type(x) is str: diff --git a/paramiko/channel.py b/paramiko/channel.py index 6a8a798..9980fce 100644 --- a/paramiko/channel.py +++ b/paramiko/channel.py @@ -157,7 +157,7 @@ class Channel (object): m.add_int(height) m.add_int(width_pixels) m.add_int(height_pixels) - m.add_string('') + m.add_string(bytes()) self._event_pending() self.transport._send_user_message(m) self._wait_for_event() @@ -477,7 +477,7 @@ class Channel (object): @since: 1.1 """ - data = '' + data = bytes() self.lock.acquire() try: old = self.combine_stderr diff --git a/paramiko/common.py b/paramiko/common.py index 37d5ee8..476ebf5 100644 --- a/paramiko/common.py +++ b/paramiko/common.py @@ -148,8 +148,6 @@ def asbytes(s): xffffffff = long(0xffffffff) x80000000 = long(0x80000000) -long_zero = long(0) -long_one = long(1) o666 = 438 o660 = 432 o644 = 420 diff --git a/paramiko/dsskey.py b/paramiko/dsskey.py index 715f9f6..4c97b26 100644 --- a/paramiko/dsskey.py +++ b/paramiko/dsskey.py @@ -110,9 +110,9 @@ class DSSKey (PKey): rstr = util.deflate_long(r, 0) sstr = util.deflate_long(s, 0) if len(rstr) < 20: - rstr = '\x00' * (20 - len(rstr)) + rstr + rstr = zero_byte * (20 - len(rstr)) + rstr if len(sstr) < 20: - sstr = '\x00' * (20 - len(sstr)) + sstr + sstr = zero_byte * (20 - len(sstr)) + sstr m.add_string(rstr + sstr) return m diff --git a/paramiko/ecdsakey.py b/paramiko/ecdsakey.py index 8585e6f..5f9dff2 100644 --- a/paramiko/ecdsakey.py +++ b/paramiko/ecdsakey.py @@ -164,7 +164,7 @@ class ECDSAKey (PKey): s, padding = der.remove_sequence(data) if padding: if padding not in self.ALLOWED_PADDINGS: - raise ValueError("weird padding: %s" % (binascii.hexlify(data))) + raise ValueError("weird padding: %s" % u(binascii.hexlify(data))) data = data[:-len(padding)] key = SigningKey.from_der(data) self.signing_key = key diff --git a/paramiko/kex_group1.py b/paramiko/kex_group1.py index ea452b3..05693a1 100644 --- a/paramiko/kex_group1.py +++ b/paramiko/kex_group1.py @@ -36,6 +36,8 @@ c_MSG_KEXDH_INIT, c_MSG_KEXDH_REPLY = [byte_chr(c) for c in range(30, 32)] P = 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF G = 2 +b7fffffffffffffff = byte_chr(0x7f) + max_byte * 7 +b0000000000000000 = zero_byte * 8 class KexGroup1(object): @@ -43,9 +45,9 @@ class KexGroup1(object): def __init__(self, transport): self.transport = transport - self.x = long_zero - self.e = long_zero - self.f = long_zero + self.x = long(0) + self.e = long(0) + self.f = long(0) def start_kex(self): self._generate_x() @@ -82,8 +84,8 @@ class KexGroup1(object): while 1: x_bytes = self.transport.rng.read(128) x_bytes = byte_mask(x_bytes[0], 0x7f) + x_bytes[1:] - if (x_bytes[:8] != '\x7F\xFF\xFF\xFF\xFF\xFF\xFF\xFF') and \ - (x_bytes[:8] != '\x00\x00\x00\x00\x00\x00\x00\x00'): + if (x_bytes[:8] != b7fffffffffffffff) and \ + (x_bytes[:8] != b0000000000000000): break self.x = util.inflate_long(x_bytes) diff --git a/paramiko/packet.py b/paramiko/packet.py index fa5ceff..37028bb 100644 --- a/paramiko/packet.py +++ b/paramiko/packet.py @@ -91,8 +91,8 @@ class Packetizer (object): self.__mac_key_in = bytes() self.__compress_engine_out = None self.__compress_engine_in = None - self.__sequence_number_out = long_zero - self.__sequence_number_in = long_zero + self.__sequence_number_out = 0 + self.__sequence_number_in = 0 # lock around outbound writes (packet computation) self.__write_lock = threading.RLock() @@ -153,6 +153,7 @@ class Packetizer (object): def close(self): self.__closed = True + self.__socket.close() def set_hexdump(self, hexdump): self.__dump_packets = hexdump diff --git a/paramiko/primes.py b/paramiko/primes.py index 144454a..4db6d52 100644 --- a/paramiko/primes.py +++ b/paramiko/primes.py @@ -125,7 +125,7 @@ class ModulusPack (object): f.close() def get_modulus(self, min, prefer, max): - bitsizes = sorted(self.pack.keys(), key=hash) + bitsizes = sorted(self.pack.keys()) if len(bitsizes) == 0: raise SSHException('no moduli available') good = -1 diff --git a/paramiko/rsakey.py b/paramiko/rsakey.py index b4222a3..0a27119 100644 --- a/paramiko/rsakey.py +++ b/paramiko/rsakey.py @@ -32,6 +32,8 @@ from paramiko.ber import BER, BERException from paramiko.pkey import PKey from paramiko.ssh_exception import SSHException +SHA1_DIGESTINFO = unhexlify(b('3021300906052b0e03021a05000414')) + class RSAKey (PKey): """ @@ -92,7 +94,7 @@ class RSAKey (PKey): def sign_ssh_data(self, rpool, data): digest = SHA.new(data).digest() rsa = RSA.construct((long(self.n), long(self.e), long(self.d))) - sig = util.deflate_long(rsa.sign(self._pkcs1imify(digest), '')[0], 0) + sig = util.deflate_long(rsa.sign(self._pkcs1imify(digest), bytes())[0], 0) m = Message() m.add_string('ssh-rsa') m.add_string(sig) @@ -158,7 +160,6 @@ class RSAKey (PKey): turn a 20-byte SHA1 hash into a blob of data as large as the key's N, using PKCS1's \"emsa-pkcs1-v1_5\" encoding. totally bizarre. """ - SHA1_DIGESTINFO = unhexlify(b('3021300906052b0e03021a05000414')) size = len(util.deflate_long(self.n, 0)) filler = max_byte * (size - len(SHA1_DIGESTINFO) - len(data) - 3) return zero_byte + one_byte + filler + zero_byte + SHA1_DIGESTINFO + data diff --git a/paramiko/sftp.py b/paramiko/sftp.py index 4186460..3e05de9 100644 --- a/paramiko/sftp.py +++ b/paramiko/sftp.py @@ -186,4 +186,4 @@ class BaseSFTP (object): t = byte_ord(data[0]) #self._log(DEBUG2, 'read: %s (len=%d)' % (CMD_NAMES.get(t), '0x%02x' % t, len(data)-1)) return t, data[1:] - return 0, '' + return 0, bytes() diff --git a/paramiko/sftp_file.py b/paramiko/sftp_file.py index 754d038..480c371 100644 --- a/paramiko/sftp_file.py +++ b/paramiko/sftp_file.py @@ -94,7 +94,7 @@ class SFTPFile (BufferedFile): k = [i for i in self._prefetch_reads if i[0] <= offset] if len(k) == 0: return False - k.sort(key=hash) + k.sort(key=lambda x: x[0]) buf_offset, buf_size = k[-1] if buf_offset + buf_size <= offset: # prefetch request ends before this one begins diff --git a/paramiko/sftp_server.py b/paramiko/sftp_server.py index 19f1ab9..f591a21 100644 --- a/paramiko/sftp_server.py +++ b/paramiko/sftp_server.py @@ -272,7 +272,7 @@ class SFTPServer (BaseSFTP, SubsystemHandler): self._send_status(request_number, SFTP_FAILURE, 'Block size too small') return - sum_out = '' + sum_out = bytes() offset = start while offset < start + length: blocklen = min(block_size, start + length - offset) @@ -342,7 +342,7 @@ class SFTPServer (BaseSFTP, SubsystemHandler): self._send_status(request_number, SFTP_BAD_MESSAGE, 'Invalid handle') return data = self.file_table[handle].read(offset, length) - if type(data) is str: + if isinstance(data, (bytes_types, string_types)): if len(data) == 0: self._send_status(request_number, SFTP_EOF) else: @@ -420,7 +420,7 @@ class SFTPServer (BaseSFTP, SubsystemHandler): elif t == CMD_READLINK: path = msg.get_text() resp = self.server.readlink(path) - if isinstance(resp, string_types): + if isinstance(resp, (bytes_types, string_types)): self._response(request_number, CMD_NAME, 1, resp, '', SFTPAttributes()) else: self._send_status(request_number, resp) diff --git a/paramiko/transport.py b/paramiko/transport.py index 46e0fe8..2008ecb 100644 --- a/paramiko/transport.py +++ b/paramiko/transport.py @@ -80,6 +80,7 @@ class SecurityOptions (object): tuple to one of the fields, C{TypeError} will be raised. """ #__slots__ = [ 'ciphers', 'digests', 'key_types', 'kex', 'compression', '_transport' ] + __slots__ = '_transport' def __init__(self, transport): self._transport = transport @@ -402,6 +403,7 @@ class Transport (threading.Thread): @since: 1.5.3 """ + self.sock.close() self.close() def get_security_options(self): @@ -691,7 +693,7 @@ class Transport (threading.Thread): """ return self.open_channel('auth-agent@openssh.com') - def open_forwarded_tcpip_channel(self, src_addr_port, dest_addr_port): + def open_forwarded_tcpip_channel(self, src_addr, dest_addr): """ Request a new channel back to the client, of type C{"forwarded-tcpip"}. This is used after a client has requested port forwarding, for sending @@ -702,9 +704,7 @@ class Transport (threading.Thread): @param dest_addr: local (server) connected address @param dest_port: local (server) connected port """ - src_addr, src_port = src_addr_port - dest_addr, dest_port = dest_addr_port - return self.open_channel('forwarded-tcpip', (dest_addr, dest_port), (src_addr, src_port)) + return self.open_channel('forwarded-tcpip', dest_addr, src_addr) def open_channel(self, kind, dest_addr=None, src_addr=None): """ @@ -805,7 +805,6 @@ class Transport (threading.Thread): """ if not self.active: raise SSHException('SSH session not active') - address = address port = int(port) response = self.global_request('tcpip-forward', (address, port), wait=True) if response is None: @@ -1642,7 +1641,7 @@ class Transport (threading.Thread): self.completion_event.set() if self.auth_handler is not None: self.auth_handler.abort() - for event in list(self.channel_events.values()): + for event in self.channel_events.values(): event.set() try: self.lock.acquire() diff --git a/paramiko/util.py b/paramiko/util.py index 7844fc6..b2ac3f5 100644 --- a/paramiko/util.py +++ b/paramiko/util.py @@ -48,7 +48,7 @@ if sys.version_info < (2,3): def inflate_long(s, always_positive=False): "turns a normalized byte string into a long-int (adapted from Crypto.Util.number)" - out = long_zero + out = long(0) negative = 0 if not always_positive and (len(s) > 0) and (byte_ord(s[0]) >= 0x80): negative = 1 @@ -60,7 +60,7 @@ def inflate_long(s, always_positive=False): for i in range(0, len(s), 4): out = (out << 32) + struct.unpack('>I', s[i:i+4])[0] if negative: - out -= (long_one << (8 * len(s))) + out -= (long(1) << (8 * len(s))) return out deflate_zero = 0 if PY3 else zero_byte @@ -128,15 +128,18 @@ def safe_string(s): # ''.join([['%%%02X' % ord(c), c][(ord(c) >= 32) and (ord(c) <= 127)] for c in s]) def bit_length(n): - norm = deflate_long(n, 0) - hbyte = byte_ord(norm[0]) - if hbyte == 0: - return 1 - bitlen = len(norm) * 8 - while not (hbyte & 0x80): - hbyte <<= 1 - bitlen -= 1 - return bitlen + try: + return n.bitlength() + except AttributeError: + norm = deflate_long(n, 0) + hbyte = byte_ord(norm[0]) + if hbyte == 0: + return 1 + bitlen = len(norm) * 8 + while not (hbyte & 0x80): + hbyte <<= 1 + bitlen -= 1 + return bitlen def tb_strings(): return ''.join(traceback.format_exception(*sys.exc_info())).split('\n') @@ -276,7 +279,7 @@ def retry_on_signal(function): class Counter (object): """Stateful counter for CTR mode crypto""" - def __init__(self, nbits, initial_value=long_one, overflow=long_zero): + def __init__(self, nbits, initial_value=long(1), overflow=long(0)): self.blocksize = nbits / 8 self.overflow = overflow # start with value - 1 so we don't have to store intermediate values when counting @@ -300,6 +303,6 @@ class Counter (object): self.value = array.array('c', zero_byte * (self.blocksize - len(x)) + x) return self.value.tostring() - def new(cls, nbits, initial_value=long_one, overflow=long_zero): + def new(cls, nbits, initial_value=long(1), overflow=long(0)): return cls(nbits, initial_value=initial_value, overflow=overflow) new = classmethod(new) diff --git a/paramiko/win_pageant.py b/paramiko/win_pageant.py index de1cd64..c0d0f4a 100644 --- a/paramiko/win_pageant.py +++ b/paramiko/win_pageant.py @@ -28,6 +28,7 @@ import threading import array import platform import ctypes.wintypes +from paramiko.util import * from . import _winapi @@ -82,7 +83,7 @@ def _query_pageant(msg): with pymap: pymap.write(msg) # Create an array buffer containing the mapped filename - char_buffer = array.array("c", map_name + '\0') + char_buffer = array.array("c", b(map_name) + zero_byte) char_buffer_address, char_buffer_size = char_buffer.buffer_info() # Create a string to use for the SendMessage function call cds = COPYDATASTRUCT(_AGENT_COPYDATA_ID, char_buffer_size, diff --git a/tests/loop.py b/tests/loop.py index b2c73dd..6e933f8 100644 --- a/tests/loop.py +++ b/tests/loop.py @@ -59,7 +59,7 @@ class LoopSocket (object): try: if self.__mate is None: # EOF - return '' + return bytes() if len(self.__in_buffer) == 0: self.__cv.wait(self.__timeout) if len(self.__in_buffer) == 0: diff --git a/tests/test_client.py b/tests/test_client.py index a8d0463..e6d1069 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -207,8 +207,14 @@ class SSHClientTest (unittest.TestCase): self.assert_(p() is not None) del self.tc # hrm, sometimes p isn't cleared right away. why is that? - st = time.time() - while (time.time() - st < 5.0) and (p() is not None): - time.sleep(0.1) + #st = time.time() + #while (time.time() - st < 5.0) and (p() is not None): + # time.sleep(0.1) + + # instead of dumbly waiting for the GC to collect, force a collection + # to see whether the SSHClient object is deallocated correctly + import gc + gc.collect() + self.assert_(p() is None) diff --git a/tests/test_sftp.py b/tests/test_sftp.py index 460e04c..20f68d8 100755 --- a/tests/test_sftp.py +++ b/tests/test_sftp.py @@ -611,12 +611,12 @@ class SFTPTest (unittest.TestCase): try: f = sftp.open(FOLDER + '/kitty.txt', 'r') sum = f.check('sha1') - self.assertEqual('91059CFC6615941378D413CB5ADAF4C5EB293402', hexlify(sum).upper()) + self.assertEqual('91059CFC6615941378D413CB5ADAF4C5EB293402', u(hexlify(sum)).upper()) sum = f.check('md5', 0, 512) - self.assertEqual('93DE4788FCA28D471516963A1FE3856A', hexlify(sum).upper()) + self.assertEqual('93DE4788FCA28D471516963A1FE3856A', u(hexlify(sum)).upper()) sum = f.check('md5', 0, 0, 510) self.assertEqual('EB3B45B8CD55A0707D99B177544A319F373183D241432BB2157AB9E46358C4AC90370B5CADE5D90336FC1716F90B36D6', - hexlify(sum).upper()) + u(hexlify(sum)).upper()) f.close() finally: sftp.unlink(FOLDER + '/kitty.txt') diff --git a/tests/test_sftp_big.py b/tests/test_sftp_big.py index b7ffe0b..c1d34d7 100644 --- a/tests/test_sftp_big.py +++ b/tests/test_sftp_big.py @@ -92,7 +92,7 @@ class BigSFTPTest (unittest.TestCase): write a 1MB file with no buffering. """ sftp = get_sftp() - kblob = (1024 * 'x') + kblob = (1024 * b('x')) start = time.time() try: f = sftp.open('%s/hongry.txt' % FOLDER, 'w') @@ -246,7 +246,7 @@ class BigSFTPTest (unittest.TestCase): without using it, to verify that paramiko doesn't get confused. """ sftp = get_sftp() - kblob = (1024 * 'x') + kblob = (1024 * b('x')) try: f = sftp.open('%s/hongry.txt' % FOLDER, 'w') f.set_pipelined(True) @@ -347,7 +347,7 @@ class BigSFTPTest (unittest.TestCase): write a 1MB file, with no linefeeds, and a big buffer. """ sftp = get_sftp() - mblob = (1024 * 1024 * 'x') + mblob = (1024 * 1024 * b('x')) try: f = sftp.open('%s/hongry.txt' % FOLDER, 'w', 128 * 1024) f.write(mblob) @@ -364,7 +364,7 @@ class BigSFTPTest (unittest.TestCase): sftp = get_sftp() t = sftp.sock.get_transport() t.packetizer.REKEY_BYTES = 512 * 1024 - k32blob = (32 * 1024 * 'x') + k32blob = (32 * 1024 * b('x')) try: f = sftp.open('%s/hongry.txt' % FOLDER, 'w', 128 * 1024) for i in range(32):