bug 191657:
clean up usage of the channel map by making a special object to hold the
weak value dict.
This commit is contained in:
Robey Pointer 2008-03-23 01:21:10 -07:00
parent e5a1b4bf56
commit 9a6ffec93f
1 changed files with 64 additions and 28 deletions

View File

@ -140,6 +140,51 @@ class SecurityOptions (object):
"Compression algorithms") "Compression algorithms")
class ChannelMap (object):
def __init__(self):
# (id -> Channel)
self._map = weakref.WeakValueDictionary()
self._lock = threading.Lock()
def put(self, chanid, chan):
self._lock.acquire()
try:
self._map[chanid] = chan
finally:
self._lock.release()
def get(self, chanid):
self._lock.acquire()
try:
return self._map.get(chanid, None)
finally:
self._lock.release()
def delete(self, chanid):
self._lock.acquire()
try:
try:
del self._map[chanid]
except KeyError:
pass
finally:
self._lock.release()
def values(self):
self._lock.acquire()
try:
return self._map.values()
finally:
self._lock.release()
def __len__(self):
self._lock.acquire()
try:
return len(self._map)
finally:
self._lock.release()
class Transport (threading.Thread): class Transport (threading.Thread):
""" """
An SSH Transport attaches to a stream (usually a socket), negotiates an An SSH Transport attaches to a stream (usually a socket), negotiates an
@ -271,7 +316,7 @@ class Transport (threading.Thread):
self.lock = threading.Lock() # synchronization (always higher level than write_lock) self.lock = threading.Lock() # synchronization (always higher level than write_lock)
# tracking open channels # tracking open channels
self.channels = weakref.WeakValueDictionary() # (id -> Channel) self._channels = ChannelMap()
self.channel_events = { } # (id -> Event) self.channel_events = { } # (id -> Event)
self.channels_seen = { } # (id -> True) self.channels_seen = { } # (id -> True)
self._channel_counter = 1 self._channel_counter = 1
@ -313,10 +358,7 @@ class Transport (threading.Thread):
out += ' (cipher %s, %d bits)' % (self.local_cipher, out += ' (cipher %s, %d bits)' % (self.local_cipher,
self._cipher_info[self.local_cipher]['key-size'] * 8) self._cipher_info[self.local_cipher]['key-size'] * 8)
if self.is_authenticated(): if self.is_authenticated():
if len(self.channels) == 1: out += ' (active; %d open channel(s))' % len(self._channels)
out += ' (active; 1 open channel)'
else:
out += ' (active; %d open channels)' % len(self.channels)
elif self.initial_kex_done: elif self.initial_kex_done:
out += ' (connected; awaiting auth)' out += ' (connected; awaiting auth)'
else: else:
@ -550,7 +592,7 @@ class Transport (threading.Thread):
self.active = False self.active = False
self.packetizer.close() self.packetizer.close()
self.join() self.join()
for chan in self.channels.values(): for chan in self._channels.values():
chan._unlink() chan._unlink()
def get_remote_server_key(self): def get_remote_server_key(self):
@ -667,7 +709,8 @@ class Transport (threading.Thread):
elif kind == 'x11': elif kind == 'x11':
m.add_string(src_addr[0]) m.add_string(src_addr[0])
m.add_int(src_addr[1]) m.add_int(src_addr[1])
self.channels[chanid] = chan = Channel(chanid) chan = Channel(chanid)
self._channels.put(chanid, chan)
self.channel_events[chanid] = event = threading.Event() self.channel_events[chanid] = event = threading.Event()
self.channels_seen[chanid] = True self.channels_seen[chanid] = True
chan._set_transport(self) chan._set_transport(self)
@ -684,12 +727,9 @@ class Transport (threading.Thread):
raise e raise e
if event.isSet(): if event.isSet():
break break
self.lock.acquire() chan = self._channels.get(chanid)
try: if chan is not None:
if chanid in self.channels: return chan
return chan
finally:
self.lock.release()
e = self.get_exception() e = self.get_exception()
if e is None: if e is None:
e = SSHException('Unable to open channel.') e = SSHException('Unable to open channel.')
@ -1334,7 +1374,7 @@ class Transport (threading.Thread):
def _next_channel(self): def _next_channel(self):
"you are holding the lock" "you are holding the lock"
chanid = self._channel_counter chanid = self._channel_counter
while chanid in self.channels: while self._channels.get(chanid) is not None:
self._channel_counter = (self._channel_counter + 1) & 0xffffff self._channel_counter = (self._channel_counter + 1) & 0xffffff
chanid = self._channel_counter chanid = self._channel_counter
self._channel_counter = (self._channel_counter + 1) & 0xffffff self._channel_counter = (self._channel_counter + 1) & 0xffffff
@ -1342,12 +1382,7 @@ class Transport (threading.Thread):
def _unlink_channel(self, chanid): def _unlink_channel(self, chanid):
"used by a Channel to remove itself from the active channel list" "used by a Channel to remove itself from the active channel list"
try: self._channels.delete(chanid)
self.lock.acquire()
if chanid in self.channels:
del self.channels[chanid]
finally:
self.lock.release()
def _send_message(self, data): def _send_message(self, data):
self.packetizer.send_message(data) self.packetizer.send_message(data)
@ -1478,8 +1513,9 @@ class Transport (threading.Thread):
self._handler_table[ptype](self, m) self._handler_table[ptype](self, m)
elif ptype in self._channel_handler_table: elif ptype in self._channel_handler_table:
chanid = m.get_int() chanid = m.get_int()
if chanid in self.channels: chan = self._channels.get(chanid)
self._channel_handler_table[ptype](self.channels[chanid], m) if chan is not None:
self._channel_handler_table[ptype](chan, m)
elif chanid in self.channels_seen: elif chanid in self.channels_seen:
self._log(DEBUG, 'Ignoring message for dead channel %d' % chanid) self._log(DEBUG, 'Ignoring message for dead channel %d' % chanid)
else: else:
@ -1514,7 +1550,7 @@ class Transport (threading.Thread):
self._log(ERROR, util.tb_strings()) self._log(ERROR, util.tb_strings())
self.saved_exception = e self.saved_exception = e
_active_threads.remove(self) _active_threads.remove(self)
for chan in self.channels.values(): for chan in self._channels.values():
chan._unlink() chan._unlink()
if self.active: if self.active:
self.active = False self.active = False
@ -1872,12 +1908,12 @@ class Transport (threading.Thread):
server_chanid = m.get_int() server_chanid = m.get_int()
server_window_size = m.get_int() server_window_size = m.get_int()
server_max_packet_size = m.get_int() server_max_packet_size = m.get_int()
if chanid not in self.channels: chan = self._channels.get(chanid)
if chan is None:
self._log(WARNING, 'Success for unrequested channel! [??]') self._log(WARNING, 'Success for unrequested channel! [??]')
return return
self.lock.acquire() self.lock.acquire()
try: try:
chan = self.channels[chanid]
chan._set_remote_channel(server_chanid, server_window_size, server_max_packet_size) chan._set_remote_channel(server_chanid, server_window_size, server_max_packet_size)
self._log(INFO, 'Secsh channel %d opened.' % chanid) self._log(INFO, 'Secsh channel %d opened.' % chanid)
if chanid in self.channel_events: if chanid in self.channel_events:
@ -1898,7 +1934,7 @@ class Transport (threading.Thread):
try: try:
self.saved_exception = ChannelException(reason, reason_text) self.saved_exception = ChannelException(reason, reason_text)
if chanid in self.channel_events: if chanid in self.channel_events:
del self.channels[chanid] self._channels.delete(chanid)
if chanid in self.channel_events: if chanid in self.channel_events:
self.channel_events[chanid].set() self.channel_events[chanid].set()
del self.channel_events[chanid] del self.channel_events[chanid]
@ -1967,9 +2003,9 @@ class Transport (threading.Thread):
return return
chan = Channel(my_chanid) chan = Channel(my_chanid)
self.lock.acquire()
try: try:
self.lock.acquire() self._channels.put(my_chanid, chan)
self.channels[my_chanid] = chan
self.channels_seen[my_chanid] = True self.channels_seen[my_chanid] = True
chan._set_transport(self) chan._set_transport(self)
chan._set_window(self.window_size, self.max_packet_size) chan._set_window(self.window_size, self.max_packet_size)