[project @ Arch-1:robey@lag.net--2003-public%secsh--dev--1.0--patch-44]
add get_username() method for remembering who you auth'd as add get_username() method for remembering who you auth'd as. also, fix these bugs: * "continue" auth response counted as a failure (in server mode). * try to import 'logging' in py22 before falling back to the fake logger, in case they have a backported version of 'logger' * raise the right exception when told to read a private key from a file that isn't a private key file * tell channels to close when the transport dies
This commit is contained in:
parent
68c8a9b2e6
commit
1af6360007
|
@ -48,6 +48,7 @@ class Transport (BaseTransport):
|
|||
|
||||
def __init__(self, sock):
|
||||
BaseTransport.__init__(self, sock)
|
||||
self.username = None
|
||||
self.authenticated = False
|
||||
self.auth_event = None
|
||||
# for server mode:
|
||||
|
@ -79,9 +80,21 @@ class Transport (BaseTransport):
|
|||
|
||||
@return: True if the session is still open and has been authenticated successfully;
|
||||
False if authentication failed and/or the session is closed.
|
||||
@rtype: bool
|
||||
"""
|
||||
return self.authenticated and self.active
|
||||
|
||||
def get_username(self):
|
||||
"""
|
||||
Return the username this connection is authenticated for. If the
|
||||
session is not authenticated (or authentication failed), this method
|
||||
returns C{None}.
|
||||
|
||||
@return: username that was authenticated, or C{None}.
|
||||
@rtype: string
|
||||
"""
|
||||
return self.username
|
||||
|
||||
def auth_publickey(self, username, key, event):
|
||||
"""
|
||||
Authenticate to the server using a private key. The key is used to
|
||||
|
@ -334,6 +347,8 @@ class Transport (BaseTransport):
|
|||
self._log(DEBUG, 'Auth rejected because the client attempted to change username in mid-flight')
|
||||
self._disconnect_no_more_auth()
|
||||
return
|
||||
self.auth_username = username
|
||||
|
||||
if method == 'none':
|
||||
result = self.check_auth_none(username)
|
||||
elif method == 'password':
|
||||
|
@ -411,6 +426,8 @@ class Transport (BaseTransport):
|
|||
pass
|
||||
self._log(INFO, 'Authentication failed.')
|
||||
self.authenticated = False
|
||||
# FIXME: i don't think we need to close() necessarily here
|
||||
self.username = None
|
||||
self.close()
|
||||
if self.auth_event != None:
|
||||
self.auth_event.set()
|
||||
|
|
|
@ -70,7 +70,7 @@ class Channel (object):
|
|||
self.eof_sent = 0
|
||||
self.in_buffer = ''
|
||||
self.timeout = None
|
||||
self.closed = 0
|
||||
self.closed = False
|
||||
self.lock = threading.Lock()
|
||||
self.in_buffer_cv = threading.Condition(self.lock)
|
||||
self.out_buffer_cv = threading.Condition(self.lock)
|
||||
|
@ -302,7 +302,7 @@ class Channel (object):
|
|||
m.add_byte(chr(MSG_CHANNEL_CLOSE))
|
||||
m.add_int(self.remote_chanid)
|
||||
self.transport._send_message(m)
|
||||
self.closed = 1
|
||||
self._set_closed()
|
||||
self.transport._unlink_channel(self.chanid)
|
||||
finally:
|
||||
self.lock.release()
|
||||
|
@ -723,6 +723,11 @@ class Channel (object):
|
|||
def _log(self, level, msg):
|
||||
self.logger.log(level, msg)
|
||||
|
||||
def _set_closed(self):
|
||||
self.closed = True
|
||||
self.in_buffer_cv.notifyAll()
|
||||
self.out_buffer_cv.notifyAll()
|
||||
|
||||
def _send_eof(self):
|
||||
if self.eof_sent:
|
||||
return
|
||||
|
@ -815,7 +820,7 @@ class Channel (object):
|
|||
def _unlink(self):
|
||||
if self.closed or not self.active:
|
||||
return
|
||||
self.closed = 1
|
||||
self._set_closed()
|
||||
self.transport._unlink_channel(self.chanid)
|
||||
|
||||
def _check_add_window(self, n):
|
||||
|
|
|
@ -62,6 +62,9 @@ randpool.randomize()
|
|||
|
||||
import sys
|
||||
if sys.version_info < (2, 3):
|
||||
try:
|
||||
import logging
|
||||
except:
|
||||
import logging22 as logging
|
||||
import select
|
||||
PY22 = True
|
||||
|
|
|
@ -237,7 +237,7 @@ class PKey (object):
|
|||
lines = f.readlines()
|
||||
f.close()
|
||||
start = 0
|
||||
while (lines[start].strip() != '-----BEGIN ' + tag + ' PRIVATE KEY-----') and (start < len(lines)):
|
||||
while (start < len(lines)) and (lines[start].strip() != '-----BEGIN ' + tag + ' PRIVATE KEY-----'):
|
||||
start += 1
|
||||
if start >= len(lines):
|
||||
raise SSHException('not a valid ' + tag + ' private key file')
|
||||
|
|
|
@ -636,10 +636,11 @@ class BaseTransport (threading.Thread):
|
|||
|
||||
# check host key if we were given one
|
||||
if (hostkeytype is not None) and (hostkey is not None):
|
||||
type, key = self.get_remote_server_key()
|
||||
if (type != hostkeytype) or (key != hostkey):
|
||||
print repr(type) + ' - ' + repr(hostkeytype)
|
||||
print repr(key) + ' - ' + repr(hostkey)
|
||||
keytype, key = self.get_remote_server_key()
|
||||
if (keytype != hostkeytype) or (key != hostkey):
|
||||
self._log(DEBUG, 'Bad host key from server')
|
||||
self._log(DEBUG, 'Expected: %s: %s' % (repr(hostkeytype), repr(hostkey)))
|
||||
self._log(DEBUG, 'Got : %s: %s' % (repr(keytype), repr(key)))
|
||||
raise SSHException('Bad host key from server')
|
||||
self._log(DEBUG, 'Host key verified (%s)' % hostkeytype)
|
||||
|
||||
|
@ -920,6 +921,8 @@ class BaseTransport (threading.Thread):
|
|||
self._log(DEBUG, util.tb_strings())
|
||||
self.saved_exception = e
|
||||
_active_threads.remove(self)
|
||||
for chan in self.channels.values():
|
||||
chan._unlink()
|
||||
if self.active:
|
||||
self.active = False
|
||||
if self.completion_event != None:
|
||||
|
|
Loading…
Reference in New Issue