clean up use of expected_packet and make it accept a tuple of packet types
This commit is contained in:
Robey Pointer 2006-07-23 16:55:48 -07:00
parent 55a52a09cc
commit 4737e44e40
1 changed files with 9 additions and 9 deletions

View File

@ -267,7 +267,7 @@ class Transport (threading.Thread):
self.initial_kex_done = False self.initial_kex_done = False
self.in_kex = False self.in_kex = False
self.authenticated = False self.authenticated = False
self.expected_packet = 0 self._expected_packet = tuple()
self.lock = threading.Lock() # synchronization (always higher level than write_lock) self.lock = threading.Lock() # synchronization (always higher level than write_lock)
# tracking open channels # tracking open channels
@ -1275,9 +1275,9 @@ class Transport (threading.Thread):
if self.session_id == None: if self.session_id == None:
self.session_id = h self.session_id = h
def _expect_packet(self, type): def _expect_packet(self, *ptypes):
"used by a kex object to register the next packet type it expects to see" "used by a kex object to register the next packet type it expects to see"
self.expected_packet = type self._expected_packet = tuple(ptypes)
def _verify_key(self, host_key, sig): def _verify_key(self, host_key, sig):
key = self._key_info[self.host_key_type](Message(host_key)) key = self._key_info[self.host_key_type](Message(host_key))
@ -1326,7 +1326,7 @@ class Transport (threading.Thread):
self.packetizer.write_all(self.local_version + '\r\n') self.packetizer.write_all(self.local_version + '\r\n')
self._check_banner() self._check_banner()
self._send_kex_init() self._send_kex_init()
self.expected_packet = MSG_KEXINIT self._expect_packet(MSG_KEXINIT)
while self.active: while self.active:
if self.packetizer.need_rekey() and not self.in_kex: if self.packetizer.need_rekey() and not self.in_kex:
@ -1345,10 +1345,10 @@ class Transport (threading.Thread):
elif ptype == MSG_DEBUG: elif ptype == MSG_DEBUG:
self._parse_debug(m) self._parse_debug(m)
continue continue
if self.expected_packet != 0: if len(self._expected_packet) > 0:
if ptype != self.expected_packet: if ptype not in self._expected_packet:
raise SSHException('Expecting packet %d, got %d' % (self.expected_packet, ptype)) raise SSHException('Expecting packet from %r, got %d' % (self._expected_packet, ptype))
self.expected_packet = 0 self._expected_packet = tuple()
if (ptype >= 30) and (ptype <= 39): if (ptype >= 30) and (ptype <= 39):
self.kex_engine.parse_next(ptype, m) self.kex_engine.parse_next(ptype, m)
continue continue
@ -1651,7 +1651,7 @@ class Transport (threading.Thread):
if not self.packetizer.need_rekey(): if not self.packetizer.need_rekey():
self.in_kex = False self.in_kex = False
# we always expect to receive NEWKEYS now # we always expect to receive NEWKEYS now
self.expected_packet = MSG_NEWKEYS self._expect_packet(MSG_NEWKEYS)
def _auth_trigger(self): def _auth_trigger(self):
self.authenticated = True self.authenticated = True