bug 191657: clean up usage of the channel map by making a special object to hold the weak value dict.
This commit is contained in:
parent
e5a1b4bf56
commit
9a6ffec93f
|
@ -140,6 +140,51 @@ class SecurityOptions (object):
|
|||
"Compression algorithms")
|
||||
|
||||
|
||||
class ChannelMap (object):
|
||||
def __init__(self):
|
||||
# (id -> Channel)
|
||||
self._map = weakref.WeakValueDictionary()
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def put(self, chanid, chan):
|
||||
self._lock.acquire()
|
||||
try:
|
||||
self._map[chanid] = chan
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def get(self, chanid):
|
||||
self._lock.acquire()
|
||||
try:
|
||||
return self._map.get(chanid, None)
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def delete(self, chanid):
|
||||
self._lock.acquire()
|
||||
try:
|
||||
try:
|
||||
del self._map[chanid]
|
||||
except KeyError:
|
||||
pass
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def values(self):
|
||||
self._lock.acquire()
|
||||
try:
|
||||
return self._map.values()
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def __len__(self):
|
||||
self._lock.acquire()
|
||||
try:
|
||||
return len(self._map)
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
|
||||
class Transport (threading.Thread):
|
||||
"""
|
||||
An SSH Transport attaches to a stream (usually a socket), negotiates an
|
||||
|
@ -271,7 +316,7 @@ class Transport (threading.Thread):
|
|||
self.lock = threading.Lock() # synchronization (always higher level than write_lock)
|
||||
|
||||
# tracking open channels
|
||||
self.channels = weakref.WeakValueDictionary() # (id -> Channel)
|
||||
self._channels = ChannelMap()
|
||||
self.channel_events = { } # (id -> Event)
|
||||
self.channels_seen = { } # (id -> True)
|
||||
self._channel_counter = 1
|
||||
|
@ -313,10 +358,7 @@ class Transport (threading.Thread):
|
|||
out += ' (cipher %s, %d bits)' % (self.local_cipher,
|
||||
self._cipher_info[self.local_cipher]['key-size'] * 8)
|
||||
if self.is_authenticated():
|
||||
if len(self.channels) == 1:
|
||||
out += ' (active; 1 open channel)'
|
||||
else:
|
||||
out += ' (active; %d open channels)' % len(self.channels)
|
||||
out += ' (active; %d open channel(s))' % len(self._channels)
|
||||
elif self.initial_kex_done:
|
||||
out += ' (connected; awaiting auth)'
|
||||
else:
|
||||
|
@ -550,7 +592,7 @@ class Transport (threading.Thread):
|
|||
self.active = False
|
||||
self.packetizer.close()
|
||||
self.join()
|
||||
for chan in self.channels.values():
|
||||
for chan in self._channels.values():
|
||||
chan._unlink()
|
||||
|
||||
def get_remote_server_key(self):
|
||||
|
@ -667,7 +709,8 @@ class Transport (threading.Thread):
|
|||
elif kind == 'x11':
|
||||
m.add_string(src_addr[0])
|
||||
m.add_int(src_addr[1])
|
||||
self.channels[chanid] = chan = Channel(chanid)
|
||||
chan = Channel(chanid)
|
||||
self._channels.put(chanid, chan)
|
||||
self.channel_events[chanid] = event = threading.Event()
|
||||
self.channels_seen[chanid] = True
|
||||
chan._set_transport(self)
|
||||
|
@ -684,12 +727,9 @@ class Transport (threading.Thread):
|
|||
raise e
|
||||
if event.isSet():
|
||||
break
|
||||
self.lock.acquire()
|
||||
try:
|
||||
if chanid in self.channels:
|
||||
return chan
|
||||
finally:
|
||||
self.lock.release()
|
||||
chan = self._channels.get(chanid)
|
||||
if chan is not None:
|
||||
return chan
|
||||
e = self.get_exception()
|
||||
if e is None:
|
||||
e = SSHException('Unable to open channel.')
|
||||
|
@ -1334,7 +1374,7 @@ class Transport (threading.Thread):
|
|||
def _next_channel(self):
|
||||
"you are holding the lock"
|
||||
chanid = self._channel_counter
|
||||
while chanid in self.channels:
|
||||
while self._channels.get(chanid) is not None:
|
||||
self._channel_counter = (self._channel_counter + 1) & 0xffffff
|
||||
chanid = self._channel_counter
|
||||
self._channel_counter = (self._channel_counter + 1) & 0xffffff
|
||||
|
@ -1342,12 +1382,7 @@ class Transport (threading.Thread):
|
|||
|
||||
def _unlink_channel(self, chanid):
|
||||
"used by a Channel to remove itself from the active channel list"
|
||||
try:
|
||||
self.lock.acquire()
|
||||
if chanid in self.channels:
|
||||
del self.channels[chanid]
|
||||
finally:
|
||||
self.lock.release()
|
||||
self._channels.delete(chanid)
|
||||
|
||||
def _send_message(self, data):
|
||||
self.packetizer.send_message(data)
|
||||
|
@ -1478,8 +1513,9 @@ class Transport (threading.Thread):
|
|||
self._handler_table[ptype](self, m)
|
||||
elif ptype in self._channel_handler_table:
|
||||
chanid = m.get_int()
|
||||
if chanid in self.channels:
|
||||
self._channel_handler_table[ptype](self.channels[chanid], m)
|
||||
chan = self._channels.get(chanid)
|
||||
if chan is not None:
|
||||
self._channel_handler_table[ptype](chan, m)
|
||||
elif chanid in self.channels_seen:
|
||||
self._log(DEBUG, 'Ignoring message for dead channel %d' % chanid)
|
||||
else:
|
||||
|
@ -1514,7 +1550,7 @@ class Transport (threading.Thread):
|
|||
self._log(ERROR, util.tb_strings())
|
||||
self.saved_exception = e
|
||||
_active_threads.remove(self)
|
||||
for chan in self.channels.values():
|
||||
for chan in self._channels.values():
|
||||
chan._unlink()
|
||||
if self.active:
|
||||
self.active = False
|
||||
|
@ -1872,12 +1908,12 @@ class Transport (threading.Thread):
|
|||
server_chanid = m.get_int()
|
||||
server_window_size = m.get_int()
|
||||
server_max_packet_size = m.get_int()
|
||||
if chanid not in self.channels:
|
||||
chan = self._channels.get(chanid)
|
||||
if chan is None:
|
||||
self._log(WARNING, 'Success for unrequested channel! [??]')
|
||||
return
|
||||
self.lock.acquire()
|
||||
try:
|
||||
chan = self.channels[chanid]
|
||||
chan._set_remote_channel(server_chanid, server_window_size, server_max_packet_size)
|
||||
self._log(INFO, 'Secsh channel %d opened.' % chanid)
|
||||
if chanid in self.channel_events:
|
||||
|
@ -1898,7 +1934,7 @@ class Transport (threading.Thread):
|
|||
try:
|
||||
self.saved_exception = ChannelException(reason, reason_text)
|
||||
if chanid in self.channel_events:
|
||||
del self.channels[chanid]
|
||||
self._channels.delete(chanid)
|
||||
if chanid in self.channel_events:
|
||||
self.channel_events[chanid].set()
|
||||
del self.channel_events[chanid]
|
||||
|
@ -1967,9 +2003,9 @@ class Transport (threading.Thread):
|
|||
return
|
||||
|
||||
chan = Channel(my_chanid)
|
||||
self.lock.acquire()
|
||||
try:
|
||||
self.lock.acquire()
|
||||
self.channels[my_chanid] = chan
|
||||
self._channels.put(my_chanid, chan)
|
||||
self.channels_seen[my_chanid] = True
|
||||
chan._set_transport(self)
|
||||
chan._set_window(self.window_size, self.max_packet_size)
|
||||
|
|
Loading…
Reference in New Issue