[project @ Arch-1:robey@lag.net--2003-public%secsh--dev--1.0--patch-44]

add get_username() method for remembering who you auth'd as
add get_username() method for remembering who you auth'd as.  also, fix these
bugs:
* "continue" auth response counted as a failure (in server mode).
* try to import 'logging' in py22 before falling back to the fake logger,
  in case they have a backported version of 'logger'
* raise the right exception when told to read a private key from a file that
  isn't a private key file
* tell channels to close when the transport dies
This commit is contained in:
Robey Pointer 2004-04-07 06:07:29 +00:00
parent 68c8a9b2e6
commit 1af6360007
5 changed files with 38 additions and 10 deletions

View File

@ -48,6 +48,7 @@ class Transport (BaseTransport):
def __init__(self, sock): def __init__(self, sock):
BaseTransport.__init__(self, sock) BaseTransport.__init__(self, sock)
self.username = None
self.authenticated = False self.authenticated = False
self.auth_event = None self.auth_event = None
# for server mode: # for server mode:
@ -79,9 +80,21 @@ class Transport (BaseTransport):
@return: True if the session is still open and has been authenticated successfully; @return: True if the session is still open and has been authenticated successfully;
False if authentication failed and/or the session is closed. False if authentication failed and/or the session is closed.
@rtype: bool
""" """
return self.authenticated and self.active return self.authenticated and self.active
def get_username(self):
"""
Return the username this connection is authenticated for. If the
session is not authenticated (or authentication failed), this method
returns C{None}.
@return: username that was authenticated, or C{None}.
@rtype: string
"""
return self.username
def auth_publickey(self, username, key, event): def auth_publickey(self, username, key, event):
""" """
Authenticate to the server using a private key. The key is used to Authenticate to the server using a private key. The key is used to
@ -334,6 +347,8 @@ class Transport (BaseTransport):
self._log(DEBUG, 'Auth rejected because the client attempted to change username in mid-flight') self._log(DEBUG, 'Auth rejected because the client attempted to change username in mid-flight')
self._disconnect_no_more_auth() self._disconnect_no_more_auth()
return return
self.auth_username = username
if method == 'none': if method == 'none':
result = self.check_auth_none(username) result = self.check_auth_none(username)
elif method == 'password': elif method == 'password':
@ -390,7 +405,7 @@ class Transport (BaseTransport):
m.add_boolean(1) m.add_boolean(1)
else: else:
m.add_boolean(0) m.add_boolean(0)
self.auth_fail_count += 1 self.auth_fail_count += 1
self._send_message(m) self._send_message(m)
if self.auth_fail_count >= 10: if self.auth_fail_count >= 10:
self._disconnect_no_more_auth() self._disconnect_no_more_auth()
@ -411,6 +426,8 @@ class Transport (BaseTransport):
pass pass
self._log(INFO, 'Authentication failed.') self._log(INFO, 'Authentication failed.')
self.authenticated = False self.authenticated = False
# FIXME: i don't think we need to close() necessarily here
self.username = None
self.close() self.close()
if self.auth_event != None: if self.auth_event != None:
self.auth_event.set() self.auth_event.set()

View File

@ -70,7 +70,7 @@ class Channel (object):
self.eof_sent = 0 self.eof_sent = 0
self.in_buffer = '' self.in_buffer = ''
self.timeout = None self.timeout = None
self.closed = 0 self.closed = False
self.lock = threading.Lock() self.lock = threading.Lock()
self.in_buffer_cv = threading.Condition(self.lock) self.in_buffer_cv = threading.Condition(self.lock)
self.out_buffer_cv = threading.Condition(self.lock) self.out_buffer_cv = threading.Condition(self.lock)
@ -302,7 +302,7 @@ class Channel (object):
m.add_byte(chr(MSG_CHANNEL_CLOSE)) m.add_byte(chr(MSG_CHANNEL_CLOSE))
m.add_int(self.remote_chanid) m.add_int(self.remote_chanid)
self.transport._send_message(m) self.transport._send_message(m)
self.closed = 1 self._set_closed()
self.transport._unlink_channel(self.chanid) self.transport._unlink_channel(self.chanid)
finally: finally:
self.lock.release() self.lock.release()
@ -723,6 +723,11 @@ class Channel (object):
def _log(self, level, msg): def _log(self, level, msg):
self.logger.log(level, msg) self.logger.log(level, msg)
def _set_closed(self):
self.closed = True
self.in_buffer_cv.notifyAll()
self.out_buffer_cv.notifyAll()
def _send_eof(self): def _send_eof(self):
if self.eof_sent: if self.eof_sent:
return return
@ -815,7 +820,7 @@ class Channel (object):
def _unlink(self): def _unlink(self):
if self.closed or not self.active: if self.closed or not self.active:
return return
self.closed = 1 self._set_closed()
self.transport._unlink_channel(self.chanid) self.transport._unlink_channel(self.chanid)
def _check_add_window(self, n): def _check_add_window(self, n):

View File

@ -62,7 +62,10 @@ randpool.randomize()
import sys import sys
if sys.version_info < (2, 3): if sys.version_info < (2, 3):
import logging22 as logging try:
import logging
except:
import logging22 as logging
import select import select
PY22 = True PY22 = True
else: else:

View File

@ -237,7 +237,7 @@ class PKey (object):
lines = f.readlines() lines = f.readlines()
f.close() f.close()
start = 0 start = 0
while (lines[start].strip() != '-----BEGIN ' + tag + ' PRIVATE KEY-----') and (start < len(lines)): while (start < len(lines)) and (lines[start].strip() != '-----BEGIN ' + tag + ' PRIVATE KEY-----'):
start += 1 start += 1
if start >= len(lines): if start >= len(lines):
raise SSHException('not a valid ' + tag + ' private key file') raise SSHException('not a valid ' + tag + ' private key file')

View File

@ -636,10 +636,11 @@ class BaseTransport (threading.Thread):
# check host key if we were given one # check host key if we were given one
if (hostkeytype is not None) and (hostkey is not None): if (hostkeytype is not None) and (hostkey is not None):
type, key = self.get_remote_server_key() keytype, key = self.get_remote_server_key()
if (type != hostkeytype) or (key != hostkey): if (keytype != hostkeytype) or (key != hostkey):
print repr(type) + ' - ' + repr(hostkeytype) self._log(DEBUG, 'Bad host key from server')
print repr(key) + ' - ' + repr(hostkey) self._log(DEBUG, 'Expected: %s: %s' % (repr(hostkeytype), repr(hostkey)))
self._log(DEBUG, 'Got : %s: %s' % (repr(keytype), repr(key)))
raise SSHException('Bad host key from server') raise SSHException('Bad host key from server')
self._log(DEBUG, 'Host key verified (%s)' % hostkeytype) self._log(DEBUG, 'Host key verified (%s)' % hostkeytype)
@ -920,6 +921,8 @@ class BaseTransport (threading.Thread):
self._log(DEBUG, util.tb_strings()) self._log(DEBUG, util.tb_strings())
self.saved_exception = e self.saved_exception = e
_active_threads.remove(self) _active_threads.remove(self)
for chan in self.channels.values():
chan._unlink()
if self.active: if self.active:
self.active = False self.active = False
if self.completion_event != None: if self.completion_event != None: