# Copyright (C) 2011  Jeff Forcier <jeff@bitprophet.org>
#
# This file is part of ssh.
#
# 'ssh' is free software; you can redistribute it and/or modify it under the
# terms of the GNU Lesser General Public License as published by the Free
# Software Foundation; either version 2.1 of the License, or (at your option)
# any later version.
#
# 'ssh' is distrubuted in the hope that it will be useful, but WITHOUT ANY
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
# A PARTICULAR PURPOSE.  See the GNU Lesser General Public License for more
# details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with 'ssh'; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Suite 500, Boston, MA  02110-1335  USA.

"""
Packetizer.
"""

import errno
import select
import socket
import struct
import threading
import time

from ssh.common import *
from ssh import util
from ssh.ssh_exception import SSHException
from ssh.message import Message


got_r_hmac = False
try:
    import r_hmac
    got_r_hmac = True
except ImportError:
    pass
def compute_hmac(key, message, digest_class):
    if got_r_hmac:
        return r_hmac.HMAC(key, message, digest_class).digest()
    from Crypto.Hash import HMAC
    return HMAC.HMAC(key, message, digest_class).digest()


class NeedRekeyException (Exception):
    pass


class Packetizer (object):
    """
    Implementation of the base SSH packet protocol.
    """

    # READ the secsh RFC's before raising these values.  if anything,
    # they should probably be lower.
    REKEY_PACKETS = pow(2, 30)
    REKEY_BYTES = pow(2, 30)

    def __init__(self, socket):
        self.__socket = socket
        self.__logger = None
        self.__closed = False
        self.__dump_packets = False
        self.__need_rekey = False
        self.__init_count = 0
        self.__remainder = ''

        # used for noticing when to re-key:
        self.__sent_bytes = 0
        self.__sent_packets = 0
        self.__received_bytes = 0
        self.__received_packets = 0
        self.__received_packets_overflow = 0

        # current inbound/outbound ciphering:
        self.__block_size_out = 8
        self.__block_size_in = 8
        self.__mac_size_out = 0
        self.__mac_size_in = 0
        self.__block_engine_out = None
        self.__block_engine_in = None
        self.__mac_engine_out = None
        self.__mac_engine_in = None
        self.__mac_key_out = ''
        self.__mac_key_in = ''
        self.__compress_engine_out = None
        self.__compress_engine_in = None
        self.__sequence_number_out = 0L
        self.__sequence_number_in = 0L

        # lock around outbound writes (packet computation)
        self.__write_lock = threading.RLock()

        # keepalives:
        self.__keepalive_interval = 0
        self.__keepalive_last = time.time()
        self.__keepalive_callback = None

    def set_log(self, log):
        """
        Set the python log object to use for logging.
        """
        self.__logger = log

    def set_outbound_cipher(self, block_engine, block_size, mac_engine, mac_size, mac_key):
        """
        Switch outbound data cipher.
        """
        self.__block_engine_out = block_engine
        self.__block_size_out = block_size
        self.__mac_engine_out = mac_engine
        self.__mac_size_out = mac_size
        self.__mac_key_out = mac_key
        self.__sent_bytes = 0
        self.__sent_packets = 0
        # wait until the reset happens in both directions before clearing rekey flag
        self.__init_count |= 1
        if self.__init_count == 3:
            self.__init_count = 0
            self.__need_rekey = False

    def set_inbound_cipher(self, block_engine, block_size, mac_engine, mac_size, mac_key):
        """
        Switch inbound data cipher.
        """
        self.__block_engine_in = block_engine
        self.__block_size_in = block_size
        self.__mac_engine_in = mac_engine
        self.__mac_size_in = mac_size
        self.__mac_key_in = mac_key
        self.__received_bytes = 0
        self.__received_packets = 0
        self.__received_packets_overflow = 0
        # wait until the reset happens in both directions before clearing rekey flag
        self.__init_count |= 2
        if self.__init_count == 3:
            self.__init_count = 0
            self.__need_rekey = False

    def set_outbound_compressor(self, compressor):
        self.__compress_engine_out = compressor

    def set_inbound_compressor(self, compressor):
        self.__compress_engine_in = compressor

    def close(self):
        self.__closed = True
        self.__socket.close()

    def set_hexdump(self, hexdump):
        self.__dump_packets = hexdump

    def get_hexdump(self):
        return self.__dump_packets

    def get_mac_size_in(self):
        return self.__mac_size_in

    def get_mac_size_out(self):
        return self.__mac_size_out

    def need_rekey(self):
        """
        Returns C{True} if a new set of keys needs to be negotiated.  This
        will be triggered during a packet read or write, so it should be
        checked after every read or write, or at least after every few.

        @return: C{True} if a new set of keys needs to be negotiated
        """
        return self.__need_rekey

    def set_keepalive(self, interval, callback):
        """
        Turn on/off the callback keepalive.  If C{interval} seconds pass with
        no data read from or written to the socket, the callback will be
        executed and the timer will be reset.
        """
        self.__keepalive_interval = interval
        self.__keepalive_callback = callback
        self.__keepalive_last = time.time()

    def read_all(self, n, check_rekey=False):
        """
        Read as close to N bytes as possible, blocking as long as necessary.

        @param n: number of bytes to read
        @type n: int
        @return: the data read
        @rtype: str
        @raise EOFError: if the socket was closed before all the bytes could
            be read
        """
        out = ''
        # handle over-reading from reading the banner line
        if len(self.__remainder) > 0:
            out = self.__remainder[:n]
            self.__remainder = self.__remainder[n:]
            n -= len(out)
        if PY22:
            return self._py22_read_all(n, out)
        while n > 0:
            got_timeout = False
            try:
                x = self.__socket.recv(n)
                if len(x) == 0:
                    raise EOFError()
                out += x
                n -= len(x)
            except socket.timeout:
                got_timeout = True
            except socket.error, e:
                # on Linux, sometimes instead of socket.timeout, we get
                # EAGAIN.  this is a bug in recent (> 2.6.9) kernels but
                # we need to work around it.
                if (type(e.args) is tuple) and (len(e.args) > 0) and (e.args[0] == errno.EAGAIN):
                    got_timeout = True
                elif (type(e.args) is tuple) and (len(e.args) > 0) and (e.args[0] == errno.EINTR):
                    # syscall interrupted; try again
                    pass
                elif self.__closed:
                    raise EOFError()
                else:
                    raise
            if got_timeout:
                if self.__closed:
                    raise EOFError()
                if check_rekey and (len(out) == 0) and self.__need_rekey:
                    raise NeedRekeyException()
                self._check_keepalive()
        return out

    def write_all(self, out):
        self.__keepalive_last = time.time()
        while len(out) > 0:
            got_timeout = False
            try:
                n = self.__socket.send(out)
            except socket.timeout:
                got_timeout = True
            except socket.error, e:
                if (type(e.args) is tuple) and (len(e.args) > 0) and (e.args[0] == errno.EAGAIN):
                    got_timeout = True
                elif (type(e.args) is tuple) and (len(e.args) > 0) and (e.args[0] == errno.EINTR):
                    # syscall interrupted; try again
                    pass
                else:
                    n = -1
            except Exception:
                # could be: (32, 'Broken pipe')
                n = -1
            if got_timeout:
                n = 0
                if self.__closed:
                    n = -1
            if n < 0:
                raise EOFError()
            if n == len(out):
                break
            out = out[n:]
        return

    def readline(self, timeout):
        """
        Read a line from the socket.  We assume no data is pending after the
        line, so it's okay to attempt large reads.
        """
        buf = self.__remainder
        while not '\n' in buf:
            buf += self._read_timeout(timeout)
        n = buf.index('\n')
        self.__remainder = buf[n+1:]
        buf = buf[:n]
        if (len(buf) > 0) and (buf[-1] == '\r'):
            buf = buf[:-1]
        return buf

    def send_message(self, data):
        """
        Write a block of data using the current cipher, as an SSH block.
        """
        # encrypt this sucka
        data = str(data)
        cmd = ord(data[0])
        if cmd in MSG_NAMES:
            cmd_name = MSG_NAMES[cmd]
        else:
            cmd_name = '$%x' % cmd
        orig_len = len(data)
        self.__write_lock.acquire()
        try:
            if self.__compress_engine_out is not None:
                data = self.__compress_engine_out(data)
            packet = self._build_packet(data)
            if self.__dump_packets:
                self._log(DEBUG, 'Write packet <%s>, length %d' % (cmd_name, orig_len))
                self._log(DEBUG, util.format_binary(packet, 'OUT: '))
            if self.__block_engine_out != None:
                out = self.__block_engine_out.encrypt(packet)
            else:
                out = packet
            # + mac
            if self.__block_engine_out != None:
                payload = struct.pack('>I', self.__sequence_number_out) + packet
                out += compute_hmac(self.__mac_key_out, payload, self.__mac_engine_out)[:self.__mac_size_out]
            self.__sequence_number_out = (self.__sequence_number_out + 1) & 0xffffffffL
            self.write_all(out)

            self.__sent_bytes += len(out)
            self.__sent_packets += 1
            if ((self.__sent_packets >= self.REKEY_PACKETS) or (self.__sent_bytes >= self.REKEY_BYTES)) \
                   and not self.__need_rekey:
                # only ask once for rekeying
                self._log(DEBUG, 'Rekeying (hit %d packets, %d bytes sent)' %
                          (self.__sent_packets, self.__sent_bytes))
                self.__received_packets_overflow = 0
                self._trigger_rekey()
        finally:
            self.__write_lock.release()

    def read_message(self):
        """
        Only one thread should ever be in this function (no other locking is
        done).

        @raise SSHException: if the packet is mangled
        @raise NeedRekeyException: if the transport should rekey
        """
        header = self.read_all(self.__block_size_in, check_rekey=True)
        if self.__block_engine_in != None:
            header = self.__block_engine_in.decrypt(header)
        if self.__dump_packets:
            self._log(DEBUG, util.format_binary(header, 'IN: '));
        packet_size = struct.unpack('>I', header[:4])[0]
        # leftover contains decrypted bytes from the first block (after the length field)
        leftover = header[4:]
        if (packet_size - len(leftover)) % self.__block_size_in != 0:
            raise SSHException('Invalid packet blocking')
        buf = self.read_all(packet_size + self.__mac_size_in - len(leftover))
        packet = buf[:packet_size - len(leftover)]
        post_packet = buf[packet_size - len(leftover):]
        if self.__block_engine_in != None:
            packet = self.__block_engine_in.decrypt(packet)
        if self.__dump_packets:
            self._log(DEBUG, util.format_binary(packet, 'IN: '));
        packet = leftover + packet

        if self.__mac_size_in > 0:
            mac = post_packet[:self.__mac_size_in]
            mac_payload = struct.pack('>II', self.__sequence_number_in, packet_size) + packet
            my_mac = compute_hmac(self.__mac_key_in, mac_payload, self.__mac_engine_in)[:self.__mac_size_in]
            if my_mac != mac:
                raise SSHException('Mismatched MAC')
        padding = ord(packet[0])
        payload = packet[1:packet_size - padding]
        
        if self.__dump_packets:
            self._log(DEBUG, 'Got payload (%d bytes, %d padding)' % (packet_size, padding))

        if self.__compress_engine_in is not None:
            payload = self.__compress_engine_in(payload)

        msg = Message(payload[1:])
        msg.seqno = self.__sequence_number_in
        self.__sequence_number_in = (self.__sequence_number_in + 1) & 0xffffffffL

        # check for rekey
        self.__received_bytes += packet_size + self.__mac_size_in + 4
        self.__received_packets += 1
        if self.__need_rekey:
            # we've asked to rekey -- give them 20 packets to comply before
            # dropping the connection
            self.__received_packets_overflow += 1
            if self.__received_packets_overflow >= 20:
                raise SSHException('Remote transport is ignoring rekey requests')
        elif (self.__received_packets >= self.REKEY_PACKETS) or \
             (self.__received_bytes >= self.REKEY_BYTES):
            # only ask once for rekeying
            self._log(DEBUG, 'Rekeying (hit %d packets, %d bytes received)' %
                      (self.__received_packets, self.__received_bytes))
            self.__received_packets_overflow = 0
            self._trigger_rekey()

        cmd = ord(payload[0])
        if cmd in MSG_NAMES:
            cmd_name = MSG_NAMES[cmd]
        else:
            cmd_name = '$%x' % cmd
        if self.__dump_packets:
            self._log(DEBUG, 'Read packet <%s>, length %d' % (cmd_name, len(payload)))
        return cmd, msg


    ##########  protected


    def _log(self, level, msg):
        if self.__logger is None:
            return
        if issubclass(type(msg), list):
            for m in msg:
                self.__logger.log(level, m)
        else:
            self.__logger.log(level, msg)

    def _check_keepalive(self):
        if (not self.__keepalive_interval) or (not self.__block_engine_out) or \
            self.__need_rekey:
            # wait till we're encrypting, and not in the middle of rekeying
            return
        now = time.time()
        if now > self.__keepalive_last + self.__keepalive_interval:
            self.__keepalive_callback()
            self.__keepalive_last = now

    def _py22_read_all(self, n, out):
        while n > 0:
            r, w, e = select.select([self.__socket], [], [], 0.1)
            if self.__socket not in r:
                if self.__closed:
                    raise EOFError()
                self._check_keepalive()
            else:
                x = self.__socket.recv(n)
                if len(x) == 0:
                    raise EOFError()
                out += x
                n -= len(x)
        return out

    def _py22_read_timeout(self, timeout):
        start = time.time()
        while True:
            r, w, e = select.select([self.__socket], [], [], 0.1)
            if self.__socket in r:
                x = self.__socket.recv(1)
                if len(x) == 0:
                    raise EOFError()
                break
            if self.__closed:
                raise EOFError()
            now = time.time()
            if now - start >= timeout:
                raise socket.timeout()
        return x

    def _read_timeout(self, timeout):
        if PY22:
            return self._py22_read_timeout(timeout)
        start = time.time()
        while True:
            try:
                x = self.__socket.recv(128)
                if len(x) == 0:
                    raise EOFError()
                break
            except socket.timeout:
                pass
            if self.__closed:
                raise EOFError()
            now = time.time()
            if now - start >= timeout:
                raise socket.timeout()
        return x

    def _build_packet(self, payload):
        # pad up at least 4 bytes, to nearest block-size (usually 8)
        bsize = self.__block_size_out
        padding = 3 + bsize - ((len(payload) + 8) % bsize)
        packet = struct.pack('>IB', len(payload) + padding + 1, padding)
        packet += payload
        if self.__block_engine_out is not None:
            packet += rng.read(padding)
        else:
            # cute trick i caught openssh doing: if we're not encrypting,
            # don't waste random bytes for the padding
            packet += (chr(0) * padding)
        return packet

    def _trigger_rekey(self):
        # outside code should check for this flag
        self.__need_rekey = True