diff --git a/README b/README index ffecafe..01b8939 100644 --- a/README +++ b/README @@ -231,3 +231,5 @@ v0.9 FEAROW * would be nice to have an ftp-like interface to sftp (put, get, chdir...) +* speed up file transfers! +* what is psyco? diff --git a/paramiko/channel.py b/paramiko/channel.py index c6915e1..cd866c0 100644 --- a/paramiko/channel.py +++ b/paramiko/channel.py @@ -298,7 +298,6 @@ class Channel (object): m.add_boolean(0) m.add_int(status) self.transport._send_user_message(m) - self._log(DEBUG, 'EXIT-STATUS') def get_transport(self): """ @@ -468,7 +467,7 @@ class Channel (object): it means you may need to wait before more data arrives. @return: C{True} if a L{recv} call on this channel would immediately - return at least one byte; C{False} otherwise. + return at least one byte; C{False} otherwise. @rtype: boolean """ self.lock.acquire() @@ -492,7 +491,7 @@ class Channel (object): @rtype: str @raise socket.timeout: if no data is ready before the timeout set by - L{settimeout}. + L{settimeout}. """ out = '' self.lock.acquire() diff --git a/paramiko/packet.py b/paramiko/packet.py new file mode 100644 index 0000000..d93227c --- /dev/null +++ b/paramiko/packet.py @@ -0,0 +1,401 @@ +#!/usr/bin/python + +# Copyright (C) 2003-2005 Robey Pointer +# +# This file is part of paramiko. +# +# Paramiko is free software; you can redistribute it and/or modify it under the +# terms of the GNU Lesser General Public License as published by the Free +# Software Foundation; either version 2.1 of the License, or (at your option) +# any later version. +# +# Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY +# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR +# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with Paramiko; if not, write to the Free Software Foundation, Inc., +# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA. + +""" +Packetizer. +""" + +import select, socket, struct, threading, time +from Crypto.Hash import HMAC +from common import * +from message import Message +import util + + +class Packetizer (object): + """ + Implementation of the base SSH packet protocol. + """ + + # READ the secsh RFC's before raising these values. if anything, + # they should probably be lower. + REKEY_PACKETS = pow(2, 30) + REKEY_BYTES = pow(2, 30) + + def __init__(self, socket): + self.__socket = socket + self.__logger = None + self.__closed = False + self.__dump_packets = False + self.__need_rekey = False + + # used for noticing when to re-key: + self.__sent_bytes = 0 + self.__sent_packets = 0 + self.__received_bytes = 0 + self.__received_packets = 0 + self.__received_packets_overflow = 0 + + # current inbound/outbound ciphering: + self.__block_size_out = 8 + self.__block_size_in = 8 + self.__mac_size_out = 0 + self.__mac_size_in = 0 + self.__block_engine_out = None + self.__block_engine_in = None + self.__mac_engine_out = None + self.__mac_engine_in = None + self.__mac_key_out = '' + self.__mac_key_in = '' + self.__sequence_number_out = 0L + self.__sequence_number_in = 0L + + # lock around outbound writes (packet computation) + self.__write_lock = threading.RLock() + + # keepalives: + self.__keepalive_interval = 0 + self.__keepalive_last = time.time() + self.__keepalive_callback = None + + + def set_log(self, log): + """ + Set the python log object to use for logging. + """ + self.__logger = log + + def set_outbound_cipher(self, block_engine, block_size, mac_engine, mac_size, mac_key): + """ + Switch outbound data cipher. + """ + self.__block_engine_out = block_engine + self.__block_size_out = block_size + self.__mac_engine_out = mac_engine + self.__mac_size_out = mac_size + self.__mac_key_out = mac_key + self.__sent_bytes = 0 + self.__sent_packets = 0 + self.__need_rekey = False + + def set_inbound_cipher(self, block_engine, block_size, mac_engine, mac_size, mac_key): + """ + Switch inbound data cipher. + """ + self.__block_engine_in = block_engine + self.__block_size_in = block_size + self.__mac_engine_in = mac_engine + self.__mac_size_in = mac_size + self.__mac_key_in = mac_key + self.__received_bytes = 0 + self.__received_packets = 0 + self.__received_packets_overflow = 0 + self.__need_rekey = False + + def close(self): + self.__closed = True + self.__block_engine_in = None + self.__block_engine_out = None + self.__socket.close() + + def set_hexdump(self, hexdump): + self.__dump_packets = hexdump + + def get_hexdump(self): + return self.__dump_packets + + def get_mac_size_in(self): + return self.__mac_size_in + + def get_mac_size_out(self): + return self.__mac_size_out + + def need_rekey(self): + """ + Returns C{True} if a new set of keys needs to be negotiated. This + will be triggered during a packet read or write, so it should be + checked after every read or write, or at least after every few. + + @return: C{True} if a new set of keys needs to be negotiated + """ + return self.__need_rekey + + def set_keepalive(self, interval, callback): + """ + Turn on/off the callback keepalive. If C{interval} seconds pass with + no data read from or written to the socket, the callback will be + executed and the timer will be reset. + """ + self.__keepalive_interval = interval + self.__keepalive_callback = callback + self.__keepalive_last = time.time() + self._log(DEBUG, 'SET KEEPALIVE %r' % interval) + + def read_all(self, n): + """ + Read as close to N bytes as possible, blocking as long as necessary. + + @param n: number of bytes to read + @type n: int + @return: the data read + @rtype: str + @throw EOFError: if the socket was closed before all the bytes could + be read + """ + if PY22: + return self._py22_read_all(n) + out = '' + while n > 0: + try: + x = self.__socket.recv(n) + if len(x) == 0: + raise EOFError() + out += x + n -= len(x) + except socket.timeout: + if self.__closed: + raise EOFError() + self._check_keepalive() + return out + + def write_all(self, out): + self.__keepalive_last = time.time() + while len(out) > 0: + try: + n = self.__socket.send(out) + except socket.timeout: + n = 0 + if self.__closed: + n = -1 + except Exception, x: + # could be: (32, 'Broken pipe') + n = -1 + if n < 0: + raise EOFError() + if n == len(out): + return + out = out[n:] + return + + def readline(self, timeout): + """ + Read a line from the socket. This is done in a fairly inefficient + way, but is only used for initial banner negotiation so it's not worth + optimising. + """ + buffer = '' + while not '\n' in buffer: + buffer += self._read_timeout(timeout) + buffer = buffer[:-1] + if (len(buffer) > 0) and (buffer[-1] == '\r'): + buffer = buffer[:-1] + return buffer + + def send_message(self, data): + """ + Write a block of data using the current cipher, as an SSH block. + """ + # encrypt this sucka + data = str(data) + cmd = ord(data[0]) + if cmd in MSG_NAMES: + cmd_name = MSG_NAMES[cmd] + else: + cmd_name = '$%x' % cmd + self._log(DEBUG, 'Write packet <%s>, length %d' % (cmd_name, len(data))) + packet = self._build_packet(data) + if self.__dump_packets: + self._log(DEBUG, util.format_binary(packet, 'OUT: ')) + if self.__block_engine_out != None: + out = self.__block_engine_out.encrypt(packet) + else: + out = packet + # + mac + self.__write_lock.acquire() + try: + if self.__block_engine_out != None: + payload = struct.pack('>I', self.__sequence_number_out) + packet + out += HMAC.HMAC(self.__mac_key_out, payload, self.__mac_engine_out).digest()[:self.__mac_size_out] + self.__sequence_number_out = (self.__sequence_number_out + 1) & 0xffffffffL + self.write_all(out) + + self.__sent_bytes += len(out) + self.__sent_packets += 1 + if ((self.__sent_packets >= self.REKEY_PACKETS) or (self.__sent_bytes >= self.REKEY_BYTES)) \ + and not self.__need_rekey: + # only ask once for rekeying + self._log(DEBUG, 'Rekeying (hit %d packets, %d bytes sent)' % + (self.__sent_packets, self.__sent_bytes)) + self.__received_packets_overflow = 0 + self._trigger_rekey() + finally: + self.__write_lock.release() + + def read_message(self): + """ + Only one thread should ever be in this function (no other locking is + done). + + @throw SSHException: if the packet is mangled + """ + header = self.read_all(self.__block_size_in) + if self.__block_engine_in != None: + header = self.__block_engine_in.decrypt(header) + if self.__dump_packets: + self._log(DEBUG, util.format_binary(header, 'IN: ')); + packet_size = struct.unpack('>I', header[:4])[0] + # leftover contains decrypted bytes from the first block (after the length field) + leftover = header[4:] + if (packet_size - len(leftover)) % self.__block_size_in != 0: + raise SSHException('Invalid packet blocking') + buffer = self.read_all(packet_size + self.__mac_size_in - len(leftover)) + packet = buffer[:packet_size - len(leftover)] + post_packet = buffer[packet_size - len(leftover):] + if self.__block_engine_in != None: + packet = self.__block_engine_in.decrypt(packet) + if self.__dump_packets: + self._log(DEBUG, util.format_binary(packet, 'IN: ')); + packet = leftover + packet + + if self.__mac_size_in > 0: + mac = post_packet[:self.__mac_size_in] + mac_payload = struct.pack('>II', self.__sequence_number_in, packet_size) + packet + my_mac = HMAC.HMAC(self.__mac_key_in, mac_payload, self.__mac_engine_in).digest()[:self.__mac_size_in] + if my_mac != mac: + raise SSHException('Mismatched MAC') + padding = ord(packet[0]) + payload = packet[1:packet_size - padding + 1] + randpool.add_event(packet[packet_size - padding + 1]) + if self.__dump_packets: + self._log(DEBUG, 'Got payload (%d bytes, %d padding)' % (packet_size, padding)) + + msg = Message(payload[1:]) + msg.seqno = self.__sequence_number_in + self.__sequence_number_in = (self.__sequence_number_in + 1) & 0xffffffffL + + # check for rekey + self.__received_bytes += packet_size + self.__mac_size_in + 4 + self.__received_packets += 1 + if self.__need_rekey: + # we've asked to rekey -- give them 20 packets to comply before + # dropping the connection + self.__received_packets_overflow += 1 + if self.__received_packets_overflow >= 20: + raise SSHException('Remote transport is ignoring rekey requests') + elif (self.__received_packets >= self.REKEY_PACKETS) or \ + (self.__received_bytes >= self.REKEY_BYTES): + # only ask once for rekeying + self._log(DEBUG, 'Rekeying (hit %d packets, %d bytes received)' % + (self.__received_packets, self.__received_bytes)) + self.__received_packets_overflow = 0 + self._trigger_rekey() + + cmd = ord(payload[0]) + if cmd in MSG_NAMES: + cmd_name = MSG_NAMES[cmd] + else: + cmd_name = '$%x' % cmd + self._log(DEBUG, 'Read packet <%s>, length %d' % (cmd_name, len(payload))) + return cmd, msg + + + ########## protected + + + def _log(self, level, msg): + if self.__logger is None: + return + if issubclass(type(msg), list): + for m in msg: + self.__logger.log(level, m) + else: + self.__logger.log(level, msg) + + def _check_keepalive(self): + if (not self.__keepalive_interval) or (not self.__block_engine_out) or \ + self.__need_rekey: + # wait till we're encrypting, and not in the middle of rekeying + return + now = time.time() + if now > self.__keepalive_last + self.__keepalive_interval: + self.__keepalive_callback() + self.__keepalive_last = now + + def _py22_read_all(self, n): + out = '' + while n > 0: + r, w, e = select.select([self.__socket], [], [], 0.1) + if self.__socket not in r: + if self.__closed: + raise EOFError() + self._check_keepalive() + else: + x = self.__socket.recv(n) + if len(x) == 0: + raise EOFError() + out += x + n -= len(x) + return out + + def _py22_read_timeout(self, timeout): + start = time.time() + while True: + r, w, e = select.select([self.__socket], [], [], 0.1) + if self.__socket in r: + x = self.__socket.recv(1) + if len(x) == 0: + raise EOFError() + return x + if self.__closed: + raise EOFError() + now = time.time() + if now - start >= timeout: + raise socket.timeout() + + def _read_timeout(self, timeout): + if PY22: + return self._py22_read_timeout(n) + start = time.time() + while True: + try: + x = self.__socket.recv(1) + if len(x) == 0: + raise EOFError() + return x + except socket.timeout: + pass + if self.__closed: + raise EOFError() + now = time.time() + if now - start >= timeout: + raise socket.timeout() + + def _build_packet(self, payload): + # pad up at least 4 bytes, to nearest block-size (usually 8) + bsize = self.__block_size_out + padding = 3 + bsize - ((len(payload) + 8) % bsize) + packet = struct.pack('>IB', len(payload) + padding + 1, padding) + packet += payload + packet += randpool.get_bytes(padding) + return packet + + def _trigger_rekey(self): + # outside code should check for this flag + self.__need_rekey = True diff --git a/paramiko/sftp_client.py b/paramiko/sftp_client.py index b29f71e..fcf7706 100644 --- a/paramiko/sftp_client.py +++ b/paramiko/sftp_client.py @@ -53,7 +53,7 @@ class SFTPClient (BaseSFTP): transport = self.sock.get_transport() self.logger = util.get_logger(transport.get_log_channel() + '.' + self.sock.get_name() + '.sftp') - self.ultra_debug = transport.ultra_debug + self.ultra_debug = transport.get_hexdump() self._send_version() def from_transport(selfclass, t): diff --git a/paramiko/sftp_server.py b/paramiko/sftp_server.py index 6797184..94a9e6c 100644 --- a/paramiko/sftp_server.py +++ b/paramiko/sftp_server.py @@ -60,7 +60,7 @@ class SFTPServer (BaseSFTP, SubsystemHandler): transport = channel.get_transport() self.logger = util.get_logger(transport.get_log_channel() + '.' + channel.get_name() + '.sftp') - self.ultra_debug = transport.ultra_debug + self.ultra_debug = transport.get_hexdump() self.next_handle = 1 # map of handle-string to SFTPHandle for files & folders: self.file_table = { } diff --git a/paramiko/transport.py b/paramiko/transport.py index da252ce..b436847 100644 --- a/paramiko/transport.py +++ b/paramiko/transport.py @@ -1,5 +1,3 @@ -#!/usr/bin/python - # Copyright (C) 2003-2005 Robey Pointer # # This file is part of paramiko. @@ -30,6 +28,7 @@ from message import Message from channel import Channel from sftp_client import SFTPClient import util +from packet import Packetizer from rsakey import RSAKey from dsskey import DSSKey from kex_group1 import KexGroup1 @@ -49,7 +48,7 @@ from Crypto.Hash import SHA, MD5, HMAC _active_threads = [] def _join_lingering_threads(): for thr in _active_threads: - thr.active = False + thr.stop_thread() import atexit atexit.register(_join_lingering_threads) @@ -162,10 +161,6 @@ class BaseTransport (threading.Thread): 'diffie-hellman-group-exchange-sha1': KexGex, } - # READ the secsh RFC's before raising these values. if anything, - # they should probably be lower. - REKEY_PACKETS = pow(2, 30) - REKEY_BYTES = pow(2, 30) _modulus_pack = None @@ -222,42 +217,29 @@ class BaseTransport (threading.Thread): except AttributeError: pass # negotiated crypto parameters + self.packetizer = Packetizer(sock) self.local_version = 'SSH-' + self._PROTO_ID + '-' + self._CLIENT_ID self.remote_version = '' - self.block_size_out = self.block_size_in = 8 - self.local_mac_len = self.remote_mac_len = 0 - self.engine_in = self.engine_out = None self.local_cipher = self.remote_cipher = '' - self.sequence_number_in = self.sequence_number_out = 0L self.local_kex_init = self.remote_kex_init = None self.session_id = None # /negotiated crypto parameters self.expected_packet = 0 self.active = False self.initial_kex_done = False - self.write_lock = threading.RLock() # lock around outbound writes (packet computation) - self.lock = threading.Lock() # synchronization (always higher level than write_lock) - self.channels = { } # (id -> Channel) - self.channel_events = { } # (id -> Event) + self.lock = threading.Lock() # synchronization (always higher level than write_lock) + self.channels = { } # (id -> Channel) + self.channel_events = { } # (id -> Event) self.channel_counter = 1 self.window_size = 65536 self.max_packet_size = 32768 - self.ultra_debug = False self.saved_exception = None self.clear_to_send = threading.Event() self.log_name = 'paramiko.transport' self.logger = util.get_logger(self.log_name) - # used for noticing when to re-key: - self.received_bytes = 0 - self.received_packets = 0 - self.received_packets_overflow = 0 - self.sent_bytes = 0 - self.sent_packets = 0 + self.packetizer.set_log(self.logger) # user-defined event callbacks: self.completion_event = None - # keepalives: - self.keepalive_interval = 0 - self.keepalive_last = time.time() # server mode: self.server_mode = False self.server_object = None @@ -293,7 +275,7 @@ class BaseTransport (threading.Thread): preference for them. @return: an object that can be used to change the preferred algorithms - for encryption, digest (hash), public key, and key exchange. + for encryption, digest (hash), public key, and key exchange. @rtype: L{SecurityOptions} @since: ivysaur @@ -316,8 +298,8 @@ class BaseTransport (threading.Thread): @note: L{connect} is a simpler method for connecting as a client. @note: After calling this method (or L{start_server} or L{connect}), - you should no longer directly read from or write to the original socket - object. + you should no longer directly read from or write to the original + socket object. @param event: an event to trigger when negotiation is complete. @type event: threading.Event @@ -350,13 +332,13 @@ class BaseTransport (threading.Thread): given C{server} object to allow channels to be opened. @note: After calling this method (or L{start_client} or L{connect}), - you should no longer directly read from or write to the original socket - object. + you should no longer directly read from or write to the original + socket object. @param event: an event to trigger when negotiation is complete. @type event: threading.Event @param server: an object used to perform authentication and create - L{Channel}s. + L{Channel}s. @type server: L{server.ServerInterface} """ if server is None: @@ -376,7 +358,7 @@ class BaseTransport (threading.Thread): key info, not just the public half. @param key: the host key to add, usually an L{RSAKey } or - L{DSSKey }. + L{DSSKey }. @type key: L{PKey } """ self.server_key_dict[key.get_name()] = key @@ -418,10 +400,10 @@ class BaseTransport (threading.Thread): support that method of key negotiation. @param filename: optional path to the moduli file, if you happen to - know that it's not in a standard location. + know that it's not in a standard location. @type filename: str @return: True if a moduli file was successfully loaded; False - otherwise. + otherwise. @rtype: bool @since: doduo @@ -449,8 +431,7 @@ class BaseTransport (threading.Thread): Close this session, and any open channels that are tied to it. """ self.active = False - self.engine_in = self.engine_out = None - self.sequence_number_in = self.sequence_number_out = 0L + self.packetizer.close() for chan in self.channels.values(): chan._unlink() @@ -459,9 +440,9 @@ class BaseTransport (threading.Thread): Return the host key of the server (in client mode). @note: Previously this call returned a tuple of (key type, key string). - You can get the same effect by calling - L{PKey.get_name } for the key type, and C{str(key)} - for the key string. + You can get the same effect by calling + L{PKey.get_name } for the key type, and + C{str(key)} for the key string. @raise SSHException: if no session is currently active. @@ -476,7 +457,8 @@ class BaseTransport (threading.Thread): """ Return true if this session is active (open). - @return: True if the session is still active (open); False if the session is closed. + @return: True if the session is still active (open); False if the + session is closed. @rtype: bool """ return self.active @@ -487,7 +469,7 @@ class BaseTransport (threading.Thread): is just an alias for C{open_channel('session')}. @return: a new L{Channel} on success, or C{None} if the request is - rejected or the session ends prematurely. + rejected or the session ends prematurely. @rtype: L{Channel} """ return self.open_channel('session') @@ -500,17 +482,17 @@ class BaseTransport (threading.Thread): L{connect} or L{start_client}) and authenticating. @param kind: the kind of channel requested (usually C{"session"}, - C{"forwarded-tcpip"} or C{"direct-tcpip"}). + C{"forwarded-tcpip"} or C{"direct-tcpip"}). @type kind: str @param dest_addr: the destination address of this port forwarding, - if C{kind} is C{"forwarded-tcpip"} or C{"direct-tcpip"} (ignored - for other channel types). + if C{kind} is C{"forwarded-tcpip"} or C{"direct-tcpip"} (ignored + for other channel types). @type dest_addr: (str, int) @param src_addr: the source address of this port forwarding, if - C{kind} is C{"forwarded-tcpip"} or C{"direct-tcpip"}. + C{kind} is C{"forwarded-tcpip"} or C{"direct-tcpip"}. @type src_addr: (str, int) @return: a new L{Channel} on success, or C{None} if the request is - rejected or the session ends prematurely. + rejected or the session ends prematurely. @rtype: L{Channel} """ chan = None @@ -599,7 +581,7 @@ class BaseTransport (threading.Thread): session has died mid-negotiation. @return: True if the renegotiation was successful, and the link is - using new keys; False if the session dropped during renegotiation. + using new keys; False if the session dropped during renegotiation. @rtype: bool """ self.completion_event = threading.Event() @@ -620,12 +602,13 @@ class BaseTransport (threading.Thread): can be useful to keep connections alive over a NAT, for example. @param interval: seconds to wait before sending a keepalive packet (or - 0 to disable keepalives). + 0 to disable keepalives). @type interval: int @since: fearow """ - self.keepalive_interval = interval + self.packetizer.set_keepalive(interval, + lambda x=self: x.global_request('keepalive@lag.net', wait=False)) def global_request(self, kind, data=None, wait=True): """ @@ -635,14 +618,14 @@ class BaseTransport (threading.Thread): @param kind: name of the request. @type kind: str @param data: an optional tuple containing additional data to attach - to the request. + to the request. @type data: tuple @param wait: C{True} if this method should not return until a response - is received; C{False} otherwise. + is received; C{False} otherwise. @type wait: bool @return: a L{Message} containing possible additional data if the - request was successful (or an empty L{Message} if C{wait} was - C{False}); C{None} if the request was denied. + request was successful (or an empty L{Message} if C{wait} was + C{False}); C{None} if the request was denied. @rtype: L{Message} @since: fearow @@ -833,7 +816,23 @@ class BaseTransport (threading.Thread): C{False} otherwise. @type hexdump: bool """ - self.ultra_debug = hexdump + self.packetizer.set_hexdump(hexdump) + + def get_hexdump(self): + """ + Return C{True} if the transport is currently logging hex dumps of + protocol traffic. + + @return: C{True} if hex dumps are being logged + @rtype: bool + + @since: 1.4 + """ + return self.packetizer.get_hexdump() + + def stop_thread(self): + self.active = False + self.packetizer.close() ### internals... @@ -859,113 +858,10 @@ class BaseTransport (threading.Thread): finally: self.lock.release() - def _check_keepalive(self): - if (not self.keepalive_interval) or (not self.initial_kex_done): - return - now = time.time() - if now > self.keepalive_last + self.keepalive_interval: - self.global_request('keepalive@lag.net', wait=False) - - def _py22_read_all(self, n): - out = '' - while n > 0: - r, w, e = select.select([self.sock], [], [], 0.1) - if self.sock not in r: - if not self.active: - raise EOFError() - self._check_keepalive() - else: - x = self.sock.recv(n) - if len(x) == 0: - raise EOFError() - out += x - n -= len(x) - return out - - def _read_all(self, n): - if PY22: - return self._py22_read_all(n) - out = '' - while n > 0: - try: - x = self.sock.recv(n) - if len(x) == 0: - raise EOFError() - out += x - n -= len(x) - except socket.timeout: - if not self.active: - raise EOFError() - self._check_keepalive() - return out - - def _write_all(self, out): - self.keepalive_last = time.time() - while len(out) > 0: - try: - n = self.sock.send(out) - except socket.timeout: - n = 0 - if not self.active: - n = -1 - except Exception, x: - # could be: (32, 'Broken pipe') - n = -1 - if n < 0: - raise EOFError() - if n == len(out): - return - out = out[n:] - return - - def _build_packet(self, payload): - # pad up at least 4 bytes, to nearest block-size (usually 8) - bsize = self.block_size_out - padding = 3 + bsize - ((len(payload) + 8) % bsize) - packet = struct.pack('>I', len(payload) + padding + 1) - packet += chr(padding) - packet += payload - packet += randpool.get_bytes(padding) - return packet - def _send_message(self, data): - # encrypt this sucka - data = str(data) - cmd = ord(data[0]) - if cmd in MSG_NAMES: - cmd_name = MSG_NAMES[cmd] - else: - cmd_name = '$%x' % cmd - self._log(DEBUG, 'Write packet <%s>, length %d' % (cmd_name, len(data))) - packet = self._build_packet(data) - if self.ultra_debug: - self._log(DEBUG, util.format_binary(packet, 'OUT: ')) - if self.engine_out != None: - out = self.engine_out.encrypt(packet) - else: - out = packet - # + mac - try: - self.write_lock.acquire() - if self.engine_out != None: - payload = struct.pack('>I', self.sequence_number_out) + packet - out += HMAC.HMAC(self.mac_key_out, payload, self.local_mac_engine).digest()[:self.local_mac_len] - self.sequence_number_out += 1L - self.sequence_number_out %= 0x100000000L - self._write_all(out) - - self.sent_bytes += len(out) - self.sent_packets += 1 - if ((self.sent_packets >= self.REKEY_PACKETS) or (self.sent_bytes >= self.REKEY_BYTES)) \ - and (self.local_kex_init is None): - # only ask once for rekeying - self._log(DEBUG, 'Rekeying (hit %d packets, %d bytes sent)' % - (self.sent_packets, self.sent_bytes)) - self.received_packets_overflow = 0 - # this may do a recursive lock, but that's okay: - self._send_kex_init() - finally: - self.write_lock.release() + self.packetizer.send_message(data) + if self.packetizer.need_rekey(): + self._send_kex_init() def _send_user_message(self, data): """ @@ -981,65 +877,6 @@ class BaseTransport (threading.Thread): break self._send_message(data) - def _read_message(self): - "only one thread will ever be in this function" - header = self._read_all(self.block_size_in) - if self.engine_in != None: - header = self.engine_in.decrypt(header) - if self.ultra_debug: - self._log(DEBUG, util.format_binary(header, 'IN: ')); - packet_size = struct.unpack('>I', header[:4])[0] - # leftover contains decrypted bytes from the first block (after the length field) - leftover = header[4:] - if (packet_size - len(leftover)) % self.block_size_in != 0: - raise SSHException('Invalid packet blocking') - buffer = self._read_all(packet_size + self.remote_mac_len - len(leftover)) - packet = buffer[:packet_size - len(leftover)] - post_packet = buffer[packet_size - len(leftover):] - if self.engine_in != None: - packet = self.engine_in.decrypt(packet) - if self.ultra_debug: - self._log(DEBUG, util.format_binary(packet, 'IN: ')); - packet = leftover + packet - if self.remote_mac_len > 0: - mac = post_packet[:self.remote_mac_len] - mac_payload = struct.pack('>II', self.sequence_number_in, packet_size) + packet - my_mac = HMAC.HMAC(self.mac_key_in, mac_payload, self.remote_mac_engine).digest()[:self.remote_mac_len] - if my_mac != mac: - raise SSHException('Mismatched MAC') - padding = ord(packet[0]) - payload = packet[1:packet_size - padding + 1] - randpool.add_event(packet[packet_size - padding + 1]) - if self.ultra_debug: - self._log(DEBUG, 'Got payload (%d bytes, %d padding)' % (packet_size, padding)) - msg = Message(payload[1:]) - msg.seqno = self.sequence_number_in - self.sequence_number_in = (self.sequence_number_in + 1) & 0xffffffffL - # check for rekey - self.received_bytes += packet_size + self.remote_mac_len + 4 - self.received_packets += 1 - if self.local_kex_init is not None: - # we've asked to rekey -- give them 20 packets to comply before - # dropping the connection - self.received_packets_overflow += 1 - if self.received_packets_overflow >= 20: - raise SSHException('Remote transport is ignoring rekey requests') - elif (self.received_packets >= self.REKEY_PACKETS) or \ - (self.received_bytes >= self.REKEY_BYTES): - # only ask once for rekeying - self._log(DEBUG, 'Rekeying (hit %d packets, %d bytes received)' % - (self.received_packets, self.received_bytes)) - self.received_packets_overflow = 0 - self._send_kex_init() - - cmd = ord(payload[0]) - if cmd in MSG_NAMES: - cmd_name = MSG_NAMES[cmd] - else: - cmd_name = '$%x' % cmd - self._log(DEBUG, 'Read packet <%s>, length %d' % (cmd_name, len(payload))) - return cmd, msg - def _set_K_H(self, k, h): "used by a kex object to set the K (root key) and H (exchange hash)" self.K = k @@ -1090,18 +927,21 @@ class BaseTransport (threading.Thread): else: self._log(DEBUG, 'starting thread (client mode): %s' % hex(long(id(self)) & 0xffffffffL)) try: - self._write_all(self.local_version + '\r\n') + self.packetizer.write_all(self.local_version + '\r\n') self._check_banner() self._send_kex_init() self.expected_packet = MSG_KEXINIT while self.active: - ptype, m = self._read_message() + if self.packetizer.need_rekey(): + self._send_kex_init() + ptype, m = self.packetizer.read_message() if ptype == MSG_IGNORE: continue elif ptype == MSG_DISCONNECT: self._parse_disconnect(m) self.active = False + self.packetizer.close() break elif ptype == MSG_DEBUG: self._parse_debug(m) @@ -1123,6 +963,7 @@ class BaseTransport (threading.Thread): else: self._log(ERROR, 'Channel request for unknown channel %d' % chanid) self.active = False + self.packetizer.close() else: self._log(WARNING, 'Oops, unhandled type %d' % ptype) msg = Message() @@ -1146,6 +987,7 @@ class BaseTransport (threading.Thread): chan._unlink() if self.active: self.active = False + self.packetizer.close() if self.completion_event != None: self.completion_event.set() if self.auth_event != None: @@ -1170,12 +1012,15 @@ class BaseTransport (threading.Thread): def _check_banner(self): # this is slow, but we only have to do it once for i in range(5): - buffer = '' - while not '\n' in buffer: - buffer += self._read_all(1) - buffer = buffer[:-1] - if (len(buffer) > 0) and (buffer[-1] == '\r'): - buffer = buffer[:-1] + # give them 5 seconds for the first line, then just 2 seconds each additional line + if i == 0: + timeout = 5 + else: + timeout = 2 + try: + buffer = self.packetizer.readline(timeout) + except Exception, x: + raise SSHException('Error reading SSH protocol banner' + str(x)) if buffer[:4] == 'SSH-': break self._log(DEBUG, 'Banner: ' + buffer) @@ -1236,13 +1081,6 @@ class BaseTransport (threading.Thread): self._send_message(m) def _parse_kex_init(self, m): - # reset counters of when to re-key, since we are now re-keying - self.received_bytes = 0 - self.received_packets = 0 - self.received_packets_overflow = 0 - self.sent_bytes = 0 - self.sent_packets = 0 - cookie = m.get_bytes(16) kex_algo_list = m.get_list() server_key_algo_list = m.get_list() @@ -1334,44 +1172,46 @@ class BaseTransport (threading.Thread): def _activate_inbound(self): "switch on newly negotiated encryption parameters for inbound traffic" - self.block_size_in = self._cipher_info[self.remote_cipher]['block-size'] + block_size = self._cipher_info[self.remote_cipher]['block-size'] if self.server_mode: - IV_in = self._compute_key('A', self.block_size_in) + IV_in = self._compute_key('A', block_size) key_in = self._compute_key('C', self._cipher_info[self.remote_cipher]['key-size']) else: - IV_in = self._compute_key('B', self.block_size_in) + IV_in = self._compute_key('B', block_size) key_in = self._compute_key('D', self._cipher_info[self.remote_cipher]['key-size']) - self.engine_in = self._get_cipher(self.remote_cipher, key_in, IV_in) - self.remote_mac_len = self._mac_info[self.remote_mac]['size'] - self.remote_mac_engine = self._mac_info[self.remote_mac]['class'] + engine = self._get_cipher(self.remote_cipher, key_in, IV_in) + mac_size = self._mac_info[self.remote_mac]['size'] + mac_engine = self._mac_info[self.remote_mac]['class'] # initial mac keys are done in the hash's natural size (not the potentially truncated # transmission size) if self.server_mode: - self.mac_key_in = self._compute_key('E', self.remote_mac_engine.digest_size) + mac_key = self._compute_key('E', mac_engine.digest_size) else: - self.mac_key_in = self._compute_key('F', self.remote_mac_engine.digest_size) + mac_key = self._compute_key('F', mac_engine.digest_size) + self.packetizer.set_inbound_cipher(engine, block_size, mac_engine, mac_size, mac_key) def _activate_outbound(self): "switch on newly negotiated encryption parameters for outbound traffic" m = Message() m.add_byte(chr(MSG_NEWKEYS)) self._send_message(m) - self.block_size_out = self._cipher_info[self.local_cipher]['block-size'] + block_size = self._cipher_info[self.local_cipher]['block-size'] if self.server_mode: - IV_out = self._compute_key('B', self.block_size_out) + IV_out = self._compute_key('B', block_size) key_out = self._compute_key('D', self._cipher_info[self.local_cipher]['key-size']) else: - IV_out = self._compute_key('A', self.block_size_out) + IV_out = self._compute_key('A', block_size) key_out = self._compute_key('C', self._cipher_info[self.local_cipher]['key-size']) - self.engine_out = self._get_cipher(self.local_cipher, key_out, IV_out) - self.local_mac_len = self._mac_info[self.local_mac]['size'] - self.local_mac_engine = self._mac_info[self.local_mac]['class'] + engine = self._get_cipher(self.local_cipher, key_out, IV_out) + mac_size = self._mac_info[self.local_mac]['size'] + mac_engine = self._mac_info[self.local_mac]['class'] # initial mac keys are done in the hash's natural size (not the potentially truncated # transmission size) if self.server_mode: - self.mac_key_out = self._compute_key('F', self.local_mac_engine.digest_size) + mac_key = self._compute_key('F', mac_engine.digest_size) else: - self.mac_key_out = self._compute_key('E', self.local_mac_engine.digest_size) + mac_key = self._compute_key('E', mac_engine.digest_size) + self.packetizer.set_outbound_cipher(engine, block_size, mac_engine, mac_size, mac_key) # we always expect to receive NEWKEYS now self.expected_packet = MSG_NEWKEYS diff --git a/test.py b/test.py index 2c4a28a..a97354e 100755 --- a/test.py +++ b/test.py @@ -70,6 +70,9 @@ options, args = parser.parse_args() if len(args) > 0: parser.error('unknown argument(s)') +# setup logging +paramiko.util.log_to_file('test.log') + if options.use_sftp: if options.use_loopback_sftp: SFTPTest.init_loopback() @@ -78,9 +81,6 @@ if options.use_sftp: if not options.use_big_file: SFTPTest.set_big_file_test(False) -# setup logging -paramiko.util.log_to_file('test.log') - suite = unittest.TestSuite() suite.addTest(unittest.makeSuite(MessageTest)) suite.addTest(unittest.makeSuite(BufferedFileTest)) diff --git a/tests/test_sftp.py b/tests/test_sftp.py index 5d4d921..5031f02 100755 --- a/tests/test_sftp.py +++ b/tests/test_sftp.py @@ -1,5 +1,3 @@ -#!/usr/bin/python - # Copyright (C) 2003-2005 Robey Pointer # # This file is part of paramiko. @@ -145,7 +143,7 @@ class SFTPTest (unittest.TestCase): finally: sftp.remove(FOLDER + '/test') - def test_1a_close(self): + def test_2_close(self): """ verify that closing the sftp session doesn't do anything bad, and that a new one can be opened. @@ -159,7 +157,7 @@ class SFTPTest (unittest.TestCase): pass sftp = paramiko.SFTP.from_transport(tc) - def test_2_write(self): + def test_3_write(self): """ verify that a file can be created and written, and the size is correct. """ @@ -171,7 +169,7 @@ class SFTPTest (unittest.TestCase): finally: sftp.remove(FOLDER + '/duck.txt') - def test_3_append(self): + def test_4_append(self): """ verify that a file can be opened for append, and tell() still works. """ @@ -191,7 +189,7 @@ class SFTPTest (unittest.TestCase): finally: sftp.remove(FOLDER + '/append.txt') - def test_4_rename(self): + def test_5_rename(self): """ verify that renaming a file works. """ @@ -219,7 +217,7 @@ class SFTPTest (unittest.TestCase): except: pass - def test_5_folder(self): + def test_6_folder(self): """ create a temporary folder, verify that we can create a file in it, then remove the folder and verify that we can't create a file in it anymore. @@ -236,7 +234,7 @@ class SFTPTest (unittest.TestCase): except IOError: pass - def test_6_listdir(self): + def test_7_listdir(self): """ verify that a folder can be created, a bunch of files can be placed in it, and those files show up in sftp.listdir. @@ -262,7 +260,7 @@ class SFTPTest (unittest.TestCase): sftp.remove(FOLDER + '/fish.txt') sftp.remove(FOLDER + '/tertiary.py') - def test_7_setstat(self): + def test_8_setstat(self): """ verify that the setstat functions (chown, chmod, utime) work. """ @@ -285,7 +283,7 @@ class SFTPTest (unittest.TestCase): finally: sftp.remove(FOLDER + '/special') - def test_8_readline_seek(self): + def test_9_readline_seek(self): """ create a text file and write a bunch of text into it. then count the lines in the file, and seek around to retreive particular lines. this should @@ -315,7 +313,7 @@ class SFTPTest (unittest.TestCase): finally: sftp.remove(FOLDER + '/duck.txt') - def test_9_write_seek(self): + def test_A_write_seek(self): """ create a text file, seek back and change part of it, and verify that the changes worked. @@ -335,7 +333,7 @@ class SFTPTest (unittest.TestCase): finally: sftp.remove(FOLDER + '/testing.txt') - def test_A_symlink(self): + def test_B_symlink(self): """ create a symlink and then check that lstat doesn't follow it. """ @@ -378,7 +376,7 @@ class SFTPTest (unittest.TestCase): except: pass - def test_B_flush_seek(self): + def test_C_flush_seek(self): """ verify that buffered writes are automatically flushed on seek. """ @@ -400,7 +398,7 @@ class SFTPTest (unittest.TestCase): except: pass - def test_C_lots_of_files(self): + def test_D_lots_of_files(self): """ create a bunch of files over the same session. """ @@ -431,7 +429,7 @@ class SFTPTest (unittest.TestCase): except: pass - def test_D_big_file(self): + def test_E_big_file(self): """ write a 1MB file, with no linefeeds, using line buffering. FIXME: this is slow! what causes the slowness? @@ -453,7 +451,7 @@ class SFTPTest (unittest.TestCase): finally: sftp.remove('%s/hongry.txt' % FOLDER) - def test_E_big_file_big_buffer(self): + def test_F_big_file_big_buffer(self): """ write a 1MB file, with no linefeeds, and a big buffer. """ @@ -470,7 +468,7 @@ class SFTPTest (unittest.TestCase): finally: sftp.remove('%s/hongry.txt' % FOLDER) - def test_F_realpath(self): + def test_G_realpath(self): """ test that realpath is returning something non-empty and not an error. @@ -481,7 +479,7 @@ class SFTPTest (unittest.TestCase): self.assert_(len(f) > 0) self.assertEquals(os.path.join(pwd, FOLDER), f) - def test_G_mkdir(self): + def test_H_mkdir(self): """ verify that mkdir/rmdir work. """ diff --git a/tests/test_transport.py b/tests/test_transport.py index bd11487..5afc2e1 100644 --- a/tests/test_transport.py +++ b/tests/test_transport.py @@ -22,7 +22,7 @@ Some unit tests for the ssh2 protocol in Transport. """ -import sys, unittest, threading +import sys, time, threading, unittest from paramiko import Transport, SecurityOptions, ServerInterface, RSAKey, DSSKey, \ SSHException, BadAuthenticationType from paramiko import AUTH_FAILED, AUTH_PARTIALLY_SUCCESSFUL, AUTH_SUCCESSFUL @@ -77,6 +77,10 @@ class NullServer (ServerInterface): def check_channel_shell_request(self, channel): return True + + def check_global_request(self, kind, msg): + self._global_request = kind + return False class TransportTest (unittest.TestCase): @@ -160,14 +164,38 @@ class TransportTest (unittest.TestCase): self.assert_(self.ts.is_active()) self.assertEquals('aes256-cbc', self.tc.local_cipher) self.assertEquals('aes256-cbc', self.tc.remote_cipher) - self.assertEquals(12, self.tc.local_mac_len) - self.assertEquals(12, self.tc.remote_mac_len) + self.assertEquals(12, self.tc.packetizer.get_mac_size_out()) + self.assertEquals(12, self.tc.packetizer.get_mac_size_in()) self.tc.send_ignore(1024) self.assert_(self.tc.renegotiate_keys()) self.ts.send_ignore(1024) - def test_4_bad_auth_type(self): + def test_4_keepalive(self): + """ + verify that the keepalive will be sent. + """ + self.tc.set_hexdump(True) + + host_key = RSAKey.from_private_key_file('tests/test_rsa.key') + public_host_key = RSAKey(data=str(host_key)) + self.ts.add_server_key(host_key) + event = threading.Event() + server = NullServer() + self.assert_(not event.isSet()) + self.ts.start_server(event, server) + self.tc.connect(hostkey=public_host_key, + username='slowdive', password='pygmalion') + event.wait(1.0) + self.assert_(event.isSet()) + self.assert_(self.ts.is_active()) + + self.assertEquals(None, getattr(server, '_global_request', None)) + self.tc.set_keepalive(1) + time.sleep(2) + self.assertEquals('keepalive@lag.net', server._global_request) + + def test_5_bad_auth_type(self): """ verify that we get the right exception when an unsupported auth type is requested. @@ -188,7 +216,7 @@ class TransportTest (unittest.TestCase): self.assertEquals(BadAuthenticationType, etype) self.assertEquals(['publickey'], evalue.allowed_types) - def test_5_bad_password(self): + def test_6_bad_password(self): """ verify that a bad password gets the right exception, and that a retry with the right password works. @@ -213,7 +241,7 @@ class TransportTest (unittest.TestCase): self.assert_(event.isSet()) self.assert_(self.ts.is_active()) - def test_6_multipart_auth(self): + def test_7_multipart_auth(self): """ verify that multipart auth works. """ @@ -235,7 +263,7 @@ class TransportTest (unittest.TestCase): self.assert_(event.isSet()) self.assert_(self.ts.is_active()) - def test_7_exec_command(self): + def test_8_exec_command(self): """ verify that exec_command() does something reasonable. """ @@ -285,7 +313,7 @@ class TransportTest (unittest.TestCase): self.assertEquals('This is on stderr.\n', f.readline()) self.assertEquals('', f.readline()) - def test_8_invoke_shell(self): + def test_9_invoke_shell(self): """ verify that invoke_shell() does something reasonable. """ @@ -312,7 +340,7 @@ class TransportTest (unittest.TestCase): chan.close() self.assertEquals('', f.readline()) - def test_9_exit_status(self): + def test_A_exit_status(self): """ verify that get_exit_status() works. """