[project @ Arch-1:robey@lag.net--2005-master-shake%paramiko--dev--1--patch-25]
cool optimization from john rochester: use cStringIO in Message (and also fix some unit test bugs revealed by the change)
This commit is contained in:
parent
0b093e49b4
commit
e3ed1616d1
|
@ -20,7 +20,7 @@
|
|||
Implementation of an SSH2 "message".
|
||||
"""
|
||||
|
||||
import string, types, struct
|
||||
import struct, cStringIO
|
||||
import util
|
||||
|
||||
|
||||
|
@ -31,16 +31,18 @@ class Message (object):
|
|||
as I{long}s). This class builds or breaks down such a byte stream.
|
||||
"""
|
||||
|
||||
def __init__(self, content=''):
|
||||
def __init__(self, content=None):
|
||||
"""
|
||||
Create a new SSH2 Message.
|
||||
|
||||
@param content: the byte stream to use as the Message content (usually
|
||||
passed in only when decomposing a Message).
|
||||
@param content: the byte stream to use as the Message content (passed
|
||||
in only when decomposing a Message).
|
||||
@type content: string
|
||||
"""
|
||||
self.packet = content
|
||||
self.idx = 0
|
||||
if content != None:
|
||||
self.packet = cStringIO.StringIO(content)
|
||||
else:
|
||||
self.packet = cStringIO.StringIO()
|
||||
|
||||
def __str__(self):
|
||||
"""
|
||||
|
@ -49,7 +51,7 @@ class Message (object):
|
|||
@return: the contents of this Message.
|
||||
@rtype: string
|
||||
"""
|
||||
return self.packet
|
||||
return self.packet.getvalue()
|
||||
|
||||
def __repr__(self):
|
||||
"""
|
||||
|
@ -57,14 +59,14 @@ class Message (object):
|
|||
|
||||
@rtype: string
|
||||
"""
|
||||
return 'paramiko.Message(' + repr(self.packet) + ')'
|
||||
return 'paramiko.Message(' + repr(self.packet.getvalue()) + ')'
|
||||
|
||||
def rewind(self):
|
||||
"""
|
||||
Rewind the message to the beginning as if no items had been parsed
|
||||
out of it yet.
|
||||
"""
|
||||
self.idx = 0
|
||||
self.packet.seek(0)
|
||||
|
||||
def get_remainder(self):
|
||||
"""
|
||||
|
@ -74,7 +76,10 @@ class Message (object):
|
|||
@return: a string of the bytes not parsed yet.
|
||||
@rtype: string
|
||||
"""
|
||||
return self.packet[self.idx:]
|
||||
position = self.packet.tell()
|
||||
remainder = self.packet.read()
|
||||
self.packet.seek(position)
|
||||
return remainder
|
||||
|
||||
def get_so_far(self):
|
||||
"""
|
||||
|
@ -85,7 +90,9 @@ class Message (object):
|
|||
@return: a string of the bytes parsed so far.
|
||||
@rtype: string
|
||||
"""
|
||||
return self.packet[:self.idx]
|
||||
position = self.packet.tell()
|
||||
self.rewind()
|
||||
return self.packet.read(position)
|
||||
|
||||
def get_bytes(self, n):
|
||||
"""
|
||||
|
@ -96,10 +103,9 @@ class Message (object):
|
|||
of C{n} zero bytes, if there aren't C{n} bytes remaining.
|
||||
@rtype: string
|
||||
"""
|
||||
if self.idx + n > len(self.packet):
|
||||
b = self.packet.read(n)
|
||||
if len(b) < n:
|
||||
return '\x00'*n
|
||||
b = self.packet[self.idx:self.idx+n]
|
||||
self.idx = self.idx + n
|
||||
return b
|
||||
|
||||
def get_byte(self):
|
||||
|
@ -130,13 +136,7 @@ class Message (object):
|
|||
@return: a 32-bit unsigned integer.
|
||||
@rtype: int
|
||||
"""
|
||||
x = self.packet
|
||||
i = self.idx
|
||||
if i + 4 > len(x):
|
||||
return 0
|
||||
n = struct.unpack('>I', x[i:i+4])[0]
|
||||
self.idx = i+4
|
||||
return n
|
||||
return struct.unpack('>I', self.get_bytes(4))[0]
|
||||
|
||||
def get_int64(self):
|
||||
"""
|
||||
|
@ -145,13 +145,7 @@ class Message (object):
|
|||
@return: a 64-bit unsigned integer.
|
||||
@rtype: long
|
||||
"""
|
||||
x = self.packet
|
||||
i = self.idx
|
||||
if i + 8 > len(x):
|
||||
return 0L
|
||||
n = struct.unpack('>Q', x[i:i+8])[0]
|
||||
self.idx += 8
|
||||
return n
|
||||
return struct.unpack('>Q', self.get_bytes(8))[0]
|
||||
|
||||
def get_mpint(self):
|
||||
"""
|
||||
|
@ -171,12 +165,7 @@ class Message (object):
|
|||
@return: a string.
|
||||
@rtype: string
|
||||
"""
|
||||
l = self.get_int()
|
||||
if self.idx + l > len(self.packet):
|
||||
return ''
|
||||
str = self.packet[self.idx:self.idx+l]
|
||||
self.idx = self.idx + l
|
||||
return str
|
||||
return self.get_bytes(self.get_int())
|
||||
|
||||
def get_list(self):
|
||||
"""
|
||||
|
@ -186,16 +175,14 @@ class Message (object):
|
|||
@return: a list of strings.
|
||||
@rtype: list of strings
|
||||
"""
|
||||
str = self.get_string()
|
||||
l = string.split(str, ',')
|
||||
return l
|
||||
return self.get_string().split(',')
|
||||
|
||||
def add_bytes(self, b):
|
||||
self.packet = self.packet + b
|
||||
self.packet.write(b)
|
||||
return self
|
||||
|
||||
def add_byte(self, b):
|
||||
self.packet = self.packet + b
|
||||
self.packet.write(b)
|
||||
return self
|
||||
|
||||
def add_boolean(self, b):
|
||||
|
@ -206,7 +193,7 @@ class Message (object):
|
|||
return self
|
||||
|
||||
def add_int(self, n):
|
||||
self.packet = self.packet + struct.pack('>I', n)
|
||||
self.packet.write(struct.pack('>I', n))
|
||||
return self
|
||||
|
||||
def add_int64(self, n):
|
||||
|
@ -216,7 +203,7 @@ class Message (object):
|
|||
@param n: long int to add.
|
||||
@type n: long
|
||||
"""
|
||||
self.packet = self.packet + struct.pack('>Q', n)
|
||||
self.packet.write(struct.pack('>Q', n))
|
||||
return self
|
||||
|
||||
def add_mpint(self, z):
|
||||
|
@ -226,13 +213,11 @@ class Message (object):
|
|||
|
||||
def add_string(self, s):
|
||||
self.add_int(len(s))
|
||||
self.packet = self.packet + s
|
||||
self.packet.write(s)
|
||||
return self
|
||||
|
||||
def add_list(self, l):
|
||||
out = string.join(l, ',')
|
||||
self.add_int(len(out))
|
||||
self.packet = self.packet + out
|
||||
self.add_string(','.join(l))
|
||||
return self
|
||||
|
||||
def _add(self, i):
|
||||
|
|
|
@ -112,7 +112,6 @@ class SFTPAttributes (object):
|
|||
count = msg.get_int()
|
||||
for i in range(count):
|
||||
self.attr[msg.get_string()] = msg.get_string()
|
||||
return msg.get_remainder()
|
||||
|
||||
def _pack(self, msg):
|
||||
self._flags = 0
|
||||
|
|
|
@ -639,8 +639,7 @@ class BaseTransport (threading.Thread):
|
|||
m.add_string(kind)
|
||||
m.add_boolean(wait)
|
||||
if data is not None:
|
||||
for item in data:
|
||||
m.add(item)
|
||||
m.add(*data)
|
||||
self._log(DEBUG, 'Sending global request "%s"' % kind)
|
||||
self._send_user_message(m)
|
||||
if not wait:
|
||||
|
@ -1085,16 +1084,16 @@ class BaseTransport (threading.Thread):
|
|||
m = Message()
|
||||
m.add_byte(chr(MSG_KEXINIT))
|
||||
m.add_bytes(randpool.get_bytes(16))
|
||||
m.add(','.join(self._preferred_kex))
|
||||
m.add(','.join(available_server_keys))
|
||||
m.add(','.join(self._preferred_ciphers))
|
||||
m.add(','.join(self._preferred_ciphers))
|
||||
m.add(','.join(self._preferred_macs))
|
||||
m.add(','.join(self._preferred_macs))
|
||||
m.add('none')
|
||||
m.add('none')
|
||||
m.add('')
|
||||
m.add('')
|
||||
m.add_list(self._preferred_kex)
|
||||
m.add_list(available_server_keys)
|
||||
m.add_list(self._preferred_ciphers)
|
||||
m.add_list(self._preferred_ciphers)
|
||||
m.add_list(self._preferred_macs)
|
||||
m.add_list(self._preferred_macs)
|
||||
m.add_string('none')
|
||||
m.add_string('none')
|
||||
m.add_string('')
|
||||
m.add_string('')
|
||||
m.add_boolean(False)
|
||||
m.add_int(0)
|
||||
# save a copy for later (needed to compute a hash)
|
||||
|
@ -1274,8 +1273,7 @@ class BaseTransport (threading.Thread):
|
|||
msg = Message()
|
||||
if ok:
|
||||
msg.add_byte(chr(MSG_REQUEST_SUCCESS))
|
||||
for item in extra:
|
||||
msg.add(item)
|
||||
msg.add(*extra)
|
||||
else:
|
||||
msg.add_byte(chr(MSG_REQUEST_FAILURE))
|
||||
self._send_message(msg)
|
||||
|
|
|
@ -97,6 +97,7 @@ class KexTest (unittest.TestCase):
|
|||
msg.add_string('fake-host-key')
|
||||
msg.add_mpint(69)
|
||||
msg.add_string('fake-sig')
|
||||
msg.rewind()
|
||||
kex.parse_next(paramiko.kex_group1._MSG_KEXDH_REPLY, msg)
|
||||
H = '03079780F3D3AD0B3C6DB30C8D21685F367A86D2'
|
||||
self.assertEquals(self.K, transport._K)
|
||||
|
@ -113,6 +114,7 @@ class KexTest (unittest.TestCase):
|
|||
|
||||
msg = Message()
|
||||
msg.add_mpint(69)
|
||||
msg.rewind()
|
||||
kex.parse_next(paramiko.kex_group1._MSG_KEXDH_INIT, msg)
|
||||
H = 'B16BF34DD10945EDE84E9C1EF24A14BFDC843389'
|
||||
x = '1F0000000866616B652D6B6579000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D40000000866616B652D736967'
|
||||
|
@ -133,6 +135,7 @@ class KexTest (unittest.TestCase):
|
|||
msg = Message()
|
||||
msg.add_mpint(FakeModulusPack.P)
|
||||
msg.add_mpint(FakeModulusPack.G)
|
||||
msg.rewind()
|
||||
kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_GROUP, msg)
|
||||
x = '20000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D4'
|
||||
self.assertEquals(x, paramiko.util.hexify(str(transport._message)))
|
||||
|
@ -142,6 +145,7 @@ class KexTest (unittest.TestCase):
|
|||
msg.add_string('fake-host-key')
|
||||
msg.add_mpint(69)
|
||||
msg.add_string('fake-sig')
|
||||
msg.rewind()
|
||||
kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_REPLY, msg)
|
||||
H = 'A265563F2FA87F1A89BF007EE90D58BE2E4A4BD0'
|
||||
self.assertEquals(self.K, transport._K)
|
||||
|
@ -160,6 +164,7 @@ class KexTest (unittest.TestCase):
|
|||
msg.add_int(1024)
|
||||
msg.add_int(2048)
|
||||
msg.add_int(4096)
|
||||
msg.rewind()
|
||||
kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST, msg)
|
||||
x = '1F0000008100FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF0000000102'
|
||||
self.assertEquals(x, paramiko.util.hexify(str(transport._message)))
|
||||
|
@ -167,6 +172,7 @@ class KexTest (unittest.TestCase):
|
|||
|
||||
msg = Message()
|
||||
msg.add_mpint(12345)
|
||||
msg.rewind()
|
||||
kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_INIT, msg)
|
||||
K = 67592995013596137876033460028393339951879041140378510871612128162185209509220726296697886624612526735888348020498716482757677848959420073720160491114319163078862905400020959196386947926388406687288901564192071077389283980347784184487280885335302632305026248574716290537036069329724382811853044654824945750581L
|
||||
H = 'CE754197C21BF3452863B4F44D0B3951F12516EF'
|
||||
|
|
|
@ -104,6 +104,7 @@ class KeyTest (unittest.TestCase):
|
|||
key = RSAKey.from_private_key_file('tests/test_rsa.key')
|
||||
msg = key.sign_ssh_data(randpool, 'ice weasels')
|
||||
self.assert_(type(msg) is Message)
|
||||
msg.rewind()
|
||||
self.assertEquals('ssh-rsa', msg.get_string())
|
||||
sig = ''.join([chr(int(x, 16)) for x in SIGNED_RSA.split(':')])
|
||||
self.assertEquals(sig, msg.get_string())
|
||||
|
@ -116,6 +117,7 @@ class KeyTest (unittest.TestCase):
|
|||
key = DSSKey.from_private_key_file('tests/test_dss.key')
|
||||
msg = key.sign_ssh_data(randpool, 'ice weasels')
|
||||
self.assert_(type(msg) is Message)
|
||||
msg.rewind()
|
||||
self.assertEquals('ssh-dss', msg.get_string())
|
||||
# can't do the same test as we do for RSA, because DSS signatures
|
||||
# are usually different each time. but we can test verification
|
||||
|
@ -128,9 +130,11 @@ class KeyTest (unittest.TestCase):
|
|||
def test_A_generate_rsa(self):
|
||||
key = RSAKey.generate(1024)
|
||||
msg = key.sign_ssh_data(randpool, 'jerri blank')
|
||||
msg.rewind()
|
||||
self.assert_(key.verify_ssh_sig('jerri blank', msg))
|
||||
|
||||
def test_B_generate_dss(self):
|
||||
key = DSSKey.generate(1024)
|
||||
msg = key.sign_ssh_data(randpool, 'jerri blank')
|
||||
msg.rewind()
|
||||
self.assert_(key.verify_ssh_sig('jerri blank', msg))
|
||||
|
|
Loading…
Reference in New Issue