add support for opening x11 channels, and a unit test
This commit is contained in:
Robey Pointer 2006-11-20 16:11:17 -08:00
parent 48bb10694b
commit fec76c51b1
4 changed files with 242 additions and 120 deletions

View File

@ -20,6 +20,7 @@
Abstraction for an SSH2 channel.
"""
import binascii
import sys
import time
import threading
@ -35,6 +36,10 @@ from paramiko.buffered_pipe import BufferedPipe, PipeTimeout
from paramiko import pipe
# lower bound on the max packet size we'll accept from the remote host
MIN_PACKET_SIZE = 1024
class Channel (object):
"""
A secure tunnel across an SSH L{Transport}. A Channel is meant to behave
@ -50,9 +55,6 @@ class Channel (object):
is exactly like a normal network socket, so it shouldn't be too surprising.
"""
# lower bound on the max packet size we'll accept from the remote host
MIN_PACKET_SIZE = 1024
def __init__(self, chanid):
"""
Create a new channel. The channel is not associated with any
@ -84,7 +86,7 @@ class Channel (object):
self.in_window_threshold = 0
self.in_window_sofar = 0
self.status_event = threading.Event()
self.name = str(chanid)
self._name = str(chanid)
self.logger = util.get_logger('paramiko.chan.' + str(chanid))
self._pipe = None
self.event = threading.Event()
@ -202,7 +204,7 @@ class Channel (object):
m.add_byte(chr(MSG_CHANNEL_REQUEST))
m.add_int(self.remote_chanid)
m.add_string('exec')
m.add_boolean(1)
m.add_boolean(True)
m.add_string(command)
self.event.clear()
self.transport._send_user_message(m)
@ -229,7 +231,7 @@ class Channel (object):
m.add_byte(chr(MSG_CHANNEL_REQUEST))
m.add_int(self.remote_chanid)
m.add_string('subsystem')
m.add_boolean(1)
m.add_boolean(True)
m.add_string(subsystem)
self.event.clear()
self.transport._send_user_message(m)
@ -254,7 +256,7 @@ class Channel (object):
m.add_byte(chr(MSG_CHANNEL_REQUEST))
m.add_int(self.remote_chanid)
m.add_string('window-change')
m.add_boolean(1)
m.add_boolean(True)
m.add_int(width)
m.add_int(height)
m.add_int(0).add_int(0)
@ -299,10 +301,73 @@ class Channel (object):
m.add_byte(chr(MSG_CHANNEL_REQUEST))
m.add_int(self.remote_chanid)
m.add_string('exit-status')
m.add_boolean(0)
m.add_boolean(False)
m.add_int(status)
self.transport._send_user_message(m)
def request_x11(self, screen_number=0, auth_protocol=None, auth_cookie=None,
single_connection=False, handler=None):
"""
Request an x11 session on this channel. If the server allows it,
further x11 requests can be made from the server to the client,
when an x11 application is run in a shell session.
From RFC4254::
It is RECOMMENDED that the 'x11 authentication cookie' that is
sent be a fake, random cookie, and that the cookie be checked and
replaced by the real cookie when a connection request is received.
If you omit the auth_cookie, a new secure random 128-bit value will be
generated, used, and returned. You will need to use this value to
verify incoming x11 requests and replace them with the actual local
x11 cookie (which requires some knoweldge of the x11 protocol).
If a handler is passed in, the handler is called from another thread
whenever a new x11 connection arrives. The default handler queues up
incoming x11 connections, which may be retrieved using
L{Transport.accept}. The handler's calling signature is::
handler(channel: Channel, (address: str, port: int))
@param screen_number: the x11 screen number (0, 10, etc)
@type screen_number: int
@param auth_protocol: the name of the X11 authentication method used;
if none is given, C{"MIT-MAGIC-COOKIE-1"} is used
@type auth_proto: str
@param auth_cookie: hexadecimal string containing the x11 auth cookie;
if none is given, a secure random 128-bit value is generated
@type auth_cookie: str
@param single_connection: if True, only a single x11 connection will be
forwarded (by default, any number of x11 connections can arrive
over this session)
@type single_connection: bool
@param handler: an optional handler to use for incoming X11 connections
@type handler: function
@return: the auth_cookie used
"""
if self.closed or self.eof_received or self.eof_sent or not self.active:
raise SSHException('Channel is not open')
if auth_protocol is None:
auth_protocol = 'MIT-MAGIC-COOKIE-1'
if auth_cookie is None:
auth_cookie = binascii.hexlify(self.transport.randpool.get_bytes(16))
m = Message()
m.add_byte(chr(MSG_CHANNEL_REQUEST))
m.add_int(self.remote_chanid)
m.add_string('x11-req')
m.add_boolean(True)
m.add_boolean(single_connection)
m.add_string(auth_protocol)
m.add_string(auth_cookie)
m.add_int(screen_number)
self.event.clear()
self.transport._send_user_message(m)
self._wait_for_event()
self.transport._set_x11_handler(handler)
return auth_cookie
def get_transport(self):
"""
Return the L{Transport} associated with this channel.
@ -321,8 +386,8 @@ class Channel (object):
@param name: new channel name.
@type name: str
"""
self.name = name
self.logger = util.get_logger(self.transport.get_log_channel() + '.' + self.name)
self._name = name
self.logger = util.get_logger(self.transport.get_log_channel() + '.' + self._name)
def get_name(self):
"""
@ -331,7 +396,7 @@ class Channel (object):
@return: the name of this channel.
@rtype: str
"""
return self.name
return self._name
def get_id(self):
"""
@ -796,7 +861,7 @@ class Channel (object):
def _set_transport(self, transport):
self.transport = transport
self.logger = util.get_logger(self.transport.get_log_channel() + '.' + self.name)
self.logger = util.get_logger(self.transport.get_log_channel() + '.' + self._name)
def _set_window(self, window_size, max_packet_size):
self.in_window_size = window_size
@ -809,7 +874,7 @@ class Channel (object):
def _set_remote_channel(self, chanid, window_size, max_packet_size):
self.remote_chanid = chanid
self.out_window_size = window_size
self.out_max_packet_size = max(max_packet_size, self.MIN_PACKET_SIZE)
self.out_max_packet_size = max(max_packet_size, MIN_PACKET_SIZE)
self.active = 1
self._log(DEBUG, 'Max packet out: %d bytes' % max_packet_size)
@ -909,6 +974,16 @@ class Channel (object):
else:
ok = server.check_channel_window_change_request(self, width, height, pixelwidth,
pixelheight)
elif key == 'x11-req':
single_connection = m.get_boolean()
auth_proto = m.get_string()
auth_cookie = m.get_string()
screen_number = m.get_int()
if server is None:
ok = False
else:
ok = server.check_channel_x11_request(self, single_connection,
auth_proto, auth_cookie, screen_number)
else:
self._log(DEBUG, 'Unhandled channel request "%s"' % key)
ok = False
@ -932,7 +1007,7 @@ class Channel (object):
self._pipe.set_forever()
finally:
self.lock.release()
self._log(DEBUG, 'EOF received')
self._log(DEBUG, 'EOF received (%s)', self._name)
def _handle_close(self, m):
self.lock.acquire()
@ -949,8 +1024,8 @@ class Channel (object):
### internals...
def _log(self, level, msg):
self.logger.log(level, msg)
def _log(self, level, msg, *args):
self.logger.log(level, msg, *args)
def _wait_for_event(self):
while True:
@ -981,7 +1056,7 @@ class Channel (object):
m.add_byte(chr(MSG_CHANNEL_EOF))
m.add_int(self.remote_chanid)
self.eof_sent = True
self._log(DEBUG, 'EOF sent')
self._log(DEBUG, 'EOF sent (%s)', self._name)
return m
def _close_internal(self):

View File

@ -426,6 +426,30 @@ class ServerInterface (object):
@rtype: bool
"""
return False
def check_channel_x11_request(self, channel, single_connection, auth_protocol, auth_cookie, screen_number):
"""
Determine if the client will be provided with an X11 session. If this
method returns C{True}, X11 applications should be routed through new
SSH channels, using L{Transport.open_x11_channel}.
The default implementation always returns C{False}.
@param channel: the L{Channel} the X11 request arrived on
@type channel: L{Channel}
@param single_connection: C{True} if only a single X11 channel should
be opened
@type single_connection: bool
@param auth_protocol: the protocol used for X11 authentication
@type auth_protocol: str
@param auth_cookie: the cookie used to authenticate to X11
@type auth_cookie: str
@param screen_number: the number of the X11 screen to connect to
@type screen_number: int
@return: C{True} if the X11 session was opened; C{False} if not
@rtype: bool
"""
return False
class SubsystemHandler (threading.Thread):

View File

@ -274,9 +274,10 @@ class Transport (threading.Thread):
self.channels = weakref.WeakValueDictionary() # (id -> Channel)
self.channel_events = { } # (id -> Event)
self.channels_seen = { } # (id -> True)
self.channel_counter = 1
self._channel_counter = 1
self.window_size = 65536
self.max_packet_size = 34816
self._x11_handler = None
self.saved_exception = None
self.clear_to_send = threading.Event()
@ -592,6 +593,22 @@ class Transport (threading.Thread):
"""
return self.open_channel('session')
def open_x11_channel(self, src_addr=None):
"""
Request a new channel to the client, of type C{"x11"}. This
is just an alias for C{open_channel('x11', src_addr=src_addr)}.
@param src_addr: the source address of the x11 server (port is the
x11 port, ie. 6010)
@type src_addr: (str, int)
@return: a new L{Channel}
@rtype: L{Channel}
@raise SSHException: if the request is rejected or the session ends
prematurely
"""
return self.open_channel('x11', src_addr=src_addr)
def open_channel(self, kind, dest_addr=None, src_addr=None):
"""
Request a new channel to the server. L{Channel}s are socket-like
@ -621,11 +638,7 @@ class Transport (threading.Thread):
return None
self.lock.acquire()
try:
chanid = self.channel_counter
while chanid in self.channels:
self.channel_counter = (self.channel_counter + 1) & 0xffffff
chanid = self.channel_counter
self.channel_counter = (self.channel_counter + 1) & 0xffffff
chanid = self._next_channel()
m = Message()
m.add_byte(chr(MSG_CHANNEL_OPEN))
m.add_string(kind)
@ -637,6 +650,9 @@ class Transport (threading.Thread):
m.add_int(dest_addr[1])
m.add_string(src_addr[0])
m.add_int(src_addr[1])
elif kind == 'x11':
m.add_string(src_addr[0])
m.add_int(src_addr[1])
self.channels[chanid] = chan = Channel(chanid)
self.channel_events[chanid] = event = threading.Event()
self.channels_seen[chanid] = True
@ -1230,17 +1246,26 @@ class Transport (threading.Thread):
### internals...
def _log(self, level, msg):
def _log(self, level, msg, *args):
if issubclass(type(msg), list):
for m in msg:
self.logger.log(level, m)
else:
self.logger.log(level, msg)
self.logger.log(level, msg, *args)
def _get_modulus_pack(self):
"used by KexGex to find primes for group exchange"
return self._modulus_pack
def _next_channel(self):
"you are holding the lock"
chanid = self._channel_counter
while chanid in self.channels:
self._channel_counter = (self._channel_counter + 1) & 0xffffff
chanid = self._channel_counter
self._channel_counter = (self._channel_counter + 1) & 0xffffff
return chanid
def _unlink_channel(self, chanid):
"used by a Channel to remove itself from the active channel list"
try:
@ -1314,6 +1339,25 @@ class Transport (threading.Thread):
raise SSHException('Unknown client cipher ' + name)
return self._cipher_info[name]['class'].new(key, self._cipher_info[name]['mode'], iv)
def _set_x11_handler(self, handler):
# only called if a channel has turned on x11 forwarding
if handler is None:
# by default, use the same mechanism as accept()
self._x11_handler = self._default_x11_handler
else:
self._x11_hanlder = handler
def _default_x11_handler(self, channel, (src_addr, src_port)):
self._queue_incoming_channel(channel)
def _queue_incoming_channel(self, channel):
self.lock.acquire()
try:
self.server_accepts.append(channel)
self.server_accept_cv.notify()
finally:
self.lock.release()
def run(self):
# (use the exposed "run" method, because if we specify a thread target
# of a private method, threading.Thread will keep a reference to it
@ -1710,7 +1754,7 @@ class Transport (threading.Thread):
self._log(DEBUG, 'Received global request "%s"' % kind)
want_reply = m.get_boolean()
if not self.server_mode:
self._log(DEBUG, 'Rejecting "%s" channel request from server.' % kind)
self._log(DEBUG, 'Rejecting "%s" global request from server.' % kind)
ok = False
else:
ok = self.server_object.check_global_request(kind, m)
@ -1784,18 +1828,23 @@ class Transport (threading.Thread):
initial_window_size = m.get_int()
max_packet_size = m.get_int()
reject = False
if not self.server_mode:
if (kind == 'x11') and (self._x11_handler is not None):
origin_addr = m.get_string()
origin_port = m.get_int()
self._log(DEBUG, 'Incoming x11 connection from %s:%d' % (origin_addr, origin_port))
self.lock.acquire()
try:
my_chanid = self._next_channel()
finally:
self.lock.release()
elif not self.server_mode:
self._log(DEBUG, 'Rejecting "%s" channel request from server.' % kind)
reject = True
reason = OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED
else:
self.lock.acquire()
try:
my_chanid = self.channel_counter
while my_chanid in self.channels:
self.channel_counter = (self.channel_counter + 1) & 0xffffff
my_chanid = self.channel_counter
self.channel_counter = (self.channel_counter + 1) & 0xffffff
my_chanid = self._next_channel()
finally:
self.lock.release()
reason = self.server_object.check_channel_request(kind, my_chanid)
@ -1811,6 +1860,7 @@ class Transport (threading.Thread):
msg.add_string('en')
self._send_message(msg)
return
chan = Channel(my_chanid)
try:
self.lock.acquire()
@ -1828,13 +1878,11 @@ class Transport (threading.Thread):
m.add_int(self.window_size)
m.add_int(self.max_packet_size)
self._send_message(m)
self._log(INFO, 'Secsh channel %d opened.' % my_chanid)
try:
self.lock.acquire()
self.server_accepts.append(chan)
self.server_accept_cv.notify()
finally:
self.lock.release()
self._log(INFO, 'Secsh channel %d (%s) opened.', my_chanid, kind)
if kind == 'x11':
self._x11_handler(chan, (origin_addr, origin_port))
else:
self._queue_incoming_channel(chan)
def _parse_debug(self, m):
always_display = m.get_boolean()

View File

@ -101,6 +101,13 @@ class NullServer (ServerInterface):
def check_global_request(self, kind, msg):
self._global_request = kind
return False
def check_channel_x11_request(self, channel, single_connection, auth_protocol, auth_cookie, screen_number):
self._x11_single_connection = single_connection
self._x11_auth_protocol = auth_protocol
self._x11_auth_cookie = auth_cookie
self._x11_screen_number = screen_number
return True
class TransportTest (unittest.TestCase):
@ -118,6 +125,20 @@ class TransportTest (unittest.TestCase):
self.socks.close()
self.sockc.close()
def setup_test_server(self):
host_key = RSAKey.from_private_key_file('tests/test_rsa.key')
public_host_key = RSAKey(data=str(host_key))
self.ts.add_server_key(host_key)
event = threading.Event()
self.server = NullServer()
self.assert_(not event.isSet())
self.ts.start_server(event, self.server)
self.tc.connect(hostkey=public_host_key)
self.tc.auth_password(username='slowdive', password='pygmalion')
event.wait(1.0)
self.assert_(event.isSet())
self.assert_(self.ts.is_active())
def test_1_security_options(self):
o = self.tc.get_security_options()
self.assertEquals(type(o), SecurityOptions)
@ -342,19 +363,7 @@ class TransportTest (unittest.TestCase):
"""
verify that exec_command() does something reasonable.
"""
host_key = RSAKey.from_private_key_file('tests/test_rsa.key')
public_host_key = RSAKey(data=str(host_key))
self.ts.add_server_key(host_key)
event = threading.Event()
server = NullServer()
self.assert_(not event.isSet())
self.ts.start_server(event, server)
self.tc.ultra_debug = True
self.tc.connect(hostkey=public_host_key)
self.tc.auth_password(username='slowdive', password='pygmalion')
event.wait(1.0)
self.assert_(event.isSet())
self.assert_(self.ts.is_active())
self.setup_test_server()
chan = self.tc.open_session()
schan = self.ts.accept(1.0)
@ -396,20 +405,7 @@ class TransportTest (unittest.TestCase):
"""
verify that invoke_shell() does something reasonable.
"""
host_key = RSAKey.from_private_key_file('tests/test_rsa.key')
public_host_key = RSAKey(data=str(host_key))
self.ts.add_server_key(host_key)
event = threading.Event()
server = NullServer()
self.assert_(not event.isSet())
self.ts.start_server(event, server)
self.tc.ultra_debug = True
self.tc.connect(hostkey=public_host_key)
self.tc.auth_password(username='slowdive', password='pygmalion')
event.wait(1.0)
self.assert_(event.isSet())
self.assert_(self.ts.is_active())
self.setup_test_server()
chan = self.tc.open_session()
chan.invoke_shell()
schan = self.ts.accept(1.0)
@ -423,20 +419,7 @@ class TransportTest (unittest.TestCase):
"""
verify that ChannelException is thrown for a bad open-channel request.
"""
host_key = RSAKey.from_private_key_file('tests/test_rsa.key')
public_host_key = RSAKey(data=str(host_key))
self.ts.add_server_key(host_key)
event = threading.Event()
server = NullServer()
self.assert_(not event.isSet())
self.ts.start_server(event, server)
self.tc.ultra_debug = True
self.tc.connect(hostkey=public_host_key)
self.tc.auth_password(username='slowdive', password='pygmalion')
event.wait(1.0)
self.assert_(event.isSet())
self.assert_(self.ts.is_active())
self.setup_test_server()
try:
chan = self.tc.open_channel('bogus')
self.fail('expected exception')
@ -447,19 +430,7 @@ class TransportTest (unittest.TestCase):
"""
verify that get_exit_status() works.
"""
host_key = RSAKey.from_private_key_file('tests/test_rsa.key')
public_host_key = RSAKey(data=str(host_key))
self.ts.add_server_key(host_key)
event = threading.Event()
server = NullServer()
self.assert_(not event.isSet())
self.ts.start_server(event, server)
self.tc.ultra_debug = True
self.tc.connect(hostkey=public_host_key)
self.tc.auth_password(username='slowdive', password='pygmalion')
event.wait(1.0)
self.assert_(event.isSet())
self.assert_(self.ts.is_active())
self.setup_test_server()
chan = self.tc.open_session()
schan = self.ts.accept(1.0)
@ -481,20 +452,7 @@ class TransportTest (unittest.TestCase):
"""
verify that select() on a channel works.
"""
host_key = RSAKey.from_private_key_file('tests/test_rsa.key')
public_host_key = RSAKey(data=str(host_key))
self.ts.add_server_key(host_key)
event = threading.Event()
server = NullServer()
self.assert_(not event.isSet())
self.ts.start_server(event, server)
self.tc.ultra_debug = True
self.tc.connect(hostkey=public_host_key)
self.tc.auth_password(username='slowdive', password='pygmalion')
event.wait(1.0)
self.assert_(event.isSet())
self.assert_(self.ts.is_active())
self.setup_test_server()
chan = self.tc.open_session()
chan.invoke_shell()
schan = self.ts.accept(1.0)
@ -549,20 +507,8 @@ class TransportTest (unittest.TestCase):
"""
verify that a transport can correctly renegotiate mid-stream.
"""
host_key = RSAKey.from_private_key_file('tests/test_rsa.key')
public_host_key = RSAKey(data=str(host_key))
self.ts.add_server_key(host_key)
event = threading.Event()
server = NullServer()
self.ts.start_server(event, server)
self.tc.connect(hostkey=public_host_key,
username='slowdive', password='pygmalion')
event.wait(1.0)
self.assert_(event.isSet())
self.assert_(self.ts.is_active())
self.setup_test_server()
self.tc.packetizer.REKEY_BYTES = 16384
chan = self.tc.open_session()
chan.exec_command('yes')
schan = self.ts.accept(1.0)
@ -612,3 +558,32 @@ class TransportTest (unittest.TestCase):
chan.close()
schan.close()
def test_I_x11(self):
"""
verify that an x11 port can be requested and opened.
"""
self.setup_test_server()
chan = self.tc.open_session()
chan.exec_command('yes')
schan = self.ts.accept(1.0)
self.assertEquals(None, getattr(self.server, '_x11_screen_number', None))
cookie = chan.request_x11(0, single_connection=True)
self.assertEquals(0, self.server._x11_screen_number)
self.assertEquals('MIT-MAGIC-COOKIE-1', self.server._x11_auth_protocol)
self.assertEquals(cookie, self.server._x11_auth_cookie)
self.assertEquals(True, self.server._x11_single_connection)
x11_server = self.ts.open_x11_channel(('localhost', 6093))
x11_client = self.tc.accept()
x11_server.send('hello')
self.assertEquals('hello', x11_client.recv(5))
x11_server.close()
x11_client.close()
chan.close()
schan.close()