[project @ Arch-1:robey@lag.net--2003-public%secsh--dev--1.0--patch-40]
add dss key generation too, and fix some bugs added the ability to generate dss keys and write private dss key files, similar to rsa. in the process, fixed a couple of bugs with ber encoding and writing password-encrypted key files. the key has to be padded to the iblock size of the cipher -- it's very difficult to determine how the others do this, so i just add random bytes to the end. fixed the simple demo to use Transport's (host, port) constructor for simplicity, and fixed a bug where the standard demo's DSS login wouldn't work. also, move the common logfile setup crap into util so all the demos can just call that one.
This commit is contained in:
parent
70faf02f3e
commit
c6d5ba9c52
20
demo.py
20
demo.py
|
@ -31,13 +31,7 @@ def load_host_keys():
|
|||
##### main demo
|
||||
|
||||
# setup logging
|
||||
l = logging.getLogger("paramiko")
|
||||
l.setLevel(logging.DEBUG)
|
||||
if len(l.handlers) == 0:
|
||||
f = open('demo.log', 'w')
|
||||
lh = logging.StreamHandler(f)
|
||||
lh.setFormatter(logging.Formatter('%(levelname)-.3s [%(asctime)s] %(name)s: %(message)s', '%Y%m%d:%H%M%S'))
|
||||
l.addHandler(lh)
|
||||
paramiko.util.log_to_file('demo.log')
|
||||
|
||||
|
||||
username = ''
|
||||
|
@ -103,29 +97,27 @@ try:
|
|||
auth = default_auth
|
||||
|
||||
if auth == 'r':
|
||||
key = paramiko.RSAKey()
|
||||
default_path = os.environ['HOME'] + '/.ssh/id_rsa'
|
||||
path = raw_input('RSA key [%s]: ' % default_path)
|
||||
if len(path) == 0:
|
||||
path = default_path
|
||||
try:
|
||||
key.read_private_key_file(path)
|
||||
key = paramiko.RSAKey.from_private_key_file(path)
|
||||
except paramiko.PasswordRequiredException:
|
||||
password = getpass.getpass('RSA key password: ')
|
||||
key.read_private_key_file(path, password)
|
||||
key = paramiko.RSAKey.from_private_key_file(path, password)
|
||||
t.auth_publickey(username, key, event)
|
||||
elif auth == 'd':
|
||||
key = paramiko.DSSKey()
|
||||
default_path = os.environ['HOME'] + '/.ssh/id_dsa'
|
||||
path = raw_input('DSS key [%s]: ' % default_path)
|
||||
if len(path) == 0:
|
||||
path = default_path
|
||||
try:
|
||||
key.read_private_key_file(path)
|
||||
key = paramiko.DSSKey.from_private_key_file(path)
|
||||
except paramiko.PasswordRequiredException:
|
||||
password = getpass.getpass('DSS key password: ')
|
||||
key.read_private_key_file(path, password)
|
||||
t.auth_key(username, key, event)
|
||||
key = paramiko.DSSKey.from_private_key_file(path, password)
|
||||
t.auth_publickey(username, key, event)
|
||||
else:
|
||||
pw = getpass.getpass('Password for %s@%s: ' % (username, hostname))
|
||||
t.auth_password(username, pw, event)
|
||||
|
|
|
@ -29,13 +29,7 @@ def load_host_keys():
|
|||
|
||||
|
||||
# setup logging
|
||||
l = logging.getLogger("paramiko")
|
||||
l.setLevel(logging.DEBUG)
|
||||
if len(l.handlers) == 0:
|
||||
f = open('demo.log', 'w')
|
||||
lh = logging.StreamHandler(f)
|
||||
lh.setFormatter(logging.Formatter('%(levelname)-.3s [%(asctime)s] %(name)s: %(message)s', '%Y%m%d:%H%M%S'))
|
||||
l.addHandler(lh)
|
||||
paramiko.util.log_to_file('demo.log')
|
||||
|
||||
# get hostname
|
||||
username = ''
|
||||
|
@ -73,19 +67,9 @@ if hkeys.has_key(hostname):
|
|||
print 'Using host key of type %s' % hostkeytype
|
||||
|
||||
|
||||
# now connect
|
||||
# now, connect and use paramiko Transport to negotiate SSH2 across the connection
|
||||
try:
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
sock.connect((hostname, port))
|
||||
except Exception, e:
|
||||
print '*** Connect failed: ' + str(e)
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
# finally, use paramiko Transport to negotiate SSH2 across the connection
|
||||
try:
|
||||
t = paramiko.Transport(sock)
|
||||
t = paramiko.Transport((hostname, port))
|
||||
t.connect(username=username, password=password, hostkeytype=hostkeytype, hostkey=hostkey)
|
||||
chan = t.open_session()
|
||||
chan.get_pty()
|
||||
|
|
|
@ -67,7 +67,7 @@ class BER(object):
|
|||
t = size & 0x7f
|
||||
if self.idx + t > len(self.content):
|
||||
return None
|
||||
size = self.inflate_long(self.content[self.idx : self.idx + t], True)
|
||||
size = util.inflate_long(self.content[self.idx : self.idx + t], True)
|
||||
self.idx += t
|
||||
if self.idx + size > len(self.content):
|
||||
# can't fit
|
||||
|
@ -116,7 +116,7 @@ class BER(object):
|
|||
elif type(x) is str:
|
||||
self.encode_tlv(4, x)
|
||||
elif (type(x) is list) or (type(x) is tuple):
|
||||
self.encode_tlv(30, self.encode_sequence(x))
|
||||
self.encode_tlv(0x30, self.encode_sequence(x))
|
||||
else:
|
||||
raise BERException('Unknown type for encoding: %s' % repr(type(x)))
|
||||
|
||||
|
|
|
@ -22,14 +22,15 @@
|
|||
L{DSSKey}
|
||||
"""
|
||||
|
||||
from ssh_exception import SSHException
|
||||
from message import Message
|
||||
from util import inflate_long, deflate_long
|
||||
from Crypto.PublicKey import DSA
|
||||
from Crypto.Hash import SHA
|
||||
|
||||
from common import *
|
||||
import util
|
||||
from ssh_exception import SSHException
|
||||
from message import Message
|
||||
from ber import BER, BERException
|
||||
from pkey import PKey
|
||||
from ssh_exception import SSHException
|
||||
|
||||
class DSSKey (PKey):
|
||||
"""
|
||||
|
@ -38,7 +39,7 @@ class DSSKey (PKey):
|
|||
"""
|
||||
|
||||
def __init__(self, msg=None, data=None):
|
||||
self.valid = 0
|
||||
self.valid = False
|
||||
if (msg is None) and (data is not None):
|
||||
msg = Message(data)
|
||||
if (msg is None) or (msg.get_string() != 'ssh-dss'):
|
||||
|
@ -47,8 +48,8 @@ class DSSKey (PKey):
|
|||
self.q = msg.get_mpint()
|
||||
self.g = msg.get_mpint()
|
||||
self.y = msg.get_mpint()
|
||||
self.size = len(deflate_long(self.p, 0))
|
||||
self.valid = 1
|
||||
self.size = len(util.deflate_long(self.p, 0))
|
||||
self.valid = True
|
||||
|
||||
def __str__(self):
|
||||
if not self.valid:
|
||||
|
@ -77,15 +78,15 @@ class DSSKey (PKey):
|
|||
hash = SHA.new(data).digest()
|
||||
dss = DSA.construct((long(self.y), long(self.g), long(self.p), long(self.q), long(self.x)))
|
||||
# generate a suitable k
|
||||
qsize = len(deflate_long(self.q, 0))
|
||||
qsize = len(util.deflate_long(self.q, 0))
|
||||
while 1:
|
||||
k = inflate_long(randpool.get_bytes(qsize), 1)
|
||||
k = util.inflate_long(randpool.get_bytes(qsize), 1)
|
||||
if (k > 2) and (k < self.q):
|
||||
break
|
||||
r, s = dss.sign(inflate_long(hash, 1), k)
|
||||
r, s = dss.sign(util.inflate_long(hash, 1), k)
|
||||
m = Message()
|
||||
m.add_string('ssh-dss')
|
||||
m.add_string(deflate_long(r, 0) + deflate_long(s, 0))
|
||||
m.add_string(util.deflate_long(r, 0) + util.deflate_long(s, 0))
|
||||
return m
|
||||
|
||||
def verify_ssh_sig(self, data, msg):
|
||||
|
@ -101,9 +102,9 @@ class DSSKey (PKey):
|
|||
sig = msg.get_string()
|
||||
|
||||
# pull out (r, s) which are NOT encoded as mpints
|
||||
sigR = inflate_long(sig[:20], 1)
|
||||
sigS = inflate_long(sig[20:], 1)
|
||||
sigM = inflate_long(SHA.new(data).digest(), 1)
|
||||
sigR = util.inflate_long(sig[:20], 1)
|
||||
sigS = util.inflate_long(sig[20:], 1)
|
||||
sigM = util.inflate_long(SHA.new(data).digest(), 1)
|
||||
|
||||
dss = DSA.construct((long(self.y), long(self.g), long(self.p), long(self.q)))
|
||||
return dss.verify(sigM, (sigR, sigS))
|
||||
|
@ -111,12 +112,12 @@ class DSSKey (PKey):
|
|||
def read_private_key_file(self, filename, password=None):
|
||||
# private key file contains:
|
||||
# DSAPrivateKey = { version = 0, p, q, g, y, x }
|
||||
self.valid = 0
|
||||
self.valid = False
|
||||
data = self._read_private_key_file('DSA', filename, password)
|
||||
try:
|
||||
keylist = BER(data).decode()
|
||||
except BERException:
|
||||
raise SSHException('Unable to parse key file')
|
||||
except BERException, x:
|
||||
raise SSHException('Unable to parse key file: ' + str(x))
|
||||
if (type(keylist) is not list) or (len(keylist) < 6) or (keylist[0] != 0):
|
||||
raise SSHException('not a valid DSA private key file (bad ber encoding)')
|
||||
self.p = keylist[1]
|
||||
|
@ -124,5 +125,40 @@ class DSSKey (PKey):
|
|||
self.g = keylist[3]
|
||||
self.y = keylist[4]
|
||||
self.x = keylist[5]
|
||||
self.size = len(deflate_long(self.p, 0))
|
||||
self.valid = 1
|
||||
self.size = len(util.deflate_long(self.p, 0))
|
||||
self.valid = True
|
||||
|
||||
def write_private_key_file(self, filename, password=None):
|
||||
if not self.valid:
|
||||
raise SSHException('Invalid key')
|
||||
keylist = [ 0, self.p, self.q, self.g, self.y, self.x ]
|
||||
try:
|
||||
b = BER()
|
||||
b.encode(keylist)
|
||||
except BERException:
|
||||
raise SSHException('Unable to create ber encoding of key')
|
||||
self._write_private_key_file('DSA', filename, str(b), password)
|
||||
|
||||
def generate(bits=1024, progress_func=None):
|
||||
"""
|
||||
Generate a new private DSS key. This factory function can be used to
|
||||
generate a new host key or authentication key.
|
||||
|
||||
@param bits: number of bits the generated key should be.
|
||||
@type bites: int
|
||||
@param progress_func: an optional function to call at key points in
|
||||
key generation (used by L{pyCrypto.PublicKey}).
|
||||
@type progress_func: function
|
||||
@return: new private key
|
||||
@rtype: L{DSSKey}
|
||||
"""
|
||||
dsa = DSA.generate(bits, randpool.get_bytes, progress_func)
|
||||
key = DSSKey()
|
||||
key.p = dsa.p
|
||||
key.q = dsa.q
|
||||
key.g = dsa.g
|
||||
key.y = dsa.y
|
||||
key.x = dsa.x
|
||||
key.valid = True
|
||||
return key
|
||||
generate = staticmethod(generate)
|
||||
|
|
|
@ -40,7 +40,7 @@ class PKey (object):
|
|||
|
||||
# known encryption types for private key files:
|
||||
_CIPHER_TABLE = {
|
||||
'DES-EDE3-CBC': { 'cipher': DES3, 'keysize': 24, 'mode': DES3.MODE_CBC }
|
||||
'DES-EDE3-CBC': { 'cipher': DES3, 'keysize': 24, 'blocksize': 8, 'mode': DES3.MODE_CBC }
|
||||
}
|
||||
|
||||
|
||||
|
@ -307,9 +307,13 @@ class PKey (object):
|
|||
cipher_name = self._CIPHER_TABLE.keys()[0]
|
||||
cipher = self._CIPHER_TABLE[cipher_name]['cipher']
|
||||
keysize = self._CIPHER_TABLE[cipher_name]['keysize']
|
||||
blocksize = self._CIPHER_TABLE[cipher_name]['blocksize']
|
||||
mode = self._CIPHER_TABLE[cipher_name]['mode']
|
||||
salt = randpool.get_bytes(8)
|
||||
key = util.generate_key_bytes(MD5, salt, password, keysize)
|
||||
if len(data) % blocksize != 0:
|
||||
n = blocksize - len(data) % blocksize
|
||||
data += randpool.get_bytes(n)
|
||||
data = cipher.new(key, mode, salt).encrypt(data)
|
||||
f.write('Proc-Type: 4,ENCRYPTED\n')
|
||||
f.write('DEK-Info: %s,%s\n' % (cipher_name, util.hexify(salt)))
|
||||
|
|
|
@ -22,8 +22,6 @@
|
|||
L{RSAKey}
|
||||
"""
|
||||
|
||||
import base64
|
||||
|
||||
from Crypto.PublicKey import RSA
|
||||
from Crypto.Hash import SHA, MD5
|
||||
from Crypto.Cipher import DES3
|
||||
|
@ -133,6 +131,18 @@ class RSAKey (PKey):
|
|||
self._write_private_key_file('RSA', filename, str(b), password)
|
||||
|
||||
def generate(bits, progress_func=None):
|
||||
"""
|
||||
Generate a new private RSA key. This factory function can be used to
|
||||
generate a new host key or authentication key.
|
||||
|
||||
@param bits: number of bits the generated key should be.
|
||||
@type bites: int
|
||||
@param progress_func: an optional function to call at key points in
|
||||
key generation (used by L{pyCrypto.PublicKey}).
|
||||
@type progress_func: function
|
||||
@return: new private key
|
||||
@rtype: L{RSAKey}
|
||||
"""
|
||||
rsa = RSA.generate(bits, randpool.get_bytes, progress_func)
|
||||
key = RSAKey()
|
||||
key.n = rsa.n
|
||||
|
|
|
@ -22,7 +22,7 @@
|
|||
Useful functions used by the rest of paramiko.
|
||||
"""
|
||||
|
||||
import sys, struct, traceback
|
||||
import sys, struct, traceback, logging
|
||||
|
||||
def inflate_long(s, always_positive=False):
|
||||
"turns a normalized byte string into a long-int (adapted from Crypto.Util.number)"
|
||||
|
@ -173,3 +173,15 @@ def mod_inverse(x, m):
|
|||
if u2 < 0:
|
||||
u2 += m
|
||||
return u2
|
||||
|
||||
def log_to_file(filename, level=logging.DEBUG):
|
||||
"send paramiko logs to a logfile, if they're not already going somewhere"
|
||||
l = logging.getLogger("paramiko")
|
||||
if len(l.handlers) > 0:
|
||||
return
|
||||
l.setLevel(level)
|
||||
f = open(filename, 'w')
|
||||
lh = logging.StreamHandler(f)
|
||||
lh.setFormatter(logging.Formatter('%(levelname)-.3s [%(asctime)s] %(name)s: %(message)s',
|
||||
'%Y%m%d-%H:%M:%S'))
|
||||
l.addHandler(lh)
|
||||
|
|
Loading…
Reference in New Issue