add PKey.from_private_key to read from a file object
This commit is contained in:
Robey Pointer 2006-06-26 23:41:06 -07:00
parent d81758f1ff
commit 4fa4fdee4b
4 changed files with 69 additions and 17 deletions

View File

@ -36,14 +36,16 @@ class DSSKey (PKey):
Representation of a DSS key which can be used to sign an verify SSH2
data.
"""
p = None
q = None
g = None
y = None
x = None
def __init__(self, msg=None, data=None, filename=None, password=None, vals=None):
def __init__(self, msg=None, data=None, filename=None, password=None, vals=None, file_obj=None):
self.p = None
self.q = None
self.g = None
self.y = None
self.x = None
if file_obj is not None:
self._from_private_key(file_obj, password)
return
if filename is not None:
self._from_private_key_file(filename, password)
return
@ -171,9 +173,16 @@ class DSSKey (PKey):
def _from_private_key_file(self, filename, password):
data = self._read_private_key_file('DSA', filename, password)
self._decode_key(data)
def _from_private_key(self, file_obj, password):
data = self._read_private_key('DSA', file_obj, password)
self._decode_key(data)
def _decode_key(self, data):
# private key file contains:
# DSAPrivateKey = { version = 0, p, q, g, y, x }
data = self._read_private_key_file('DSA', filename, password)
try:
keylist = BER(data).decode()
except BERException, x:

View File

@ -197,6 +197,30 @@ class PKey (object):
return key
from_private_key_file = classmethod(from_private_key_file)
def from_private_key(cls, file_obj, password=None):
"""
Create a key object by reading a private key from a file (or file-like)
object. If the private key is encrypted and C{password} is not C{None},
the given password will be used to decrypt the key (otherwise
L{PasswordRequiredException} is thrown).
@param file_obj: the file to read from
@type file_obj: file
@param password: an optional password to use to decrypt the key, if it's
encrypted
@type password: str
@return: a new key object based on the given private key
@rtype: L{PKey}
@raise IOError: if there was an error reading the key
@raise PasswordRequiredException: if the private key file is encrypted,
and C{password} is C{None}
@raise SSHException: if the key file is invalid
"""
key = cls(file_obj=file_obj, password=password)
return key
from_private_key = classmethod(from_private_key)
def write_private_key_file(self, filename, password=None):
"""
Write private key contents into a file. If the password is not
@ -251,8 +275,12 @@ class PKey (object):
@raise SSHException: if the key file is invalid.
"""
f = open(filename, 'r')
lines = f.readlines()
data = self._read_private_key(tag, f, password)
f.close()
return data
def _read_private_key(self, tag, f, password=None):
lines = f.readlines()
start = 0
while (start < len(lines)) and (lines[start].strip() != '-----BEGIN ' + tag + ' PRIVATE KEY-----'):
start += 1

View File

@ -38,13 +38,15 @@ class RSAKey (PKey):
data.
"""
n = None
e = None
d = None
p = None
q = None
def __init__(self, msg=None, data=None, filename=None, password=None, vals=None):
def __init__(self, msg=None, data=None, filename=None, password=None, vals=None, file_obj=None):
self.n = None
self.e = None
self.d = None
self.p = None
self.q = None
if file_obj is not None:
self._from_private_key(file_obj, password)
return
if filename is not None:
self._from_private_key_file(filename, password)
return
@ -159,9 +161,16 @@ class RSAKey (PKey):
return '\x00\x01' + filler + '\x00' + SHA1_DIGESTINFO + data
def _from_private_key_file(self, filename, password):
data = self._read_private_key_file('RSA', filename, password)
self._decode_key(data)
def _from_private_key(self, file_obj, password):
data = self._read_private_key('RSA', file_obj, password)
self._decode_key(data)
def _decode_key(self, data):
# private key file contains:
# RSAPrivateKey = { version = 0, n, e, d, p, q, d mod p-1, d mod q-1, q**-1 mod p }
data = self._read_private_key_file('RSA', filename, password)
try:
keylist = BER(data).decode()
except BERException:

View File

@ -91,6 +91,9 @@ class KeyTest (unittest.TestCase):
s = StringIO.StringIO()
key.write_private_key(s)
self.assertEquals(RSA_PRIVATE_OUT, s.getvalue())
s.seek(0)
key2 = RSAKey.from_private_key(s)
self.assertEquals(key, key2)
def test_3_load_rsa_password(self):
key = RSAKey.from_private_key_file('tests/test_rsa_password.key', 'television')
@ -113,6 +116,9 @@ class KeyTest (unittest.TestCase):
s = StringIO.StringIO()
key.write_private_key(s)
self.assertEquals(DSS_PRIVATE_OUT, s.getvalue())
s.seek(0)
key2 = DSSKey.from_private_key(s)
self.assertEquals(key, key2)
def test_5_load_dss_password(self):
key = DSSKey.from_private_key_file('tests/test_dss_password.key', 'television')