From fec76c51b161dbf823e0b3d58d0fd0ab324d2443 Mon Sep 17 00:00:00 2001 From: Robey Pointer Date: Mon, 20 Nov 2006 16:11:17 -0800 Subject: [PATCH] [project @ robey@lag.net-20061121001117-8mf8zzltvfvzzrv7] add support for opening x11 channels, and a unit test --- paramiko/channel.py | 109 +++++++++++++++++++++++++++----- paramiko/server.py | 24 +++++++ paramiko/transport.py | 92 ++++++++++++++++++++------- tests/test_transport.py | 137 ++++++++++++++++------------------------ 4 files changed, 242 insertions(+), 120 deletions(-) diff --git a/paramiko/channel.py b/paramiko/channel.py index 48b4487..51f8e78 100644 --- a/paramiko/channel.py +++ b/paramiko/channel.py @@ -20,6 +20,7 @@ Abstraction for an SSH2 channel. """ +import binascii import sys import time import threading @@ -35,6 +36,10 @@ from paramiko.buffered_pipe import BufferedPipe, PipeTimeout from paramiko import pipe +# lower bound on the max packet size we'll accept from the remote host +MIN_PACKET_SIZE = 1024 + + class Channel (object): """ A secure tunnel across an SSH L{Transport}. A Channel is meant to behave @@ -50,9 +55,6 @@ class Channel (object): is exactly like a normal network socket, so it shouldn't be too surprising. """ - # lower bound on the max packet size we'll accept from the remote host - MIN_PACKET_SIZE = 1024 - def __init__(self, chanid): """ Create a new channel. The channel is not associated with any @@ -84,7 +86,7 @@ class Channel (object): self.in_window_threshold = 0 self.in_window_sofar = 0 self.status_event = threading.Event() - self.name = str(chanid) + self._name = str(chanid) self.logger = util.get_logger('paramiko.chan.' + str(chanid)) self._pipe = None self.event = threading.Event() @@ -202,7 +204,7 @@ class Channel (object): m.add_byte(chr(MSG_CHANNEL_REQUEST)) m.add_int(self.remote_chanid) m.add_string('exec') - m.add_boolean(1) + m.add_boolean(True) m.add_string(command) self.event.clear() self.transport._send_user_message(m) @@ -229,7 +231,7 @@ class Channel (object): m.add_byte(chr(MSG_CHANNEL_REQUEST)) m.add_int(self.remote_chanid) m.add_string('subsystem') - m.add_boolean(1) + m.add_boolean(True) m.add_string(subsystem) self.event.clear() self.transport._send_user_message(m) @@ -254,7 +256,7 @@ class Channel (object): m.add_byte(chr(MSG_CHANNEL_REQUEST)) m.add_int(self.remote_chanid) m.add_string('window-change') - m.add_boolean(1) + m.add_boolean(True) m.add_int(width) m.add_int(height) m.add_int(0).add_int(0) @@ -299,10 +301,73 @@ class Channel (object): m.add_byte(chr(MSG_CHANNEL_REQUEST)) m.add_int(self.remote_chanid) m.add_string('exit-status') - m.add_boolean(0) + m.add_boolean(False) m.add_int(status) self.transport._send_user_message(m) + + def request_x11(self, screen_number=0, auth_protocol=None, auth_cookie=None, + single_connection=False, handler=None): + """ + Request an x11 session on this channel. If the server allows it, + further x11 requests can be made from the server to the client, + when an x11 application is run in a shell session. + From RFC4254:: + + It is RECOMMENDED that the 'x11 authentication cookie' that is + sent be a fake, random cookie, and that the cookie be checked and + replaced by the real cookie when a connection request is received. + + If you omit the auth_cookie, a new secure random 128-bit value will be + generated, used, and returned. You will need to use this value to + verify incoming x11 requests and replace them with the actual local + x11 cookie (which requires some knoweldge of the x11 protocol). + + If a handler is passed in, the handler is called from another thread + whenever a new x11 connection arrives. The default handler queues up + incoming x11 connections, which may be retrieved using + L{Transport.accept}. The handler's calling signature is:: + + handler(channel: Channel, (address: str, port: int)) + + @param screen_number: the x11 screen number (0, 10, etc) + @type screen_number: int + @param auth_protocol: the name of the X11 authentication method used; + if none is given, C{"MIT-MAGIC-COOKIE-1"} is used + @type auth_proto: str + @param auth_cookie: hexadecimal string containing the x11 auth cookie; + if none is given, a secure random 128-bit value is generated + @type auth_cookie: str + @param single_connection: if True, only a single x11 connection will be + forwarded (by default, any number of x11 connections can arrive + over this session) + @type single_connection: bool + @param handler: an optional handler to use for incoming X11 connections + @type handler: function + @return: the auth_cookie used + """ + if self.closed or self.eof_received or self.eof_sent or not self.active: + raise SSHException('Channel is not open') + if auth_protocol is None: + auth_protocol = 'MIT-MAGIC-COOKIE-1' + if auth_cookie is None: + auth_cookie = binascii.hexlify(self.transport.randpool.get_bytes(16)) + + m = Message() + m.add_byte(chr(MSG_CHANNEL_REQUEST)) + m.add_int(self.remote_chanid) + m.add_string('x11-req') + m.add_boolean(True) + m.add_boolean(single_connection) + m.add_string(auth_protocol) + m.add_string(auth_cookie) + m.add_int(screen_number) + self.event.clear() + self.transport._send_user_message(m) + self._wait_for_event() + self.transport._set_x11_handler(handler) + return auth_cookie + def get_transport(self): """ Return the L{Transport} associated with this channel. @@ -321,8 +386,8 @@ class Channel (object): @param name: new channel name. @type name: str """ - self.name = name - self.logger = util.get_logger(self.transport.get_log_channel() + '.' + self.name) + self._name = name + self.logger = util.get_logger(self.transport.get_log_channel() + '.' + self._name) def get_name(self): """ @@ -331,7 +396,7 @@ class Channel (object): @return: the name of this channel. @rtype: str """ - return self.name + return self._name def get_id(self): """ @@ -796,7 +861,7 @@ class Channel (object): def _set_transport(self, transport): self.transport = transport - self.logger = util.get_logger(self.transport.get_log_channel() + '.' + self.name) + self.logger = util.get_logger(self.transport.get_log_channel() + '.' + self._name) def _set_window(self, window_size, max_packet_size): self.in_window_size = window_size @@ -809,7 +874,7 @@ class Channel (object): def _set_remote_channel(self, chanid, window_size, max_packet_size): self.remote_chanid = chanid self.out_window_size = window_size - self.out_max_packet_size = max(max_packet_size, self.MIN_PACKET_SIZE) + self.out_max_packet_size = max(max_packet_size, MIN_PACKET_SIZE) self.active = 1 self._log(DEBUG, 'Max packet out: %d bytes' % max_packet_size) @@ -909,6 +974,16 @@ class Channel (object): else: ok = server.check_channel_window_change_request(self, width, height, pixelwidth, pixelheight) + elif key == 'x11-req': + single_connection = m.get_boolean() + auth_proto = m.get_string() + auth_cookie = m.get_string() + screen_number = m.get_int() + if server is None: + ok = False + else: + ok = server.check_channel_x11_request(self, single_connection, + auth_proto, auth_cookie, screen_number) else: self._log(DEBUG, 'Unhandled channel request "%s"' % key) ok = False @@ -932,7 +1007,7 @@ class Channel (object): self._pipe.set_forever() finally: self.lock.release() - self._log(DEBUG, 'EOF received') + self._log(DEBUG, 'EOF received (%s)', self._name) def _handle_close(self, m): self.lock.acquire() @@ -949,8 +1024,8 @@ class Channel (object): ### internals... - def _log(self, level, msg): - self.logger.log(level, msg) + def _log(self, level, msg, *args): + self.logger.log(level, msg, *args) def _wait_for_event(self): while True: @@ -981,7 +1056,7 @@ class Channel (object): m.add_byte(chr(MSG_CHANNEL_EOF)) m.add_int(self.remote_chanid) self.eof_sent = True - self._log(DEBUG, 'EOF sent') + self._log(DEBUG, 'EOF sent (%s)', self._name) return m def _close_internal(self): diff --git a/paramiko/server.py b/paramiko/server.py index 3f08e47..31cc88c 100644 --- a/paramiko/server.py +++ b/paramiko/server.py @@ -426,6 +426,30 @@ class ServerInterface (object): @rtype: bool """ return False + + def check_channel_x11_request(self, channel, single_connection, auth_protocol, auth_cookie, screen_number): + """ + Determine if the client will be provided with an X11 session. If this + method returns C{True}, X11 applications should be routed through new + SSH channels, using L{Transport.open_x11_channel}. + + The default implementation always returns C{False}. + + @param channel: the L{Channel} the X11 request arrived on + @type channel: L{Channel} + @param single_connection: C{True} if only a single X11 channel should + be opened + @type single_connection: bool + @param auth_protocol: the protocol used for X11 authentication + @type auth_protocol: str + @param auth_cookie: the cookie used to authenticate to X11 + @type auth_cookie: str + @param screen_number: the number of the X11 screen to connect to + @type screen_number: int + @return: C{True} if the X11 session was opened; C{False} if not + @rtype: bool + """ + return False class SubsystemHandler (threading.Thread): diff --git a/paramiko/transport.py b/paramiko/transport.py index 762781b..0a1daf3 100644 --- a/paramiko/transport.py +++ b/paramiko/transport.py @@ -274,9 +274,10 @@ class Transport (threading.Thread): self.channels = weakref.WeakValueDictionary() # (id -> Channel) self.channel_events = { } # (id -> Event) self.channels_seen = { } # (id -> True) - self.channel_counter = 1 + self._channel_counter = 1 self.window_size = 65536 self.max_packet_size = 34816 + self._x11_handler = None self.saved_exception = None self.clear_to_send = threading.Event() @@ -592,6 +593,22 @@ class Transport (threading.Thread): """ return self.open_channel('session') + def open_x11_channel(self, src_addr=None): + """ + Request a new channel to the client, of type C{"x11"}. This + is just an alias for C{open_channel('x11', src_addr=src_addr)}. + + @param src_addr: the source address of the x11 server (port is the + x11 port, ie. 6010) + @type src_addr: (str, int) + @return: a new L{Channel} + @rtype: L{Channel} + + @raise SSHException: if the request is rejected or the session ends + prematurely + """ + return self.open_channel('x11', src_addr=src_addr) + def open_channel(self, kind, dest_addr=None, src_addr=None): """ Request a new channel to the server. L{Channel}s are socket-like @@ -621,11 +638,7 @@ class Transport (threading.Thread): return None self.lock.acquire() try: - chanid = self.channel_counter - while chanid in self.channels: - self.channel_counter = (self.channel_counter + 1) & 0xffffff - chanid = self.channel_counter - self.channel_counter = (self.channel_counter + 1) & 0xffffff + chanid = self._next_channel() m = Message() m.add_byte(chr(MSG_CHANNEL_OPEN)) m.add_string(kind) @@ -637,6 +650,9 @@ class Transport (threading.Thread): m.add_int(dest_addr[1]) m.add_string(src_addr[0]) m.add_int(src_addr[1]) + elif kind == 'x11': + m.add_string(src_addr[0]) + m.add_int(src_addr[1]) self.channels[chanid] = chan = Channel(chanid) self.channel_events[chanid] = event = threading.Event() self.channels_seen[chanid] = True @@ -1230,17 +1246,26 @@ class Transport (threading.Thread): ### internals... - def _log(self, level, msg): + def _log(self, level, msg, *args): if issubclass(type(msg), list): for m in msg: self.logger.log(level, m) else: - self.logger.log(level, msg) + self.logger.log(level, msg, *args) def _get_modulus_pack(self): "used by KexGex to find primes for group exchange" return self._modulus_pack + def _next_channel(self): + "you are holding the lock" + chanid = self._channel_counter + while chanid in self.channels: + self._channel_counter = (self._channel_counter + 1) & 0xffffff + chanid = self._channel_counter + self._channel_counter = (self._channel_counter + 1) & 0xffffff + return chanid + def _unlink_channel(self, chanid): "used by a Channel to remove itself from the active channel list" try: @@ -1314,6 +1339,25 @@ class Transport (threading.Thread): raise SSHException('Unknown client cipher ' + name) return self._cipher_info[name]['class'].new(key, self._cipher_info[name]['mode'], iv) + def _set_x11_handler(self, handler): + # only called if a channel has turned on x11 forwarding + if handler is None: + # by default, use the same mechanism as accept() + self._x11_handler = self._default_x11_handler + else: + self._x11_hanlder = handler + + def _default_x11_handler(self, channel, (src_addr, src_port)): + self._queue_incoming_channel(channel) + + def _queue_incoming_channel(self, channel): + self.lock.acquire() + try: + self.server_accepts.append(channel) + self.server_accept_cv.notify() + finally: + self.lock.release() + def run(self): # (use the exposed "run" method, because if we specify a thread target # of a private method, threading.Thread will keep a reference to it @@ -1710,7 +1754,7 @@ class Transport (threading.Thread): self._log(DEBUG, 'Received global request "%s"' % kind) want_reply = m.get_boolean() if not self.server_mode: - self._log(DEBUG, 'Rejecting "%s" channel request from server.' % kind) + self._log(DEBUG, 'Rejecting "%s" global request from server.' % kind) ok = False else: ok = self.server_object.check_global_request(kind, m) @@ -1784,18 +1828,23 @@ class Transport (threading.Thread): initial_window_size = m.get_int() max_packet_size = m.get_int() reject = False - if not self.server_mode: + if (kind == 'x11') and (self._x11_handler is not None): + origin_addr = m.get_string() + origin_port = m.get_int() + self._log(DEBUG, 'Incoming x11 connection from %s:%d' % (origin_addr, origin_port)) + self.lock.acquire() + try: + my_chanid = self._next_channel() + finally: + self.lock.release() + elif not self.server_mode: self._log(DEBUG, 'Rejecting "%s" channel request from server.' % kind) reject = True reason = OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED else: self.lock.acquire() try: - my_chanid = self.channel_counter - while my_chanid in self.channels: - self.channel_counter = (self.channel_counter + 1) & 0xffffff - my_chanid = self.channel_counter - self.channel_counter = (self.channel_counter + 1) & 0xffffff + my_chanid = self._next_channel() finally: self.lock.release() reason = self.server_object.check_channel_request(kind, my_chanid) @@ -1811,6 +1860,7 @@ class Transport (threading.Thread): msg.add_string('en') self._send_message(msg) return + chan = Channel(my_chanid) try: self.lock.acquire() @@ -1828,13 +1878,11 @@ class Transport (threading.Thread): m.add_int(self.window_size) m.add_int(self.max_packet_size) self._send_message(m) - self._log(INFO, 'Secsh channel %d opened.' % my_chanid) - try: - self.lock.acquire() - self.server_accepts.append(chan) - self.server_accept_cv.notify() - finally: - self.lock.release() + self._log(INFO, 'Secsh channel %d (%s) opened.', my_chanid, kind) + if kind == 'x11': + self._x11_handler(chan, (origin_addr, origin_port)) + else: + self._queue_incoming_channel(chan) def _parse_debug(self, m): always_display = m.get_boolean() diff --git a/tests/test_transport.py b/tests/test_transport.py index e3763ee..e94b862 100644 --- a/tests/test_transport.py +++ b/tests/test_transport.py @@ -101,6 +101,13 @@ class NullServer (ServerInterface): def check_global_request(self, kind, msg): self._global_request = kind return False + + def check_channel_x11_request(self, channel, single_connection, auth_protocol, auth_cookie, screen_number): + self._x11_single_connection = single_connection + self._x11_auth_protocol = auth_protocol + self._x11_auth_cookie = auth_cookie + self._x11_screen_number = screen_number + return True class TransportTest (unittest.TestCase): @@ -118,6 +125,20 @@ class TransportTest (unittest.TestCase): self.socks.close() self.sockc.close() + def setup_test_server(self): + host_key = RSAKey.from_private_key_file('tests/test_rsa.key') + public_host_key = RSAKey(data=str(host_key)) + self.ts.add_server_key(host_key) + event = threading.Event() + self.server = NullServer() + self.assert_(not event.isSet()) + self.ts.start_server(event, self.server) + self.tc.connect(hostkey=public_host_key) + self.tc.auth_password(username='slowdive', password='pygmalion') + event.wait(1.0) + self.assert_(event.isSet()) + self.assert_(self.ts.is_active()) + def test_1_security_options(self): o = self.tc.get_security_options() self.assertEquals(type(o), SecurityOptions) @@ -342,19 +363,7 @@ class TransportTest (unittest.TestCase): """ verify that exec_command() does something reasonable. """ - host_key = RSAKey.from_private_key_file('tests/test_rsa.key') - public_host_key = RSAKey(data=str(host_key)) - self.ts.add_server_key(host_key) - event = threading.Event() - server = NullServer() - self.assert_(not event.isSet()) - self.ts.start_server(event, server) - self.tc.ultra_debug = True - self.tc.connect(hostkey=public_host_key) - self.tc.auth_password(username='slowdive', password='pygmalion') - event.wait(1.0) - self.assert_(event.isSet()) - self.assert_(self.ts.is_active()) + self.setup_test_server() chan = self.tc.open_session() schan = self.ts.accept(1.0) @@ -396,20 +405,7 @@ class TransportTest (unittest.TestCase): """ verify that invoke_shell() does something reasonable. """ - host_key = RSAKey.from_private_key_file('tests/test_rsa.key') - public_host_key = RSAKey(data=str(host_key)) - self.ts.add_server_key(host_key) - event = threading.Event() - server = NullServer() - self.assert_(not event.isSet()) - self.ts.start_server(event, server) - self.tc.ultra_debug = True - self.tc.connect(hostkey=public_host_key) - self.tc.auth_password(username='slowdive', password='pygmalion') - event.wait(1.0) - self.assert_(event.isSet()) - self.assert_(self.ts.is_active()) - + self.setup_test_server() chan = self.tc.open_session() chan.invoke_shell() schan = self.ts.accept(1.0) @@ -423,20 +419,7 @@ class TransportTest (unittest.TestCase): """ verify that ChannelException is thrown for a bad open-channel request. """ - host_key = RSAKey.from_private_key_file('tests/test_rsa.key') - public_host_key = RSAKey(data=str(host_key)) - self.ts.add_server_key(host_key) - event = threading.Event() - server = NullServer() - self.assert_(not event.isSet()) - self.ts.start_server(event, server) - self.tc.ultra_debug = True - self.tc.connect(hostkey=public_host_key) - self.tc.auth_password(username='slowdive', password='pygmalion') - event.wait(1.0) - self.assert_(event.isSet()) - self.assert_(self.ts.is_active()) - + self.setup_test_server() try: chan = self.tc.open_channel('bogus') self.fail('expected exception') @@ -447,19 +430,7 @@ class TransportTest (unittest.TestCase): """ verify that get_exit_status() works. """ - host_key = RSAKey.from_private_key_file('tests/test_rsa.key') - public_host_key = RSAKey(data=str(host_key)) - self.ts.add_server_key(host_key) - event = threading.Event() - server = NullServer() - self.assert_(not event.isSet()) - self.ts.start_server(event, server) - self.tc.ultra_debug = True - self.tc.connect(hostkey=public_host_key) - self.tc.auth_password(username='slowdive', password='pygmalion') - event.wait(1.0) - self.assert_(event.isSet()) - self.assert_(self.ts.is_active()) + self.setup_test_server() chan = self.tc.open_session() schan = self.ts.accept(1.0) @@ -481,20 +452,7 @@ class TransportTest (unittest.TestCase): """ verify that select() on a channel works. """ - host_key = RSAKey.from_private_key_file('tests/test_rsa.key') - public_host_key = RSAKey(data=str(host_key)) - self.ts.add_server_key(host_key) - event = threading.Event() - server = NullServer() - self.assert_(not event.isSet()) - self.ts.start_server(event, server) - self.tc.ultra_debug = True - self.tc.connect(hostkey=public_host_key) - self.tc.auth_password(username='slowdive', password='pygmalion') - event.wait(1.0) - self.assert_(event.isSet()) - self.assert_(self.ts.is_active()) - + self.setup_test_server() chan = self.tc.open_session() chan.invoke_shell() schan = self.ts.accept(1.0) @@ -549,20 +507,8 @@ class TransportTest (unittest.TestCase): """ verify that a transport can correctly renegotiate mid-stream. """ - host_key = RSAKey.from_private_key_file('tests/test_rsa.key') - public_host_key = RSAKey(data=str(host_key)) - self.ts.add_server_key(host_key) - event = threading.Event() - server = NullServer() - self.ts.start_server(event, server) - self.tc.connect(hostkey=public_host_key, - username='slowdive', password='pygmalion') - event.wait(1.0) - self.assert_(event.isSet()) - self.assert_(self.ts.is_active()) - + self.setup_test_server() self.tc.packetizer.REKEY_BYTES = 16384 - chan = self.tc.open_session() chan.exec_command('yes') schan = self.ts.accept(1.0) @@ -612,3 +558,32 @@ class TransportTest (unittest.TestCase): chan.close() schan.close() + + def test_I_x11(self): + """ + verify that an x11 port can be requested and opened. + """ + self.setup_test_server() + chan = self.tc.open_session() + chan.exec_command('yes') + schan = self.ts.accept(1.0) + + self.assertEquals(None, getattr(self.server, '_x11_screen_number', None)) + cookie = chan.request_x11(0, single_connection=True) + self.assertEquals(0, self.server._x11_screen_number) + self.assertEquals('MIT-MAGIC-COOKIE-1', self.server._x11_auth_protocol) + self.assertEquals(cookie, self.server._x11_auth_cookie) + self.assertEquals(True, self.server._x11_single_connection) + + x11_server = self.ts.open_x11_channel(('localhost', 6093)) + x11_client = self.tc.accept() + + x11_server.send('hello') + self.assertEquals('hello', x11_client.recv(5)) + + x11_server.close() + x11_client.close() + chan.close() + schan.close() + + \ No newline at end of file