[project @ Arch-1:robey@lag.net--2003-public%secsh--dev--1.0--patch-40]

add dss key generation too, and fix some bugs
added the ability to generate dss keys and write private dss key files,
similar to rsa.  in the process, fixed a couple of bugs with ber encoding
and writing password-encrypted key files.  the key has to be padded to the
iblock size of the cipher -- it's very difficult to determine how the others
do this, so i just add random bytes to the end.

fixed the simple demo to use Transport's (host, port) constructor for
simplicity, and fixed a bug where the standard demo's DSS login wouldn't
work.

also, move the common logfile setup crap into util so all the demos can just
call that one.
This commit is contained in:
Robey Pointer 2004-04-05 19:36:40 +00:00
parent 70faf02f3e
commit c6d5ba9c52
7 changed files with 96 additions and 58 deletions

20
demo.py
View File

@ -31,13 +31,7 @@ def load_host_keys():
##### main demo
# setup logging
l = logging.getLogger("paramiko")
l.setLevel(logging.DEBUG)
if len(l.handlers) == 0:
f = open('demo.log', 'w')
lh = logging.StreamHandler(f)
lh.setFormatter(logging.Formatter('%(levelname)-.3s [%(asctime)s] %(name)s: %(message)s', '%Y%m%d:%H%M%S'))
l.addHandler(lh)
paramiko.util.log_to_file('demo.log')
username = ''
@ -103,29 +97,27 @@ try:
auth = default_auth
if auth == 'r':
key = paramiko.RSAKey()
default_path = os.environ['HOME'] + '/.ssh/id_rsa'
path = raw_input('RSA key [%s]: ' % default_path)
if len(path) == 0:
path = default_path
try:
key.read_private_key_file(path)
key = paramiko.RSAKey.from_private_key_file(path)
except paramiko.PasswordRequiredException:
password = getpass.getpass('RSA key password: ')
key.read_private_key_file(path, password)
key = paramiko.RSAKey.from_private_key_file(path, password)
t.auth_publickey(username, key, event)
elif auth == 'd':
key = paramiko.DSSKey()
default_path = os.environ['HOME'] + '/.ssh/id_dsa'
path = raw_input('DSS key [%s]: ' % default_path)
if len(path) == 0:
path = default_path
try:
key.read_private_key_file(path)
key = paramiko.DSSKey.from_private_key_file(path)
except paramiko.PasswordRequiredException:
password = getpass.getpass('DSS key password: ')
key.read_private_key_file(path, password)
t.auth_key(username, key, event)
key = paramiko.DSSKey.from_private_key_file(path, password)
t.auth_publickey(username, key, event)
else:
pw = getpass.getpass('Password for %s@%s: ' % (username, hostname))
t.auth_password(username, pw, event)

View File

@ -29,13 +29,7 @@ def load_host_keys():
# setup logging
l = logging.getLogger("paramiko")
l.setLevel(logging.DEBUG)
if len(l.handlers) == 0:
f = open('demo.log', 'w')
lh = logging.StreamHandler(f)
lh.setFormatter(logging.Formatter('%(levelname)-.3s [%(asctime)s] %(name)s: %(message)s', '%Y%m%d:%H%M%S'))
l.addHandler(lh)
paramiko.util.log_to_file('demo.log')
# get hostname
username = ''
@ -73,19 +67,9 @@ if hkeys.has_key(hostname):
print 'Using host key of type %s' % hostkeytype
# now connect
# now, connect and use paramiko Transport to negotiate SSH2 across the connection
try:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect((hostname, port))
except Exception, e:
print '*** Connect failed: ' + str(e)
traceback.print_exc()
sys.exit(1)
# finally, use paramiko Transport to negotiate SSH2 across the connection
try:
t = paramiko.Transport(sock)
t = paramiko.Transport((hostname, port))
t.connect(username=username, password=password, hostkeytype=hostkeytype, hostkey=hostkey)
chan = t.open_session()
chan.get_pty()

View File

@ -67,7 +67,7 @@ class BER(object):
t = size & 0x7f
if self.idx + t > len(self.content):
return None
size = self.inflate_long(self.content[self.idx : self.idx + t], True)
size = util.inflate_long(self.content[self.idx : self.idx + t], True)
self.idx += t
if self.idx + size > len(self.content):
# can't fit
@ -116,7 +116,7 @@ class BER(object):
elif type(x) is str:
self.encode_tlv(4, x)
elif (type(x) is list) or (type(x) is tuple):
self.encode_tlv(30, self.encode_sequence(x))
self.encode_tlv(0x30, self.encode_sequence(x))
else:
raise BERException('Unknown type for encoding: %s' % repr(type(x)))

View File

@ -22,14 +22,15 @@
L{DSSKey}
"""
from ssh_exception import SSHException
from message import Message
from util import inflate_long, deflate_long
from Crypto.PublicKey import DSA
from Crypto.Hash import SHA
from common import *
import util
from ssh_exception import SSHException
from message import Message
from ber import BER, BERException
from pkey import PKey
from ssh_exception import SSHException
class DSSKey (PKey):
"""
@ -38,7 +39,7 @@ class DSSKey (PKey):
"""
def __init__(self, msg=None, data=None):
self.valid = 0
self.valid = False
if (msg is None) and (data is not None):
msg = Message(data)
if (msg is None) or (msg.get_string() != 'ssh-dss'):
@ -47,8 +48,8 @@ class DSSKey (PKey):
self.q = msg.get_mpint()
self.g = msg.get_mpint()
self.y = msg.get_mpint()
self.size = len(deflate_long(self.p, 0))
self.valid = 1
self.size = len(util.deflate_long(self.p, 0))
self.valid = True
def __str__(self):
if not self.valid:
@ -77,15 +78,15 @@ class DSSKey (PKey):
hash = SHA.new(data).digest()
dss = DSA.construct((long(self.y), long(self.g), long(self.p), long(self.q), long(self.x)))
# generate a suitable k
qsize = len(deflate_long(self.q, 0))
qsize = len(util.deflate_long(self.q, 0))
while 1:
k = inflate_long(randpool.get_bytes(qsize), 1)
k = util.inflate_long(randpool.get_bytes(qsize), 1)
if (k > 2) and (k < self.q):
break
r, s = dss.sign(inflate_long(hash, 1), k)
r, s = dss.sign(util.inflate_long(hash, 1), k)
m = Message()
m.add_string('ssh-dss')
m.add_string(deflate_long(r, 0) + deflate_long(s, 0))
m.add_string(util.deflate_long(r, 0) + util.deflate_long(s, 0))
return m
def verify_ssh_sig(self, data, msg):
@ -101,9 +102,9 @@ class DSSKey (PKey):
sig = msg.get_string()
# pull out (r, s) which are NOT encoded as mpints
sigR = inflate_long(sig[:20], 1)
sigS = inflate_long(sig[20:], 1)
sigM = inflate_long(SHA.new(data).digest(), 1)
sigR = util.inflate_long(sig[:20], 1)
sigS = util.inflate_long(sig[20:], 1)
sigM = util.inflate_long(SHA.new(data).digest(), 1)
dss = DSA.construct((long(self.y), long(self.g), long(self.p), long(self.q)))
return dss.verify(sigM, (sigR, sigS))
@ -111,12 +112,12 @@ class DSSKey (PKey):
def read_private_key_file(self, filename, password=None):
# private key file contains:
# DSAPrivateKey = { version = 0, p, q, g, y, x }
self.valid = 0
self.valid = False
data = self._read_private_key_file('DSA', filename, password)
try:
keylist = BER(data).decode()
except BERException:
raise SSHException('Unable to parse key file')
except BERException, x:
raise SSHException('Unable to parse key file: ' + str(x))
if (type(keylist) is not list) or (len(keylist) < 6) or (keylist[0] != 0):
raise SSHException('not a valid DSA private key file (bad ber encoding)')
self.p = keylist[1]
@ -124,5 +125,40 @@ class DSSKey (PKey):
self.g = keylist[3]
self.y = keylist[4]
self.x = keylist[5]
self.size = len(deflate_long(self.p, 0))
self.valid = 1
self.size = len(util.deflate_long(self.p, 0))
self.valid = True
def write_private_key_file(self, filename, password=None):
if not self.valid:
raise SSHException('Invalid key')
keylist = [ 0, self.p, self.q, self.g, self.y, self.x ]
try:
b = BER()
b.encode(keylist)
except BERException:
raise SSHException('Unable to create ber encoding of key')
self._write_private_key_file('DSA', filename, str(b), password)
def generate(bits=1024, progress_func=None):
"""
Generate a new private DSS key. This factory function can be used to
generate a new host key or authentication key.
@param bits: number of bits the generated key should be.
@type bites: int
@param progress_func: an optional function to call at key points in
key generation (used by L{pyCrypto.PublicKey}).
@type progress_func: function
@return: new private key
@rtype: L{DSSKey}
"""
dsa = DSA.generate(bits, randpool.get_bytes, progress_func)
key = DSSKey()
key.p = dsa.p
key.q = dsa.q
key.g = dsa.g
key.y = dsa.y
key.x = dsa.x
key.valid = True
return key
generate = staticmethod(generate)

View File

@ -40,7 +40,7 @@ class PKey (object):
# known encryption types for private key files:
_CIPHER_TABLE = {
'DES-EDE3-CBC': { 'cipher': DES3, 'keysize': 24, 'mode': DES3.MODE_CBC }
'DES-EDE3-CBC': { 'cipher': DES3, 'keysize': 24, 'blocksize': 8, 'mode': DES3.MODE_CBC }
}
@ -307,9 +307,13 @@ class PKey (object):
cipher_name = self._CIPHER_TABLE.keys()[0]
cipher = self._CIPHER_TABLE[cipher_name]['cipher']
keysize = self._CIPHER_TABLE[cipher_name]['keysize']
blocksize = self._CIPHER_TABLE[cipher_name]['blocksize']
mode = self._CIPHER_TABLE[cipher_name]['mode']
salt = randpool.get_bytes(8)
key = util.generate_key_bytes(MD5, salt, password, keysize)
if len(data) % blocksize != 0:
n = blocksize - len(data) % blocksize
data += randpool.get_bytes(n)
data = cipher.new(key, mode, salt).encrypt(data)
f.write('Proc-Type: 4,ENCRYPTED\n')
f.write('DEK-Info: %s,%s\n' % (cipher_name, util.hexify(salt)))

View File

@ -22,8 +22,6 @@
L{RSAKey}
"""
import base64
from Crypto.PublicKey import RSA
from Crypto.Hash import SHA, MD5
from Crypto.Cipher import DES3
@ -133,6 +131,18 @@ class RSAKey (PKey):
self._write_private_key_file('RSA', filename, str(b), password)
def generate(bits, progress_func=None):
"""
Generate a new private RSA key. This factory function can be used to
generate a new host key or authentication key.
@param bits: number of bits the generated key should be.
@type bites: int
@param progress_func: an optional function to call at key points in
key generation (used by L{pyCrypto.PublicKey}).
@type progress_func: function
@return: new private key
@rtype: L{RSAKey}
"""
rsa = RSA.generate(bits, randpool.get_bytes, progress_func)
key = RSAKey()
key.n = rsa.n

View File

@ -22,7 +22,7 @@
Useful functions used by the rest of paramiko.
"""
import sys, struct, traceback
import sys, struct, traceback, logging
def inflate_long(s, always_positive=False):
"turns a normalized byte string into a long-int (adapted from Crypto.Util.number)"
@ -173,3 +173,15 @@ def mod_inverse(x, m):
if u2 < 0:
u2 += m
return u2
def log_to_file(filename, level=logging.DEBUG):
"send paramiko logs to a logfile, if they're not already going somewhere"
l = logging.getLogger("paramiko")
if len(l.handlers) > 0:
return
l.setLevel(level)
f = open(filename, 'w')
lh = logging.StreamHandler(f)
lh.setFormatter(logging.Formatter('%(levelname)-.3s [%(asctime)s] %(name)s: %(message)s',
'%Y%m%d-%H:%M:%S'))
l.addHandler(lh)