[project @ Arch-1:robey@lag.net--2005-master-shake%paramiko--dev--1--patch-5]

split out Packetizer, fix banner detection bug, new unit test
split out a chunk of BaseTransport into a Packetizer class, which handles
the in/out packet data, ciphers, etc.  it didn't make the code any smaller
(transport.py is still close to 1500 lines, which is awful) but it did split
out a coherent chunk of functionality into a discrete unit.

in the process, fixed a bug that alain spineux pointed out: the banner
check was too forgiving and would block forever waiting for an SSH banner.
now it waits 5 seconds for the first line, and 2 seconds for each subsequent
line, before giving up.

added a unit test to test keepalive, since i wasn't sure that was still
working after pulling out Packetizer.
This commit is contained in:
Robey Pointer 2005-05-01 08:04:59 +00:00
parent 2f2d7bdee8
commit 36055c5ac2
9 changed files with 553 additions and 285 deletions

2
README
View File

@ -231,3 +231,5 @@ v0.9 FEAROW
* would be nice to have an ftp-like interface to sftp (put, get, chdir...)
* speed up file transfers!
* what is psyco?

View File

@ -298,7 +298,6 @@ class Channel (object):
m.add_boolean(0)
m.add_int(status)
self.transport._send_user_message(m)
self._log(DEBUG, 'EXIT-STATUS')
def get_transport(self):
"""
@ -468,7 +467,7 @@ class Channel (object):
it means you may need to wait before more data arrives.
@return: C{True} if a L{recv} call on this channel would immediately
return at least one byte; C{False} otherwise.
return at least one byte; C{False} otherwise.
@rtype: boolean
"""
self.lock.acquire()
@ -492,7 +491,7 @@ class Channel (object):
@rtype: str
@raise socket.timeout: if no data is ready before the timeout set by
L{settimeout}.
L{settimeout}.
"""
out = ''
self.lock.acquire()

401
paramiko/packet.py Normal file
View File

@ -0,0 +1,401 @@
#!/usr/bin/python
# Copyright (C) 2003-2005 Robey Pointer <robey@lag.net>
#
# This file is part of paramiko.
#
# Paramiko is free software; you can redistribute it and/or modify it under the
# terms of the GNU Lesser General Public License as published by the Free
# Software Foundation; either version 2.1 of the License, or (at your option)
# any later version.
#
# Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
# details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with Paramiko; if not, write to the Free Software Foundation, Inc.,
# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA.
"""
Packetizer.
"""
import select, socket, struct, threading, time
from Crypto.Hash import HMAC
from common import *
from message import Message
import util
class Packetizer (object):
"""
Implementation of the base SSH packet protocol.
"""
# READ the secsh RFC's before raising these values. if anything,
# they should probably be lower.
REKEY_PACKETS = pow(2, 30)
REKEY_BYTES = pow(2, 30)
def __init__(self, socket):
self.__socket = socket
self.__logger = None
self.__closed = False
self.__dump_packets = False
self.__need_rekey = False
# used for noticing when to re-key:
self.__sent_bytes = 0
self.__sent_packets = 0
self.__received_bytes = 0
self.__received_packets = 0
self.__received_packets_overflow = 0
# current inbound/outbound ciphering:
self.__block_size_out = 8
self.__block_size_in = 8
self.__mac_size_out = 0
self.__mac_size_in = 0
self.__block_engine_out = None
self.__block_engine_in = None
self.__mac_engine_out = None
self.__mac_engine_in = None
self.__mac_key_out = ''
self.__mac_key_in = ''
self.__sequence_number_out = 0L
self.__sequence_number_in = 0L
# lock around outbound writes (packet computation)
self.__write_lock = threading.RLock()
# keepalives:
self.__keepalive_interval = 0
self.__keepalive_last = time.time()
self.__keepalive_callback = None
def set_log(self, log):
"""
Set the python log object to use for logging.
"""
self.__logger = log
def set_outbound_cipher(self, block_engine, block_size, mac_engine, mac_size, mac_key):
"""
Switch outbound data cipher.
"""
self.__block_engine_out = block_engine
self.__block_size_out = block_size
self.__mac_engine_out = mac_engine
self.__mac_size_out = mac_size
self.__mac_key_out = mac_key
self.__sent_bytes = 0
self.__sent_packets = 0
self.__need_rekey = False
def set_inbound_cipher(self, block_engine, block_size, mac_engine, mac_size, mac_key):
"""
Switch inbound data cipher.
"""
self.__block_engine_in = block_engine
self.__block_size_in = block_size
self.__mac_engine_in = mac_engine
self.__mac_size_in = mac_size
self.__mac_key_in = mac_key
self.__received_bytes = 0
self.__received_packets = 0
self.__received_packets_overflow = 0
self.__need_rekey = False
def close(self):
self.__closed = True
self.__block_engine_in = None
self.__block_engine_out = None
self.__socket.close()
def set_hexdump(self, hexdump):
self.__dump_packets = hexdump
def get_hexdump(self):
return self.__dump_packets
def get_mac_size_in(self):
return self.__mac_size_in
def get_mac_size_out(self):
return self.__mac_size_out
def need_rekey(self):
"""
Returns C{True} if a new set of keys needs to be negotiated. This
will be triggered during a packet read or write, so it should be
checked after every read or write, or at least after every few.
@return: C{True} if a new set of keys needs to be negotiated
"""
return self.__need_rekey
def set_keepalive(self, interval, callback):
"""
Turn on/off the callback keepalive. If C{interval} seconds pass with
no data read from or written to the socket, the callback will be
executed and the timer will be reset.
"""
self.__keepalive_interval = interval
self.__keepalive_callback = callback
self.__keepalive_last = time.time()
self._log(DEBUG, 'SET KEEPALIVE %r' % interval)
def read_all(self, n):
"""
Read as close to N bytes as possible, blocking as long as necessary.
@param n: number of bytes to read
@type n: int
@return: the data read
@rtype: str
@throw EOFError: if the socket was closed before all the bytes could
be read
"""
if PY22:
return self._py22_read_all(n)
out = ''
while n > 0:
try:
x = self.__socket.recv(n)
if len(x) == 0:
raise EOFError()
out += x
n -= len(x)
except socket.timeout:
if self.__closed:
raise EOFError()
self._check_keepalive()
return out
def write_all(self, out):
self.__keepalive_last = time.time()
while len(out) > 0:
try:
n = self.__socket.send(out)
except socket.timeout:
n = 0
if self.__closed:
n = -1
except Exception, x:
# could be: (32, 'Broken pipe')
n = -1
if n < 0:
raise EOFError()
if n == len(out):
return
out = out[n:]
return
def readline(self, timeout):
"""
Read a line from the socket. This is done in a fairly inefficient
way, but is only used for initial banner negotiation so it's not worth
optimising.
"""
buffer = ''
while not '\n' in buffer:
buffer += self._read_timeout(timeout)
buffer = buffer[:-1]
if (len(buffer) > 0) and (buffer[-1] == '\r'):
buffer = buffer[:-1]
return buffer
def send_message(self, data):
"""
Write a block of data using the current cipher, as an SSH block.
"""
# encrypt this sucka
data = str(data)
cmd = ord(data[0])
if cmd in MSG_NAMES:
cmd_name = MSG_NAMES[cmd]
else:
cmd_name = '$%x' % cmd
self._log(DEBUG, 'Write packet <%s>, length %d' % (cmd_name, len(data)))
packet = self._build_packet(data)
if self.__dump_packets:
self._log(DEBUG, util.format_binary(packet, 'OUT: '))
if self.__block_engine_out != None:
out = self.__block_engine_out.encrypt(packet)
else:
out = packet
# + mac
self.__write_lock.acquire()
try:
if self.__block_engine_out != None:
payload = struct.pack('>I', self.__sequence_number_out) + packet
out += HMAC.HMAC(self.__mac_key_out, payload, self.__mac_engine_out).digest()[:self.__mac_size_out]
self.__sequence_number_out = (self.__sequence_number_out + 1) & 0xffffffffL
self.write_all(out)
self.__sent_bytes += len(out)
self.__sent_packets += 1
if ((self.__sent_packets >= self.REKEY_PACKETS) or (self.__sent_bytes >= self.REKEY_BYTES)) \
and not self.__need_rekey:
# only ask once for rekeying
self._log(DEBUG, 'Rekeying (hit %d packets, %d bytes sent)' %
(self.__sent_packets, self.__sent_bytes))
self.__received_packets_overflow = 0
self._trigger_rekey()
finally:
self.__write_lock.release()
def read_message(self):
"""
Only one thread should ever be in this function (no other locking is
done).
@throw SSHException: if the packet is mangled
"""
header = self.read_all(self.__block_size_in)
if self.__block_engine_in != None:
header = self.__block_engine_in.decrypt(header)
if self.__dump_packets:
self._log(DEBUG, util.format_binary(header, 'IN: '));
packet_size = struct.unpack('>I', header[:4])[0]
# leftover contains decrypted bytes from the first block (after the length field)
leftover = header[4:]
if (packet_size - len(leftover)) % self.__block_size_in != 0:
raise SSHException('Invalid packet blocking')
buffer = self.read_all(packet_size + self.__mac_size_in - len(leftover))
packet = buffer[:packet_size - len(leftover)]
post_packet = buffer[packet_size - len(leftover):]
if self.__block_engine_in != None:
packet = self.__block_engine_in.decrypt(packet)
if self.__dump_packets:
self._log(DEBUG, util.format_binary(packet, 'IN: '));
packet = leftover + packet
if self.__mac_size_in > 0:
mac = post_packet[:self.__mac_size_in]
mac_payload = struct.pack('>II', self.__sequence_number_in, packet_size) + packet
my_mac = HMAC.HMAC(self.__mac_key_in, mac_payload, self.__mac_engine_in).digest()[:self.__mac_size_in]
if my_mac != mac:
raise SSHException('Mismatched MAC')
padding = ord(packet[0])
payload = packet[1:packet_size - padding + 1]
randpool.add_event(packet[packet_size - padding + 1])
if self.__dump_packets:
self._log(DEBUG, 'Got payload (%d bytes, %d padding)' % (packet_size, padding))
msg = Message(payload[1:])
msg.seqno = self.__sequence_number_in
self.__sequence_number_in = (self.__sequence_number_in + 1) & 0xffffffffL
# check for rekey
self.__received_bytes += packet_size + self.__mac_size_in + 4
self.__received_packets += 1
if self.__need_rekey:
# we've asked to rekey -- give them 20 packets to comply before
# dropping the connection
self.__received_packets_overflow += 1
if self.__received_packets_overflow >= 20:
raise SSHException('Remote transport is ignoring rekey requests')
elif (self.__received_packets >= self.REKEY_PACKETS) or \
(self.__received_bytes >= self.REKEY_BYTES):
# only ask once for rekeying
self._log(DEBUG, 'Rekeying (hit %d packets, %d bytes received)' %
(self.__received_packets, self.__received_bytes))
self.__received_packets_overflow = 0
self._trigger_rekey()
cmd = ord(payload[0])
if cmd in MSG_NAMES:
cmd_name = MSG_NAMES[cmd]
else:
cmd_name = '$%x' % cmd
self._log(DEBUG, 'Read packet <%s>, length %d' % (cmd_name, len(payload)))
return cmd, msg
########## protected
def _log(self, level, msg):
if self.__logger is None:
return
if issubclass(type(msg), list):
for m in msg:
self.__logger.log(level, m)
else:
self.__logger.log(level, msg)
def _check_keepalive(self):
if (not self.__keepalive_interval) or (not self.__block_engine_out) or \
self.__need_rekey:
# wait till we're encrypting, and not in the middle of rekeying
return
now = time.time()
if now > self.__keepalive_last + self.__keepalive_interval:
self.__keepalive_callback()
self.__keepalive_last = now
def _py22_read_all(self, n):
out = ''
while n > 0:
r, w, e = select.select([self.__socket], [], [], 0.1)
if self.__socket not in r:
if self.__closed:
raise EOFError()
self._check_keepalive()
else:
x = self.__socket.recv(n)
if len(x) == 0:
raise EOFError()
out += x
n -= len(x)
return out
def _py22_read_timeout(self, timeout):
start = time.time()
while True:
r, w, e = select.select([self.__socket], [], [], 0.1)
if self.__socket in r:
x = self.__socket.recv(1)
if len(x) == 0:
raise EOFError()
return x
if self.__closed:
raise EOFError()
now = time.time()
if now - start >= timeout:
raise socket.timeout()
def _read_timeout(self, timeout):
if PY22:
return self._py22_read_timeout(n)
start = time.time()
while True:
try:
x = self.__socket.recv(1)
if len(x) == 0:
raise EOFError()
return x
except socket.timeout:
pass
if self.__closed:
raise EOFError()
now = time.time()
if now - start >= timeout:
raise socket.timeout()
def _build_packet(self, payload):
# pad up at least 4 bytes, to nearest block-size (usually 8)
bsize = self.__block_size_out
padding = 3 + bsize - ((len(payload) + 8) % bsize)
packet = struct.pack('>IB', len(payload) + padding + 1, padding)
packet += payload
packet += randpool.get_bytes(padding)
return packet
def _trigger_rekey(self):
# outside code should check for this flag
self.__need_rekey = True

View File

@ -53,7 +53,7 @@ class SFTPClient (BaseSFTP):
transport = self.sock.get_transport()
self.logger = util.get_logger(transport.get_log_channel() + '.' +
self.sock.get_name() + '.sftp')
self.ultra_debug = transport.ultra_debug
self.ultra_debug = transport.get_hexdump()
self._send_version()
def from_transport(selfclass, t):

View File

@ -60,7 +60,7 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
transport = channel.get_transport()
self.logger = util.get_logger(transport.get_log_channel() + '.' +
channel.get_name() + '.sftp')
self.ultra_debug = transport.ultra_debug
self.ultra_debug = transport.get_hexdump()
self.next_handle = 1
# map of handle-string to SFTPHandle for files & folders:
self.file_table = { }

View File

@ -1,5 +1,3 @@
#!/usr/bin/python
# Copyright (C) 2003-2005 Robey Pointer <robey@lag.net>
#
# This file is part of paramiko.
@ -30,6 +28,7 @@ from message import Message
from channel import Channel
from sftp_client import SFTPClient
import util
from packet import Packetizer
from rsakey import RSAKey
from dsskey import DSSKey
from kex_group1 import KexGroup1
@ -49,7 +48,7 @@ from Crypto.Hash import SHA, MD5, HMAC
_active_threads = []
def _join_lingering_threads():
for thr in _active_threads:
thr.active = False
thr.stop_thread()
import atexit
atexit.register(_join_lingering_threads)
@ -162,10 +161,6 @@ class BaseTransport (threading.Thread):
'diffie-hellman-group-exchange-sha1': KexGex,
}
# READ the secsh RFC's before raising these values. if anything,
# they should probably be lower.
REKEY_PACKETS = pow(2, 30)
REKEY_BYTES = pow(2, 30)
_modulus_pack = None
@ -222,42 +217,29 @@ class BaseTransport (threading.Thread):
except AttributeError:
pass
# negotiated crypto parameters
self.packetizer = Packetizer(sock)
self.local_version = 'SSH-' + self._PROTO_ID + '-' + self._CLIENT_ID
self.remote_version = ''
self.block_size_out = self.block_size_in = 8
self.local_mac_len = self.remote_mac_len = 0
self.engine_in = self.engine_out = None
self.local_cipher = self.remote_cipher = ''
self.sequence_number_in = self.sequence_number_out = 0L
self.local_kex_init = self.remote_kex_init = None
self.session_id = None
# /negotiated crypto parameters
self.expected_packet = 0
self.active = False
self.initial_kex_done = False
self.write_lock = threading.RLock() # lock around outbound writes (packet computation)
self.lock = threading.Lock() # synchronization (always higher level than write_lock)
self.channels = { } # (id -> Channel)
self.channel_events = { } # (id -> Event)
self.lock = threading.Lock() # synchronization (always higher level than write_lock)
self.channels = { } # (id -> Channel)
self.channel_events = { } # (id -> Event)
self.channel_counter = 1
self.window_size = 65536
self.max_packet_size = 32768
self.ultra_debug = False
self.saved_exception = None
self.clear_to_send = threading.Event()
self.log_name = 'paramiko.transport'
self.logger = util.get_logger(self.log_name)
# used for noticing when to re-key:
self.received_bytes = 0
self.received_packets = 0
self.received_packets_overflow = 0
self.sent_bytes = 0
self.sent_packets = 0
self.packetizer.set_log(self.logger)
# user-defined event callbacks:
self.completion_event = None
# keepalives:
self.keepalive_interval = 0
self.keepalive_last = time.time()
# server mode:
self.server_mode = False
self.server_object = None
@ -293,7 +275,7 @@ class BaseTransport (threading.Thread):
preference for them.
@return: an object that can be used to change the preferred algorithms
for encryption, digest (hash), public key, and key exchange.
for encryption, digest (hash), public key, and key exchange.
@rtype: L{SecurityOptions}
@since: ivysaur
@ -316,8 +298,8 @@ class BaseTransport (threading.Thread):
@note: L{connect} is a simpler method for connecting as a client.
@note: After calling this method (or L{start_server} or L{connect}),
you should no longer directly read from or write to the original socket
object.
you should no longer directly read from or write to the original
socket object.
@param event: an event to trigger when negotiation is complete.
@type event: threading.Event
@ -350,13 +332,13 @@ class BaseTransport (threading.Thread):
given C{server} object to allow channels to be opened.
@note: After calling this method (or L{start_client} or L{connect}),
you should no longer directly read from or write to the original socket
object.
you should no longer directly read from or write to the original
socket object.
@param event: an event to trigger when negotiation is complete.
@type event: threading.Event
@param server: an object used to perform authentication and create
L{Channel}s.
L{Channel}s.
@type server: L{server.ServerInterface}
"""
if server is None:
@ -376,7 +358,7 @@ class BaseTransport (threading.Thread):
key info, not just the public half.
@param key: the host key to add, usually an L{RSAKey <rsakey.RSAKey>} or
L{DSSKey <dsskey.DSSKey>}.
L{DSSKey <dsskey.DSSKey>}.
@type key: L{PKey <pkey.PKey>}
"""
self.server_key_dict[key.get_name()] = key
@ -418,10 +400,10 @@ class BaseTransport (threading.Thread):
support that method of key negotiation.
@param filename: optional path to the moduli file, if you happen to
know that it's not in a standard location.
know that it's not in a standard location.
@type filename: str
@return: True if a moduli file was successfully loaded; False
otherwise.
otherwise.
@rtype: bool
@since: doduo
@ -449,8 +431,7 @@ class BaseTransport (threading.Thread):
Close this session, and any open channels that are tied to it.
"""
self.active = False
self.engine_in = self.engine_out = None
self.sequence_number_in = self.sequence_number_out = 0L
self.packetizer.close()
for chan in self.channels.values():
chan._unlink()
@ -459,9 +440,9 @@ class BaseTransport (threading.Thread):
Return the host key of the server (in client mode).
@note: Previously this call returned a tuple of (key type, key string).
You can get the same effect by calling
L{PKey.get_name <pkey.PKey.get_name>} for the key type, and C{str(key)}
for the key string.
You can get the same effect by calling
L{PKey.get_name <pkey.PKey.get_name>} for the key type, and
C{str(key)} for the key string.
@raise SSHException: if no session is currently active.
@ -476,7 +457,8 @@ class BaseTransport (threading.Thread):
"""
Return true if this session is active (open).
@return: True if the session is still active (open); False if the session is closed.
@return: True if the session is still active (open); False if the
session is closed.
@rtype: bool
"""
return self.active
@ -487,7 +469,7 @@ class BaseTransport (threading.Thread):
is just an alias for C{open_channel('session')}.
@return: a new L{Channel} on success, or C{None} if the request is
rejected or the session ends prematurely.
rejected or the session ends prematurely.
@rtype: L{Channel}
"""
return self.open_channel('session')
@ -500,17 +482,17 @@ class BaseTransport (threading.Thread):
L{connect} or L{start_client}) and authenticating.
@param kind: the kind of channel requested (usually C{"session"},
C{"forwarded-tcpip"} or C{"direct-tcpip"}).
C{"forwarded-tcpip"} or C{"direct-tcpip"}).
@type kind: str
@param dest_addr: the destination address of this port forwarding,
if C{kind} is C{"forwarded-tcpip"} or C{"direct-tcpip"} (ignored
for other channel types).
if C{kind} is C{"forwarded-tcpip"} or C{"direct-tcpip"} (ignored
for other channel types).
@type dest_addr: (str, int)
@param src_addr: the source address of this port forwarding, if
C{kind} is C{"forwarded-tcpip"} or C{"direct-tcpip"}.
C{kind} is C{"forwarded-tcpip"} or C{"direct-tcpip"}.
@type src_addr: (str, int)
@return: a new L{Channel} on success, or C{None} if the request is
rejected or the session ends prematurely.
rejected or the session ends prematurely.
@rtype: L{Channel}
"""
chan = None
@ -599,7 +581,7 @@ class BaseTransport (threading.Thread):
session has died mid-negotiation.
@return: True if the renegotiation was successful, and the link is
using new keys; False if the session dropped during renegotiation.
using new keys; False if the session dropped during renegotiation.
@rtype: bool
"""
self.completion_event = threading.Event()
@ -620,12 +602,13 @@ class BaseTransport (threading.Thread):
can be useful to keep connections alive over a NAT, for example.
@param interval: seconds to wait before sending a keepalive packet (or
0 to disable keepalives).
0 to disable keepalives).
@type interval: int
@since: fearow
"""
self.keepalive_interval = interval
self.packetizer.set_keepalive(interval,
lambda x=self: x.global_request('keepalive@lag.net', wait=False))
def global_request(self, kind, data=None, wait=True):
"""
@ -635,14 +618,14 @@ class BaseTransport (threading.Thread):
@param kind: name of the request.
@type kind: str
@param data: an optional tuple containing additional data to attach
to the request.
to the request.
@type data: tuple
@param wait: C{True} if this method should not return until a response
is received; C{False} otherwise.
is received; C{False} otherwise.
@type wait: bool
@return: a L{Message} containing possible additional data if the
request was successful (or an empty L{Message} if C{wait} was
C{False}); C{None} if the request was denied.
request was successful (or an empty L{Message} if C{wait} was
C{False}); C{None} if the request was denied.
@rtype: L{Message}
@since: fearow
@ -833,7 +816,23 @@ class BaseTransport (threading.Thread):
C{False} otherwise.
@type hexdump: bool
"""
self.ultra_debug = hexdump
self.packetizer.set_hexdump(hexdump)
def get_hexdump(self):
"""
Return C{True} if the transport is currently logging hex dumps of
protocol traffic.
@return: C{True} if hex dumps are being logged
@rtype: bool
@since: 1.4
"""
return self.packetizer.get_hexdump()
def stop_thread(self):
self.active = False
self.packetizer.close()
### internals...
@ -859,113 +858,10 @@ class BaseTransport (threading.Thread):
finally:
self.lock.release()
def _check_keepalive(self):
if (not self.keepalive_interval) or (not self.initial_kex_done):
return
now = time.time()
if now > self.keepalive_last + self.keepalive_interval:
self.global_request('keepalive@lag.net', wait=False)
def _py22_read_all(self, n):
out = ''
while n > 0:
r, w, e = select.select([self.sock], [], [], 0.1)
if self.sock not in r:
if not self.active:
raise EOFError()
self._check_keepalive()
else:
x = self.sock.recv(n)
if len(x) == 0:
raise EOFError()
out += x
n -= len(x)
return out
def _read_all(self, n):
if PY22:
return self._py22_read_all(n)
out = ''
while n > 0:
try:
x = self.sock.recv(n)
if len(x) == 0:
raise EOFError()
out += x
n -= len(x)
except socket.timeout:
if not self.active:
raise EOFError()
self._check_keepalive()
return out
def _write_all(self, out):
self.keepalive_last = time.time()
while len(out) > 0:
try:
n = self.sock.send(out)
except socket.timeout:
n = 0
if not self.active:
n = -1
except Exception, x:
# could be: (32, 'Broken pipe')
n = -1
if n < 0:
raise EOFError()
if n == len(out):
return
out = out[n:]
return
def _build_packet(self, payload):
# pad up at least 4 bytes, to nearest block-size (usually 8)
bsize = self.block_size_out
padding = 3 + bsize - ((len(payload) + 8) % bsize)
packet = struct.pack('>I', len(payload) + padding + 1)
packet += chr(padding)
packet += payload
packet += randpool.get_bytes(padding)
return packet
def _send_message(self, data):
# encrypt this sucka
data = str(data)
cmd = ord(data[0])
if cmd in MSG_NAMES:
cmd_name = MSG_NAMES[cmd]
else:
cmd_name = '$%x' % cmd
self._log(DEBUG, 'Write packet <%s>, length %d' % (cmd_name, len(data)))
packet = self._build_packet(data)
if self.ultra_debug:
self._log(DEBUG, util.format_binary(packet, 'OUT: '))
if self.engine_out != None:
out = self.engine_out.encrypt(packet)
else:
out = packet
# + mac
try:
self.write_lock.acquire()
if self.engine_out != None:
payload = struct.pack('>I', self.sequence_number_out) + packet
out += HMAC.HMAC(self.mac_key_out, payload, self.local_mac_engine).digest()[:self.local_mac_len]
self.sequence_number_out += 1L
self.sequence_number_out %= 0x100000000L
self._write_all(out)
self.sent_bytes += len(out)
self.sent_packets += 1
if ((self.sent_packets >= self.REKEY_PACKETS) or (self.sent_bytes >= self.REKEY_BYTES)) \
and (self.local_kex_init is None):
# only ask once for rekeying
self._log(DEBUG, 'Rekeying (hit %d packets, %d bytes sent)' %
(self.sent_packets, self.sent_bytes))
self.received_packets_overflow = 0
# this may do a recursive lock, but that's okay:
self._send_kex_init()
finally:
self.write_lock.release()
self.packetizer.send_message(data)
if self.packetizer.need_rekey():
self._send_kex_init()
def _send_user_message(self, data):
"""
@ -981,65 +877,6 @@ class BaseTransport (threading.Thread):
break
self._send_message(data)
def _read_message(self):
"only one thread will ever be in this function"
header = self._read_all(self.block_size_in)
if self.engine_in != None:
header = self.engine_in.decrypt(header)
if self.ultra_debug:
self._log(DEBUG, util.format_binary(header, 'IN: '));
packet_size = struct.unpack('>I', header[:4])[0]
# leftover contains decrypted bytes from the first block (after the length field)
leftover = header[4:]
if (packet_size - len(leftover)) % self.block_size_in != 0:
raise SSHException('Invalid packet blocking')
buffer = self._read_all(packet_size + self.remote_mac_len - len(leftover))
packet = buffer[:packet_size - len(leftover)]
post_packet = buffer[packet_size - len(leftover):]
if self.engine_in != None:
packet = self.engine_in.decrypt(packet)
if self.ultra_debug:
self._log(DEBUG, util.format_binary(packet, 'IN: '));
packet = leftover + packet
if self.remote_mac_len > 0:
mac = post_packet[:self.remote_mac_len]
mac_payload = struct.pack('>II', self.sequence_number_in, packet_size) + packet
my_mac = HMAC.HMAC(self.mac_key_in, mac_payload, self.remote_mac_engine).digest()[:self.remote_mac_len]
if my_mac != mac:
raise SSHException('Mismatched MAC')
padding = ord(packet[0])
payload = packet[1:packet_size - padding + 1]
randpool.add_event(packet[packet_size - padding + 1])
if self.ultra_debug:
self._log(DEBUG, 'Got payload (%d bytes, %d padding)' % (packet_size, padding))
msg = Message(payload[1:])
msg.seqno = self.sequence_number_in
self.sequence_number_in = (self.sequence_number_in + 1) & 0xffffffffL
# check for rekey
self.received_bytes += packet_size + self.remote_mac_len + 4
self.received_packets += 1
if self.local_kex_init is not None:
# we've asked to rekey -- give them 20 packets to comply before
# dropping the connection
self.received_packets_overflow += 1
if self.received_packets_overflow >= 20:
raise SSHException('Remote transport is ignoring rekey requests')
elif (self.received_packets >= self.REKEY_PACKETS) or \
(self.received_bytes >= self.REKEY_BYTES):
# only ask once for rekeying
self._log(DEBUG, 'Rekeying (hit %d packets, %d bytes received)' %
(self.received_packets, self.received_bytes))
self.received_packets_overflow = 0
self._send_kex_init()
cmd = ord(payload[0])
if cmd in MSG_NAMES:
cmd_name = MSG_NAMES[cmd]
else:
cmd_name = '$%x' % cmd
self._log(DEBUG, 'Read packet <%s>, length %d' % (cmd_name, len(payload)))
return cmd, msg
def _set_K_H(self, k, h):
"used by a kex object to set the K (root key) and H (exchange hash)"
self.K = k
@ -1090,18 +927,21 @@ class BaseTransport (threading.Thread):
else:
self._log(DEBUG, 'starting thread (client mode): %s' % hex(long(id(self)) & 0xffffffffL))
try:
self._write_all(self.local_version + '\r\n')
self.packetizer.write_all(self.local_version + '\r\n')
self._check_banner()
self._send_kex_init()
self.expected_packet = MSG_KEXINIT
while self.active:
ptype, m = self._read_message()
if self.packetizer.need_rekey():
self._send_kex_init()
ptype, m = self.packetizer.read_message()
if ptype == MSG_IGNORE:
continue
elif ptype == MSG_DISCONNECT:
self._parse_disconnect(m)
self.active = False
self.packetizer.close()
break
elif ptype == MSG_DEBUG:
self._parse_debug(m)
@ -1123,6 +963,7 @@ class BaseTransport (threading.Thread):
else:
self._log(ERROR, 'Channel request for unknown channel %d' % chanid)
self.active = False
self.packetizer.close()
else:
self._log(WARNING, 'Oops, unhandled type %d' % ptype)
msg = Message()
@ -1146,6 +987,7 @@ class BaseTransport (threading.Thread):
chan._unlink()
if self.active:
self.active = False
self.packetizer.close()
if self.completion_event != None:
self.completion_event.set()
if self.auth_event != None:
@ -1170,12 +1012,15 @@ class BaseTransport (threading.Thread):
def _check_banner(self):
# this is slow, but we only have to do it once
for i in range(5):
buffer = ''
while not '\n' in buffer:
buffer += self._read_all(1)
buffer = buffer[:-1]
if (len(buffer) > 0) and (buffer[-1] == '\r'):
buffer = buffer[:-1]
# give them 5 seconds for the first line, then just 2 seconds each additional line
if i == 0:
timeout = 5
else:
timeout = 2
try:
buffer = self.packetizer.readline(timeout)
except Exception, x:
raise SSHException('Error reading SSH protocol banner' + str(x))
if buffer[:4] == 'SSH-':
break
self._log(DEBUG, 'Banner: ' + buffer)
@ -1236,13 +1081,6 @@ class BaseTransport (threading.Thread):
self._send_message(m)
def _parse_kex_init(self, m):
# reset counters of when to re-key, since we are now re-keying
self.received_bytes = 0
self.received_packets = 0
self.received_packets_overflow = 0
self.sent_bytes = 0
self.sent_packets = 0
cookie = m.get_bytes(16)
kex_algo_list = m.get_list()
server_key_algo_list = m.get_list()
@ -1334,44 +1172,46 @@ class BaseTransport (threading.Thread):
def _activate_inbound(self):
"switch on newly negotiated encryption parameters for inbound traffic"
self.block_size_in = self._cipher_info[self.remote_cipher]['block-size']
block_size = self._cipher_info[self.remote_cipher]['block-size']
if self.server_mode:
IV_in = self._compute_key('A', self.block_size_in)
IV_in = self._compute_key('A', block_size)
key_in = self._compute_key('C', self._cipher_info[self.remote_cipher]['key-size'])
else:
IV_in = self._compute_key('B', self.block_size_in)
IV_in = self._compute_key('B', block_size)
key_in = self._compute_key('D', self._cipher_info[self.remote_cipher]['key-size'])
self.engine_in = self._get_cipher(self.remote_cipher, key_in, IV_in)
self.remote_mac_len = self._mac_info[self.remote_mac]['size']
self.remote_mac_engine = self._mac_info[self.remote_mac]['class']
engine = self._get_cipher(self.remote_cipher, key_in, IV_in)
mac_size = self._mac_info[self.remote_mac]['size']
mac_engine = self._mac_info[self.remote_mac]['class']
# initial mac keys are done in the hash's natural size (not the potentially truncated
# transmission size)
if self.server_mode:
self.mac_key_in = self._compute_key('E', self.remote_mac_engine.digest_size)
mac_key = self._compute_key('E', mac_engine.digest_size)
else:
self.mac_key_in = self._compute_key('F', self.remote_mac_engine.digest_size)
mac_key = self._compute_key('F', mac_engine.digest_size)
self.packetizer.set_inbound_cipher(engine, block_size, mac_engine, mac_size, mac_key)
def _activate_outbound(self):
"switch on newly negotiated encryption parameters for outbound traffic"
m = Message()
m.add_byte(chr(MSG_NEWKEYS))
self._send_message(m)
self.block_size_out = self._cipher_info[self.local_cipher]['block-size']
block_size = self._cipher_info[self.local_cipher]['block-size']
if self.server_mode:
IV_out = self._compute_key('B', self.block_size_out)
IV_out = self._compute_key('B', block_size)
key_out = self._compute_key('D', self._cipher_info[self.local_cipher]['key-size'])
else:
IV_out = self._compute_key('A', self.block_size_out)
IV_out = self._compute_key('A', block_size)
key_out = self._compute_key('C', self._cipher_info[self.local_cipher]['key-size'])
self.engine_out = self._get_cipher(self.local_cipher, key_out, IV_out)
self.local_mac_len = self._mac_info[self.local_mac]['size']
self.local_mac_engine = self._mac_info[self.local_mac]['class']
engine = self._get_cipher(self.local_cipher, key_out, IV_out)
mac_size = self._mac_info[self.local_mac]['size']
mac_engine = self._mac_info[self.local_mac]['class']
# initial mac keys are done in the hash's natural size (not the potentially truncated
# transmission size)
if self.server_mode:
self.mac_key_out = self._compute_key('F', self.local_mac_engine.digest_size)
mac_key = self._compute_key('F', mac_engine.digest_size)
else:
self.mac_key_out = self._compute_key('E', self.local_mac_engine.digest_size)
mac_key = self._compute_key('E', mac_engine.digest_size)
self.packetizer.set_outbound_cipher(engine, block_size, mac_engine, mac_size, mac_key)
# we always expect to receive NEWKEYS now
self.expected_packet = MSG_NEWKEYS

View File

@ -70,6 +70,9 @@ options, args = parser.parse_args()
if len(args) > 0:
parser.error('unknown argument(s)')
# setup logging
paramiko.util.log_to_file('test.log')
if options.use_sftp:
if options.use_loopback_sftp:
SFTPTest.init_loopback()
@ -78,9 +81,6 @@ if options.use_sftp:
if not options.use_big_file:
SFTPTest.set_big_file_test(False)
# setup logging
paramiko.util.log_to_file('test.log')
suite = unittest.TestSuite()
suite.addTest(unittest.makeSuite(MessageTest))
suite.addTest(unittest.makeSuite(BufferedFileTest))

View File

@ -1,5 +1,3 @@
#!/usr/bin/python
# Copyright (C) 2003-2005 Robey Pointer <robey@lag.net>
#
# This file is part of paramiko.
@ -145,7 +143,7 @@ class SFTPTest (unittest.TestCase):
finally:
sftp.remove(FOLDER + '/test')
def test_1a_close(self):
def test_2_close(self):
"""
verify that closing the sftp session doesn't do anything bad, and that
a new one can be opened.
@ -159,7 +157,7 @@ class SFTPTest (unittest.TestCase):
pass
sftp = paramiko.SFTP.from_transport(tc)
def test_2_write(self):
def test_3_write(self):
"""
verify that a file can be created and written, and the size is correct.
"""
@ -171,7 +169,7 @@ class SFTPTest (unittest.TestCase):
finally:
sftp.remove(FOLDER + '/duck.txt')
def test_3_append(self):
def test_4_append(self):
"""
verify that a file can be opened for append, and tell() still works.
"""
@ -191,7 +189,7 @@ class SFTPTest (unittest.TestCase):
finally:
sftp.remove(FOLDER + '/append.txt')
def test_4_rename(self):
def test_5_rename(self):
"""
verify that renaming a file works.
"""
@ -219,7 +217,7 @@ class SFTPTest (unittest.TestCase):
except:
pass
def test_5_folder(self):
def test_6_folder(self):
"""
create a temporary folder, verify that we can create a file in it, then
remove the folder and verify that we can't create a file in it anymore.
@ -236,7 +234,7 @@ class SFTPTest (unittest.TestCase):
except IOError:
pass
def test_6_listdir(self):
def test_7_listdir(self):
"""
verify that a folder can be created, a bunch of files can be placed in it,
and those files show up in sftp.listdir.
@ -262,7 +260,7 @@ class SFTPTest (unittest.TestCase):
sftp.remove(FOLDER + '/fish.txt')
sftp.remove(FOLDER + '/tertiary.py')
def test_7_setstat(self):
def test_8_setstat(self):
"""
verify that the setstat functions (chown, chmod, utime) work.
"""
@ -285,7 +283,7 @@ class SFTPTest (unittest.TestCase):
finally:
sftp.remove(FOLDER + '/special')
def test_8_readline_seek(self):
def test_9_readline_seek(self):
"""
create a text file and write a bunch of text into it. then count the lines
in the file, and seek around to retreive particular lines. this should
@ -315,7 +313,7 @@ class SFTPTest (unittest.TestCase):
finally:
sftp.remove(FOLDER + '/duck.txt')
def test_9_write_seek(self):
def test_A_write_seek(self):
"""
create a text file, seek back and change part of it, and verify that the
changes worked.
@ -335,7 +333,7 @@ class SFTPTest (unittest.TestCase):
finally:
sftp.remove(FOLDER + '/testing.txt')
def test_A_symlink(self):
def test_B_symlink(self):
"""
create a symlink and then check that lstat doesn't follow it.
"""
@ -378,7 +376,7 @@ class SFTPTest (unittest.TestCase):
except:
pass
def test_B_flush_seek(self):
def test_C_flush_seek(self):
"""
verify that buffered writes are automatically flushed on seek.
"""
@ -400,7 +398,7 @@ class SFTPTest (unittest.TestCase):
except:
pass
def test_C_lots_of_files(self):
def test_D_lots_of_files(self):
"""
create a bunch of files over the same session.
"""
@ -431,7 +429,7 @@ class SFTPTest (unittest.TestCase):
except:
pass
def test_D_big_file(self):
def test_E_big_file(self):
"""
write a 1MB file, with no linefeeds, using line buffering.
FIXME: this is slow! what causes the slowness?
@ -453,7 +451,7 @@ class SFTPTest (unittest.TestCase):
finally:
sftp.remove('%s/hongry.txt' % FOLDER)
def test_E_big_file_big_buffer(self):
def test_F_big_file_big_buffer(self):
"""
write a 1MB file, with no linefeeds, and a big buffer.
"""
@ -470,7 +468,7 @@ class SFTPTest (unittest.TestCase):
finally:
sftp.remove('%s/hongry.txt' % FOLDER)
def test_F_realpath(self):
def test_G_realpath(self):
"""
test that realpath is returning something non-empty and not an
error.
@ -481,7 +479,7 @@ class SFTPTest (unittest.TestCase):
self.assert_(len(f) > 0)
self.assertEquals(os.path.join(pwd, FOLDER), f)
def test_G_mkdir(self):
def test_H_mkdir(self):
"""
verify that mkdir/rmdir work.
"""

View File

@ -22,7 +22,7 @@
Some unit tests for the ssh2 protocol in Transport.
"""
import sys, unittest, threading
import sys, time, threading, unittest
from paramiko import Transport, SecurityOptions, ServerInterface, RSAKey, DSSKey, \
SSHException, BadAuthenticationType
from paramiko import AUTH_FAILED, AUTH_PARTIALLY_SUCCESSFUL, AUTH_SUCCESSFUL
@ -77,6 +77,10 @@ class NullServer (ServerInterface):
def check_channel_shell_request(self, channel):
return True
def check_global_request(self, kind, msg):
self._global_request = kind
return False
class TransportTest (unittest.TestCase):
@ -160,14 +164,38 @@ class TransportTest (unittest.TestCase):
self.assert_(self.ts.is_active())
self.assertEquals('aes256-cbc', self.tc.local_cipher)
self.assertEquals('aes256-cbc', self.tc.remote_cipher)
self.assertEquals(12, self.tc.local_mac_len)
self.assertEquals(12, self.tc.remote_mac_len)
self.assertEquals(12, self.tc.packetizer.get_mac_size_out())
self.assertEquals(12, self.tc.packetizer.get_mac_size_in())
self.tc.send_ignore(1024)
self.assert_(self.tc.renegotiate_keys())
self.ts.send_ignore(1024)
def test_4_bad_auth_type(self):
def test_4_keepalive(self):
"""
verify that the keepalive will be sent.
"""
self.tc.set_hexdump(True)
host_key = RSAKey.from_private_key_file('tests/test_rsa.key')
public_host_key = RSAKey(data=str(host_key))
self.ts.add_server_key(host_key)
event = threading.Event()
server = NullServer()
self.assert_(not event.isSet())
self.ts.start_server(event, server)
self.tc.connect(hostkey=public_host_key,
username='slowdive', password='pygmalion')
event.wait(1.0)
self.assert_(event.isSet())
self.assert_(self.ts.is_active())
self.assertEquals(None, getattr(server, '_global_request', None))
self.tc.set_keepalive(1)
time.sleep(2)
self.assertEquals('keepalive@lag.net', server._global_request)
def test_5_bad_auth_type(self):
"""
verify that we get the right exception when an unsupported auth
type is requested.
@ -188,7 +216,7 @@ class TransportTest (unittest.TestCase):
self.assertEquals(BadAuthenticationType, etype)
self.assertEquals(['publickey'], evalue.allowed_types)
def test_5_bad_password(self):
def test_6_bad_password(self):
"""
verify that a bad password gets the right exception, and that a retry
with the right password works.
@ -213,7 +241,7 @@ class TransportTest (unittest.TestCase):
self.assert_(event.isSet())
self.assert_(self.ts.is_active())
def test_6_multipart_auth(self):
def test_7_multipart_auth(self):
"""
verify that multipart auth works.
"""
@ -235,7 +263,7 @@ class TransportTest (unittest.TestCase):
self.assert_(event.isSet())
self.assert_(self.ts.is_active())
def test_7_exec_command(self):
def test_8_exec_command(self):
"""
verify that exec_command() does something reasonable.
"""
@ -285,7 +313,7 @@ class TransportTest (unittest.TestCase):
self.assertEquals('This is on stderr.\n', f.readline())
self.assertEquals('', f.readline())
def test_8_invoke_shell(self):
def test_9_invoke_shell(self):
"""
verify that invoke_shell() does something reasonable.
"""
@ -312,7 +340,7 @@ class TransportTest (unittest.TestCase):
chan.close()
self.assertEquals('', f.readline())
def test_9_exit_status(self):
def test_A_exit_status(self):
"""
verify that get_exit_status() works.
"""