[project @ Arch-1:robey@lag.net--2003-public%secsh--dev--1.0--patch-5]

big chunk of work which makes server code 95% done
fixed auth check methods to return just a result (failed, succeeded,
partially succeeded) and always use get_allowed_auths to determine the
list of allowed auth methods to return.

channel's internal API changed a bit to allow for client-side vs.
server-side channels.  we now honor the "want-reply" bit from channel
requests.  in server mode (for now), we automatically allow pty-req
and shell requests without doing anything.

ChannelFile was fixed up a bit to support universal newlines.  readline
got rewritten: the old way used the "greedy" read call from ChannelFile,
which won't work if the socket doesn't have that much data buffered and
ready.  now it uses recv directly, and tracks the different newlines.

demo-server.py now answers to a single shell request (like a CLI ssh
tool will make) and does a very simple demo pretending to be a BBS.

transport: fixed a bug with parsing the remote side's banner.  channel
requests are passed to another method in server mode, to determine if
we should allow it.  new allowed channels are added to an accept queue,
and a new method 'accept' (with timeout) will block until the next
incoming channel is ready.
This commit is contained in:
Robey Pointer 2003-11-09 21:14:21 +00:00
parent 79fecc4564
commit 5a48714394
5 changed files with 227 additions and 69 deletions

View File

@ -10,11 +10,13 @@ from logging import DEBUG, INFO, WARNING, ERROR, CRITICAL
DISCONNECT_SERVICE_NOT_AVAILABLE, DISCONNECT_AUTH_CANCELLED_BY_USER, \ DISCONNECT_SERVICE_NOT_AVAILABLE, DISCONNECT_AUTH_CANCELLED_BY_USER, \
DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE = 7, 13, 14 DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE = 7, 13, 14
AUTH_SUCCESSFUL, AUTH_PARTIALLY_SUCCESSFUL, AUTH_FAILED = range(3)
class Transport(BaseTransport): class Transport(BaseTransport):
"BaseTransport with the auth framework hooked up" "BaseTransport with the auth framework hooked up"
AUTH_SUCCESSFUL, AUTH_PARTIALLY_SUCCESSFUL, AUTH_FAILED = range(3)
def __init__(self, sock): def __init__(self, sock):
BaseTransport.__init__(self, sock) BaseTransport.__init__(self, sock)
self.auth_event = None self.auth_event = None
@ -111,21 +113,21 @@ class Transport(BaseTransport):
else: else:
self.log(DEBUG, 'Service request "%s" accepted (?)' % service) self.log(DEBUG, 'Service request "%s" accepted (?)' % service)
def get_allowed_auths(self): def get_allowed_auths(self, username):
"override me!" "override me!"
return 'password' return 'password'
def check_auth_none(self, username): def check_auth_none(self, username):
"override me! return tuple of (int, string) ==> (auth status, list of acceptable auth methods)" "override me! return int ==> auth status"
return (AUTH_FAILED, self.get_allowed_auths()) return self.AUTH_FAILED
def check_auth_password(self, username, password): def check_auth_password(self, username, password):
"override me! return tuple of (int, string) ==> (auth status, list of acceptable auth methods)" "override me! return int ==> auth status"
return (AUTH_FAILED, self.get_allowed_auths()) return self.AUTH_FAILED
def check_auth_publickey(self, username, key): def check_auth_publickey(self, username, key):
"override me! return tuple of (int, string) ==> (auth status, list of acceptable auth methods)" "override me! return int ==> auth status"
return (AUTH_FAILED, self.get_allowed_auths()) return self.AUTH_FAILED
def parse_userauth_request(self, m): def parse_userauth_request(self, m):
if not self.server_mode: if not self.server_mode:
@ -142,11 +144,12 @@ class Transport(BaseTransport):
username = m.get_string() username = m.get_string()
service = m.get_string() service = m.get_string()
method = m.get_string() method = m.get_string()
self.log(DEBUG, 'Auth request (type=%s) service=%s, username=%s' % (method, service, username))
if service != 'ssh-connection': if service != 'ssh-connection':
self.disconnect_service_not_available() self.disconnect_service_not_available()
return return
if (self.auth_username is not None) and (self.auth_username != username): if (self.auth_username is not None) and (self.auth_username != username):
# trying to change username in mid-flight! self.log(DEBUG, 'Auth rejected because the client attempted to change username in mid-flight')
self.disconnect_no_more_auth() self.disconnect_no_more_auth()
return return
if method == 'none': if method == 'none':
@ -157,27 +160,27 @@ class Transport(BaseTransport):
if changereq: if changereq:
# always treated as failure, since we don't support changing passwords, but collect # always treated as failure, since we don't support changing passwords, but collect
# the list of valid auth types from the callback anyway # the list of valid auth types from the callback anyway
self.log(DEBUG, 'Auth request to change passwords (rejected)')
newpassword = m.get_string().decode('UTF-8') newpassword = m.get_string().decode('UTF-8')
result = self.check_auth_password(username, password) result = self.AUTH_FAILED
result = (AUTH_FAILED, result[1])
else: else:
result = self.check_auth_password(username, password) result = self.check_auth_password(username, password)
elif method == 'publickey': elif method == 'publickey':
# FIXME # FIXME
result = self.check_auth_none(username) result = self.check_auth_none(username)
result = (AUTH_FAILED, result[1])
else: else:
result = self.check_auth_none(username) result = self.check_auth_none(username)
result = (AUTH_FAILED, result[1])
# okay, send result # okay, send result
m = Message() m = Message()
if result[0] == AUTH_SUCCESSFUL: if result == self.AUTH_SUCCESSFUL:
m.add_byte(chr(MSG_USERAUTH_SUCCESSFUL)) self.log(DEBUG, 'Auth granted.')
m.add_byte(chr(MSG_USERAUTH_SUCCESS))
self.auth_complete = 1 self.auth_complete = 1
else: else:
self.log(DEBUG, 'Auth rejected.')
m.add_byte(chr(MSG_USERAUTH_FAILURE)) m.add_byte(chr(MSG_USERAUTH_FAILURE))
m.add_string(result[1]) m.add_string(self.get_allowed_auths(username))
if result[0] == AUTH_PARTIALLY_SUCCESSFUL: if result == self.AUTH_PARTIALLY_SUCCESSFUL:
m.add_boolean(1) m.add_boolean(1)
else: else:
m.add_boolean(0) m.add_boolean(0)

View File

@ -1,7 +1,7 @@
from message import Message from message import Message
from secsh import SSHException from secsh import SSHException
from transport import MSG_CHANNEL_REQUEST, MSG_CHANNEL_CLOSE, MSG_CHANNEL_WINDOW_ADJUST, MSG_CHANNEL_DATA, \ from transport import MSG_CHANNEL_REQUEST, MSG_CHANNEL_CLOSE, MSG_CHANNEL_WINDOW_ADJUST, MSG_CHANNEL_DATA, \
MSG_CHANNEL_EOF MSG_CHANNEL_EOF, MSG_CHANNEL_SUCCESS, MSG_CHANNEL_FAILURE
import time, threading, logging, socket, os import time, threading, logging, socket, os
from logging import DEBUG from logging import DEBUG
@ -18,9 +18,9 @@ class Channel(object):
Abstraction for a secsh channel. Abstraction for a secsh channel.
""" """
def __init__(self, chanid, transport): def __init__(self, chanid):
self.chanid = chanid self.chanid = chanid
self.transport = transport self.transport = None
self.active = 0 self.active = 0
self.eof_received = 0 self.eof_received = 0
self.eof_sent = 0 self.eof_sent = 0
@ -50,6 +50,9 @@ class Channel(object):
out += '>' out += '>'
return out return out
def set_transport(self, transport):
self.transport = transport
def log(self, level, msg): def log(self, level, msg):
self.logger.log(level, msg) self.logger.log(level, msg)
@ -60,8 +63,8 @@ class Channel(object):
self.in_window_threshold = window_size // 10 self.in_window_threshold = window_size // 10
self.in_window_sofar = 0 self.in_window_sofar = 0
def set_server_channel(self, chanid, window_size, max_packet_size): def set_remote_channel(self, chanid, window_size, max_packet_size):
self.server_chanid = chanid self.remote_chanid = chanid
self.out_window_size = window_size self.out_window_size = window_size
self.out_max_packet_size = max_packet_size self.out_max_packet_size = max_packet_size
self.active = 1 self.active = 1
@ -99,14 +102,29 @@ class Channel(object):
def handle_request(self, m): def handle_request(self, m):
key = m.get_string() key = m.get_string()
want_reply = m.get_boolean()
ok = False
if key == 'exit-status': if key == 'exit-status':
self.exit_status = m.get_int() self.exit_status = m.get_int()
return ok = True
elif key == 'xon-xoff': elif key == 'xon-xoff':
# ignore # ignore
return ok = True
elif (key == 'pty-req') or (key == 'shell'):
if self.transport.server_mode:
# humor them
ok = True
else: else:
self.log(DEBUG, 'Unhandled channel request "%s"' % key) self.log(DEBUG, 'Unhandled channel request "%s"' % key)
ok = False
if want_reply:
m = Message()
if ok:
m.add_byte(chr(MSG_CHANNEL_SUCCESS))
else:
m.add_byte(chr(MSG_CHANNEL_FAILURE))
m.add_int(self.remote_chanid)
self.transport.send_message(m)
def handle_eof(self, m): def handle_eof(self, m):
self.eof_received = 1 self.eof_received = 1
@ -140,7 +158,7 @@ class Channel(object):
raise SSHException('Channel is not open') raise SSHException('Channel is not open')
m = Message() m = Message()
m.add_byte(chr(MSG_CHANNEL_REQUEST)) m.add_byte(chr(MSG_CHANNEL_REQUEST))
m.add_int(self.server_chanid) m.add_int(self.remote_chanid)
m.add_string('pty-req') m.add_string('pty-req')
m.add_boolean(0) m.add_boolean(0)
m.add_string(term) m.add_string(term)
@ -156,7 +174,7 @@ class Channel(object):
raise SSHException('Channel is not open') raise SSHException('Channel is not open')
m = Message() m = Message()
m.add_byte(chr(MSG_CHANNEL_REQUEST)) m.add_byte(chr(MSG_CHANNEL_REQUEST))
m.add_int(self.server_chanid) m.add_int(self.remote_chanid)
m.add_string('shell') m.add_string('shell')
m.add_boolean(1) m.add_boolean(1)
self.transport.send_message(m) self.transport.send_message(m)
@ -166,7 +184,7 @@ class Channel(object):
raise SSHException('Channel is not open') raise SSHException('Channel is not open')
m = Message() m = Message()
m.add_byte(chr(MSG_CHANNEL_REQUEST)) m.add_byte(chr(MSG_CHANNEL_REQUEST))
m.add_int(self.server_chanid) m.add_int(self.remote_chanid)
m.add_string('exec') m.add_string('exec')
m.add_boolean(1) m.add_boolean(1)
m.add_string(command) m.add_string(command)
@ -177,7 +195,7 @@ class Channel(object):
raise SSHException('Channel is not open') raise SSHException('Channel is not open')
m = Message() m = Message()
m.add_byte(chr(MSG_CHANNEL_REQUEST)) m.add_byte(chr(MSG_CHANNEL_REQUEST))
m.add_int(self.server_chanid) m.add_int(self.remote_chanid)
m.add_string('subsystem') m.add_string('subsystem')
m.add_boolean(1) m.add_boolean(1)
m.add_string(subsystem) m.add_string(subsystem)
@ -188,7 +206,7 @@ class Channel(object):
raise SSHException('Channel is not open') raise SSHException('Channel is not open')
m = Message() m = Message()
m.add_byte(chr(MSG_CHANNEL_REQUEST)) m.add_byte(chr(MSG_CHANNEL_REQUEST))
m.add_int(self.server_chanid) m.add_int(self.remote_chanid)
m.add_string('window-change') m.add_string('window-change')
m.add_boolean(0) m.add_boolean(0)
m.add_int(width) m.add_int(width)
@ -211,7 +229,7 @@ class Channel(object):
return return
m = Message() m = Message()
m.add_byte(chr(MSG_CHANNEL_EOF)) m.add_byte(chr(MSG_CHANNEL_EOF))
m.add_int(self.server_chanid) m.add_int(self.remote_chanid)
self.transport.send_message(m) self.transport.send_message(m)
self.eof_sent = 1 self.eof_sent = 1
self.log(DEBUG, 'EOF sent') self.log(DEBUG, 'EOF sent')
@ -238,7 +256,7 @@ class Channel(object):
self.send_eof() self.send_eof()
m = Message() m = Message()
m.add_byte(chr(MSG_CHANNEL_CLOSE)) m.add_byte(chr(MSG_CHANNEL_CLOSE))
m.add_int(self.server_chanid) m.add_int(self.remote_chanid)
self.transport.send_message(m) self.transport.send_message(m)
self.closed = 1 self.closed = 1
self.transport.unlink_channel(self.chanid) self.transport.unlink_channel(self.chanid)
@ -316,7 +334,7 @@ class Channel(object):
size = self.out_max_packet_size size = self.out_max_packet_size
m = Message() m = Message()
m.add_byte(chr(MSG_CHANNEL_DATA)) m.add_byte(chr(MSG_CHANNEL_DATA))
m.add_int(self.server_chanid) m.add_int(self.remote_chanid)
m.add_string(s[:size]) m.add_string(s[:size])
self.transport.send_message(m) self.transport.send_message(m)
self.out_window_size -= size self.out_window_size -= size
@ -469,7 +487,7 @@ class Channel(object):
self.log(DEBUG, 'addwindow send %d' % self.in_window_sofar) self.log(DEBUG, 'addwindow send %d' % self.in_window_sofar)
m = Message() m = Message()
m.add_byte(chr(MSG_CHANNEL_WINDOW_ADJUST)) m.add_byte(chr(MSG_CHANNEL_WINDOW_ADJUST))
m.add_int(self.server_chanid) m.add_int(self.remote_chanid)
m.add_int(self.in_window_sofar) m.add_int(self.in_window_sofar)
self.transport.send_message(m) self.transport.send_message(m)
self.in_window_sofar = 0 self.in_window_sofar = 0
@ -490,7 +508,7 @@ class ChannelFile(object):
def __init__(self, channel, mode = "r", buf_size = -1): def __init__(self, channel, mode = "r", buf_size = -1):
self.channel = channel self.channel = channel
self.mode = mode self.mode = mode
if buf_size < 0: if buf_size <= 0:
self.buf_size = 1024 self.buf_size = 1024
self.line_buffered = 0 self.line_buffered = 0
elif buf_size == 1: elif buf_size == 1:
@ -503,10 +521,12 @@ class ChannelFile(object):
self.rbuffer = "" self.rbuffer = ""
self.readable = ("r" in mode) self.readable = ("r" in mode)
self.writable = ("w" in mode) or ("+" in mode) or ("a" in mode) self.writable = ("w" in mode) or ("+" in mode) or ("a" in mode)
self.universal_newlines = ('U' in mode)
self.binary = ("b" in mode) self.binary = ("b" in mode)
if not self.binary: self.at_trailing_cr = False
raise NotImplementedError("text mode not supported") self.name = '<file from ' + repr(self.channel) + '>'
self.softspace = 0 self.newlines = None
self.softspace = False
def __iter__(self): def __iter__(self):
return self return self
@ -570,23 +590,56 @@ class ChannelFile(object):
self.rbuffer[size:] self.rbuffer[size:]
return result return result
def readline(self, size = None): def readline(self, size=None):
line = "" line = self.rbuffer
while "\n" not in line: while 1:
if size >= 0: if self.at_trailing_cr and (len(line) > 0):
new_data = self.read(size - len(line)) if line[0] == '\n':
line = line[1:]
self.at_trailing_cr = False
if self.universal_newlines:
if ('\n' in line) or ('\r' in line):
break
else: else:
new_data = self.read(64) if '\n' in line:
break
if size >= 0:
if len(line) >= size:
# truncate line and return
self.rbuffer = line[size:]
line = line[:size]
return line
n = size - len(line)
else:
n = 64
new_data = self.channel.recv(n)
if not new_data: if not new_data:
break self.rbuffer = ''
return line
line += new_data line += new_data
newline_pos = line.find("\n") # find the newline
if newline_pos >= 0: pos = line.find('\n')
self.rbuffer = line[newline_pos+1:] + self.rbuffer if self.universal_newlines:
return line[:newline_pos+1] rpos = line.find('\r')
elif len(line) > size: if (rpos >= 0) and ((rpos < pos) or (pos < 0)):
self.rbuffer = line[size:] + self.rbuffer pos = rpos
return line[:size] xpos = pos + 1
if (line[pos] == '\r') and (xpos < len(line)) and (line[xpos] == '\n'):
xpos += 1
self.rbuffer = line[xpos:]
lf = line[pos:xpos]
line = line[:xpos]
if (len(self.rbuffer) == 0) and (lf == '\r'):
# we could read the line up to a '\r' and there could still be a
# '\n' following that we read next time. note that and eat it.
self.at_trailing_cr = True
# silliness about tracking what kinds of newlines we've seen
if self.newlines is None:
self.newlines = lf
elif (type(self.newlines) is str) and (self.newlines != lf):
self.newlines = (self.newlines, lf)
elif lf not in self.newlines:
self.newlines += (lf,)
return line return line
def readlines(self, sizehint = None): def readlines(self, sizehint = None):

View File

@ -1,6 +1,6 @@
#!/usr/bin/python #!/usr/bin/python
import sys, os, socket, threading, logging, traceback import sys, os, socket, threading, logging, traceback, time
import secsh import secsh
# setup logging # setup logging
@ -15,6 +15,19 @@ if len(l.handlers) == 0:
host_key = secsh.RSAKey() host_key = secsh.RSAKey()
host_key.read_private_key_file('demo-host-key') host_key.read_private_key_file('demo-host-key')
class ServerTransport(secsh.Transport):
def check_channel_request(self, kind, chanid):
if kind == 'session':
return secsh.Channel(chanid)
return self.OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED
def check_auth_password(self, username, password):
if (username == 'robey') and (password == 'foo'):
return self.AUTH_SUCCESSFUL
return self.AUTH_FAILED
# now connect # now connect
try: try:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
@ -35,7 +48,7 @@ except Exception, e:
try: try:
event = threading.Event() event = threading.Event()
t = secsh.Transport(client) t = ServerTransport(client)
t.add_server_key(host_key) t.add_server_key(host_key)
t.ultra_debug = 1 t.ultra_debug = 1
t.start_server(event) t.start_server(event)
@ -45,6 +58,18 @@ try:
print '*** SSH negotiation failed.' print '*** SSH negotiation failed.'
sys.exit(1) sys.exit(1)
# print repr(t) # print repr(t)
chan = t.accept()
time.sleep(2)
chan.send('\r\n\r\nWelcome to my dorky little BBS!\r\n\r\n')
chan.send('We are on fire all the time! Hooray! Candy corn for everyone!\r\n')
chan.send('Happy birthday to Robot Dave!\r\n\r\n')
chan.send('Username: ')
f = chan.makefile('rU')
username = f.readline().strip('\r\n')
chan.send('\r\nI don\'t like you, ' + username + '.\r\n')
chan.close()
except Exception, e: except Exception, e:
print '*** Caught exception: ' + str(e.__class__) + ': ' + str(e) print '*** Caught exception: ' + str(e.__class__) + ': ' + str(e)
traceback.print_exc() traceback.print_exc()

View File

@ -76,7 +76,7 @@ try:
# print repr(t) # print repr(t)
keys = load_host_keys() keys = load_host_keys()
keytype, hostkey = t.get_host_key() keytype, hostkey = t.get_remote_server_key()
if not keys.has_key(hostname): if not keys.has_key(hostname):
print '*** WARNING: Unknown host key!' print '*** WARNING: Unknown host key!'
elif not keys[hostname].has_key(keytype): elif not keys[hostname].has_key(keytype):

View File

@ -11,12 +11,11 @@ MSG_CHANNEL_OPEN, MSG_CHANNEL_OPEN_SUCCESS, MSG_CHANNEL_OPEN_FAILURE, \
MSG_CHANNEL_EOF, MSG_CHANNEL_CLOSE, MSG_CHANNEL_REQUEST, \ MSG_CHANNEL_EOF, MSG_CHANNEL_CLOSE, MSG_CHANNEL_REQUEST, \
MSG_CHANNEL_SUCCESS, MSG_CHANNEL_FAILURE = range(90, 101) MSG_CHANNEL_SUCCESS, MSG_CHANNEL_FAILURE = range(90, 101)
import sys, os, string, threading, socket, logging, struct import sys, os, string, threading, socket, logging, struct
from message import Message from message import Message
from channel import Channel from channel import Channel
from secsh import SSHException from secsh import SSHException
from util import format_binary, safe_string, inflate_long, deflate_long from util import format_binary, safe_string, inflate_long, deflate_long, tb_strings
from rsakey import RSAKey from rsakey import RSAKey
from dsskey import DSSKey from dsskey import DSSKey
from kex_group1 import KexGroup1 from kex_group1 import KexGroup1
@ -105,6 +104,9 @@ class BaseTransport(threading.Thread):
REKEY_PACKETS = pow(2, 30) REKEY_PACKETS = pow(2, 30)
REKEY_BYTES = pow(2, 30) REKEY_BYTES = pow(2, 30)
OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED, OPEN_FAILED_CONNECT_FAILED, OPEN_FAILED_UNKNOWN_CHANNEL_TYPE, \
OPEN_FAILED_RESOURCE_SHORTAGE = range(1, 5)
def __init__(self, sock): def __init__(self, sock):
threading.Thread.__init__(self) threading.Thread.__init__(self)
self.randpool = randpool self.randpool = randpool
@ -143,6 +145,8 @@ class BaseTransport(threading.Thread):
# server mode: # server mode:
self.server_mode = 0 self.server_mode = 0
self.server_key_dict = { } self.server_key_dict = { }
self.server_accepts = [ ]
self.server_accept_cv = threading.Condition(self.lock)
def start_client(self, event=None): def start_client(self, event=None):
self.completion_event = event self.completion_event = event
@ -196,7 +200,7 @@ class BaseTransport(threading.Thread):
for chan in self.channels.values(): for chan in self.channels.values():
chan.unlink() chan.unlink()
def get_host_key(self): def get_remote_server_key(self):
'returns (type, key) where type is like "ssh-rsa" and key is an opaque string' 'returns (type, key) where type is like "ssh-rsa" and key is an opaque string'
if (not self.active) or (not self.initial_kex_done): if (not self.active) or (not self.initial_kex_done):
raise SSHException('No existing session') raise SSHException('No existing session')
@ -225,8 +229,9 @@ class BaseTransport(threading.Thread):
m.add_int(chanid) m.add_int(chanid)
m.add_int(self.window_size) m.add_int(self.window_size)
m.add_int(self.max_packet_size) m.add_int(self.max_packet_size)
self.channels[chanid] = chan = Channel(chanid, self) self.channels[chanid] = chan = Channel(chanid)
self.channel_events[chanid] = event = threading.Event() self.channel_events[chanid] = event = threading.Event()
chan.set_transport(self)
chan.set_window(self.window_size, self.max_packet_size) chan.set_window(self.window_size, self.max_packet_size)
self.send_message(m) self.send_message(m)
finally: finally:
@ -445,10 +450,12 @@ class BaseTransport(threading.Thread):
self.send_message(msg) self.send_message(msg)
except SSHException, e: except SSHException, e:
self.log(DEBUG, 'Exception: ' + str(e)) self.log(DEBUG, 'Exception: ' + str(e))
self.log(DEBUG, tb_strings())
except EOFError, e: except EOFError, e:
self.log(DEBUG, 'EOF') self.log(DEBUG, 'EOF')
except Exception, e: except Exception, e:
self.log(DEBUG, 'Unknown exception: ' + str(e)) self.log(DEBUG, 'Unknown exception: ' + str(e))
self.log(DEBUG, tb_strings())
if self.active: if self.active:
self.active = 0 self.active = 0
if self.completion_event != None: if self.completion_event != None:
@ -503,7 +510,11 @@ class BaseTransport(threading.Thread):
comment = buffer[i+1:] comment = buffer[i+1:]
buffer = buffer[:i] buffer = buffer[:i]
# parse out version string and make sure it matches # parse out version string and make sure it matches
_unused, version, client = string.split(buffer, '-') segs = buffer.split('-', 2)
if len(segs) < 3:
raise SSHException('Invalid SSH banner')
version = segs[1]
client = segs[2]
if version != '1.99' and version != '2.0': if version != '1.99' and version != '2.0':
raise SSHException('Incompatible version (%s instead of 2.0)' % (version,)) raise SSHException('Incompatible version (%s instead of 2.0)' % (version,))
self.log(INFO, 'Connected (version %s, client %s)' % (version, client)) self.log(INFO, 'Connected (version %s, client %s)' % (version, client))
@ -681,6 +692,7 @@ class BaseTransport(threading.Thread):
code = m.get_int() code = m.get_int()
desc = m.get_string() desc = m.get_string()
self.log(INFO, 'Disconnect (code %d): %s' % (code, desc)) self.log(INFO, 'Disconnect (code %d): %s' % (code, desc))
def parse_channel_open_success(self, m): def parse_channel_open_success(self, m):
chanid = m.get_int() chanid = m.get_int()
server_chanid = m.get_int() server_chanid = m.get_int()
@ -692,7 +704,7 @@ class BaseTransport(threading.Thread):
try: try:
self.lock.acquire() self.lock.acquire()
chan = self.channels[chanid] chan = self.channels[chanid]
chan.set_server_channel(server_chanid, server_window_size, server_max_packet_size) chan.set_remote_channel(server_chanid, server_window_size, server_max_packet_size)
self.log(INFO, 'Secsh channel %d opened.' % chanid) self.log(INFO, 'Secsh channel %d opened.' % chanid)
if self.channel_events.has_key(chanid): if self.channel_events.has_key(chanid):
self.channel_events[chanid].set() self.channel_events[chanid].set()
@ -719,20 +731,85 @@ class BaseTransport(threading.Thread):
self.channel_events[chanid].set() self.channel_events[chanid].set()
del self.channel_events[chanid] del self.channel_events[chanid]
finally: finally:
self.lock_release() self.lock.release()
return return
def check_channel_request(self, kind, chanid):
"override me! return object descended from Channel to allow, or None to reject"
return None
def parse_channel_open(self, m): def parse_channel_open(self, m):
kind = m.get_string() kind = m.get_string()
self.log(DEBUG, 'Rejecting "%s" channel request from server.' % kind)
chanid = m.get_int() chanid = m.get_int()
msg = Message() initial_window_size = m.get_int()
msg.add_byte(chr(MSG_CHANNEL_OPEN_FAILURE)) max_packet_size = m.get_int()
msg.add_int(chanid) reject = False
msg.add_int(1) if not self.server_mode:
msg.add_string('Client connections are not allowed.') self.log(DEBUG, 'Rejecting "%s" channel request from server.' % kind)
msg.add_string('en') reject = True
self.send_message(msg) reason = self.OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED
else:
try:
self.lock.acquire()
my_chanid = self.channel_counter
self.channel_counter += 1
finally:
self.lock.release()
chan = self.check_channel_request(kind, my_chanid)
if (chan is None) or (type(chan) is int):
self.log(DEBUG, 'Rejecting "%s" channel request from client.' % kind)
reject = True
if type(chan) is int:
reason = chan
else:
reason = self.OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED
if reject:
msg = Message()
msg.add_byte(chr(MSG_CHANNEL_OPEN_FAILURE))
msg.add_int(chanid)
msg.add_int(reason)
msg.add_string('')
msg.add_string('en')
self.send_message(msg)
return
try:
self.lock.acquire()
self.channels[my_chanid] = chan
chan.set_transport(self)
chan.set_window(self.window_size, self.max_packet_size)
chan.set_remote_channel(chanid, initial_window_size, max_packet_size)
finally:
self.lock.release()
m = Message()
m.add_byte(chr(MSG_CHANNEL_OPEN_SUCCESS))
m.add_int(chanid)
m.add_int(my_chanid)
m.add_int(self.window_size)
m.add_int(self.max_packet_size)
self.send_message(m)
self.log(INFO, 'Secsh channel %d opened.' % my_chanid)
try:
self.lock.acquire()
self.server_accepts.append(chan)
self.server_accept_cv.notify()
finally:
self.lock.release()
def accept(self, timeout=None):
try:
self.lock.acquire()
if len(self.server_accepts) > 0:
chan = self.server_accepts.pop(0)
else:
self.server_accept_cv.wait(timeout)
if len(self.server_accepts) > 0:
chan = self.server_accepts.pop(0)
else:
# timeout
chan = None
finally:
self.lock.release()
return chan
def parse_debug(self, m): def parse_debug(self, m):
always_display = m.get_boolean() always_display = m.get_boolean()