[project @ Arch-1:robey@lag.net--2003-public%secsh--dev--1.0--patch-16]

hook up server-side kex-gex; add more documentation
group-exchange kex should work now on the server side.  it will only be
advertised if a "moduli" file has been loaded (see the -gasp- docs) so we
don't spend hours (literally. hours.) computing primes.  some of the logic
was previously wrong, too, since it had never been tested.

fixed repr() string for Transport/BaseTransport.  moved is_authenticated to
Transport where it belongs.

added lots of documentation (but still only about 10% documented).  lots of
methods were made private finally.
This commit is contained in:
Robey Pointer 2003-12-28 03:20:42 +00:00
parent eb4c279ec4
commit 36d6d95dc6
11 changed files with 506 additions and 274 deletions

4
NOTES
View File

@ -22,14 +22,14 @@ from BaseTransport:
get_server_key get_server_key
close close
get_remote_server_key get_remote_server_key
is_active * is_active
is_authenticated
open_session open_session
open_channel open_channel
renegotiate_keys renegotiate_keys
check_channel_request check_channel_request
from Transport: from Transport:
* is_authenticated
auth_key auth_key
auth_password auth_password
get_allowed_auths get_allowed_auths

View File

@ -69,7 +69,7 @@ try:
t.ultra_debug = 0 t.ultra_debug = 0
t.start_client(event) t.start_client(event)
# print repr(t) # print repr(t)
event.wait(10) event.wait(15)
if not t.is_active(): if not t.is_active():
print '*** SSH negotiation failed.' print '*** SSH negotiation failed.'
sys.exit(1) sys.exit(1)

View File

@ -70,6 +70,11 @@ print 'Got a connection!'
try: try:
event = threading.Event() event = threading.Event()
t = ServerTransport(client) t = ServerTransport(client)
try:
t.load_server_moduli()
except:
print '(Failed to load moduli -- gex will be unsupported.)'
raise
t.add_server_key(host_key) t.add_server_key(host_key)
t.ultra_debug = 0 t.ultra_debug = 0
t.start_server(event) t.start_server(event)
@ -81,10 +86,11 @@ try:
# print repr(t) # print repr(t)
# wait for auth # wait for auth
chan = t.accept(10) chan = t.accept(20)
if chan is None: if chan is None:
print '*** No channel.' print '*** No channel.'
sys.exit(1) sys.exit(1)
print 'Authenticated!'
chan.event.wait(10) chan.event.wait(10)
if not chan.event.isSet(): if not chan.event.isSet():
print '*** Client never asked for a shell.' print '*** Client never asked for a shell.'

View File

@ -17,4 +17,4 @@ from rsakey import RSAKey
from dsskey import DSSKey from dsskey import DSSKey
from util import hexify from util import hexify
__all__ = [ 'Transport', 'Channel', 'RSAKey', 'DSSKey', 'hexify' ] #__all__ = [ 'Transport', 'Channel', 'RSAKey', 'DSSKey', 'hexify' ]

View File

@ -19,17 +19,45 @@ class Transport(BaseTransport):
def __init__(self, sock): def __init__(self, sock):
BaseTransport.__init__(self, sock) BaseTransport.__init__(self, sock)
self.authenticated = False
self.auth_event = None self.auth_event = None
# for server mode: # for server mode:
self.auth_username = None self.auth_username = None
self.auth_fail_count = 0 self.auth_fail_count = 0
self.auth_complete = 0 self.auth_complete = 0
def request_auth(self): def __repr__(self):
if not self.active:
return '<paramiko.Transport (unconnected)>'
out = '<paramiko.Transport'
if self.local_cipher != '':
out += ' (cipher %s)' % self.local_cipher
if self.authenticated:
if len(self.channels) == 1:
out += ' (active; 1 open channel)'
else:
out += ' (active; %d open channels)' % len(self.channels)
elif self.initial_kex_done:
out += ' (connected; awaiting auth)'
else:
out += ' (connecting)'
out += '>'
return out
def is_authenticated(self):
"""
Return true if this session is active and authenticated.
@return: True if the session is still open and has been authenticated successfully;
False if authentication failed and/or the session is closed.
"""
return self.authenticated and self.active
def _request_auth(self):
m = Message() m = Message()
m.add_byte(chr(MSG_SERVICE_REQUEST)) m.add_byte(chr(MSG_SERVICE_REQUEST))
m.add_string('ssh-userauth') m.add_string('ssh-userauth')
self.send_message(m) self._send_message(m)
def auth_key(self, username, key, event): def auth_key(self, username, key, event):
if (not self.active) or (not self.initial_kex_done): if (not self.active) or (not self.initial_kex_done):
@ -41,7 +69,7 @@ class Transport(BaseTransport):
self.auth_method = 'publickey' self.auth_method = 'publickey'
self.username = username self.username = username
self.private_key = key self.private_key = key
self.request_auth() self._request_auth()
finally: finally:
self.lock.release() self.lock.release()
@ -56,7 +84,7 @@ class Transport(BaseTransport):
self.auth_method = 'password' self.auth_method = 'password'
self.username = username self.username = username
self.password = password self.password = password
self.request_auth() self._request_auth()
finally: finally:
self.lock.release() self.lock.release()
@ -66,7 +94,7 @@ class Transport(BaseTransport):
m.add_int(DISCONNECT_SERVICE_NOT_AVAILABLE) m.add_int(DISCONNECT_SERVICE_NOT_AVAILABLE)
m.add_string('Service not available') m.add_string('Service not available')
m.add_string('en') m.add_string('en')
self.send_message(m) self._send_message(m)
self.close() self.close()
def disconnect_no_more_auth(self): def disconnect_no_more_auth(self):
@ -75,7 +103,7 @@ class Transport(BaseTransport):
m.add_int(DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE) m.add_int(DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE)
m.add_string('No more auth methods available') m.add_string('No more auth methods available')
m.add_string('en') m.add_string('en')
self.send_message(m) self._send_message(m)
self.close() self.close()
def parse_service_request(self, m): def parse_service_request(self, m):
@ -85,7 +113,7 @@ class Transport(BaseTransport):
m = Message() m = Message()
m.add_byte(chr(MSG_SERVICE_ACCEPT)) m.add_byte(chr(MSG_SERVICE_ACCEPT))
m.add_string(service) m.add_string(service)
self.send_message(m) self._send_message(m)
return return
# dunno this one # dunno this one
self.disconnect_service_not_available() self.disconnect_service_not_available()
@ -93,7 +121,7 @@ class Transport(BaseTransport):
def parse_service_accept(self, m): def parse_service_accept(self, m):
service = m.get_string() service = m.get_string()
if service == 'ssh-userauth': if service == 'ssh-userauth':
self.log(DEBUG, 'userauth is OK') self._log(DEBUG, 'userauth is OK')
m = Message() m = Message()
m.add_byte(chr(MSG_USERAUTH_REQUEST)) m.add_byte(chr(MSG_USERAUTH_REQUEST))
m.add_string(self.username) m.add_string(self.username)
@ -109,9 +137,9 @@ class Transport(BaseTransport):
m.add_string(self.private_key.sign_ssh_session(self.randpool, self.H, self.username)) m.add_string(self.private_key.sign_ssh_session(self.randpool, self.H, self.username))
else: else:
raise SSHException('Unknown auth method "%s"' % self.auth_method) raise SSHException('Unknown auth method "%s"' % self.auth_method)
self.send_message(m) self._send_message(m)
else: else:
self.log(DEBUG, 'Service request "%s" accepted (?)' % service) self._log(DEBUG, 'Service request "%s" accepted (?)' % service)
def get_allowed_auths(self, username): def get_allowed_auths(self, username):
"override me!" "override me!"
@ -136,7 +164,7 @@ class Transport(BaseTransport):
m.add_byte(chr(MSG_USERAUTH_FAILURE)) m.add_byte(chr(MSG_USERAUTH_FAILURE))
m.add_string('none') m.add_string('none')
m.add_boolean(0) m.add_boolean(0)
self.send_message(m) self._send_message(m)
return return
if self.auth_complete: if self.auth_complete:
# ignore # ignore
@ -144,12 +172,12 @@ class Transport(BaseTransport):
username = m.get_string() username = m.get_string()
service = m.get_string() service = m.get_string()
method = m.get_string() method = m.get_string()
self.log(DEBUG, 'Auth request (type=%s) service=%s, username=%s' % (method, service, username)) self._log(DEBUG, 'Auth request (type=%s) service=%s, username=%s' % (method, service, username))
if service != 'ssh-connection': if service != 'ssh-connection':
self.disconnect_service_not_available() self.disconnect_service_not_available()
return return
if (self.auth_username is not None) and (self.auth_username != username): if (self.auth_username is not None) and (self.auth_username != username):
self.log(DEBUG, 'Auth rejected because the client attempted to change username in mid-flight') self._log(DEBUG, 'Auth rejected because the client attempted to change username in mid-flight')
self.disconnect_no_more_auth() self.disconnect_no_more_auth()
return return
if method == 'none': if method == 'none':
@ -160,7 +188,7 @@ class Transport(BaseTransport):
if changereq: if changereq:
# always treated as failure, since we don't support changing passwords, but collect # always treated as failure, since we don't support changing passwords, but collect
# the list of valid auth types from the callback anyway # the list of valid auth types from the callback anyway
self.log(DEBUG, 'Auth request to change passwords (rejected)') self._log(DEBUG, 'Auth request to change passwords (rejected)')
newpassword = m.get_string().decode('UTF-8') newpassword = m.get_string().decode('UTF-8')
result = self.AUTH_FAILED result = self.AUTH_FAILED
else: else:
@ -173,11 +201,11 @@ class Transport(BaseTransport):
# okay, send result # okay, send result
m = Message() m = Message()
if result == self.AUTH_SUCCESSFUL: if result == self.AUTH_SUCCESSFUL:
self.log(DEBUG, 'Auth granted.') self._log(DEBUG, 'Auth granted.')
m.add_byte(chr(MSG_USERAUTH_SUCCESS)) m.add_byte(chr(MSG_USERAUTH_SUCCESS))
self.auth_complete = 1 self.auth_complete = 1
else: else:
self.log(DEBUG, 'Auth rejected.') self._log(DEBUG, 'Auth rejected.')
m.add_byte(chr(MSG_USERAUTH_FAILURE)) m.add_byte(chr(MSG_USERAUTH_FAILURE))
m.add_string(self.get_allowed_auths(username)) m.add_string(self.get_allowed_auths(username))
if result == self.AUTH_PARTIALLY_SUCCESSFUL: if result == self.AUTH_PARTIALLY_SUCCESSFUL:
@ -185,13 +213,13 @@ class Transport(BaseTransport):
else: else:
m.add_boolean(0) m.add_boolean(0)
self.auth_fail_count += 1 self.auth_fail_count += 1
self.send_message(m) self._send_message(m)
if self.auth_fail_count >= 10: if self.auth_fail_count >= 10:
self.disconnect_no_more_auth() self.disconnect_no_more_auth()
def parse_userauth_success(self, m): def parse_userauth_success(self, m):
self.log(INFO, 'Authentication successful!') self._log(INFO, 'Authentication successful!')
self.authenticated = 1 self.authenticated = True
if self.auth_event != None: if self.auth_event != None:
self.auth_event.set() self.auth_event.set()
@ -199,12 +227,12 @@ class Transport(BaseTransport):
authlist = m.get_list() authlist = m.get_list()
partial = m.get_boolean() partial = m.get_boolean()
if partial: if partial:
self.log(INFO, 'Authentication continues...') self._log(INFO, 'Authentication continues...')
self.log(DEBUG, 'Methods: ' + str(partial)) self._log(DEBUG, 'Methods: ' + str(partial))
# FIXME - do something # FIXME - do something
pass pass
self.log(INFO, 'Authentication failed.') self._log(INFO, 'Authentication failed.')
self.authenticated = 0 self.authenticated = False
self.close() self.close()
if self.auth_event != None: if self.auth_event != None:
self.auth_event.set() self.auth_event.set()
@ -212,11 +240,11 @@ class Transport(BaseTransport):
def parse_userauth_banner(self, m): def parse_userauth_banner(self, m):
banner = m.get_string() banner = m.get_string()
lang = m.get_string() lang = m.get_string()
self.log(INFO, 'Auth banner: ' + banner) self._log(INFO, 'Auth banner: ' + banner)
# who cares. # who cares.
handler_table = BaseTransport.handler_table.copy() _handler_table = BaseTransport._handler_table.copy()
handler_table.update({ _handler_table.update({
MSG_SERVICE_REQUEST: parse_service_request, MSG_SERVICE_REQUEST: parse_service_request,
MSG_SERVICE_ACCEPT: parse_service_accept, MSG_SERVICE_ACCEPT: parse_service_accept,
MSG_USERAUTH_REQUEST: parse_userauth_request, MSG_USERAUTH_REQUEST: parse_userauth_request,

View File

@ -50,10 +50,10 @@ class Channel(object):
out += '>' out += '>'
return out return out
def set_transport(self, transport): def _set_transport(self, transport):
self.transport = transport self.transport = transport
def log(self, level, msg): def _log(self, level, msg):
self.logger.log(level, msg) self.logger.log(level, msg)
def set_window(self, window_size, max_packet_size): def set_window(self, window_size, max_packet_size):
@ -70,7 +70,7 @@ class Channel(object):
self.active = 1 self.active = 1
def request_success(self, m): def request_success(self, m):
self.log(DEBUG, 'Sesch channel %d request ok' % self.chanid) self._log(DEBUG, 'Sesch channel %d request ok' % self.chanid)
return return
def request_failed(self, m): def request_failed(self, m):
@ -80,13 +80,13 @@ class Channel(object):
s = m.get_string() s = m.get_string()
try: try:
self.lock.acquire() self.lock.acquire()
self.log(DEBUG, 'fed %d bytes' % len(s)) self._log(DEBUG, 'fed %d bytes' % len(s))
if self.pipe_wfd != None: if self.pipe_wfd != None:
self.feed_pipe(s) self.feed_pipe(s)
else: else:
self.in_buffer += s self.in_buffer += s
self.in_buffer_cv.notifyAll() self.in_buffer_cv.notifyAll()
self.log(DEBUG, '(out from feed)') self._log(DEBUG, '(out from feed)')
finally: finally:
self.lock.release() self.lock.release()
@ -94,7 +94,7 @@ class Channel(object):
nbytes = m.get_int() nbytes = m.get_int()
try: try:
self.lock.acquire() self.lock.acquire()
self.log(DEBUG, 'window up %d' % nbytes) self._log(DEBUG, 'window up %d' % nbytes)
self.out_window_size += nbytes self.out_window_size += nbytes
self.out_buffer_cv.notifyAll() self.out_buffer_cv.notifyAll()
finally: finally:
@ -146,7 +146,7 @@ class Channel(object):
pixelheight = m.get_int() pixelheight = m.get_int()
ok = self.check_window_change_request(width, height, pixelwidth, pixelheight) ok = self.check_window_change_request(width, height, pixelwidth, pixelheight)
else: else:
self.log(DEBUG, 'Unhandled channel request "%s"' % key) self._log(DEBUG, 'Unhandled channel request "%s"' % key)
ok = False ok = False
if want_reply: if want_reply:
m = Message() m = Message()
@ -155,7 +155,7 @@ class Channel(object):
else: else:
m.add_byte(chr(MSG_CHANNEL_FAILURE)) m.add_byte(chr(MSG_CHANNEL_FAILURE))
m.add_int(self.remote_chanid) m.add_int(self.remote_chanid)
self.transport.send_message(m) self.transport._send_message(m)
def handle_eof(self, m): def handle_eof(self, m):
try: try:
@ -168,7 +168,7 @@ class Channel(object):
self.pipe_wfd = None self.pipe_wfd = None
finally: finally:
self.lock.release() self.lock.release()
self.log(DEBUG, 'EOF received') self._log(DEBUG, 'EOF received')
def handle_close(self, m): def handle_close(self, m):
self.close() self.close()
@ -199,7 +199,7 @@ class Channel(object):
# pixel height, width (usually useless) # pixel height, width (usually useless)
m.add_int(0).add_int(0) m.add_int(0).add_int(0)
m.add_string('') m.add_string('')
self.transport.send_message(m) self.transport._send_message(m)
def invoke_shell(self): def invoke_shell(self):
if self.closed or self.eof_received or self.eof_sent or not self.active: if self.closed or self.eof_received or self.eof_sent or not self.active:
@ -209,7 +209,7 @@ class Channel(object):
m.add_int(self.remote_chanid) m.add_int(self.remote_chanid)
m.add_string('shell') m.add_string('shell')
m.add_boolean(1) m.add_boolean(1)
self.transport.send_message(m) self.transport._send_message(m)
def exec_command(self, command): def exec_command(self, command):
if self.closed or self.eof_received or self.eof_sent or not self.active: if self.closed or self.eof_received or self.eof_sent or not self.active:
@ -220,7 +220,7 @@ class Channel(object):
m.add_string('exec') m.add_string('exec')
m.add_boolean(1) m.add_boolean(1)
m.add_string(command) m.add_string(command)
self.transport.send_message(m) self.transport._send_message(m)
def invoke_subsystem(self, subsystem): def invoke_subsystem(self, subsystem):
if self.closed or self.eof_received or self.eof_sent or not self.active: if self.closed or self.eof_received or self.eof_sent or not self.active:
@ -231,7 +231,7 @@ class Channel(object):
m.add_string('subsystem') m.add_string('subsystem')
m.add_boolean(1) m.add_boolean(1)
m.add_string(subsystem) m.add_string(subsystem)
self.transport.send_message(m) self.transport._send_message(m)
def resize_pty(self, width=80, height=24): def resize_pty(self, width=80, height=24):
if self.closed or self.eof_received or self.eof_sent or not self.active: if self.closed or self.eof_received or self.eof_sent or not self.active:
@ -244,7 +244,7 @@ class Channel(object):
m.add_int(width) m.add_int(width)
m.add_int(height) m.add_int(height)
m.add_int(0).add_int(0) m.add_int(0).add_int(0)
self.transport.send_message(m) self.transport._send_message(m)
def get_transport(self): def get_transport(self):
return self.transport return self.transport
@ -262,9 +262,9 @@ class Channel(object):
m = Message() m = Message()
m.add_byte(chr(MSG_CHANNEL_EOF)) m.add_byte(chr(MSG_CHANNEL_EOF))
m.add_int(self.remote_chanid) m.add_int(self.remote_chanid)
self.transport.send_message(m) self.transport._send_message(m)
self.eof_sent = 1 self.eof_sent = 1
self.log(DEBUG, 'EOF sent') self._log(DEBUG, 'EOF sent')
return return
@ -290,9 +290,9 @@ class Channel(object):
m = Message() m = Message()
m.add_byte(chr(MSG_CHANNEL_CLOSE)) m.add_byte(chr(MSG_CHANNEL_CLOSE))
m.add_int(self.remote_chanid) m.add_int(self.remote_chanid)
self.transport.send_message(m) self.transport._send_message(m)
self.closed = 1 self.closed = 1
self.transport.unlink_channel(self.chanid) self.transport._unlink_channel(self.chanid)
finally: finally:
self.lock.release() self.lock.release()
@ -371,7 +371,7 @@ class Channel(object):
m.add_byte(chr(MSG_CHANNEL_DATA)) m.add_byte(chr(MSG_CHANNEL_DATA))
m.add_int(self.remote_chanid) m.add_int(self.remote_chanid)
m.add_string(s[:size]) m.add_string(s[:size])
self.transport.send_message(m) self.transport._send_message(m)
self.out_window_size -= size self.out_window_size -= size
finally: finally:
self.lock.release() self.lock.release()
@ -506,25 +506,25 @@ class Channel(object):
self.in_buffer = self.in_buffer[nbytes:] self.in_buffer = self.in_buffer[nbytes:]
os.write(self.pipd_wfd, x) os.write(self.pipd_wfd, x)
def unlink(self): def _unlink(self):
if self.closed or not self.active: if self.closed or not self.active:
return return
self.closed = 1 self.closed = 1
self.transport.unlink_channel(self.chanid) self.transport._unlink_channel(self.chanid)
def check_add_window(self, n): def check_add_window(self, n):
# already holding the lock! # already holding the lock!
if self.closed or self.eof_received or not self.active: if self.closed or self.eof_received or not self.active:
return return
self.log(DEBUG, 'addwindow %d' % n) self._log(DEBUG, 'addwindow %d' % n)
self.in_window_sofar += n self.in_window_sofar += n
if self.in_window_sofar > self.in_window_threshold: if self.in_window_sofar > self.in_window_threshold:
self.log(DEBUG, 'addwindow send %d' % self.in_window_sofar) self._log(DEBUG, 'addwindow send %d' % self.in_window_sofar)
m = Message() m = Message()
m.add_byte(chr(MSG_CHANNEL_WINDOW_ADJUST)) m.add_byte(chr(MSG_CHANNEL_WINDOW_ADJUST))
m.add_int(self.remote_chanid) m.add_int(self.remote_chanid)
m.add_int(self.in_window_sofar) m.add_int(self.in_window_sofar)
self.transport.send_message(m) self.transport._send_message(m)
self.in_window_sofar = 0 self.in_window_sofar = 0

View File

@ -5,7 +5,7 @@
# LOT more on the server side). # LOT more on the server side).
from message import Message from message import Message
from util import inflate_long, deflate_long, generate_prime, bit_length from util import inflate_long, deflate_long, bit_length
from ssh_exception import SSHException from ssh_exception import SSHException
from transport import MSG_NEWKEYS from transport import MSG_NEWKEYS
from Crypto.Hash import SHA from Crypto.Hash import SHA
@ -27,7 +27,7 @@ class KexGex(object):
def start_kex(self): def start_kex(self):
if self.transport.server_mode: if self.transport.server_mode:
self.transport.expected_packet = MSG_KEXDH_GEX_REQUEST self.transport._expect_packet(MSG_KEXDH_GEX_REQUEST)
return return
# request a bit range: we accept (min_bits) to (max_bits), but prefer # request a bit range: we accept (min_bits) to (max_bits), but prefer
# (preferred_bits). according to the spec, we shouldn't pull the # (preferred_bits). according to the spec, we shouldn't pull the
@ -37,21 +37,21 @@ class KexGex(object):
m.add_int(self.min_bits) m.add_int(self.min_bits)
m.add_int(self.preferred_bits) m.add_int(self.preferred_bits)
m.add_int(self.max_bits) m.add_int(self.max_bits)
self.transport.send_message(m) self.transport._send_message(m)
self.transport.expected_packet = MSG_KEXDH_GEX_GROUP self.transport._expect_packet(MSG_KEXDH_GEX_GROUP)
def parse_next(self, ptype, m): def parse_next(self, ptype, m):
if ptype == MSG_KEXDH_GEX_REQUEST: if ptype == MSG_KEXDH_GEX_REQUEST:
return self.parse_kexdh_gex_request(m) return self._parse_kexdh_gex_request(m)
elif ptype == MSG_KEXDH_GEX_GROUP: elif ptype == MSG_KEXDH_GEX_GROUP:
return self.parse_kexdh_gex_group(m) return self._parse_kexdh_gex_group(m)
elif ptype == MSG_KEXDH_GEX_INIT: elif ptype == MSG_KEXDH_GEX_INIT:
return self.parse_kexdh_gex_init(m) return self._parse_kexdh_gex_init(m)
elif ptype == MSG_KEXDH_GEX_REPLY: elif ptype == MSG_KEXDH_GEX_REPLY:
return self.parse_kexdh_gex_reply(m) return self._parse_kexdh_gex_reply(m)
raise SSHException('KexGex asked to handle packet type %d' % ptype) raise SSHException('KexGex asked to handle packet type %d' % ptype)
def generate_x(self): def _generate_x(self):
# generate an "x" (1 < x < (p-1)/2). # generate an "x" (1 < x < (p-1)/2).
q = (self.p - 1) // 2 q = (self.p - 1) // 2
qnorm = deflate_long(q, 0) qnorm = deflate_long(q, 0)
@ -70,7 +70,7 @@ class KexGex(object):
break break
self.x = x self.x = x
def parse_kexdh_gex_request(self, m): def _parse_kexdh_gex_request(self, m):
min = m.get_int() min = m.get_int()
preferred = m.get_int() preferred = m.get_int()
max = m.get_int() max = m.get_int()
@ -79,52 +79,53 @@ class KexGex(object):
preferred = self.max_bits preferred = self.max_bits
if preferred < self.min_bits: if preferred < self.min_bits:
preferred = self.min_bits preferred = self.min_bits
# fix min/max if they're inconsistent. technically, we could just pout
# and hang up, but there's no harm in giving them the benefit of the
# doubt and just picking a bitsize for them.
if min > preferred:
min = preferred
if max < preferred:
max = preferred
# now save a copy # now save a copy
self.min_bits = min self.min_bits = min
self.preferred_bits = preferred self.preferred_bits = preferred
self.max_bits = max self.max_bits = max
# generate prime # generate prime
while 1: pack = self.transport._get_modulus_pack()
# does not work FIXME if pack is None:
# the problem is that it's too fscking SLOW raise SSHException('Can\'t do server-side gex with no modulus pack')
self.transport.log(DEBUG, 'stir...') self.g, self.p = pack.get_modulus(min, preferred, max)
self.transport.randpool.stir()
self.transport.log(DEBUG, 'get-prime %d...' % preferred)
self.p = generate_prime(preferred, self.transport.randpool)
self.transport.log(DEBUG, 'got ' + repr(self.p))
if number.isPrime((self.p - 1) // 2):
break
self.g = 2
m = Message() m = Message()
m.add_byte(chr(MSG_KEXDH_GEX_GROUP)) m.add_byte(chr(MSG_KEXDH_GEX_GROUP))
m.add_mpint(self.p) m.add_mpint(self.p)
m.add_mpint(self.g) m.add_mpint(self.g)
self.transport.send_message(m) self.transport._send_message(m)
self.transport.expected_packet = MSG_KEXDH_GEX_INIT self.transport._expect_packet(MSG_KEXDH_GEX_INIT)
def parse_kexdh_gex_group(self, m): def _parse_kexdh_gex_group(self, m):
self.p = m.get_mpint() self.p = m.get_mpint()
self.g = m.get_mpint() self.g = m.get_mpint()
# reject if p's bit length < 1024 or > 8192 # reject if p's bit length < 1024 or > 8192
bitlen = bit_length(self.p) bitlen = bit_length(self.p)
if (bitlen < 1024) or (bitlen > 8192): if (bitlen < 1024) or (bitlen > 8192):
raise SSHException('Server-generated gex p (don\'t ask) is out of range (%d bits)' % bitlen) raise SSHException('Server-generated gex p (don\'t ask) is out of range (%d bits)' % bitlen)
self.transport.log(DEBUG, 'Got server p (%d bits)' % bitlen) self.transport._log(DEBUG, 'Got server p (%d bits)' % bitlen)
self.generate_x() self._generate_x()
# now compute e = g^x mod p # now compute e = g^x mod p
self.e = pow(self.g, self.x, self.p) self.e = pow(self.g, self.x, self.p)
m = Message() m = Message()
m.add_byte(chr(MSG_KEXDH_GEX_INIT)) m.add_byte(chr(MSG_KEXDH_GEX_INIT))
m.add_mpint(self.e) m.add_mpint(self.e)
self.transport.send_message(m) self.transport._send_message(m)
self.transport.expected_packet = MSG_KEXDH_GEX_REPLY self.transport._expect_packet(MSG_KEXDH_GEX_REPLY)
def parse_kexdh_gex_init(self, m): def _parse_kexdh_gex_init(self, m):
self.e = m.get_mpint() self.e = m.get_mpint()
if (self.e < 1) or (self.e > self.p - 1): if (self.e < 1) or (self.e > self.p - 1):
raise SSHException('Client kex "e" is out of range') raise SSHException('Client kex "e" is out of range')
self.generate_x() self._generate_x()
K = pow(self.e, self.x, P) self.f = pow(self.g, self.x, self.p)
K = pow(self.e, self.x, self.p)
key = str(self.transport.get_server_key()) key = str(self.transport.get_server_key())
# okay, build up the hash H of (V_C || V_S || I_C || I_S || K_S || min || n || max || p || g || e || f || K) # okay, build up the hash H of (V_C || V_S || I_C || I_S || K_S || min || n || max || p || g || e || f || K)
hm = Message().add(self.transport.remote_version).add(self.transport.local_version) hm = Message().add(self.transport.remote_version).add(self.transport.local_version)
@ -136,7 +137,7 @@ class KexGex(object):
hm.add_mpint(self.g) hm.add_mpint(self.g)
hm.add(self.e).add(self.f).add(K) hm.add(self.e).add(self.f).add(K)
H = SHA.new(str(hm)).digest() H = SHA.new(str(hm)).digest()
self.transport.set_K_H(K, H) self.transport._set_K_H(K, H)
# sign it # sign it
sig = self.transport.get_server_key().sign_ssh_data(self.transport.randpool, H) sig = self.transport.get_server_key().sign_ssh_data(self.transport.randpool, H)
# send reply # send reply
@ -145,11 +146,10 @@ class KexGex(object):
m.add_string(key) m.add_string(key)
m.add_mpint(self.f) m.add_mpint(self.f)
m.add_string(sig) m.add_string(sig)
self.transport.send_message(m) self.transport._send_message(m)
self.transport.activate_outbound() self.transport._activate_outbound()
self.transport.expected_packet = MSG_NEWKEYS
def parse_kexdh_gex_reply(self, m): def _parse_kexdh_gex_reply(self, m):
host_key = m.get_string() host_key = m.get_string()
self.f = m.get_mpint() self.f = m.get_mpint()
sig = m.get_string() sig = m.get_string()
@ -165,9 +165,9 @@ class KexGex(object):
hm.add_mpint(self.p) hm.add_mpint(self.p)
hm.add_mpint(self.g) hm.add_mpint(self.g)
hm.add(self.e).add(self.f).add(K) hm.add(self.e).add(self.f).add(K)
self.transport.set_K_H(K, SHA.new(str(hm)).digest()) self.transport._set_K_H(K, SHA.new(str(hm)).digest())
self.transport.verify_key(host_key, sig) self.transport._verify_key(host_key, sig)
self.transport.activate_outbound() self.transport._activate_outbound()
self.transport.expected_packet = MSG_NEWKEYS

View File

@ -44,15 +44,15 @@ class KexGroup1(object):
if self.transport.server_mode: if self.transport.server_mode:
# compute f = g^x mod p, but don't send it yet # compute f = g^x mod p, but don't send it yet
self.f = pow(G, self.x, P) self.f = pow(G, self.x, P)
self.transport.expected_packet = MSG_KEXDH_INIT self.transport._expect_packet(MSG_KEXDH_INIT)
return return
# compute e = g^x mod p (where g=2), and send it # compute e = g^x mod p (where g=2), and send it
self.e = pow(G, self.x, P) self.e = pow(G, self.x, P)
m = Message() m = Message()
m.add_byte(chr(MSG_KEXDH_INIT)) m.add_byte(chr(MSG_KEXDH_INIT))
m.add_mpint(self.e) m.add_mpint(self.e)
self.transport.send_message(m) self.transport._send_message(m)
self.transport.expected_packet = MSG_KEXDH_REPLY self.transport._expect_packet(MSG_KEXDH_REPLY)
def parse_next(self, ptype, m): def parse_next(self, ptype, m):
if self.transport.server_mode and (ptype == MSG_KEXDH_INIT): if self.transport.server_mode and (ptype == MSG_KEXDH_INIT):
@ -73,10 +73,9 @@ class KexGroup1(object):
hm = Message().add(self.transport.local_version).add(self.transport.remote_version) hm = Message().add(self.transport.local_version).add(self.transport.remote_version)
hm.add(self.transport.local_kex_init).add(self.transport.remote_kex_init).add(host_key) hm.add(self.transport.local_kex_init).add(self.transport.remote_kex_init).add(host_key)
hm.add(self.e).add(self.f).add(K) hm.add(self.e).add(self.f).add(K)
self.transport.set_K_H(K, SHA.new(str(hm)).digest()) self.transport._set_K_H(K, SHA.new(str(hm)).digest())
self.transport.verify_key(host_key, sig) self.transport._verify_key(host_key, sig)
self.transport.activate_outbound() self.transport._activate_outbound()
self.transport.expected_packet = MSG_NEWKEYS
def parse_kexdh_init(self, m): def parse_kexdh_init(self, m):
# server mode # server mode
@ -90,7 +89,7 @@ class KexGroup1(object):
hm.add(self.transport.remote_kex_init).add(self.transport.local_kex_init).add(key) hm.add(self.transport.remote_kex_init).add(self.transport.local_kex_init).add(key)
hm.add(self.e).add(self.f).add(K) hm.add(self.e).add(self.f).add(K)
H = SHA.new(str(hm)).digest() H = SHA.new(str(hm)).digest()
self.transport.set_K_H(K, H) self.transport._set_K_H(K, H)
# sign it # sign it
sig = self.transport.get_server_key().sign_ssh_data(self.transport.randpool, H) sig = self.transport.get_server_key().sign_ssh_data(self.transport.randpool, H)
# send reply # send reply
@ -99,6 +98,5 @@ class KexGroup1(object):
m.add_string(key) m.add_string(key)
m.add_mpint(self.f) m.add_mpint(self.f)
m.add_string(sig) m.add_string(sig)
self.transport.send_message(m) self.transport._send_message(m)
self.transport.activate_outbound() self.transport._activate_outbound()
self.transport.expected_packet = MSG_NEWKEYS

128
paramiko/primes.py Normal file
View File

@ -0,0 +1,128 @@
# utility functions for dealing with primes
from Crypto.Util import number
from util import bit_length, inflate_long
def generate_prime(bits, randpool):
hbyte_mask = pow(2, bits % 8) - 1
while 1:
# loop catches the case where we increment n into a higher bit-range
x = randpool.get_bytes((bits+7) // 8)
if hbyte_mask > 0:
x = chr(ord(x[0]) & hbyte_mask) + x[1:]
n = inflate_long(x, 1)
n |= 1
n |= (1 << (bits - 1))
while not number.isPrime(n):
n += 2
if bit_length(n) == bits:
return n
def roll_random(randpool, n):
"returns a random # from 0 to N-1"
bits = bit_length(n-1)
bytes = (bits + 7) // 8
hbyte_mask = pow(2, bits % 8) - 1
# so here's the plan:
# we fetch as many random bits as we'd need to fit N-1, and if the
# generated number is >= N, we try again. in the worst case (N-1 is a
# power of 2), we have slightly better than 50% odds of getting one that
# fits, so i can't guarantee that this loop will ever finish, but the odds
# of it looping forever should be infinitesimal.
while 1:
x = randpool.get_bytes(bytes)
if hbyte_mask > 0:
x = chr(ord(x[0]) & hbyte_mask) + x[1:]
num = inflate_long(x, 1)
if num < n:
return num
class ModulusPack (object):
"""
convenience object for holding the contents of the /etc/ssh/moduli file,
on systems that have such a file.
"""
def __init__(self, randpool):
# pack is a hash of: bits -> [ (generator, modulus) ... ]
self.pack = {}
self.discarded = []
self.randpool = randpool
def _parse_modulus(self, line):
timestamp, type, tests, tries, size, generator, modulus = line.split()
type = int(type)
tests = int(tests)
tries = int(tries)
size = int(size)
generator = int(generator)
modulus = long(modulus, 16)
# weed out primes that aren't at least:
# type 2 (meets basic structural requirements)
# test 4 (more than just a small-prime sieve)
# tries < 100 if test & 4 (at least 100 tries of miller-rabin)
if (type < 2) or (tests < 4) or ((tests & 4) and (tests < 8) and (tries < 100)):
self.discarded.append((modulus, 'does not meet basic requirements'))
return
if generator == 0:
generator = 2
# there's a bug in the ssh "moduli" file (yeah, i know: shock! dismay!
# call cnn!) where it understates the bit lengths of these primes by 1.
# this is okay.
bl = bit_length(modulus)
if (bl != size) and (bl != size + 1):
self.discarded.append((modulus, 'incorrectly reported bit length %d' % size))
return
if not self.pack.has_key(bl):
self.pack[bl] = []
self.pack[bl].append((generator, modulus))
def read_file(self, filename):
"""
@raise IOError: passed from any file operations that fail.
"""
self.pack = {}
f = open(filename, 'r')
for line in f:
line = line.strip()
if (len(line) == 0) or (line[0] == '#'):
continue
try:
self._parse_modulus(line)
except:
continue
f.close()
def get_modulus(self, min, prefer, max):
bitsizes = self.pack.keys()
bitsizes.sort()
if len(bitsizes) == 0:
raise SSHException('no moduli available')
good = -1
# find nearest bitsize >= preferred
for b in bitsizes:
if (b >= prefer) and (b < max) and ((b < good) or (good == -1)):
good = b
# if that failed, find greatest bitsize >= min
if good == -1:
for b in bitsizes:
if (b >= min) and (b < max) and (b > good):
good = b
if good == -1:
# their entire (min, max) range has no intersection with our range.
# if their range is below ours, pick the smallest. otherwise pick
# the largest. it'll be out of their range requirement either way,
# but we'll be sending them the closest one we have.
good = bitsizes[0]
if min > good:
good = bitsizes[-1]
# now pick a random modulus of this bitsize
n = roll_random(self.randpool, len(self.pack[good]))
return self.pack[good][n]

View File

@ -12,6 +12,7 @@ MSG_CHANNEL_OPEN, MSG_CHANNEL_OPEN_SUCCESS, MSG_CHANNEL_OPEN_FAILURE, \
MSG_CHANNEL_SUCCESS, MSG_CHANNEL_FAILURE = range(90, 101) MSG_CHANNEL_SUCCESS, MSG_CHANNEL_FAILURE = range(90, 101)
import sys, os, string, threading, socket, logging, struct import sys, os, string, threading, socket, logging, struct
from ssh_exception import SSHException
from message import Message from message import Message
from channel import Channel from channel import Channel
from util import format_binary, safe_string, inflate_long, deflate_long, tb_strings from util import format_binary, safe_string, inflate_long, deflate_long, tb_strings
@ -19,6 +20,7 @@ from rsakey import RSAKey
from dsskey import DSSKey from dsskey import DSSKey
from kex_group1 import KexGroup1 from kex_group1 import KexGroup1
from kex_gex import KexGex from kex_gex import KexGex
from primes import ModulusPack
# these come from PyCrypt # these come from PyCrypt
# http://www.amk.ca/python/writing/pycrypt/ # http://www.amk.ca/python/writing/pycrypt/
@ -81,21 +83,21 @@ class BaseTransport(threading.Thread):
preferred_keys = [ 'ssh-rsa', 'ssh-dss' ] preferred_keys = [ 'ssh-rsa', 'ssh-dss' ]
preferred_kex = [ 'diffie-hellman-group1-sha1', 'diffie-hellman-group-exchange-sha1' ] preferred_kex = [ 'diffie-hellman-group1-sha1', 'diffie-hellman-group-exchange-sha1' ]
cipher_info = { _cipher_info = {
'blowfish-cbc': { 'class': Blowfish, 'mode': Blowfish.MODE_CBC, 'block-size': 8, 'key-size': 16 }, 'blowfish-cbc': { 'class': Blowfish, 'mode': Blowfish.MODE_CBC, 'block-size': 8, 'key-size': 16 },
'aes128-cbc': { 'class': AES, 'mode': AES.MODE_CBC, 'block-size': 16, 'key-size': 16 }, 'aes128-cbc': { 'class': AES, 'mode': AES.MODE_CBC, 'block-size': 16, 'key-size': 16 },
'aes256-cbc': { 'class': AES, 'mode': AES.MODE_CBC, 'block-size': 16, 'key-size': 32 }, 'aes256-cbc': { 'class': AES, 'mode': AES.MODE_CBC, 'block-size': 16, 'key-size': 32 },
'3des-cbc': { 'class': DES3, 'mode': DES3.MODE_CBC, 'block-size': 8, 'key-size': 24 }, '3des-cbc': { 'class': DES3, 'mode': DES3.MODE_CBC, 'block-size': 8, 'key-size': 24 },
} }
mac_info = { _mac_info = {
'hmac-sha1': { 'class': SHA, 'size': 20 }, 'hmac-sha1': { 'class': SHA, 'size': 20 },
'hmac-sha1-96': { 'class': SHA, 'size': 12 }, 'hmac-sha1-96': { 'class': SHA, 'size': 12 },
'hmac-md5': { 'class': MD5, 'size': 16 }, 'hmac-md5': { 'class': MD5, 'size': 16 },
'hmac-md5-96': { 'class': MD5, 'size': 12 }, 'hmac-md5-96': { 'class': MD5, 'size': 12 },
} }
kex_info = { _kex_info = {
'diffie-hellman-group1-sha1': KexGroup1, 'diffie-hellman-group1-sha1': KexGroup1,
'diffie-hellman-group-exchange-sha1': KexGex, 'diffie-hellman-group-exchange-sha1': KexGex,
} }
@ -107,7 +109,7 @@ class BaseTransport(threading.Thread):
OPEN_FAILED_RESOURCE_SHORTAGE = range(1, 5) OPEN_FAILED_RESOURCE_SHORTAGE = range(1, 5)
def __init__(self, sock): def __init__(self, sock):
threading.Thread.__init__(self) threading.Thread.__init__(self, target=self._run)
self.randpool = randpool self.randpool = randpool
self.sock = sock self.sock = sock
self.sock.settimeout(0.1) self.sock.settimeout(0.1)
@ -123,11 +125,10 @@ class BaseTransport(threading.Thread):
self.session_id = None self.session_id = None
# /negotiated crypto parameters # /negotiated crypto parameters
self.expected_packet = 0 self.expected_packet = 0
self.active = 0 self.active = False
self.initial_kex_done = 0 self.initial_kex_done = 0
self.write_lock = threading.Lock() # lock around outbound writes (packet computation) self.write_lock = threading.Lock() # lock around outbound writes (packet computation)
self.lock = threading.Lock() # synchronization (always higher level than write_lock) self.lock = threading.Lock() # synchronization (always higher level than write_lock)
self.authenticated = 0
self.channels = { } # (id -> Channel) self.channels = { } # (id -> Channel)
self.channel_events = { } # (id -> Event) self.channel_events = { } # (id -> Event)
self.channel_counter = 1 self.channel_counter = 1
@ -135,6 +136,7 @@ class BaseTransport(threading.Thread):
self.window_size = 65536 self.window_size = 65536
self.max_packet_size = 2048 self.max_packet_size = 2048
self.ultra_debug = 0 self.ultra_debug = 0
self.modulus_pack = None
# used for noticing when to re-key: # used for noticing when to re-key:
self.received_bytes = 0 self.received_bytes = 0
self.received_packets = 0 self.received_packets = 0
@ -165,27 +167,69 @@ class BaseTransport(threading.Thread):
except KeyError: except KeyError:
return None return None
def load_server_moduli(self, filename=None):
"""
I{(optional)}
Load a file of prime moduli for use in doing group-exchange key
negotiation in server mode. It's a rather obscure option and can be
safely ignored.
In server mode, the remote client may request "group-exchange" key
negotiation, which asks the server to send a random prime number that
fits certain criteria. These primes are pretty difficult to compute,
so they can't be generated on demand. But many systems contain a file
of suitable primes (usually named something like C{/etc/ssh/moduli}).
If you call C{load_server_moduli} and it returns C{True}, then this
file of primes has been loaded and we will support "group-exchange" in
server mode. Otherwise server mode will just claim that it doesn't
support that method of key negotiation.
@param filename: optional path to the moduli file, if you happen to
know that it's not in a standard location.
@type filename: string
@return: True if a moduli file was successfully loaded; False
otherwise.
@rtype: boolean
@since: doduo
@note: This has no effect when used in client mode.
"""
self.modulus_pack = ModulusPack(self.randpool)
# places to look for the openssh "moduli" file
file_list = [ '/etc/ssh/moduli', '/usr/local/etc/moduli' ]
if filename is not None:
file_list.insert(0, filename)
for fn in file_list:
try:
self.modulus_pack.read_file(fn)
return True
except IOError:
pass
# none succeeded
self.modulus_pack = None
return False
def _get_modulus_pack(self):
"used by KexGex to find primes for group exchange"
return self.modulus_pack
def __repr__(self): def __repr__(self):
if not self.active: if not self.active:
return '<paramiko.Transport (unconnected)>' return '<paramiko.BaseTransport (unconnected)>'
out = '<sesch.Transport' out = '<paramiko.BaseTransport'
#if self.remote_version != '': #if self.remote_version != '':
# out += ' (server version "%s")' % self.remote_version # out += ' (server version "%s")' % self.remote_version
if self.local_cipher != '': if self.local_cipher != '':
out += ' (cipher %s)' % self.local_cipher out += ' (cipher %s)' % self.local_cipher
if self.authenticated: if len(self.channels) == 1:
if len(self.channels) == 1: out += ' (active; 1 open channel)'
out += ' (active; 1 open channel)'
else:
out += ' (active; %d open channels)' % len(self.channels)
elif self.initial_kex_done:
out += ' (connected; awaiting auth)'
else: else:
out += ' (connecting)' out += ' (active; %d open channels)' % len(self.channels)
out += '>' out += '>'
return out return out
def log(self, level, msg): def _log(self, level, msg):
if type(msg) == type([]): if type(msg) == type([]):
for m in msg: for m in msg:
self.logger.log(level, m) self.logger.log(level, m)
@ -193,14 +237,29 @@ class BaseTransport(threading.Thread):
self.logger.log(level, msg) self.logger.log(level, msg)
def close(self): def close(self):
self.active = 0 """
Close this session, and any open channels that are tied to it.
"""
self.active = False
self.engine_in = self.engine_out = None self.engine_in = self.engine_out = None
self.sequence_number_in = self.sequence_number_out = 0L self.sequence_number_in = self.sequence_number_out = 0L
for chan in self.channels.values(): for chan in self.channels.values():
chan.unlink() chan._unlink()
def get_remote_server_key(self): def get_remote_server_key(self):
'returns (type, key) where type is like "ssh-rsa" and key is an opaque string' """
Return the host key of the server (in client mode).
The type string is usually either C{"ssh-rsa"} or C{"ssh-dss"} and the
key is an opaque string, which may be saved or used for comparison with
previously-seen keys. (In other words, you don't need to worry about
the content of the key, only that it compares equal to the key you
expected to see.)
@raise SSHException: if no session is currently active.
@return: tuple of (key type, key)
@rtype: (string, string)
"""
if (not self.active) or (not self.initial_kex_done): if (not self.active) or (not self.initial_kex_done):
raise SSHException('No existing session') raise SSHException('No existing session')
key_msg = Message(self.host_key) key_msg = Message(self.host_key)
@ -208,10 +267,13 @@ class BaseTransport(threading.Thread):
return key_type, self.host_key return key_type, self.host_key
def is_active(self): def is_active(self):
return self.active """
Return true if this session is active (open).
def is_authenticated(self): @return: True if the session is still active (open); False if the session is closed.
return self.authenticated and self.active @rtype: boolean
"""
return self.active
def open_session(self): def open_session(self):
return self.open_channel('session') return self.open_channel('session')
@ -230,9 +292,9 @@ class BaseTransport(threading.Thread):
m.add_int(self.max_packet_size) m.add_int(self.max_packet_size)
self.channels[chanid] = chan = Channel(chanid) self.channels[chanid] = chan = Channel(chanid)
self.channel_events[chanid] = event = threading.Event() self.channel_events[chanid] = event = threading.Event()
chan.set_transport(self) chan._set_transport(self)
chan.set_window(self.window_size, self.max_packet_size) chan.set_window(self.window_size, self.max_packet_size)
self.send_message(m) self._send_message(m)
finally: finally:
self.lock.release() self.lock.release()
while 1: while 1:
@ -249,7 +311,8 @@ class BaseTransport(threading.Thread):
self.lock.release() self.lock.release()
return chan return chan
def unlink_channel(self, chanid): def _unlink_channel(self, chanid):
"used by a Channel to remove itself from the active channel list"
try: try:
self.lock.acquire() self.lock.acquire()
if self.channels.has_key(chanid): if self.channels.has_key(chanid):
@ -257,7 +320,7 @@ class BaseTransport(threading.Thread):
finally: finally:
self.lock.release() self.lock.release()
def read_all(self, n): def _read_all(self, n):
out = '' out = ''
while n > 0: while n > 0:
try: try:
@ -271,7 +334,7 @@ class BaseTransport(threading.Thread):
raise EOFError() raise EOFError()
return out return out
def write_all(self, out): def _write_all(self, out):
while len(out) > 0: while len(out) > 0:
n = self.sock.send(out) n = self.sock.send(out)
if n <= 0: if n <= 0:
@ -281,7 +344,7 @@ class BaseTransport(threading.Thread):
out = out[n:] out = out[n:]
return return
def build_packet(self, payload): def _build_packet(self, payload):
# pad up at least 4 bytes, to nearest block-size (usually 8) # pad up at least 4 bytes, to nearest block-size (usually 8)
bsize = self.block_size_out bsize = self.block_size_out
padding = 3 + bsize - ((len(payload) + 8) % bsize) padding = 3 + bsize - ((len(payload) + 8) % bsize)
@ -291,11 +354,12 @@ class BaseTransport(threading.Thread):
packet += randpool.get_bytes(padding) packet += randpool.get_bytes(padding)
return packet return packet
def send_message(self, data): def _send_message(self, data):
# FIXME: should we check for rekeying here too?
# encrypt this sucka # encrypt this sucka
packet = self.build_packet(str(data)) packet = self._build_packet(str(data))
if self.ultra_debug: if self.ultra_debug:
self.log(DEBUG, format_binary(packet, 'OUT: ')) self._log(DEBUG, format_binary(packet, 'OUT: '))
if self.engine_out != None: if self.engine_out != None:
out = self.engine_out.encrypt(packet) out = self.engine_out.encrypt(packet)
else: else:
@ -308,29 +372,29 @@ class BaseTransport(threading.Thread):
out += HMAC.HMAC(self.mac_key_out, payload, self.local_mac_engine).digest()[:self.local_mac_len] out += HMAC.HMAC(self.mac_key_out, payload, self.local_mac_engine).digest()[:self.local_mac_len]
self.sequence_number_out += 1L self.sequence_number_out += 1L
self.sequence_number_out %= 0x100000000L self.sequence_number_out %= 0x100000000L
self.write_all(out) self._write_all(out)
finally: finally:
self.write_lock.release() self.write_lock.release()
def read_message(self): def _read_message(self):
"only one thread will ever be in this function" "only one thread will ever be in this function"
header = self.read_all(self.block_size_in) header = self._read_all(self.block_size_in)
if self.engine_in != None: if self.engine_in != None:
header = self.engine_in.decrypt(header) header = self.engine_in.decrypt(header)
if self.ultra_debug: if self.ultra_debug:
self.log(DEBUG, format_binary(header, 'IN: ')); self._log(DEBUG, format_binary(header, 'IN: '));
packet_size = struct.unpack('>I', header[:4])[0] packet_size = struct.unpack('>I', header[:4])[0]
# leftover contains decrypted bytes from the first block (after the length field) # leftover contains decrypted bytes from the first block (after the length field)
leftover = header[4:] leftover = header[4:]
if (packet_size - len(leftover)) % self.block_size_in != 0: if (packet_size - len(leftover)) % self.block_size_in != 0:
raise SSHException('Invalid packet blocking') raise SSHException('Invalid packet blocking')
buffer = self.read_all(packet_size + self.remote_mac_len - len(leftover)) buffer = self._read_all(packet_size + self.remote_mac_len - len(leftover))
packet = buffer[:packet_size - len(leftover)] packet = buffer[:packet_size - len(leftover)]
post_packet = buffer[packet_size - len(leftover):] post_packet = buffer[packet_size - len(leftover):]
if self.engine_in != None: if self.engine_in != None:
packet = self.engine_in.decrypt(packet) packet = self.engine_in.decrypt(packet)
if self.ultra_debug: if self.ultra_debug:
self.log(DEBUG, format_binary(packet, 'IN: ')); self._log(DEBUG, format_binary(packet, 'IN: '));
packet = leftover + packet packet = leftover + packet
if self.remote_mac_len > 0: if self.remote_mac_len > 0:
mac = post_packet[:self.remote_mac_len] mac = post_packet[:self.remote_mac_len]
@ -341,7 +405,7 @@ class BaseTransport(threading.Thread):
padding = ord(packet[0]) padding = ord(packet[0])
payload = packet[1:packet_size - padding + 1] payload = packet[1:packet_size - padding + 1]
randpool.add_event(packet[packet_size - padding + 1]) randpool.add_event(packet[packet_size - padding + 1])
#self.log(DEBUG, 'Got payload (%d bytes, %d padding)' % (packet_size, padding)) #self._log(DEBUG, 'Got payload (%d bytes, %d padding)' % (packet_size, padding))
msg = Message(payload[1:]) msg = Message(payload[1:])
msg.seqno = self.sequence_number_in msg.seqno = self.sequence_number_in
self.sequence_number_in = (self.sequence_number_in + 1) & 0xffffffffL self.sequence_number_in = (self.sequence_number_in + 1) & 0xffffffffL
@ -351,10 +415,10 @@ class BaseTransport(threading.Thread):
if (self.received_packets >= self.REKEY_PACKETS) or (self.received_bytes >= self.REKEY_BYTES): if (self.received_packets >= self.REKEY_PACKETS) or (self.received_bytes >= self.REKEY_BYTES):
# only ask once for rekeying # only ask once for rekeying
if self.local_kex_init is None: if self.local_kex_init is None:
self.log(DEBUG, 'Rekeying (hit %d packets, %d bytes)' % (self.received_packets, self._log(DEBUG, 'Rekeying (hit %d packets, %d bytes)' % (self.received_packets,
self.received_bytes)) self.received_bytes))
self.received_packets_overflow = 0 self.received_packets_overflow = 0
self.send_kex_init() self._send_kex_init()
else: else:
# we've asked to rekey already -- give them 20 packets to # we've asked to rekey already -- give them 20 packets to
# comply, then just drop the connection # comply, then just drop the connection
@ -364,14 +428,18 @@ class BaseTransport(threading.Thread):
return ord(payload[0]), msg return ord(payload[0]), msg
def set_K_H(self, k, h): def _set_K_H(self, k, h):
"used by a kex object to set the K (root key) and H (exchange hash)" "used by a kex object to set the K (root key) and H (exchange hash)"
self.K = k self.K = k
self.H = h self.H = h
if self.session_id == None: if self.session_id == None:
self.session_id = h self.session_id = h
def verify_key(self, host_key, sig): def _expect_packet(self, type):
"used by a kex object to register the next packet type it expects to see"
self.expected_packet = type
def _verify_key(self, host_key, sig):
if self.host_key_type == 'ssh-rsa': if self.host_key_type == 'ssh-rsa':
key = RSAKey(Message(host_key)) key = RSAKey(Message(host_key))
elif self.host_key_type == 'ssh-dss': elif self.host_key_type == 'ssh-dss':
@ -384,7 +452,7 @@ class BaseTransport(threading.Thread):
raise SSHException('Signature verification (%s) failed. Boo. Robey should debug this.' % self.host_key_type) raise SSHException('Signature verification (%s) failed. Boo. Robey should debug this.' % self.host_key_type)
self.host_key = host_key self.host_key = host_key
def compute_key(self, id, nbytes): def _compute_key(self, id, nbytes):
"id is 'A' - 'F' for the various keys used by ssh" "id is 'A' - 'F' for the various keys used by ssh"
m = Message() m = Message()
m.add_mpint(self.K) m.add_mpint(self.K)
@ -402,30 +470,30 @@ class BaseTransport(threading.Thread):
sofar += hash sofar += hash
return out[:nbytes] return out[:nbytes]
def get_cipher(self, name, key, iv): def _get_cipher(self, name, key, iv):
if not self.cipher_info.has_key(name): if not self._cipher_info.has_key(name):
raise SSHException('Unknown client cipher ' + name) raise SSHException('Unknown client cipher ' + name)
return self.cipher_info[name]['class'].new(key, self.cipher_info[name]['mode'], iv) return self._cipher_info[name]['class'].new(key, self._cipher_info[name]['mode'], iv)
def run(self): def _run(self):
self.active = 1 self.active = True
try: try:
# SSH-1.99-OpenSSH_2.9p2 # SSH-1.99-OpenSSH_2.9p2
self.write_all(self.local_version + '\r\n') self._write_all(self.local_version + '\r\n')
self.check_banner() self._check_banner()
self.send_kex_init() self._send_kex_init()
self.expected_packet = MSG_KEXINIT self.expected_packet = MSG_KEXINIT
while self.active: while self.active:
ptype, m = self.read_message() ptype, m = self._read_message()
if ptype == MSG_IGNORE: if ptype == MSG_IGNORE:
continue continue
elif ptype == MSG_DISCONNECT: elif ptype == MSG_DISCONNECT:
self.parse_disconnect(m) self._parse_disconnect(m)
self.active = 0 self.active = False
break break
elif ptype == MSG_DEBUG: elif ptype == MSG_DEBUG:
self.parse_debug(m) self._parse_debug(m)
continue continue
if self.expected_packet != 0: if self.expected_packet != 0:
if ptype != self.expected_packet: if ptype != self.expected_packet:
@ -435,28 +503,29 @@ class BaseTransport(threading.Thread):
self.kex_engine.parse_next(ptype, m) self.kex_engine.parse_next(ptype, m)
continue continue
if self.handler_table.has_key(ptype): if self._handler_table.has_key(ptype):
self.handler_table[ptype](self, m) self._handler_table[ptype](self, m)
elif self.channel_handler_table.has_key(ptype): elif self._channel_handler_table.has_key(ptype):
chanid = m.get_int() chanid = m.get_int()
if self.channels.has_key(chanid): if self.channels.has_key(chanid):
self.channel_handler_table[ptype](self.channels[chanid], m) self._channel_handler_table[ptype](self.channels[chanid], m)
else: else:
self.log(WARNING, 'Oops, unhandled type %d' % ptype) self._log(WARNING, 'Oops, unhandled type %d' % ptype)
msg = Message() msg = Message()
msg.add_byte(chr(MSG_UNIMPLEMENTED)) msg.add_byte(chr(MSG_UNIMPLEMENTED))
msg.add_int(m.seqno) msg.add_int(m.seqno)
self.send_message(msg) self._send_message(msg)
except SSHException, e: except SSHException, e:
self.log(DEBUG, 'Exception: ' + str(e)) self._log(DEBUG, 'Exception: ' + str(e))
self.log(DEBUG, tb_strings()) self._log(DEBUG, tb_strings())
except EOFError, e: except EOFError, e:
self.log(DEBUG, 'EOF') self._log(DEBUG, 'EOF')
self._log(DEBUG, tb_strings())
except Exception, e: except Exception, e:
self.log(DEBUG, 'Unknown exception: ' + str(e)) self._log(DEBUG, 'Unknown exception: ' + str(e))
self.log(DEBUG, tb_strings()) self._log(DEBUG, tb_strings())
if self.active: if self.active:
self.active = 0 self.active = False
if self.completion_event != None: if self.completion_event != None:
self.completion_event.set() self.completion_event.set()
if self.auth_event != None: if self.auth_event != None:
@ -468,36 +537,49 @@ class BaseTransport(threading.Thread):
### protocol stages ### protocol stages
def renegotiate_keys(self): def renegotiate_keys(self):
"""
Force this session to switch to new keys. Normally this is done
automatically after the session hits a certain number of packets or
bytes sent or received, but this method gives you the option of forcing
new keys whenever you want. Negotiating new keys causes a pause in
traffic both ways as the two sides swap keys and do computations. This
method returns when the session has switched to new keys, or the
session has died mid-negotiation.
@return: True if the renegotiation was successful, and the link is
using new keys; False if the session dropped during renegotiation.
@rtype: boolean
"""
self.completion_event = threading.Event() self.completion_event = threading.Event()
self.send_kex_init() self._send_kex_init()
while 1: while 1:
self.completion_event.wait(0.1); self.completion_event.wait(0.1);
if not self.active: if not self.active:
return 0 return False
if self.completion_event.isSet(): if self.completion_event.isSet():
break break
return 1 return True
def negotiate_keys(self, m): def _negotiate_keys(self, m):
# throws SSHException on anything unusual # throws SSHException on anything unusual
if self.local_kex_init == None: if self.local_kex_init == None:
# remote side wants to renegotiate # remote side wants to renegotiate
self.send_kex_init() self._send_kex_init()
self.parse_kex_init(m) self._parse_kex_init(m)
self.kex_engine.start_kex() self.kex_engine.start_kex()
def check_banner(self): def _check_banner(self):
# this is slow, but we only have to do it once # this is slow, but we only have to do it once
for i in range(5): for i in range(5):
buffer = '' buffer = ''
while not '\n' in buffer: while not '\n' in buffer:
buffer += self.read_all(1) buffer += self._read_all(1)
buffer = buffer[:-1] buffer = buffer[:-1]
if (len(buffer) > 0) and (buffer[-1] == '\r'): if (len(buffer) > 0) and (buffer[-1] == '\r'):
buffer = buffer[:-1] buffer = buffer[:-1]
if buffer[:4] == 'SSH-': if buffer[:4] == 'SSH-':
break break
self.log(DEBUG, 'Banner: ' + buffer) self._log(DEBUG, 'Banner: ' + buffer)
if buffer[:4] != 'SSH-': if buffer[:4] != 'SSH-':
raise SSHException('Indecipherable protocol version "' + buffer + '"') raise SSHException('Indecipherable protocol version "' + buffer + '"')
# save this server version string for later # save this server version string for later
@ -516,17 +598,21 @@ class BaseTransport(threading.Thread):
client = segs[2] client = segs[2]
if version != '1.99' and version != '2.0': if version != '1.99' and version != '2.0':
raise SSHException('Incompatible version (%s instead of 2.0)' % (version,)) raise SSHException('Incompatible version (%s instead of 2.0)' % (version,))
self.log(INFO, 'Connected (version %s, client %s)' % (version, client)) self._log(INFO, 'Connected (version %s, client %s)' % (version, client))
def send_kex_init(self): def _send_kex_init(self):
# send a really wimpy kex-init packet that says we're a bare-bones ssh client """
announce to the other side that we'd like to negotiate keys, and what
kind of key negotiation we support.
"""
if self.server_mode: if self.server_mode:
# FIXME: can't do group-exchange (gex) yet -- too slow if (self.modulus_pack is None) and ('diffie-hellman-group-exchange-sha1' in self.preferred_kex):
if 'diffie-hellman-group-exchange-sha1' in self.preferred_kex: # can't do group-exchange if we don't have a pack of potential primes
self.preferred_kex.remove('diffie-hellman-group-exchange-sha1') self.preferred_kex.remove('diffie-hellman-group-exchange-sha1')
available_server_keys = filter(self.server_key_dict.keys().__contains__,
available_server_keys = filter(self.server_key_dict.keys().__contains__, self.preferred_keys)
self.preferred_keys) else:
available_server_keys = self.preferred_keys
m = Message() m = Message()
m.add_byte(chr(MSG_KEXINIT)) m.add_byte(chr(MSG_KEXINIT))
@ -545,9 +631,9 @@ class BaseTransport(threading.Thread):
m.add_int(0) m.add_int(0)
# save a copy for later (needed to compute a hash) # save a copy for later (needed to compute a hash)
self.local_kex_init = str(m) self.local_kex_init = str(m)
self.send_message(m) self._send_message(m)
def parse_kex_init(self, m): def _parse_kex_init(self, m):
# reset counters of when to re-key, since we are now re-keying # reset counters of when to re-key, since we are now re-keying
self.received_bytes = 0 self.received_bytes = 0
self.received_packets = 0 self.received_packets = 0
@ -580,7 +666,7 @@ class BaseTransport(threading.Thread):
agreed_kex = filter(kex_algo_list.__contains__, self.preferred_kex) agreed_kex = filter(kex_algo_list.__contains__, self.preferred_kex)
if len(agreed_kex) == 0: if len(agreed_kex) == 0:
raise SSHException('Incompatible ssh peer (no acceptable kex algorithm)') raise SSHException('Incompatible ssh peer (no acceptable kex algorithm)')
self.kex_engine = self.kex_info[agreed_kex[0]](self) self.kex_engine = self._kex_info[agreed_kex[0]](self)
if self.server_mode: if self.server_mode:
available_server_keys = filter(self.server_key_dict.keys().__contains__, available_server_keys = filter(self.server_key_dict.keys().__contains__,
@ -608,7 +694,7 @@ class BaseTransport(threading.Thread):
raise SSHException('Incompatible ssh server (no acceptable ciphers)') raise SSHException('Incompatible ssh server (no acceptable ciphers)')
self.local_cipher = agreed_local_ciphers[0] self.local_cipher = agreed_local_ciphers[0]
self.remote_cipher = agreed_remote_ciphers[0] self.remote_cipher = agreed_remote_ciphers[0]
self.log(DEBUG, 'Ciphers agreed: local=%s, remote=%s' % (self.local_cipher, self.remote_cipher)) self._log(DEBUG, 'Ciphers agreed: local=%s, remote=%s' % (self.local_cipher, self.remote_cipher))
if self.server_mode: if self.server_mode:
agreed_remote_macs = filter(self.preferred_macs.__contains__, client_mac_algo_list) agreed_remote_macs = filter(self.preferred_macs.__contains__, client_mac_algo_list)
@ -621,19 +707,19 @@ class BaseTransport(threading.Thread):
self.local_mac = agreed_local_macs[0] self.local_mac = agreed_local_macs[0]
self.remote_mac = agreed_remote_macs[0] self.remote_mac = agreed_remote_macs[0]
self.log(DEBUG, 'kex algos:' + str(kex_algo_list) + ' server key:' + str(server_key_algo_list) + \ self._log(DEBUG, 'kex algos:' + str(kex_algo_list) + ' server key:' + str(server_key_algo_list) + \
' client encrypt:' + str(client_encrypt_algo_list) + \ ' client encrypt:' + str(client_encrypt_algo_list) + \
' server encrypt:' + str(server_encrypt_algo_list) + \ ' server encrypt:' + str(server_encrypt_algo_list) + \
' client mac:' + str(client_mac_algo_list) + \ ' client mac:' + str(client_mac_algo_list) + \
' server mac:' + str(server_mac_algo_list) + \ ' server mac:' + str(server_mac_algo_list) + \
' client compress:' + str(client_compress_algo_list) + \ ' client compress:' + str(client_compress_algo_list) + \
' server compress:' + str(server_compress_algo_list) + \ ' server compress:' + str(server_compress_algo_list) + \
' client lang:' + str(client_lang_list) + \ ' client lang:' + str(client_lang_list) + \
' server lang:' + str(server_lang_list) + \ ' server lang:' + str(server_lang_list) + \
' kex follows?' + str(kex_follows)) ' kex follows?' + str(kex_follows))
self.log(DEBUG, 'using kex %s; server key type %s; cipher: local %s, remote %s; mac: local %s, remote %s' % self._log(DEBUG, 'using kex %s; server key type %s; cipher: local %s, remote %s; mac: local %s, remote %s' %
(agreed_kex[0], self.host_key_type, self.local_cipher, self.remote_cipher, self.local_mac, (agreed_kex[0], self.host_key_type, self.local_cipher, self.remote_cipher, self.local_mac,
self.remote_mac)) self.remote_mac))
# save for computing hash later... # save for computing hash later...
# now wait! openssh has a bug (and others might too) where there are # now wait! openssh has a bug (and others might too) where there are
@ -642,50 +728,52 @@ class BaseTransport(threading.Thread):
# away those bytes because they aren't part of the hash. # away those bytes because they aren't part of the hash.
self.remote_kex_init = chr(MSG_KEXINIT) + m.get_so_far() self.remote_kex_init = chr(MSG_KEXINIT) + m.get_so_far()
def activate_inbound(self): def _activate_inbound(self):
"switch on newly negotiated encryption parameters for inbound traffic" "switch on newly negotiated encryption parameters for inbound traffic"
self.block_size_in = self.cipher_info[self.remote_cipher]['block-size'] self.block_size_in = self._cipher_info[self.remote_cipher]['block-size']
if self.server_mode: if self.server_mode:
IV_in = self.compute_key('A', self.block_size_in) IV_in = self._compute_key('A', self.block_size_in)
key_in = self.compute_key('C', self.cipher_info[self.remote_cipher]['key-size']) key_in = self._compute_key('C', self._cipher_info[self.remote_cipher]['key-size'])
else: else:
IV_in = self.compute_key('B', self.block_size_in) IV_in = self._compute_key('B', self.block_size_in)
key_in = self.compute_key('D', self.cipher_info[self.remote_cipher]['key-size']) key_in = self._compute_key('D', self._cipher_info[self.remote_cipher]['key-size'])
self.engine_in = self.get_cipher(self.remote_cipher, key_in, IV_in) self.engine_in = self._get_cipher(self.remote_cipher, key_in, IV_in)
self.remote_mac_len = self.mac_info[self.remote_mac]['size'] self.remote_mac_len = self._mac_info[self.remote_mac]['size']
self.remote_mac_engine = self.mac_info[self.remote_mac]['class'] self.remote_mac_engine = self._mac_info[self.remote_mac]['class']
# initial mac keys are done in the hash's natural size (not the potentially truncated # initial mac keys are done in the hash's natural size (not the potentially truncated
# transmission size) # transmission size)
if self.server_mode: if self.server_mode:
self.mac_key_in = self.compute_key('E', self.remote_mac_engine.digest_size) self.mac_key_in = self._compute_key('E', self.remote_mac_engine.digest_size)
else: else:
self.mac_key_in = self.compute_key('F', self.remote_mac_engine.digest_size) self.mac_key_in = self._compute_key('F', self.remote_mac_engine.digest_size)
def activate_outbound(self): def _activate_outbound(self):
"switch on newly negotiated encryption parameters for outbound traffic" "switch on newly negotiated encryption parameters for outbound traffic"
m = Message() m = Message()
m.add_byte(chr(MSG_NEWKEYS)) m.add_byte(chr(MSG_NEWKEYS))
self.send_message(m) self._send_message(m)
self.block_size_out = self.cipher_info[self.local_cipher]['block-size'] self.block_size_out = self._cipher_info[self.local_cipher]['block-size']
if self.server_mode: if self.server_mode:
IV_out = self.compute_key('B', self.block_size_out) IV_out = self._compute_key('B', self.block_size_out)
key_out = self.compute_key('D', self.cipher_info[self.local_cipher]['key-size']) key_out = self._compute_key('D', self._cipher_info[self.local_cipher]['key-size'])
else: else:
IV_out = self.compute_key('A', self.block_size_out) IV_out = self._compute_key('A', self.block_size_out)
key_out = self.compute_key('C', self.cipher_info[self.local_cipher]['key-size']) key_out = self._compute_key('C', self._cipher_info[self.local_cipher]['key-size'])
self.engine_out = self.get_cipher(self.local_cipher, key_out, IV_out) self.engine_out = self._get_cipher(self.local_cipher, key_out, IV_out)
self.local_mac_len = self.mac_info[self.local_mac]['size'] self.local_mac_len = self._mac_info[self.local_mac]['size']
self.local_mac_engine = self.mac_info[self.local_mac]['class'] self.local_mac_engine = self._mac_info[self.local_mac]['class']
# initial mac keys are done in the hash's natural size (not the potentially truncated # initial mac keys are done in the hash's natural size (not the potentially truncated
# transmission size) # transmission size)
if self.server_mode: if self.server_mode:
self.mac_key_out = self.compute_key('F', self.local_mac_engine.digest_size) self.mac_key_out = self._compute_key('F', self.local_mac_engine.digest_size)
else: else:
self.mac_key_out = self.compute_key('E', self.local_mac_engine.digest_size) self.mac_key_out = self._compute_key('E', self.local_mac_engine.digest_size)
# we always expect to receive NEWKEYS now
self.expected_packet = MSG_NEWKEYS
def parse_newkeys(self, m): def _parse_newkeys(self, m):
self.log(DEBUG, 'Switch to new keys ...') self._log(DEBUG, 'Switch to new keys ...')
self.activate_inbound() self._activate_inbound()
# can also free a bunch of stuff here # can also free a bunch of stuff here
self.local_kex_init = self.remote_kex_init = None self.local_kex_init = self.remote_kex_init = None
self.e = self.f = self.K = self.x = None self.e = self.f = self.K = self.x = None
@ -697,24 +785,24 @@ class BaseTransport(threading.Thread):
self.completion_event.set() self.completion_event.set()
return return
def parse_disconnect(self, m): def _parse_disconnect(self, m):
code = m.get_int() code = m.get_int()
desc = m.get_string() desc = m.get_string()
self.log(INFO, 'Disconnect (code %d): %s' % (code, desc)) self._log(INFO, 'Disconnect (code %d): %s' % (code, desc))
def parse_channel_open_success(self, m): def _parse_channel_open_success(self, m):
chanid = m.get_int() chanid = m.get_int()
server_chanid = m.get_int() server_chanid = m.get_int()
server_window_size = m.get_int() server_window_size = m.get_int()
server_max_packet_size = m.get_int() server_max_packet_size = m.get_int()
if not self.channels.has_key(chanid): if not self.channels.has_key(chanid):
self.log(WARNING, 'Success for unrequested channel! [??]') self._log(WARNING, 'Success for unrequested channel! [??]')
return return
try: try:
self.lock.acquire() self.lock.acquire()
chan = self.channels[chanid] chan = self.channels[chanid]
chan.set_remote_channel(server_chanid, server_window_size, server_max_packet_size) chan.set_remote_channel(server_chanid, server_window_size, server_max_packet_size)
self.log(INFO, 'Secsh channel %d opened.' % chanid) self._log(INFO, 'Secsh channel %d opened.' % chanid)
if self.channel_events.has_key(chanid): if self.channel_events.has_key(chanid):
self.channel_events[chanid].set() self.channel_events[chanid].set()
del self.channel_events[chanid] del self.channel_events[chanid]
@ -722,7 +810,7 @@ class BaseTransport(threading.Thread):
self.lock.release() self.lock.release()
return return
def parse_channel_open_failure(self, m): def _parse_channel_open_failure(self, m):
chanid = m.get_int() chanid = m.get_int()
reason = m.get_int() reason = m.get_int()
reason_str = m.get_string() reason_str = m.get_string()
@ -731,7 +819,7 @@ class BaseTransport(threading.Thread):
reason_text = CONNECTION_FAILED_CODE[reason] reason_text = CONNECTION_FAILED_CODE[reason]
else: else:
reason_text = '(unknown code)' reason_text = '(unknown code)'
self.log(INFO, 'Secsh channel %d open FAILED: %s: %s' % (chanid, reason_str, reason_text)) self._log(INFO, 'Secsh channel %d open FAILED: %s: %s' % (chanid, reason_str, reason_text))
try: try:
self.lock.aquire() self.lock.aquire()
if self.channels.has_key(chanid): if self.channels.has_key(chanid):
@ -747,14 +835,14 @@ class BaseTransport(threading.Thread):
"override me! return object descended from Channel to allow, or None to reject" "override me! return object descended from Channel to allow, or None to reject"
return None return None
def parse_channel_open(self, m): def _parse_channel_open(self, m):
kind = m.get_string() kind = m.get_string()
chanid = m.get_int() chanid = m.get_int()
initial_window_size = m.get_int() initial_window_size = m.get_int()
max_packet_size = m.get_int() max_packet_size = m.get_int()
reject = False reject = False
if not self.server_mode: if not self.server_mode:
self.log(DEBUG, 'Rejecting "%s" channel request from server.' % kind) self._log(DEBUG, 'Rejecting "%s" channel request from server.' % kind)
reject = True reject = True
reason = self.OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED reason = self.OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED
else: else:
@ -766,7 +854,7 @@ class BaseTransport(threading.Thread):
self.lock.release() self.lock.release()
chan = self.check_channel_request(kind, my_chanid) chan = self.check_channel_request(kind, my_chanid)
if (chan is None) or (type(chan) is int): if (chan is None) or (type(chan) is int):
self.log(DEBUG, 'Rejecting "%s" channel request from client.' % kind) self._log(DEBUG, 'Rejecting "%s" channel request from client.' % kind)
reject = True reject = True
if type(chan) is int: if type(chan) is int:
reason = chan reason = chan
@ -779,12 +867,12 @@ class BaseTransport(threading.Thread):
msg.add_int(reason) msg.add_int(reason)
msg.add_string('') msg.add_string('')
msg.add_string('en') msg.add_string('en')
self.send_message(msg) self._send_message(msg)
return return
try: try:
self.lock.acquire() self.lock.acquire()
self.channels[my_chanid] = chan self.channels[my_chanid] = chan
chan.set_transport(self) chan._set_transport(self)
chan.set_window(self.window_size, self.max_packet_size) chan.set_window(self.window_size, self.max_packet_size)
chan.set_remote_channel(chanid, initial_window_size, max_packet_size) chan.set_remote_channel(chanid, initial_window_size, max_packet_size)
finally: finally:
@ -795,8 +883,8 @@ class BaseTransport(threading.Thread):
m.add_int(my_chanid) m.add_int(my_chanid)
m.add_int(self.window_size) m.add_int(self.window_size)
m.add_int(self.max_packet_size) m.add_int(self.max_packet_size)
self.send_message(m) self._send_message(m)
self.log(INFO, 'Secsh channel %d opened.' % my_chanid) self._log(INFO, 'Secsh channel %d opened.' % my_chanid)
try: try:
self.lock.acquire() self.lock.acquire()
self.server_accepts.append(chan) self.server_accepts.append(chan)
@ -820,21 +908,21 @@ class BaseTransport(threading.Thread):
self.lock.release() self.lock.release()
return chan return chan
def parse_debug(self, m): def _parse_debug(self, m):
always_display = m.get_boolean() always_display = m.get_boolean()
msg = m.get_string() msg = m.get_string()
lang = m.get_string() lang = m.get_string()
self.log(DEBUG, 'Debug msg: ' + safe_string(msg)) self._log(DEBUG, 'Debug msg: ' + safe_string(msg))
handler_table = { _handler_table = {
MSG_NEWKEYS: parse_newkeys, MSG_NEWKEYS: _parse_newkeys,
MSG_CHANNEL_OPEN_SUCCESS: parse_channel_open_success, MSG_CHANNEL_OPEN_SUCCESS: _parse_channel_open_success,
MSG_CHANNEL_OPEN_FAILURE: parse_channel_open_failure, MSG_CHANNEL_OPEN_FAILURE: _parse_channel_open_failure,
MSG_CHANNEL_OPEN: parse_channel_open, MSG_CHANNEL_OPEN: _parse_channel_open,
MSG_KEXINIT: negotiate_keys, MSG_KEXINIT: _negotiate_keys,
} }
channel_handler_table = { _channel_handler_table = {
MSG_CHANNEL_SUCCESS: Channel.request_success, MSG_CHANNEL_SUCCESS: Channel.request_success,
MSG_CHANNEL_FAILURE: Channel.request_failed, MSG_CHANNEL_FAILURE: Channel.request_failed,
MSG_CHANNEL_DATA: Channel.feed, MSG_CHANNEL_DATA: Channel.feed,

View File

@ -1,7 +1,6 @@
#!/usr/bin/python #!/usr/bin/python
import sys, struct, traceback import sys, struct, traceback
from Crypto.Util import number
def inflate_long(s, always_positive=0): def inflate_long(s, always_positive=0):
"turns a normalized byte string into a long-int (adapted from Crypto.Util.number)" "turns a normalized byte string into a long-int (adapted from Crypto.Util.number)"
@ -98,20 +97,5 @@ def bit_length(n):
bitlen -= 1 bitlen -= 1
return bitlen return bitlen
def generate_prime(bits, randpool):
hbyte_mask = pow(2, bits % 8) - 1
x = randpool.get_bytes((bits+7) // 8)
if hbyte_mask > 0:
x = chr(ord(x[0]) & hbyte_mask) + x[1:]
n = inflate_long(x, 1)
n |= 1
n |= (1 << (bits - 1))
while 1:
# loop catches the case where we increment n into a higher bit-range
while not number.isPrime(n):
n += 2
if bit_length(n) == bits:
return n
def tb_strings(): def tb_strings():
return ''.join(traceback.format_exception(*sys.exc_info())).split('\n') return ''.join(traceback.format_exception(*sys.exc_info())).split('\n')