[project @ Arch-1:robey@lag.net--2005-master-shake%paramiko--dev--1--patch-5]

split out Packetizer, fix banner detection bug, new unit test
split out a chunk of BaseTransport into a Packetizer class, which handles
the in/out packet data, ciphers, etc.  it didn't make the code any smaller
(transport.py is still close to 1500 lines, which is awful) but it did split
out a coherent chunk of functionality into a discrete unit.

in the process, fixed a bug that alain spineux pointed out: the banner
check was too forgiving and would block forever waiting for an SSH banner.
now it waits 5 seconds for the first line, and 2 seconds for each subsequent
line, before giving up.

added a unit test to test keepalive, since i wasn't sure that was still
working after pulling out Packetizer.
This commit is contained in:
Robey Pointer 2005-05-01 08:04:59 +00:00
parent 2f2d7bdee8
commit 36055c5ac2
9 changed files with 553 additions and 285 deletions

2
README
View File

@ -231,3 +231,5 @@ v0.9 FEAROW
* would be nice to have an ftp-like interface to sftp (put, get, chdir...) * would be nice to have an ftp-like interface to sftp (put, get, chdir...)
* speed up file transfers!
* what is psyco?

View File

@ -298,7 +298,6 @@ class Channel (object):
m.add_boolean(0) m.add_boolean(0)
m.add_int(status) m.add_int(status)
self.transport._send_user_message(m) self.transport._send_user_message(m)
self._log(DEBUG, 'EXIT-STATUS')
def get_transport(self): def get_transport(self):
""" """
@ -468,7 +467,7 @@ class Channel (object):
it means you may need to wait before more data arrives. it means you may need to wait before more data arrives.
@return: C{True} if a L{recv} call on this channel would immediately @return: C{True} if a L{recv} call on this channel would immediately
return at least one byte; C{False} otherwise. return at least one byte; C{False} otherwise.
@rtype: boolean @rtype: boolean
""" """
self.lock.acquire() self.lock.acquire()
@ -492,7 +491,7 @@ class Channel (object):
@rtype: str @rtype: str
@raise socket.timeout: if no data is ready before the timeout set by @raise socket.timeout: if no data is ready before the timeout set by
L{settimeout}. L{settimeout}.
""" """
out = '' out = ''
self.lock.acquire() self.lock.acquire()

401
paramiko/packet.py Normal file
View File

@ -0,0 +1,401 @@
#!/usr/bin/python
# Copyright (C) 2003-2005 Robey Pointer <robey@lag.net>
#
# This file is part of paramiko.
#
# Paramiko is free software; you can redistribute it and/or modify it under the
# terms of the GNU Lesser General Public License as published by the Free
# Software Foundation; either version 2.1 of the License, or (at your option)
# any later version.
#
# Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
# details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with Paramiko; if not, write to the Free Software Foundation, Inc.,
# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA.
"""
Packetizer.
"""
import select, socket, struct, threading, time
from Crypto.Hash import HMAC
from common import *
from message import Message
import util
class Packetizer (object):
"""
Implementation of the base SSH packet protocol.
"""
# READ the secsh RFC's before raising these values. if anything,
# they should probably be lower.
REKEY_PACKETS = pow(2, 30)
REKEY_BYTES = pow(2, 30)
def __init__(self, socket):
self.__socket = socket
self.__logger = None
self.__closed = False
self.__dump_packets = False
self.__need_rekey = False
# used for noticing when to re-key:
self.__sent_bytes = 0
self.__sent_packets = 0
self.__received_bytes = 0
self.__received_packets = 0
self.__received_packets_overflow = 0
# current inbound/outbound ciphering:
self.__block_size_out = 8
self.__block_size_in = 8
self.__mac_size_out = 0
self.__mac_size_in = 0
self.__block_engine_out = None
self.__block_engine_in = None
self.__mac_engine_out = None
self.__mac_engine_in = None
self.__mac_key_out = ''
self.__mac_key_in = ''
self.__sequence_number_out = 0L
self.__sequence_number_in = 0L
# lock around outbound writes (packet computation)
self.__write_lock = threading.RLock()
# keepalives:
self.__keepalive_interval = 0
self.__keepalive_last = time.time()
self.__keepalive_callback = None
def set_log(self, log):
"""
Set the python log object to use for logging.
"""
self.__logger = log
def set_outbound_cipher(self, block_engine, block_size, mac_engine, mac_size, mac_key):
"""
Switch outbound data cipher.
"""
self.__block_engine_out = block_engine
self.__block_size_out = block_size
self.__mac_engine_out = mac_engine
self.__mac_size_out = mac_size
self.__mac_key_out = mac_key
self.__sent_bytes = 0
self.__sent_packets = 0
self.__need_rekey = False
def set_inbound_cipher(self, block_engine, block_size, mac_engine, mac_size, mac_key):
"""
Switch inbound data cipher.
"""
self.__block_engine_in = block_engine
self.__block_size_in = block_size
self.__mac_engine_in = mac_engine
self.__mac_size_in = mac_size
self.__mac_key_in = mac_key
self.__received_bytes = 0
self.__received_packets = 0
self.__received_packets_overflow = 0
self.__need_rekey = False
def close(self):
self.__closed = True
self.__block_engine_in = None
self.__block_engine_out = None
self.__socket.close()
def set_hexdump(self, hexdump):
self.__dump_packets = hexdump
def get_hexdump(self):
return self.__dump_packets
def get_mac_size_in(self):
return self.__mac_size_in
def get_mac_size_out(self):
return self.__mac_size_out
def need_rekey(self):
"""
Returns C{True} if a new set of keys needs to be negotiated. This
will be triggered during a packet read or write, so it should be
checked after every read or write, or at least after every few.
@return: C{True} if a new set of keys needs to be negotiated
"""
return self.__need_rekey
def set_keepalive(self, interval, callback):
"""
Turn on/off the callback keepalive. If C{interval} seconds pass with
no data read from or written to the socket, the callback will be
executed and the timer will be reset.
"""
self.__keepalive_interval = interval
self.__keepalive_callback = callback
self.__keepalive_last = time.time()
self._log(DEBUG, 'SET KEEPALIVE %r' % interval)
def read_all(self, n):
"""
Read as close to N bytes as possible, blocking as long as necessary.
@param n: number of bytes to read
@type n: int
@return: the data read
@rtype: str
@throw EOFError: if the socket was closed before all the bytes could
be read
"""
if PY22:
return self._py22_read_all(n)
out = ''
while n > 0:
try:
x = self.__socket.recv(n)
if len(x) == 0:
raise EOFError()
out += x
n -= len(x)
except socket.timeout:
if self.__closed:
raise EOFError()
self._check_keepalive()
return out
def write_all(self, out):
self.__keepalive_last = time.time()
while len(out) > 0:
try:
n = self.__socket.send(out)
except socket.timeout:
n = 0
if self.__closed:
n = -1
except Exception, x:
# could be: (32, 'Broken pipe')
n = -1
if n < 0:
raise EOFError()
if n == len(out):
return
out = out[n:]
return
def readline(self, timeout):
"""
Read a line from the socket. This is done in a fairly inefficient
way, but is only used for initial banner negotiation so it's not worth
optimising.
"""
buffer = ''
while not '\n' in buffer:
buffer += self._read_timeout(timeout)
buffer = buffer[:-1]
if (len(buffer) > 0) and (buffer[-1] == '\r'):
buffer = buffer[:-1]
return buffer
def send_message(self, data):
"""
Write a block of data using the current cipher, as an SSH block.
"""
# encrypt this sucka
data = str(data)
cmd = ord(data[0])
if cmd in MSG_NAMES:
cmd_name = MSG_NAMES[cmd]
else:
cmd_name = '$%x' % cmd
self._log(DEBUG, 'Write packet <%s>, length %d' % (cmd_name, len(data)))
packet = self._build_packet(data)
if self.__dump_packets:
self._log(DEBUG, util.format_binary(packet, 'OUT: '))
if self.__block_engine_out != None:
out = self.__block_engine_out.encrypt(packet)
else:
out = packet
# + mac
self.__write_lock.acquire()
try:
if self.__block_engine_out != None:
payload = struct.pack('>I', self.__sequence_number_out) + packet
out += HMAC.HMAC(self.__mac_key_out, payload, self.__mac_engine_out).digest()[:self.__mac_size_out]
self.__sequence_number_out = (self.__sequence_number_out + 1) & 0xffffffffL
self.write_all(out)
self.__sent_bytes += len(out)
self.__sent_packets += 1
if ((self.__sent_packets >= self.REKEY_PACKETS) or (self.__sent_bytes >= self.REKEY_BYTES)) \
and not self.__need_rekey:
# only ask once for rekeying
self._log(DEBUG, 'Rekeying (hit %d packets, %d bytes sent)' %
(self.__sent_packets, self.__sent_bytes))
self.__received_packets_overflow = 0
self._trigger_rekey()
finally:
self.__write_lock.release()
def read_message(self):
"""
Only one thread should ever be in this function (no other locking is
done).
@throw SSHException: if the packet is mangled
"""
header = self.read_all(self.__block_size_in)
if self.__block_engine_in != None:
header = self.__block_engine_in.decrypt(header)
if self.__dump_packets:
self._log(DEBUG, util.format_binary(header, 'IN: '));
packet_size = struct.unpack('>I', header[:4])[0]
# leftover contains decrypted bytes from the first block (after the length field)
leftover = header[4:]
if (packet_size - len(leftover)) % self.__block_size_in != 0:
raise SSHException('Invalid packet blocking')
buffer = self.read_all(packet_size + self.__mac_size_in - len(leftover))
packet = buffer[:packet_size - len(leftover)]
post_packet = buffer[packet_size - len(leftover):]
if self.__block_engine_in != None:
packet = self.__block_engine_in.decrypt(packet)
if self.__dump_packets:
self._log(DEBUG, util.format_binary(packet, 'IN: '));
packet = leftover + packet
if self.__mac_size_in > 0:
mac = post_packet[:self.__mac_size_in]
mac_payload = struct.pack('>II', self.__sequence_number_in, packet_size) + packet
my_mac = HMAC.HMAC(self.__mac_key_in, mac_payload, self.__mac_engine_in).digest()[:self.__mac_size_in]
if my_mac != mac:
raise SSHException('Mismatched MAC')
padding = ord(packet[0])
payload = packet[1:packet_size - padding + 1]
randpool.add_event(packet[packet_size - padding + 1])
if self.__dump_packets:
self._log(DEBUG, 'Got payload (%d bytes, %d padding)' % (packet_size, padding))
msg = Message(payload[1:])
msg.seqno = self.__sequence_number_in
self.__sequence_number_in = (self.__sequence_number_in + 1) & 0xffffffffL
# check for rekey
self.__received_bytes += packet_size + self.__mac_size_in + 4
self.__received_packets += 1
if self.__need_rekey:
# we've asked to rekey -- give them 20 packets to comply before
# dropping the connection
self.__received_packets_overflow += 1
if self.__received_packets_overflow >= 20:
raise SSHException('Remote transport is ignoring rekey requests')
elif (self.__received_packets >= self.REKEY_PACKETS) or \
(self.__received_bytes >= self.REKEY_BYTES):
# only ask once for rekeying
self._log(DEBUG, 'Rekeying (hit %d packets, %d bytes received)' %
(self.__received_packets, self.__received_bytes))
self.__received_packets_overflow = 0
self._trigger_rekey()
cmd = ord(payload[0])
if cmd in MSG_NAMES:
cmd_name = MSG_NAMES[cmd]
else:
cmd_name = '$%x' % cmd
self._log(DEBUG, 'Read packet <%s>, length %d' % (cmd_name, len(payload)))
return cmd, msg
########## protected
def _log(self, level, msg):
if self.__logger is None:
return
if issubclass(type(msg), list):
for m in msg:
self.__logger.log(level, m)
else:
self.__logger.log(level, msg)
def _check_keepalive(self):
if (not self.__keepalive_interval) or (not self.__block_engine_out) or \
self.__need_rekey:
# wait till we're encrypting, and not in the middle of rekeying
return
now = time.time()
if now > self.__keepalive_last + self.__keepalive_interval:
self.__keepalive_callback()
self.__keepalive_last = now
def _py22_read_all(self, n):
out = ''
while n > 0:
r, w, e = select.select([self.__socket], [], [], 0.1)
if self.__socket not in r:
if self.__closed:
raise EOFError()
self._check_keepalive()
else:
x = self.__socket.recv(n)
if len(x) == 0:
raise EOFError()
out += x
n -= len(x)
return out
def _py22_read_timeout(self, timeout):
start = time.time()
while True:
r, w, e = select.select([self.__socket], [], [], 0.1)
if self.__socket in r:
x = self.__socket.recv(1)
if len(x) == 0:
raise EOFError()
return x
if self.__closed:
raise EOFError()
now = time.time()
if now - start >= timeout:
raise socket.timeout()
def _read_timeout(self, timeout):
if PY22:
return self._py22_read_timeout(n)
start = time.time()
while True:
try:
x = self.__socket.recv(1)
if len(x) == 0:
raise EOFError()
return x
except socket.timeout:
pass
if self.__closed:
raise EOFError()
now = time.time()
if now - start >= timeout:
raise socket.timeout()
def _build_packet(self, payload):
# pad up at least 4 bytes, to nearest block-size (usually 8)
bsize = self.__block_size_out
padding = 3 + bsize - ((len(payload) + 8) % bsize)
packet = struct.pack('>IB', len(payload) + padding + 1, padding)
packet += payload
packet += randpool.get_bytes(padding)
return packet
def _trigger_rekey(self):
# outside code should check for this flag
self.__need_rekey = True

View File

@ -53,7 +53,7 @@ class SFTPClient (BaseSFTP):
transport = self.sock.get_transport() transport = self.sock.get_transport()
self.logger = util.get_logger(transport.get_log_channel() + '.' + self.logger = util.get_logger(transport.get_log_channel() + '.' +
self.sock.get_name() + '.sftp') self.sock.get_name() + '.sftp')
self.ultra_debug = transport.ultra_debug self.ultra_debug = transport.get_hexdump()
self._send_version() self._send_version()
def from_transport(selfclass, t): def from_transport(selfclass, t):

View File

@ -60,7 +60,7 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
transport = channel.get_transport() transport = channel.get_transport()
self.logger = util.get_logger(transport.get_log_channel() + '.' + self.logger = util.get_logger(transport.get_log_channel() + '.' +
channel.get_name() + '.sftp') channel.get_name() + '.sftp')
self.ultra_debug = transport.ultra_debug self.ultra_debug = transport.get_hexdump()
self.next_handle = 1 self.next_handle = 1
# map of handle-string to SFTPHandle for files & folders: # map of handle-string to SFTPHandle for files & folders:
self.file_table = { } self.file_table = { }

View File

@ -1,5 +1,3 @@
#!/usr/bin/python
# Copyright (C) 2003-2005 Robey Pointer <robey@lag.net> # Copyright (C) 2003-2005 Robey Pointer <robey@lag.net>
# #
# This file is part of paramiko. # This file is part of paramiko.
@ -30,6 +28,7 @@ from message import Message
from channel import Channel from channel import Channel
from sftp_client import SFTPClient from sftp_client import SFTPClient
import util import util
from packet import Packetizer
from rsakey import RSAKey from rsakey import RSAKey
from dsskey import DSSKey from dsskey import DSSKey
from kex_group1 import KexGroup1 from kex_group1 import KexGroup1
@ -49,7 +48,7 @@ from Crypto.Hash import SHA, MD5, HMAC
_active_threads = [] _active_threads = []
def _join_lingering_threads(): def _join_lingering_threads():
for thr in _active_threads: for thr in _active_threads:
thr.active = False thr.stop_thread()
import atexit import atexit
atexit.register(_join_lingering_threads) atexit.register(_join_lingering_threads)
@ -162,10 +161,6 @@ class BaseTransport (threading.Thread):
'diffie-hellman-group-exchange-sha1': KexGex, 'diffie-hellman-group-exchange-sha1': KexGex,
} }
# READ the secsh RFC's before raising these values. if anything,
# they should probably be lower.
REKEY_PACKETS = pow(2, 30)
REKEY_BYTES = pow(2, 30)
_modulus_pack = None _modulus_pack = None
@ -222,42 +217,29 @@ class BaseTransport (threading.Thread):
except AttributeError: except AttributeError:
pass pass
# negotiated crypto parameters # negotiated crypto parameters
self.packetizer = Packetizer(sock)
self.local_version = 'SSH-' + self._PROTO_ID + '-' + self._CLIENT_ID self.local_version = 'SSH-' + self._PROTO_ID + '-' + self._CLIENT_ID
self.remote_version = '' self.remote_version = ''
self.block_size_out = self.block_size_in = 8
self.local_mac_len = self.remote_mac_len = 0
self.engine_in = self.engine_out = None
self.local_cipher = self.remote_cipher = '' self.local_cipher = self.remote_cipher = ''
self.sequence_number_in = self.sequence_number_out = 0L
self.local_kex_init = self.remote_kex_init = None self.local_kex_init = self.remote_kex_init = None
self.session_id = None self.session_id = None
# /negotiated crypto parameters # /negotiated crypto parameters
self.expected_packet = 0 self.expected_packet = 0
self.active = False self.active = False
self.initial_kex_done = False self.initial_kex_done = False
self.write_lock = threading.RLock() # lock around outbound writes (packet computation) self.lock = threading.Lock() # synchronization (always higher level than write_lock)
self.lock = threading.Lock() # synchronization (always higher level than write_lock) self.channels = { } # (id -> Channel)
self.channels = { } # (id -> Channel) self.channel_events = { } # (id -> Event)
self.channel_events = { } # (id -> Event)
self.channel_counter = 1 self.channel_counter = 1
self.window_size = 65536 self.window_size = 65536
self.max_packet_size = 32768 self.max_packet_size = 32768
self.ultra_debug = False
self.saved_exception = None self.saved_exception = None
self.clear_to_send = threading.Event() self.clear_to_send = threading.Event()
self.log_name = 'paramiko.transport' self.log_name = 'paramiko.transport'
self.logger = util.get_logger(self.log_name) self.logger = util.get_logger(self.log_name)
# used for noticing when to re-key: self.packetizer.set_log(self.logger)
self.received_bytes = 0
self.received_packets = 0
self.received_packets_overflow = 0
self.sent_bytes = 0
self.sent_packets = 0
# user-defined event callbacks: # user-defined event callbacks:
self.completion_event = None self.completion_event = None
# keepalives:
self.keepalive_interval = 0
self.keepalive_last = time.time()
# server mode: # server mode:
self.server_mode = False self.server_mode = False
self.server_object = None self.server_object = None
@ -293,7 +275,7 @@ class BaseTransport (threading.Thread):
preference for them. preference for them.
@return: an object that can be used to change the preferred algorithms @return: an object that can be used to change the preferred algorithms
for encryption, digest (hash), public key, and key exchange. for encryption, digest (hash), public key, and key exchange.
@rtype: L{SecurityOptions} @rtype: L{SecurityOptions}
@since: ivysaur @since: ivysaur
@ -316,8 +298,8 @@ class BaseTransport (threading.Thread):
@note: L{connect} is a simpler method for connecting as a client. @note: L{connect} is a simpler method for connecting as a client.
@note: After calling this method (or L{start_server} or L{connect}), @note: After calling this method (or L{start_server} or L{connect}),
you should no longer directly read from or write to the original socket you should no longer directly read from or write to the original
object. socket object.
@param event: an event to trigger when negotiation is complete. @param event: an event to trigger when negotiation is complete.
@type event: threading.Event @type event: threading.Event
@ -350,13 +332,13 @@ class BaseTransport (threading.Thread):
given C{server} object to allow channels to be opened. given C{server} object to allow channels to be opened.
@note: After calling this method (or L{start_client} or L{connect}), @note: After calling this method (or L{start_client} or L{connect}),
you should no longer directly read from or write to the original socket you should no longer directly read from or write to the original
object. socket object.
@param event: an event to trigger when negotiation is complete. @param event: an event to trigger when negotiation is complete.
@type event: threading.Event @type event: threading.Event
@param server: an object used to perform authentication and create @param server: an object used to perform authentication and create
L{Channel}s. L{Channel}s.
@type server: L{server.ServerInterface} @type server: L{server.ServerInterface}
""" """
if server is None: if server is None:
@ -376,7 +358,7 @@ class BaseTransport (threading.Thread):
key info, not just the public half. key info, not just the public half.
@param key: the host key to add, usually an L{RSAKey <rsakey.RSAKey>} or @param key: the host key to add, usually an L{RSAKey <rsakey.RSAKey>} or
L{DSSKey <dsskey.DSSKey>}. L{DSSKey <dsskey.DSSKey>}.
@type key: L{PKey <pkey.PKey>} @type key: L{PKey <pkey.PKey>}
""" """
self.server_key_dict[key.get_name()] = key self.server_key_dict[key.get_name()] = key
@ -418,10 +400,10 @@ class BaseTransport (threading.Thread):
support that method of key negotiation. support that method of key negotiation.
@param filename: optional path to the moduli file, if you happen to @param filename: optional path to the moduli file, if you happen to
know that it's not in a standard location. know that it's not in a standard location.
@type filename: str @type filename: str
@return: True if a moduli file was successfully loaded; False @return: True if a moduli file was successfully loaded; False
otherwise. otherwise.
@rtype: bool @rtype: bool
@since: doduo @since: doduo
@ -449,8 +431,7 @@ class BaseTransport (threading.Thread):
Close this session, and any open channels that are tied to it. Close this session, and any open channels that are tied to it.
""" """
self.active = False self.active = False
self.engine_in = self.engine_out = None self.packetizer.close()
self.sequence_number_in = self.sequence_number_out = 0L
for chan in self.channels.values(): for chan in self.channels.values():
chan._unlink() chan._unlink()
@ -459,9 +440,9 @@ class BaseTransport (threading.Thread):
Return the host key of the server (in client mode). Return the host key of the server (in client mode).
@note: Previously this call returned a tuple of (key type, key string). @note: Previously this call returned a tuple of (key type, key string).
You can get the same effect by calling You can get the same effect by calling
L{PKey.get_name <pkey.PKey.get_name>} for the key type, and C{str(key)} L{PKey.get_name <pkey.PKey.get_name>} for the key type, and
for the key string. C{str(key)} for the key string.
@raise SSHException: if no session is currently active. @raise SSHException: if no session is currently active.
@ -476,7 +457,8 @@ class BaseTransport (threading.Thread):
""" """
Return true if this session is active (open). Return true if this session is active (open).
@return: True if the session is still active (open); False if the session is closed. @return: True if the session is still active (open); False if the
session is closed.
@rtype: bool @rtype: bool
""" """
return self.active return self.active
@ -487,7 +469,7 @@ class BaseTransport (threading.Thread):
is just an alias for C{open_channel('session')}. is just an alias for C{open_channel('session')}.
@return: a new L{Channel} on success, or C{None} if the request is @return: a new L{Channel} on success, or C{None} if the request is
rejected or the session ends prematurely. rejected or the session ends prematurely.
@rtype: L{Channel} @rtype: L{Channel}
""" """
return self.open_channel('session') return self.open_channel('session')
@ -500,17 +482,17 @@ class BaseTransport (threading.Thread):
L{connect} or L{start_client}) and authenticating. L{connect} or L{start_client}) and authenticating.
@param kind: the kind of channel requested (usually C{"session"}, @param kind: the kind of channel requested (usually C{"session"},
C{"forwarded-tcpip"} or C{"direct-tcpip"}). C{"forwarded-tcpip"} or C{"direct-tcpip"}).
@type kind: str @type kind: str
@param dest_addr: the destination address of this port forwarding, @param dest_addr: the destination address of this port forwarding,
if C{kind} is C{"forwarded-tcpip"} or C{"direct-tcpip"} (ignored if C{kind} is C{"forwarded-tcpip"} or C{"direct-tcpip"} (ignored
for other channel types). for other channel types).
@type dest_addr: (str, int) @type dest_addr: (str, int)
@param src_addr: the source address of this port forwarding, if @param src_addr: the source address of this port forwarding, if
C{kind} is C{"forwarded-tcpip"} or C{"direct-tcpip"}. C{kind} is C{"forwarded-tcpip"} or C{"direct-tcpip"}.
@type src_addr: (str, int) @type src_addr: (str, int)
@return: a new L{Channel} on success, or C{None} if the request is @return: a new L{Channel} on success, or C{None} if the request is
rejected or the session ends prematurely. rejected or the session ends prematurely.
@rtype: L{Channel} @rtype: L{Channel}
""" """
chan = None chan = None
@ -599,7 +581,7 @@ class BaseTransport (threading.Thread):
session has died mid-negotiation. session has died mid-negotiation.
@return: True if the renegotiation was successful, and the link is @return: True if the renegotiation was successful, and the link is
using new keys; False if the session dropped during renegotiation. using new keys; False if the session dropped during renegotiation.
@rtype: bool @rtype: bool
""" """
self.completion_event = threading.Event() self.completion_event = threading.Event()
@ -620,12 +602,13 @@ class BaseTransport (threading.Thread):
can be useful to keep connections alive over a NAT, for example. can be useful to keep connections alive over a NAT, for example.
@param interval: seconds to wait before sending a keepalive packet (or @param interval: seconds to wait before sending a keepalive packet (or
0 to disable keepalives). 0 to disable keepalives).
@type interval: int @type interval: int
@since: fearow @since: fearow
""" """
self.keepalive_interval = interval self.packetizer.set_keepalive(interval,
lambda x=self: x.global_request('keepalive@lag.net', wait=False))
def global_request(self, kind, data=None, wait=True): def global_request(self, kind, data=None, wait=True):
""" """
@ -635,14 +618,14 @@ class BaseTransport (threading.Thread):
@param kind: name of the request. @param kind: name of the request.
@type kind: str @type kind: str
@param data: an optional tuple containing additional data to attach @param data: an optional tuple containing additional data to attach
to the request. to the request.
@type data: tuple @type data: tuple
@param wait: C{True} if this method should not return until a response @param wait: C{True} if this method should not return until a response
is received; C{False} otherwise. is received; C{False} otherwise.
@type wait: bool @type wait: bool
@return: a L{Message} containing possible additional data if the @return: a L{Message} containing possible additional data if the
request was successful (or an empty L{Message} if C{wait} was request was successful (or an empty L{Message} if C{wait} was
C{False}); C{None} if the request was denied. C{False}); C{None} if the request was denied.
@rtype: L{Message} @rtype: L{Message}
@since: fearow @since: fearow
@ -833,7 +816,23 @@ class BaseTransport (threading.Thread):
C{False} otherwise. C{False} otherwise.
@type hexdump: bool @type hexdump: bool
""" """
self.ultra_debug = hexdump self.packetizer.set_hexdump(hexdump)
def get_hexdump(self):
"""
Return C{True} if the transport is currently logging hex dumps of
protocol traffic.
@return: C{True} if hex dumps are being logged
@rtype: bool
@since: 1.4
"""
return self.packetizer.get_hexdump()
def stop_thread(self):
self.active = False
self.packetizer.close()
### internals... ### internals...
@ -859,113 +858,10 @@ class BaseTransport (threading.Thread):
finally: finally:
self.lock.release() self.lock.release()
def _check_keepalive(self):
if (not self.keepalive_interval) or (not self.initial_kex_done):
return
now = time.time()
if now > self.keepalive_last + self.keepalive_interval:
self.global_request('keepalive@lag.net', wait=False)
def _py22_read_all(self, n):
out = ''
while n > 0:
r, w, e = select.select([self.sock], [], [], 0.1)
if self.sock not in r:
if not self.active:
raise EOFError()
self._check_keepalive()
else:
x = self.sock.recv(n)
if len(x) == 0:
raise EOFError()
out += x
n -= len(x)
return out
def _read_all(self, n):
if PY22:
return self._py22_read_all(n)
out = ''
while n > 0:
try:
x = self.sock.recv(n)
if len(x) == 0:
raise EOFError()
out += x
n -= len(x)
except socket.timeout:
if not self.active:
raise EOFError()
self._check_keepalive()
return out
def _write_all(self, out):
self.keepalive_last = time.time()
while len(out) > 0:
try:
n = self.sock.send(out)
except socket.timeout:
n = 0
if not self.active:
n = -1
except Exception, x:
# could be: (32, 'Broken pipe')
n = -1
if n < 0:
raise EOFError()
if n == len(out):
return
out = out[n:]
return
def _build_packet(self, payload):
# pad up at least 4 bytes, to nearest block-size (usually 8)
bsize = self.block_size_out
padding = 3 + bsize - ((len(payload) + 8) % bsize)
packet = struct.pack('>I', len(payload) + padding + 1)
packet += chr(padding)
packet += payload
packet += randpool.get_bytes(padding)
return packet
def _send_message(self, data): def _send_message(self, data):
# encrypt this sucka self.packetizer.send_message(data)
data = str(data) if self.packetizer.need_rekey():
cmd = ord(data[0]) self._send_kex_init()
if cmd in MSG_NAMES:
cmd_name = MSG_NAMES[cmd]
else:
cmd_name = '$%x' % cmd
self._log(DEBUG, 'Write packet <%s>, length %d' % (cmd_name, len(data)))
packet = self._build_packet(data)
if self.ultra_debug:
self._log(DEBUG, util.format_binary(packet, 'OUT: '))
if self.engine_out != None:
out = self.engine_out.encrypt(packet)
else:
out = packet
# + mac
try:
self.write_lock.acquire()
if self.engine_out != None:
payload = struct.pack('>I', self.sequence_number_out) + packet
out += HMAC.HMAC(self.mac_key_out, payload, self.local_mac_engine).digest()[:self.local_mac_len]
self.sequence_number_out += 1L
self.sequence_number_out %= 0x100000000L
self._write_all(out)
self.sent_bytes += len(out)
self.sent_packets += 1
if ((self.sent_packets >= self.REKEY_PACKETS) or (self.sent_bytes >= self.REKEY_BYTES)) \
and (self.local_kex_init is None):
# only ask once for rekeying
self._log(DEBUG, 'Rekeying (hit %d packets, %d bytes sent)' %
(self.sent_packets, self.sent_bytes))
self.received_packets_overflow = 0
# this may do a recursive lock, but that's okay:
self._send_kex_init()
finally:
self.write_lock.release()
def _send_user_message(self, data): def _send_user_message(self, data):
""" """
@ -981,65 +877,6 @@ class BaseTransport (threading.Thread):
break break
self._send_message(data) self._send_message(data)
def _read_message(self):
"only one thread will ever be in this function"
header = self._read_all(self.block_size_in)
if self.engine_in != None:
header = self.engine_in.decrypt(header)
if self.ultra_debug:
self._log(DEBUG, util.format_binary(header, 'IN: '));
packet_size = struct.unpack('>I', header[:4])[0]
# leftover contains decrypted bytes from the first block (after the length field)
leftover = header[4:]
if (packet_size - len(leftover)) % self.block_size_in != 0:
raise SSHException('Invalid packet blocking')
buffer = self._read_all(packet_size + self.remote_mac_len - len(leftover))
packet = buffer[:packet_size - len(leftover)]
post_packet = buffer[packet_size - len(leftover):]
if self.engine_in != None:
packet = self.engine_in.decrypt(packet)
if self.ultra_debug:
self._log(DEBUG, util.format_binary(packet, 'IN: '));
packet = leftover + packet
if self.remote_mac_len > 0:
mac = post_packet[:self.remote_mac_len]
mac_payload = struct.pack('>II', self.sequence_number_in, packet_size) + packet
my_mac = HMAC.HMAC(self.mac_key_in, mac_payload, self.remote_mac_engine).digest()[:self.remote_mac_len]
if my_mac != mac:
raise SSHException('Mismatched MAC')
padding = ord(packet[0])
payload = packet[1:packet_size - padding + 1]
randpool.add_event(packet[packet_size - padding + 1])
if self.ultra_debug:
self._log(DEBUG, 'Got payload (%d bytes, %d padding)' % (packet_size, padding))
msg = Message(payload[1:])
msg.seqno = self.sequence_number_in
self.sequence_number_in = (self.sequence_number_in + 1) & 0xffffffffL
# check for rekey
self.received_bytes += packet_size + self.remote_mac_len + 4
self.received_packets += 1
if self.local_kex_init is not None:
# we've asked to rekey -- give them 20 packets to comply before
# dropping the connection
self.received_packets_overflow += 1
if self.received_packets_overflow >= 20:
raise SSHException('Remote transport is ignoring rekey requests')
elif (self.received_packets >= self.REKEY_PACKETS) or \
(self.received_bytes >= self.REKEY_BYTES):
# only ask once for rekeying
self._log(DEBUG, 'Rekeying (hit %d packets, %d bytes received)' %
(self.received_packets, self.received_bytes))
self.received_packets_overflow = 0
self._send_kex_init()
cmd = ord(payload[0])
if cmd in MSG_NAMES:
cmd_name = MSG_NAMES[cmd]
else:
cmd_name = '$%x' % cmd
self._log(DEBUG, 'Read packet <%s>, length %d' % (cmd_name, len(payload)))
return cmd, msg
def _set_K_H(self, k, h): def _set_K_H(self, k, h):
"used by a kex object to set the K (root key) and H (exchange hash)" "used by a kex object to set the K (root key) and H (exchange hash)"
self.K = k self.K = k
@ -1090,18 +927,21 @@ class BaseTransport (threading.Thread):
else: else:
self._log(DEBUG, 'starting thread (client mode): %s' % hex(long(id(self)) & 0xffffffffL)) self._log(DEBUG, 'starting thread (client mode): %s' % hex(long(id(self)) & 0xffffffffL))
try: try:
self._write_all(self.local_version + '\r\n') self.packetizer.write_all(self.local_version + '\r\n')
self._check_banner() self._check_banner()
self._send_kex_init() self._send_kex_init()
self.expected_packet = MSG_KEXINIT self.expected_packet = MSG_KEXINIT
while self.active: while self.active:
ptype, m = self._read_message() if self.packetizer.need_rekey():
self._send_kex_init()
ptype, m = self.packetizer.read_message()
if ptype == MSG_IGNORE: if ptype == MSG_IGNORE:
continue continue
elif ptype == MSG_DISCONNECT: elif ptype == MSG_DISCONNECT:
self._parse_disconnect(m) self._parse_disconnect(m)
self.active = False self.active = False
self.packetizer.close()
break break
elif ptype == MSG_DEBUG: elif ptype == MSG_DEBUG:
self._parse_debug(m) self._parse_debug(m)
@ -1123,6 +963,7 @@ class BaseTransport (threading.Thread):
else: else:
self._log(ERROR, 'Channel request for unknown channel %d' % chanid) self._log(ERROR, 'Channel request for unknown channel %d' % chanid)
self.active = False self.active = False
self.packetizer.close()
else: else:
self._log(WARNING, 'Oops, unhandled type %d' % ptype) self._log(WARNING, 'Oops, unhandled type %d' % ptype)
msg = Message() msg = Message()
@ -1146,6 +987,7 @@ class BaseTransport (threading.Thread):
chan._unlink() chan._unlink()
if self.active: if self.active:
self.active = False self.active = False
self.packetizer.close()
if self.completion_event != None: if self.completion_event != None:
self.completion_event.set() self.completion_event.set()
if self.auth_event != None: if self.auth_event != None:
@ -1170,12 +1012,15 @@ class BaseTransport (threading.Thread):
def _check_banner(self): def _check_banner(self):
# this is slow, but we only have to do it once # this is slow, but we only have to do it once
for i in range(5): for i in range(5):
buffer = '' # give them 5 seconds for the first line, then just 2 seconds each additional line
while not '\n' in buffer: if i == 0:
buffer += self._read_all(1) timeout = 5
buffer = buffer[:-1] else:
if (len(buffer) > 0) and (buffer[-1] == '\r'): timeout = 2
buffer = buffer[:-1] try:
buffer = self.packetizer.readline(timeout)
except Exception, x:
raise SSHException('Error reading SSH protocol banner' + str(x))
if buffer[:4] == 'SSH-': if buffer[:4] == 'SSH-':
break break
self._log(DEBUG, 'Banner: ' + buffer) self._log(DEBUG, 'Banner: ' + buffer)
@ -1236,13 +1081,6 @@ class BaseTransport (threading.Thread):
self._send_message(m) self._send_message(m)
def _parse_kex_init(self, m): def _parse_kex_init(self, m):
# reset counters of when to re-key, since we are now re-keying
self.received_bytes = 0
self.received_packets = 0
self.received_packets_overflow = 0
self.sent_bytes = 0
self.sent_packets = 0
cookie = m.get_bytes(16) cookie = m.get_bytes(16)
kex_algo_list = m.get_list() kex_algo_list = m.get_list()
server_key_algo_list = m.get_list() server_key_algo_list = m.get_list()
@ -1334,44 +1172,46 @@ class BaseTransport (threading.Thread):
def _activate_inbound(self): def _activate_inbound(self):
"switch on newly negotiated encryption parameters for inbound traffic" "switch on newly negotiated encryption parameters for inbound traffic"
self.block_size_in = self._cipher_info[self.remote_cipher]['block-size'] block_size = self._cipher_info[self.remote_cipher]['block-size']
if self.server_mode: if self.server_mode:
IV_in = self._compute_key('A', self.block_size_in) IV_in = self._compute_key('A', block_size)
key_in = self._compute_key('C', self._cipher_info[self.remote_cipher]['key-size']) key_in = self._compute_key('C', self._cipher_info[self.remote_cipher]['key-size'])
else: else:
IV_in = self._compute_key('B', self.block_size_in) IV_in = self._compute_key('B', block_size)
key_in = self._compute_key('D', self._cipher_info[self.remote_cipher]['key-size']) key_in = self._compute_key('D', self._cipher_info[self.remote_cipher]['key-size'])
self.engine_in = self._get_cipher(self.remote_cipher, key_in, IV_in) engine = self._get_cipher(self.remote_cipher, key_in, IV_in)
self.remote_mac_len = self._mac_info[self.remote_mac]['size'] mac_size = self._mac_info[self.remote_mac]['size']
self.remote_mac_engine = self._mac_info[self.remote_mac]['class'] mac_engine = self._mac_info[self.remote_mac]['class']
# initial mac keys are done in the hash's natural size (not the potentially truncated # initial mac keys are done in the hash's natural size (not the potentially truncated
# transmission size) # transmission size)
if self.server_mode: if self.server_mode:
self.mac_key_in = self._compute_key('E', self.remote_mac_engine.digest_size) mac_key = self._compute_key('E', mac_engine.digest_size)
else: else:
self.mac_key_in = self._compute_key('F', self.remote_mac_engine.digest_size) mac_key = self._compute_key('F', mac_engine.digest_size)
self.packetizer.set_inbound_cipher(engine, block_size, mac_engine, mac_size, mac_key)
def _activate_outbound(self): def _activate_outbound(self):
"switch on newly negotiated encryption parameters for outbound traffic" "switch on newly negotiated encryption parameters for outbound traffic"
m = Message() m = Message()
m.add_byte(chr(MSG_NEWKEYS)) m.add_byte(chr(MSG_NEWKEYS))
self._send_message(m) self._send_message(m)
self.block_size_out = self._cipher_info[self.local_cipher]['block-size'] block_size = self._cipher_info[self.local_cipher]['block-size']
if self.server_mode: if self.server_mode:
IV_out = self._compute_key('B', self.block_size_out) IV_out = self._compute_key('B', block_size)
key_out = self._compute_key('D', self._cipher_info[self.local_cipher]['key-size']) key_out = self._compute_key('D', self._cipher_info[self.local_cipher]['key-size'])
else: else:
IV_out = self._compute_key('A', self.block_size_out) IV_out = self._compute_key('A', block_size)
key_out = self._compute_key('C', self._cipher_info[self.local_cipher]['key-size']) key_out = self._compute_key('C', self._cipher_info[self.local_cipher]['key-size'])
self.engine_out = self._get_cipher(self.local_cipher, key_out, IV_out) engine = self._get_cipher(self.local_cipher, key_out, IV_out)
self.local_mac_len = self._mac_info[self.local_mac]['size'] mac_size = self._mac_info[self.local_mac]['size']
self.local_mac_engine = self._mac_info[self.local_mac]['class'] mac_engine = self._mac_info[self.local_mac]['class']
# initial mac keys are done in the hash's natural size (not the potentially truncated # initial mac keys are done in the hash's natural size (not the potentially truncated
# transmission size) # transmission size)
if self.server_mode: if self.server_mode:
self.mac_key_out = self._compute_key('F', self.local_mac_engine.digest_size) mac_key = self._compute_key('F', mac_engine.digest_size)
else: else:
self.mac_key_out = self._compute_key('E', self.local_mac_engine.digest_size) mac_key = self._compute_key('E', mac_engine.digest_size)
self.packetizer.set_outbound_cipher(engine, block_size, mac_engine, mac_size, mac_key)
# we always expect to receive NEWKEYS now # we always expect to receive NEWKEYS now
self.expected_packet = MSG_NEWKEYS self.expected_packet = MSG_NEWKEYS

View File

@ -70,6 +70,9 @@ options, args = parser.parse_args()
if len(args) > 0: if len(args) > 0:
parser.error('unknown argument(s)') parser.error('unknown argument(s)')
# setup logging
paramiko.util.log_to_file('test.log')
if options.use_sftp: if options.use_sftp:
if options.use_loopback_sftp: if options.use_loopback_sftp:
SFTPTest.init_loopback() SFTPTest.init_loopback()
@ -78,9 +81,6 @@ if options.use_sftp:
if not options.use_big_file: if not options.use_big_file:
SFTPTest.set_big_file_test(False) SFTPTest.set_big_file_test(False)
# setup logging
paramiko.util.log_to_file('test.log')
suite = unittest.TestSuite() suite = unittest.TestSuite()
suite.addTest(unittest.makeSuite(MessageTest)) suite.addTest(unittest.makeSuite(MessageTest))
suite.addTest(unittest.makeSuite(BufferedFileTest)) suite.addTest(unittest.makeSuite(BufferedFileTest))

View File

@ -1,5 +1,3 @@
#!/usr/bin/python
# Copyright (C) 2003-2005 Robey Pointer <robey@lag.net> # Copyright (C) 2003-2005 Robey Pointer <robey@lag.net>
# #
# This file is part of paramiko. # This file is part of paramiko.
@ -145,7 +143,7 @@ class SFTPTest (unittest.TestCase):
finally: finally:
sftp.remove(FOLDER + '/test') sftp.remove(FOLDER + '/test')
def test_1a_close(self): def test_2_close(self):
""" """
verify that closing the sftp session doesn't do anything bad, and that verify that closing the sftp session doesn't do anything bad, and that
a new one can be opened. a new one can be opened.
@ -159,7 +157,7 @@ class SFTPTest (unittest.TestCase):
pass pass
sftp = paramiko.SFTP.from_transport(tc) sftp = paramiko.SFTP.from_transport(tc)
def test_2_write(self): def test_3_write(self):
""" """
verify that a file can be created and written, and the size is correct. verify that a file can be created and written, and the size is correct.
""" """
@ -171,7 +169,7 @@ class SFTPTest (unittest.TestCase):
finally: finally:
sftp.remove(FOLDER + '/duck.txt') sftp.remove(FOLDER + '/duck.txt')
def test_3_append(self): def test_4_append(self):
""" """
verify that a file can be opened for append, and tell() still works. verify that a file can be opened for append, and tell() still works.
""" """
@ -191,7 +189,7 @@ class SFTPTest (unittest.TestCase):
finally: finally:
sftp.remove(FOLDER + '/append.txt') sftp.remove(FOLDER + '/append.txt')
def test_4_rename(self): def test_5_rename(self):
""" """
verify that renaming a file works. verify that renaming a file works.
""" """
@ -219,7 +217,7 @@ class SFTPTest (unittest.TestCase):
except: except:
pass pass
def test_5_folder(self): def test_6_folder(self):
""" """
create a temporary folder, verify that we can create a file in it, then create a temporary folder, verify that we can create a file in it, then
remove the folder and verify that we can't create a file in it anymore. remove the folder and verify that we can't create a file in it anymore.
@ -236,7 +234,7 @@ class SFTPTest (unittest.TestCase):
except IOError: except IOError:
pass pass
def test_6_listdir(self): def test_7_listdir(self):
""" """
verify that a folder can be created, a bunch of files can be placed in it, verify that a folder can be created, a bunch of files can be placed in it,
and those files show up in sftp.listdir. and those files show up in sftp.listdir.
@ -262,7 +260,7 @@ class SFTPTest (unittest.TestCase):
sftp.remove(FOLDER + '/fish.txt') sftp.remove(FOLDER + '/fish.txt')
sftp.remove(FOLDER + '/tertiary.py') sftp.remove(FOLDER + '/tertiary.py')
def test_7_setstat(self): def test_8_setstat(self):
""" """
verify that the setstat functions (chown, chmod, utime) work. verify that the setstat functions (chown, chmod, utime) work.
""" """
@ -285,7 +283,7 @@ class SFTPTest (unittest.TestCase):
finally: finally:
sftp.remove(FOLDER + '/special') sftp.remove(FOLDER + '/special')
def test_8_readline_seek(self): def test_9_readline_seek(self):
""" """
create a text file and write a bunch of text into it. then count the lines create a text file and write a bunch of text into it. then count the lines
in the file, and seek around to retreive particular lines. this should in the file, and seek around to retreive particular lines. this should
@ -315,7 +313,7 @@ class SFTPTest (unittest.TestCase):
finally: finally:
sftp.remove(FOLDER + '/duck.txt') sftp.remove(FOLDER + '/duck.txt')
def test_9_write_seek(self): def test_A_write_seek(self):
""" """
create a text file, seek back and change part of it, and verify that the create a text file, seek back and change part of it, and verify that the
changes worked. changes worked.
@ -335,7 +333,7 @@ class SFTPTest (unittest.TestCase):
finally: finally:
sftp.remove(FOLDER + '/testing.txt') sftp.remove(FOLDER + '/testing.txt')
def test_A_symlink(self): def test_B_symlink(self):
""" """
create a symlink and then check that lstat doesn't follow it. create a symlink and then check that lstat doesn't follow it.
""" """
@ -378,7 +376,7 @@ class SFTPTest (unittest.TestCase):
except: except:
pass pass
def test_B_flush_seek(self): def test_C_flush_seek(self):
""" """
verify that buffered writes are automatically flushed on seek. verify that buffered writes are automatically flushed on seek.
""" """
@ -400,7 +398,7 @@ class SFTPTest (unittest.TestCase):
except: except:
pass pass
def test_C_lots_of_files(self): def test_D_lots_of_files(self):
""" """
create a bunch of files over the same session. create a bunch of files over the same session.
""" """
@ -431,7 +429,7 @@ class SFTPTest (unittest.TestCase):
except: except:
pass pass
def test_D_big_file(self): def test_E_big_file(self):
""" """
write a 1MB file, with no linefeeds, using line buffering. write a 1MB file, with no linefeeds, using line buffering.
FIXME: this is slow! what causes the slowness? FIXME: this is slow! what causes the slowness?
@ -453,7 +451,7 @@ class SFTPTest (unittest.TestCase):
finally: finally:
sftp.remove('%s/hongry.txt' % FOLDER) sftp.remove('%s/hongry.txt' % FOLDER)
def test_E_big_file_big_buffer(self): def test_F_big_file_big_buffer(self):
""" """
write a 1MB file, with no linefeeds, and a big buffer. write a 1MB file, with no linefeeds, and a big buffer.
""" """
@ -470,7 +468,7 @@ class SFTPTest (unittest.TestCase):
finally: finally:
sftp.remove('%s/hongry.txt' % FOLDER) sftp.remove('%s/hongry.txt' % FOLDER)
def test_F_realpath(self): def test_G_realpath(self):
""" """
test that realpath is returning something non-empty and not an test that realpath is returning something non-empty and not an
error. error.
@ -481,7 +479,7 @@ class SFTPTest (unittest.TestCase):
self.assert_(len(f) > 0) self.assert_(len(f) > 0)
self.assertEquals(os.path.join(pwd, FOLDER), f) self.assertEquals(os.path.join(pwd, FOLDER), f)
def test_G_mkdir(self): def test_H_mkdir(self):
""" """
verify that mkdir/rmdir work. verify that mkdir/rmdir work.
""" """

View File

@ -22,7 +22,7 @@
Some unit tests for the ssh2 protocol in Transport. Some unit tests for the ssh2 protocol in Transport.
""" """
import sys, unittest, threading import sys, time, threading, unittest
from paramiko import Transport, SecurityOptions, ServerInterface, RSAKey, DSSKey, \ from paramiko import Transport, SecurityOptions, ServerInterface, RSAKey, DSSKey, \
SSHException, BadAuthenticationType SSHException, BadAuthenticationType
from paramiko import AUTH_FAILED, AUTH_PARTIALLY_SUCCESSFUL, AUTH_SUCCESSFUL from paramiko import AUTH_FAILED, AUTH_PARTIALLY_SUCCESSFUL, AUTH_SUCCESSFUL
@ -77,6 +77,10 @@ class NullServer (ServerInterface):
def check_channel_shell_request(self, channel): def check_channel_shell_request(self, channel):
return True return True
def check_global_request(self, kind, msg):
self._global_request = kind
return False
class TransportTest (unittest.TestCase): class TransportTest (unittest.TestCase):
@ -160,14 +164,38 @@ class TransportTest (unittest.TestCase):
self.assert_(self.ts.is_active()) self.assert_(self.ts.is_active())
self.assertEquals('aes256-cbc', self.tc.local_cipher) self.assertEquals('aes256-cbc', self.tc.local_cipher)
self.assertEquals('aes256-cbc', self.tc.remote_cipher) self.assertEquals('aes256-cbc', self.tc.remote_cipher)
self.assertEquals(12, self.tc.local_mac_len) self.assertEquals(12, self.tc.packetizer.get_mac_size_out())
self.assertEquals(12, self.tc.remote_mac_len) self.assertEquals(12, self.tc.packetizer.get_mac_size_in())
self.tc.send_ignore(1024) self.tc.send_ignore(1024)
self.assert_(self.tc.renegotiate_keys()) self.assert_(self.tc.renegotiate_keys())
self.ts.send_ignore(1024) self.ts.send_ignore(1024)
def test_4_bad_auth_type(self): def test_4_keepalive(self):
"""
verify that the keepalive will be sent.
"""
self.tc.set_hexdump(True)
host_key = RSAKey.from_private_key_file('tests/test_rsa.key')
public_host_key = RSAKey(data=str(host_key))
self.ts.add_server_key(host_key)
event = threading.Event()
server = NullServer()
self.assert_(not event.isSet())
self.ts.start_server(event, server)
self.tc.connect(hostkey=public_host_key,
username='slowdive', password='pygmalion')
event.wait(1.0)
self.assert_(event.isSet())
self.assert_(self.ts.is_active())
self.assertEquals(None, getattr(server, '_global_request', None))
self.tc.set_keepalive(1)
time.sleep(2)
self.assertEquals('keepalive@lag.net', server._global_request)
def test_5_bad_auth_type(self):
""" """
verify that we get the right exception when an unsupported auth verify that we get the right exception when an unsupported auth
type is requested. type is requested.
@ -188,7 +216,7 @@ class TransportTest (unittest.TestCase):
self.assertEquals(BadAuthenticationType, etype) self.assertEquals(BadAuthenticationType, etype)
self.assertEquals(['publickey'], evalue.allowed_types) self.assertEquals(['publickey'], evalue.allowed_types)
def test_5_bad_password(self): def test_6_bad_password(self):
""" """
verify that a bad password gets the right exception, and that a retry verify that a bad password gets the right exception, and that a retry
with the right password works. with the right password works.
@ -213,7 +241,7 @@ class TransportTest (unittest.TestCase):
self.assert_(event.isSet()) self.assert_(event.isSet())
self.assert_(self.ts.is_active()) self.assert_(self.ts.is_active())
def test_6_multipart_auth(self): def test_7_multipart_auth(self):
""" """
verify that multipart auth works. verify that multipart auth works.
""" """
@ -235,7 +263,7 @@ class TransportTest (unittest.TestCase):
self.assert_(event.isSet()) self.assert_(event.isSet())
self.assert_(self.ts.is_active()) self.assert_(self.ts.is_active())
def test_7_exec_command(self): def test_8_exec_command(self):
""" """
verify that exec_command() does something reasonable. verify that exec_command() does something reasonable.
""" """
@ -285,7 +313,7 @@ class TransportTest (unittest.TestCase):
self.assertEquals('This is on stderr.\n', f.readline()) self.assertEquals('This is on stderr.\n', f.readline())
self.assertEquals('', f.readline()) self.assertEquals('', f.readline())
def test_8_invoke_shell(self): def test_9_invoke_shell(self):
""" """
verify that invoke_shell() does something reasonable. verify that invoke_shell() does something reasonable.
""" """
@ -312,7 +340,7 @@ class TransportTest (unittest.TestCase):
chan.close() chan.close()
self.assertEquals('', f.readline()) self.assertEquals('', f.readline())
def test_9_exit_status(self): def test_A_exit_status(self):
""" """
verify that get_exit_status() works. verify that get_exit_status() works.
""" """