diff --git a/paramiko/transport.py b/paramiko/transport.py index 05d3087..df87151 100644 --- a/paramiko/transport.py +++ b/paramiko/transport.py @@ -267,7 +267,7 @@ class Transport (threading.Thread): self.initial_kex_done = False self.in_kex = False self.authenticated = False - self.expected_packet = 0 + self._expected_packet = tuple() self.lock = threading.Lock() # synchronization (always higher level than write_lock) # tracking open channels @@ -1275,9 +1275,9 @@ class Transport (threading.Thread): if self.session_id == None: self.session_id = h - def _expect_packet(self, type): + def _expect_packet(self, *ptypes): "used by a kex object to register the next packet type it expects to see" - self.expected_packet = type + self._expected_packet = tuple(ptypes) def _verify_key(self, host_key, sig): key = self._key_info[self.host_key_type](Message(host_key)) @@ -1326,7 +1326,7 @@ class Transport (threading.Thread): self.packetizer.write_all(self.local_version + '\r\n') self._check_banner() self._send_kex_init() - self.expected_packet = MSG_KEXINIT + self._expect_packet(MSG_KEXINIT) while self.active: if self.packetizer.need_rekey() and not self.in_kex: @@ -1345,10 +1345,10 @@ class Transport (threading.Thread): elif ptype == MSG_DEBUG: self._parse_debug(m) continue - if self.expected_packet != 0: - if ptype != self.expected_packet: - raise SSHException('Expecting packet %d, got %d' % (self.expected_packet, ptype)) - self.expected_packet = 0 + if len(self._expected_packet) > 0: + if ptype not in self._expected_packet: + raise SSHException('Expecting packet from %r, got %d' % (self._expected_packet, ptype)) + self._expected_packet = tuple() if (ptype >= 30) and (ptype <= 39): self.kex_engine.parse_next(ptype, m) continue @@ -1651,7 +1651,7 @@ class Transport (threading.Thread): if not self.packetizer.need_rekey(): self.in_kex = False # we always expect to receive NEWKEYS now - self.expected_packet = MSG_NEWKEYS + self._expect_packet(MSG_NEWKEYS) def _auth_trigger(self): self.authenticated = True