bug 191657: clean up usage of the channel map by making a special object to hold the weak value dict.
This commit is contained in:
parent
e5a1b4bf56
commit
9a6ffec93f
|
@ -140,6 +140,51 @@ class SecurityOptions (object):
|
||||||
"Compression algorithms")
|
"Compression algorithms")
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelMap (object):
|
||||||
|
def __init__(self):
|
||||||
|
# (id -> Channel)
|
||||||
|
self._map = weakref.WeakValueDictionary()
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
|
def put(self, chanid, chan):
|
||||||
|
self._lock.acquire()
|
||||||
|
try:
|
||||||
|
self._map[chanid] = chan
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
||||||
|
|
||||||
|
def get(self, chanid):
|
||||||
|
self._lock.acquire()
|
||||||
|
try:
|
||||||
|
return self._map.get(chanid, None)
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
||||||
|
|
||||||
|
def delete(self, chanid):
|
||||||
|
self._lock.acquire()
|
||||||
|
try:
|
||||||
|
try:
|
||||||
|
del self._map[chanid]
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
||||||
|
|
||||||
|
def values(self):
|
||||||
|
self._lock.acquire()
|
||||||
|
try:
|
||||||
|
return self._map.values()
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
self._lock.acquire()
|
||||||
|
try:
|
||||||
|
return len(self._map)
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
||||||
|
|
||||||
|
|
||||||
class Transport (threading.Thread):
|
class Transport (threading.Thread):
|
||||||
"""
|
"""
|
||||||
An SSH Transport attaches to a stream (usually a socket), negotiates an
|
An SSH Transport attaches to a stream (usually a socket), negotiates an
|
||||||
|
@ -271,7 +316,7 @@ class Transport (threading.Thread):
|
||||||
self.lock = threading.Lock() # synchronization (always higher level than write_lock)
|
self.lock = threading.Lock() # synchronization (always higher level than write_lock)
|
||||||
|
|
||||||
# tracking open channels
|
# tracking open channels
|
||||||
self.channels = weakref.WeakValueDictionary() # (id -> Channel)
|
self._channels = ChannelMap()
|
||||||
self.channel_events = { } # (id -> Event)
|
self.channel_events = { } # (id -> Event)
|
||||||
self.channels_seen = { } # (id -> True)
|
self.channels_seen = { } # (id -> True)
|
||||||
self._channel_counter = 1
|
self._channel_counter = 1
|
||||||
|
@ -313,10 +358,7 @@ class Transport (threading.Thread):
|
||||||
out += ' (cipher %s, %d bits)' % (self.local_cipher,
|
out += ' (cipher %s, %d bits)' % (self.local_cipher,
|
||||||
self._cipher_info[self.local_cipher]['key-size'] * 8)
|
self._cipher_info[self.local_cipher]['key-size'] * 8)
|
||||||
if self.is_authenticated():
|
if self.is_authenticated():
|
||||||
if len(self.channels) == 1:
|
out += ' (active; %d open channel(s))' % len(self._channels)
|
||||||
out += ' (active; 1 open channel)'
|
|
||||||
else:
|
|
||||||
out += ' (active; %d open channels)' % len(self.channels)
|
|
||||||
elif self.initial_kex_done:
|
elif self.initial_kex_done:
|
||||||
out += ' (connected; awaiting auth)'
|
out += ' (connected; awaiting auth)'
|
||||||
else:
|
else:
|
||||||
|
@ -550,7 +592,7 @@ class Transport (threading.Thread):
|
||||||
self.active = False
|
self.active = False
|
||||||
self.packetizer.close()
|
self.packetizer.close()
|
||||||
self.join()
|
self.join()
|
||||||
for chan in self.channels.values():
|
for chan in self._channels.values():
|
||||||
chan._unlink()
|
chan._unlink()
|
||||||
|
|
||||||
def get_remote_server_key(self):
|
def get_remote_server_key(self):
|
||||||
|
@ -667,7 +709,8 @@ class Transport (threading.Thread):
|
||||||
elif kind == 'x11':
|
elif kind == 'x11':
|
||||||
m.add_string(src_addr[0])
|
m.add_string(src_addr[0])
|
||||||
m.add_int(src_addr[1])
|
m.add_int(src_addr[1])
|
||||||
self.channels[chanid] = chan = Channel(chanid)
|
chan = Channel(chanid)
|
||||||
|
self._channels.put(chanid, chan)
|
||||||
self.channel_events[chanid] = event = threading.Event()
|
self.channel_events[chanid] = event = threading.Event()
|
||||||
self.channels_seen[chanid] = True
|
self.channels_seen[chanid] = True
|
||||||
chan._set_transport(self)
|
chan._set_transport(self)
|
||||||
|
@ -684,12 +727,9 @@ class Transport (threading.Thread):
|
||||||
raise e
|
raise e
|
||||||
if event.isSet():
|
if event.isSet():
|
||||||
break
|
break
|
||||||
self.lock.acquire()
|
chan = self._channels.get(chanid)
|
||||||
try:
|
if chan is not None:
|
||||||
if chanid in self.channels:
|
return chan
|
||||||
return chan
|
|
||||||
finally:
|
|
||||||
self.lock.release()
|
|
||||||
e = self.get_exception()
|
e = self.get_exception()
|
||||||
if e is None:
|
if e is None:
|
||||||
e = SSHException('Unable to open channel.')
|
e = SSHException('Unable to open channel.')
|
||||||
|
@ -1334,7 +1374,7 @@ class Transport (threading.Thread):
|
||||||
def _next_channel(self):
|
def _next_channel(self):
|
||||||
"you are holding the lock"
|
"you are holding the lock"
|
||||||
chanid = self._channel_counter
|
chanid = self._channel_counter
|
||||||
while chanid in self.channels:
|
while self._channels.get(chanid) is not None:
|
||||||
self._channel_counter = (self._channel_counter + 1) & 0xffffff
|
self._channel_counter = (self._channel_counter + 1) & 0xffffff
|
||||||
chanid = self._channel_counter
|
chanid = self._channel_counter
|
||||||
self._channel_counter = (self._channel_counter + 1) & 0xffffff
|
self._channel_counter = (self._channel_counter + 1) & 0xffffff
|
||||||
|
@ -1342,12 +1382,7 @@ class Transport (threading.Thread):
|
||||||
|
|
||||||
def _unlink_channel(self, chanid):
|
def _unlink_channel(self, chanid):
|
||||||
"used by a Channel to remove itself from the active channel list"
|
"used by a Channel to remove itself from the active channel list"
|
||||||
try:
|
self._channels.delete(chanid)
|
||||||
self.lock.acquire()
|
|
||||||
if chanid in self.channels:
|
|
||||||
del self.channels[chanid]
|
|
||||||
finally:
|
|
||||||
self.lock.release()
|
|
||||||
|
|
||||||
def _send_message(self, data):
|
def _send_message(self, data):
|
||||||
self.packetizer.send_message(data)
|
self.packetizer.send_message(data)
|
||||||
|
@ -1478,8 +1513,9 @@ class Transport (threading.Thread):
|
||||||
self._handler_table[ptype](self, m)
|
self._handler_table[ptype](self, m)
|
||||||
elif ptype in self._channel_handler_table:
|
elif ptype in self._channel_handler_table:
|
||||||
chanid = m.get_int()
|
chanid = m.get_int()
|
||||||
if chanid in self.channels:
|
chan = self._channels.get(chanid)
|
||||||
self._channel_handler_table[ptype](self.channels[chanid], m)
|
if chan is not None:
|
||||||
|
self._channel_handler_table[ptype](chan, m)
|
||||||
elif chanid in self.channels_seen:
|
elif chanid in self.channels_seen:
|
||||||
self._log(DEBUG, 'Ignoring message for dead channel %d' % chanid)
|
self._log(DEBUG, 'Ignoring message for dead channel %d' % chanid)
|
||||||
else:
|
else:
|
||||||
|
@ -1514,7 +1550,7 @@ class Transport (threading.Thread):
|
||||||
self._log(ERROR, util.tb_strings())
|
self._log(ERROR, util.tb_strings())
|
||||||
self.saved_exception = e
|
self.saved_exception = e
|
||||||
_active_threads.remove(self)
|
_active_threads.remove(self)
|
||||||
for chan in self.channels.values():
|
for chan in self._channels.values():
|
||||||
chan._unlink()
|
chan._unlink()
|
||||||
if self.active:
|
if self.active:
|
||||||
self.active = False
|
self.active = False
|
||||||
|
@ -1872,12 +1908,12 @@ class Transport (threading.Thread):
|
||||||
server_chanid = m.get_int()
|
server_chanid = m.get_int()
|
||||||
server_window_size = m.get_int()
|
server_window_size = m.get_int()
|
||||||
server_max_packet_size = m.get_int()
|
server_max_packet_size = m.get_int()
|
||||||
if chanid not in self.channels:
|
chan = self._channels.get(chanid)
|
||||||
|
if chan is None:
|
||||||
self._log(WARNING, 'Success for unrequested channel! [??]')
|
self._log(WARNING, 'Success for unrequested channel! [??]')
|
||||||
return
|
return
|
||||||
self.lock.acquire()
|
self.lock.acquire()
|
||||||
try:
|
try:
|
||||||
chan = self.channels[chanid]
|
|
||||||
chan._set_remote_channel(server_chanid, server_window_size, server_max_packet_size)
|
chan._set_remote_channel(server_chanid, server_window_size, server_max_packet_size)
|
||||||
self._log(INFO, 'Secsh channel %d opened.' % chanid)
|
self._log(INFO, 'Secsh channel %d opened.' % chanid)
|
||||||
if chanid in self.channel_events:
|
if chanid in self.channel_events:
|
||||||
|
@ -1898,7 +1934,7 @@ class Transport (threading.Thread):
|
||||||
try:
|
try:
|
||||||
self.saved_exception = ChannelException(reason, reason_text)
|
self.saved_exception = ChannelException(reason, reason_text)
|
||||||
if chanid in self.channel_events:
|
if chanid in self.channel_events:
|
||||||
del self.channels[chanid]
|
self._channels.delete(chanid)
|
||||||
if chanid in self.channel_events:
|
if chanid in self.channel_events:
|
||||||
self.channel_events[chanid].set()
|
self.channel_events[chanid].set()
|
||||||
del self.channel_events[chanid]
|
del self.channel_events[chanid]
|
||||||
|
@ -1967,9 +2003,9 @@ class Transport (threading.Thread):
|
||||||
return
|
return
|
||||||
|
|
||||||
chan = Channel(my_chanid)
|
chan = Channel(my_chanid)
|
||||||
|
self.lock.acquire()
|
||||||
try:
|
try:
|
||||||
self.lock.acquire()
|
self._channels.put(my_chanid, chan)
|
||||||
self.channels[my_chanid] = chan
|
|
||||||
self.channels_seen[my_chanid] = True
|
self.channels_seen[my_chanid] = True
|
||||||
chan._set_transport(self)
|
chan._set_transport(self)
|
||||||
chan._set_window(self.window_size, self.max_packet_size)
|
chan._set_window(self.window_size, self.max_packet_size)
|
||||||
|
|
Loading…
Reference in New Issue