[project @ Arch-1:robey@lag.net--2003-public%secsh--dev--1.0--patch-16]

hook up server-side kex-gex; add more documentation
group-exchange kex should work now on the server side.  it will only be
advertised if a "moduli" file has been loaded (see the -gasp- docs) so we
don't spend hours (literally. hours.) computing primes.  some of the logic
was previously wrong, too, since it had never been tested.

fixed repr() string for Transport/BaseTransport.  moved is_authenticated to
Transport where it belongs.

added lots of documentation (but still only about 10% documented).  lots of
methods were made private finally.
This commit is contained in:
Robey Pointer 2003-12-28 03:20:42 +00:00
parent eb4c279ec4
commit 36d6d95dc6
11 changed files with 506 additions and 274 deletions

4
NOTES
View File

@ -22,14 +22,14 @@ from BaseTransport:
get_server_key
close
get_remote_server_key
is_active
is_authenticated
* is_active
open_session
open_channel
renegotiate_keys
check_channel_request
from Transport:
* is_authenticated
auth_key
auth_password
get_allowed_auths

View File

@ -69,7 +69,7 @@ try:
t.ultra_debug = 0
t.start_client(event)
# print repr(t)
event.wait(10)
event.wait(15)
if not t.is_active():
print '*** SSH negotiation failed.'
sys.exit(1)

View File

@ -70,6 +70,11 @@ print 'Got a connection!'
try:
event = threading.Event()
t = ServerTransport(client)
try:
t.load_server_moduli()
except:
print '(Failed to load moduli -- gex will be unsupported.)'
raise
t.add_server_key(host_key)
t.ultra_debug = 0
t.start_server(event)
@ -81,10 +86,11 @@ try:
# print repr(t)
# wait for auth
chan = t.accept(10)
chan = t.accept(20)
if chan is None:
print '*** No channel.'
sys.exit(1)
print 'Authenticated!'
chan.event.wait(10)
if not chan.event.isSet():
print '*** Client never asked for a shell.'

View File

@ -17,4 +17,4 @@ from rsakey import RSAKey
from dsskey import DSSKey
from util import hexify
__all__ = [ 'Transport', 'Channel', 'RSAKey', 'DSSKey', 'hexify' ]
#__all__ = [ 'Transport', 'Channel', 'RSAKey', 'DSSKey', 'hexify' ]

View File

@ -19,17 +19,45 @@ class Transport(BaseTransport):
def __init__(self, sock):
BaseTransport.__init__(self, sock)
self.authenticated = False
self.auth_event = None
# for server mode:
self.auth_username = None
self.auth_fail_count = 0
self.auth_complete = 0
def request_auth(self):
def __repr__(self):
if not self.active:
return '<paramiko.Transport (unconnected)>'
out = '<paramiko.Transport'
if self.local_cipher != '':
out += ' (cipher %s)' % self.local_cipher
if self.authenticated:
if len(self.channels) == 1:
out += ' (active; 1 open channel)'
else:
out += ' (active; %d open channels)' % len(self.channels)
elif self.initial_kex_done:
out += ' (connected; awaiting auth)'
else:
out += ' (connecting)'
out += '>'
return out
def is_authenticated(self):
"""
Return true if this session is active and authenticated.
@return: True if the session is still open and has been authenticated successfully;
False if authentication failed and/or the session is closed.
"""
return self.authenticated and self.active
def _request_auth(self):
m = Message()
m.add_byte(chr(MSG_SERVICE_REQUEST))
m.add_string('ssh-userauth')
self.send_message(m)
self._send_message(m)
def auth_key(self, username, key, event):
if (not self.active) or (not self.initial_kex_done):
@ -41,7 +69,7 @@ class Transport(BaseTransport):
self.auth_method = 'publickey'
self.username = username
self.private_key = key
self.request_auth()
self._request_auth()
finally:
self.lock.release()
@ -56,7 +84,7 @@ class Transport(BaseTransport):
self.auth_method = 'password'
self.username = username
self.password = password
self.request_auth()
self._request_auth()
finally:
self.lock.release()
@ -66,7 +94,7 @@ class Transport(BaseTransport):
m.add_int(DISCONNECT_SERVICE_NOT_AVAILABLE)
m.add_string('Service not available')
m.add_string('en')
self.send_message(m)
self._send_message(m)
self.close()
def disconnect_no_more_auth(self):
@ -75,7 +103,7 @@ class Transport(BaseTransport):
m.add_int(DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE)
m.add_string('No more auth methods available')
m.add_string('en')
self.send_message(m)
self._send_message(m)
self.close()
def parse_service_request(self, m):
@ -85,7 +113,7 @@ class Transport(BaseTransport):
m = Message()
m.add_byte(chr(MSG_SERVICE_ACCEPT))
m.add_string(service)
self.send_message(m)
self._send_message(m)
return
# dunno this one
self.disconnect_service_not_available()
@ -93,7 +121,7 @@ class Transport(BaseTransport):
def parse_service_accept(self, m):
service = m.get_string()
if service == 'ssh-userauth':
self.log(DEBUG, 'userauth is OK')
self._log(DEBUG, 'userauth is OK')
m = Message()
m.add_byte(chr(MSG_USERAUTH_REQUEST))
m.add_string(self.username)
@ -109,9 +137,9 @@ class Transport(BaseTransport):
m.add_string(self.private_key.sign_ssh_session(self.randpool, self.H, self.username))
else:
raise SSHException('Unknown auth method "%s"' % self.auth_method)
self.send_message(m)
self._send_message(m)
else:
self.log(DEBUG, 'Service request "%s" accepted (?)' % service)
self._log(DEBUG, 'Service request "%s" accepted (?)' % service)
def get_allowed_auths(self, username):
"override me!"
@ -136,7 +164,7 @@ class Transport(BaseTransport):
m.add_byte(chr(MSG_USERAUTH_FAILURE))
m.add_string('none')
m.add_boolean(0)
self.send_message(m)
self._send_message(m)
return
if self.auth_complete:
# ignore
@ -144,12 +172,12 @@ class Transport(BaseTransport):
username = m.get_string()
service = m.get_string()
method = m.get_string()
self.log(DEBUG, 'Auth request (type=%s) service=%s, username=%s' % (method, service, username))
self._log(DEBUG, 'Auth request (type=%s) service=%s, username=%s' % (method, service, username))
if service != 'ssh-connection':
self.disconnect_service_not_available()
return
if (self.auth_username is not None) and (self.auth_username != username):
self.log(DEBUG, 'Auth rejected because the client attempted to change username in mid-flight')
self._log(DEBUG, 'Auth rejected because the client attempted to change username in mid-flight')
self.disconnect_no_more_auth()
return
if method == 'none':
@ -160,7 +188,7 @@ class Transport(BaseTransport):
if changereq:
# always treated as failure, since we don't support changing passwords, but collect
# the list of valid auth types from the callback anyway
self.log(DEBUG, 'Auth request to change passwords (rejected)')
self._log(DEBUG, 'Auth request to change passwords (rejected)')
newpassword = m.get_string().decode('UTF-8')
result = self.AUTH_FAILED
else:
@ -173,11 +201,11 @@ class Transport(BaseTransport):
# okay, send result
m = Message()
if result == self.AUTH_SUCCESSFUL:
self.log(DEBUG, 'Auth granted.')
self._log(DEBUG, 'Auth granted.')
m.add_byte(chr(MSG_USERAUTH_SUCCESS))
self.auth_complete = 1
else:
self.log(DEBUG, 'Auth rejected.')
self._log(DEBUG, 'Auth rejected.')
m.add_byte(chr(MSG_USERAUTH_FAILURE))
m.add_string(self.get_allowed_auths(username))
if result == self.AUTH_PARTIALLY_SUCCESSFUL:
@ -185,13 +213,13 @@ class Transport(BaseTransport):
else:
m.add_boolean(0)
self.auth_fail_count += 1
self.send_message(m)
self._send_message(m)
if self.auth_fail_count >= 10:
self.disconnect_no_more_auth()
def parse_userauth_success(self, m):
self.log(INFO, 'Authentication successful!')
self.authenticated = 1
self._log(INFO, 'Authentication successful!')
self.authenticated = True
if self.auth_event != None:
self.auth_event.set()
@ -199,12 +227,12 @@ class Transport(BaseTransport):
authlist = m.get_list()
partial = m.get_boolean()
if partial:
self.log(INFO, 'Authentication continues...')
self.log(DEBUG, 'Methods: ' + str(partial))
self._log(INFO, 'Authentication continues...')
self._log(DEBUG, 'Methods: ' + str(partial))
# FIXME - do something
pass
self.log(INFO, 'Authentication failed.')
self.authenticated = 0
self._log(INFO, 'Authentication failed.')
self.authenticated = False
self.close()
if self.auth_event != None:
self.auth_event.set()
@ -212,11 +240,11 @@ class Transport(BaseTransport):
def parse_userauth_banner(self, m):
banner = m.get_string()
lang = m.get_string()
self.log(INFO, 'Auth banner: ' + banner)
self._log(INFO, 'Auth banner: ' + banner)
# who cares.
handler_table = BaseTransport.handler_table.copy()
handler_table.update({
_handler_table = BaseTransport._handler_table.copy()
_handler_table.update({
MSG_SERVICE_REQUEST: parse_service_request,
MSG_SERVICE_ACCEPT: parse_service_accept,
MSG_USERAUTH_REQUEST: parse_userauth_request,

View File

@ -50,10 +50,10 @@ class Channel(object):
out += '>'
return out
def set_transport(self, transport):
def _set_transport(self, transport):
self.transport = transport
def log(self, level, msg):
def _log(self, level, msg):
self.logger.log(level, msg)
def set_window(self, window_size, max_packet_size):
@ -70,7 +70,7 @@ class Channel(object):
self.active = 1
def request_success(self, m):
self.log(DEBUG, 'Sesch channel %d request ok' % self.chanid)
self._log(DEBUG, 'Sesch channel %d request ok' % self.chanid)
return
def request_failed(self, m):
@ -80,13 +80,13 @@ class Channel(object):
s = m.get_string()
try:
self.lock.acquire()
self.log(DEBUG, 'fed %d bytes' % len(s))
self._log(DEBUG, 'fed %d bytes' % len(s))
if self.pipe_wfd != None:
self.feed_pipe(s)
else:
self.in_buffer += s
self.in_buffer_cv.notifyAll()
self.log(DEBUG, '(out from feed)')
self._log(DEBUG, '(out from feed)')
finally:
self.lock.release()
@ -94,7 +94,7 @@ class Channel(object):
nbytes = m.get_int()
try:
self.lock.acquire()
self.log(DEBUG, 'window up %d' % nbytes)
self._log(DEBUG, 'window up %d' % nbytes)
self.out_window_size += nbytes
self.out_buffer_cv.notifyAll()
finally:
@ -146,7 +146,7 @@ class Channel(object):
pixelheight = m.get_int()
ok = self.check_window_change_request(width, height, pixelwidth, pixelheight)
else:
self.log(DEBUG, 'Unhandled channel request "%s"' % key)
self._log(DEBUG, 'Unhandled channel request "%s"' % key)
ok = False
if want_reply:
m = Message()
@ -155,7 +155,7 @@ class Channel(object):
else:
m.add_byte(chr(MSG_CHANNEL_FAILURE))
m.add_int(self.remote_chanid)
self.transport.send_message(m)
self.transport._send_message(m)
def handle_eof(self, m):
try:
@ -168,7 +168,7 @@ class Channel(object):
self.pipe_wfd = None
finally:
self.lock.release()
self.log(DEBUG, 'EOF received')
self._log(DEBUG, 'EOF received')
def handle_close(self, m):
self.close()
@ -199,7 +199,7 @@ class Channel(object):
# pixel height, width (usually useless)
m.add_int(0).add_int(0)
m.add_string('')
self.transport.send_message(m)
self.transport._send_message(m)
def invoke_shell(self):
if self.closed or self.eof_received or self.eof_sent or not self.active:
@ -209,7 +209,7 @@ class Channel(object):
m.add_int(self.remote_chanid)
m.add_string('shell')
m.add_boolean(1)
self.transport.send_message(m)
self.transport._send_message(m)
def exec_command(self, command):
if self.closed or self.eof_received or self.eof_sent or not self.active:
@ -220,7 +220,7 @@ class Channel(object):
m.add_string('exec')
m.add_boolean(1)
m.add_string(command)
self.transport.send_message(m)
self.transport._send_message(m)
def invoke_subsystem(self, subsystem):
if self.closed or self.eof_received or self.eof_sent or not self.active:
@ -231,7 +231,7 @@ class Channel(object):
m.add_string('subsystem')
m.add_boolean(1)
m.add_string(subsystem)
self.transport.send_message(m)
self.transport._send_message(m)
def resize_pty(self, width=80, height=24):
if self.closed or self.eof_received or self.eof_sent or not self.active:
@ -244,7 +244,7 @@ class Channel(object):
m.add_int(width)
m.add_int(height)
m.add_int(0).add_int(0)
self.transport.send_message(m)
self.transport._send_message(m)
def get_transport(self):
return self.transport
@ -262,9 +262,9 @@ class Channel(object):
m = Message()
m.add_byte(chr(MSG_CHANNEL_EOF))
m.add_int(self.remote_chanid)
self.transport.send_message(m)
self.transport._send_message(m)
self.eof_sent = 1
self.log(DEBUG, 'EOF sent')
self._log(DEBUG, 'EOF sent')
return
@ -290,9 +290,9 @@ class Channel(object):
m = Message()
m.add_byte(chr(MSG_CHANNEL_CLOSE))
m.add_int(self.remote_chanid)
self.transport.send_message(m)
self.transport._send_message(m)
self.closed = 1
self.transport.unlink_channel(self.chanid)
self.transport._unlink_channel(self.chanid)
finally:
self.lock.release()
@ -371,7 +371,7 @@ class Channel(object):
m.add_byte(chr(MSG_CHANNEL_DATA))
m.add_int(self.remote_chanid)
m.add_string(s[:size])
self.transport.send_message(m)
self.transport._send_message(m)
self.out_window_size -= size
finally:
self.lock.release()
@ -506,25 +506,25 @@ class Channel(object):
self.in_buffer = self.in_buffer[nbytes:]
os.write(self.pipd_wfd, x)
def unlink(self):
def _unlink(self):
if self.closed or not self.active:
return
self.closed = 1
self.transport.unlink_channel(self.chanid)
self.transport._unlink_channel(self.chanid)
def check_add_window(self, n):
# already holding the lock!
if self.closed or self.eof_received or not self.active:
return
self.log(DEBUG, 'addwindow %d' % n)
self._log(DEBUG, 'addwindow %d' % n)
self.in_window_sofar += n
if self.in_window_sofar > self.in_window_threshold:
self.log(DEBUG, 'addwindow send %d' % self.in_window_sofar)
self._log(DEBUG, 'addwindow send %d' % self.in_window_sofar)
m = Message()
m.add_byte(chr(MSG_CHANNEL_WINDOW_ADJUST))
m.add_int(self.remote_chanid)
m.add_int(self.in_window_sofar)
self.transport.send_message(m)
self.transport._send_message(m)
self.in_window_sofar = 0

View File

@ -5,7 +5,7 @@
# LOT more on the server side).
from message import Message
from util import inflate_long, deflate_long, generate_prime, bit_length
from util import inflate_long, deflate_long, bit_length
from ssh_exception import SSHException
from transport import MSG_NEWKEYS
from Crypto.Hash import SHA
@ -27,7 +27,7 @@ class KexGex(object):
def start_kex(self):
if self.transport.server_mode:
self.transport.expected_packet = MSG_KEXDH_GEX_REQUEST
self.transport._expect_packet(MSG_KEXDH_GEX_REQUEST)
return
# request a bit range: we accept (min_bits) to (max_bits), but prefer
# (preferred_bits). according to the spec, we shouldn't pull the
@ -37,21 +37,21 @@ class KexGex(object):
m.add_int(self.min_bits)
m.add_int(self.preferred_bits)
m.add_int(self.max_bits)
self.transport.send_message(m)
self.transport.expected_packet = MSG_KEXDH_GEX_GROUP
self.transport._send_message(m)
self.transport._expect_packet(MSG_KEXDH_GEX_GROUP)
def parse_next(self, ptype, m):
if ptype == MSG_KEXDH_GEX_REQUEST:
return self.parse_kexdh_gex_request(m)
return self._parse_kexdh_gex_request(m)
elif ptype == MSG_KEXDH_GEX_GROUP:
return self.parse_kexdh_gex_group(m)
return self._parse_kexdh_gex_group(m)
elif ptype == MSG_KEXDH_GEX_INIT:
return self.parse_kexdh_gex_init(m)
return self._parse_kexdh_gex_init(m)
elif ptype == MSG_KEXDH_GEX_REPLY:
return self.parse_kexdh_gex_reply(m)
return self._parse_kexdh_gex_reply(m)
raise SSHException('KexGex asked to handle packet type %d' % ptype)
def generate_x(self):
def _generate_x(self):
# generate an "x" (1 < x < (p-1)/2).
q = (self.p - 1) // 2
qnorm = deflate_long(q, 0)
@ -70,7 +70,7 @@ class KexGex(object):
break
self.x = x
def parse_kexdh_gex_request(self, m):
def _parse_kexdh_gex_request(self, m):
min = m.get_int()
preferred = m.get_int()
max = m.get_int()
@ -79,52 +79,53 @@ class KexGex(object):
preferred = self.max_bits
if preferred < self.min_bits:
preferred = self.min_bits
# fix min/max if they're inconsistent. technically, we could just pout
# and hang up, but there's no harm in giving them the benefit of the
# doubt and just picking a bitsize for them.
if min > preferred:
min = preferred
if max < preferred:
max = preferred
# now save a copy
self.min_bits = min
self.preferred_bits = preferred
self.max_bits = max
# generate prime
while 1:
# does not work FIXME
# the problem is that it's too fscking SLOW
self.transport.log(DEBUG, 'stir...')
self.transport.randpool.stir()
self.transport.log(DEBUG, 'get-prime %d...' % preferred)
self.p = generate_prime(preferred, self.transport.randpool)
self.transport.log(DEBUG, 'got ' + repr(self.p))
if number.isPrime((self.p - 1) // 2):
break
self.g = 2
pack = self.transport._get_modulus_pack()
if pack is None:
raise SSHException('Can\'t do server-side gex with no modulus pack')
self.g, self.p = pack.get_modulus(min, preferred, max)
m = Message()
m.add_byte(chr(MSG_KEXDH_GEX_GROUP))
m.add_mpint(self.p)
m.add_mpint(self.g)
self.transport.send_message(m)
self.transport.expected_packet = MSG_KEXDH_GEX_INIT
self.transport._send_message(m)
self.transport._expect_packet(MSG_KEXDH_GEX_INIT)
def parse_kexdh_gex_group(self, m):
def _parse_kexdh_gex_group(self, m):
self.p = m.get_mpint()
self.g = m.get_mpint()
# reject if p's bit length < 1024 or > 8192
bitlen = bit_length(self.p)
if (bitlen < 1024) or (bitlen > 8192):
raise SSHException('Server-generated gex p (don\'t ask) is out of range (%d bits)' % bitlen)
self.transport.log(DEBUG, 'Got server p (%d bits)' % bitlen)
self.generate_x()
self.transport._log(DEBUG, 'Got server p (%d bits)' % bitlen)
self._generate_x()
# now compute e = g^x mod p
self.e = pow(self.g, self.x, self.p)
m = Message()
m.add_byte(chr(MSG_KEXDH_GEX_INIT))
m.add_mpint(self.e)
self.transport.send_message(m)
self.transport.expected_packet = MSG_KEXDH_GEX_REPLY
self.transport._send_message(m)
self.transport._expect_packet(MSG_KEXDH_GEX_REPLY)
def parse_kexdh_gex_init(self, m):
def _parse_kexdh_gex_init(self, m):
self.e = m.get_mpint()
if (self.e < 1) or (self.e > self.p - 1):
raise SSHException('Client kex "e" is out of range')
self.generate_x()
K = pow(self.e, self.x, P)
self._generate_x()
self.f = pow(self.g, self.x, self.p)
K = pow(self.e, self.x, self.p)
key = str(self.transport.get_server_key())
# okay, build up the hash H of (V_C || V_S || I_C || I_S || K_S || min || n || max || p || g || e || f || K)
hm = Message().add(self.transport.remote_version).add(self.transport.local_version)
@ -136,7 +137,7 @@ class KexGex(object):
hm.add_mpint(self.g)
hm.add(self.e).add(self.f).add(K)
H = SHA.new(str(hm)).digest()
self.transport.set_K_H(K, H)
self.transport._set_K_H(K, H)
# sign it
sig = self.transport.get_server_key().sign_ssh_data(self.transport.randpool, H)
# send reply
@ -145,11 +146,10 @@ class KexGex(object):
m.add_string(key)
m.add_mpint(self.f)
m.add_string(sig)
self.transport.send_message(m)
self.transport.activate_outbound()
self.transport.expected_packet = MSG_NEWKEYS
self.transport._send_message(m)
self.transport._activate_outbound()
def parse_kexdh_gex_reply(self, m):
def _parse_kexdh_gex_reply(self, m):
host_key = m.get_string()
self.f = m.get_mpint()
sig = m.get_string()
@ -165,9 +165,9 @@ class KexGex(object):
hm.add_mpint(self.p)
hm.add_mpint(self.g)
hm.add(self.e).add(self.f).add(K)
self.transport.set_K_H(K, SHA.new(str(hm)).digest())
self.transport.verify_key(host_key, sig)
self.transport.activate_outbound()
self.transport.expected_packet = MSG_NEWKEYS
self.transport._set_K_H(K, SHA.new(str(hm)).digest())
self.transport._verify_key(host_key, sig)
self.transport._activate_outbound()

View File

@ -44,15 +44,15 @@ class KexGroup1(object):
if self.transport.server_mode:
# compute f = g^x mod p, but don't send it yet
self.f = pow(G, self.x, P)
self.transport.expected_packet = MSG_KEXDH_INIT
self.transport._expect_packet(MSG_KEXDH_INIT)
return
# compute e = g^x mod p (where g=2), and send it
self.e = pow(G, self.x, P)
m = Message()
m.add_byte(chr(MSG_KEXDH_INIT))
m.add_mpint(self.e)
self.transport.send_message(m)
self.transport.expected_packet = MSG_KEXDH_REPLY
self.transport._send_message(m)
self.transport._expect_packet(MSG_KEXDH_REPLY)
def parse_next(self, ptype, m):
if self.transport.server_mode and (ptype == MSG_KEXDH_INIT):
@ -73,10 +73,9 @@ class KexGroup1(object):
hm = Message().add(self.transport.local_version).add(self.transport.remote_version)
hm.add(self.transport.local_kex_init).add(self.transport.remote_kex_init).add(host_key)
hm.add(self.e).add(self.f).add(K)
self.transport.set_K_H(K, SHA.new(str(hm)).digest())
self.transport.verify_key(host_key, sig)
self.transport.activate_outbound()
self.transport.expected_packet = MSG_NEWKEYS
self.transport._set_K_H(K, SHA.new(str(hm)).digest())
self.transport._verify_key(host_key, sig)
self.transport._activate_outbound()
def parse_kexdh_init(self, m):
# server mode
@ -90,7 +89,7 @@ class KexGroup1(object):
hm.add(self.transport.remote_kex_init).add(self.transport.local_kex_init).add(key)
hm.add(self.e).add(self.f).add(K)
H = SHA.new(str(hm)).digest()
self.transport.set_K_H(K, H)
self.transport._set_K_H(K, H)
# sign it
sig = self.transport.get_server_key().sign_ssh_data(self.transport.randpool, H)
# send reply
@ -99,6 +98,5 @@ class KexGroup1(object):
m.add_string(key)
m.add_mpint(self.f)
m.add_string(sig)
self.transport.send_message(m)
self.transport.activate_outbound()
self.transport.expected_packet = MSG_NEWKEYS
self.transport._send_message(m)
self.transport._activate_outbound()

128
paramiko/primes.py Normal file
View File

@ -0,0 +1,128 @@
# utility functions for dealing with primes
from Crypto.Util import number
from util import bit_length, inflate_long
def generate_prime(bits, randpool):
hbyte_mask = pow(2, bits % 8) - 1
while 1:
# loop catches the case where we increment n into a higher bit-range
x = randpool.get_bytes((bits+7) // 8)
if hbyte_mask > 0:
x = chr(ord(x[0]) & hbyte_mask) + x[1:]
n = inflate_long(x, 1)
n |= 1
n |= (1 << (bits - 1))
while not number.isPrime(n):
n += 2
if bit_length(n) == bits:
return n
def roll_random(randpool, n):
"returns a random # from 0 to N-1"
bits = bit_length(n-1)
bytes = (bits + 7) // 8
hbyte_mask = pow(2, bits % 8) - 1
# so here's the plan:
# we fetch as many random bits as we'd need to fit N-1, and if the
# generated number is >= N, we try again. in the worst case (N-1 is a
# power of 2), we have slightly better than 50% odds of getting one that
# fits, so i can't guarantee that this loop will ever finish, but the odds
# of it looping forever should be infinitesimal.
while 1:
x = randpool.get_bytes(bytes)
if hbyte_mask > 0:
x = chr(ord(x[0]) & hbyte_mask) + x[1:]
num = inflate_long(x, 1)
if num < n:
return num
class ModulusPack (object):
"""
convenience object for holding the contents of the /etc/ssh/moduli file,
on systems that have such a file.
"""
def __init__(self, randpool):
# pack is a hash of: bits -> [ (generator, modulus) ... ]
self.pack = {}
self.discarded = []
self.randpool = randpool
def _parse_modulus(self, line):
timestamp, type, tests, tries, size, generator, modulus = line.split()
type = int(type)
tests = int(tests)
tries = int(tries)
size = int(size)
generator = int(generator)
modulus = long(modulus, 16)
# weed out primes that aren't at least:
# type 2 (meets basic structural requirements)
# test 4 (more than just a small-prime sieve)
# tries < 100 if test & 4 (at least 100 tries of miller-rabin)
if (type < 2) or (tests < 4) or ((tests & 4) and (tests < 8) and (tries < 100)):
self.discarded.append((modulus, 'does not meet basic requirements'))
return
if generator == 0:
generator = 2
# there's a bug in the ssh "moduli" file (yeah, i know: shock! dismay!
# call cnn!) where it understates the bit lengths of these primes by 1.
# this is okay.
bl = bit_length(modulus)
if (bl != size) and (bl != size + 1):
self.discarded.append((modulus, 'incorrectly reported bit length %d' % size))
return
if not self.pack.has_key(bl):
self.pack[bl] = []
self.pack[bl].append((generator, modulus))
def read_file(self, filename):
"""
@raise IOError: passed from any file operations that fail.
"""
self.pack = {}
f = open(filename, 'r')
for line in f:
line = line.strip()
if (len(line) == 0) or (line[0] == '#'):
continue
try:
self._parse_modulus(line)
except:
continue
f.close()
def get_modulus(self, min, prefer, max):
bitsizes = self.pack.keys()
bitsizes.sort()
if len(bitsizes) == 0:
raise SSHException('no moduli available')
good = -1
# find nearest bitsize >= preferred
for b in bitsizes:
if (b >= prefer) and (b < max) and ((b < good) or (good == -1)):
good = b
# if that failed, find greatest bitsize >= min
if good == -1:
for b in bitsizes:
if (b >= min) and (b < max) and (b > good):
good = b
if good == -1:
# their entire (min, max) range has no intersection with our range.
# if their range is below ours, pick the smallest. otherwise pick
# the largest. it'll be out of their range requirement either way,
# but we'll be sending them the closest one we have.
good = bitsizes[0]
if min > good:
good = bitsizes[-1]
# now pick a random modulus of this bitsize
n = roll_random(self.randpool, len(self.pack[good]))
return self.pack[good][n]

View File

@ -12,6 +12,7 @@ MSG_CHANNEL_OPEN, MSG_CHANNEL_OPEN_SUCCESS, MSG_CHANNEL_OPEN_FAILURE, \
MSG_CHANNEL_SUCCESS, MSG_CHANNEL_FAILURE = range(90, 101)
import sys, os, string, threading, socket, logging, struct
from ssh_exception import SSHException
from message import Message
from channel import Channel
from util import format_binary, safe_string, inflate_long, deflate_long, tb_strings
@ -19,6 +20,7 @@ from rsakey import RSAKey
from dsskey import DSSKey
from kex_group1 import KexGroup1
from kex_gex import KexGex
from primes import ModulusPack
# these come from PyCrypt
# http://www.amk.ca/python/writing/pycrypt/
@ -81,21 +83,21 @@ class BaseTransport(threading.Thread):
preferred_keys = [ 'ssh-rsa', 'ssh-dss' ]
preferred_kex = [ 'diffie-hellman-group1-sha1', 'diffie-hellman-group-exchange-sha1' ]
cipher_info = {
_cipher_info = {
'blowfish-cbc': { 'class': Blowfish, 'mode': Blowfish.MODE_CBC, 'block-size': 8, 'key-size': 16 },
'aes128-cbc': { 'class': AES, 'mode': AES.MODE_CBC, 'block-size': 16, 'key-size': 16 },
'aes256-cbc': { 'class': AES, 'mode': AES.MODE_CBC, 'block-size': 16, 'key-size': 32 },
'3des-cbc': { 'class': DES3, 'mode': DES3.MODE_CBC, 'block-size': 8, 'key-size': 24 },
}
mac_info = {
_mac_info = {
'hmac-sha1': { 'class': SHA, 'size': 20 },
'hmac-sha1-96': { 'class': SHA, 'size': 12 },
'hmac-md5': { 'class': MD5, 'size': 16 },
'hmac-md5-96': { 'class': MD5, 'size': 12 },
}
kex_info = {
_kex_info = {
'diffie-hellman-group1-sha1': KexGroup1,
'diffie-hellman-group-exchange-sha1': KexGex,
}
@ -107,7 +109,7 @@ class BaseTransport(threading.Thread):
OPEN_FAILED_RESOURCE_SHORTAGE = range(1, 5)
def __init__(self, sock):
threading.Thread.__init__(self)
threading.Thread.__init__(self, target=self._run)
self.randpool = randpool
self.sock = sock
self.sock.settimeout(0.1)
@ -123,11 +125,10 @@ class BaseTransport(threading.Thread):
self.session_id = None
# /negotiated crypto parameters
self.expected_packet = 0
self.active = 0
self.active = False
self.initial_kex_done = 0
self.write_lock = threading.Lock() # lock around outbound writes (packet computation)
self.lock = threading.Lock() # synchronization (always higher level than write_lock)
self.authenticated = 0
self.channels = { } # (id -> Channel)
self.channel_events = { } # (id -> Event)
self.channel_counter = 1
@ -135,6 +136,7 @@ class BaseTransport(threading.Thread):
self.window_size = 65536
self.max_packet_size = 2048
self.ultra_debug = 0
self.modulus_pack = None
# used for noticing when to re-key:
self.received_bytes = 0
self.received_packets = 0
@ -165,27 +167,69 @@ class BaseTransport(threading.Thread):
except KeyError:
return None
def load_server_moduli(self, filename=None):
"""
I{(optional)}
Load a file of prime moduli for use in doing group-exchange key
negotiation in server mode. It's a rather obscure option and can be
safely ignored.
In server mode, the remote client may request "group-exchange" key
negotiation, which asks the server to send a random prime number that
fits certain criteria. These primes are pretty difficult to compute,
so they can't be generated on demand. But many systems contain a file
of suitable primes (usually named something like C{/etc/ssh/moduli}).
If you call C{load_server_moduli} and it returns C{True}, then this
file of primes has been loaded and we will support "group-exchange" in
server mode. Otherwise server mode will just claim that it doesn't
support that method of key negotiation.
@param filename: optional path to the moduli file, if you happen to
know that it's not in a standard location.
@type filename: string
@return: True if a moduli file was successfully loaded; False
otherwise.
@rtype: boolean
@since: doduo
@note: This has no effect when used in client mode.
"""
self.modulus_pack = ModulusPack(self.randpool)
# places to look for the openssh "moduli" file
file_list = [ '/etc/ssh/moduli', '/usr/local/etc/moduli' ]
if filename is not None:
file_list.insert(0, filename)
for fn in file_list:
try:
self.modulus_pack.read_file(fn)
return True
except IOError:
pass
# none succeeded
self.modulus_pack = None
return False
def _get_modulus_pack(self):
"used by KexGex to find primes for group exchange"
return self.modulus_pack
def __repr__(self):
if not self.active:
return '<paramiko.Transport (unconnected)>'
out = '<sesch.Transport'
return '<paramiko.BaseTransport (unconnected)>'
out = '<paramiko.BaseTransport'
#if self.remote_version != '':
# out += ' (server version "%s")' % self.remote_version
if self.local_cipher != '':
out += ' (cipher %s)' % self.local_cipher
if self.authenticated:
if len(self.channels) == 1:
out += ' (active; 1 open channel)'
else:
out += ' (active; %d open channels)' % len(self.channels)
elif self.initial_kex_done:
out += ' (connected; awaiting auth)'
else:
out += ' (connecting)'
out += '>'
return out
def log(self, level, msg):
def _log(self, level, msg):
if type(msg) == type([]):
for m in msg:
self.logger.log(level, m)
@ -193,14 +237,29 @@ class BaseTransport(threading.Thread):
self.logger.log(level, msg)
def close(self):
self.active = 0
"""
Close this session, and any open channels that are tied to it.
"""
self.active = False
self.engine_in = self.engine_out = None
self.sequence_number_in = self.sequence_number_out = 0L
for chan in self.channels.values():
chan.unlink()
chan._unlink()
def get_remote_server_key(self):
'returns (type, key) where type is like "ssh-rsa" and key is an opaque string'
"""
Return the host key of the server (in client mode).
The type string is usually either C{"ssh-rsa"} or C{"ssh-dss"} and the
key is an opaque string, which may be saved or used for comparison with
previously-seen keys. (In other words, you don't need to worry about
the content of the key, only that it compares equal to the key you
expected to see.)
@raise SSHException: if no session is currently active.
@return: tuple of (key type, key)
@rtype: (string, string)
"""
if (not self.active) or (not self.initial_kex_done):
raise SSHException('No existing session')
key_msg = Message(self.host_key)
@ -208,10 +267,13 @@ class BaseTransport(threading.Thread):
return key_type, self.host_key
def is_active(self):
return self.active
"""
Return true if this session is active (open).
def is_authenticated(self):
return self.authenticated and self.active
@return: True if the session is still active (open); False if the session is closed.
@rtype: boolean
"""
return self.active
def open_session(self):
return self.open_channel('session')
@ -230,9 +292,9 @@ class BaseTransport(threading.Thread):
m.add_int(self.max_packet_size)
self.channels[chanid] = chan = Channel(chanid)
self.channel_events[chanid] = event = threading.Event()
chan.set_transport(self)
chan._set_transport(self)
chan.set_window(self.window_size, self.max_packet_size)
self.send_message(m)
self._send_message(m)
finally:
self.lock.release()
while 1:
@ -249,7 +311,8 @@ class BaseTransport(threading.Thread):
self.lock.release()
return chan
def unlink_channel(self, chanid):
def _unlink_channel(self, chanid):
"used by a Channel to remove itself from the active channel list"
try:
self.lock.acquire()
if self.channels.has_key(chanid):
@ -257,7 +320,7 @@ class BaseTransport(threading.Thread):
finally:
self.lock.release()
def read_all(self, n):
def _read_all(self, n):
out = ''
while n > 0:
try:
@ -271,7 +334,7 @@ class BaseTransport(threading.Thread):
raise EOFError()
return out
def write_all(self, out):
def _write_all(self, out):
while len(out) > 0:
n = self.sock.send(out)
if n <= 0:
@ -281,7 +344,7 @@ class BaseTransport(threading.Thread):
out = out[n:]
return
def build_packet(self, payload):
def _build_packet(self, payload):
# pad up at least 4 bytes, to nearest block-size (usually 8)
bsize = self.block_size_out
padding = 3 + bsize - ((len(payload) + 8) % bsize)
@ -291,11 +354,12 @@ class BaseTransport(threading.Thread):
packet += randpool.get_bytes(padding)
return packet
def send_message(self, data):
def _send_message(self, data):
# FIXME: should we check for rekeying here too?
# encrypt this sucka
packet = self.build_packet(str(data))
packet = self._build_packet(str(data))
if self.ultra_debug:
self.log(DEBUG, format_binary(packet, 'OUT: '))
self._log(DEBUG, format_binary(packet, 'OUT: '))
if self.engine_out != None:
out = self.engine_out.encrypt(packet)
else:
@ -308,29 +372,29 @@ class BaseTransport(threading.Thread):
out += HMAC.HMAC(self.mac_key_out, payload, self.local_mac_engine).digest()[:self.local_mac_len]
self.sequence_number_out += 1L
self.sequence_number_out %= 0x100000000L
self.write_all(out)
self._write_all(out)
finally:
self.write_lock.release()
def read_message(self):
def _read_message(self):
"only one thread will ever be in this function"
header = self.read_all(self.block_size_in)
header = self._read_all(self.block_size_in)
if self.engine_in != None:
header = self.engine_in.decrypt(header)
if self.ultra_debug:
self.log(DEBUG, format_binary(header, 'IN: '));
self._log(DEBUG, format_binary(header, 'IN: '));
packet_size = struct.unpack('>I', header[:4])[0]
# leftover contains decrypted bytes from the first block (after the length field)
leftover = header[4:]
if (packet_size - len(leftover)) % self.block_size_in != 0:
raise SSHException('Invalid packet blocking')
buffer = self.read_all(packet_size + self.remote_mac_len - len(leftover))
buffer = self._read_all(packet_size + self.remote_mac_len - len(leftover))
packet = buffer[:packet_size - len(leftover)]
post_packet = buffer[packet_size - len(leftover):]
if self.engine_in != None:
packet = self.engine_in.decrypt(packet)
if self.ultra_debug:
self.log(DEBUG, format_binary(packet, 'IN: '));
self._log(DEBUG, format_binary(packet, 'IN: '));
packet = leftover + packet
if self.remote_mac_len > 0:
mac = post_packet[:self.remote_mac_len]
@ -341,7 +405,7 @@ class BaseTransport(threading.Thread):
padding = ord(packet[0])
payload = packet[1:packet_size - padding + 1]
randpool.add_event(packet[packet_size - padding + 1])
#self.log(DEBUG, 'Got payload (%d bytes, %d padding)' % (packet_size, padding))
#self._log(DEBUG, 'Got payload (%d bytes, %d padding)' % (packet_size, padding))
msg = Message(payload[1:])
msg.seqno = self.sequence_number_in
self.sequence_number_in = (self.sequence_number_in + 1) & 0xffffffffL
@ -351,10 +415,10 @@ class BaseTransport(threading.Thread):
if (self.received_packets >= self.REKEY_PACKETS) or (self.received_bytes >= self.REKEY_BYTES):
# only ask once for rekeying
if self.local_kex_init is None:
self.log(DEBUG, 'Rekeying (hit %d packets, %d bytes)' % (self.received_packets,
self._log(DEBUG, 'Rekeying (hit %d packets, %d bytes)' % (self.received_packets,
self.received_bytes))
self.received_packets_overflow = 0
self.send_kex_init()
self._send_kex_init()
else:
# we've asked to rekey already -- give them 20 packets to
# comply, then just drop the connection
@ -364,14 +428,18 @@ class BaseTransport(threading.Thread):
return ord(payload[0]), msg
def set_K_H(self, k, h):
def _set_K_H(self, k, h):
"used by a kex object to set the K (root key) and H (exchange hash)"
self.K = k
self.H = h
if self.session_id == None:
self.session_id = h
def verify_key(self, host_key, sig):
def _expect_packet(self, type):
"used by a kex object to register the next packet type it expects to see"
self.expected_packet = type
def _verify_key(self, host_key, sig):
if self.host_key_type == 'ssh-rsa':
key = RSAKey(Message(host_key))
elif self.host_key_type == 'ssh-dss':
@ -384,7 +452,7 @@ class BaseTransport(threading.Thread):
raise SSHException('Signature verification (%s) failed. Boo. Robey should debug this.' % self.host_key_type)
self.host_key = host_key
def compute_key(self, id, nbytes):
def _compute_key(self, id, nbytes):
"id is 'A' - 'F' for the various keys used by ssh"
m = Message()
m.add_mpint(self.K)
@ -402,30 +470,30 @@ class BaseTransport(threading.Thread):
sofar += hash
return out[:nbytes]
def get_cipher(self, name, key, iv):
if not self.cipher_info.has_key(name):
def _get_cipher(self, name, key, iv):
if not self._cipher_info.has_key(name):
raise SSHException('Unknown client cipher ' + name)
return self.cipher_info[name]['class'].new(key, self.cipher_info[name]['mode'], iv)
return self._cipher_info[name]['class'].new(key, self._cipher_info[name]['mode'], iv)
def run(self):
self.active = 1
def _run(self):
self.active = True
try:
# SSH-1.99-OpenSSH_2.9p2
self.write_all(self.local_version + '\r\n')
self.check_banner()
self.send_kex_init()
self._write_all(self.local_version + '\r\n')
self._check_banner()
self._send_kex_init()
self.expected_packet = MSG_KEXINIT
while self.active:
ptype, m = self.read_message()
ptype, m = self._read_message()
if ptype == MSG_IGNORE:
continue
elif ptype == MSG_DISCONNECT:
self.parse_disconnect(m)
self.active = 0
self._parse_disconnect(m)
self.active = False
break
elif ptype == MSG_DEBUG:
self.parse_debug(m)
self._parse_debug(m)
continue
if self.expected_packet != 0:
if ptype != self.expected_packet:
@ -435,28 +503,29 @@ class BaseTransport(threading.Thread):
self.kex_engine.parse_next(ptype, m)
continue
if self.handler_table.has_key(ptype):
self.handler_table[ptype](self, m)
elif self.channel_handler_table.has_key(ptype):
if self._handler_table.has_key(ptype):
self._handler_table[ptype](self, m)
elif self._channel_handler_table.has_key(ptype):
chanid = m.get_int()
if self.channels.has_key(chanid):
self.channel_handler_table[ptype](self.channels[chanid], m)
self._channel_handler_table[ptype](self.channels[chanid], m)
else:
self.log(WARNING, 'Oops, unhandled type %d' % ptype)
self._log(WARNING, 'Oops, unhandled type %d' % ptype)
msg = Message()
msg.add_byte(chr(MSG_UNIMPLEMENTED))
msg.add_int(m.seqno)
self.send_message(msg)
self._send_message(msg)
except SSHException, e:
self.log(DEBUG, 'Exception: ' + str(e))
self.log(DEBUG, tb_strings())
self._log(DEBUG, 'Exception: ' + str(e))
self._log(DEBUG, tb_strings())
except EOFError, e:
self.log(DEBUG, 'EOF')
self._log(DEBUG, 'EOF')
self._log(DEBUG, tb_strings())
except Exception, e:
self.log(DEBUG, 'Unknown exception: ' + str(e))
self.log(DEBUG, tb_strings())
self._log(DEBUG, 'Unknown exception: ' + str(e))
self._log(DEBUG, tb_strings())
if self.active:
self.active = 0
self.active = False
if self.completion_event != None:
self.completion_event.set()
if self.auth_event != None:
@ -468,36 +537,49 @@ class BaseTransport(threading.Thread):
### protocol stages
def renegotiate_keys(self):
"""
Force this session to switch to new keys. Normally this is done
automatically after the session hits a certain number of packets or
bytes sent or received, but this method gives you the option of forcing
new keys whenever you want. Negotiating new keys causes a pause in
traffic both ways as the two sides swap keys and do computations. This
method returns when the session has switched to new keys, or the
session has died mid-negotiation.
@return: True if the renegotiation was successful, and the link is
using new keys; False if the session dropped during renegotiation.
@rtype: boolean
"""
self.completion_event = threading.Event()
self.send_kex_init()
self._send_kex_init()
while 1:
self.completion_event.wait(0.1);
if not self.active:
return 0
return False
if self.completion_event.isSet():
break
return 1
return True
def negotiate_keys(self, m):
def _negotiate_keys(self, m):
# throws SSHException on anything unusual
if self.local_kex_init == None:
# remote side wants to renegotiate
self.send_kex_init()
self.parse_kex_init(m)
self._send_kex_init()
self._parse_kex_init(m)
self.kex_engine.start_kex()
def check_banner(self):
def _check_banner(self):
# this is slow, but we only have to do it once
for i in range(5):
buffer = ''
while not '\n' in buffer:
buffer += self.read_all(1)
buffer += self._read_all(1)
buffer = buffer[:-1]
if (len(buffer) > 0) and (buffer[-1] == '\r'):
buffer = buffer[:-1]
if buffer[:4] == 'SSH-':
break
self.log(DEBUG, 'Banner: ' + buffer)
self._log(DEBUG, 'Banner: ' + buffer)
if buffer[:4] != 'SSH-':
raise SSHException('Indecipherable protocol version "' + buffer + '"')
# save this server version string for later
@ -516,17 +598,21 @@ class BaseTransport(threading.Thread):
client = segs[2]
if version != '1.99' and version != '2.0':
raise SSHException('Incompatible version (%s instead of 2.0)' % (version,))
self.log(INFO, 'Connected (version %s, client %s)' % (version, client))
self._log(INFO, 'Connected (version %s, client %s)' % (version, client))
def send_kex_init(self):
# send a really wimpy kex-init packet that says we're a bare-bones ssh client
def _send_kex_init(self):
"""
announce to the other side that we'd like to negotiate keys, and what
kind of key negotiation we support.
"""
if self.server_mode:
# FIXME: can't do group-exchange (gex) yet -- too slow
if 'diffie-hellman-group-exchange-sha1' in self.preferred_kex:
if (self.modulus_pack is None) and ('diffie-hellman-group-exchange-sha1' in self.preferred_kex):
# can't do group-exchange if we don't have a pack of potential primes
self.preferred_kex.remove('diffie-hellman-group-exchange-sha1')
available_server_keys = filter(self.server_key_dict.keys().__contains__,
self.preferred_keys)
else:
available_server_keys = self.preferred_keys
m = Message()
m.add_byte(chr(MSG_KEXINIT))
@ -545,9 +631,9 @@ class BaseTransport(threading.Thread):
m.add_int(0)
# save a copy for later (needed to compute a hash)
self.local_kex_init = str(m)
self.send_message(m)
self._send_message(m)
def parse_kex_init(self, m):
def _parse_kex_init(self, m):
# reset counters of when to re-key, since we are now re-keying
self.received_bytes = 0
self.received_packets = 0
@ -580,7 +666,7 @@ class BaseTransport(threading.Thread):
agreed_kex = filter(kex_algo_list.__contains__, self.preferred_kex)
if len(agreed_kex) == 0:
raise SSHException('Incompatible ssh peer (no acceptable kex algorithm)')
self.kex_engine = self.kex_info[agreed_kex[0]](self)
self.kex_engine = self._kex_info[agreed_kex[0]](self)
if self.server_mode:
available_server_keys = filter(self.server_key_dict.keys().__contains__,
@ -608,7 +694,7 @@ class BaseTransport(threading.Thread):
raise SSHException('Incompatible ssh server (no acceptable ciphers)')
self.local_cipher = agreed_local_ciphers[0]
self.remote_cipher = agreed_remote_ciphers[0]
self.log(DEBUG, 'Ciphers agreed: local=%s, remote=%s' % (self.local_cipher, self.remote_cipher))
self._log(DEBUG, 'Ciphers agreed: local=%s, remote=%s' % (self.local_cipher, self.remote_cipher))
if self.server_mode:
agreed_remote_macs = filter(self.preferred_macs.__contains__, client_mac_algo_list)
@ -621,7 +707,7 @@ class BaseTransport(threading.Thread):
self.local_mac = agreed_local_macs[0]
self.remote_mac = agreed_remote_macs[0]
self.log(DEBUG, 'kex algos:' + str(kex_algo_list) + ' server key:' + str(server_key_algo_list) + \
self._log(DEBUG, 'kex algos:' + str(kex_algo_list) + ' server key:' + str(server_key_algo_list) + \
' client encrypt:' + str(client_encrypt_algo_list) + \
' server encrypt:' + str(server_encrypt_algo_list) + \
' client mac:' + str(client_mac_algo_list) + \
@ -631,7 +717,7 @@ class BaseTransport(threading.Thread):
' client lang:' + str(client_lang_list) + \
' server lang:' + str(server_lang_list) + \
' kex follows?' + str(kex_follows))
self.log(DEBUG, 'using kex %s; server key type %s; cipher: local %s, remote %s; mac: local %s, remote %s' %
self._log(DEBUG, 'using kex %s; server key type %s; cipher: local %s, remote %s; mac: local %s, remote %s' %
(agreed_kex[0], self.host_key_type, self.local_cipher, self.remote_cipher, self.local_mac,
self.remote_mac))
@ -642,50 +728,52 @@ class BaseTransport(threading.Thread):
# away those bytes because they aren't part of the hash.
self.remote_kex_init = chr(MSG_KEXINIT) + m.get_so_far()
def activate_inbound(self):
def _activate_inbound(self):
"switch on newly negotiated encryption parameters for inbound traffic"
self.block_size_in = self.cipher_info[self.remote_cipher]['block-size']
self.block_size_in = self._cipher_info[self.remote_cipher]['block-size']
if self.server_mode:
IV_in = self.compute_key('A', self.block_size_in)
key_in = self.compute_key('C', self.cipher_info[self.remote_cipher]['key-size'])
IV_in = self._compute_key('A', self.block_size_in)
key_in = self._compute_key('C', self._cipher_info[self.remote_cipher]['key-size'])
else:
IV_in = self.compute_key('B', self.block_size_in)
key_in = self.compute_key('D', self.cipher_info[self.remote_cipher]['key-size'])
self.engine_in = self.get_cipher(self.remote_cipher, key_in, IV_in)
self.remote_mac_len = self.mac_info[self.remote_mac]['size']
self.remote_mac_engine = self.mac_info[self.remote_mac]['class']
IV_in = self._compute_key('B', self.block_size_in)
key_in = self._compute_key('D', self._cipher_info[self.remote_cipher]['key-size'])
self.engine_in = self._get_cipher(self.remote_cipher, key_in, IV_in)
self.remote_mac_len = self._mac_info[self.remote_mac]['size']
self.remote_mac_engine = self._mac_info[self.remote_mac]['class']
# initial mac keys are done in the hash's natural size (not the potentially truncated
# transmission size)
if self.server_mode:
self.mac_key_in = self.compute_key('E', self.remote_mac_engine.digest_size)
self.mac_key_in = self._compute_key('E', self.remote_mac_engine.digest_size)
else:
self.mac_key_in = self.compute_key('F', self.remote_mac_engine.digest_size)
self.mac_key_in = self._compute_key('F', self.remote_mac_engine.digest_size)
def activate_outbound(self):
def _activate_outbound(self):
"switch on newly negotiated encryption parameters for outbound traffic"
m = Message()
m.add_byte(chr(MSG_NEWKEYS))
self.send_message(m)
self.block_size_out = self.cipher_info[self.local_cipher]['block-size']
self._send_message(m)
self.block_size_out = self._cipher_info[self.local_cipher]['block-size']
if self.server_mode:
IV_out = self.compute_key('B', self.block_size_out)
key_out = self.compute_key('D', self.cipher_info[self.local_cipher]['key-size'])
IV_out = self._compute_key('B', self.block_size_out)
key_out = self._compute_key('D', self._cipher_info[self.local_cipher]['key-size'])
else:
IV_out = self.compute_key('A', self.block_size_out)
key_out = self.compute_key('C', self.cipher_info[self.local_cipher]['key-size'])
self.engine_out = self.get_cipher(self.local_cipher, key_out, IV_out)
self.local_mac_len = self.mac_info[self.local_mac]['size']
self.local_mac_engine = self.mac_info[self.local_mac]['class']
IV_out = self._compute_key('A', self.block_size_out)
key_out = self._compute_key('C', self._cipher_info[self.local_cipher]['key-size'])
self.engine_out = self._get_cipher(self.local_cipher, key_out, IV_out)
self.local_mac_len = self._mac_info[self.local_mac]['size']
self.local_mac_engine = self._mac_info[self.local_mac]['class']
# initial mac keys are done in the hash's natural size (not the potentially truncated
# transmission size)
if self.server_mode:
self.mac_key_out = self.compute_key('F', self.local_mac_engine.digest_size)
self.mac_key_out = self._compute_key('F', self.local_mac_engine.digest_size)
else:
self.mac_key_out = self.compute_key('E', self.local_mac_engine.digest_size)
self.mac_key_out = self._compute_key('E', self.local_mac_engine.digest_size)
# we always expect to receive NEWKEYS now
self.expected_packet = MSG_NEWKEYS
def parse_newkeys(self, m):
self.log(DEBUG, 'Switch to new keys ...')
self.activate_inbound()
def _parse_newkeys(self, m):
self._log(DEBUG, 'Switch to new keys ...')
self._activate_inbound()
# can also free a bunch of stuff here
self.local_kex_init = self.remote_kex_init = None
self.e = self.f = self.K = self.x = None
@ -697,24 +785,24 @@ class BaseTransport(threading.Thread):
self.completion_event.set()
return
def parse_disconnect(self, m):
def _parse_disconnect(self, m):
code = m.get_int()
desc = m.get_string()
self.log(INFO, 'Disconnect (code %d): %s' % (code, desc))
self._log(INFO, 'Disconnect (code %d): %s' % (code, desc))
def parse_channel_open_success(self, m):
def _parse_channel_open_success(self, m):
chanid = m.get_int()
server_chanid = m.get_int()
server_window_size = m.get_int()
server_max_packet_size = m.get_int()
if not self.channels.has_key(chanid):
self.log(WARNING, 'Success for unrequested channel! [??]')
self._log(WARNING, 'Success for unrequested channel! [??]')
return
try:
self.lock.acquire()
chan = self.channels[chanid]
chan.set_remote_channel(server_chanid, server_window_size, server_max_packet_size)
self.log(INFO, 'Secsh channel %d opened.' % chanid)
self._log(INFO, 'Secsh channel %d opened.' % chanid)
if self.channel_events.has_key(chanid):
self.channel_events[chanid].set()
del self.channel_events[chanid]
@ -722,7 +810,7 @@ class BaseTransport(threading.Thread):
self.lock.release()
return
def parse_channel_open_failure(self, m):
def _parse_channel_open_failure(self, m):
chanid = m.get_int()
reason = m.get_int()
reason_str = m.get_string()
@ -731,7 +819,7 @@ class BaseTransport(threading.Thread):
reason_text = CONNECTION_FAILED_CODE[reason]
else:
reason_text = '(unknown code)'
self.log(INFO, 'Secsh channel %d open FAILED: %s: %s' % (chanid, reason_str, reason_text))
self._log(INFO, 'Secsh channel %d open FAILED: %s: %s' % (chanid, reason_str, reason_text))
try:
self.lock.aquire()
if self.channels.has_key(chanid):
@ -747,14 +835,14 @@ class BaseTransport(threading.Thread):
"override me! return object descended from Channel to allow, or None to reject"
return None
def parse_channel_open(self, m):
def _parse_channel_open(self, m):
kind = m.get_string()
chanid = m.get_int()
initial_window_size = m.get_int()
max_packet_size = m.get_int()
reject = False
if not self.server_mode:
self.log(DEBUG, 'Rejecting "%s" channel request from server.' % kind)
self._log(DEBUG, 'Rejecting "%s" channel request from server.' % kind)
reject = True
reason = self.OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED
else:
@ -766,7 +854,7 @@ class BaseTransport(threading.Thread):
self.lock.release()
chan = self.check_channel_request(kind, my_chanid)
if (chan is None) or (type(chan) is int):
self.log(DEBUG, 'Rejecting "%s" channel request from client.' % kind)
self._log(DEBUG, 'Rejecting "%s" channel request from client.' % kind)
reject = True
if type(chan) is int:
reason = chan
@ -779,12 +867,12 @@ class BaseTransport(threading.Thread):
msg.add_int(reason)
msg.add_string('')
msg.add_string('en')
self.send_message(msg)
self._send_message(msg)
return
try:
self.lock.acquire()
self.channels[my_chanid] = chan
chan.set_transport(self)
chan._set_transport(self)
chan.set_window(self.window_size, self.max_packet_size)
chan.set_remote_channel(chanid, initial_window_size, max_packet_size)
finally:
@ -795,8 +883,8 @@ class BaseTransport(threading.Thread):
m.add_int(my_chanid)
m.add_int(self.window_size)
m.add_int(self.max_packet_size)
self.send_message(m)
self.log(INFO, 'Secsh channel %d opened.' % my_chanid)
self._send_message(m)
self._log(INFO, 'Secsh channel %d opened.' % my_chanid)
try:
self.lock.acquire()
self.server_accepts.append(chan)
@ -820,21 +908,21 @@ class BaseTransport(threading.Thread):
self.lock.release()
return chan
def parse_debug(self, m):
def _parse_debug(self, m):
always_display = m.get_boolean()
msg = m.get_string()
lang = m.get_string()
self.log(DEBUG, 'Debug msg: ' + safe_string(msg))
self._log(DEBUG, 'Debug msg: ' + safe_string(msg))
handler_table = {
MSG_NEWKEYS: parse_newkeys,
MSG_CHANNEL_OPEN_SUCCESS: parse_channel_open_success,
MSG_CHANNEL_OPEN_FAILURE: parse_channel_open_failure,
MSG_CHANNEL_OPEN: parse_channel_open,
MSG_KEXINIT: negotiate_keys,
_handler_table = {
MSG_NEWKEYS: _parse_newkeys,
MSG_CHANNEL_OPEN_SUCCESS: _parse_channel_open_success,
MSG_CHANNEL_OPEN_FAILURE: _parse_channel_open_failure,
MSG_CHANNEL_OPEN: _parse_channel_open,
MSG_KEXINIT: _negotiate_keys,
}
channel_handler_table = {
_channel_handler_table = {
MSG_CHANNEL_SUCCESS: Channel.request_success,
MSG_CHANNEL_FAILURE: Channel.request_failed,
MSG_CHANNEL_DATA: Channel.feed,

View File

@ -1,7 +1,6 @@
#!/usr/bin/python
import sys, struct, traceback
from Crypto.Util import number
def inflate_long(s, always_positive=0):
"turns a normalized byte string into a long-int (adapted from Crypto.Util.number)"
@ -98,20 +97,5 @@ def bit_length(n):
bitlen -= 1
return bitlen
def generate_prime(bits, randpool):
hbyte_mask = pow(2, bits % 8) - 1
x = randpool.get_bytes((bits+7) // 8)
if hbyte_mask > 0:
x = chr(ord(x[0]) & hbyte_mask) + x[1:]
n = inflate_long(x, 1)
n |= 1
n |= (1 << (bits - 1))
while 1:
# loop catches the case where we increment n into a higher bit-range
while not number.isPrime(n):
n += 2
if bit_length(n) == bits:
return n
def tb_strings():
return ''.join(traceback.format_exception(*sys.exc_info())).split('\n')