486 lines
17 KiB
Python
486 lines
17 KiB
Python
# Copyright (C) 2011 Jeff Forcier <jeff@bitprophet.org>
|
|
#
|
|
# This file is part of ssh.
|
|
#
|
|
# 'ssh' is free software; you can redistribute it and/or modify it under the
|
|
# terms of the GNU Lesser General Public License as published by the Free
|
|
# Software Foundation; either version 2.1 of the License, or (at your option)
|
|
# any later version.
|
|
#
|
|
# 'ssh' is distrubuted in the hope that it will be useful, but WITHOUT ANY
|
|
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
|
|
# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
|
|
# details.
|
|
#
|
|
# You should have received a copy of the GNU Lesser General Public License
|
|
# along with 'ssh'; if not, write to the Free Software Foundation, Inc.,
|
|
# 51 Franklin Street, Suite 500, Boston, MA 02110-1335 USA.
|
|
|
|
"""
|
|
Packetizer.
|
|
"""
|
|
|
|
import errno
|
|
import select
|
|
import socket
|
|
import struct
|
|
import threading
|
|
import time
|
|
|
|
from ssh.common import *
|
|
from ssh import util
|
|
from ssh.ssh_exception import SSHException
|
|
from ssh.message import Message
|
|
|
|
|
|
got_r_hmac = False
|
|
try:
|
|
import r_hmac
|
|
got_r_hmac = True
|
|
except ImportError:
|
|
pass
|
|
def compute_hmac(key, message, digest_class):
|
|
if got_r_hmac:
|
|
return r_hmac.HMAC(key, message, digest_class).digest()
|
|
from Crypto.Hash import HMAC
|
|
return HMAC.HMAC(key, message, digest_class).digest()
|
|
|
|
|
|
class NeedRekeyException (Exception):
|
|
pass
|
|
|
|
|
|
class Packetizer (object):
|
|
"""
|
|
Implementation of the base SSH packet protocol.
|
|
"""
|
|
|
|
# READ the secsh RFC's before raising these values. if anything,
|
|
# they should probably be lower.
|
|
REKEY_PACKETS = pow(2, 30)
|
|
REKEY_BYTES = pow(2, 30)
|
|
|
|
def __init__(self, socket):
|
|
self.__socket = socket
|
|
self.__logger = None
|
|
self.__closed = False
|
|
self.__dump_packets = False
|
|
self.__need_rekey = False
|
|
self.__init_count = 0
|
|
self.__remainder = ''
|
|
|
|
# used for noticing when to re-key:
|
|
self.__sent_bytes = 0
|
|
self.__sent_packets = 0
|
|
self.__received_bytes = 0
|
|
self.__received_packets = 0
|
|
self.__received_packets_overflow = 0
|
|
|
|
# current inbound/outbound ciphering:
|
|
self.__block_size_out = 8
|
|
self.__block_size_in = 8
|
|
self.__mac_size_out = 0
|
|
self.__mac_size_in = 0
|
|
self.__block_engine_out = None
|
|
self.__block_engine_in = None
|
|
self.__mac_engine_out = None
|
|
self.__mac_engine_in = None
|
|
self.__mac_key_out = ''
|
|
self.__mac_key_in = ''
|
|
self.__compress_engine_out = None
|
|
self.__compress_engine_in = None
|
|
self.__sequence_number_out = 0L
|
|
self.__sequence_number_in = 0L
|
|
|
|
# lock around outbound writes (packet computation)
|
|
self.__write_lock = threading.RLock()
|
|
|
|
# keepalives:
|
|
self.__keepalive_interval = 0
|
|
self.__keepalive_last = time.time()
|
|
self.__keepalive_callback = None
|
|
|
|
def set_log(self, log):
|
|
"""
|
|
Set the python log object to use for logging.
|
|
"""
|
|
self.__logger = log
|
|
|
|
def set_outbound_cipher(self, block_engine, block_size, mac_engine, mac_size, mac_key):
|
|
"""
|
|
Switch outbound data cipher.
|
|
"""
|
|
self.__block_engine_out = block_engine
|
|
self.__block_size_out = block_size
|
|
self.__mac_engine_out = mac_engine
|
|
self.__mac_size_out = mac_size
|
|
self.__mac_key_out = mac_key
|
|
self.__sent_bytes = 0
|
|
self.__sent_packets = 0
|
|
# wait until the reset happens in both directions before clearing rekey flag
|
|
self.__init_count |= 1
|
|
if self.__init_count == 3:
|
|
self.__init_count = 0
|
|
self.__need_rekey = False
|
|
|
|
def set_inbound_cipher(self, block_engine, block_size, mac_engine, mac_size, mac_key):
|
|
"""
|
|
Switch inbound data cipher.
|
|
"""
|
|
self.__block_engine_in = block_engine
|
|
self.__block_size_in = block_size
|
|
self.__mac_engine_in = mac_engine
|
|
self.__mac_size_in = mac_size
|
|
self.__mac_key_in = mac_key
|
|
self.__received_bytes = 0
|
|
self.__received_packets = 0
|
|
self.__received_packets_overflow = 0
|
|
# wait until the reset happens in both directions before clearing rekey flag
|
|
self.__init_count |= 2
|
|
if self.__init_count == 3:
|
|
self.__init_count = 0
|
|
self.__need_rekey = False
|
|
|
|
def set_outbound_compressor(self, compressor):
|
|
self.__compress_engine_out = compressor
|
|
|
|
def set_inbound_compressor(self, compressor):
|
|
self.__compress_engine_in = compressor
|
|
|
|
def close(self):
|
|
self.__closed = True
|
|
self.__socket.close()
|
|
|
|
def set_hexdump(self, hexdump):
|
|
self.__dump_packets = hexdump
|
|
|
|
def get_hexdump(self):
|
|
return self.__dump_packets
|
|
|
|
def get_mac_size_in(self):
|
|
return self.__mac_size_in
|
|
|
|
def get_mac_size_out(self):
|
|
return self.__mac_size_out
|
|
|
|
def need_rekey(self):
|
|
"""
|
|
Returns C{True} if a new set of keys needs to be negotiated. This
|
|
will be triggered during a packet read or write, so it should be
|
|
checked after every read or write, or at least after every few.
|
|
|
|
@return: C{True} if a new set of keys needs to be negotiated
|
|
"""
|
|
return self.__need_rekey
|
|
|
|
def set_keepalive(self, interval, callback):
|
|
"""
|
|
Turn on/off the callback keepalive. If C{interval} seconds pass with
|
|
no data read from or written to the socket, the callback will be
|
|
executed and the timer will be reset.
|
|
"""
|
|
self.__keepalive_interval = interval
|
|
self.__keepalive_callback = callback
|
|
self.__keepalive_last = time.time()
|
|
|
|
def read_all(self, n, check_rekey=False):
|
|
"""
|
|
Read as close to N bytes as possible, blocking as long as necessary.
|
|
|
|
@param n: number of bytes to read
|
|
@type n: int
|
|
@return: the data read
|
|
@rtype: str
|
|
@raise EOFError: if the socket was closed before all the bytes could
|
|
be read
|
|
"""
|
|
out = ''
|
|
# handle over-reading from reading the banner line
|
|
if len(self.__remainder) > 0:
|
|
out = self.__remainder[:n]
|
|
self.__remainder = self.__remainder[n:]
|
|
n -= len(out)
|
|
if PY22:
|
|
return self._py22_read_all(n, out)
|
|
while n > 0:
|
|
got_timeout = False
|
|
try:
|
|
x = self.__socket.recv(n)
|
|
if len(x) == 0:
|
|
raise EOFError()
|
|
out += x
|
|
n -= len(x)
|
|
except socket.timeout:
|
|
got_timeout = True
|
|
except socket.error, e:
|
|
# on Linux, sometimes instead of socket.timeout, we get
|
|
# EAGAIN. this is a bug in recent (> 2.6.9) kernels but
|
|
# we need to work around it.
|
|
if (type(e.args) is tuple) and (len(e.args) > 0) and (e.args[0] == errno.EAGAIN):
|
|
got_timeout = True
|
|
elif (type(e.args) is tuple) and (len(e.args) > 0) and (e.args[0] == errno.EINTR):
|
|
# syscall interrupted; try again
|
|
pass
|
|
elif self.__closed:
|
|
raise EOFError()
|
|
else:
|
|
raise
|
|
if got_timeout:
|
|
if self.__closed:
|
|
raise EOFError()
|
|
if check_rekey and (len(out) == 0) and self.__need_rekey:
|
|
raise NeedRekeyException()
|
|
self._check_keepalive()
|
|
return out
|
|
|
|
def write_all(self, out):
|
|
self.__keepalive_last = time.time()
|
|
while len(out) > 0:
|
|
got_timeout = False
|
|
try:
|
|
n = self.__socket.send(out)
|
|
except socket.timeout:
|
|
got_timeout = True
|
|
except socket.error, e:
|
|
if (type(e.args) is tuple) and (len(e.args) > 0) and (e.args[0] == errno.EAGAIN):
|
|
got_timeout = True
|
|
elif (type(e.args) is tuple) and (len(e.args) > 0) and (e.args[0] == errno.EINTR):
|
|
# syscall interrupted; try again
|
|
pass
|
|
else:
|
|
n = -1
|
|
except Exception:
|
|
# could be: (32, 'Broken pipe')
|
|
n = -1
|
|
if got_timeout:
|
|
n = 0
|
|
if self.__closed:
|
|
n = -1
|
|
if n < 0:
|
|
raise EOFError()
|
|
if n == len(out):
|
|
break
|
|
out = out[n:]
|
|
return
|
|
|
|
def readline(self, timeout):
|
|
"""
|
|
Read a line from the socket. We assume no data is pending after the
|
|
line, so it's okay to attempt large reads.
|
|
"""
|
|
buf = self.__remainder
|
|
while not '\n' in buf:
|
|
buf += self._read_timeout(timeout)
|
|
n = buf.index('\n')
|
|
self.__remainder = buf[n+1:]
|
|
buf = buf[:n]
|
|
if (len(buf) > 0) and (buf[-1] == '\r'):
|
|
buf = buf[:-1]
|
|
return buf
|
|
|
|
def send_message(self, data):
|
|
"""
|
|
Write a block of data using the current cipher, as an SSH block.
|
|
"""
|
|
# encrypt this sucka
|
|
data = str(data)
|
|
cmd = ord(data[0])
|
|
if cmd in MSG_NAMES:
|
|
cmd_name = MSG_NAMES[cmd]
|
|
else:
|
|
cmd_name = '$%x' % cmd
|
|
orig_len = len(data)
|
|
self.__write_lock.acquire()
|
|
try:
|
|
if self.__compress_engine_out is not None:
|
|
data = self.__compress_engine_out(data)
|
|
packet = self._build_packet(data)
|
|
if self.__dump_packets:
|
|
self._log(DEBUG, 'Write packet <%s>, length %d' % (cmd_name, orig_len))
|
|
self._log(DEBUG, util.format_binary(packet, 'OUT: '))
|
|
if self.__block_engine_out != None:
|
|
out = self.__block_engine_out.encrypt(packet)
|
|
else:
|
|
out = packet
|
|
# + mac
|
|
if self.__block_engine_out != None:
|
|
payload = struct.pack('>I', self.__sequence_number_out) + packet
|
|
out += compute_hmac(self.__mac_key_out, payload, self.__mac_engine_out)[:self.__mac_size_out]
|
|
self.__sequence_number_out = (self.__sequence_number_out + 1) & 0xffffffffL
|
|
self.write_all(out)
|
|
|
|
self.__sent_bytes += len(out)
|
|
self.__sent_packets += 1
|
|
if ((self.__sent_packets >= self.REKEY_PACKETS) or (self.__sent_bytes >= self.REKEY_BYTES)) \
|
|
and not self.__need_rekey:
|
|
# only ask once for rekeying
|
|
self._log(DEBUG, 'Rekeying (hit %d packets, %d bytes sent)' %
|
|
(self.__sent_packets, self.__sent_bytes))
|
|
self.__received_packets_overflow = 0
|
|
self._trigger_rekey()
|
|
finally:
|
|
self.__write_lock.release()
|
|
|
|
def read_message(self):
|
|
"""
|
|
Only one thread should ever be in this function (no other locking is
|
|
done).
|
|
|
|
@raise SSHException: if the packet is mangled
|
|
@raise NeedRekeyException: if the transport should rekey
|
|
"""
|
|
header = self.read_all(self.__block_size_in, check_rekey=True)
|
|
if self.__block_engine_in != None:
|
|
header = self.__block_engine_in.decrypt(header)
|
|
if self.__dump_packets:
|
|
self._log(DEBUG, util.format_binary(header, 'IN: '));
|
|
packet_size = struct.unpack('>I', header[:4])[0]
|
|
# leftover contains decrypted bytes from the first block (after the length field)
|
|
leftover = header[4:]
|
|
if (packet_size - len(leftover)) % self.__block_size_in != 0:
|
|
raise SSHException('Invalid packet blocking')
|
|
buf = self.read_all(packet_size + self.__mac_size_in - len(leftover))
|
|
packet = buf[:packet_size - len(leftover)]
|
|
post_packet = buf[packet_size - len(leftover):]
|
|
if self.__block_engine_in != None:
|
|
packet = self.__block_engine_in.decrypt(packet)
|
|
if self.__dump_packets:
|
|
self._log(DEBUG, util.format_binary(packet, 'IN: '));
|
|
packet = leftover + packet
|
|
|
|
if self.__mac_size_in > 0:
|
|
mac = post_packet[:self.__mac_size_in]
|
|
mac_payload = struct.pack('>II', self.__sequence_number_in, packet_size) + packet
|
|
my_mac = compute_hmac(self.__mac_key_in, mac_payload, self.__mac_engine_in)[:self.__mac_size_in]
|
|
if my_mac != mac:
|
|
raise SSHException('Mismatched MAC')
|
|
padding = ord(packet[0])
|
|
payload = packet[1:packet_size - padding]
|
|
|
|
if self.__dump_packets:
|
|
self._log(DEBUG, 'Got payload (%d bytes, %d padding)' % (packet_size, padding))
|
|
|
|
if self.__compress_engine_in is not None:
|
|
payload = self.__compress_engine_in(payload)
|
|
|
|
msg = Message(payload[1:])
|
|
msg.seqno = self.__sequence_number_in
|
|
self.__sequence_number_in = (self.__sequence_number_in + 1) & 0xffffffffL
|
|
|
|
# check for rekey
|
|
self.__received_bytes += packet_size + self.__mac_size_in + 4
|
|
self.__received_packets += 1
|
|
if self.__need_rekey:
|
|
# we've asked to rekey -- give them 20 packets to comply before
|
|
# dropping the connection
|
|
self.__received_packets_overflow += 1
|
|
if self.__received_packets_overflow >= 20:
|
|
raise SSHException('Remote transport is ignoring rekey requests')
|
|
elif (self.__received_packets >= self.REKEY_PACKETS) or \
|
|
(self.__received_bytes >= self.REKEY_BYTES):
|
|
# only ask once for rekeying
|
|
self._log(DEBUG, 'Rekeying (hit %d packets, %d bytes received)' %
|
|
(self.__received_packets, self.__received_bytes))
|
|
self.__received_packets_overflow = 0
|
|
self._trigger_rekey()
|
|
|
|
cmd = ord(payload[0])
|
|
if cmd in MSG_NAMES:
|
|
cmd_name = MSG_NAMES[cmd]
|
|
else:
|
|
cmd_name = '$%x' % cmd
|
|
if self.__dump_packets:
|
|
self._log(DEBUG, 'Read packet <%s>, length %d' % (cmd_name, len(payload)))
|
|
return cmd, msg
|
|
|
|
|
|
########## protected
|
|
|
|
|
|
def _log(self, level, msg):
|
|
if self.__logger is None:
|
|
return
|
|
if issubclass(type(msg), list):
|
|
for m in msg:
|
|
self.__logger.log(level, m)
|
|
else:
|
|
self.__logger.log(level, msg)
|
|
|
|
def _check_keepalive(self):
|
|
if (not self.__keepalive_interval) or (not self.__block_engine_out) or \
|
|
self.__need_rekey:
|
|
# wait till we're encrypting, and not in the middle of rekeying
|
|
return
|
|
now = time.time()
|
|
if now > self.__keepalive_last + self.__keepalive_interval:
|
|
self.__keepalive_callback()
|
|
self.__keepalive_last = now
|
|
|
|
def _py22_read_all(self, n, out):
|
|
while n > 0:
|
|
r, w, e = select.select([self.__socket], [], [], 0.1)
|
|
if self.__socket not in r:
|
|
if self.__closed:
|
|
raise EOFError()
|
|
self._check_keepalive()
|
|
else:
|
|
x = self.__socket.recv(n)
|
|
if len(x) == 0:
|
|
raise EOFError()
|
|
out += x
|
|
n -= len(x)
|
|
return out
|
|
|
|
def _py22_read_timeout(self, timeout):
|
|
start = time.time()
|
|
while True:
|
|
r, w, e = select.select([self.__socket], [], [], 0.1)
|
|
if self.__socket in r:
|
|
x = self.__socket.recv(1)
|
|
if len(x) == 0:
|
|
raise EOFError()
|
|
break
|
|
if self.__closed:
|
|
raise EOFError()
|
|
now = time.time()
|
|
if now - start >= timeout:
|
|
raise socket.timeout()
|
|
return x
|
|
|
|
def _read_timeout(self, timeout):
|
|
if PY22:
|
|
return self._py22_read_timeout(timeout)
|
|
start = time.time()
|
|
while True:
|
|
try:
|
|
x = self.__socket.recv(128)
|
|
if len(x) == 0:
|
|
raise EOFError()
|
|
break
|
|
except socket.timeout:
|
|
pass
|
|
if self.__closed:
|
|
raise EOFError()
|
|
now = time.time()
|
|
if now - start >= timeout:
|
|
raise socket.timeout()
|
|
return x
|
|
|
|
def _build_packet(self, payload):
|
|
# pad up at least 4 bytes, to nearest block-size (usually 8)
|
|
bsize = self.__block_size_out
|
|
padding = 3 + bsize - ((len(payload) + 8) % bsize)
|
|
packet = struct.pack('>IB', len(payload) + padding + 1, padding)
|
|
packet += payload
|
|
if self.__block_engine_out is not None:
|
|
packet += rng.read(padding)
|
|
else:
|
|
# cute trick i caught openssh doing: if we're not encrypting,
|
|
# don't waste random bytes for the padding
|
|
packet += (chr(0) * padding)
|
|
return packet
|
|
|
|
def _trigger_rekey(self):
|
|
# outside code should check for this flag
|
|
self.__need_rekey = True
|