clean up use of expected_packet and make it accept a tuple of packet types
This commit is contained in:
parent
55a52a09cc
commit
4737e44e40
|
@ -267,7 +267,7 @@ class Transport (threading.Thread):
|
|||
self.initial_kex_done = False
|
||||
self.in_kex = False
|
||||
self.authenticated = False
|
||||
self.expected_packet = 0
|
||||
self._expected_packet = tuple()
|
||||
self.lock = threading.Lock() # synchronization (always higher level than write_lock)
|
||||
|
||||
# tracking open channels
|
||||
|
@ -1275,9 +1275,9 @@ class Transport (threading.Thread):
|
|||
if self.session_id == None:
|
||||
self.session_id = h
|
||||
|
||||
def _expect_packet(self, type):
|
||||
def _expect_packet(self, *ptypes):
|
||||
"used by a kex object to register the next packet type it expects to see"
|
||||
self.expected_packet = type
|
||||
self._expected_packet = tuple(ptypes)
|
||||
|
||||
def _verify_key(self, host_key, sig):
|
||||
key = self._key_info[self.host_key_type](Message(host_key))
|
||||
|
@ -1326,7 +1326,7 @@ class Transport (threading.Thread):
|
|||
self.packetizer.write_all(self.local_version + '\r\n')
|
||||
self._check_banner()
|
||||
self._send_kex_init()
|
||||
self.expected_packet = MSG_KEXINIT
|
||||
self._expect_packet(MSG_KEXINIT)
|
||||
|
||||
while self.active:
|
||||
if self.packetizer.need_rekey() and not self.in_kex:
|
||||
|
@ -1345,10 +1345,10 @@ class Transport (threading.Thread):
|
|||
elif ptype == MSG_DEBUG:
|
||||
self._parse_debug(m)
|
||||
continue
|
||||
if self.expected_packet != 0:
|
||||
if ptype != self.expected_packet:
|
||||
raise SSHException('Expecting packet %d, got %d' % (self.expected_packet, ptype))
|
||||
self.expected_packet = 0
|
||||
if len(self._expected_packet) > 0:
|
||||
if ptype not in self._expected_packet:
|
||||
raise SSHException('Expecting packet from %r, got %d' % (self._expected_packet, ptype))
|
||||
self._expected_packet = tuple()
|
||||
if (ptype >= 30) and (ptype <= 39):
|
||||
self.kex_engine.parse_next(ptype, m)
|
||||
continue
|
||||
|
@ -1651,7 +1651,7 @@ class Transport (threading.Thread):
|
|||
if not self.packetizer.need_rekey():
|
||||
self.in_kex = False
|
||||
# we always expect to receive NEWKEYS now
|
||||
self.expected_packet = MSG_NEWKEYS
|
||||
self._expect_packet(MSG_NEWKEYS)
|
||||
|
||||
def _auth_trigger(self):
|
||||
self.authenticated = True
|
||||
|
|
Loading…
Reference in New Issue