Fix bytes/str type in more places

This commit is contained in:
Scott Maxwell 2013-10-31 10:01:21 -07:00
parent e4e1dc2002
commit fcf56ff9f8
13 changed files with 147 additions and 106 deletions

View File

@ -25,6 +25,7 @@ read operations are blocking and can have a timeout set.
import array import array
import threading import threading
import time import time
from paramiko.common import *
class PipeTimeout (IOError): class PipeTimeout (IOError):
@ -121,7 +122,7 @@ class BufferedPipe (object):
@raise PipeTimeout: if a timeout was specified and no data was ready @raise PipeTimeout: if a timeout was specified and no data was ready
before that timeout before that timeout
""" """
out = '' out = bytes()
self._lock.acquire() self._lock.acquire()
try: try:
if len(self._buffer) == 0: if len(self._buffer) == 0:

View File

@ -63,7 +63,7 @@ class ECDSAKey (PKey):
raise SSHException("Can't handle curve of type %s" % curvename) raise SSHException("Can't handle curve of type %s" % curvename)
pointinfo = msg.get_binary() pointinfo = msg.get_binary()
if pointinfo[0] != four_byte: if pointinfo[0:1] != four_byte:
raise SSHException('Point compression is being used: %s' % raise SSHException('Point compression is being used: %s' %
binascii.hexlify(pointinfo)) binascii.hexlify(pointinfo))
self.verifying_key = VerifyingKey.from_string(pointinfo[1:], self.verifying_key = VerifyingKey.from_string(pointinfo[1:],
@ -157,6 +157,11 @@ class ECDSAKey (PKey):
data = self._read_private_key('EC', file_obj, password) data = self._read_private_key('EC', file_obj, password)
self._decode_key(data) self._decode_key(data)
if PY3:
ALLOWED_PADDINGS = [b'\x01', b'\x02\x02', b'\x03\x03\x03', b'\x04\x04\x04\x04',
b'\x05\x05\x05\x05\x05', b'\x06\x06\x06\x06\x06\x06',
b'\x07\x07\x07\x07\x07\x07\x07']
else:
ALLOWED_PADDINGS = ['\x01', '\x02\x02', '\x03\x03\x03', '\x04\x04\x04\x04', ALLOWED_PADDINGS = ['\x01', '\x02\x02', '\x03\x03\x03', '\x04\x04\x04\x04',
'\x05\x05\x05\x05\x05', '\x06\x06\x06\x06\x06\x06', '\x05\x05\x05\x05\x05', '\x06\x06\x06\x06\x06\x06',
'\x07\x07\x07\x07\x07\x07\x07'] '\x07\x07\x07\x07\x07\x07\x07']
@ -164,7 +169,7 @@ class ECDSAKey (PKey):
s, padding = der.remove_sequence(data) s, padding = der.remove_sequence(data)
if padding: if padding:
if padding not in self.ALLOWED_PADDINGS: if padding not in self.ALLOWED_PADDINGS:
raise ValueError("weird padding: %s" % (binascii.hexlify(empty))) raise ValueError("weird padding: %s" % (binascii.hexlify(data)))
data = data[:-len(padding)] data = data[:-len(padding)]
key = SigningKey.from_der(data) key = SigningKey.from_der(data)
self.signing_key = key self.signing_key = key

View File

@ -81,6 +81,7 @@ class HostKeyEntry:
# Decide what kind of key we're looking at and create an object # Decide what kind of key we're looking at and create an object
# to hold it accordingly. # to hold it accordingly.
try: try:
key = b(key)
if keytype == 'ssh-rsa': if keytype == 'ssh-rsa':
key = RSAKey(data=base64.decodestring(key)) key = RSAKey(data=base64.decodestring(key))
elif keytype == 'ssh-dss': elif keytype == 'ssh-dss':
@ -361,9 +362,9 @@ class HostKeys (MutableMapping):
else: else:
if salt.startswith('|1|'): if salt.startswith('|1|'):
salt = salt.split('|')[2] salt = salt.split('|')[2]
salt = base64.decodestring(salt) salt = base64.decodestring(b(salt))
assert len(salt) == SHA.digest_size assert len(salt) == SHA.digest_size
hmac = HMAC.HMAC(salt, hostname, SHA).digest() hmac = HMAC.HMAC(salt, b(hostname), SHA).digest()
hostkey = '|1|%s|%s' % (base64.encodestring(salt), base64.encodestring(hmac)) hostkey = '|1|%s|%s' % (base64.encodestring(salt), base64.encodestring(hmac))
return hostkey.replace('\n', '') return hostkey.replace('\n', '')
hash_host = staticmethod(hash_host) hash_host = staticmethod(hash_host)

View File

@ -87,8 +87,8 @@ class Packetizer (object):
self.__sdctr_out = False self.__sdctr_out = False
self.__mac_engine_out = None self.__mac_engine_out = None
self.__mac_engine_in = None self.__mac_engine_in = None
self.__mac_key_out = '' self.__mac_key_out = bytes()
self.__mac_key_in = '' self.__mac_key_in = bytes()
self.__compress_engine_out = None self.__compress_engine_out = None
self.__compress_engine_in = None self.__compress_engine_in = None
self.__sequence_number_out = long_zero self.__sequence_number_out = long_zero

View File

@ -148,7 +148,7 @@ class PKey (object):
@return: a base64 string containing the public part of the key. @return: a base64 string containing the public part of the key.
@rtype: str @rtype: str
""" """
return base64.encodestring(self.asbytes()).replace('\n', '') return u(base64.encodestring(self.asbytes())).replace('\n', '')
def sign_ssh_data(self, rng, data): def sign_ssh_data(self, rng, data):
""" """
@ -378,7 +378,7 @@ class PKey (object):
f.write('Proc-Type: 4,ENCRYPTED\n') f.write('Proc-Type: 4,ENCRYPTED\n')
f.write('DEK-Info: %s,%s\n' % (cipher_name, hexlify(salt).upper())) f.write('DEK-Info: %s,%s\n' % (cipher_name, hexlify(salt).upper()))
f.write('\n') f.write('\n')
s = base64.encodestring(data) s = u(base64.encodestring(data))
# re-wrap to 64-char lines # re-wrap to 64-char lines
s = ''.join(s.split('\n')) s = ''.join(s.split('\n'))
s = '\n'.join([s[i : i+64] for i in range(0, len(s), 64)]) s = '\n'.join([s[i : i+64] for i in range(0, len(s), 64)])

View File

@ -152,15 +152,26 @@ class RSAKey (PKey):
### internals... ### internals...
if PY3:
def _pkcs1imify(self, data): def _pkcs1imify(self, data):
""" """
turn a 20-byte SHA1 hash into a blob of data as large as the key's N, turn a 20-byte SHA1 hash into a blob of data as large as the key's N,
using PKCS1's \"emsa-pkcs1-v1_5\" encoding. totally bizarre. using PKCS1's \"emsa-pkcs1-v1_5\" encoding. totally bizarre.
""" """
SHA1_DIGESTINFO = '\x30\x21\x30\x09\x06\x05\x2b\x0e\x03\x02\x1a\x05\x00\x04\x14' SHA1_DIGESTINFO = b'\x30\x21\x30\x09\x06\x05\x2b\x0e\x03\x02\x1a\x05\x00\x04\x14'
size = len(util.deflate_long(self.n, 0)) size = len(util.deflate_long(self.n, 0))
filler = '\xff' * (size - len(SHA1_DIGESTINFO) - len(data) - 3) filler = b'\xff' * (size - len(SHA1_DIGESTINFO) - len(data) - 3)
return '\x00\x01' + filler + '\x00' + SHA1_DIGESTINFO + data return b'\x00\x01' + filler + b'\x00' + SHA1_DIGESTINFO + data
else:
def _pkcs1imify(self, data):
"""
turn a 20-byte SHA1 hash into a blob of data as large as the key's N,
using PKCS1's \"emsa-pkcs1-v1_5\" encoding. totally bizarre.
"""
SHA1_DIGESTINFO = b('\x30\x21\x30\x09\x06\x05\x2b\x0e\x03\x02\x1a\x05\x00\x04\x14')
size = len(util.deflate_long(self.n, 0))
filler = b('\xff') * (size - len(SHA1_DIGESTINFO) - len(data) - 3)
return b('\x00\x01') + filler + b('\x00') + SHA1_DIGESTINFO + b(data)
def _from_private_key_file(self, filename, password): def _from_private_key_file(self, filename, password):
data = self._read_private_key_file('RSA', filename, password) data = self._read_private_key_file('RSA', filename, password)

View File

@ -151,7 +151,7 @@ def generate_key_bytes(hashclass, salt, key, nbytes):
hashing function (like C{MD5} or C{SHA}). hashing function (like C{MD5} or C{SHA}).
@type hashclass: L{Crypto.Hash} @type hashclass: L{Crypto.Hash}
@param salt: data to salt the hash with. @param salt: data to salt the hash with.
@type salt: string @type salt: byte string
@param key: human-entered password or passphrase. @param key: human-entered password or passphrase.
@type key: string @type key: string
@param nbytes: number of bytes to generate. @param nbytes: number of bytes to generate.
@ -159,15 +159,15 @@ def generate_key_bytes(hashclass, salt, key, nbytes):
@return: key data @return: key data
@rtype: string @rtype: string
""" """
keydata = '' keydata = bytes()
digest = '' digest = bytes()
if len(salt) > 8: if len(salt) > 8:
salt = salt[:8] salt = salt[:8]
while nbytes > 0: while nbytes > 0:
hash_obj = hashclass.new() hash_obj = hashclass.new()
if len(digest) > 0: if len(digest) > 0:
hash_obj.update(digest) hash_obj.update(digest)
hash_obj.update(key) hash_obj.update(b(key))
hash_obj.update(salt) hash_obj.update(salt)
digest = hash_obj.digest() digest = hash_obj.digest()
size = min(nbytes, len(digest)) size = min(nbytes, len(digest))

View File

@ -25,6 +25,7 @@ import time
import unittest import unittest
from paramiko.buffered_pipe import BufferedPipe, PipeTimeout from paramiko.buffered_pipe import BufferedPipe, PipeTimeout
from paramiko import pipe from paramiko import pipe
from paramiko.py3compat import b
from tests.util import ParamikoTest from tests.util import ParamikoTest
@ -48,35 +49,35 @@ class BufferedPipeTest(ParamikoTest):
p.feed('hello.') p.feed('hello.')
self.assert_(p.read_ready()) self.assert_(p.read_ready())
data = p.read(6) data = p.read(6)
self.assertEquals('hello.', data) self.assertEquals(b('hello.'), data)
p.feed('plus/minus') p.feed('plus/minus')
self.assertEquals('plu', p.read(3)) self.assertEquals(b('plu'), p.read(3))
self.assertEquals('s/m', p.read(3)) self.assertEquals(b('s/m'), p.read(3))
self.assertEquals('inus', p.read(4)) self.assertEquals(b('inus'), p.read(4))
p.close() p.close()
self.assert_(not p.read_ready()) self.assert_(not p.read_ready())
self.assertEquals('', p.read(1)) self.assertEquals(b(''), p.read(1))
def test_2_delay(self): def test_2_delay(self):
p = BufferedPipe() p = BufferedPipe()
self.assert_(not p.read_ready()) self.assert_(not p.read_ready())
threading.Thread(target=delay_thread, args=(p,)).start() threading.Thread(target=delay_thread, args=(p,)).start()
self.assertEquals('a', p.read(1, 0.1)) self.assertEquals(b('a'), p.read(1, 0.1))
try: try:
p.read(1, 0.1) p.read(1, 0.1)
self.assert_(False) self.assert_(False)
except PipeTimeout: except PipeTimeout:
pass pass
self.assertEquals('b', p.read(1, 1.0)) self.assertEquals(b('b'), p.read(1, 1.0))
self.assertEquals('', p.read(1)) self.assertEquals(b(''), p.read(1))
def test_3_close_while_reading(self): def test_3_close_while_reading(self):
p = BufferedPipe() p = BufferedPipe()
threading.Thread(target=close_thread, args=(p,)).start() threading.Thread(target=close_thread, args=(p,)).start()
data = p.read(1, 1.0) data = p.read(1, 1.0)
self.assertEquals('', data) self.assertEquals(b(''), data)
def test_4_or_pipe(self): def test_4_or_pipe(self):
p = pipe.make_pipe() p = pipe.make_pipe()

View File

@ -25,6 +25,7 @@ from binascii import hexlify
import os import os
import unittest import unittest
import paramiko import paramiko
from paramiko.py3compat import b
test_hosts_file = """\ test_hosts_file = """\
@ -36,12 +37,12 @@ BGQ3GQ/Fc7SX6gkpXkwcZryoi4kNFhHu5LvHcZPdxXV1D+uTMfGS1eyd2Yz/DoNWXNAl8TI0cAsW\
5ymME3bQ4J/k1IKxCtz/bAlAqFgKoc+EolMziDYqWIATtW0rYTJvzGAzTmMj80/QpsFH+Pc2M= 5ymME3bQ4J/k1IKxCtz/bAlAqFgKoc+EolMziDYqWIATtW0rYTJvzGAzTmMj80/QpsFH+Pc2M=
""" """
keyblob = """\ keyblob = b("""\
AAAAB3NzaC1yc2EAAAABIwAAAIEA8bP1ZA7DCZDB9J0s50l31MBGQ3GQ/Fc7SX6gkpXkwcZryoi4k\ AAAAB3NzaC1yc2EAAAABIwAAAIEA8bP1ZA7DCZDB9J0s50l31MBGQ3GQ/Fc7SX6gkpXkwcZryoi4k\
NFhHu5LvHcZPdxXV1D+uTMfGS1eyd2Yz/DoNWXNAl8TI0cAsW5ymME3bQ4J/k1IKxCtz/bAlAqFgK\ NFhHu5LvHcZPdxXV1D+uTMfGS1eyd2Yz/DoNWXNAl8TI0cAsW5ymME3bQ4J/k1IKxCtz/bAlAqFgK\
oc+EolMziDYqWIATtW0rYTJvzGAzTmMj80/QpsFH+Pc2M=""" oc+EolMziDYqWIATtW0rYTJvzGAzTmMj80/QpsFH+Pc2M=""")
keyblob_dss = """\ keyblob_dss = b("""\
AAAAB3NzaC1kc3MAAACBAOeBpgNnfRzr/twmAQRu2XwWAp3CFtrVnug6s6fgwj/oLjYbVtjAy6pl/\ AAAAB3NzaC1kc3MAAACBAOeBpgNnfRzr/twmAQRu2XwWAp3CFtrVnug6s6fgwj/oLjYbVtjAy6pl/\
h0EKCWx2rf1IetyNsTxWrniA9I6HeDj65X1FyDkg6g8tvCnaNB8Xp/UUhuzHuGsMIipRxBxw9LF60\ h0EKCWx2rf1IetyNsTxWrniA9I6HeDj65X1FyDkg6g8tvCnaNB8Xp/UUhuzHuGsMIipRxBxw9LF60\
8EqZcj1E3ytktoW5B5OcjrkEoz3xG7C+rpIjYvAAAAFQDwz4UnmsGiSNu5iqjn3uTzwUpshwAAAIE\ 8EqZcj1E3ytktoW5B5OcjrkEoz3xG7C+rpIjYvAAAAFQDwz4UnmsGiSNu5iqjn3uTzwUpshwAAAIE\
@ -49,7 +50,7 @@ AkxfFeY8P2wZpDjX0MimZl5wkoFQDL25cPzGBuB4OnB8NoUk/yjAHIIpEShw8V+LzouMK5CTJQo5+\
Ngw3qIch/WgRmMHy4kBq1SsXMjQCte1So6HBMvBPIW5SiMTmjCfZZiw4AYHK+B/JaOwaG9yRg2Ejg\ Ngw3qIch/WgRmMHy4kBq1SsXMjQCte1So6HBMvBPIW5SiMTmjCfZZiw4AYHK+B/JaOwaG9yRg2Ejg\
4Ok10+XFDxlqZo8Y+wAAACARmR7CCPjodxASvRbIyzaVpZoJ/Z6x7dAumV+ysrV1BVYd0lYukmnjO\ 4Ok10+XFDxlqZo8Y+wAAACARmR7CCPjodxASvRbIyzaVpZoJ/Z6x7dAumV+ysrV1BVYd0lYukmnjO\
1kKBWApqpH1ve9XDQYN8zgxM4b16L21kpoWQnZtXrY3GZ4/it9kUgyB7+NwacIBlXa8cMDL7Q/69o\ 1kKBWApqpH1ve9XDQYN8zgxM4b16L21kpoWQnZtXrY3GZ4/it9kUgyB7+NwacIBlXa8cMDL7Q/69o\
0d54U0X/NeX5QxuYR6OMJlrkQB7oiW/P/1mwjQgE=""" 0d54U0X/NeX5QxuYR6OMJlrkQB7oiW/P/1mwjQgE=""")
class HostKeysTest (unittest.TestCase): class HostKeysTest (unittest.TestCase):
@ -68,7 +69,7 @@ class HostKeysTest (unittest.TestCase):
self.assertEquals(1, len(list(hostdict.values())[0])) self.assertEquals(1, len(list(hostdict.values())[0]))
self.assertEquals(1, len(list(hostdict.values())[1])) self.assertEquals(1, len(list(hostdict.values())[1]))
fp = hexlify(hostdict['secure.example.com']['ssh-rsa'].get_fingerprint()).upper() fp = hexlify(hostdict['secure.example.com']['ssh-rsa'].get_fingerprint()).upper()
self.assertEquals('E6684DB30E109B67B70FF1DC5C7F1363', fp) self.assertEquals(b('E6684DB30E109B67B70FF1DC5C7F1363'), fp)
def test_2_add(self): def test_2_add(self):
hostdict = paramiko.HostKeys('hostfile.temp') hostdict = paramiko.HostKeys('hostfile.temp')
@ -78,7 +79,7 @@ class HostKeysTest (unittest.TestCase):
self.assertEquals(3, len(list(hostdict))) self.assertEquals(3, len(list(hostdict)))
x = hostdict['foo.example.com'] x = hostdict['foo.example.com']
fp = hexlify(x['ssh-rsa'].get_fingerprint()).upper() fp = hexlify(x['ssh-rsa'].get_fingerprint()).upper()
self.assertEquals('7EC91BB336CB6D810B124B1353C32396', fp) self.assertEquals(b('7EC91BB336CB6D810B124B1353C32396'), fp)
self.assert_(hostdict.check('foo.example.com', key)) self.assert_(hostdict.check('foo.example.com', key))
def test_3_dict(self): def test_3_dict(self):
@ -90,7 +91,7 @@ class HostKeysTest (unittest.TestCase):
x = hostdict.get('secure.example.com', None) x = hostdict.get('secure.example.com', None)
self.assert_(x is not None) self.assert_(x is not None)
fp = hexlify(x['ssh-rsa'].get_fingerprint()).upper() fp = hexlify(x['ssh-rsa'].get_fingerprint()).upper()
self.assertEquals('E6684DB30E109B67B70FF1DC5C7F1363', fp) self.assertEquals(b('E6684DB30E109B67B70FF1DC5C7F1363'), fp)
i = 0 i = 0
for key in hostdict: for key in hostdict:
i += 1 i += 1
@ -112,6 +113,6 @@ class HostKeysTest (unittest.TestCase):
self.assertEquals(1, len(list(hostdict.values())[1])) self.assertEquals(1, len(list(hostdict.values())[1]))
self.assertEquals(1, len(list(hostdict.values())[2])) self.assertEquals(1, len(list(hostdict.values())[2]))
fp = hexlify(hostdict['secure.example.com']['ssh-rsa'].get_fingerprint()).upper() fp = hexlify(hostdict['secure.example.com']['ssh-rsa'].get_fingerprint()).upper()
self.assertEquals('7EC91BB336CB6D810B124B1353C32396', fp) self.assertEquals(b('7EC91BB336CB6D810B124B1353C32396'), fp)
fp = hexlify(hostdict['secure.example.com']['ssh-dss'].get_fingerprint()).upper() fp = hexlify(hostdict['secure.example.com']['ssh-dss'].get_fingerprint()).upper()
self.assertEquals('4478F0B9A23CC5182009FF755BC1D26C', fp) self.assertEquals(b('4478F0B9A23CC5182009FF755BC1D26C'), fp)

View File

@ -40,7 +40,7 @@ class FakeKey (object):
def asbytes(self): def asbytes(self):
return b('fake-key') return b('fake-key')
def sign_ssh_data(self, rng, H): def sign_ssh_data(self, rng, H):
return 'fake-sig' return b('fake-sig')
class FakeModulusPack (object): class FakeModulusPack (object):
@ -91,7 +91,7 @@ class KexTest (unittest.TestCase):
transport.server_mode = False transport.server_mode = False
kex = KexGroup1(transport) kex = KexGroup1(transport)
kex.start_kex() kex.start_kex()
x = '1E000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D4' x = b('1E000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D4')
self.assertEquals(x, hexlify(transport._message.asbytes()).upper()) self.assertEquals(x, hexlify(transport._message.asbytes()).upper())
self.assertEquals((paramiko.kex_group1._MSG_KEXDH_REPLY,), transport._expect) self.assertEquals((paramiko.kex_group1._MSG_KEXDH_REPLY,), transport._expect)
@ -102,10 +102,10 @@ class KexTest (unittest.TestCase):
msg.add_string('fake-sig') msg.add_string('fake-sig')
msg.rewind() msg.rewind()
kex.parse_next(paramiko.kex_group1._MSG_KEXDH_REPLY, msg) kex.parse_next(paramiko.kex_group1._MSG_KEXDH_REPLY, msg)
H = '03079780F3D3AD0B3C6DB30C8D21685F367A86D2' H = b('03079780F3D3AD0B3C6DB30C8D21685F367A86D2')
self.assertEquals(self.K, transport._K) self.assertEquals(self.K, transport._K)
self.assertEquals(H, hexlify(transport._H).upper()) self.assertEquals(H, hexlify(transport._H).upper())
self.assertEquals(('fake-host-key', 'fake-sig'), transport._verify) self.assertEquals((b('fake-host-key'), b('fake-sig')), transport._verify)
self.assert_(transport._activated) self.assert_(transport._activated)
def test_2_group1_server(self): def test_2_group1_server(self):
@ -119,8 +119,8 @@ class KexTest (unittest.TestCase):
msg.add_mpint(69) msg.add_mpint(69)
msg.rewind() msg.rewind()
kex.parse_next(paramiko.kex_group1._MSG_KEXDH_INIT, msg) kex.parse_next(paramiko.kex_group1._MSG_KEXDH_INIT, msg)
H = 'B16BF34DD10945EDE84E9C1EF24A14BFDC843389' H = b('B16BF34DD10945EDE84E9C1EF24A14BFDC843389')
x = '1F0000000866616B652D6B6579000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D40000000866616B652D736967' x = b('1F0000000866616B652D6B6579000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D40000000866616B652D736967')
self.assertEquals(self.K, transport._K) self.assertEquals(self.K, transport._K)
self.assertEquals(H, hexlify(transport._H).upper()) self.assertEquals(H, hexlify(transport._H).upper())
self.assertEquals(x, hexlify(transport._message.asbytes()).upper()) self.assertEquals(x, hexlify(transport._message.asbytes()).upper())
@ -131,7 +131,7 @@ class KexTest (unittest.TestCase):
transport.server_mode = False transport.server_mode = False
kex = KexGex(transport) kex = KexGex(transport)
kex.start_kex() kex.start_kex()
x = '22000004000000080000002000' x = b('22000004000000080000002000')
self.assertEquals(x, hexlify(transport._message.asbytes()).upper()) self.assertEquals(x, hexlify(transport._message.asbytes()).upper())
self.assertEquals((paramiko.kex_gex._MSG_KEXDH_GEX_GROUP,), transport._expect) self.assertEquals((paramiko.kex_gex._MSG_KEXDH_GEX_GROUP,), transport._expect)
@ -140,7 +140,7 @@ class KexTest (unittest.TestCase):
msg.add_mpint(FakeModulusPack.G) msg.add_mpint(FakeModulusPack.G)
msg.rewind() msg.rewind()
kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_GROUP, msg) kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_GROUP, msg)
x = '20000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D4' x = b('20000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D4')
self.assertEquals(x, hexlify(transport._message.asbytes()).upper()) self.assertEquals(x, hexlify(transport._message.asbytes()).upper())
self.assertEquals((paramiko.kex_gex._MSG_KEXDH_GEX_REPLY,), transport._expect) self.assertEquals((paramiko.kex_gex._MSG_KEXDH_GEX_REPLY,), transport._expect)
@ -150,10 +150,10 @@ class KexTest (unittest.TestCase):
msg.add_string('fake-sig') msg.add_string('fake-sig')
msg.rewind() msg.rewind()
kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_REPLY, msg) kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_REPLY, msg)
H = 'A265563F2FA87F1A89BF007EE90D58BE2E4A4BD0' H = b('A265563F2FA87F1A89BF007EE90D58BE2E4A4BD0')
self.assertEquals(self.K, transport._K) self.assertEquals(self.K, transport._K)
self.assertEquals(H, hexlify(transport._H).upper()) self.assertEquals(H, hexlify(transport._H).upper())
self.assertEquals(('fake-host-key', 'fake-sig'), transport._verify) self.assertEquals((b('fake-host-key'), b('fake-sig')), transport._verify)
self.assert_(transport._activated) self.assert_(transport._activated)
def test_4_gex_old_client(self): def test_4_gex_old_client(self):
@ -161,7 +161,7 @@ class KexTest (unittest.TestCase):
transport.server_mode = False transport.server_mode = False
kex = KexGex(transport) kex = KexGex(transport)
kex.start_kex(_test_old_style=True) kex.start_kex(_test_old_style=True)
x = '1E00000800' x = b('1E00000800')
self.assertEquals(x, hexlify(transport._message.asbytes()).upper()) self.assertEquals(x, hexlify(transport._message.asbytes()).upper())
self.assertEquals((paramiko.kex_gex._MSG_KEXDH_GEX_GROUP,), transport._expect) self.assertEquals((paramiko.kex_gex._MSG_KEXDH_GEX_GROUP,), transport._expect)
@ -170,7 +170,7 @@ class KexTest (unittest.TestCase):
msg.add_mpint(FakeModulusPack.G) msg.add_mpint(FakeModulusPack.G)
msg.rewind() msg.rewind()
kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_GROUP, msg) kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_GROUP, msg)
x = '20000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D4' x = b('20000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D4')
self.assertEquals(x, hexlify(transport._message.asbytes()).upper()) self.assertEquals(x, hexlify(transport._message.asbytes()).upper())
self.assertEquals((paramiko.kex_gex._MSG_KEXDH_GEX_REPLY,), transport._expect) self.assertEquals((paramiko.kex_gex._MSG_KEXDH_GEX_REPLY,), transport._expect)
@ -180,10 +180,10 @@ class KexTest (unittest.TestCase):
msg.add_string('fake-sig') msg.add_string('fake-sig')
msg.rewind() msg.rewind()
kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_REPLY, msg) kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_REPLY, msg)
H = '807F87B269EF7AC5EC7E75676808776A27D5864C' H = b('807F87B269EF7AC5EC7E75676808776A27D5864C')
self.assertEquals(self.K, transport._K) self.assertEquals(self.K, transport._K)
self.assertEquals(H, hexlify(transport._H).upper()) self.assertEquals(H, hexlify(transport._H).upper())
self.assertEquals(('fake-host-key', 'fake-sig'), transport._verify) self.assertEquals((b('fake-host-key'), b('fake-sig')), transport._verify)
self.assert_(transport._activated) self.assert_(transport._activated)
def test_5_gex_server(self): def test_5_gex_server(self):
@ -199,7 +199,7 @@ class KexTest (unittest.TestCase):
msg.add_int(4096) msg.add_int(4096)
msg.rewind() msg.rewind()
kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST, msg) kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST, msg)
x = '1F0000008100FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF0000000102' x = b('1F0000008100FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF0000000102')
self.assertEquals(x, hexlify(transport._message.asbytes()).upper()) self.assertEquals(x, hexlify(transport._message.asbytes()).upper())
self.assertEquals((paramiko.kex_gex._MSG_KEXDH_GEX_INIT,), transport._expect) self.assertEquals((paramiko.kex_gex._MSG_KEXDH_GEX_INIT,), transport._expect)
@ -208,8 +208,8 @@ class KexTest (unittest.TestCase):
msg.rewind() msg.rewind()
kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_INIT, msg) kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_INIT, msg)
K = 67592995013596137876033460028393339951879041140378510871612128162185209509220726296697886624612526735888348020498716482757677848959420073720160491114319163078862905400020959196386947926388406687288901564192071077389283980347784184487280885335302632305026248574716290537036069329724382811853044654824945750581 K = 67592995013596137876033460028393339951879041140378510871612128162185209509220726296697886624612526735888348020498716482757677848959420073720160491114319163078862905400020959196386947926388406687288901564192071077389283980347784184487280885335302632305026248574716290537036069329724382811853044654824945750581
H = 'CE754197C21BF3452863B4F44D0B3951F12516EF' H = b('CE754197C21BF3452863B4F44D0B3951F12516EF')
x = '210000000866616B652D6B6579000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D40000000866616B652D736967' x = b('210000000866616B652D6B6579000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D40000000866616B652D736967')
self.assertEquals(K, transport._K) self.assertEquals(K, transport._K)
self.assertEquals(H, hexlify(transport._H).upper()) self.assertEquals(H, hexlify(transport._H).upper())
self.assertEquals(x, hexlify(transport._message.asbytes()).upper()) self.assertEquals(x, hexlify(transport._message.asbytes()).upper())
@ -226,7 +226,7 @@ class KexTest (unittest.TestCase):
msg.add_int(2048) msg.add_int(2048)
msg.rewind() msg.rewind()
kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST_OLD, msg) kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST_OLD, msg)
x = '1F0000008100FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF0000000102' x = b('1F0000008100FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF0000000102')
self.assertEquals(x, hexlify(transport._message.asbytes()).upper()) self.assertEquals(x, hexlify(transport._message.asbytes()).upper())
self.assertEquals((paramiko.kex_gex._MSG_KEXDH_GEX_INIT,), transport._expect) self.assertEquals((paramiko.kex_gex._MSG_KEXDH_GEX_INIT,), transport._expect)
@ -235,8 +235,8 @@ class KexTest (unittest.TestCase):
msg.rewind() msg.rewind()
kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_INIT, msg) kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_INIT, msg)
K = 67592995013596137876033460028393339951879041140378510871612128162185209509220726296697886624612526735888348020498716482757677848959420073720160491114319163078862905400020959196386947926388406687288901564192071077389283980347784184487280885335302632305026248574716290537036069329724382811853044654824945750581 K = 67592995013596137876033460028393339951879041140378510871612128162185209509220726296697886624612526735888348020498716482757677848959420073720160491114319163078862905400020959196386947926388406687288901564192071077389283980347784184487280885335302632305026248574716290537036069329724382811853044654824945750581
H = 'B41A06B2E59043CEFC1AE16EC31F1E2D12EC455B' H = b('B41A06B2E59043CEFC1AE16EC31F1E2D12EC455B')
x = '210000000866616B652D6B6579000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D40000000866616B652D736967' x = b('210000000866616B652D6B6579000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D40000000866616B652D736967')
self.assertEquals(K, transport._K) self.assertEquals(K, transport._K)
self.assertEquals(H, hexlify(transport._H).upper()) self.assertEquals(H, hexlify(transport._H).upper())
self.assertEquals(x, hexlify(transport._message.asbytes()).upper()) self.assertEquals(x, hexlify(transport._message.asbytes()).upper())

View File

@ -25,7 +25,15 @@ from tests.loop import LoopSocket
from Crypto.Cipher import AES from Crypto.Cipher import AES
from Crypto.Hash import SHA, HMAC from Crypto.Hash import SHA, HMAC
from paramiko import Message, Packetizer, util from paramiko import Message, Packetizer, util
from paramiko.py3compat import byte_chr from paramiko.common import *
if PY3:
x55 = b'\x55'
x1f = b'\x1f'
else:
x55 = '\x55'
x1f = '\x1f'
class PacketizerTest (unittest.TestCase): class PacketizerTest (unittest.TestCase):
@ -36,8 +44,8 @@ class PacketizerTest (unittest.TestCase):
p = Packetizer(wsock) p = Packetizer(wsock)
p.set_log(util.get_logger('paramiko.transport')) p.set_log(util.get_logger('paramiko.transport'))
p.set_hexdump(True) p.set_hexdump(True)
cipher = AES.new('\x00' * 16, AES.MODE_CBC, '\x55' * 16) cipher = AES.new(zero_byte * 16, AES.MODE_CBC, x55 * 16)
p.set_outbound_cipher(cipher, 16, SHA, 12, '\x1f' * 20) p.set_outbound_cipher(cipher, 16, SHA, 12, x1f * 20)
# message has to be at least 16 bytes long, so we'll have at least one # message has to be at least 16 bytes long, so we'll have at least one
# block of data encrypted that contains zero random padding bytes # block of data encrypted that contains zero random padding bytes
@ -50,6 +58,9 @@ class PacketizerTest (unittest.TestCase):
data = rsock.recv(100) data = rsock.recv(100)
# 32 + 12 bytes of MAC = 44 # 32 + 12 bytes of MAC = 44
self.assertEquals(44, len(data)) self.assertEquals(44, len(data))
if PY3:
self.assertEquals(b'\x43\x91\x97\xbd\x5b\x50\xac\x25\x87\xc2\xc4\x6b\xc7\xe9\x38\xc0', data[:16])
else:
self.assertEquals('\x43\x91\x97\xbd\x5b\x50\xac\x25\x87\xc2\xc4\x6b\xc7\xe9\x38\xc0', data[:16]) self.assertEquals('\x43\x91\x97\xbd\x5b\x50\xac\x25\x87\xc2\xc4\x6b\xc7\xe9\x38\xc0', data[:16])
def test_2_read (self): def test_2_read (self):
@ -59,9 +70,12 @@ class PacketizerTest (unittest.TestCase):
p = Packetizer(rsock) p = Packetizer(rsock)
p.set_log(util.get_logger('paramiko.transport')) p.set_log(util.get_logger('paramiko.transport'))
p.set_hexdump(True) p.set_hexdump(True)
cipher = AES.new('\x00' * 16, AES.MODE_CBC, '\x55' * 16) cipher = AES.new(zero_byte * 16, AES.MODE_CBC, x55 * 16)
p.set_inbound_cipher(cipher, 16, SHA, 12, '\x1f' * 20) p.set_inbound_cipher(cipher, 16, SHA, 12, x1f * 20)
if PY3:
wsock.send(b'C\x91\x97\xbd[P\xac%\x87\xc2\xc4k\xc7\xe98\xc0' + \
b'\x90\xd2\x16V\rqsa8|L=\xfb\x97}\xe2n\x03\xb1\xa0\xc2\x1c\xd6AAL\xb4Y')
else:
wsock.send('C\x91\x97\xbd[P\xac%\x87\xc2\xc4k\xc7\xe98\xc0' + \ wsock.send('C\x91\x97\xbd[P\xac%\x87\xc2\xc4k\xc7\xe98\xc0' + \
'\x90\xd2\x16V\rqsa8|L=\xfb\x97}\xe2n\x03\xb1\xa0\xc2\x1c\xd6AAL\xb4Y') '\x90\xd2\x16V\rqsa8|L=\xfb\x97}\xe2n\x03\xb1\xa0\xc2\x1c\xd6AAL\xb4Y')
cmd, m = p.read_message() cmd, m = p.read_message()

View File

@ -23,7 +23,8 @@ Some unit tests for public/private key objects.
from binascii import hexlify, unhexlify from binascii import hexlify, unhexlify
import unittest import unittest
from paramiko import RSAKey, DSSKey, ECDSAKey, Message, util from paramiko import RSAKey, DSSKey, ECDSAKey, Message, util
from paramiko.common import rng, StringIO, byte_chr from paramiko.common import rng, StringIO, byte_chr, b, PY3
from tests.util import test_path
# from openssh's ssh-keygen # from openssh's ssh-keygen
PUB_RSA = 'ssh-rsa AAAAB3NzaC1yc2EAAAABIwAAAIEA049W6geFpmsljTwfvI1UmKWWJPNFI74+vNKTk4dmzkQY2yAMs6FhlvhlI8ysU4oj71ZsRYMecHbBbxdN79+JRFVYTKaLqjwGENeTd+yv4q+V2PvZv3fLnzApI3l7EJCqhWwJUHJ1jAkZzqDx0tyOL4uoZpww3nmE0kb3y21tH4c=' PUB_RSA = 'ssh-rsa AAAAB3NzaC1yc2EAAAABIwAAAIEA049W6geFpmsljTwfvI1UmKWWJPNFI74+vNKTk4dmzkQY2yAMs6FhlvhlI8ysU4oj71ZsRYMecHbBbxdN79+JRFVYTKaLqjwGENeTd+yv4q+V2PvZv3fLnzApI3l7EJCqhWwJUHJ1jAkZzqDx0tyOL4uoZpww3nmE0kb3y21tH4c='
@ -76,6 +77,12 @@ ADRvOqQ5R98Sxst765CAqXmRtz8vwoD96g==
-----END EC PRIVATE KEY----- -----END EC PRIVATE KEY-----
""" """
if PY3:
x1234 = b'\x01\x02\x03\x04'
else:
x1234 = '\x01\x02\x03\x04'
class KeyTest (unittest.TestCase): class KeyTest (unittest.TestCase):
def setUp(self): def setUp(self):
@ -86,14 +93,14 @@ class KeyTest (unittest.TestCase):
def test_1_generate_key_bytes(self): def test_1_generate_key_bytes(self):
from Crypto.Hash import MD5 from Crypto.Hash import MD5
key = util.generate_key_bytes(MD5, '\x01\x02\x03\x04', 'happy birthday', 30) key = util.generate_key_bytes(MD5, x1234, 'happy birthday', 30)
exp = unhexlify('61E1F272F4C1C4561586BD322498C0E924672780F47BB37DDA7D54019E64') exp = unhexlify('61E1F272F4C1C4561586BD322498C0E924672780F47BB37DDA7D54019E64')
self.assertEquals(exp, key) self.assertEquals(exp, key)
def test_2_load_rsa(self): def test_2_load_rsa(self):
key = RSAKey.from_private_key_file('tests/test_rsa.key') key = RSAKey.from_private_key_file(test_path('test_rsa.key'))
self.assertEquals('ssh-rsa', key.get_name()) self.assertEquals('ssh-rsa', key.get_name())
exp_rsa = FINGER_RSA.split()[1].replace(':', '') exp_rsa = b(FINGER_RSA.split()[1].replace(':', ''))
my_rsa = hexlify(key.get_fingerprint()) my_rsa = hexlify(key.get_fingerprint())
self.assertEquals(exp_rsa, my_rsa) self.assertEquals(exp_rsa, my_rsa)
self.assertEquals(PUB_RSA.split()[1], key.get_base64()) self.assertEquals(PUB_RSA.split()[1], key.get_base64())
@ -107,18 +114,18 @@ class KeyTest (unittest.TestCase):
self.assertEquals(key, key2) self.assertEquals(key, key2)
def test_3_load_rsa_password(self): def test_3_load_rsa_password(self):
key = RSAKey.from_private_key_file('tests/test_rsa_password.key', 'television') key = RSAKey.from_private_key_file(test_path('test_rsa_password.key'), 'television')
self.assertEquals('ssh-rsa', key.get_name()) self.assertEquals('ssh-rsa', key.get_name())
exp_rsa = FINGER_RSA.split()[1].replace(':', '') exp_rsa = b(FINGER_RSA.split()[1].replace(':', ''))
my_rsa = hexlify(key.get_fingerprint()) my_rsa = hexlify(key.get_fingerprint())
self.assertEquals(exp_rsa, my_rsa) self.assertEquals(exp_rsa, my_rsa)
self.assertEquals(PUB_RSA.split()[1], key.get_base64()) self.assertEquals(PUB_RSA.split()[1], key.get_base64())
self.assertEquals(1024, key.get_bits()) self.assertEquals(1024, key.get_bits())
def test_4_load_dss(self): def test_4_load_dss(self):
key = DSSKey.from_private_key_file('tests/test_dss.key') key = DSSKey.from_private_key_file(test_path('test_dss.key'))
self.assertEquals('ssh-dss', key.get_name()) self.assertEquals('ssh-dss', key.get_name())
exp_dss = FINGER_DSS.split()[1].replace(':', '') exp_dss = b(FINGER_DSS.split()[1].replace(':', ''))
my_dss = hexlify(key.get_fingerprint()) my_dss = hexlify(key.get_fingerprint())
self.assertEquals(exp_dss, my_dss) self.assertEquals(exp_dss, my_dss)
self.assertEquals(PUB_DSS.split()[1], key.get_base64()) self.assertEquals(PUB_DSS.split()[1], key.get_base64())
@ -132,9 +139,9 @@ class KeyTest (unittest.TestCase):
self.assertEquals(key, key2) self.assertEquals(key, key2)
def test_5_load_dss_password(self): def test_5_load_dss_password(self):
key = DSSKey.from_private_key_file('tests/test_dss_password.key', 'television') key = DSSKey.from_private_key_file(test_path('test_dss_password.key'), 'television')
self.assertEquals('ssh-dss', key.get_name()) self.assertEquals('ssh-dss', key.get_name())
exp_dss = FINGER_DSS.split()[1].replace(':', '') exp_dss = b(FINGER_DSS.split()[1].replace(':', ''))
my_dss = hexlify(key.get_fingerprint()) my_dss = hexlify(key.get_fingerprint())
self.assertEquals(exp_dss, my_dss) self.assertEquals(exp_dss, my_dss)
self.assertEquals(PUB_DSS.split()[1], key.get_base64()) self.assertEquals(PUB_DSS.split()[1], key.get_base64())
@ -142,7 +149,7 @@ class KeyTest (unittest.TestCase):
def test_6_compare_rsa(self): def test_6_compare_rsa(self):
# verify that the private & public keys compare equal # verify that the private & public keys compare equal
key = RSAKey.from_private_key_file('tests/test_rsa.key') key = RSAKey.from_private_key_file(test_path('test_rsa.key'))
self.assertEquals(key, key) self.assertEquals(key, key)
pub = RSAKey(data=key.asbytes()) pub = RSAKey(data=key.asbytes())
self.assert_(key.can_sign()) self.assert_(key.can_sign())
@ -151,7 +158,7 @@ class KeyTest (unittest.TestCase):
def test_7_compare_dss(self): def test_7_compare_dss(self):
# verify that the private & public keys compare equal # verify that the private & public keys compare equal
key = DSSKey.from_private_key_file('tests/test_dss.key') key = DSSKey.from_private_key_file(test_path('test_dss.key'))
self.assertEquals(key, key) self.assertEquals(key, key)
pub = DSSKey(data=key.asbytes()) pub = DSSKey(data=key.asbytes())
self.assert_(key.can_sign()) self.assert_(key.can_sign())
@ -160,8 +167,8 @@ class KeyTest (unittest.TestCase):
def test_8_sign_rsa(self): def test_8_sign_rsa(self):
# verify that the rsa private key can sign and verify # verify that the rsa private key can sign and verify
key = RSAKey.from_private_key_file('tests/test_rsa.key') key = RSAKey.from_private_key_file(test_path('test_rsa.key'))
msg = key.sign_ssh_data(rng, 'ice weasels') msg = key.sign_ssh_data(rng, b('ice weasels'))
self.assert_(type(msg) is Message) self.assert_(type(msg) is Message)
msg.rewind() msg.rewind()
self.assertEquals('ssh-rsa', msg.get_text()) self.assertEquals('ssh-rsa', msg.get_text())
@ -169,12 +176,12 @@ class KeyTest (unittest.TestCase):
self.assertEquals(sig, msg.get_binary()) self.assertEquals(sig, msg.get_binary())
msg.rewind() msg.rewind()
pub = RSAKey(data=key.asbytes()) pub = RSAKey(data=key.asbytes())
self.assert_(pub.verify_ssh_sig('ice weasels', msg)) self.assert_(pub.verify_ssh_sig(b('ice weasels'), msg))
def test_9_sign_dss(self): def test_9_sign_dss(self):
# verify that the dss private key can sign and verify # verify that the dss private key can sign and verify
key = DSSKey.from_private_key_file('tests/test_dss.key') key = DSSKey.from_private_key_file(test_path('test_dss.key'))
msg = key.sign_ssh_data(rng, 'ice weasels') msg = key.sign_ssh_data(rng, b('ice weasels'))
self.assert_(type(msg) is Message) self.assert_(type(msg) is Message)
msg.rewind() msg.rewind()
self.assertEquals('ssh-dss', msg.get_text()) self.assertEquals('ssh-dss', msg.get_text())
@ -184,24 +191,24 @@ class KeyTest (unittest.TestCase):
self.assertEquals(40, len(msg.get_binary())) self.assertEquals(40, len(msg.get_binary()))
msg.rewind() msg.rewind()
pub = DSSKey(data=key.asbytes()) pub = DSSKey(data=key.asbytes())
self.assert_(pub.verify_ssh_sig('ice weasels', msg)) self.assert_(pub.verify_ssh_sig(b('ice weasels'), msg))
def test_A_generate_rsa(self): def test_A_generate_rsa(self):
key = RSAKey.generate(1024) key = RSAKey.generate(1024)
msg = key.sign_ssh_data(rng, 'jerri blank') msg = key.sign_ssh_data(rng, b('jerri blank'))
msg.rewind() msg.rewind()
self.assert_(key.verify_ssh_sig('jerri blank', msg)) self.assert_(key.verify_ssh_sig(b('jerri blank'), msg))
def test_B_generate_dss(self): def test_B_generate_dss(self):
key = DSSKey.generate(1024) key = DSSKey.generate(1024)
msg = key.sign_ssh_data(rng, 'jerri blank') msg = key.sign_ssh_data(rng, b('jerri blank'))
msg.rewind() msg.rewind()
self.assert_(key.verify_ssh_sig('jerri blank', msg)) self.assert_(key.verify_ssh_sig(b('jerri blank'), msg))
def test_10_load_ecdsa(self): def test_10_load_ecdsa(self):
key = ECDSAKey.from_private_key_file('tests/test_ecdsa.key') key = ECDSAKey.from_private_key_file(test_path('test_ecdsa.key'))
self.assertEquals('ecdsa-sha2-nistp256', key.get_name()) self.assertEquals('ecdsa-sha2-nistp256', key.get_name())
exp_ecdsa = FINGER_ECDSA.split()[1].replace(':', '') exp_ecdsa = b(FINGER_ECDSA.split()[1].replace(':', ''))
my_ecdsa = hexlify(key.get_fingerprint()) my_ecdsa = hexlify(key.get_fingerprint())
self.assertEquals(exp_ecdsa, my_ecdsa) self.assertEquals(exp_ecdsa, my_ecdsa)
self.assertEquals(PUB_ECDSA.split()[1], key.get_base64()) self.assertEquals(PUB_ECDSA.split()[1], key.get_base64())
@ -215,9 +222,9 @@ class KeyTest (unittest.TestCase):
self.assertEquals(key, key2) self.assertEquals(key, key2)
def test_11_load_ecdsa_password(self): def test_11_load_ecdsa_password(self):
key = ECDSAKey.from_private_key_file('tests/test_ecdsa_password.key', 'television') key = ECDSAKey.from_private_key_file(test_path('test_ecdsa_password.key'), b('television'))
self.assertEquals('ecdsa-sha2-nistp256', key.get_name()) self.assertEquals('ecdsa-sha2-nistp256', key.get_name())
exp_ecdsa = FINGER_ECDSA.split()[1].replace(':', '') exp_ecdsa = b(FINGER_ECDSA.split()[1].replace(':', ''))
my_ecdsa = hexlify(key.get_fingerprint()) my_ecdsa = hexlify(key.get_fingerprint())
self.assertEquals(exp_ecdsa, my_ecdsa) self.assertEquals(exp_ecdsa, my_ecdsa)
self.assertEquals(PUB_ECDSA.split()[1], key.get_base64()) self.assertEquals(PUB_ECDSA.split()[1], key.get_base64())
@ -225,7 +232,7 @@ class KeyTest (unittest.TestCase):
def test_12_compare_ecdsa(self): def test_12_compare_ecdsa(self):
# verify that the private & public keys compare equal # verify that the private & public keys compare equal
key = ECDSAKey.from_private_key_file('tests/test_ecdsa.key') key = ECDSAKey.from_private_key_file(test_path('test_ecdsa.key'))
self.assertEquals(key, key) self.assertEquals(key, key)
pub = ECDSAKey(data=key.asbytes()) pub = ECDSAKey(data=key.asbytes())
self.assert_(key.can_sign()) self.assert_(key.can_sign())
@ -234,8 +241,8 @@ class KeyTest (unittest.TestCase):
def test_13_sign_ecdsa(self): def test_13_sign_ecdsa(self):
# verify that the rsa private key can sign and verify # verify that the rsa private key can sign and verify
key = ECDSAKey.from_private_key_file('tests/test_ecdsa.key') key = ECDSAKey.from_private_key_file(test_path('test_ecdsa.key'))
msg = key.sign_ssh_data(rng, 'ice weasels') msg = key.sign_ssh_data(rng, b('ice weasels'))
self.assert_(type(msg) is Message) self.assert_(type(msg) is Message)
msg.rewind() msg.rewind()
self.assertEquals('ecdsa-sha2-nistp256', msg.get_text()) self.assertEquals('ecdsa-sha2-nistp256', msg.get_text())
@ -246,4 +253,4 @@ class KeyTest (unittest.TestCase):
msg.rewind() msg.rewind()
pub = ECDSAKey(data=key.asbytes()) pub = ECDSAKey(data=key.asbytes())
self.assert_(pub.verify_ssh_sig('ice weasels', msg)) self.assert_(pub.verify_ssh_sig(b('ice weasels'), msg))

View File

@ -27,7 +27,7 @@ import unittest
from Crypto.Hash import SHA from Crypto.Hash import SHA
import paramiko.util import paramiko.util
from paramiko.util import lookup_ssh_host_config as host_config from paramiko.util import lookup_ssh_host_config as host_config
from paramiko.py3compat import StringIO, byte_ord from paramiko.py3compat import StringIO, byte_ord, b
from tests.util import ParamikoTest from tests.util import ParamikoTest
@ -137,7 +137,7 @@ class UtilTest(ParamikoTest):
) )
def test_4_generate_key_bytes(self): def test_4_generate_key_bytes(self):
x = paramiko.util.generate_key_bytes(SHA, 'ABCDEFGH', 'This is my secret passphrase.', 64) x = paramiko.util.generate_key_bytes(SHA, b('ABCDEFGH'), 'This is my secret passphrase.', 64)
hex = ''.join(['%02x' % byte_ord(c) for c in x]) hex = ''.join(['%02x' % byte_ord(c) for c in x])
self.assertEquals(hex, '9110e2f6793b69363e58173e9436b13a5a4b339005741d5c680e505f57d871347b4239f14fb5c46e857d5e100424873ba849ac699cea98d729e57b3e84378e8b') self.assertEquals(hex, '9110e2f6793b69363e58173e9436b13a5a4b339005741d5c680e505f57d871347b4239f14fb5c46e857d5e100424873ba849ac699cea98d729e57b3e84378e8b')
@ -151,7 +151,7 @@ class UtilTest(ParamikoTest):
self.assertEquals(1, len(list(hostdict.values())[0])) self.assertEquals(1, len(list(hostdict.values())[0]))
self.assertEquals(1, len(list(hostdict.values())[1])) self.assertEquals(1, len(list(hostdict.values())[1]))
fp = hexlify(hostdict['secure.example.com']['ssh-rsa'].get_fingerprint()).upper() fp = hexlify(hostdict['secure.example.com']['ssh-rsa'].get_fingerprint()).upper()
self.assertEquals('E6684DB30E109B67B70FF1DC5C7F1363', fp) self.assertEquals(b('E6684DB30E109B67B70FF1DC5C7F1363'), fp)
finally: finally:
os.unlink('hostfile.temp') os.unlink('hostfile.temp')