diff --git a/paramiko/dsskey.py b/paramiko/dsskey.py index a004597..2677e4a 100644 --- a/paramiko/dsskey.py +++ b/paramiko/dsskey.py @@ -36,14 +36,16 @@ class DSSKey (PKey): Representation of a DSS key which can be used to sign an verify SSH2 data. """ - - p = None - q = None - g = None - y = None - x = None - def __init__(self, msg=None, data=None, filename=None, password=None, vals=None): + def __init__(self, msg=None, data=None, filename=None, password=None, vals=None, file_obj=None): + self.p = None + self.q = None + self.g = None + self.y = None + self.x = None + if file_obj is not None: + self._from_private_key(file_obj, password) + return if filename is not None: self._from_private_key_file(filename, password) return @@ -171,9 +173,16 @@ class DSSKey (PKey): def _from_private_key_file(self, filename, password): + data = self._read_private_key_file('DSA', filename, password) + self._decode_key(data) + + def _from_private_key(self, file_obj, password): + data = self._read_private_key('DSA', file_obj, password) + self._decode_key(data) + + def _decode_key(self, data): # private key file contains: # DSAPrivateKey = { version = 0, p, q, g, y, x } - data = self._read_private_key_file('DSA', filename, password) try: keylist = BER(data).decode() except BERException, x: diff --git a/paramiko/pkey.py b/paramiko/pkey.py index 053869e..e1aef88 100644 --- a/paramiko/pkey.py +++ b/paramiko/pkey.py @@ -197,6 +197,30 @@ class PKey (object): return key from_private_key_file = classmethod(from_private_key_file) + def from_private_key(cls, file_obj, password=None): + """ + Create a key object by reading a private key from a file (or file-like) + object. If the private key is encrypted and C{password} is not C{None}, + the given password will be used to decrypt the key (otherwise + L{PasswordRequiredException} is thrown). + + @param file_obj: the file to read from + @type file_obj: file + @param password: an optional password to use to decrypt the key, if it's + encrypted + @type password: str + @return: a new key object based on the given private key + @rtype: L{PKey} + + @raise IOError: if there was an error reading the key + @raise PasswordRequiredException: if the private key file is encrypted, + and C{password} is C{None} + @raise SSHException: if the key file is invalid + """ + key = cls(file_obj=file_obj, password=password) + return key + from_private_key = classmethod(from_private_key) + def write_private_key_file(self, filename, password=None): """ Write private key contents into a file. If the password is not @@ -251,8 +275,12 @@ class PKey (object): @raise SSHException: if the key file is invalid. """ f = open(filename, 'r') - lines = f.readlines() + data = self._read_private_key(tag, f, password) f.close() + return data + + def _read_private_key(self, tag, f, password=None): + lines = f.readlines() start = 0 while (start < len(lines)) and (lines[start].strip() != '-----BEGIN ' + tag + ' PRIVATE KEY-----'): start += 1 diff --git a/paramiko/rsakey.py b/paramiko/rsakey.py index 414815c..8bd925d 100644 --- a/paramiko/rsakey.py +++ b/paramiko/rsakey.py @@ -38,13 +38,15 @@ class RSAKey (PKey): data. """ - n = None - e = None - d = None - p = None - q = None - - def __init__(self, msg=None, data=None, filename=None, password=None, vals=None): + def __init__(self, msg=None, data=None, filename=None, password=None, vals=None, file_obj=None): + self.n = None + self.e = None + self.d = None + self.p = None + self.q = None + if file_obj is not None: + self._from_private_key(file_obj, password) + return if filename is not None: self._from_private_key_file(filename, password) return @@ -159,9 +161,16 @@ class RSAKey (PKey): return '\x00\x01' + filler + '\x00' + SHA1_DIGESTINFO + data def _from_private_key_file(self, filename, password): + data = self._read_private_key_file('RSA', filename, password) + self._decode_key(data) + + def _from_private_key(self, file_obj, password): + data = self._read_private_key('RSA', file_obj, password) + self._decode_key(data) + + def _decode_key(self, data): # private key file contains: # RSAPrivateKey = { version = 0, n, e, d, p, q, d mod p-1, d mod q-1, q**-1 mod p } - data = self._read_private_key_file('RSA', filename, password) try: keylist = BER(data).decode() except BERException: diff --git a/tests/test_pkey.py b/tests/test_pkey.py index d4b826a..d5419cd 100644 --- a/tests/test_pkey.py +++ b/tests/test_pkey.py @@ -91,6 +91,9 @@ class KeyTest (unittest.TestCase): s = StringIO.StringIO() key.write_private_key(s) self.assertEquals(RSA_PRIVATE_OUT, s.getvalue()) + s.seek(0) + key2 = RSAKey.from_private_key(s) + self.assertEquals(key, key2) def test_3_load_rsa_password(self): key = RSAKey.from_private_key_file('tests/test_rsa_password.key', 'television') @@ -113,6 +116,9 @@ class KeyTest (unittest.TestCase): s = StringIO.StringIO() key.write_private_key(s) self.assertEquals(DSS_PRIVATE_OUT, s.getvalue()) + s.seek(0) + key2 = DSSKey.from_private_key(s) + self.assertEquals(key, key2) def test_5_load_dss_password(self): key = DSSKey.from_private_key_file('tests/test_dss_password.key', 'television')