[project @ Arch-1:robey@lag.net--2003-public%secsh--dev--1.0--patch-16]
hook up server-side kex-gex; add more documentation group-exchange kex should work now on the server side. it will only be advertised if a "moduli" file has been loaded (see the -gasp- docs) so we don't spend hours (literally. hours.) computing primes. some of the logic was previously wrong, too, since it had never been tested. fixed repr() string for Transport/BaseTransport. moved is_authenticated to Transport where it belongs. added lots of documentation (but still only about 10% documented). lots of methods were made private finally.
This commit is contained in:
		
							parent
							
								
									eb4c279ec4
								
							
						
					
					
						commit
						36d6d95dc6
					
				
							
								
								
									
										4
									
								
								NOTES
								
								
								
								
							
							
						
						
									
										4
									
								
								NOTES
								
								
								
								
							|  | @ -22,14 +22,14 @@ from BaseTransport: | |||
|     get_server_key | ||||
|     close | ||||
|     get_remote_server_key | ||||
|     is_active | ||||
|     is_authenticated | ||||
| *   is_active | ||||
|     open_session | ||||
|     open_channel | ||||
|     renegotiate_keys | ||||
|     check_channel_request | ||||
| 
 | ||||
| from Transport: | ||||
| *   is_authenticated | ||||
|     auth_key | ||||
|     auth_password | ||||
|     get_allowed_auths | ||||
|  |  | |||
							
								
								
									
										2
									
								
								demo.py
								
								
								
								
							
							
						
						
									
										2
									
								
								demo.py
								
								
								
								
							|  | @ -69,7 +69,7 @@ try: | |||
|     t.ultra_debug = 0 | ||||
|     t.start_client(event) | ||||
|     # print repr(t) | ||||
|     event.wait(10) | ||||
|     event.wait(15) | ||||
|     if not t.is_active(): | ||||
|         print '*** SSH negotiation failed.' | ||||
|         sys.exit(1) | ||||
|  |  | |||
|  | @ -70,6 +70,11 @@ print 'Got a connection!' | |||
| try: | ||||
|     event = threading.Event() | ||||
|     t = ServerTransport(client) | ||||
|     try: | ||||
|         t.load_server_moduli() | ||||
|     except: | ||||
|         print '(Failed to load moduli -- gex will be unsupported.)' | ||||
|         raise | ||||
|     t.add_server_key(host_key) | ||||
|     t.ultra_debug = 0 | ||||
|     t.start_server(event) | ||||
|  | @ -81,10 +86,11 @@ try: | |||
|     # print repr(t) | ||||
| 
 | ||||
|     # wait for auth | ||||
|     chan = t.accept(10) | ||||
|     chan = t.accept(20) | ||||
|     if chan is None: | ||||
|         print '*** No channel.' | ||||
|         sys.exit(1) | ||||
|     print 'Authenticated!' | ||||
|     chan.event.wait(10) | ||||
|     if not chan.event.isSet(): | ||||
|         print '*** Client never asked for a shell.' | ||||
|  |  | |||
|  | @ -17,4 +17,4 @@ from rsakey import RSAKey | |||
| from dsskey import DSSKey | ||||
| from util import hexify | ||||
| 
 | ||||
| __all__ = [ 'Transport', 'Channel', 'RSAKey', 'DSSKey', 'hexify' ] | ||||
| #__all__ = [ 'Transport', 'Channel', 'RSAKey', 'DSSKey', 'hexify' ] | ||||
|  |  | |||
|  | @ -19,17 +19,45 @@ class Transport(BaseTransport): | |||
| 
 | ||||
|     def __init__(self, sock): | ||||
|         BaseTransport.__init__(self, sock) | ||||
|         self.authenticated = False | ||||
|         self.auth_event = None | ||||
|         # for server mode: | ||||
|         self.auth_username = None | ||||
|         self.auth_fail_count = 0 | ||||
|         self.auth_complete = 0 | ||||
| 
 | ||||
|     def request_auth(self): | ||||
|     def __repr__(self): | ||||
|         if not self.active: | ||||
|             return '<paramiko.Transport (unconnected)>' | ||||
|         out = '<paramiko.Transport' | ||||
|         if self.local_cipher != '': | ||||
|             out += ' (cipher %s)' % self.local_cipher | ||||
|         if self.authenticated: | ||||
|             if len(self.channels) == 1: | ||||
|                 out += ' (active; 1 open channel)' | ||||
|             else: | ||||
|                 out += ' (active; %d open channels)' % len(self.channels) | ||||
|         elif self.initial_kex_done: | ||||
|             out += ' (connected; awaiting auth)' | ||||
|         else: | ||||
|             out += ' (connecting)' | ||||
|         out += '>' | ||||
|         return out | ||||
| 
 | ||||
|     def is_authenticated(self): | ||||
|         """ | ||||
|         Return true if this session is active and authenticated. | ||||
| 
 | ||||
|         @return: True if the session is still open and has been authenticated successfully; | ||||
|         False if authentication failed and/or the session is closed. | ||||
|         """ | ||||
|         return self.authenticated and self.active | ||||
| 
 | ||||
|     def _request_auth(self): | ||||
|         m = Message() | ||||
|         m.add_byte(chr(MSG_SERVICE_REQUEST)) | ||||
|         m.add_string('ssh-userauth') | ||||
|         self.send_message(m) | ||||
|         self._send_message(m) | ||||
| 
 | ||||
|     def auth_key(self, username, key, event): | ||||
|         if (not self.active) or (not self.initial_kex_done): | ||||
|  | @ -41,7 +69,7 @@ class Transport(BaseTransport): | |||
|             self.auth_method = 'publickey' | ||||
|             self.username = username | ||||
|             self.private_key = key | ||||
|             self.request_auth() | ||||
|             self._request_auth() | ||||
|         finally: | ||||
|             self.lock.release() | ||||
| 
 | ||||
|  | @ -56,7 +84,7 @@ class Transport(BaseTransport): | |||
|             self.auth_method = 'password' | ||||
|             self.username = username | ||||
|             self.password = password | ||||
|             self.request_auth() | ||||
|             self._request_auth() | ||||
|         finally: | ||||
|             self.lock.release() | ||||
| 
 | ||||
|  | @ -66,7 +94,7 @@ class Transport(BaseTransport): | |||
|         m.add_int(DISCONNECT_SERVICE_NOT_AVAILABLE) | ||||
|         m.add_string('Service not available') | ||||
|         m.add_string('en') | ||||
|         self.send_message(m) | ||||
|         self._send_message(m) | ||||
|         self.close() | ||||
| 
 | ||||
|     def disconnect_no_more_auth(self): | ||||
|  | @ -75,7 +103,7 @@ class Transport(BaseTransport): | |||
|         m.add_int(DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE) | ||||
|         m.add_string('No more auth methods available') | ||||
|         m.add_string('en') | ||||
|         self.send_message(m) | ||||
|         self._send_message(m) | ||||
|         self.close() | ||||
| 
 | ||||
|     def parse_service_request(self, m): | ||||
|  | @ -85,7 +113,7 @@ class Transport(BaseTransport): | |||
|             m = Message() | ||||
|             m.add_byte(chr(MSG_SERVICE_ACCEPT)) | ||||
|             m.add_string(service) | ||||
|             self.send_message(m) | ||||
|             self._send_message(m) | ||||
|             return | ||||
|         # dunno this one | ||||
|         self.disconnect_service_not_available() | ||||
|  | @ -93,7 +121,7 @@ class Transport(BaseTransport): | |||
|     def parse_service_accept(self, m): | ||||
|         service = m.get_string() | ||||
|         if service == 'ssh-userauth': | ||||
|             self.log(DEBUG, 'userauth is OK') | ||||
|             self._log(DEBUG, 'userauth is OK') | ||||
|             m = Message() | ||||
|             m.add_byte(chr(MSG_USERAUTH_REQUEST)) | ||||
|             m.add_string(self.username) | ||||
|  | @ -109,9 +137,9 @@ class Transport(BaseTransport): | |||
|                 m.add_string(self.private_key.sign_ssh_session(self.randpool, self.H, self.username)) | ||||
|             else: | ||||
|                 raise SSHException('Unknown auth method "%s"' % self.auth_method) | ||||
|             self.send_message(m) | ||||
|             self._send_message(m) | ||||
|         else: | ||||
|             self.log(DEBUG, 'Service request "%s" accepted (?)' % service) | ||||
|             self._log(DEBUG, 'Service request "%s" accepted (?)' % service) | ||||
| 
 | ||||
|     def get_allowed_auths(self, username): | ||||
|         "override me!" | ||||
|  | @ -136,7 +164,7 @@ class Transport(BaseTransport): | |||
|             m.add_byte(chr(MSG_USERAUTH_FAILURE)) | ||||
|             m.add_string('none') | ||||
|             m.add_boolean(0) | ||||
|             self.send_message(m) | ||||
|             self._send_message(m) | ||||
|             return | ||||
|         if self.auth_complete: | ||||
|             # ignore | ||||
|  | @ -144,12 +172,12 @@ class Transport(BaseTransport): | |||
|         username = m.get_string() | ||||
|         service = m.get_string() | ||||
|         method = m.get_string() | ||||
|         self.log(DEBUG, 'Auth request (type=%s) service=%s, username=%s' % (method, service, username)) | ||||
|         self._log(DEBUG, 'Auth request (type=%s) service=%s, username=%s' % (method, service, username)) | ||||
|         if service != 'ssh-connection': | ||||
|             self.disconnect_service_not_available() | ||||
|             return | ||||
|         if (self.auth_username is not None) and (self.auth_username != username): | ||||
|             self.log(DEBUG, 'Auth rejected because the client attempted to change username in mid-flight') | ||||
|             self._log(DEBUG, 'Auth rejected because the client attempted to change username in mid-flight') | ||||
|             self.disconnect_no_more_auth() | ||||
|             return | ||||
|         if method == 'none': | ||||
|  | @ -160,7 +188,7 @@ class Transport(BaseTransport): | |||
|             if changereq: | ||||
|                 # always treated as failure, since we don't support changing passwords, but collect | ||||
|                 # the list of valid auth types from the callback anyway | ||||
|                 self.log(DEBUG, 'Auth request to change passwords (rejected)') | ||||
|                 self._log(DEBUG, 'Auth request to change passwords (rejected)') | ||||
|                 newpassword = m.get_string().decode('UTF-8') | ||||
|                 result = self.AUTH_FAILED | ||||
|             else: | ||||
|  | @ -173,11 +201,11 @@ class Transport(BaseTransport): | |||
|         # okay, send result | ||||
|         m = Message() | ||||
|         if result == self.AUTH_SUCCESSFUL: | ||||
|             self.log(DEBUG, 'Auth granted.') | ||||
|             self._log(DEBUG, 'Auth granted.') | ||||
|             m.add_byte(chr(MSG_USERAUTH_SUCCESS)) | ||||
|             self.auth_complete = 1 | ||||
|         else: | ||||
|             self.log(DEBUG, 'Auth rejected.') | ||||
|             self._log(DEBUG, 'Auth rejected.') | ||||
|             m.add_byte(chr(MSG_USERAUTH_FAILURE)) | ||||
|             m.add_string(self.get_allowed_auths(username)) | ||||
|             if result == self.AUTH_PARTIALLY_SUCCESSFUL: | ||||
|  | @ -185,13 +213,13 @@ class Transport(BaseTransport): | |||
|             else: | ||||
|                 m.add_boolean(0) | ||||
|             self.auth_fail_count += 1 | ||||
|         self.send_message(m) | ||||
|         self._send_message(m) | ||||
|         if self.auth_fail_count >= 10: | ||||
|             self.disconnect_no_more_auth() | ||||
| 
 | ||||
|     def parse_userauth_success(self, m): | ||||
|         self.log(INFO, 'Authentication successful!') | ||||
|         self.authenticated = 1 | ||||
|         self._log(INFO, 'Authentication successful!') | ||||
|         self.authenticated = True | ||||
|         if self.auth_event != None: | ||||
|             self.auth_event.set() | ||||
| 
 | ||||
|  | @ -199,12 +227,12 @@ class Transport(BaseTransport): | |||
|         authlist = m.get_list() | ||||
|         partial = m.get_boolean() | ||||
|         if partial: | ||||
|             self.log(INFO, 'Authentication continues...') | ||||
|             self.log(DEBUG, 'Methods: ' + str(partial)) | ||||
|             self._log(INFO, 'Authentication continues...') | ||||
|             self._log(DEBUG, 'Methods: ' + str(partial)) | ||||
|             # FIXME - do something | ||||
|             pass | ||||
|         self.log(INFO, 'Authentication failed.') | ||||
|         self.authenticated = 0 | ||||
|         self._log(INFO, 'Authentication failed.') | ||||
|         self.authenticated = False | ||||
|         self.close() | ||||
|         if self.auth_event != None: | ||||
|             self.auth_event.set() | ||||
|  | @ -212,11 +240,11 @@ class Transport(BaseTransport): | |||
|     def parse_userauth_banner(self, m): | ||||
|         banner = m.get_string() | ||||
|         lang = m.get_string() | ||||
|         self.log(INFO, 'Auth banner: ' + banner) | ||||
|         self._log(INFO, 'Auth banner: ' + banner) | ||||
|         # who cares. | ||||
| 
 | ||||
|     handler_table = BaseTransport.handler_table.copy() | ||||
|     handler_table.update({ | ||||
|     _handler_table = BaseTransport._handler_table.copy() | ||||
|     _handler_table.update({ | ||||
|         MSG_SERVICE_REQUEST: parse_service_request, | ||||
|         MSG_SERVICE_ACCEPT: parse_service_accept, | ||||
|         MSG_USERAUTH_REQUEST: parse_userauth_request, | ||||
|  |  | |||
|  | @ -50,10 +50,10 @@ class Channel(object): | |||
|         out += '>' | ||||
|         return out | ||||
| 
 | ||||
|     def set_transport(self, transport): | ||||
|     def _set_transport(self, transport): | ||||
|         self.transport = transport | ||||
| 
 | ||||
|     def log(self, level, msg): | ||||
|     def _log(self, level, msg): | ||||
|         self.logger.log(level, msg) | ||||
| 
 | ||||
|     def set_window(self, window_size, max_packet_size): | ||||
|  | @ -70,7 +70,7 @@ class Channel(object): | |||
|         self.active = 1 | ||||
| 
 | ||||
|     def request_success(self, m): | ||||
|         self.log(DEBUG, 'Sesch channel %d request ok' % self.chanid) | ||||
|         self._log(DEBUG, 'Sesch channel %d request ok' % self.chanid) | ||||
|         return | ||||
| 
 | ||||
|     def request_failed(self, m): | ||||
|  | @ -80,13 +80,13 @@ class Channel(object): | |||
|         s = m.get_string() | ||||
|         try: | ||||
|             self.lock.acquire() | ||||
|             self.log(DEBUG, 'fed %d bytes' % len(s)) | ||||
|             self._log(DEBUG, 'fed %d bytes' % len(s)) | ||||
|             if self.pipe_wfd != None: | ||||
|                 self.feed_pipe(s) | ||||
|             else: | ||||
|                 self.in_buffer += s | ||||
|                 self.in_buffer_cv.notifyAll() | ||||
|             self.log(DEBUG, '(out from feed)') | ||||
|             self._log(DEBUG, '(out from feed)') | ||||
|         finally: | ||||
|             self.lock.release() | ||||
| 
 | ||||
|  | @ -94,7 +94,7 @@ class Channel(object): | |||
|         nbytes = m.get_int() | ||||
|         try: | ||||
|             self.lock.acquire() | ||||
|             self.log(DEBUG, 'window up %d' % nbytes) | ||||
|             self._log(DEBUG, 'window up %d' % nbytes) | ||||
|             self.out_window_size += nbytes | ||||
|             self.out_buffer_cv.notifyAll() | ||||
|         finally: | ||||
|  | @ -146,7 +146,7 @@ class Channel(object): | |||
|             pixelheight = m.get_int() | ||||
|             ok = self.check_window_change_request(width, height, pixelwidth, pixelheight) | ||||
|         else: | ||||
|             self.log(DEBUG, 'Unhandled channel request "%s"' % key) | ||||
|             self._log(DEBUG, 'Unhandled channel request "%s"' % key) | ||||
|             ok = False | ||||
|         if want_reply: | ||||
|             m = Message() | ||||
|  | @ -155,7 +155,7 @@ class Channel(object): | |||
|             else: | ||||
|                 m.add_byte(chr(MSG_CHANNEL_FAILURE)) | ||||
|             m.add_int(self.remote_chanid) | ||||
|             self.transport.send_message(m) | ||||
|             self.transport._send_message(m) | ||||
| 
 | ||||
|     def handle_eof(self, m): | ||||
|         try: | ||||
|  | @ -168,7 +168,7 @@ class Channel(object): | |||
|                     self.pipe_wfd = None | ||||
|         finally: | ||||
|             self.lock.release() | ||||
|         self.log(DEBUG, 'EOF received') | ||||
|         self._log(DEBUG, 'EOF received') | ||||
| 
 | ||||
|     def handle_close(self, m): | ||||
|         self.close() | ||||
|  | @ -199,7 +199,7 @@ class Channel(object): | |||
|         # pixel height, width (usually useless) | ||||
|         m.add_int(0).add_int(0) | ||||
|         m.add_string('') | ||||
|         self.transport.send_message(m) | ||||
|         self.transport._send_message(m) | ||||
| 
 | ||||
|     def invoke_shell(self): | ||||
|         if self.closed or self.eof_received or self.eof_sent or not self.active: | ||||
|  | @ -209,7 +209,7 @@ class Channel(object): | |||
|         m.add_int(self.remote_chanid) | ||||
|         m.add_string('shell') | ||||
|         m.add_boolean(1) | ||||
|         self.transport.send_message(m) | ||||
|         self.transport._send_message(m) | ||||
| 
 | ||||
|     def exec_command(self, command): | ||||
|         if self.closed or self.eof_received or self.eof_sent or not self.active: | ||||
|  | @ -220,7 +220,7 @@ class Channel(object): | |||
|         m.add_string('exec') | ||||
|         m.add_boolean(1) | ||||
|         m.add_string(command) | ||||
|         self.transport.send_message(m) | ||||
|         self.transport._send_message(m) | ||||
| 
 | ||||
|     def invoke_subsystem(self, subsystem): | ||||
|         if self.closed or self.eof_received or self.eof_sent or not self.active: | ||||
|  | @ -231,7 +231,7 @@ class Channel(object): | |||
|         m.add_string('subsystem') | ||||
|         m.add_boolean(1) | ||||
|         m.add_string(subsystem) | ||||
|         self.transport.send_message(m) | ||||
|         self.transport._send_message(m) | ||||
| 
 | ||||
|     def resize_pty(self, width=80, height=24): | ||||
|         if self.closed or self.eof_received or self.eof_sent or not self.active: | ||||
|  | @ -244,7 +244,7 @@ class Channel(object): | |||
|         m.add_int(width) | ||||
|         m.add_int(height) | ||||
|         m.add_int(0).add_int(0) | ||||
|         self.transport.send_message(m) | ||||
|         self.transport._send_message(m) | ||||
| 
 | ||||
|     def get_transport(self): | ||||
|         return self.transport | ||||
|  | @ -262,9 +262,9 @@ class Channel(object): | |||
|         m = Message() | ||||
|         m.add_byte(chr(MSG_CHANNEL_EOF)) | ||||
|         m.add_int(self.remote_chanid) | ||||
|         self.transport.send_message(m) | ||||
|         self.transport._send_message(m) | ||||
|         self.eof_sent = 1 | ||||
|         self.log(DEBUG, 'EOF sent') | ||||
|         self._log(DEBUG, 'EOF sent') | ||||
|         return | ||||
| 
 | ||||
| 
 | ||||
|  | @ -290,9 +290,9 @@ class Channel(object): | |||
|                 m = Message() | ||||
|                 m.add_byte(chr(MSG_CHANNEL_CLOSE)) | ||||
|                 m.add_int(self.remote_chanid) | ||||
|                 self.transport.send_message(m) | ||||
|                 self.transport._send_message(m) | ||||
|                 self.closed = 1 | ||||
|                 self.transport.unlink_channel(self.chanid) | ||||
|                 self.transport._unlink_channel(self.chanid) | ||||
|         finally: | ||||
|             self.lock.release() | ||||
| 
 | ||||
|  | @ -371,7 +371,7 @@ class Channel(object): | |||
|             m.add_byte(chr(MSG_CHANNEL_DATA)) | ||||
|             m.add_int(self.remote_chanid) | ||||
|             m.add_string(s[:size]) | ||||
|             self.transport.send_message(m) | ||||
|             self.transport._send_message(m) | ||||
|             self.out_window_size -= size | ||||
|         finally: | ||||
|             self.lock.release() | ||||
|  | @ -506,25 +506,25 @@ class Channel(object): | |||
|         self.in_buffer = self.in_buffer[nbytes:] | ||||
|         os.write(self.pipd_wfd, x) | ||||
| 
 | ||||
|     def unlink(self): | ||||
|     def _unlink(self): | ||||
|         if self.closed or not self.active: | ||||
|             return | ||||
|         self.closed = 1 | ||||
|         self.transport.unlink_channel(self.chanid) | ||||
|         self.transport._unlink_channel(self.chanid) | ||||
| 
 | ||||
|     def check_add_window(self, n): | ||||
|         # already holding the lock! | ||||
|         if self.closed or self.eof_received or not self.active: | ||||
|             return | ||||
|         self.log(DEBUG, 'addwindow %d' % n) | ||||
|         self._log(DEBUG, 'addwindow %d' % n) | ||||
|         self.in_window_sofar += n | ||||
|         if self.in_window_sofar > self.in_window_threshold: | ||||
|             self.log(DEBUG, 'addwindow send %d' % self.in_window_sofar) | ||||
|             self._log(DEBUG, 'addwindow send %d' % self.in_window_sofar) | ||||
|             m = Message() | ||||
|             m.add_byte(chr(MSG_CHANNEL_WINDOW_ADJUST)) | ||||
|             m.add_int(self.remote_chanid) | ||||
|             m.add_int(self.in_window_sofar) | ||||
|             self.transport.send_message(m) | ||||
|             self.transport._send_message(m) | ||||
|             self.in_window_sofar = 0 | ||||
| 
 | ||||
| 
 | ||||
|  |  | |||
|  | @ -5,7 +5,7 @@ | |||
| # LOT more on the server side). | ||||
| 
 | ||||
| from message import Message | ||||
| from util import inflate_long, deflate_long, generate_prime, bit_length | ||||
| from util import inflate_long, deflate_long, bit_length | ||||
| from ssh_exception import SSHException | ||||
| from transport import MSG_NEWKEYS | ||||
| from Crypto.Hash import SHA | ||||
|  | @ -27,7 +27,7 @@ class KexGex(object): | |||
| 
 | ||||
|     def start_kex(self): | ||||
|         if self.transport.server_mode: | ||||
|             self.transport.expected_packet = MSG_KEXDH_GEX_REQUEST | ||||
|             self.transport._expect_packet(MSG_KEXDH_GEX_REQUEST) | ||||
|             return | ||||
|         # request a bit range: we accept (min_bits) to (max_bits), but prefer | ||||
|         # (preferred_bits).  according to the spec, we shouldn't pull the | ||||
|  | @ -37,21 +37,21 @@ class KexGex(object): | |||
|         m.add_int(self.min_bits) | ||||
|         m.add_int(self.preferred_bits) | ||||
|         m.add_int(self.max_bits) | ||||
|         self.transport.send_message(m) | ||||
|         self.transport.expected_packet = MSG_KEXDH_GEX_GROUP | ||||
|         self.transport._send_message(m) | ||||
|         self.transport._expect_packet(MSG_KEXDH_GEX_GROUP) | ||||
| 
 | ||||
|     def parse_next(self, ptype, m): | ||||
|         if ptype == MSG_KEXDH_GEX_REQUEST: | ||||
|             return self.parse_kexdh_gex_request(m) | ||||
|             return self._parse_kexdh_gex_request(m) | ||||
|         elif ptype == MSG_KEXDH_GEX_GROUP: | ||||
|             return self.parse_kexdh_gex_group(m) | ||||
|             return self._parse_kexdh_gex_group(m) | ||||
|         elif ptype == MSG_KEXDH_GEX_INIT: | ||||
|             return self.parse_kexdh_gex_init(m) | ||||
|             return self._parse_kexdh_gex_init(m) | ||||
|         elif ptype == MSG_KEXDH_GEX_REPLY: | ||||
|             return self.parse_kexdh_gex_reply(m) | ||||
|             return self._parse_kexdh_gex_reply(m) | ||||
|         raise SSHException('KexGex asked to handle packet type %d' % ptype) | ||||
| 
 | ||||
|     def generate_x(self): | ||||
|     def _generate_x(self): | ||||
|         # generate an "x" (1 < x < (p-1)/2). | ||||
|         q = (self.p - 1) // 2 | ||||
|         qnorm = deflate_long(q, 0) | ||||
|  | @ -70,7 +70,7 @@ class KexGex(object): | |||
|                 break | ||||
|         self.x = x | ||||
| 
 | ||||
|     def parse_kexdh_gex_request(self, m): | ||||
|     def _parse_kexdh_gex_request(self, m): | ||||
|         min = m.get_int() | ||||
|         preferred = m.get_int() | ||||
|         max = m.get_int() | ||||
|  | @ -79,52 +79,53 @@ class KexGex(object): | |||
|             preferred = self.max_bits | ||||
|         if preferred < self.min_bits: | ||||
|             preferred = self.min_bits | ||||
|         # fix min/max if they're inconsistent.  technically, we could just pout | ||||
|         # and hang up, but there's no harm in giving them the benefit of the | ||||
|         # doubt and just picking a bitsize for them. | ||||
|         if min > preferred: | ||||
|             min = preferred | ||||
|         if max < preferred: | ||||
|             max = preferred | ||||
|         # now save a copy | ||||
|         self.min_bits = min | ||||
|         self.preferred_bits = preferred | ||||
|         self.max_bits = max | ||||
|         # generate prime | ||||
|         while 1: | ||||
|             # does not work FIXME | ||||
|             # the problem is that it's too fscking SLOW | ||||
|             self.transport.log(DEBUG, 'stir...') | ||||
|             self.transport.randpool.stir() | ||||
|             self.transport.log(DEBUG, 'get-prime %d...' % preferred) | ||||
|             self.p = generate_prime(preferred, self.transport.randpool) | ||||
|             self.transport.log(DEBUG, 'got ' + repr(self.p)) | ||||
|             if number.isPrime((self.p - 1) // 2): | ||||
|                 break | ||||
|         self.g = 2 | ||||
|         pack = self.transport._get_modulus_pack() | ||||
|         if pack is None: | ||||
|             raise SSHException('Can\'t do server-side gex with no modulus pack') | ||||
|         self.g, self.p = pack.get_modulus(min, preferred, max) | ||||
|         m = Message() | ||||
|         m.add_byte(chr(MSG_KEXDH_GEX_GROUP)) | ||||
|         m.add_mpint(self.p) | ||||
|         m.add_mpint(self.g) | ||||
|         self.transport.send_message(m) | ||||
|         self.transport.expected_packet = MSG_KEXDH_GEX_INIT | ||||
|         self.transport._send_message(m) | ||||
|         self.transport._expect_packet(MSG_KEXDH_GEX_INIT) | ||||
| 
 | ||||
|     def parse_kexdh_gex_group(self, m): | ||||
|     def _parse_kexdh_gex_group(self, m): | ||||
|         self.p = m.get_mpint() | ||||
|         self.g = m.get_mpint() | ||||
|         # reject if p's bit length < 1024 or > 8192 | ||||
|         bitlen = bit_length(self.p) | ||||
|         if (bitlen < 1024) or (bitlen > 8192): | ||||
|             raise SSHException('Server-generated gex p (don\'t ask) is out of range (%d bits)' % bitlen) | ||||
|         self.transport.log(DEBUG, 'Got server p (%d bits)' % bitlen) | ||||
|         self.generate_x() | ||||
|         self.transport._log(DEBUG, 'Got server p (%d bits)' % bitlen) | ||||
|         self._generate_x() | ||||
|         # now compute e = g^x mod p | ||||
|         self.e = pow(self.g, self.x, self.p) | ||||
|         m = Message() | ||||
|         m.add_byte(chr(MSG_KEXDH_GEX_INIT)) | ||||
|         m.add_mpint(self.e) | ||||
|         self.transport.send_message(m) | ||||
|         self.transport.expected_packet = MSG_KEXDH_GEX_REPLY | ||||
|         self.transport._send_message(m) | ||||
|         self.transport._expect_packet(MSG_KEXDH_GEX_REPLY) | ||||
| 
 | ||||
|     def parse_kexdh_gex_init(self, m): | ||||
|     def _parse_kexdh_gex_init(self, m): | ||||
|         self.e = m.get_mpint() | ||||
|         if (self.e < 1) or (self.e > self.p - 1): | ||||
|             raise SSHException('Client kex "e" is out of range') | ||||
|         self.generate_x() | ||||
|         K = pow(self.e, self.x, P) | ||||
|         self._generate_x() | ||||
|         self.f = pow(self.g, self.x, self.p) | ||||
|         K = pow(self.e, self.x, self.p) | ||||
|         key = str(self.transport.get_server_key()) | ||||
|         # okay, build up the hash H of (V_C || V_S || I_C || I_S || K_S || min || n || max || p || g || e || f || K) | ||||
|         hm = Message().add(self.transport.remote_version).add(self.transport.local_version) | ||||
|  | @ -136,7 +137,7 @@ class KexGex(object): | |||
|         hm.add_mpint(self.g) | ||||
|         hm.add(self.e).add(self.f).add(K) | ||||
|         H = SHA.new(str(hm)).digest() | ||||
|         self.transport.set_K_H(K, H) | ||||
|         self.transport._set_K_H(K, H) | ||||
|         # sign it | ||||
|         sig = self.transport.get_server_key().sign_ssh_data(self.transport.randpool, H) | ||||
|         # send reply | ||||
|  | @ -145,11 +146,10 @@ class KexGex(object): | |||
|         m.add_string(key) | ||||
|         m.add_mpint(self.f) | ||||
|         m.add_string(sig) | ||||
|         self.transport.send_message(m) | ||||
|         self.transport.activate_outbound() | ||||
|         self.transport.expected_packet = MSG_NEWKEYS | ||||
|         self.transport._send_message(m) | ||||
|         self.transport._activate_outbound() | ||||
|          | ||||
|     def parse_kexdh_gex_reply(self, m): | ||||
|     def _parse_kexdh_gex_reply(self, m): | ||||
|         host_key = m.get_string() | ||||
|         self.f = m.get_mpint() | ||||
|         sig = m.get_string() | ||||
|  | @ -165,9 +165,9 @@ class KexGex(object): | |||
|         hm.add_mpint(self.p) | ||||
|         hm.add_mpint(self.g) | ||||
|         hm.add(self.e).add(self.f).add(K) | ||||
|         self.transport.set_K_H(K, SHA.new(str(hm)).digest()) | ||||
|         self.transport.verify_key(host_key, sig) | ||||
|         self.transport.activate_outbound() | ||||
|         self.transport.expected_packet = MSG_NEWKEYS | ||||
|         self.transport._set_K_H(K, SHA.new(str(hm)).digest()) | ||||
|         self.transport._verify_key(host_key, sig) | ||||
|         self.transport._activate_outbound() | ||||
| 
 | ||||
| 
 | ||||
|      | ||||
|  |  | |||
|  | @ -44,15 +44,15 @@ class KexGroup1(object): | |||
|         if self.transport.server_mode: | ||||
|             # compute f = g^x mod p, but don't send it yet | ||||
|             self.f = pow(G, self.x, P) | ||||
|             self.transport.expected_packet = MSG_KEXDH_INIT | ||||
|             self.transport._expect_packet(MSG_KEXDH_INIT) | ||||
|             return | ||||
|         # compute e = g^x mod p (where g=2), and send it | ||||
|         self.e = pow(G, self.x, P) | ||||
|         m = Message() | ||||
|         m.add_byte(chr(MSG_KEXDH_INIT)) | ||||
|         m.add_mpint(self.e) | ||||
|         self.transport.send_message(m) | ||||
|         self.transport.expected_packet = MSG_KEXDH_REPLY | ||||
|         self.transport._send_message(m) | ||||
|         self.transport._expect_packet(MSG_KEXDH_REPLY) | ||||
| 
 | ||||
|     def parse_next(self, ptype, m): | ||||
|         if self.transport.server_mode and (ptype == MSG_KEXDH_INIT): | ||||
|  | @ -73,10 +73,9 @@ class KexGroup1(object): | |||
|         hm = Message().add(self.transport.local_version).add(self.transport.remote_version) | ||||
|         hm.add(self.transport.local_kex_init).add(self.transport.remote_kex_init).add(host_key) | ||||
|         hm.add(self.e).add(self.f).add(K) | ||||
|         self.transport.set_K_H(K, SHA.new(str(hm)).digest()) | ||||
|         self.transport.verify_key(host_key, sig) | ||||
|         self.transport.activate_outbound() | ||||
|         self.transport.expected_packet = MSG_NEWKEYS | ||||
|         self.transport._set_K_H(K, SHA.new(str(hm)).digest()) | ||||
|         self.transport._verify_key(host_key, sig) | ||||
|         self.transport._activate_outbound() | ||||
| 
 | ||||
|     def parse_kexdh_init(self, m): | ||||
|         # server mode | ||||
|  | @ -90,7 +89,7 @@ class KexGroup1(object): | |||
|         hm.add(self.transport.remote_kex_init).add(self.transport.local_kex_init).add(key) | ||||
|         hm.add(self.e).add(self.f).add(K) | ||||
|         H = SHA.new(str(hm)).digest() | ||||
|         self.transport.set_K_H(K, H) | ||||
|         self.transport._set_K_H(K, H) | ||||
|         # sign it | ||||
|         sig = self.transport.get_server_key().sign_ssh_data(self.transport.randpool, H) | ||||
|         # send reply | ||||
|  | @ -99,6 +98,5 @@ class KexGroup1(object): | |||
|         m.add_string(key) | ||||
|         m.add_mpint(self.f) | ||||
|         m.add_string(sig) | ||||
|         self.transport.send_message(m) | ||||
|         self.transport.activate_outbound() | ||||
|         self.transport.expected_packet = MSG_NEWKEYS | ||||
|         self.transport._send_message(m) | ||||
|         self.transport._activate_outbound() | ||||
|  |  | |||
|  | @ -0,0 +1,128 @@ | |||
| 
 | ||||
| # utility functions for dealing with primes | ||||
| 
 | ||||
| from Crypto.Util import number | ||||
| from util import bit_length, inflate_long | ||||
| 
 | ||||
| 
 | ||||
| def generate_prime(bits, randpool): | ||||
|     hbyte_mask = pow(2, bits % 8) - 1 | ||||
|     while 1: | ||||
|         # loop catches the case where we increment n into a higher bit-range | ||||
|         x = randpool.get_bytes((bits+7) // 8) | ||||
|         if hbyte_mask > 0: | ||||
|             x = chr(ord(x[0]) & hbyte_mask) + x[1:] | ||||
|         n = inflate_long(x, 1) | ||||
|         n |= 1 | ||||
|         n |= (1 << (bits - 1)) | ||||
|         while not number.isPrime(n): | ||||
|             n += 2 | ||||
|         if bit_length(n) == bits: | ||||
|             return n | ||||
| 
 | ||||
| def roll_random(randpool, n): | ||||
|     "returns a random # from 0 to N-1" | ||||
|     bits = bit_length(n-1) | ||||
|     bytes = (bits + 7) // 8 | ||||
|     hbyte_mask = pow(2, bits % 8) - 1 | ||||
| 
 | ||||
|     # so here's the plan: | ||||
|     # we fetch as many random bits as we'd need to fit N-1, and if the | ||||
|     # generated number is >= N, we try again.  in the worst case (N-1 is a | ||||
|     # power of 2), we have slightly better than 50% odds of getting one that | ||||
|     # fits, so i can't guarantee that this loop will ever finish, but the odds | ||||
|     # of it looping forever should be infinitesimal. | ||||
|     while 1: | ||||
|         x = randpool.get_bytes(bytes) | ||||
|         if hbyte_mask > 0: | ||||
|             x = chr(ord(x[0]) & hbyte_mask) + x[1:] | ||||
|         num = inflate_long(x, 1) | ||||
|         if num < n: | ||||
|             return num | ||||
| 
 | ||||
| 
 | ||||
| class ModulusPack (object): | ||||
|     """ | ||||
|     convenience object for holding the contents of the /etc/ssh/moduli file, | ||||
|     on systems that have such a file. | ||||
|     """ | ||||
| 
 | ||||
|     def __init__(self, randpool): | ||||
|         # pack is a hash of: bits -> [ (generator, modulus) ... ] | ||||
|         self.pack = {} | ||||
|         self.discarded = [] | ||||
|         self.randpool = randpool | ||||
| 
 | ||||
|     def _parse_modulus(self, line): | ||||
|         timestamp, type, tests, tries, size, generator, modulus = line.split() | ||||
|         type = int(type) | ||||
|         tests = int(tests) | ||||
|         tries = int(tries) | ||||
|         size = int(size) | ||||
|         generator = int(generator) | ||||
|         modulus = long(modulus, 16) | ||||
| 
 | ||||
|         # weed out primes that aren't at least: | ||||
|         # type 2 (meets basic structural requirements) | ||||
|         # test 4 (more than just a small-prime sieve) | ||||
|         # tries < 100 if test & 4 (at least 100 tries of miller-rabin) | ||||
|         if (type < 2) or (tests < 4) or ((tests & 4) and (tests < 8) and (tries < 100)): | ||||
|             self.discarded.append((modulus, 'does not meet basic requirements')) | ||||
|             return | ||||
|         if generator == 0: | ||||
|             generator = 2 | ||||
| 
 | ||||
|         # there's a bug in the ssh "moduli" file (yeah, i know: shock! dismay! | ||||
|         # call cnn!) where it understates the bit lengths of these primes by 1. | ||||
|         # this is okay. | ||||
|         bl = bit_length(modulus) | ||||
|         if (bl != size) and (bl != size + 1): | ||||
|             self.discarded.append((modulus, 'incorrectly reported bit length %d' % size)) | ||||
|             return | ||||
|         if not self.pack.has_key(bl): | ||||
|             self.pack[bl] = [] | ||||
|         self.pack[bl].append((generator, modulus)) | ||||
| 
 | ||||
|     def read_file(self, filename): | ||||
|         """ | ||||
|         @raise IOError: passed from any file operations that fail. | ||||
|         """ | ||||
|         self.pack = {} | ||||
|         f = open(filename, 'r') | ||||
|         for line in f: | ||||
|             line = line.strip() | ||||
|             if (len(line) == 0) or (line[0] == '#'): | ||||
|                 continue | ||||
|             try: | ||||
|                 self._parse_modulus(line) | ||||
|             except: | ||||
|                 continue | ||||
|         f.close() | ||||
| 
 | ||||
|     def get_modulus(self, min, prefer, max): | ||||
|         bitsizes = self.pack.keys() | ||||
|         bitsizes.sort() | ||||
|         if len(bitsizes) == 0: | ||||
|             raise SSHException('no moduli available') | ||||
|         good = -1 | ||||
|         # find nearest bitsize >= preferred | ||||
|         for b in bitsizes: | ||||
|             if (b >= prefer) and (b < max) and ((b < good) or (good == -1)): | ||||
|                 good = b | ||||
|         # if that failed, find greatest bitsize >= min | ||||
|         if good == -1: | ||||
|             for b in bitsizes: | ||||
|                 if (b >= min) and (b < max) and  (b > good): | ||||
|                     good = b | ||||
|         if good == -1: | ||||
|             # their entire (min, max) range has no intersection with our range. | ||||
|             # if their range is below ours, pick the smallest.  otherwise pick | ||||
|             # the largest.  it'll be out of their range requirement either way, | ||||
|             # but we'll be sending them the closest one we have. | ||||
|             good = bitsizes[0] | ||||
|             if min > good: | ||||
|                 good = bitsizes[-1] | ||||
|         # now pick a random modulus of this bitsize | ||||
|         n = roll_random(self.randpool, len(self.pack[good])) | ||||
|         return self.pack[good][n] | ||||
| 
 | ||||
|  | @ -12,6 +12,7 @@ MSG_CHANNEL_OPEN, MSG_CHANNEL_OPEN_SUCCESS, MSG_CHANNEL_OPEN_FAILURE, \ | |||
| 	MSG_CHANNEL_SUCCESS, MSG_CHANNEL_FAILURE = range(90, 101) | ||||
| 
 | ||||
| import sys, os, string, threading, socket, logging, struct | ||||
| from ssh_exception import SSHException | ||||
| from message import Message | ||||
| from channel import Channel | ||||
| from util import format_binary, safe_string, inflate_long, deflate_long, tb_strings | ||||
|  | @ -19,6 +20,7 @@ from rsakey import RSAKey | |||
| from dsskey import DSSKey | ||||
| from kex_group1 import KexGroup1 | ||||
| from kex_gex import KexGex | ||||
| from primes import ModulusPack | ||||
| 
 | ||||
| # these come from PyCrypt | ||||
| #     http://www.amk.ca/python/writing/pycrypt/ | ||||
|  | @ -81,21 +83,21 @@ class BaseTransport(threading.Thread): | |||
|     preferred_keys = [ 'ssh-rsa', 'ssh-dss' ] | ||||
|     preferred_kex = [ 'diffie-hellman-group1-sha1', 'diffie-hellman-group-exchange-sha1' ] | ||||
| 
 | ||||
|     cipher_info = { | ||||
|     _cipher_info = { | ||||
|         'blowfish-cbc': { 'class': Blowfish, 'mode': Blowfish.MODE_CBC, 'block-size': 8, 'key-size': 16 }, | ||||
|         'aes128-cbc': { 'class': AES, 'mode': AES.MODE_CBC, 'block-size': 16, 'key-size': 16 }, | ||||
|         'aes256-cbc': { 'class': AES, 'mode': AES.MODE_CBC, 'block-size': 16, 'key-size': 32 }, | ||||
|         '3des-cbc': { 'class': DES3, 'mode': DES3.MODE_CBC, 'block-size': 8, 'key-size': 24 }, | ||||
|         } | ||||
| 
 | ||||
|     mac_info = { | ||||
|     _mac_info = { | ||||
|         'hmac-sha1': { 'class': SHA, 'size': 20 }, | ||||
|         'hmac-sha1-96': { 'class': SHA, 'size': 12 }, | ||||
|         'hmac-md5': { 'class': MD5, 'size': 16 }, | ||||
|         'hmac-md5-96': { 'class': MD5, 'size': 12 }, | ||||
|         } | ||||
| 
 | ||||
|     kex_info = { | ||||
|     _kex_info = { | ||||
|         'diffie-hellman-group1-sha1': KexGroup1, | ||||
|         'diffie-hellman-group-exchange-sha1': KexGex, | ||||
|         } | ||||
|  | @ -107,7 +109,7 @@ class BaseTransport(threading.Thread): | |||
| 	OPEN_FAILED_RESOURCE_SHORTAGE = range(1, 5) | ||||
| 
 | ||||
|     def __init__(self, sock): | ||||
|         threading.Thread.__init__(self) | ||||
|         threading.Thread.__init__(self, target=self._run) | ||||
|         self.randpool = randpool | ||||
|         self.sock = sock | ||||
|         self.sock.settimeout(0.1) | ||||
|  | @ -123,11 +125,10 @@ class BaseTransport(threading.Thread): | |||
|         self.session_id = None | ||||
|         # /negotiated crypto parameters | ||||
|         self.expected_packet = 0 | ||||
|         self.active = 0 | ||||
|         self.active = False | ||||
|         self.initial_kex_done = 0 | ||||
|         self.write_lock = threading.Lock()	# lock around outbound writes (packet computation) | ||||
|         self.lock = threading.Lock()		# synchronization (always higher level than write_lock) | ||||
|         self.authenticated = 0 | ||||
|         self.channels = { }			# (id -> Channel) | ||||
|         self.channel_events = { }		# (id -> Event) | ||||
|         self.channel_counter = 1 | ||||
|  | @ -135,6 +136,7 @@ class BaseTransport(threading.Thread): | |||
|         self.window_size = 65536 | ||||
|         self.max_packet_size = 2048 | ||||
|         self.ultra_debug = 0 | ||||
|         self.modulus_pack = None | ||||
|         # used for noticing when to re-key: | ||||
|         self.received_bytes = 0 | ||||
|         self.received_packets = 0 | ||||
|  | @ -165,27 +167,69 @@ class BaseTransport(threading.Thread): | |||
|         except KeyError: | ||||
|             return None | ||||
| 
 | ||||
|     def load_server_moduli(self, filename=None): | ||||
|         """ | ||||
|         I{(optional)} | ||||
|         Load a file of prime moduli for use in doing group-exchange key | ||||
|         negotiation in server mode.  It's a rather obscure option and can be | ||||
|         safely ignored. | ||||
| 
 | ||||
|         In server mode, the remote client may request "group-exchange" key | ||||
|         negotiation, which asks the server to send a random prime number that | ||||
|         fits certain criteria.  These primes are pretty difficult to compute, | ||||
|         so they can't be generated on demand.  But many systems contain a file | ||||
|         of suitable primes (usually named something like C{/etc/ssh/moduli}). | ||||
|         If you call C{load_server_moduli} and it returns C{True}, then this | ||||
|         file of primes has been loaded and we will support "group-exchange" in | ||||
|         server mode.  Otherwise server mode will just claim that it doesn't | ||||
|         support that method of key negotiation. | ||||
| 
 | ||||
|         @param filename: optional path to the moduli file, if you happen to | ||||
|         know that it's not in a standard location. | ||||
|         @type filename: string | ||||
|         @return: True if a moduli file was successfully loaded; False | ||||
|         otherwise. | ||||
|         @rtype: boolean | ||||
| 
 | ||||
|         @since: doduo | ||||
|          | ||||
|         @note: This has no effect when used in client mode. | ||||
|         """ | ||||
|         self.modulus_pack = ModulusPack(self.randpool) | ||||
|         # places to look for the openssh "moduli" file | ||||
|         file_list = [ '/etc/ssh/moduli', '/usr/local/etc/moduli' ] | ||||
|         if filename is not None: | ||||
|             file_list.insert(0, filename) | ||||
|         for fn in file_list: | ||||
|             try: | ||||
|                 self.modulus_pack.read_file(fn) | ||||
|                 return True | ||||
|             except IOError: | ||||
|                 pass | ||||
|         # none succeeded | ||||
|         self.modulus_pack = None | ||||
|         return False | ||||
| 
 | ||||
|     def _get_modulus_pack(self): | ||||
|         "used by KexGex to find primes for group exchange" | ||||
|         return self.modulus_pack | ||||
| 
 | ||||
|     def __repr__(self): | ||||
|         if not self.active: | ||||
|             return '<paramiko.Transport (unconnected)>' | ||||
|         out = '<sesch.Transport' | ||||
|             return '<paramiko.BaseTransport (unconnected)>' | ||||
|         out = '<paramiko.BaseTransport' | ||||
|         #if self.remote_version != '': | ||||
|         #    out += ' (server version "%s")' % self.remote_version | ||||
|         if self.local_cipher != '': | ||||
|             out += ' (cipher %s)' % self.local_cipher | ||||
|         if self.authenticated: | ||||
|             if len(self.channels) == 1: | ||||
|                 out += ' (active; 1 open channel)' | ||||
|             else: | ||||
|                 out += ' (active; %d open channels)' % len(self.channels) | ||||
|         elif self.initial_kex_done: | ||||
|             out += ' (connected; awaiting auth)' | ||||
|         if len(self.channels) == 1: | ||||
|             out += ' (active; 1 open channel)' | ||||
|         else: | ||||
|             out += ' (connecting)' | ||||
|             out += ' (active; %d open channels)' % len(self.channels) | ||||
|         out += '>' | ||||
|         return out | ||||
| 
 | ||||
|     def log(self, level, msg): | ||||
|     def _log(self, level, msg): | ||||
|         if type(msg) == type([]): | ||||
|             for m in msg: | ||||
|                 self.logger.log(level, m) | ||||
|  | @ -193,14 +237,29 @@ class BaseTransport(threading.Thread): | |||
|             self.logger.log(level, msg) | ||||
| 
 | ||||
|     def close(self): | ||||
|         self.active = 0 | ||||
|         """ | ||||
|         Close this session, and any open channels that are tied to it. | ||||
|         """ | ||||
|         self.active = False | ||||
|         self.engine_in = self.engine_out = None | ||||
|         self.sequence_number_in = self.sequence_number_out = 0L | ||||
|         for chan in self.channels.values(): | ||||
|             chan.unlink() | ||||
|             chan._unlink() | ||||
| 
 | ||||
|     def get_remote_server_key(self): | ||||
|         'returns (type, key) where type is like "ssh-rsa" and key is an opaque string' | ||||
|         """ | ||||
|         Return the host key of the server (in client mode). | ||||
|         The type string is usually either C{"ssh-rsa"} or C{"ssh-dss"} and the | ||||
|         key is an opaque string, which may be saved or used for comparison with | ||||
|         previously-seen keys.  (In other words, you don't need to worry about | ||||
|         the content of the key, only that it compares equal to the key you | ||||
|         expected to see.) | ||||
| 
 | ||||
|         @raise SSHException: if no session is currently active. | ||||
|          | ||||
|         @return: tuple of (key type, key) | ||||
|         @rtype: (string, string) | ||||
|         """ | ||||
|         if (not self.active) or (not self.initial_kex_done): | ||||
|             raise SSHException('No existing session') | ||||
|         key_msg = Message(self.host_key) | ||||
|  | @ -208,10 +267,13 @@ class BaseTransport(threading.Thread): | |||
|         return key_type, self.host_key | ||||
| 
 | ||||
|     def is_active(self): | ||||
|         return self.active | ||||
|         """ | ||||
|         Return true if this session is active (open). | ||||
| 
 | ||||
|     def is_authenticated(self): | ||||
|         return self.authenticated and self.active | ||||
|         @return: True if the session is still active (open); False if the session is closed. | ||||
|         @rtype: boolean | ||||
|         """ | ||||
|         return self.active | ||||
| 
 | ||||
|     def open_session(self): | ||||
|         return self.open_channel('session') | ||||
|  | @ -230,9 +292,9 @@ class BaseTransport(threading.Thread): | |||
|             m.add_int(self.max_packet_size) | ||||
|             self.channels[chanid] = chan = Channel(chanid) | ||||
|             self.channel_events[chanid] = event = threading.Event() | ||||
|             chan.set_transport(self) | ||||
|             chan._set_transport(self) | ||||
|             chan.set_window(self.window_size, self.max_packet_size) | ||||
|             self.send_message(m) | ||||
|             self._send_message(m) | ||||
|         finally: | ||||
|             self.lock.release() | ||||
|         while 1: | ||||
|  | @ -249,7 +311,8 @@ class BaseTransport(threading.Thread): | |||
|             self.lock.release() | ||||
|         return chan | ||||
| 
 | ||||
|     def unlink_channel(self, chanid): | ||||
|     def _unlink_channel(self, chanid): | ||||
|         "used by a Channel to remove itself from the active channel list" | ||||
|         try: | ||||
|             self.lock.acquire() | ||||
|             if self.channels.has_key(chanid): | ||||
|  | @ -257,7 +320,7 @@ class BaseTransport(threading.Thread): | |||
|         finally: | ||||
|             self.lock.release() | ||||
| 
 | ||||
|     def read_all(self, n): | ||||
|     def _read_all(self, n): | ||||
|         out = '' | ||||
|         while n > 0: | ||||
|             try: | ||||
|  | @ -271,7 +334,7 @@ class BaseTransport(threading.Thread): | |||
|                     raise EOFError() | ||||
|         return out | ||||
| 
 | ||||
|     def write_all(self, out): | ||||
|     def _write_all(self, out): | ||||
|         while len(out) > 0: | ||||
|             n = self.sock.send(out) | ||||
|             if n <= 0: | ||||
|  | @ -281,7 +344,7 @@ class BaseTransport(threading.Thread): | |||
|             out = out[n:] | ||||
|         return | ||||
| 
 | ||||
|     def build_packet(self, payload): | ||||
|     def _build_packet(self, payload): | ||||
|         # pad up at least 4 bytes, to nearest block-size (usually 8) | ||||
|         bsize = self.block_size_out | ||||
|         padding = 3 + bsize - ((len(payload) + 8) % bsize) | ||||
|  | @ -291,11 +354,12 @@ class BaseTransport(threading.Thread): | |||
|         packet += randpool.get_bytes(padding) | ||||
|         return packet | ||||
| 
 | ||||
|     def send_message(self, data): | ||||
|     def _send_message(self, data): | ||||
|         # FIXME: should we check for rekeying here too? | ||||
|         # encrypt this sucka | ||||
|         packet = self.build_packet(str(data)) | ||||
|         packet = self._build_packet(str(data)) | ||||
|         if self.ultra_debug: | ||||
|             self.log(DEBUG, format_binary(packet, 'OUT: ')) | ||||
|             self._log(DEBUG, format_binary(packet, 'OUT: ')) | ||||
|         if self.engine_out != None: | ||||
|             out = self.engine_out.encrypt(packet) | ||||
|         else: | ||||
|  | @ -308,29 +372,29 @@ class BaseTransport(threading.Thread): | |||
|                 out += HMAC.HMAC(self.mac_key_out, payload, self.local_mac_engine).digest()[:self.local_mac_len] | ||||
|             self.sequence_number_out += 1L | ||||
|             self.sequence_number_out %= 0x100000000L | ||||
|             self.write_all(out) | ||||
|             self._write_all(out) | ||||
|         finally: | ||||
|             self.write_lock.release() | ||||
| 
 | ||||
|     def read_message(self): | ||||
|     def _read_message(self): | ||||
|         "only one thread will ever be in this function" | ||||
|         header = self.read_all(self.block_size_in) | ||||
|         header = self._read_all(self.block_size_in) | ||||
|         if self.engine_in != None: | ||||
|             header = self.engine_in.decrypt(header) | ||||
|         if self.ultra_debug: | ||||
|             self.log(DEBUG, format_binary(header, 'IN: ')); | ||||
|             self._log(DEBUG, format_binary(header, 'IN: ')); | ||||
|         packet_size = struct.unpack('>I', header[:4])[0] | ||||
|         # leftover contains decrypted bytes from the first block (after the length field) | ||||
|         leftover = header[4:] | ||||
|         if (packet_size - len(leftover)) % self.block_size_in != 0: | ||||
|             raise SSHException('Invalid packet blocking') | ||||
|         buffer = self.read_all(packet_size + self.remote_mac_len - len(leftover)) | ||||
|         buffer = self._read_all(packet_size + self.remote_mac_len - len(leftover)) | ||||
|         packet = buffer[:packet_size - len(leftover)] | ||||
|         post_packet = buffer[packet_size - len(leftover):] | ||||
|         if self.engine_in != None: | ||||
|             packet = self.engine_in.decrypt(packet) | ||||
|         if self.ultra_debug: | ||||
|             self.log(DEBUG, format_binary(packet, 'IN: ')); | ||||
|             self._log(DEBUG, format_binary(packet, 'IN: ')); | ||||
|         packet = leftover + packet | ||||
|         if self.remote_mac_len > 0: | ||||
|             mac = post_packet[:self.remote_mac_len] | ||||
|  | @ -341,7 +405,7 @@ class BaseTransport(threading.Thread): | |||
|         padding = ord(packet[0]) | ||||
|         payload = packet[1:packet_size - padding + 1] | ||||
|         randpool.add_event(packet[packet_size - padding + 1]) | ||||
|         #self.log(DEBUG, 'Got payload (%d bytes, %d padding)' % (packet_size, padding)) | ||||
|         #self._log(DEBUG, 'Got payload (%d bytes, %d padding)' % (packet_size, padding)) | ||||
|         msg = Message(payload[1:]) | ||||
|         msg.seqno = self.sequence_number_in | ||||
|         self.sequence_number_in = (self.sequence_number_in + 1) & 0xffffffffL | ||||
|  | @ -351,10 +415,10 @@ class BaseTransport(threading.Thread): | |||
|         if (self.received_packets >= self.REKEY_PACKETS) or (self.received_bytes >= self.REKEY_BYTES): | ||||
|             # only ask once for rekeying | ||||
|             if self.local_kex_init is None: | ||||
|                 self.log(DEBUG, 'Rekeying (hit %d packets, %d bytes)' % (self.received_packets, | ||||
|                                                                          self.received_bytes)) | ||||
|                 self._log(DEBUG, 'Rekeying (hit %d packets, %d bytes)' % (self.received_packets, | ||||
|                                                                           self.received_bytes)) | ||||
|                 self.received_packets_overflow = 0 | ||||
|                 self.send_kex_init() | ||||
|                 self._send_kex_init() | ||||
|             else: | ||||
|                 # we've asked to rekey already -- give them 20 packets to | ||||
|                 # comply, then just drop the connection | ||||
|  | @ -364,14 +428,18 @@ class BaseTransport(threading.Thread): | |||
|                  | ||||
|         return ord(payload[0]), msg | ||||
| 
 | ||||
|     def set_K_H(self, k, h): | ||||
|     def _set_K_H(self, k, h): | ||||
|         "used by a kex object to set the K (root key) and H (exchange hash)" | ||||
|         self.K = k | ||||
|         self.H = h | ||||
|         if self.session_id == None: | ||||
|             self.session_id = h | ||||
| 
 | ||||
|     def verify_key(self, host_key, sig): | ||||
|     def _expect_packet(self, type): | ||||
|         "used by a kex object to register the next packet type it expects to see" | ||||
|         self.expected_packet = type | ||||
| 
 | ||||
|     def _verify_key(self, host_key, sig): | ||||
|         if self.host_key_type == 'ssh-rsa': | ||||
|             key = RSAKey(Message(host_key)) | ||||
|         elif self.host_key_type == 'ssh-dss': | ||||
|  | @ -384,7 +452,7 @@ class BaseTransport(threading.Thread): | |||
|             raise SSHException('Signature verification (%s) failed.  Boo.  Robey should debug this.' % self.host_key_type) | ||||
|         self.host_key = host_key | ||||
| 
 | ||||
|     def compute_key(self, id, nbytes): | ||||
|     def _compute_key(self, id, nbytes): | ||||
|         "id is 'A' - 'F' for the various keys used by ssh" | ||||
|         m = Message() | ||||
|         m.add_mpint(self.K) | ||||
|  | @ -402,30 +470,30 @@ class BaseTransport(threading.Thread): | |||
|             sofar += hash | ||||
|         return out[:nbytes] | ||||
| 
 | ||||
|     def get_cipher(self, name, key, iv): | ||||
|         if not self.cipher_info.has_key(name): | ||||
|     def _get_cipher(self, name, key, iv): | ||||
|         if not self._cipher_info.has_key(name): | ||||
|             raise SSHException('Unknown client cipher ' + name) | ||||
|         return self.cipher_info[name]['class'].new(key, self.cipher_info[name]['mode'], iv) | ||||
|         return self._cipher_info[name]['class'].new(key, self._cipher_info[name]['mode'], iv) | ||||
| 
 | ||||
|     def run(self): | ||||
|         self.active = 1 | ||||
|     def _run(self): | ||||
|         self.active = True | ||||
|         try: | ||||
|             # SSH-1.99-OpenSSH_2.9p2 | ||||
|             self.write_all(self.local_version + '\r\n') | ||||
|             self.check_banner() | ||||
|             self.send_kex_init() | ||||
|             self._write_all(self.local_version + '\r\n') | ||||
|             self._check_banner() | ||||
|             self._send_kex_init() | ||||
|             self.expected_packet = MSG_KEXINIT | ||||
| 
 | ||||
|             while self.active: | ||||
|                 ptype, m = self.read_message() | ||||
|                 ptype, m = self._read_message() | ||||
|                 if ptype == MSG_IGNORE: | ||||
|                     continue | ||||
|                 elif ptype == MSG_DISCONNECT: | ||||
|                     self.parse_disconnect(m) | ||||
|                     self.active = 0 | ||||
|                     self._parse_disconnect(m) | ||||
|                     self.active = False | ||||
|                     break | ||||
|                 elif ptype == MSG_DEBUG: | ||||
|                     self.parse_debug(m) | ||||
|                     self._parse_debug(m) | ||||
|                     continue | ||||
|                 if self.expected_packet != 0: | ||||
|                     if ptype != self.expected_packet: | ||||
|  | @ -435,28 +503,29 @@ class BaseTransport(threading.Thread): | |||
|                         self.kex_engine.parse_next(ptype, m) | ||||
|                         continue | ||||
| 
 | ||||
|                 if self.handler_table.has_key(ptype): | ||||
|                     self.handler_table[ptype](self, m) | ||||
|                 elif self.channel_handler_table.has_key(ptype): | ||||
|                 if self._handler_table.has_key(ptype): | ||||
|                     self._handler_table[ptype](self, m) | ||||
|                 elif self._channel_handler_table.has_key(ptype): | ||||
|                     chanid = m.get_int() | ||||
|                     if self.channels.has_key(chanid): | ||||
|                         self.channel_handler_table[ptype](self.channels[chanid], m) | ||||
|                         self._channel_handler_table[ptype](self.channels[chanid], m) | ||||
|                 else: | ||||
|                     self.log(WARNING, 'Oops, unhandled type %d' % ptype) | ||||
|                     self._log(WARNING, 'Oops, unhandled type %d' % ptype) | ||||
|                     msg = Message() | ||||
|                     msg.add_byte(chr(MSG_UNIMPLEMENTED)) | ||||
|                     msg.add_int(m.seqno) | ||||
|                     self.send_message(msg) | ||||
|                     self._send_message(msg) | ||||
|         except SSHException, e: | ||||
|             self.log(DEBUG, 'Exception: ' + str(e)) | ||||
|             self.log(DEBUG, tb_strings()) | ||||
|             self._log(DEBUG, 'Exception: ' + str(e)) | ||||
|             self._log(DEBUG, tb_strings()) | ||||
|         except EOFError, e: | ||||
|             self.log(DEBUG, 'EOF') | ||||
|             self._log(DEBUG, 'EOF') | ||||
|             self._log(DEBUG, tb_strings()) | ||||
|         except Exception, e: | ||||
|             self.log(DEBUG, 'Unknown exception: ' + str(e)) | ||||
|             self.log(DEBUG, tb_strings()) | ||||
|             self._log(DEBUG, 'Unknown exception: ' + str(e)) | ||||
|             self._log(DEBUG, tb_strings()) | ||||
|         if self.active: | ||||
|             self.active = 0 | ||||
|             self.active = False | ||||
|             if self.completion_event != None: | ||||
|                 self.completion_event.set() | ||||
|             if self.auth_event != None: | ||||
|  | @ -468,36 +537,49 @@ class BaseTransport(threading.Thread): | |||
|     ###  protocol stages | ||||
| 
 | ||||
|     def renegotiate_keys(self): | ||||
|         """ | ||||
|         Force this session to switch to new keys.  Normally this is done | ||||
|         automatically after the session hits a certain number of packets or | ||||
|         bytes sent or received, but this method gives you the option of forcing | ||||
|         new keys whenever you want.  Negotiating new keys causes a pause in | ||||
|         traffic both ways as the two sides swap keys and do computations.  This | ||||
|         method returns when the session has switched to new keys, or the | ||||
|         session has died mid-negotiation. | ||||
| 
 | ||||
|         @return: True if the renegotiation was successful, and the link is | ||||
|         using new keys; False if the session dropped during renegotiation. | ||||
|         @rtype: boolean | ||||
|         """ | ||||
|         self.completion_event = threading.Event() | ||||
|         self.send_kex_init() | ||||
|         self._send_kex_init() | ||||
|         while 1: | ||||
|             self.completion_event.wait(0.1); | ||||
|             if not self.active: | ||||
|                 return 0 | ||||
|                 return False | ||||
|             if self.completion_event.isSet(): | ||||
|                 break | ||||
|         return 1 | ||||
|         return True | ||||
| 
 | ||||
|     def negotiate_keys(self, m): | ||||
|     def _negotiate_keys(self, m): | ||||
|         # throws SSHException on anything unusual | ||||
|         if self.local_kex_init == None: | ||||
|             # remote side wants to renegotiate | ||||
|             self.send_kex_init() | ||||
|         self.parse_kex_init(m) | ||||
|             self._send_kex_init() | ||||
|         self._parse_kex_init(m) | ||||
|         self.kex_engine.start_kex() | ||||
| 
 | ||||
|     def check_banner(self): | ||||
|     def _check_banner(self): | ||||
|         # this is slow, but we only have to do it once | ||||
|         for i in range(5): | ||||
|             buffer = '' | ||||
|             while not '\n' in buffer: | ||||
|                 buffer += self.read_all(1) | ||||
|                 buffer += self._read_all(1) | ||||
|             buffer = buffer[:-1] | ||||
|             if (len(buffer) > 0) and (buffer[-1] == '\r'): | ||||
|                 buffer = buffer[:-1] | ||||
|             if buffer[:4] == 'SSH-': | ||||
|                 break | ||||
|             self.log(DEBUG, 'Banner: ' + buffer) | ||||
|             self._log(DEBUG, 'Banner: ' + buffer) | ||||
|         if buffer[:4] != 'SSH-': | ||||
|             raise SSHException('Indecipherable protocol version "' + buffer + '"') | ||||
|         # save this server version string for later | ||||
|  | @ -516,17 +598,21 @@ class BaseTransport(threading.Thread): | |||
|         client = segs[2] | ||||
|         if version != '1.99' and version != '2.0': | ||||
|             raise SSHException('Incompatible version (%s instead of 2.0)' % (version,)) | ||||
|         self.log(INFO, 'Connected (version %s, client %s)' % (version, client)) | ||||
|         self._log(INFO, 'Connected (version %s, client %s)' % (version, client)) | ||||
| 
 | ||||
|     def send_kex_init(self): | ||||
|         # send a really wimpy kex-init packet that says we're a bare-bones ssh client | ||||
|     def _send_kex_init(self): | ||||
|         """ | ||||
|         announce to the other side that we'd like to negotiate keys, and what | ||||
|         kind of key negotiation we support. | ||||
|         """ | ||||
|         if self.server_mode: | ||||
|             # FIXME: can't do group-exchange (gex) yet -- too slow | ||||
|             if 'diffie-hellman-group-exchange-sha1' in self.preferred_kex: | ||||
|             if (self.modulus_pack is None) and ('diffie-hellman-group-exchange-sha1' in self.preferred_kex): | ||||
|                 # can't do group-exchange if we don't have a pack of potential primes | ||||
|                 self.preferred_kex.remove('diffie-hellman-group-exchange-sha1') | ||||
| 
 | ||||
|         available_server_keys = filter(self.server_key_dict.keys().__contains__, | ||||
|                                        self.preferred_keys) | ||||
|             available_server_keys = filter(self.server_key_dict.keys().__contains__, | ||||
|                                            self.preferred_keys) | ||||
|         else: | ||||
|             available_server_keys = self.preferred_keys | ||||
| 
 | ||||
|         m = Message() | ||||
|         m.add_byte(chr(MSG_KEXINIT)) | ||||
|  | @ -545,9 +631,9 @@ class BaseTransport(threading.Thread): | |||
|         m.add_int(0) | ||||
|         # save a copy for later (needed to compute a hash) | ||||
|         self.local_kex_init = str(m) | ||||
|         self.send_message(m) | ||||
|         self._send_message(m) | ||||
| 
 | ||||
|     def parse_kex_init(self, m): | ||||
|     def _parse_kex_init(self, m): | ||||
|         # reset counters of when to re-key, since we are now re-keying | ||||
|         self.received_bytes = 0 | ||||
|         self.received_packets = 0 | ||||
|  | @ -580,7 +666,7 @@ class BaseTransport(threading.Thread): | |||
|             agreed_kex = filter(kex_algo_list.__contains__, self.preferred_kex) | ||||
|         if len(agreed_kex) == 0: | ||||
|             raise SSHException('Incompatible ssh peer (no acceptable kex algorithm)') | ||||
|         self.kex_engine = self.kex_info[agreed_kex[0]](self) | ||||
|         self.kex_engine = self._kex_info[agreed_kex[0]](self) | ||||
| 
 | ||||
|         if self.server_mode: | ||||
|             available_server_keys = filter(self.server_key_dict.keys().__contains__, | ||||
|  | @ -608,7 +694,7 @@ class BaseTransport(threading.Thread): | |||
|             raise SSHException('Incompatible ssh server (no acceptable ciphers)') | ||||
|         self.local_cipher = agreed_local_ciphers[0] | ||||
|         self.remote_cipher = agreed_remote_ciphers[0] | ||||
|         self.log(DEBUG, 'Ciphers agreed: local=%s, remote=%s' % (self.local_cipher, self.remote_cipher)) | ||||
|         self._log(DEBUG, 'Ciphers agreed: local=%s, remote=%s' % (self.local_cipher, self.remote_cipher)) | ||||
| 
 | ||||
|         if self.server_mode: | ||||
|             agreed_remote_macs = filter(self.preferred_macs.__contains__, client_mac_algo_list) | ||||
|  | @ -621,19 +707,19 @@ class BaseTransport(threading.Thread): | |||
|         self.local_mac = agreed_local_macs[0] | ||||
|         self.remote_mac = agreed_remote_macs[0] | ||||
| 
 | ||||
|         self.log(DEBUG, 'kex algos:' + str(kex_algo_list) + ' server key:' + str(server_key_algo_list) + \ | ||||
|                  ' client encrypt:' + str(client_encrypt_algo_list) + \ | ||||
|                  ' server encrypt:' + str(server_encrypt_algo_list) + \ | ||||
|                  ' client mac:' + str(client_mac_algo_list) + \ | ||||
|                  ' server mac:' + str(server_mac_algo_list) + \ | ||||
|                  ' client compress:' + str(client_compress_algo_list) + \ | ||||
|                  ' server compress:' + str(server_compress_algo_list) + \ | ||||
|                  ' client lang:' + str(client_lang_list) + \ | ||||
|                  ' server lang:' + str(server_lang_list) + \ | ||||
|                  ' kex follows?' + str(kex_follows)) | ||||
|         self.log(DEBUG, 'using kex %s; server key type %s; cipher: local %s, remote %s; mac: local %s, remote %s' % | ||||
|                  (agreed_kex[0], self.host_key_type, self.local_cipher, self.remote_cipher, self.local_mac, | ||||
|                   self.remote_mac)) | ||||
|         self._log(DEBUG, 'kex algos:' + str(kex_algo_list) + ' server key:' + str(server_key_algo_list) + \ | ||||
|                   ' client encrypt:' + str(client_encrypt_algo_list) + \ | ||||
|                   ' server encrypt:' + str(server_encrypt_algo_list) + \ | ||||
|                   ' client mac:' + str(client_mac_algo_list) + \ | ||||
|                   ' server mac:' + str(server_mac_algo_list) + \ | ||||
|                   ' client compress:' + str(client_compress_algo_list) + \ | ||||
|                   ' server compress:' + str(server_compress_algo_list) + \ | ||||
|                   ' client lang:' + str(client_lang_list) + \ | ||||
|                   ' server lang:' + str(server_lang_list) + \ | ||||
|                   ' kex follows?' + str(kex_follows)) | ||||
|         self._log(DEBUG, 'using kex %s; server key type %s; cipher: local %s, remote %s; mac: local %s, remote %s' % | ||||
|                   (agreed_kex[0], self.host_key_type, self.local_cipher, self.remote_cipher, self.local_mac, | ||||
|                    self.remote_mac)) | ||||
| 
 | ||||
|         # save for computing hash later... | ||||
|         # now wait!  openssh has a bug (and others might too) where there are | ||||
|  | @ -642,50 +728,52 @@ class BaseTransport(threading.Thread): | |||
|         # away those bytes because they aren't part of the hash. | ||||
|         self.remote_kex_init = chr(MSG_KEXINIT) + m.get_so_far() | ||||
| 
 | ||||
|     def activate_inbound(self): | ||||
|     def _activate_inbound(self): | ||||
|         "switch on newly negotiated encryption parameters for inbound traffic" | ||||
|         self.block_size_in = self.cipher_info[self.remote_cipher]['block-size'] | ||||
|         self.block_size_in = self._cipher_info[self.remote_cipher]['block-size'] | ||||
|         if self.server_mode: | ||||
|             IV_in = self.compute_key('A', self.block_size_in) | ||||
|             key_in = self.compute_key('C', self.cipher_info[self.remote_cipher]['key-size']) | ||||
|             IV_in = self._compute_key('A', self.block_size_in) | ||||
|             key_in = self._compute_key('C', self._cipher_info[self.remote_cipher]['key-size']) | ||||
|         else: | ||||
|             IV_in = self.compute_key('B', self.block_size_in) | ||||
|             key_in = self.compute_key('D', self.cipher_info[self.remote_cipher]['key-size']) | ||||
|         self.engine_in = self.get_cipher(self.remote_cipher, key_in, IV_in) | ||||
|         self.remote_mac_len = self.mac_info[self.remote_mac]['size'] | ||||
|         self.remote_mac_engine = self.mac_info[self.remote_mac]['class'] | ||||
|             IV_in = self._compute_key('B', self.block_size_in) | ||||
|             key_in = self._compute_key('D', self._cipher_info[self.remote_cipher]['key-size']) | ||||
|         self.engine_in = self._get_cipher(self.remote_cipher, key_in, IV_in) | ||||
|         self.remote_mac_len = self._mac_info[self.remote_mac]['size'] | ||||
|         self.remote_mac_engine = self._mac_info[self.remote_mac]['class'] | ||||
|         # initial mac keys are done in the hash's natural size (not the potentially truncated | ||||
|         # transmission size) | ||||
|         if self.server_mode: | ||||
|             self.mac_key_in = self.compute_key('E', self.remote_mac_engine.digest_size) | ||||
|             self.mac_key_in = self._compute_key('E', self.remote_mac_engine.digest_size) | ||||
|         else: | ||||
|             self.mac_key_in = self.compute_key('F', self.remote_mac_engine.digest_size) | ||||
|             self.mac_key_in = self._compute_key('F', self.remote_mac_engine.digest_size) | ||||
| 
 | ||||
|     def activate_outbound(self): | ||||
|     def _activate_outbound(self): | ||||
|         "switch on newly negotiated encryption parameters for outbound traffic" | ||||
|         m = Message() | ||||
|         m.add_byte(chr(MSG_NEWKEYS)) | ||||
|         self.send_message(m) | ||||
|         self.block_size_out = self.cipher_info[self.local_cipher]['block-size'] | ||||
|         self._send_message(m) | ||||
|         self.block_size_out = self._cipher_info[self.local_cipher]['block-size'] | ||||
|         if self.server_mode: | ||||
|             IV_out = self.compute_key('B', self.block_size_out) | ||||
|             key_out = self.compute_key('D', self.cipher_info[self.local_cipher]['key-size']) | ||||
|             IV_out = self._compute_key('B', self.block_size_out) | ||||
|             key_out = self._compute_key('D', self._cipher_info[self.local_cipher]['key-size']) | ||||
|         else: | ||||
|             IV_out = self.compute_key('A', self.block_size_out) | ||||
|             key_out = self.compute_key('C', self.cipher_info[self.local_cipher]['key-size']) | ||||
|         self.engine_out = self.get_cipher(self.local_cipher, key_out, IV_out) | ||||
|         self.local_mac_len = self.mac_info[self.local_mac]['size'] | ||||
|         self.local_mac_engine = self.mac_info[self.local_mac]['class'] | ||||
|             IV_out = self._compute_key('A', self.block_size_out) | ||||
|             key_out = self._compute_key('C', self._cipher_info[self.local_cipher]['key-size']) | ||||
|         self.engine_out = self._get_cipher(self.local_cipher, key_out, IV_out) | ||||
|         self.local_mac_len = self._mac_info[self.local_mac]['size'] | ||||
|         self.local_mac_engine = self._mac_info[self.local_mac]['class'] | ||||
|         # initial mac keys are done in the hash's natural size (not the potentially truncated | ||||
|         # transmission size) | ||||
|         if self.server_mode: | ||||
|             self.mac_key_out = self.compute_key('F', self.local_mac_engine.digest_size) | ||||
|             self.mac_key_out = self._compute_key('F', self.local_mac_engine.digest_size) | ||||
|         else: | ||||
|             self.mac_key_out = self.compute_key('E', self.local_mac_engine.digest_size) | ||||
|             self.mac_key_out = self._compute_key('E', self.local_mac_engine.digest_size) | ||||
|         # we always expect to receive NEWKEYS now | ||||
|         self.expected_packet = MSG_NEWKEYS | ||||
| 
 | ||||
|     def parse_newkeys(self, m): | ||||
|         self.log(DEBUG, 'Switch to new keys ...') | ||||
|         self.activate_inbound() | ||||
|     def _parse_newkeys(self, m): | ||||
|         self._log(DEBUG, 'Switch to new keys ...') | ||||
|         self._activate_inbound() | ||||
|         # can also free a bunch of stuff here | ||||
|         self.local_kex_init = self.remote_kex_init = None | ||||
|         self.e = self.f = self.K = self.x = None | ||||
|  | @ -697,24 +785,24 @@ class BaseTransport(threading.Thread): | |||
|             self.completion_event.set() | ||||
|         return | ||||
| 
 | ||||
|     def parse_disconnect(self, m): | ||||
|     def _parse_disconnect(self, m): | ||||
|         code = m.get_int() | ||||
|         desc = m.get_string() | ||||
|         self.log(INFO, 'Disconnect (code %d): %s' % (code, desc)) | ||||
|         self._log(INFO, 'Disconnect (code %d): %s' % (code, desc)) | ||||
| 
 | ||||
|     def parse_channel_open_success(self, m): | ||||
|     def _parse_channel_open_success(self, m): | ||||
|         chanid = m.get_int() | ||||
|         server_chanid = m.get_int() | ||||
|         server_window_size = m.get_int() | ||||
|         server_max_packet_size = m.get_int() | ||||
|         if not self.channels.has_key(chanid): | ||||
|             self.log(WARNING, 'Success for unrequested channel! [??]') | ||||
|             self._log(WARNING, 'Success for unrequested channel! [??]') | ||||
|             return | ||||
|         try: | ||||
|             self.lock.acquire() | ||||
|             chan = self.channels[chanid] | ||||
|             chan.set_remote_channel(server_chanid, server_window_size, server_max_packet_size) | ||||
|             self.log(INFO, 'Secsh channel %d opened.' % chanid) | ||||
|             self._log(INFO, 'Secsh channel %d opened.' % chanid) | ||||
|             if self.channel_events.has_key(chanid): | ||||
|                 self.channel_events[chanid].set() | ||||
|                 del self.channel_events[chanid] | ||||
|  | @ -722,7 +810,7 @@ class BaseTransport(threading.Thread): | |||
|             self.lock.release() | ||||
|         return | ||||
| 
 | ||||
|     def parse_channel_open_failure(self, m): | ||||
|     def _parse_channel_open_failure(self, m): | ||||
|         chanid = m.get_int() | ||||
|         reason = m.get_int() | ||||
|         reason_str = m.get_string() | ||||
|  | @ -731,7 +819,7 @@ class BaseTransport(threading.Thread): | |||
|             reason_text = CONNECTION_FAILED_CODE[reason] | ||||
|         else: | ||||
|             reason_text = '(unknown code)' | ||||
|         self.log(INFO, 'Secsh channel %d open FAILED: %s: %s' % (chanid, reason_str, reason_text)) | ||||
|         self._log(INFO, 'Secsh channel %d open FAILED: %s: %s' % (chanid, reason_str, reason_text)) | ||||
|         try: | ||||
|             self.lock.aquire() | ||||
|             if self.channels.has_key(chanid): | ||||
|  | @ -747,14 +835,14 @@ class BaseTransport(threading.Thread): | |||
|         "override me!  return object descended from Channel to allow, or None to reject" | ||||
|         return None | ||||
| 
 | ||||
|     def parse_channel_open(self, m): | ||||
|     def _parse_channel_open(self, m): | ||||
|         kind = m.get_string() | ||||
|         chanid = m.get_int() | ||||
|         initial_window_size = m.get_int() | ||||
|         max_packet_size = m.get_int() | ||||
|         reject = False | ||||
|         if not self.server_mode: | ||||
|             self.log(DEBUG, 'Rejecting "%s" channel request from server.' % kind) | ||||
|             self._log(DEBUG, 'Rejecting "%s" channel request from server.' % kind) | ||||
|             reject = True | ||||
|             reason = self.OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED | ||||
|         else: | ||||
|  | @ -766,7 +854,7 @@ class BaseTransport(threading.Thread): | |||
|                 self.lock.release() | ||||
|             chan = self.check_channel_request(kind, my_chanid) | ||||
|             if (chan is None) or (type(chan) is int): | ||||
|                 self.log(DEBUG, 'Rejecting "%s" channel request from client.' % kind) | ||||
|                 self._log(DEBUG, 'Rejecting "%s" channel request from client.' % kind) | ||||
|                 reject = True | ||||
|                 if type(chan) is int: | ||||
|                     reason = chan | ||||
|  | @ -779,12 +867,12 @@ class BaseTransport(threading.Thread): | |||
|             msg.add_int(reason) | ||||
|             msg.add_string('') | ||||
|             msg.add_string('en') | ||||
|             self.send_message(msg) | ||||
|             self._send_message(msg) | ||||
|             return | ||||
|         try: | ||||
|             self.lock.acquire() | ||||
|             self.channels[my_chanid] = chan | ||||
|             chan.set_transport(self) | ||||
|             chan._set_transport(self) | ||||
|             chan.set_window(self.window_size, self.max_packet_size) | ||||
|             chan.set_remote_channel(chanid, initial_window_size, max_packet_size) | ||||
|         finally: | ||||
|  | @ -795,8 +883,8 @@ class BaseTransport(threading.Thread): | |||
|         m.add_int(my_chanid) | ||||
|         m.add_int(self.window_size) | ||||
|         m.add_int(self.max_packet_size) | ||||
|         self.send_message(m) | ||||
|         self.log(INFO, 'Secsh channel %d opened.' % my_chanid) | ||||
|         self._send_message(m) | ||||
|         self._log(INFO, 'Secsh channel %d opened.' % my_chanid) | ||||
|         try: | ||||
|             self.lock.acquire() | ||||
|             self.server_accepts.append(chan) | ||||
|  | @ -820,21 +908,21 @@ class BaseTransport(threading.Thread): | |||
|             self.lock.release() | ||||
|         return chan | ||||
| 
 | ||||
|     def parse_debug(self, m): | ||||
|     def _parse_debug(self, m): | ||||
|         always_display = m.get_boolean() | ||||
|         msg = m.get_string() | ||||
|         lang = m.get_string() | ||||
|         self.log(DEBUG, 'Debug msg: ' + safe_string(msg)) | ||||
|         self._log(DEBUG, 'Debug msg: ' + safe_string(msg)) | ||||
| 
 | ||||
|     handler_table = { | ||||
|         MSG_NEWKEYS: parse_newkeys, | ||||
|         MSG_CHANNEL_OPEN_SUCCESS: parse_channel_open_success, | ||||
|         MSG_CHANNEL_OPEN_FAILURE: parse_channel_open_failure, | ||||
|         MSG_CHANNEL_OPEN: parse_channel_open, | ||||
|         MSG_KEXINIT: negotiate_keys, | ||||
|     _handler_table = { | ||||
|         MSG_NEWKEYS: _parse_newkeys, | ||||
|         MSG_CHANNEL_OPEN_SUCCESS: _parse_channel_open_success, | ||||
|         MSG_CHANNEL_OPEN_FAILURE: _parse_channel_open_failure, | ||||
|         MSG_CHANNEL_OPEN: _parse_channel_open, | ||||
|         MSG_KEXINIT: _negotiate_keys, | ||||
|         } | ||||
| 
 | ||||
|     channel_handler_table = { | ||||
|     _channel_handler_table = { | ||||
|         MSG_CHANNEL_SUCCESS: Channel.request_success, | ||||
|         MSG_CHANNEL_FAILURE: Channel.request_failed, | ||||
|         MSG_CHANNEL_DATA: Channel.feed, | ||||
|  |  | |||
|  | @ -1,7 +1,6 @@ | |||
| #!/usr/bin/python | ||||
| 
 | ||||
| import sys, struct, traceback | ||||
| from Crypto.Util import number | ||||
| 
 | ||||
| def inflate_long(s, always_positive=0): | ||||
|     "turns a normalized byte string into a long-int (adapted from Crypto.Util.number)" | ||||
|  | @ -98,20 +97,5 @@ def bit_length(n): | |||
|         bitlen -= 1 | ||||
|     return bitlen | ||||
| 
 | ||||
| def generate_prime(bits, randpool): | ||||
|     hbyte_mask = pow(2, bits % 8) - 1 | ||||
|     x = randpool.get_bytes((bits+7) // 8) | ||||
|     if hbyte_mask > 0: | ||||
|         x = chr(ord(x[0]) & hbyte_mask) + x[1:] | ||||
|     n = inflate_long(x, 1) | ||||
|     n |= 1 | ||||
|     n |= (1 << (bits - 1)) | ||||
|     while 1: | ||||
|         # loop catches the case where we increment n into a higher bit-range | ||||
|         while not number.isPrime(n): | ||||
|             n += 2 | ||||
|         if bit_length(n) == bits: | ||||
|             return n | ||||
| 
 | ||||
| def tb_strings(): | ||||
|     return ''.join(traceback.format_exception(*sys.exc_info())).split('\n') | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue