diff --git a/paramiko/dsskey.py b/paramiko/dsskey.py index b3914b4..a004597 100644 --- a/paramiko/dsskey.py +++ b/paramiko/dsskey.py @@ -129,7 +129,7 @@ class DSSKey (PKey): dss = DSA.construct((long(self.y), long(self.g), long(self.p), long(self.q))) return dss.verify(sigM, (sigR, sigS)) - def write_private_key_file(self, filename, password=None): + def _encode_key(self): if self.x is None: raise SSHException('Not enough key information') keylist = [ 0, self.p, self.q, self.g, self.y, self.x ] @@ -138,7 +138,13 @@ class DSSKey (PKey): b.encode(keylist) except BERException: raise SSHException('Unable to create ber encoding of key') - self._write_private_key_file('DSA', filename, str(b), password) + return str(b) + + def write_private_key_file(self, filename, password=None): + self._write_private_key_file('DSA', filename, self._encode_key(), password) + + def write_private_key(self, file_obj, password=None): + self._write_private_key('DSA', file_obj, self._encode_key(), password) def generate(bits=1024, progress_func=None): """ diff --git a/paramiko/pkey.py b/paramiko/pkey.py index 0d00f7e..053869e 100644 --- a/paramiko/pkey.py +++ b/paramiko/pkey.py @@ -180,18 +180,18 @@ class PKey (object): exist in all subclasses of PKey (such as L{RSAKey} or L{DSSKey}), but is useless on the abstract PKey class. - @param filename: name of the file to read. + @param filename: name of the file to read @type filename: str @param password: an optional password to use to decrypt the key file, if it's encrypted @type password: str - @return: a new key object based on the given private key. + @return: a new key object based on the given private key @rtype: L{PKey} - @raise IOError: if there was an error reading the file. + @raise IOError: if there was an error reading the file @raise PasswordRequiredException: if the private key file is - encrypted, and C{password} is C{None}. - @raise SSHException: if the key file is invalid. + encrypted, and C{password} is C{None} + @raise SSHException: if the key file is invalid """ key = cls(filename=filename, password=password) return key @@ -202,13 +202,28 @@ class PKey (object): Write private key contents into a file. If the password is not C{None}, the key is encrypted before writing. - @param filename: name of the file to write. + @param filename: name of the file to write @type filename: str - @param password: an optional password to use to encrypt the key file. + @param password: an optional password to use to encrypt the key file @type password: str - @raise IOError: if there was an error writing the file. - @raise SSHException: if the key is invalid. + @raise IOError: if there was an error writing the file + @raise SSHException: if the key is invalid + """ + raise Exception('Not implemented in PKey') + + def write_private_key(self, file_obj, password=None): + """ + Write private key contents into a file (or file-like) object. If the + password is not C{None}, the key is encrypted before writing. + + @param file_obj: the file object to write into + @type file_obj: file + @param password: an optional password to use to encrypt the key + @type password: str + + @raise IOError: if there was an error writing to the file + @raise SSHException: if the key is invalid """ raise Exception('Not implemented in PKey') @@ -304,6 +319,10 @@ class PKey (object): f = open(filename, 'w', 0600) # grrr... the mode doesn't always take hold os.chmod(filename, 0600) + self._write_private_key(tag, f, data, password) + f.close() + + def _write_private_key(self, tag, f, data, password=None): f.write('-----BEGIN %s PRIVATE KEY-----\n' % tag) if password is not None: # since we only support one cipher here, use it @@ -330,4 +349,3 @@ class PKey (object): f.write(s) f.write('\n') f.write('-----END %s PRIVATE KEY-----\n' % tag) - f.close() diff --git a/paramiko/rsakey.py b/paramiko/rsakey.py index e146d4c..414815c 100644 --- a/paramiko/rsakey.py +++ b/paramiko/rsakey.py @@ -103,7 +103,7 @@ class RSAKey (PKey): rsa = RSA.construct((long(self.n), long(self.e))) return rsa.verify(hash_obj, (sig,)) - def write_private_key_file(self, filename, password=None): + def _encode_key(self): if (self.p is None) or (self.q is None): raise SSHException('Not enough key info to write private key file') keylist = [ 0, self.n, self.e, self.d, self.p, self.q, @@ -114,7 +114,13 @@ class RSAKey (PKey): b.encode(keylist) except BERException: raise SSHException('Unable to create ber encoding of key') - self._write_private_key_file('RSA', filename, str(b), password) + return str(b) + + def write_private_key_file(self, filename, password=None): + self._write_private_key_file('RSA', filename, self._encode_key(), password) + + def write_private_key(self, file_obj, password=None): + self._write_private_key('RSA', file_obj, self._encode_key(), password) def generate(bits, progress_func=None): """ diff --git a/tests/test_pkey.py b/tests/test_pkey.py index e56edb1..d4b826a 100644 --- a/tests/test_pkey.py +++ b/tests/test_pkey.py @@ -20,6 +20,7 @@ Some unit tests for public/private key objects. """ +import StringIO import unittest from paramiko import RSAKey, DSSKey, Message, util, randpool @@ -30,6 +31,39 @@ FINGER_RSA = '1024 60:73:38:44:cb:51:86:65:7f:de:da:a2:2b:5a:57:d5' FINGER_DSS = '1024 44:78:f0:b9:a2:3c:c5:18:20:09:ff:75:5b:c1:d2:6c' SIGNED_RSA = '20:d7:8a:31:21:cb:f7:92:12:f2:a4:89:37:f5:78:af:e6:16:b6:25:b9:97:3d:a2:cd:5f:ca:20:21:73:4c:ad:34:73:8f:20:77:28:e2:94:15:08:d8:91:40:7a:85:83:bf:18:37:95:dc:54:1a:9b:88:29:6c:73:ca:38:b4:04:f1:56:b9:f2:42:9d:52:1b:29:29:b4:4f:fd:c9:2d:af:47:d2:40:76:30:f3:63:45:0c:d9:1d:43:86:0f:1c:70:e2:93:12:34:f3:ac:c5:0a:2f:14:50:66:59:f1:88:ee:c1:4a:e9:d1:9c:4e:46:f0:0e:47:6f:38:74:f1:44:a8' +RSA_PRIVATE_OUT = """\ +-----BEGIN RSA PRIVATE KEY----- +MIICXAIBAAKCAIEA049W6geFpmsljTwfvI1UmKWWJPNFI74+vNKTk4dmzkQY2yAM +s6FhlvhlI8ysU4oj71ZsRYMecHbBbxdN79+JRFVYTKaLqjwGENeTd+yv4q+V2PvZ +v3fLnzApI3l7EJCqhWwJUHJ1jAkZzqDx0tyOL4uoZpww3nmE0kb3y21tH4cCASMC +ggCAEiI6plhqipt4P05L3PYr0pHZq2VPEbE4k9eI/gRKo/c1VJxY3DJnc1cenKsk +trQRtW3OxCEufqsX5PNec6VyKkW+Ox6beJjMKm4KF8ZDpKi9Nw6MdX3P6Gele9D9 ++ieyhVFljrnAqcXsgChTBOYlL2imqCs3qRGAJ3cMBIAx3VsCQQD3pIFVYW398kE0 +n0e1icEpkbDRV4c5iZVhu8xKy2yyfy6f6lClSb2+Ub9uns7F3+b5v0pYSHbE9+/r +OpRq83AfAkEA2rMZlr8SnMXgnyka2LuggA9QgMYy18hyao1dUxySubNDa9N+q2QR +mwDisTUgRFHKIlDHoQmzPbXAmYZX1YlDmQJBAPCRLS5epV0XOAc7pL762OaNhzHC +veAfQKgVhKBt105PqaKpGyQ5AXcNlWQlPeTK4GBTbMrKDPna6RBkyrEJvV8CQBK+ +5O+p+kfztCrmRCE0p1tvBuZ3Y3GU1ptrM+KNa6mEZN1bRV8l1Z+SXJLYqv6Kquz/ +nBUeFq2Em3rfoSDugiMCQDyG3cxD5dKX3IgkhLyBWls/FLDk4x/DQ+NUTu0F1Cu6 +JJye+5ARLkL0EweMXf0tmIYfWItDLsWB0fKg/56h0js= +-----END RSA PRIVATE KEY----- +""" + +DSS_PRIVATE_OUT = """\ +-----BEGIN DSA PRIVATE KEY----- +MIIBvgIBAAKCAIEA54GmA2d9HOv+3CYBBG7ZfBYCncIW2tWe6Dqzp+DCP+guNhtW +2MDLqmX+HQQoJbHat/Uh63I2xPFaueID0jod4OPrlfUXIOSDqDy28Kdo0Hxen9RS +G7Me4awwiKlHEHHD0sXrTwSplyPUTfK2S2hbkHk5yOuQSjPfEbsL6ukiNi8CFQDw +z4UnmsGiSNu5iqjn3uTzwUpshwKCAIEAkxfFeY8P2wZpDjX0MimZl5wkoFQDL25c +PzGBuB4OnB8NoUk/yjAHIIpEShw8V+LzouMK5CTJQo5+Ngw3qIch/WgRmMHy4kBq +1SsXMjQCte1So6HBMvBPIW5SiMTmjCfZZiw4AYHK+B/JaOwaG9yRg2Ejg4Ok10+X +FDxlqZo8Y+wCggCARmR7CCPjodxASvRbIyzaVpZoJ/Z6x7dAumV+ysrV1BVYd0lY +ukmnjO1kKBWApqpH1ve9XDQYN8zgxM4b16L21kpoWQnZtXrY3GZ4/it9kUgyB7+N +wacIBlXa8cMDL7Q/69o0d54U0X/NeX5QxuYR6OMJlrkQB7oiW/P/1mwjQgECFGI9 +QPSch9pT9XHqn+1rZ4bK+QGA +-----END DSA PRIVATE KEY----- +""" + class KeyTest (unittest.TestCase): @@ -54,6 +88,10 @@ class KeyTest (unittest.TestCase): self.assertEquals(PUB_RSA.split()[1], key.get_base64()) self.assertEquals(1024, key.get_bits()) + s = StringIO.StringIO() + key.write_private_key(s) + self.assertEquals(RSA_PRIVATE_OUT, s.getvalue()) + def test_3_load_rsa_password(self): key = RSAKey.from_private_key_file('tests/test_rsa_password.key', 'television') self.assertEquals('ssh-rsa', key.get_name()) @@ -72,6 +110,10 @@ class KeyTest (unittest.TestCase): self.assertEquals(PUB_DSS.split()[1], key.get_base64()) self.assertEquals(1024, key.get_bits()) + s = StringIO.StringIO() + key.write_private_key(s) + self.assertEquals(DSS_PRIVATE_OUT, s.getvalue()) + def test_5_load_dss_password(self): key = DSSKey.from_private_key_file('tests/test_dss_password.key', 'television') self.assertEquals('ssh-dss', key.get_name())