clean up use of expected_packet and make it accept a tuple of packet types
This commit is contained in:
parent
55a52a09cc
commit
4737e44e40
|
@ -267,7 +267,7 @@ class Transport (threading.Thread):
|
||||||
self.initial_kex_done = False
|
self.initial_kex_done = False
|
||||||
self.in_kex = False
|
self.in_kex = False
|
||||||
self.authenticated = False
|
self.authenticated = False
|
||||||
self.expected_packet = 0
|
self._expected_packet = tuple()
|
||||||
self.lock = threading.Lock() # synchronization (always higher level than write_lock)
|
self.lock = threading.Lock() # synchronization (always higher level than write_lock)
|
||||||
|
|
||||||
# tracking open channels
|
# tracking open channels
|
||||||
|
@ -1275,9 +1275,9 @@ class Transport (threading.Thread):
|
||||||
if self.session_id == None:
|
if self.session_id == None:
|
||||||
self.session_id = h
|
self.session_id = h
|
||||||
|
|
||||||
def _expect_packet(self, type):
|
def _expect_packet(self, *ptypes):
|
||||||
"used by a kex object to register the next packet type it expects to see"
|
"used by a kex object to register the next packet type it expects to see"
|
||||||
self.expected_packet = type
|
self._expected_packet = tuple(ptypes)
|
||||||
|
|
||||||
def _verify_key(self, host_key, sig):
|
def _verify_key(self, host_key, sig):
|
||||||
key = self._key_info[self.host_key_type](Message(host_key))
|
key = self._key_info[self.host_key_type](Message(host_key))
|
||||||
|
@ -1326,7 +1326,7 @@ class Transport (threading.Thread):
|
||||||
self.packetizer.write_all(self.local_version + '\r\n')
|
self.packetizer.write_all(self.local_version + '\r\n')
|
||||||
self._check_banner()
|
self._check_banner()
|
||||||
self._send_kex_init()
|
self._send_kex_init()
|
||||||
self.expected_packet = MSG_KEXINIT
|
self._expect_packet(MSG_KEXINIT)
|
||||||
|
|
||||||
while self.active:
|
while self.active:
|
||||||
if self.packetizer.need_rekey() and not self.in_kex:
|
if self.packetizer.need_rekey() and not self.in_kex:
|
||||||
|
@ -1345,10 +1345,10 @@ class Transport (threading.Thread):
|
||||||
elif ptype == MSG_DEBUG:
|
elif ptype == MSG_DEBUG:
|
||||||
self._parse_debug(m)
|
self._parse_debug(m)
|
||||||
continue
|
continue
|
||||||
if self.expected_packet != 0:
|
if len(self._expected_packet) > 0:
|
||||||
if ptype != self.expected_packet:
|
if ptype not in self._expected_packet:
|
||||||
raise SSHException('Expecting packet %d, got %d' % (self.expected_packet, ptype))
|
raise SSHException('Expecting packet from %r, got %d' % (self._expected_packet, ptype))
|
||||||
self.expected_packet = 0
|
self._expected_packet = tuple()
|
||||||
if (ptype >= 30) and (ptype <= 39):
|
if (ptype >= 30) and (ptype <= 39):
|
||||||
self.kex_engine.parse_next(ptype, m)
|
self.kex_engine.parse_next(ptype, m)
|
||||||
continue
|
continue
|
||||||
|
@ -1651,7 +1651,7 @@ class Transport (threading.Thread):
|
||||||
if not self.packetizer.need_rekey():
|
if not self.packetizer.need_rekey():
|
||||||
self.in_kex = False
|
self.in_kex = False
|
||||||
# we always expect to receive NEWKEYS now
|
# we always expect to receive NEWKEYS now
|
||||||
self.expected_packet = MSG_NEWKEYS
|
self._expect_packet(MSG_NEWKEYS)
|
||||||
|
|
||||||
def _auth_trigger(self):
|
def _auth_trigger(self):
|
||||||
self.authenticated = True
|
self.authenticated = True
|
||||||
|
|
Loading…
Reference in New Issue