diff --git a/auth_transport.py b/auth_transport.py index 1a06326..78ce8d7 100644 --- a/auth_transport.py +++ b/auth_transport.py @@ -10,11 +10,13 @@ from logging import DEBUG, INFO, WARNING, ERROR, CRITICAL DISCONNECT_SERVICE_NOT_AVAILABLE, DISCONNECT_AUTH_CANCELLED_BY_USER, \ DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE = 7, 13, 14 -AUTH_SUCCESSFUL, AUTH_PARTIALLY_SUCCESSFUL, AUTH_FAILED = range(3) class Transport(BaseTransport): "BaseTransport with the auth framework hooked up" + + AUTH_SUCCESSFUL, AUTH_PARTIALLY_SUCCESSFUL, AUTH_FAILED = range(3) + def __init__(self, sock): BaseTransport.__init__(self, sock) self.auth_event = None @@ -111,21 +113,21 @@ class Transport(BaseTransport): else: self.log(DEBUG, 'Service request "%s" accepted (?)' % service) - def get_allowed_auths(self): + def get_allowed_auths(self, username): "override me!" return 'password' def check_auth_none(self, username): - "override me! return tuple of (int, string) ==> (auth status, list of acceptable auth methods)" - return (AUTH_FAILED, self.get_allowed_auths()) + "override me! return int ==> auth status" + return self.AUTH_FAILED def check_auth_password(self, username, password): - "override me! return tuple of (int, string) ==> (auth status, list of acceptable auth methods)" - return (AUTH_FAILED, self.get_allowed_auths()) + "override me! return int ==> auth status" + return self.AUTH_FAILED def check_auth_publickey(self, username, key): - "override me! return tuple of (int, string) ==> (auth status, list of acceptable auth methods)" - return (AUTH_FAILED, self.get_allowed_auths()) + "override me! return int ==> auth status" + return self.AUTH_FAILED def parse_userauth_request(self, m): if not self.server_mode: @@ -142,11 +144,12 @@ class Transport(BaseTransport): username = m.get_string() service = m.get_string() method = m.get_string() + self.log(DEBUG, 'Auth request (type=%s) service=%s, username=%s' % (method, service, username)) if service != 'ssh-connection': self.disconnect_service_not_available() return if (self.auth_username is not None) and (self.auth_username != username): - # trying to change username in mid-flight! + self.log(DEBUG, 'Auth rejected because the client attempted to change username in mid-flight') self.disconnect_no_more_auth() return if method == 'none': @@ -157,27 +160,27 @@ class Transport(BaseTransport): if changereq: # always treated as failure, since we don't support changing passwords, but collect # the list of valid auth types from the callback anyway + self.log(DEBUG, 'Auth request to change passwords (rejected)') newpassword = m.get_string().decode('UTF-8') - result = self.check_auth_password(username, password) - result = (AUTH_FAILED, result[1]) + result = self.AUTH_FAILED else: result = self.check_auth_password(username, password) elif method == 'publickey': # FIXME result = self.check_auth_none(username) - result = (AUTH_FAILED, result[1]) else: result = self.check_auth_none(username) - result = (AUTH_FAILED, result[1]) # okay, send result m = Message() - if result[0] == AUTH_SUCCESSFUL: - m.add_byte(chr(MSG_USERAUTH_SUCCESSFUL)) + if result == self.AUTH_SUCCESSFUL: + self.log(DEBUG, 'Auth granted.') + m.add_byte(chr(MSG_USERAUTH_SUCCESS)) self.auth_complete = 1 else: + self.log(DEBUG, 'Auth rejected.') m.add_byte(chr(MSG_USERAUTH_FAILURE)) - m.add_string(result[1]) - if result[0] == AUTH_PARTIALLY_SUCCESSFUL: + m.add_string(self.get_allowed_auths(username)) + if result == self.AUTH_PARTIALLY_SUCCESSFUL: m.add_boolean(1) else: m.add_boolean(0) diff --git a/channel.py b/channel.py index 275c0a2..8f53d37 100644 --- a/channel.py +++ b/channel.py @@ -1,7 +1,7 @@ from message import Message from secsh import SSHException from transport import MSG_CHANNEL_REQUEST, MSG_CHANNEL_CLOSE, MSG_CHANNEL_WINDOW_ADJUST, MSG_CHANNEL_DATA, \ - MSG_CHANNEL_EOF + MSG_CHANNEL_EOF, MSG_CHANNEL_SUCCESS, MSG_CHANNEL_FAILURE import time, threading, logging, socket, os from logging import DEBUG @@ -18,9 +18,9 @@ class Channel(object): Abstraction for a secsh channel. """ - def __init__(self, chanid, transport): + def __init__(self, chanid): self.chanid = chanid - self.transport = transport + self.transport = None self.active = 0 self.eof_received = 0 self.eof_sent = 0 @@ -50,6 +50,9 @@ class Channel(object): out += '>' return out + def set_transport(self, transport): + self.transport = transport + def log(self, level, msg): self.logger.log(level, msg) @@ -60,8 +63,8 @@ class Channel(object): self.in_window_threshold = window_size // 10 self.in_window_sofar = 0 - def set_server_channel(self, chanid, window_size, max_packet_size): - self.server_chanid = chanid + def set_remote_channel(self, chanid, window_size, max_packet_size): + self.remote_chanid = chanid self.out_window_size = window_size self.out_max_packet_size = max_packet_size self.active = 1 @@ -99,14 +102,29 @@ class Channel(object): def handle_request(self, m): key = m.get_string() + want_reply = m.get_boolean() + ok = False if key == 'exit-status': self.exit_status = m.get_int() - return + ok = True elif key == 'xon-xoff': # ignore - return + ok = True + elif (key == 'pty-req') or (key == 'shell'): + if self.transport.server_mode: + # humor them + ok = True else: self.log(DEBUG, 'Unhandled channel request "%s"' % key) + ok = False + if want_reply: + m = Message() + if ok: + m.add_byte(chr(MSG_CHANNEL_SUCCESS)) + else: + m.add_byte(chr(MSG_CHANNEL_FAILURE)) + m.add_int(self.remote_chanid) + self.transport.send_message(m) def handle_eof(self, m): self.eof_received = 1 @@ -140,7 +158,7 @@ class Channel(object): raise SSHException('Channel is not open') m = Message() m.add_byte(chr(MSG_CHANNEL_REQUEST)) - m.add_int(self.server_chanid) + m.add_int(self.remote_chanid) m.add_string('pty-req') m.add_boolean(0) m.add_string(term) @@ -156,7 +174,7 @@ class Channel(object): raise SSHException('Channel is not open') m = Message() m.add_byte(chr(MSG_CHANNEL_REQUEST)) - m.add_int(self.server_chanid) + m.add_int(self.remote_chanid) m.add_string('shell') m.add_boolean(1) self.transport.send_message(m) @@ -166,7 +184,7 @@ class Channel(object): raise SSHException('Channel is not open') m = Message() m.add_byte(chr(MSG_CHANNEL_REQUEST)) - m.add_int(self.server_chanid) + m.add_int(self.remote_chanid) m.add_string('exec') m.add_boolean(1) m.add_string(command) @@ -177,7 +195,7 @@ class Channel(object): raise SSHException('Channel is not open') m = Message() m.add_byte(chr(MSG_CHANNEL_REQUEST)) - m.add_int(self.server_chanid) + m.add_int(self.remote_chanid) m.add_string('subsystem') m.add_boolean(1) m.add_string(subsystem) @@ -188,7 +206,7 @@ class Channel(object): raise SSHException('Channel is not open') m = Message() m.add_byte(chr(MSG_CHANNEL_REQUEST)) - m.add_int(self.server_chanid) + m.add_int(self.remote_chanid) m.add_string('window-change') m.add_boolean(0) m.add_int(width) @@ -211,7 +229,7 @@ class Channel(object): return m = Message() m.add_byte(chr(MSG_CHANNEL_EOF)) - m.add_int(self.server_chanid) + m.add_int(self.remote_chanid) self.transport.send_message(m) self.eof_sent = 1 self.log(DEBUG, 'EOF sent') @@ -238,7 +256,7 @@ class Channel(object): self.send_eof() m = Message() m.add_byte(chr(MSG_CHANNEL_CLOSE)) - m.add_int(self.server_chanid) + m.add_int(self.remote_chanid) self.transport.send_message(m) self.closed = 1 self.transport.unlink_channel(self.chanid) @@ -316,7 +334,7 @@ class Channel(object): size = self.out_max_packet_size m = Message() m.add_byte(chr(MSG_CHANNEL_DATA)) - m.add_int(self.server_chanid) + m.add_int(self.remote_chanid) m.add_string(s[:size]) self.transport.send_message(m) self.out_window_size -= size @@ -469,7 +487,7 @@ class Channel(object): self.log(DEBUG, 'addwindow send %d' % self.in_window_sofar) m = Message() m.add_byte(chr(MSG_CHANNEL_WINDOW_ADJUST)) - m.add_int(self.server_chanid) + m.add_int(self.remote_chanid) m.add_int(self.in_window_sofar) self.transport.send_message(m) self.in_window_sofar = 0 @@ -490,7 +508,7 @@ class ChannelFile(object): def __init__(self, channel, mode = "r", buf_size = -1): self.channel = channel self.mode = mode - if buf_size < 0: + if buf_size <= 0: self.buf_size = 1024 self.line_buffered = 0 elif buf_size == 1: @@ -503,10 +521,12 @@ class ChannelFile(object): self.rbuffer = "" self.readable = ("r" in mode) self.writable = ("w" in mode) or ("+" in mode) or ("a" in mode) + self.universal_newlines = ('U' in mode) self.binary = ("b" in mode) - if not self.binary: - raise NotImplementedError("text mode not supported") - self.softspace = 0 + self.at_trailing_cr = False + self.name = '' + self.newlines = None + self.softspace = False def __iter__(self): return self @@ -570,23 +590,56 @@ class ChannelFile(object): self.rbuffer[size:] return result - def readline(self, size = None): - line = "" - while "\n" not in line: - if size >= 0: - new_data = self.read(size - len(line)) + def readline(self, size=None): + line = self.rbuffer + while 1: + if self.at_trailing_cr and (len(line) > 0): + if line[0] == '\n': + line = line[1:] + self.at_trailing_cr = False + if self.universal_newlines: + if ('\n' in line) or ('\r' in line): + break else: - new_data = self.read(64) + if '\n' in line: + break + if size >= 0: + if len(line) >= size: + # truncate line and return + self.rbuffer = line[size:] + line = line[:size] + return line + n = size - len(line) + else: + n = 64 + new_data = self.channel.recv(n) if not new_data: - break + self.rbuffer = '' + return line line += new_data - newline_pos = line.find("\n") - if newline_pos >= 0: - self.rbuffer = line[newline_pos+1:] + self.rbuffer - return line[:newline_pos+1] - elif len(line) > size: - self.rbuffer = line[size:] + self.rbuffer - return line[:size] + # find the newline + pos = line.find('\n') + if self.universal_newlines: + rpos = line.find('\r') + if (rpos >= 0) and ((rpos < pos) or (pos < 0)): + pos = rpos + xpos = pos + 1 + if (line[pos] == '\r') and (xpos < len(line)) and (line[xpos] == '\n'): + xpos += 1 + self.rbuffer = line[xpos:] + lf = line[pos:xpos] + line = line[:xpos] + if (len(self.rbuffer) == 0) and (lf == '\r'): + # we could read the line up to a '\r' and there could still be a + # '\n' following that we read next time. note that and eat it. + self.at_trailing_cr = True + # silliness about tracking what kinds of newlines we've seen + if self.newlines is None: + self.newlines = lf + elif (type(self.newlines) is str) and (self.newlines != lf): + self.newlines = (self.newlines, lf) + elif lf not in self.newlines: + self.newlines += (lf,) return line def readlines(self, sizehint = None): diff --git a/demo-server.py b/demo-server.py index 1db0223..b0f8326 100755 --- a/demo-server.py +++ b/demo-server.py @@ -1,6 +1,6 @@ #!/usr/bin/python -import sys, os, socket, threading, logging, traceback +import sys, os, socket, threading, logging, traceback, time import secsh # setup logging @@ -15,6 +15,19 @@ if len(l.handlers) == 0: host_key = secsh.RSAKey() host_key.read_private_key_file('demo-host-key') + +class ServerTransport(secsh.Transport): + def check_channel_request(self, kind, chanid): + if kind == 'session': + return secsh.Channel(chanid) + return self.OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED + + def check_auth_password(self, username, password): + if (username == 'robey') and (password == 'foo'): + return self.AUTH_SUCCESSFUL + return self.AUTH_FAILED + + # now connect try: sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) @@ -35,7 +48,7 @@ except Exception, e: try: event = threading.Event() - t = secsh.Transport(client) + t = ServerTransport(client) t.add_server_key(host_key) t.ultra_debug = 1 t.start_server(event) @@ -45,6 +58,18 @@ try: print '*** SSH negotiation failed.' sys.exit(1) # print repr(t) + + chan = t.accept() + time.sleep(2) + chan.send('\r\n\r\nWelcome to my dorky little BBS!\r\n\r\n') + chan.send('We are on fire all the time! Hooray! Candy corn for everyone!\r\n') + chan.send('Happy birthday to Robot Dave!\r\n\r\n') + chan.send('Username: ') + f = chan.makefile('rU') + username = f.readline().strip('\r\n') + chan.send('\r\nI don\'t like you, ' + username + '.\r\n') + chan.close() + except Exception, e: print '*** Caught exception: ' + str(e.__class__) + ': ' + str(e) traceback.print_exc() diff --git a/demo.py b/demo.py index fc707e4..069077d 100755 --- a/demo.py +++ b/demo.py @@ -76,7 +76,7 @@ try: # print repr(t) keys = load_host_keys() - keytype, hostkey = t.get_host_key() + keytype, hostkey = t.get_remote_server_key() if not keys.has_key(hostname): print '*** WARNING: Unknown host key!' elif not keys[hostname].has_key(keytype): diff --git a/transport.py b/transport.py index 2020b27..c5ff252 100644 --- a/transport.py +++ b/transport.py @@ -11,12 +11,11 @@ MSG_CHANNEL_OPEN, MSG_CHANNEL_OPEN_SUCCESS, MSG_CHANNEL_OPEN_FAILURE, \ MSG_CHANNEL_EOF, MSG_CHANNEL_CLOSE, MSG_CHANNEL_REQUEST, \ MSG_CHANNEL_SUCCESS, MSG_CHANNEL_FAILURE = range(90, 101) - import sys, os, string, threading, socket, logging, struct from message import Message from channel import Channel from secsh import SSHException -from util import format_binary, safe_string, inflate_long, deflate_long +from util import format_binary, safe_string, inflate_long, deflate_long, tb_strings from rsakey import RSAKey from dsskey import DSSKey from kex_group1 import KexGroup1 @@ -105,6 +104,9 @@ class BaseTransport(threading.Thread): REKEY_PACKETS = pow(2, 30) REKEY_BYTES = pow(2, 30) + OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED, OPEN_FAILED_CONNECT_FAILED, OPEN_FAILED_UNKNOWN_CHANNEL_TYPE, \ + OPEN_FAILED_RESOURCE_SHORTAGE = range(1, 5) + def __init__(self, sock): threading.Thread.__init__(self) self.randpool = randpool @@ -143,6 +145,8 @@ class BaseTransport(threading.Thread): # server mode: self.server_mode = 0 self.server_key_dict = { } + self.server_accepts = [ ] + self.server_accept_cv = threading.Condition(self.lock) def start_client(self, event=None): self.completion_event = event @@ -196,7 +200,7 @@ class BaseTransport(threading.Thread): for chan in self.channels.values(): chan.unlink() - def get_host_key(self): + def get_remote_server_key(self): 'returns (type, key) where type is like "ssh-rsa" and key is an opaque string' if (not self.active) or (not self.initial_kex_done): raise SSHException('No existing session') @@ -225,8 +229,9 @@ class BaseTransport(threading.Thread): m.add_int(chanid) m.add_int(self.window_size) m.add_int(self.max_packet_size) - self.channels[chanid] = chan = Channel(chanid, self) + self.channels[chanid] = chan = Channel(chanid) self.channel_events[chanid] = event = threading.Event() + chan.set_transport(self) chan.set_window(self.window_size, self.max_packet_size) self.send_message(m) finally: @@ -445,10 +450,12 @@ class BaseTransport(threading.Thread): self.send_message(msg) except SSHException, e: self.log(DEBUG, 'Exception: ' + str(e)) + self.log(DEBUG, tb_strings()) except EOFError, e: self.log(DEBUG, 'EOF') except Exception, e: self.log(DEBUG, 'Unknown exception: ' + str(e)) + self.log(DEBUG, tb_strings()) if self.active: self.active = 0 if self.completion_event != None: @@ -503,7 +510,11 @@ class BaseTransport(threading.Thread): comment = buffer[i+1:] buffer = buffer[:i] # parse out version string and make sure it matches - _unused, version, client = string.split(buffer, '-') + segs = buffer.split('-', 2) + if len(segs) < 3: + raise SSHException('Invalid SSH banner') + version = segs[1] + client = segs[2] if version != '1.99' and version != '2.0': raise SSHException('Incompatible version (%s instead of 2.0)' % (version,)) self.log(INFO, 'Connected (version %s, client %s)' % (version, client)) @@ -681,6 +692,7 @@ class BaseTransport(threading.Thread): code = m.get_int() desc = m.get_string() self.log(INFO, 'Disconnect (code %d): %s' % (code, desc)) + def parse_channel_open_success(self, m): chanid = m.get_int() server_chanid = m.get_int() @@ -692,7 +704,7 @@ class BaseTransport(threading.Thread): try: self.lock.acquire() chan = self.channels[chanid] - chan.set_server_channel(server_chanid, server_window_size, server_max_packet_size) + chan.set_remote_channel(server_chanid, server_window_size, server_max_packet_size) self.log(INFO, 'Secsh channel %d opened.' % chanid) if self.channel_events.has_key(chanid): self.channel_events[chanid].set() @@ -719,20 +731,85 @@ class BaseTransport(threading.Thread): self.channel_events[chanid].set() del self.channel_events[chanid] finally: - self.lock_release() + self.lock.release() return + def check_channel_request(self, kind, chanid): + "override me! return object descended from Channel to allow, or None to reject" + return None + def parse_channel_open(self, m): kind = m.get_string() - self.log(DEBUG, 'Rejecting "%s" channel request from server.' % kind) chanid = m.get_int() - msg = Message() - msg.add_byte(chr(MSG_CHANNEL_OPEN_FAILURE)) - msg.add_int(chanid) - msg.add_int(1) - msg.add_string('Client connections are not allowed.') - msg.add_string('en') - self.send_message(msg) + initial_window_size = m.get_int() + max_packet_size = m.get_int() + reject = False + if not self.server_mode: + self.log(DEBUG, 'Rejecting "%s" channel request from server.' % kind) + reject = True + reason = self.OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED + else: + try: + self.lock.acquire() + my_chanid = self.channel_counter + self.channel_counter += 1 + finally: + self.lock.release() + chan = self.check_channel_request(kind, my_chanid) + if (chan is None) or (type(chan) is int): + self.log(DEBUG, 'Rejecting "%s" channel request from client.' % kind) + reject = True + if type(chan) is int: + reason = chan + else: + reason = self.OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED + if reject: + msg = Message() + msg.add_byte(chr(MSG_CHANNEL_OPEN_FAILURE)) + msg.add_int(chanid) + msg.add_int(reason) + msg.add_string('') + msg.add_string('en') + self.send_message(msg) + return + try: + self.lock.acquire() + self.channels[my_chanid] = chan + chan.set_transport(self) + chan.set_window(self.window_size, self.max_packet_size) + chan.set_remote_channel(chanid, initial_window_size, max_packet_size) + finally: + self.lock.release() + m = Message() + m.add_byte(chr(MSG_CHANNEL_OPEN_SUCCESS)) + m.add_int(chanid) + m.add_int(my_chanid) + m.add_int(self.window_size) + m.add_int(self.max_packet_size) + self.send_message(m) + self.log(INFO, 'Secsh channel %d opened.' % my_chanid) + try: + self.lock.acquire() + self.server_accepts.append(chan) + self.server_accept_cv.notify() + finally: + self.lock.release() + + def accept(self, timeout=None): + try: + self.lock.acquire() + if len(self.server_accepts) > 0: + chan = self.server_accepts.pop(0) + else: + self.server_accept_cv.wait(timeout) + if len(self.server_accepts) > 0: + chan = self.server_accepts.pop(0) + else: + # timeout + chan = None + finally: + self.lock.release() + return chan def parse_debug(self, m): always_display = m.get_boolean()