diff --git a/paramiko/auth_transport.py b/paramiko/auth_transport.py index b8468f0..1bc0f7d 100644 --- a/paramiko/auth_transport.py +++ b/paramiko/auth_transport.py @@ -48,6 +48,7 @@ class Transport (BaseTransport): def __init__(self, sock): BaseTransport.__init__(self, sock) + self.username = None self.authenticated = False self.auth_event = None # for server mode: @@ -79,9 +80,21 @@ class Transport (BaseTransport): @return: True if the session is still open and has been authenticated successfully; False if authentication failed and/or the session is closed. + @rtype: bool """ return self.authenticated and self.active + def get_username(self): + """ + Return the username this connection is authenticated for. If the + session is not authenticated (or authentication failed), this method + returns C{None}. + + @return: username that was authenticated, or C{None}. + @rtype: string + """ + return self.username + def auth_publickey(self, username, key, event): """ Authenticate to the server using a private key. The key is used to @@ -334,6 +347,8 @@ class Transport (BaseTransport): self._log(DEBUG, 'Auth rejected because the client attempted to change username in mid-flight') self._disconnect_no_more_auth() return + self.auth_username = username + if method == 'none': result = self.check_auth_none(username) elif method == 'password': @@ -390,7 +405,7 @@ class Transport (BaseTransport): m.add_boolean(1) else: m.add_boolean(0) - self.auth_fail_count += 1 + self.auth_fail_count += 1 self._send_message(m) if self.auth_fail_count >= 10: self._disconnect_no_more_auth() @@ -411,6 +426,8 @@ class Transport (BaseTransport): pass self._log(INFO, 'Authentication failed.') self.authenticated = False + # FIXME: i don't think we need to close() necessarily here + self.username = None self.close() if self.auth_event != None: self.auth_event.set() diff --git a/paramiko/channel.py b/paramiko/channel.py index 44672d9..2a7bd20 100644 --- a/paramiko/channel.py +++ b/paramiko/channel.py @@ -70,7 +70,7 @@ class Channel (object): self.eof_sent = 0 self.in_buffer = '' self.timeout = None - self.closed = 0 + self.closed = False self.lock = threading.Lock() self.in_buffer_cv = threading.Condition(self.lock) self.out_buffer_cv = threading.Condition(self.lock) @@ -302,7 +302,7 @@ class Channel (object): m.add_byte(chr(MSG_CHANNEL_CLOSE)) m.add_int(self.remote_chanid) self.transport._send_message(m) - self.closed = 1 + self._set_closed() self.transport._unlink_channel(self.chanid) finally: self.lock.release() @@ -723,6 +723,11 @@ class Channel (object): def _log(self, level, msg): self.logger.log(level, msg) + def _set_closed(self): + self.closed = True + self.in_buffer_cv.notifyAll() + self.out_buffer_cv.notifyAll() + def _send_eof(self): if self.eof_sent: return @@ -815,7 +820,7 @@ class Channel (object): def _unlink(self): if self.closed or not self.active: return - self.closed = 1 + self._set_closed() self.transport._unlink_channel(self.chanid) def _check_add_window(self, n): diff --git a/paramiko/common.py b/paramiko/common.py index e5aaaf8..88e8b1a 100644 --- a/paramiko/common.py +++ b/paramiko/common.py @@ -62,7 +62,10 @@ randpool.randomize() import sys if sys.version_info < (2, 3): - import logging22 as logging + try: + import logging + except: + import logging22 as logging import select PY22 = True else: diff --git a/paramiko/pkey.py b/paramiko/pkey.py index 7b8afcb..fa3cccc 100644 --- a/paramiko/pkey.py +++ b/paramiko/pkey.py @@ -237,7 +237,7 @@ class PKey (object): lines = f.readlines() f.close() start = 0 - while (lines[start].strip() != '-----BEGIN ' + tag + ' PRIVATE KEY-----') and (start < len(lines)): + while (start < len(lines)) and (lines[start].strip() != '-----BEGIN ' + tag + ' PRIVATE KEY-----'): start += 1 if start >= len(lines): raise SSHException('not a valid ' + tag + ' private key file') diff --git a/paramiko/transport.py b/paramiko/transport.py index bc3338c..29f59e8 100644 --- a/paramiko/transport.py +++ b/paramiko/transport.py @@ -636,10 +636,11 @@ class BaseTransport (threading.Thread): # check host key if we were given one if (hostkeytype is not None) and (hostkey is not None): - type, key = self.get_remote_server_key() - if (type != hostkeytype) or (key != hostkey): - print repr(type) + ' - ' + repr(hostkeytype) - print repr(key) + ' - ' + repr(hostkey) + keytype, key = self.get_remote_server_key() + if (keytype != hostkeytype) or (key != hostkey): + self._log(DEBUG, 'Bad host key from server') + self._log(DEBUG, 'Expected: %s: %s' % (repr(hostkeytype), repr(hostkey))) + self._log(DEBUG, 'Got : %s: %s' % (repr(keytype), repr(key))) raise SSHException('Bad host key from server') self._log(DEBUG, 'Host key verified (%s)' % hostkeytype) @@ -920,6 +921,8 @@ class BaseTransport (threading.Thread): self._log(DEBUG, util.tb_strings()) self.saved_exception = e _active_threads.remove(self) + for chan in self.channels.values(): + chan._unlink() if self.active: self.active = False if self.completion_event != None: