diff --git a/paramiko/transport.py b/paramiko/transport.py index 4ceb4a0..af4c307 100644 --- a/paramiko/transport.py +++ b/paramiko/transport.py @@ -140,6 +140,51 @@ class SecurityOptions (object): "Compression algorithms") +class ChannelMap (object): + def __init__(self): + # (id -> Channel) + self._map = weakref.WeakValueDictionary() + self._lock = threading.Lock() + + def put(self, chanid, chan): + self._lock.acquire() + try: + self._map[chanid] = chan + finally: + self._lock.release() + + def get(self, chanid): + self._lock.acquire() + try: + return self._map.get(chanid, None) + finally: + self._lock.release() + + def delete(self, chanid): + self._lock.acquire() + try: + try: + del self._map[chanid] + except KeyError: + pass + finally: + self._lock.release() + + def values(self): + self._lock.acquire() + try: + return self._map.values() + finally: + self._lock.release() + + def __len__(self): + self._lock.acquire() + try: + return len(self._map) + finally: + self._lock.release() + + class Transport (threading.Thread): """ An SSH Transport attaches to a stream (usually a socket), negotiates an @@ -271,7 +316,7 @@ class Transport (threading.Thread): self.lock = threading.Lock() # synchronization (always higher level than write_lock) # tracking open channels - self.channels = weakref.WeakValueDictionary() # (id -> Channel) + self._channels = ChannelMap() self.channel_events = { } # (id -> Event) self.channels_seen = { } # (id -> True) self._channel_counter = 1 @@ -313,10 +358,7 @@ class Transport (threading.Thread): out += ' (cipher %s, %d bits)' % (self.local_cipher, self._cipher_info[self.local_cipher]['key-size'] * 8) if self.is_authenticated(): - if len(self.channels) == 1: - out += ' (active; 1 open channel)' - else: - out += ' (active; %d open channels)' % len(self.channels) + out += ' (active; %d open channel(s))' % len(self._channels) elif self.initial_kex_done: out += ' (connected; awaiting auth)' else: @@ -550,7 +592,7 @@ class Transport (threading.Thread): self.active = False self.packetizer.close() self.join() - for chan in self.channels.values(): + for chan in self._channels.values(): chan._unlink() def get_remote_server_key(self): @@ -667,7 +709,8 @@ class Transport (threading.Thread): elif kind == 'x11': m.add_string(src_addr[0]) m.add_int(src_addr[1]) - self.channels[chanid] = chan = Channel(chanid) + chan = Channel(chanid) + self._channels.put(chanid, chan) self.channel_events[chanid] = event = threading.Event() self.channels_seen[chanid] = True chan._set_transport(self) @@ -684,12 +727,9 @@ class Transport (threading.Thread): raise e if event.isSet(): break - self.lock.acquire() - try: - if chanid in self.channels: - return chan - finally: - self.lock.release() + chan = self._channels.get(chanid) + if chan is not None: + return chan e = self.get_exception() if e is None: e = SSHException('Unable to open channel.') @@ -1334,7 +1374,7 @@ class Transport (threading.Thread): def _next_channel(self): "you are holding the lock" chanid = self._channel_counter - while chanid in self.channels: + while self._channels.get(chanid) is not None: self._channel_counter = (self._channel_counter + 1) & 0xffffff chanid = self._channel_counter self._channel_counter = (self._channel_counter + 1) & 0xffffff @@ -1342,12 +1382,7 @@ class Transport (threading.Thread): def _unlink_channel(self, chanid): "used by a Channel to remove itself from the active channel list" - try: - self.lock.acquire() - if chanid in self.channels: - del self.channels[chanid] - finally: - self.lock.release() + self._channels.delete(chanid) def _send_message(self, data): self.packetizer.send_message(data) @@ -1478,8 +1513,9 @@ class Transport (threading.Thread): self._handler_table[ptype](self, m) elif ptype in self._channel_handler_table: chanid = m.get_int() - if chanid in self.channels: - self._channel_handler_table[ptype](self.channels[chanid], m) + chan = self._channels.get(chanid) + if chan is not None: + self._channel_handler_table[ptype](chan, m) elif chanid in self.channels_seen: self._log(DEBUG, 'Ignoring message for dead channel %d' % chanid) else: @@ -1514,7 +1550,7 @@ class Transport (threading.Thread): self._log(ERROR, util.tb_strings()) self.saved_exception = e _active_threads.remove(self) - for chan in self.channels.values(): + for chan in self._channels.values(): chan._unlink() if self.active: self.active = False @@ -1872,12 +1908,12 @@ class Transport (threading.Thread): server_chanid = m.get_int() server_window_size = m.get_int() server_max_packet_size = m.get_int() - if chanid not in self.channels: + chan = self._channels.get(chanid) + if chan is None: self._log(WARNING, 'Success for unrequested channel! [??]') return self.lock.acquire() try: - chan = self.channels[chanid] chan._set_remote_channel(server_chanid, server_window_size, server_max_packet_size) self._log(INFO, 'Secsh channel %d opened.' % chanid) if chanid in self.channel_events: @@ -1898,7 +1934,7 @@ class Transport (threading.Thread): try: self.saved_exception = ChannelException(reason, reason_text) if chanid in self.channel_events: - del self.channels[chanid] + self._channels.delete(chanid) if chanid in self.channel_events: self.channel_events[chanid].set() del self.channel_events[chanid] @@ -1967,9 +2003,9 @@ class Transport (threading.Thread): return chan = Channel(my_chanid) + self.lock.acquire() try: - self.lock.acquire() - self.channels[my_chanid] = chan + self._channels.put(my_chanid, chan) self.channels_seen[my_chanid] = True chan._set_transport(self) chan._set_window(self.window_size, self.max_packet_size)