[project @ Arch-1:robey@lag.net--2005-master-shake%paramiko--dev--1--patch-25]

cool optimization from john rochester: use cStringIO in Message (and also fix some unit test bugs revealed by the change)
This commit is contained in:
Robey Pointer 2005-07-07 01:10:57 +00:00
parent 0b093e49b4
commit e3ed1616d1
5 changed files with 52 additions and 60 deletions

View File

@ -20,7 +20,7 @@
Implementation of an SSH2 "message". Implementation of an SSH2 "message".
""" """
import string, types, struct import struct, cStringIO
import util import util
@ -31,16 +31,18 @@ class Message (object):
as I{long}s). This class builds or breaks down such a byte stream. as I{long}s). This class builds or breaks down such a byte stream.
""" """
def __init__(self, content=''): def __init__(self, content=None):
""" """
Create a new SSH2 Message. Create a new SSH2 Message.
@param content: the byte stream to use as the Message content (usually @param content: the byte stream to use as the Message content (passed
passed in only when decomposing a Message). in only when decomposing a Message).
@type content: string @type content: string
""" """
self.packet = content if content != None:
self.idx = 0 self.packet = cStringIO.StringIO(content)
else:
self.packet = cStringIO.StringIO()
def __str__(self): def __str__(self):
""" """
@ -49,7 +51,7 @@ class Message (object):
@return: the contents of this Message. @return: the contents of this Message.
@rtype: string @rtype: string
""" """
return self.packet return self.packet.getvalue()
def __repr__(self): def __repr__(self):
""" """
@ -57,14 +59,14 @@ class Message (object):
@rtype: string @rtype: string
""" """
return 'paramiko.Message(' + repr(self.packet) + ')' return 'paramiko.Message(' + repr(self.packet.getvalue()) + ')'
def rewind(self): def rewind(self):
""" """
Rewind the message to the beginning as if no items had been parsed Rewind the message to the beginning as if no items had been parsed
out of it yet. out of it yet.
""" """
self.idx = 0 self.packet.seek(0)
def get_remainder(self): def get_remainder(self):
""" """
@ -74,7 +76,10 @@ class Message (object):
@return: a string of the bytes not parsed yet. @return: a string of the bytes not parsed yet.
@rtype: string @rtype: string
""" """
return self.packet[self.idx:] position = self.packet.tell()
remainder = self.packet.read()
self.packet.seek(position)
return remainder
def get_so_far(self): def get_so_far(self):
""" """
@ -85,7 +90,9 @@ class Message (object):
@return: a string of the bytes parsed so far. @return: a string of the bytes parsed so far.
@rtype: string @rtype: string
""" """
return self.packet[:self.idx] position = self.packet.tell()
self.rewind()
return self.packet.read(position)
def get_bytes(self, n): def get_bytes(self, n):
""" """
@ -96,10 +103,9 @@ class Message (object):
of C{n} zero bytes, if there aren't C{n} bytes remaining. of C{n} zero bytes, if there aren't C{n} bytes remaining.
@rtype: string @rtype: string
""" """
if self.idx + n > len(self.packet): b = self.packet.read(n)
if len(b) < n:
return '\x00'*n return '\x00'*n
b = self.packet[self.idx:self.idx+n]
self.idx = self.idx + n
return b return b
def get_byte(self): def get_byte(self):
@ -130,13 +136,7 @@ class Message (object):
@return: a 32-bit unsigned integer. @return: a 32-bit unsigned integer.
@rtype: int @rtype: int
""" """
x = self.packet return struct.unpack('>I', self.get_bytes(4))[0]
i = self.idx
if i + 4 > len(x):
return 0
n = struct.unpack('>I', x[i:i+4])[0]
self.idx = i+4
return n
def get_int64(self): def get_int64(self):
""" """
@ -145,13 +145,7 @@ class Message (object):
@return: a 64-bit unsigned integer. @return: a 64-bit unsigned integer.
@rtype: long @rtype: long
""" """
x = self.packet return struct.unpack('>Q', self.get_bytes(8))[0]
i = self.idx
if i + 8 > len(x):
return 0L
n = struct.unpack('>Q', x[i:i+8])[0]
self.idx += 8
return n
def get_mpint(self): def get_mpint(self):
""" """
@ -171,12 +165,7 @@ class Message (object):
@return: a string. @return: a string.
@rtype: string @rtype: string
""" """
l = self.get_int() return self.get_bytes(self.get_int())
if self.idx + l > len(self.packet):
return ''
str = self.packet[self.idx:self.idx+l]
self.idx = self.idx + l
return str
def get_list(self): def get_list(self):
""" """
@ -186,16 +175,14 @@ class Message (object):
@return: a list of strings. @return: a list of strings.
@rtype: list of strings @rtype: list of strings
""" """
str = self.get_string() return self.get_string().split(',')
l = string.split(str, ',')
return l
def add_bytes(self, b): def add_bytes(self, b):
self.packet = self.packet + b self.packet.write(b)
return self return self
def add_byte(self, b): def add_byte(self, b):
self.packet = self.packet + b self.packet.write(b)
return self return self
def add_boolean(self, b): def add_boolean(self, b):
@ -206,7 +193,7 @@ class Message (object):
return self return self
def add_int(self, n): def add_int(self, n):
self.packet = self.packet + struct.pack('>I', n) self.packet.write(struct.pack('>I', n))
return self return self
def add_int64(self, n): def add_int64(self, n):
@ -216,7 +203,7 @@ class Message (object):
@param n: long int to add. @param n: long int to add.
@type n: long @type n: long
""" """
self.packet = self.packet + struct.pack('>Q', n) self.packet.write(struct.pack('>Q', n))
return self return self
def add_mpint(self, z): def add_mpint(self, z):
@ -226,13 +213,11 @@ class Message (object):
def add_string(self, s): def add_string(self, s):
self.add_int(len(s)) self.add_int(len(s))
self.packet = self.packet + s self.packet.write(s)
return self return self
def add_list(self, l): def add_list(self, l):
out = string.join(l, ',') self.add_string(','.join(l))
self.add_int(len(out))
self.packet = self.packet + out
return self return self
def _add(self, i): def _add(self, i):

View File

@ -112,7 +112,6 @@ class SFTPAttributes (object):
count = msg.get_int() count = msg.get_int()
for i in range(count): for i in range(count):
self.attr[msg.get_string()] = msg.get_string() self.attr[msg.get_string()] = msg.get_string()
return msg.get_remainder()
def _pack(self, msg): def _pack(self, msg):
self._flags = 0 self._flags = 0

View File

@ -639,8 +639,7 @@ class BaseTransport (threading.Thread):
m.add_string(kind) m.add_string(kind)
m.add_boolean(wait) m.add_boolean(wait)
if data is not None: if data is not None:
for item in data: m.add(*data)
m.add(item)
self._log(DEBUG, 'Sending global request "%s"' % kind) self._log(DEBUG, 'Sending global request "%s"' % kind)
self._send_user_message(m) self._send_user_message(m)
if not wait: if not wait:
@ -1085,16 +1084,16 @@ class BaseTransport (threading.Thread):
m = Message() m = Message()
m.add_byte(chr(MSG_KEXINIT)) m.add_byte(chr(MSG_KEXINIT))
m.add_bytes(randpool.get_bytes(16)) m.add_bytes(randpool.get_bytes(16))
m.add(','.join(self._preferred_kex)) m.add_list(self._preferred_kex)
m.add(','.join(available_server_keys)) m.add_list(available_server_keys)
m.add(','.join(self._preferred_ciphers)) m.add_list(self._preferred_ciphers)
m.add(','.join(self._preferred_ciphers)) m.add_list(self._preferred_ciphers)
m.add(','.join(self._preferred_macs)) m.add_list(self._preferred_macs)
m.add(','.join(self._preferred_macs)) m.add_list(self._preferred_macs)
m.add('none') m.add_string('none')
m.add('none') m.add_string('none')
m.add('') m.add_string('')
m.add('') m.add_string('')
m.add_boolean(False) m.add_boolean(False)
m.add_int(0) m.add_int(0)
# save a copy for later (needed to compute a hash) # save a copy for later (needed to compute a hash)
@ -1274,8 +1273,7 @@ class BaseTransport (threading.Thread):
msg = Message() msg = Message()
if ok: if ok:
msg.add_byte(chr(MSG_REQUEST_SUCCESS)) msg.add_byte(chr(MSG_REQUEST_SUCCESS))
for item in extra: msg.add(*extra)
msg.add(item)
else: else:
msg.add_byte(chr(MSG_REQUEST_FAILURE)) msg.add_byte(chr(MSG_REQUEST_FAILURE))
self._send_message(msg) self._send_message(msg)

View File

@ -97,6 +97,7 @@ class KexTest (unittest.TestCase):
msg.add_string('fake-host-key') msg.add_string('fake-host-key')
msg.add_mpint(69) msg.add_mpint(69)
msg.add_string('fake-sig') msg.add_string('fake-sig')
msg.rewind()
kex.parse_next(paramiko.kex_group1._MSG_KEXDH_REPLY, msg) kex.parse_next(paramiko.kex_group1._MSG_KEXDH_REPLY, msg)
H = '03079780F3D3AD0B3C6DB30C8D21685F367A86D2' H = '03079780F3D3AD0B3C6DB30C8D21685F367A86D2'
self.assertEquals(self.K, transport._K) self.assertEquals(self.K, transport._K)
@ -113,6 +114,7 @@ class KexTest (unittest.TestCase):
msg = Message() msg = Message()
msg.add_mpint(69) msg.add_mpint(69)
msg.rewind()
kex.parse_next(paramiko.kex_group1._MSG_KEXDH_INIT, msg) kex.parse_next(paramiko.kex_group1._MSG_KEXDH_INIT, msg)
H = 'B16BF34DD10945EDE84E9C1EF24A14BFDC843389' H = 'B16BF34DD10945EDE84E9C1EF24A14BFDC843389'
x = '1F0000000866616B652D6B6579000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D40000000866616B652D736967' x = '1F0000000866616B652D6B6579000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D40000000866616B652D736967'
@ -133,6 +135,7 @@ class KexTest (unittest.TestCase):
msg = Message() msg = Message()
msg.add_mpint(FakeModulusPack.P) msg.add_mpint(FakeModulusPack.P)
msg.add_mpint(FakeModulusPack.G) msg.add_mpint(FakeModulusPack.G)
msg.rewind()
kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_GROUP, msg) kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_GROUP, msg)
x = '20000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D4' x = '20000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D4'
self.assertEquals(x, paramiko.util.hexify(str(transport._message))) self.assertEquals(x, paramiko.util.hexify(str(transport._message)))
@ -142,6 +145,7 @@ class KexTest (unittest.TestCase):
msg.add_string('fake-host-key') msg.add_string('fake-host-key')
msg.add_mpint(69) msg.add_mpint(69)
msg.add_string('fake-sig') msg.add_string('fake-sig')
msg.rewind()
kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_REPLY, msg) kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_REPLY, msg)
H = 'A265563F2FA87F1A89BF007EE90D58BE2E4A4BD0' H = 'A265563F2FA87F1A89BF007EE90D58BE2E4A4BD0'
self.assertEquals(self.K, transport._K) self.assertEquals(self.K, transport._K)
@ -160,6 +164,7 @@ class KexTest (unittest.TestCase):
msg.add_int(1024) msg.add_int(1024)
msg.add_int(2048) msg.add_int(2048)
msg.add_int(4096) msg.add_int(4096)
msg.rewind()
kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST, msg) kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST, msg)
x = '1F0000008100FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF0000000102' x = '1F0000008100FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF0000000102'
self.assertEquals(x, paramiko.util.hexify(str(transport._message))) self.assertEquals(x, paramiko.util.hexify(str(transport._message)))
@ -167,6 +172,7 @@ class KexTest (unittest.TestCase):
msg = Message() msg = Message()
msg.add_mpint(12345) msg.add_mpint(12345)
msg.rewind()
kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_INIT, msg) kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_INIT, msg)
K = 67592995013596137876033460028393339951879041140378510871612128162185209509220726296697886624612526735888348020498716482757677848959420073720160491114319163078862905400020959196386947926388406687288901564192071077389283980347784184487280885335302632305026248574716290537036069329724382811853044654824945750581L K = 67592995013596137876033460028393339951879041140378510871612128162185209509220726296697886624612526735888348020498716482757677848959420073720160491114319163078862905400020959196386947926388406687288901564192071077389283980347784184487280885335302632305026248574716290537036069329724382811853044654824945750581L
H = 'CE754197C21BF3452863B4F44D0B3951F12516EF' H = 'CE754197C21BF3452863B4F44D0B3951F12516EF'

View File

@ -104,6 +104,7 @@ class KeyTest (unittest.TestCase):
key = RSAKey.from_private_key_file('tests/test_rsa.key') key = RSAKey.from_private_key_file('tests/test_rsa.key')
msg = key.sign_ssh_data(randpool, 'ice weasels') msg = key.sign_ssh_data(randpool, 'ice weasels')
self.assert_(type(msg) is Message) self.assert_(type(msg) is Message)
msg.rewind()
self.assertEquals('ssh-rsa', msg.get_string()) self.assertEquals('ssh-rsa', msg.get_string())
sig = ''.join([chr(int(x, 16)) for x in SIGNED_RSA.split(':')]) sig = ''.join([chr(int(x, 16)) for x in SIGNED_RSA.split(':')])
self.assertEquals(sig, msg.get_string()) self.assertEquals(sig, msg.get_string())
@ -116,6 +117,7 @@ class KeyTest (unittest.TestCase):
key = DSSKey.from_private_key_file('tests/test_dss.key') key = DSSKey.from_private_key_file('tests/test_dss.key')
msg = key.sign_ssh_data(randpool, 'ice weasels') msg = key.sign_ssh_data(randpool, 'ice weasels')
self.assert_(type(msg) is Message) self.assert_(type(msg) is Message)
msg.rewind()
self.assertEquals('ssh-dss', msg.get_string()) self.assertEquals('ssh-dss', msg.get_string())
# can't do the same test as we do for RSA, because DSS signatures # can't do the same test as we do for RSA, because DSS signatures
# are usually different each time. but we can test verification # are usually different each time. but we can test verification
@ -128,9 +130,11 @@ class KeyTest (unittest.TestCase):
def test_A_generate_rsa(self): def test_A_generate_rsa(self):
key = RSAKey.generate(1024) key = RSAKey.generate(1024)
msg = key.sign_ssh_data(randpool, 'jerri blank') msg = key.sign_ssh_data(randpool, 'jerri blank')
msg.rewind()
self.assert_(key.verify_ssh_sig('jerri blank', msg)) self.assert_(key.verify_ssh_sig('jerri blank', msg))
def test_B_generate_dss(self): def test_B_generate_dss(self):
key = DSSKey.generate(1024) key = DSSKey.generate(1024)
msg = key.sign_ssh_data(randpool, 'jerri blank') msg = key.sign_ssh_data(randpool, 'jerri blank')
msg.rewind()
self.assert_(key.verify_ssh_sig('jerri blank', msg)) self.assert_(key.verify_ssh_sig('jerri blank', msg))