Fix message sending

Create constants for byte messages, implement asbytes so many methods can take Message and key objects directly and split get_string into get_text and get_binary. Also, change int handling to use mpint with a flag whenever the int is greater than 32 bits.
This commit is contained in:
Scott Maxwell 2013-10-30 17:09:34 -07:00
parent 339d73cc13
commit 0e4ce3762a
27 changed files with 388 additions and 306 deletions

View File

@ -37,8 +37,11 @@ from paramiko.channel import Channel
from paramiko.common import * from paramiko.common import *
from paramiko.util import retry_on_signal from paramiko.util import retry_on_signal
SSH2_AGENTC_REQUEST_IDENTITIES, SSH2_AGENT_IDENTITIES_ANSWER, \ cSSH2_AGENTC_REQUEST_IDENTITIES = byte_chr(11)
SSH2_AGENTC_SIGN_REQUEST, SSH2_AGENT_SIGN_RESPONSE = range(11, 15) SSH2_AGENT_IDENTITIES_ANSWER = 12
cSSH2_AGENTC_SIGN_REQUEST = byte_chr(13)
SSH2_AGENT_SIGN_RESPONSE = 14
class AgentSSH(object): class AgentSSH(object):
""" """
@ -68,12 +71,12 @@ class AgentSSH(object):
def _connect(self, conn): def _connect(self, conn):
self._conn = conn self._conn = conn
ptype, result = self._send_message(chr(SSH2_AGENTC_REQUEST_IDENTITIES)) ptype, result = self._send_message(cSSH2_AGENTC_REQUEST_IDENTITIES)
if ptype != SSH2_AGENT_IDENTITIES_ANSWER: if ptype != SSH2_AGENT_IDENTITIES_ANSWER:
raise SSHException('could not get keys from ssh-agent') raise SSHException('could not get keys from ssh-agent')
keys = [] keys = []
for i in range(result.get_int()): for i in range(result.get_int()):
keys.append(AgentKey(self, result.get_string())) keys.append(AgentKey(self, result.get_binary()))
result.get_string() result.get_string()
self._keys = tuple(keys) self._keys = tuple(keys)
@ -83,7 +86,7 @@ class AgentSSH(object):
self._keys = () self._keys = ()
def _send_message(self, msg): def _send_message(self, msg):
msg = str(msg) msg = asbytes(msg)
self._conn.send(struct.pack('>I', len(msg)) + msg) self._conn.send(struct.pack('>I', len(msg)) + msg)
l = self._read_all(4) l = self._read_all(4)
msg = Message(self._read_all(struct.unpack('>I', l)[0])) msg = Message(self._read_all(struct.unpack('>I', l)[0]))
@ -360,21 +363,24 @@ class AgentKey(PKey):
def __init__(self, agent, blob): def __init__(self, agent, blob):
self.agent = agent self.agent = agent
self.blob = blob self.blob = blob
self.name = Message(blob).get_string() self.name = Message(blob).get_text()
def asbytes(self):
return self.blob
def __str__(self): def __str__(self):
return self.blob return self.asbytes()
def get_name(self): def get_name(self):
return self.name return self.name
def sign_ssh_data(self, rng, data): def sign_ssh_data(self, rng, data):
msg = Message() msg = Message()
msg.add_byte(chr(SSH2_AGENTC_SIGN_REQUEST)) msg.add_byte(cSSH2_AGENTC_SIGN_REQUEST)
msg.add_string(self.blob) msg.add_string(self.blob)
msg.add_string(data) msg.add_string(data)
msg.add_int(0) msg.add_int(0)
ptype, result = self.agent._send_message(msg) ptype, result = self.agent._send_message(msg)
if ptype != SSH2_AGENT_SIGN_RESPONSE: if ptype != SSH2_AGENT_SIGN_RESPONSE:
raise SSHException('key cannot be used for signing') raise SSHException('key cannot be used for signing')
return result.get_string() return result.get_binary()

View File

@ -119,13 +119,13 @@ class AuthHandler (object):
def _request_auth(self): def _request_auth(self):
m = Message() m = Message()
m.add_byte(chr(MSG_SERVICE_REQUEST)) m.add_byte(cMSG_SERVICE_REQUEST)
m.add_string('ssh-userauth') m.add_string('ssh-userauth')
self.transport._send_message(m) self.transport._send_message(m)
def _disconnect_service_not_available(self): def _disconnect_service_not_available(self):
m = Message() m = Message()
m.add_byte(chr(MSG_DISCONNECT)) m.add_byte(cMSG_DISCONNECT)
m.add_int(DISCONNECT_SERVICE_NOT_AVAILABLE) m.add_int(DISCONNECT_SERVICE_NOT_AVAILABLE)
m.add_string('Service not available') m.add_string('Service not available')
m.add_string('en') m.add_string('en')
@ -134,7 +134,7 @@ class AuthHandler (object):
def _disconnect_no_more_auth(self): def _disconnect_no_more_auth(self):
m = Message() m = Message()
m.add_byte(chr(MSG_DISCONNECT)) m.add_byte(cMSG_DISCONNECT)
m.add_int(DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE) m.add_int(DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE)
m.add_string('No more auth methods available') m.add_string('No more auth methods available')
m.add_string('en') m.add_string('en')
@ -144,14 +144,14 @@ class AuthHandler (object):
def _get_session_blob(self, key, service, username): def _get_session_blob(self, key, service, username):
m = Message() m = Message()
m.add_string(self.transport.session_id) m.add_string(self.transport.session_id)
m.add_byte(chr(MSG_USERAUTH_REQUEST)) m.add_byte(cMSG_USERAUTH_REQUEST)
m.add_string(username) m.add_string(username)
m.add_string(service) m.add_string(service)
m.add_string('publickey') m.add_string('publickey')
m.add_boolean(1) m.add_boolean(1)
m.add_string(key.get_name()) m.add_string(key.get_name())
m.add_string(str(key)) m.add_string(key)
return str(m) return m.asbytes()
def wait_for_response(self, event): def wait_for_response(self, event):
while True: while True:
@ -175,11 +175,11 @@ class AuthHandler (object):
return [] return []
def _parse_service_request(self, m): def _parse_service_request(self, m):
service = m.get_string() service = m.get_text()
if self.transport.server_mode and (service == 'ssh-userauth'): if self.transport.server_mode and (service == 'ssh-userauth'):
# accepted # accepted
m = Message() m = Message()
m.add_byte(chr(MSG_SERVICE_ACCEPT)) m.add_byte(cMSG_SERVICE_ACCEPT)
m.add_string(service) m.add_string(service)
self.transport._send_message(m) self.transport._send_message(m)
return return
@ -187,27 +187,25 @@ class AuthHandler (object):
self._disconnect_service_not_available() self._disconnect_service_not_available()
def _parse_service_accept(self, m): def _parse_service_accept(self, m):
service = m.get_string() service = m.get_text()
if service == 'ssh-userauth': if service == 'ssh-userauth':
self.transport._log(DEBUG, 'userauth is OK') self.transport._log(DEBUG, 'userauth is OK')
m = Message() m = Message()
m.add_byte(chr(MSG_USERAUTH_REQUEST)) m.add_byte(cMSG_USERAUTH_REQUEST)
m.add_string(self.username) m.add_string(self.username)
m.add_string('ssh-connection') m.add_string('ssh-connection')
m.add_string(self.auth_method) m.add_string(self.auth_method)
if self.auth_method == 'password': if self.auth_method == 'password':
m.add_boolean(False) m.add_boolean(False)
password = self.password password = bytestring(self.password)
if isinstance(password, unicode):
password = password.encode('UTF-8')
m.add_string(password) m.add_string(password)
elif self.auth_method == 'publickey': elif self.auth_method == 'publickey':
m.add_boolean(True) m.add_boolean(True)
m.add_string(self.private_key.get_name()) m.add_string(self.private_key.get_name())
m.add_string(str(self.private_key)) m.add_string(self.private_key)
blob = self._get_session_blob(self.private_key, 'ssh-connection', self.username) blob = self._get_session_blob(self.private_key, 'ssh-connection', self.username)
sig = self.private_key.sign_ssh_data(self.transport.rng, blob) sig = self.private_key.sign_ssh_data(self.transport.rng, blob)
m.add_string(str(sig)) m.add_string(sig)
elif self.auth_method == 'keyboard-interactive': elif self.auth_method == 'keyboard-interactive':
m.add_string('') m.add_string('')
m.add_string(self.submethods) m.add_string(self.submethods)
@ -224,11 +222,11 @@ class AuthHandler (object):
m = Message() m = Message()
if result == AUTH_SUCCESSFUL: if result == AUTH_SUCCESSFUL:
self.transport._log(INFO, 'Auth granted (%s).' % method) self.transport._log(INFO, 'Auth granted (%s).' % method)
m.add_byte(chr(MSG_USERAUTH_SUCCESS)) m.add_byte(cMSG_USERAUTH_SUCCESS)
self.authenticated = True self.authenticated = True
else: else:
self.transport._log(INFO, 'Auth rejected (%s).' % method) self.transport._log(INFO, 'Auth rejected (%s).' % method)
m.add_byte(chr(MSG_USERAUTH_FAILURE)) m.add_byte(cMSG_USERAUTH_FAILURE)
m.add_string(self.transport.server_object.get_allowed_auths(username)) m.add_string(self.transport.server_object.get_allowed_auths(username))
if result == AUTH_PARTIALLY_SUCCESSFUL: if result == AUTH_PARTIALLY_SUCCESSFUL:
m.add_boolean(1) m.add_boolean(1)
@ -244,7 +242,7 @@ class AuthHandler (object):
def _interactive_query(self, q): def _interactive_query(self, q):
# make interactive query instead of response # make interactive query instead of response
m = Message() m = Message()
m.add_byte(chr(MSG_USERAUTH_INFO_REQUEST)) m.add_byte(cMSG_USERAUTH_INFO_REQUEST)
m.add_string(q.name) m.add_string(q.name)
m.add_string(q.instructions) m.add_string(q.instructions)
m.add_string('') m.add_string('')
@ -258,7 +256,7 @@ class AuthHandler (object):
if not self.transport.server_mode: if not self.transport.server_mode:
# er, uh... what? # er, uh... what?
m = Message() m = Message()
m.add_byte(chr(MSG_USERAUTH_FAILURE)) m.add_byte(cMSG_USERAUTH_FAILURE)
m.add_string('none') m.add_string('none')
m.add_boolean(0) m.add_boolean(0)
self.transport._send_message(m) self.transport._send_message(m)
@ -266,9 +264,9 @@ class AuthHandler (object):
if self.authenticated: if self.authenticated:
# ignore # ignore
return return
username = m.get_string() username = m.get_text()
service = m.get_string() service = m.get_text()
method = m.get_string() method = m.get_text()
self.transport._log(DEBUG, 'Auth request (type=%s) service=%s, username=%s' % (method, service, username)) self.transport._log(DEBUG, 'Auth request (type=%s) service=%s, username=%s' % (method, service, username))
if service != 'ssh-connection': if service != 'ssh-connection':
self._disconnect_service_not_available() self._disconnect_service_not_available()
@ -283,7 +281,7 @@ class AuthHandler (object):
result = self.transport.server_object.check_auth_none(username) result = self.transport.server_object.check_auth_none(username)
elif method == 'password': elif method == 'password':
changereq = m.get_boolean() changereq = m.get_boolean()
password = m.get_string() password = m.get_binary()
try: try:
password = password.decode('UTF-8') password = password.decode('UTF-8')
except UnicodeError: except UnicodeError:
@ -294,7 +292,7 @@ class AuthHandler (object):
# always treated as failure, since we don't support changing passwords, but collect # always treated as failure, since we don't support changing passwords, but collect
# the list of valid auth types from the callback anyway # the list of valid auth types from the callback anyway
self.transport._log(DEBUG, 'Auth request to change passwords (rejected)') self.transport._log(DEBUG, 'Auth request to change passwords (rejected)')
newpassword = m.get_string() newpassword = m.get_binary()
try: try:
newpassword = newpassword.decode('UTF-8', 'replace') newpassword = newpassword.decode('UTF-8', 'replace')
except UnicodeError: except UnicodeError:
@ -304,8 +302,8 @@ class AuthHandler (object):
result = self.transport.server_object.check_auth_password(username, password) result = self.transport.server_object.check_auth_password(username, password)
elif method == 'publickey': elif method == 'publickey':
sig_attached = m.get_boolean() sig_attached = m.get_boolean()
keytype = m.get_string() keytype = m.get_text()
keyblob = m.get_string() keyblob = m.get_binary()
try: try:
key = self.transport._key_info[keytype](Message(keyblob)) key = self.transport._key_info[keytype](Message(keyblob))
except SSHException: except SSHException:
@ -326,12 +324,12 @@ class AuthHandler (object):
# client wants to know if this key is acceptable, before it # client wants to know if this key is acceptable, before it
# signs anything... send special "ok" message # signs anything... send special "ok" message
m = Message() m = Message()
m.add_byte(chr(MSG_USERAUTH_PK_OK)) m.add_byte(cMSG_USERAUTH_PK_OK)
m.add_string(keytype) m.add_string(keytype)
m.add_string(keyblob) m.add_string(keyblob)
self.transport._send_message(m) self.transport._send_message(m)
return return
sig = Message(m.get_string()) sig = Message(m.get_binary())
blob = self._get_session_blob(key, service, username) blob = self._get_session_blob(key, service, username)
if not key.verify_ssh_sig(blob, sig): if not key.verify_ssh_sig(blob, sig):
self.transport._log(INFO, 'Auth rejected: invalid signature') self.transport._log(INFO, 'Auth rejected: invalid signature')
@ -383,17 +381,17 @@ class AuthHandler (object):
def _parse_userauth_info_request(self, m): def _parse_userauth_info_request(self, m):
if self.auth_method != 'keyboard-interactive': if self.auth_method != 'keyboard-interactive':
raise SSHException('Illegal info request from server') raise SSHException('Illegal info request from server')
title = m.get_string() title = m.get_text()
instructions = m.get_string() instructions = m.get_text()
m.get_string() # lang m.get_binary() # lang
prompts = m.get_int() prompts = m.get_int()
prompt_list = [] prompt_list = []
for i in range(prompts): for i in range(prompts):
prompt_list.append((m.get_string(), m.get_boolean())) prompt_list.append((m.get_text(), m.get_boolean()))
response_list = self.interactive_handler(title, instructions, prompt_list) response_list = self.interactive_handler(title, instructions, prompt_list)
m = Message() m = Message()
m.add_byte(chr(MSG_USERAUTH_INFO_RESPONSE)) m.add_byte(cMSG_USERAUTH_INFO_RESPONSE)
m.add_int(len(response_list)) m.add_int(len(response_list))
for r in response_list: for r in response_list:
m.add_string(r) m.add_string(r)
@ -405,14 +403,14 @@ class AuthHandler (object):
n = m.get_int() n = m.get_int()
responses = [] responses = []
for i in range(n): for i in range(n):
responses.append(m.get_string()) responses.append(m.get_text())
result = self.transport.server_object.check_auth_interactive_response(responses) result = self.transport.server_object.check_auth_interactive_response(responses)
if isinstance(type(result), InteractiveQuery): if isinstance(type(result), InteractiveQuery):
# make interactive query instead of response # make interactive query instead of response
self._interactive_query(result) self._interactive_query(result)
return return
self._send_auth_result(self.auth_username, 'keyboard-interactive', result) self._send_auth_result(self.auth_username, 'keyboard-interactive', result)
_handler_table = { _handler_table = {
MSG_SERVICE_REQUEST: _parse_service_request, MSG_SERVICE_REQUEST: _parse_service_request,

View File

@ -30,13 +30,16 @@ class BER(object):
Robey's tiny little attempt at a BER decoder. Robey's tiny little attempt at a BER decoder.
""" """
def __init__(self, content=''): def __init__(self, content=bytes()):
self.content = content self.content = b(content)
self.idx = 0 self.idx = 0
def __str__(self): def asbytes(self):
return self.content return self.content
def __str__(self):
return self.asbytes()
def __repr__(self): def __repr__(self):
return 'BER(\'' + repr(self.content) + '\')' return 'BER(\'' + repr(self.content) + '\')'
@ -126,5 +129,5 @@ class BER(object):
b = BER() b = BER()
for item in data: for item in data:
b.encode(item) b.encode(item)
return str(b) return b.asbytes()
encode_sequence = staticmethod(encode_sequence) encode_sequence = staticmethod(encode_sequence)

View File

@ -148,7 +148,7 @@ class Channel (object):
if self.closed or self.eof_received or self.eof_sent or not self.active: if self.closed or self.eof_received or self.eof_sent or not self.active:
raise SSHException('Channel is not open') raise SSHException('Channel is not open')
m = Message() m = Message()
m.add_byte(chr(MSG_CHANNEL_REQUEST)) m.add_byte(cMSG_CHANNEL_REQUEST)
m.add_int(self.remote_chanid) m.add_int(self.remote_chanid)
m.add_string('pty-req') m.add_string('pty-req')
m.add_boolean(True) m.add_boolean(True)
@ -181,7 +181,7 @@ class Channel (object):
if self.closed or self.eof_received or self.eof_sent or not self.active: if self.closed or self.eof_received or self.eof_sent or not self.active:
raise SSHException('Channel is not open') raise SSHException('Channel is not open')
m = Message() m = Message()
m.add_byte(chr(MSG_CHANNEL_REQUEST)) m.add_byte(cMSG_CHANNEL_REQUEST)
m.add_int(self.remote_chanid) m.add_int(self.remote_chanid)
m.add_string('shell') m.add_string('shell')
m.add_boolean(1) m.add_boolean(1)
@ -208,7 +208,7 @@ class Channel (object):
if self.closed or self.eof_received or self.eof_sent or not self.active: if self.closed or self.eof_received or self.eof_sent or not self.active:
raise SSHException('Channel is not open') raise SSHException('Channel is not open')
m = Message() m = Message()
m.add_byte(chr(MSG_CHANNEL_REQUEST)) m.add_byte(cMSG_CHANNEL_REQUEST)
m.add_int(self.remote_chanid) m.add_int(self.remote_chanid)
m.add_string('exec') m.add_string('exec')
m.add_boolean(True) m.add_boolean(True)
@ -235,7 +235,7 @@ class Channel (object):
if self.closed or self.eof_received or self.eof_sent or not self.active: if self.closed or self.eof_received or self.eof_sent or not self.active:
raise SSHException('Channel is not open') raise SSHException('Channel is not open')
m = Message() m = Message()
m.add_byte(chr(MSG_CHANNEL_REQUEST)) m.add_byte(cMSG_CHANNEL_REQUEST)
m.add_int(self.remote_chanid) m.add_int(self.remote_chanid)
m.add_string('subsystem') m.add_string('subsystem')
m.add_boolean(True) m.add_boolean(True)
@ -264,7 +264,7 @@ class Channel (object):
if self.closed or self.eof_received or self.eof_sent or not self.active: if self.closed or self.eof_received or self.eof_sent or not self.active:
raise SSHException('Channel is not open') raise SSHException('Channel is not open')
m = Message() m = Message()
m.add_byte(chr(MSG_CHANNEL_REQUEST)) m.add_byte(cMSG_CHANNEL_REQUEST)
m.add_int(self.remote_chanid) m.add_int(self.remote_chanid)
m.add_string('window-change') m.add_string('window-change')
m.add_boolean(False) m.add_boolean(False)
@ -319,7 +319,7 @@ class Channel (object):
# in many cases, the channel will not still be open here. # in many cases, the channel will not still be open here.
# that's fine. # that's fine.
m = Message() m = Message()
m.add_byte(chr(MSG_CHANNEL_REQUEST)) m.add_byte(cMSG_CHANNEL_REQUEST)
m.add_int(self.remote_chanid) m.add_int(self.remote_chanid)
m.add_string('exit-status') m.add_string('exit-status')
m.add_boolean(False) m.add_boolean(False)
@ -375,7 +375,7 @@ class Channel (object):
auth_cookie = binascii.hexlify(self.transport.rng.read(16)) auth_cookie = binascii.hexlify(self.transport.rng.read(16))
m = Message() m = Message()
m.add_byte(chr(MSG_CHANNEL_REQUEST)) m.add_byte(cMSG_CHANNEL_REQUEST)
m.add_int(self.remote_chanid) m.add_int(self.remote_chanid)
m.add_string('x11-req') m.add_string('x11-req')
m.add_boolean(True) m.add_boolean(True)
@ -406,7 +406,7 @@ class Channel (object):
raise SSHException('Channel is not open') raise SSHException('Channel is not open')
m = Message() m = Message()
m.add_byte(chr(MSG_CHANNEL_REQUEST)) m.add_byte(cMSG_CHANNEL_REQUEST)
m.add_int(self.remote_chanid) m.add_int(self.remote_chanid)
m.add_string('auth-agent-req@openssh.com') m.add_string('auth-agent-req@openssh.com')
m.add_boolean(False) m.add_boolean(False)
@ -491,7 +491,7 @@ class Channel (object):
self._feed(data) self._feed(data)
return old return old
### socket API ### socket API
@ -622,7 +622,7 @@ class Channel (object):
# no need to hold the channel lock when sending this # no need to hold the channel lock when sending this
if ack > 0: if ack > 0:
m = Message() m = Message()
m.add_byte(chr(MSG_CHANNEL_WINDOW_ADJUST)) m.add_byte(cMSG_CHANNEL_WINDOW_ADJUST)
m.add_int(self.remote_chanid) m.add_int(self.remote_chanid)
m.add_int(ack) m.add_int(ack)
self.transport._send_user_message(m) self.transport._send_user_message(m)
@ -672,7 +672,7 @@ class Channel (object):
# no need to hold the channel lock when sending this # no need to hold the channel lock when sending this
if ack > 0: if ack > 0:
m = Message() m = Message()
m.add_byte(chr(MSG_CHANNEL_WINDOW_ADJUST)) m.add_byte(cMSG_CHANNEL_WINDOW_ADJUST)
m.add_int(self.remote_chanid) m.add_int(self.remote_chanid)
m.add_int(ack) m.add_int(ack)
self.transport._send_user_message(m) self.transport._send_user_message(m)
@ -724,7 +724,7 @@ class Channel (object):
# eof or similar # eof or similar
return 0 return 0
m = Message() m = Message()
m.add_byte(chr(MSG_CHANNEL_DATA)) m.add_byte(cMSG_CHANNEL_DATA)
m.add_int(self.remote_chanid) m.add_int(self.remote_chanid)
m.add_string(s[:size]) m.add_string(s[:size])
finally: finally:
@ -761,7 +761,7 @@ class Channel (object):
# eof or similar # eof or similar
return 0 return 0
m = Message() m = Message()
m.add_byte(chr(MSG_CHANNEL_EXTENDED_DATA)) m.add_byte(cMSG_CHANNEL_EXTENDED_DATA)
m.add_int(self.remote_chanid) m.add_int(self.remote_chanid)
m.add_int(1) m.add_int(1)
m.add_string(s[:size]) m.add_string(s[:size])
@ -973,12 +973,12 @@ class Channel (object):
# passed from _feed_extended # passed from _feed_extended
s = m s = m
else: else:
s = m.get_string() s = m.get_binary()
self.in_buffer.feed(s) self.in_buffer.feed(s)
def _feed_extended(self, m): def _feed_extended(self, m):
code = m.get_int() code = m.get_int()
s = m.get_string() s = m.get_text()
if code != 1: if code != 1:
self._log(ERROR, 'unknown extended_data type %d; discarding' % code) self._log(ERROR, 'unknown extended_data type %d; discarding' % code)
return return
@ -999,7 +999,7 @@ class Channel (object):
self.lock.release() self.lock.release()
def _handle_request(self, m): def _handle_request(self, m):
key = m.get_string() key = m.get_text()
want_reply = m.get_boolean() want_reply = m.get_boolean()
server = self.transport.server_object server = self.transport.server_object
ok = False ok = False
@ -1035,13 +1035,13 @@ class Channel (object):
else: else:
ok = server.check_channel_env_request(self, name, value) ok = server.check_channel_env_request(self, name, value)
elif key == 'exec': elif key == 'exec':
cmd = m.get_string() cmd = m.get_text()
if server is None: if server is None:
ok = False ok = False
else: else:
ok = server.check_channel_exec_request(self, cmd) ok = server.check_channel_exec_request(self, cmd)
elif key == 'subsystem': elif key == 'subsystem':
name = m.get_string() name = m.get_text()
if server is None: if server is None:
ok = False ok = False
else: else:
@ -1058,8 +1058,8 @@ class Channel (object):
pixelheight) pixelheight)
elif key == 'x11-req': elif key == 'x11-req':
single_connection = m.get_boolean() single_connection = m.get_boolean()
auth_proto = m.get_string() auth_proto = m.get_text()
auth_cookie = m.get_string() auth_cookie = m.get_text()
screen_number = m.get_int() screen_number = m.get_int()
if server is None: if server is None:
ok = False ok = False
@ -1077,9 +1077,9 @@ class Channel (object):
if want_reply: if want_reply:
m = Message() m = Message()
if ok: if ok:
m.add_byte(chr(MSG_CHANNEL_SUCCESS)) m.add_byte(cMSG_CHANNEL_SUCCESS)
else: else:
m.add_byte(chr(MSG_CHANNEL_FAILURE)) m.add_byte(cMSG_CHANNEL_FAILURE)
m.add_int(self.remote_chanid) m.add_int(self.remote_chanid)
self.transport._send_user_message(m) self.transport._send_user_message(m)
@ -1145,7 +1145,7 @@ class Channel (object):
if self.eof_sent: if self.eof_sent:
return None return None
m = Message() m = Message()
m.add_byte(chr(MSG_CHANNEL_EOF)) m.add_byte(cMSG_CHANNEL_EOF)
m.add_int(self.remote_chanid) m.add_int(self.remote_chanid)
self.eof_sent = True self.eof_sent = True
self._log(DEBUG, 'EOF sent (%s)', self._name) self._log(DEBUG, 'EOF sent (%s)', self._name)
@ -1157,7 +1157,7 @@ class Channel (object):
return None, None return None, None
m1 = self._send_eof() m1 = self._send_eof()
m2 = Message() m2 = Message()
m2.add_byte(chr(MSG_CHANNEL_CLOSE)) m2.add_byte(cMSG_CHANNEL_CLOSE)
m2.add_int(self.remote_chanid) m2.add_int(self.remote_chanid)
self._set_closed() self._set_closed()
# can't unlink from the Transport yet -- the remote side may still # can't unlink from the Transport yet -- the remote side may still

View File

@ -56,7 +56,7 @@ class DSSKey (PKey):
else: else:
if msg is None: if msg is None:
raise SSHException('Key object may not be empty') raise SSHException('Key object may not be empty')
if msg.get_string() != 'ssh-dss': if msg.get_text() != 'ssh-dss':
raise SSHException('Invalid key') raise SSHException('Invalid key')
self.p = msg.get_mpint() self.p = msg.get_mpint()
self.q = msg.get_mpint() self.q = msg.get_mpint()
@ -64,14 +64,17 @@ class DSSKey (PKey):
self.y = msg.get_mpint() self.y = msg.get_mpint()
self.size = util.bit_length(self.p) self.size = util.bit_length(self.p)
def __str__(self): def asbytes(self):
m = Message() m = Message()
m.add_string('ssh-dss') m.add_string('ssh-dss')
m.add_mpint(self.p) m.add_mpint(self.p)
m.add_mpint(self.q) m.add_mpint(self.q)
m.add_mpint(self.g) m.add_mpint(self.g)
m.add_mpint(self.y) m.add_mpint(self.y)
return str(m) return m.asbytes()
def __str__(self):
return self.asbytes()
def __hash__(self): def __hash__(self):
h = hash(self.get_name()) h = hash(self.get_name())
@ -114,14 +117,14 @@ class DSSKey (PKey):
return m return m
def verify_ssh_sig(self, data, msg): def verify_ssh_sig(self, data, msg):
if len(str(msg)) == 40: if len(msg.asbytes()) == 40:
# spies.com bug: signature has no header # spies.com bug: signature has no header
sig = str(msg) sig = msg.asbytes()
else: else:
kind = msg.get_string() kind = msg.get_text()
if kind != 'ssh-dss': if kind != 'ssh-dss':
return 0 return 0
sig = msg.get_string() sig = msg.get_binary()
# pull out (r, s) which are NOT encoded as mpints # pull out (r, s) which are NOT encoded as mpints
sigR = util.inflate_long(sig[:20], 1) sigR = util.inflate_long(sig[:20], 1)
@ -140,7 +143,7 @@ class DSSKey (PKey):
b.encode(keylist) b.encode(keylist)
except BERException: except BERException:
raise SSHException('Unable to create ber encoding of key') raise SSHException('Unable to create ber encoding of key')
return str(b) return b.asbytes()
def write_private_key_file(self, filename, password=None): def write_private_key_file(self, filename, password=None):
self._write_private_key_file('DSA', filename, self._encode_key(), password) self._write_private_key_file('DSA', filename, self._encode_key(), password)

View File

@ -56,30 +56,33 @@ class ECDSAKey (PKey):
else: else:
if msg is None: if msg is None:
raise SSHException('Key object may not be empty') raise SSHException('Key object may not be empty')
if msg.get_string() != 'ecdsa-sha2-nistp256': if msg.get_text() != 'ecdsa-sha2-nistp256':
raise SSHException('Invalid key') raise SSHException('Invalid key')
curvename = msg.get_string() curvename = msg.get_text()
if curvename != 'nistp256': if curvename != 'nistp256':
raise SSHException("Can't handle curve of type %s" % curvename) raise SSHException("Can't handle curve of type %s" % curvename)
pointinfo = msg.get_string() pointinfo = msg.get_binary()
if pointinfo[0] != "\x04": if pointinfo[0] != four_byte:
raise SSHException('Point compression is being used: %s'% raise SSHException('Point compression is being used: %s' %
binascii.hexlify(pointinfo)) binascii.hexlify(pointinfo))
self.verifying_key = VerifyingKey.from_string(pointinfo[1:], self.verifying_key = VerifyingKey.from_string(pointinfo[1:],
curve=curves.NIST256p) curve=curves.NIST256p)
self.size = 256 self.size = 256
def __str__(self): def asbytes(self):
key = self.verifying_key key = self.verifying_key
m = Message() m = Message()
m.add_string('ecdsa-sha2-nistp256') m.add_string('ecdsa-sha2-nistp256')
m.add_string('nistp256') m.add_string('nistp256')
point_str = "\x04" + key.to_string() point_str = four_byte + key.to_string()
m.add_string(point_str) m.add_string(point_str)
return str(m) return m.asbytes()
def __str__(self):
return self.asbytes()
def __hash__(self): def __hash__(self):
h = hash(self.get_name()) h = hash(self.get_name())
@ -106,9 +109,9 @@ class ECDSAKey (PKey):
return m return m
def verify_ssh_sig(self, data, msg): def verify_ssh_sig(self, data, msg):
if msg.get_string() != 'ecdsa-sha2-nistp256': if msg.get_text() != 'ecdsa-sha2-nistp256':
return False return False
sig = msg.get_string() sig = msg.get_binary()
# verify the signature by SHA'ing the data and encrypting it # verify the signature by SHA'ing the data and encrypting it
# using the public key. # using the public key.
@ -161,7 +164,7 @@ class ECDSAKey (PKey):
s, padding = der.remove_sequence(data) s, padding = der.remove_sequence(data)
if padding: if padding:
if padding not in self.ALLOWED_PADDINGS: if padding not in self.ALLOWED_PADDINGS:
raise ValueError, "weird padding: %s" % (binascii.hexlify(empty)) raise ValueError("weird padding: %s" % (binascii.hexlify(empty)))
data = data[:-len(padding)] data = data[:-len(padding)]
key = SigningKey.from_der(data) key = SigningKey.from_der(data)
self.signing_key = key self.signing_key = key
@ -172,7 +175,7 @@ class ECDSAKey (PKey):
msg = Message() msg = Message()
msg.add_mpint(r) msg.add_mpint(r)
msg.add_mpint(s) msg.add_mpint(s)
return str(msg) return msg.asbytes()
def _sigdecode(self, sig, order): def _sigdecode(self, sig, order):
msg = Message(sig) msg = Message(sig)

View File

@ -289,7 +289,7 @@ class HostKeys (MutableMapping):
host_key = k.get(key.get_name(), None) host_key = k.get(key.get_name(), None)
if host_key is None: if host_key is None:
return False return False
return str(host_key) == str(key) return host_key.asbytes() == key.asbytes()
def clear(self): def clear(self):
""" """

View File

@ -33,6 +33,8 @@ from paramiko.ssh_exception import SSHException
_MSG_KEXDH_GEX_REQUEST_OLD, _MSG_KEXDH_GEX_GROUP, _MSG_KEXDH_GEX_INIT, \ _MSG_KEXDH_GEX_REQUEST_OLD, _MSG_KEXDH_GEX_GROUP, _MSG_KEXDH_GEX_INIT, \
_MSG_KEXDH_GEX_REPLY, _MSG_KEXDH_GEX_REQUEST = range(30, 35) _MSG_KEXDH_GEX_REPLY, _MSG_KEXDH_GEX_REQUEST = range(30, 35)
c_MSG_KEXDH_GEX_REQUEST_OLD, c_MSG_KEXDH_GEX_GROUP, c_MSG_KEXDH_GEX_INIT, \
c_MSG_KEXDH_GEX_REPLY, c_MSG_KEXDH_GEX_REQUEST = [byte_chr(c) for c in range(30, 35)]
class KexGex (object): class KexGex (object):
@ -62,11 +64,11 @@ class KexGex (object):
m = Message() m = Message()
if _test_old_style: if _test_old_style:
# only used for unit tests: we shouldn't ever send this # only used for unit tests: we shouldn't ever send this
m.add_byte(chr(_MSG_KEXDH_GEX_REQUEST_OLD)) m.add_byte(c_MSG_KEXDH_GEX_REQUEST_OLD)
m.add_int(self.preferred_bits) m.add_int(self.preferred_bits)
self.old_style = True self.old_style = True
else: else:
m.add_byte(chr(_MSG_KEXDH_GEX_REQUEST)) m.add_byte(c_MSG_KEXDH_GEX_REQUEST)
m.add_int(self.min_bits) m.add_int(self.min_bits)
m.add_int(self.preferred_bits) m.add_int(self.preferred_bits)
m.add_int(self.max_bits) m.add_int(self.max_bits)
@ -135,7 +137,7 @@ class KexGex (object):
self.transport._log(DEBUG, 'Picking p (%d <= %d <= %d bits)' % (minbits, preferredbits, maxbits)) self.transport._log(DEBUG, 'Picking p (%d <= %d <= %d bits)' % (minbits, preferredbits, maxbits))
self.g, self.p = pack.get_modulus(minbits, preferredbits, maxbits) self.g, self.p = pack.get_modulus(minbits, preferredbits, maxbits)
m = Message() m = Message()
m.add_byte(chr(_MSG_KEXDH_GEX_GROUP)) m.add_byte(c_MSG_KEXDH_GEX_GROUP)
m.add_mpint(self.p) m.add_mpint(self.p)
m.add_mpint(self.g) m.add_mpint(self.g)
self.transport._send_message(m) self.transport._send_message(m)
@ -156,7 +158,7 @@ class KexGex (object):
self.transport._log(DEBUG, 'Picking p (~ %d bits)' % (self.preferred_bits,)) self.transport._log(DEBUG, 'Picking p (~ %d bits)' % (self.preferred_bits,))
self.g, self.p = pack.get_modulus(self.min_bits, self.preferred_bits, self.max_bits) self.g, self.p = pack.get_modulus(self.min_bits, self.preferred_bits, self.max_bits)
m = Message() m = Message()
m.add_byte(chr(_MSG_KEXDH_GEX_GROUP)) m.add_byte(c_MSG_KEXDH_GEX_GROUP)
m.add_mpint(self.p) m.add_mpint(self.p)
m.add_mpint(self.g) m.add_mpint(self.g)
self.transport._send_message(m) self.transport._send_message(m)
@ -175,7 +177,7 @@ class KexGex (object):
# now compute e = g^x mod p # now compute e = g^x mod p
self.e = pow(self.g, self.x, self.p) self.e = pow(self.g, self.x, self.p)
m = Message() m = Message()
m.add_byte(chr(_MSG_KEXDH_GEX_INIT)) m.add_byte(c_MSG_KEXDH_GEX_INIT)
m.add_mpint(self.e) m.add_mpint(self.e)
self.transport._send_message(m) self.transport._send_message(m)
self.transport._expect_packet(_MSG_KEXDH_GEX_REPLY) self.transport._expect_packet(_MSG_KEXDH_GEX_REPLY)
@ -187,7 +189,7 @@ class KexGex (object):
self._generate_x() self._generate_x()
self.f = pow(self.g, self.x, self.p) self.f = pow(self.g, self.x, self.p)
K = pow(self.e, self.x, self.p) K = pow(self.e, self.x, self.p)
key = str(self.transport.get_server_key()) key = self.transport.get_server_key().asbytes()
# okay, build up the hash H of (V_C || V_S || I_C || I_S || K_S || min || n || max || p || g || e || f || K) # okay, build up the hash H of (V_C || V_S || I_C || I_S || K_S || min || n || max || p || g || e || f || K)
hm = Message() hm = Message()
hm.add(self.transport.remote_version, self.transport.local_version, hm.add(self.transport.remote_version, self.transport.local_version,
@ -203,16 +205,16 @@ class KexGex (object):
hm.add_mpint(self.e) hm.add_mpint(self.e)
hm.add_mpint(self.f) hm.add_mpint(self.f)
hm.add_mpint(K) hm.add_mpint(K)
H = SHA.new(str(hm)).digest() H = SHA.new(hm.asbytes()).digest()
self.transport._set_K_H(K, H) self.transport._set_K_H(K, H)
# sign it # sign it
sig = self.transport.get_server_key().sign_ssh_data(self.transport.rng, H) sig = self.transport.get_server_key().sign_ssh_data(self.transport.rng, H)
# send reply # send reply
m = Message() m = Message()
m.add_byte(chr(_MSG_KEXDH_GEX_REPLY)) m.add_byte(c_MSG_KEXDH_GEX_REPLY)
m.add_string(key) m.add_string(key)
m.add_mpint(self.f) m.add_mpint(self.f)
m.add_string(str(sig)) m.add_string(sig)
self.transport._send_message(m) self.transport._send_message(m)
self.transport._activate_outbound() self.transport._activate_outbound()
@ -238,6 +240,6 @@ class KexGex (object):
hm.add_mpint(self.e) hm.add_mpint(self.e)
hm.add_mpint(self.f) hm.add_mpint(self.f)
hm.add_mpint(K) hm.add_mpint(K)
self.transport._set_K_H(K, SHA.new(str(hm)).digest()) self.transport._set_K_H(K, SHA.new(hm.asbytes()).digest())
self.transport._verify_key(host_key, sig) self.transport._verify_key(host_key, sig)
self.transport._activate_outbound() self.transport._activate_outbound()

View File

@ -56,7 +56,7 @@ class KexGroup1(object):
# compute e = g^x mod p (where g=2), and send it # compute e = g^x mod p (where g=2), and send it
self.e = pow(G, self.x, P) self.e = pow(G, self.x, P)
m = Message() m = Message()
m.add_byte(chr(_MSG_KEXDH_INIT)) m.add_byte(c_MSG_KEXDH_INIT)
m.add_mpint(self.e) m.add_mpint(self.e)
self.transport._send_message(m) self.transport._send_message(m)
self.transport._expect_packet(_MSG_KEXDH_REPLY) self.transport._expect_packet(_MSG_KEXDH_REPLY)
@ -67,7 +67,7 @@ class KexGroup1(object):
elif not self.transport.server_mode and (ptype == _MSG_KEXDH_REPLY): elif not self.transport.server_mode and (ptype == _MSG_KEXDH_REPLY):
return self._parse_kexdh_reply(m) return self._parse_kexdh_reply(m)
raise SSHException('KexGroup1 asked to handle packet type %d' % ptype) raise SSHException('KexGroup1 asked to handle packet type %d' % ptype)
### internals... ### internals...
@ -92,7 +92,7 @@ class KexGroup1(object):
self.f = m.get_mpint() self.f = m.get_mpint()
if (self.f < 1) or (self.f > P - 1): if (self.f < 1) or (self.f > P - 1):
raise SSHException('Server kex "f" is out of range') raise SSHException('Server kex "f" is out of range')
sig = m.get_string() sig = m.get_binary()
K = pow(self.f, self.x, P) K = pow(self.f, self.x, P)
# okay, build up the hash H of (V_C || V_S || I_C || I_S || K_S || e || f || K) # okay, build up the hash H of (V_C || V_S || I_C || I_S || K_S || e || f || K)
hm = Message() hm = Message()
@ -102,7 +102,7 @@ class KexGroup1(object):
hm.add_mpint(self.e) hm.add_mpint(self.e)
hm.add_mpint(self.f) hm.add_mpint(self.f)
hm.add_mpint(K) hm.add_mpint(K)
self.transport._set_K_H(K, SHA.new(str(hm)).digest()) self.transport._set_K_H(K, SHA.new(hm.asbytes()).digest())
self.transport._verify_key(host_key, sig) self.transport._verify_key(host_key, sig)
self.transport._activate_outbound() self.transport._activate_outbound()
@ -112,7 +112,7 @@ class KexGroup1(object):
if (self.e < 1) or (self.e > P - 1): if (self.e < 1) or (self.e > P - 1):
raise SSHException('Client kex "e" is out of range') raise SSHException('Client kex "e" is out of range')
K = pow(self.e, self.x, P) K = pow(self.e, self.x, P)
key = str(self.transport.get_server_key()) key = self.transport.get_server_key().asbytes()
# okay, build up the hash H of (V_C || V_S || I_C || I_S || K_S || e || f || K) # okay, build up the hash H of (V_C || V_S || I_C || I_S || K_S || e || f || K)
hm = Message() hm = Message()
hm.add(self.transport.remote_version, self.transport.local_version, hm.add(self.transport.remote_version, self.transport.local_version,
@ -121,15 +121,15 @@ class KexGroup1(object):
hm.add_mpint(self.e) hm.add_mpint(self.e)
hm.add_mpint(self.f) hm.add_mpint(self.f)
hm.add_mpint(K) hm.add_mpint(K)
H = SHA.new(str(hm)).digest() H = SHA.new(hm.asbytes()).digest()
self.transport._set_K_H(K, H) self.transport._set_K_H(K, H)
# sign it # sign it
sig = self.transport.get_server_key().sign_ssh_data(self.transport.rng, H) sig = self.transport.get_server_key().sign_ssh_data(self.transport.rng, H)
# send reply # send reply
m = Message() m = Message()
m.add_byte(chr(_MSG_KEXDH_REPLY)) m.add_byte(c_MSG_KEXDH_REPLY)
m.add_string(key) m.add_string(key)
m.add_mpint(self.f) m.add_mpint(self.f)
m.add_string(str(sig)) m.add_string(sig)
self.transport._send_message(m) self.transport._send_message(m)
self.transport._activate_outbound() self.transport._activate_outbound()

View File

@ -37,6 +37,8 @@ class Message (object):
paramiko doesn't support yet. paramiko doesn't support yet.
""" """
big_int = long(0xff000000)
def __init__(self, content=None): def __init__(self, content=None):
""" """
Create a new SSH2 Message. Create a new SSH2 Message.
@ -46,18 +48,12 @@ class Message (object):
@type content: string @type content: string
""" """
if content != None: if content != None:
self.packet = cStringIO.StringIO(content) self.packet = BytesIO(content)
else: else:
self.packet = cStringIO.StringIO() self.packet = BytesIO()
def __str__(self): def __str__(self):
""" return self.asbytes()
Return the byte stream content of this Message, as a string.
@return: the contents of this Message.
@rtype: string
"""
return self.packet.getvalue()
def __repr__(self): def __repr__(self):
""" """
@ -67,6 +63,15 @@ class Message (object):
""" """
return 'paramiko.Message(' + repr(self.packet.getvalue()) + ')' return 'paramiko.Message(' + repr(self.packet.getvalue()) + ')'
def asbytes(self):
"""
Return the byte stream content of this Message, as bytes.
@return: the contents of this Message.
@rtype: bytes
"""
return self.packet.getvalue()
def rewind(self): def rewind(self):
""" """
Rewind the message to the beginning as if no items had been parsed Rewind the message to the beginning as if no items had been parsed
@ -112,7 +117,7 @@ class Message (object):
b = self.packet.read(n) b = self.packet.read(n)
max_pad_size = 1<<20 # Limit padding to 1 MB max_pad_size = 1<<20 # Limit padding to 1 MB
if len(b) < n and n < max_pad_size: if len(b) < n and n < max_pad_size:
return b + '\x00' * (n - len(b)) return b + zero_byte * (n - len(b))
return b return b
def get_byte(self): def get_byte(self):
@ -134,12 +139,25 @@ class Message (object):
@rtype: bool @rtype: bool
""" """
b = self.get_bytes(1) b = self.get_bytes(1)
return b != '\x00' return b != zero_byte
def get_int(self): def get_int(self):
""" """
Fetch an int from the stream. Fetch an int from the stream.
@return: a 32-bit unsigned integer.
@rtype: int
"""
byte = self.get_bytes(1)
if byte == max_byte:
return util.inflate_long(self.get_binary())
byte += self.get_bytes(3)
return struct.unpack('>I', byte)[0]
def get_size(self):
"""
Fetch an int from the stream.
@return: a 32-bit unsigned integer. @return: a 32-bit unsigned integer.
@rtype: int @rtype: int
""" """
@ -152,7 +170,7 @@ class Message (object):
@return: a 64-bit unsigned integer. @return: a 64-bit unsigned integer.
@rtype: long @rtype: long
""" """
return struct.unpack('>Q', self.get_bytes(8))[0] return self.get_int()
def get_mpint(self): def get_mpint(self):
""" """
@ -161,7 +179,7 @@ class Message (object):
@return: an arbitrary-length integer. @return: an arbitrary-length integer.
@rtype: long @rtype: long
""" """
return util.inflate_long(self.get_string()) return util.inflate_long(self.get_binary())
def get_string(self): def get_string(self):
""" """
@ -172,7 +190,30 @@ class Message (object):
@return: a string. @return: a string.
@rtype: string @rtype: string
""" """
return self.get_bytes(self.get_int()) return self.get_bytes(self.get_size())
def get_text(self):
"""
Fetch a string from the stream. This could be a byte string and may
contain unprintable characters. (It's not unheard of for a string to
contain another byte-stream Message.)
@return: a string.
@rtype: string
"""
return u(self.get_bytes(self.get_size()))
#return self.get_bytes(self.get_size())
def get_binary(self):
"""
Fetch a string from the stream. This could be a byte string and may
contain unprintable characters. (It's not unheard of for a string to
contain another byte-stream Message.)
@return: a string.
@rtype: string
"""
return self.get_bytes(self.get_size())
def get_list(self): def get_list(self):
""" """
@ -182,7 +223,7 @@ class Message (object):
@return: a list of strings. @return: a list of strings.
@rtype: list of strings @rtype: list of strings
""" """
return self.get_string().split(',') return self.get_text().split(',')
def add_bytes(self, b): def add_bytes(self, b):
""" """
@ -212,12 +253,12 @@ class Message (object):
@type b: bool @type b: bool
""" """
if b: if b:
self.add_byte('\x01') self.packet.write(one_byte)
else: else:
self.add_byte('\x00') self.packet.write(zero_byte)
return self return self
def add_int(self, n): def add_size(self, n):
""" """
Add an integer to the stream. Add an integer to the stream.
@ -227,6 +268,20 @@ class Message (object):
self.packet.write(struct.pack('>I', n)) self.packet.write(struct.pack('>I', n))
return self return self
def add_int(self, n):
"""
Add an integer to the stream.
@param n: integer to add
@type n: int
"""
if n >= Message.big_int:
self.packet.write(max_byte)
self.add_string(util.deflate_long(n))
else:
self.packet.write(struct.pack('>I', n))
return self
def add_int64(self, n): def add_int64(self, n):
""" """
Add a 64-bit int to the stream. Add a 64-bit int to the stream.
@ -234,8 +289,7 @@ class Message (object):
@param n: long int to add @param n: long int to add
@type n: long @type n: long
""" """
self.packet.write(struct.pack('>Q', n)) return self.add_int(n)
return self
def add_mpint(self, z): def add_mpint(self, z):
""" """
@ -255,7 +309,8 @@ class Message (object):
@param s: string to add @param s: string to add
@type s: str @type s: str
""" """
self.add_int(len(s)) s = asbytes(s)
self.add_size(len(s))
self.packet.write(s) self.packet.write(s)
return self return self
@ -272,21 +327,14 @@ class Message (object):
return self return self
def _add(self, i): def _add(self, i):
if type(i) is str: if type(i) is bool:
return self.add_string(i)
elif type(i) is int:
return self.add_int(i)
elif type(i) is long:
if i > 0xffffffffL:
return self.add_mpint(i)
else:
return self.add_int(i)
elif type(i) is bool:
return self.add_boolean(i) return self.add_boolean(i)
elif isinstance(i, integer_types):
return self.add_int(i)
elif type(i) is list: elif type(i) is list:
return self.add_list(i) return self.add_list(i)
else: else:
raise Exception('Unknown type') return self.add_string(i)
def add(self, *seq): def add(self, *seq):
""" """

View File

@ -63,7 +63,7 @@ class PKey (object):
""" """
pass pass
def __str__(self): def asbytes(self):
""" """
Return a string of an SSH L{Message} made up of the public part(s) of Return a string of an SSH L{Message} made up of the public part(s) of
this key. This string is suitable for passing to L{__init__} to this key. This string is suitable for passing to L{__init__} to
@ -72,7 +72,10 @@ class PKey (object):
@return: string representation of an SSH key message. @return: string representation of an SSH key message.
@rtype: str @rtype: str
""" """
return '' return bytes()
def __str__(self):
return self.asbytes()
def __cmp__(self, other): def __cmp__(self, other):
""" """
@ -90,7 +93,10 @@ class PKey (object):
ho = hash(other) ho = hash(other)
if hs != ho: if hs != ho:
return cmp(hs, ho) return cmp(hs, ho)
return cmp(str(self), str(other)) return cmp(self.asbytes(), other.asbytes())
def __eq__(self, other):
return hash(self) == hash(other)
def get_name(self): def get_name(self):
""" """
@ -131,7 +137,7 @@ class PKey (object):
format. format.
@rtype: str @rtype: str
""" """
return MD5.new(str(self)).digest() return MD5.new(self.asbytes()).digest()
def get_base64(self): def get_base64(self):
""" """
@ -142,7 +148,7 @@ class PKey (object):
@return: a base64 string containing the public part of the key. @return: a base64 string containing the public part of the key.
@rtype: str @rtype: str
""" """
return base64.encodestring(str(self)).replace('\n', '') return base64.encodestring(self.asbytes()).replace('\n', '')
def sign_ssh_data(self, rng, data): def sign_ssh_data(self, rng, data):
""" """
@ -156,7 +162,7 @@ class PKey (object):
@return: an SSH signature message. @return: an SSH signature message.
@rtype: L{Message} @rtype: L{Message}
""" """
return '' return bytes()
def verify_ssh_sig(self, data, msg): def verify_ssh_sig(self, data, msg):
""" """
@ -303,7 +309,7 @@ class PKey (object):
end += 1 end += 1
# if we trudged to the end of the file, just try to cope. # if we trudged to the end of the file, just try to cope.
try: try:
data = base64.decodestring(''.join(lines[start:end])) data = base64.decodestring(b(''.join(lines[start:end])))
except base64.binascii.Error: except base64.binascii.Error:
raise SSHException('base64 decoding error: ' + str(sys.exc_info()[1])) raise SSHException('base64 decoding error: ' + str(sys.exc_info()[1]))
if 'proc-type' not in headers: if 'proc-type' not in headers:
@ -356,7 +362,7 @@ class PKey (object):
f.write('-----BEGIN %s PRIVATE KEY-----\n' % tag) f.write('-----BEGIN %s PRIVATE KEY-----\n' % tag)
if password is not None: if password is not None:
# since we only support one cipher here, use it # since we only support one cipher here, use it
cipher_name = self._CIPHER_TABLE.keys()[0] cipher_name = list(self._CIPHER_TABLE.keys())[0]
cipher = self._CIPHER_TABLE[cipher_name]['cipher'] cipher = self._CIPHER_TABLE[cipher_name]['cipher']
keysize = self._CIPHER_TABLE[cipher_name]['keysize'] keysize = self._CIPHER_TABLE[cipher_name]['keysize']
blocksize = self._CIPHER_TABLE[cipher_name]['blocksize'] blocksize = self._CIPHER_TABLE[cipher_name]['blocksize']

View File

@ -57,18 +57,21 @@ class RSAKey (PKey):
else: else:
if msg is None: if msg is None:
raise SSHException('Key object may not be empty') raise SSHException('Key object may not be empty')
if msg.get_string() != 'ssh-rsa': if msg.get_text() != 'ssh-rsa':
raise SSHException('Invalid key') raise SSHException('Invalid key')
self.e = msg.get_mpint() self.e = msg.get_mpint()
self.n = msg.get_mpint() self.n = msg.get_mpint()
self.size = util.bit_length(self.n) self.size = util.bit_length(self.n)
def __str__(self): def asbytes(self):
m = Message() m = Message()
m.add_string('ssh-rsa') m.add_string('ssh-rsa')
m.add_mpint(self.e) m.add_mpint(self.e)
m.add_mpint(self.n) m.add_mpint(self.n)
return str(m) return m.asbytes()
def __str__(self):
return self.asbytes()
def __hash__(self): def __hash__(self):
h = hash(self.get_name()) h = hash(self.get_name())
@ -95,9 +98,9 @@ class RSAKey (PKey):
return m return m
def verify_ssh_sig(self, data, msg): def verify_ssh_sig(self, data, msg):
if msg.get_string() != 'ssh-rsa': if msg.get_text() != 'ssh-rsa':
return False return False
sig = util.inflate_long(msg.get_string(), True) sig = util.inflate_long(msg.get_binary(), True)
# verify the signature by SHA'ing the data and encrypting it using the # verify the signature by SHA'ing the data and encrypting it using the
# public key. some wackiness ensues where we "pkcs1imify" the 20-byte # public key. some wackiness ensues where we "pkcs1imify" the 20-byte
# hash into a string as long as the RSA key. # hash into a string as long as the RSA key.
@ -116,7 +119,7 @@ class RSAKey (PKey):
b.encode(keylist) b.encode(keylist)
except BERException: except BERException:
raise SSHException('Unable to create ber encoding of key') raise SSHException('Unable to create ber encoding of key')
return str(b) return b.asbytes()
def write_private_key_file(self, filename, password=None): def write_private_key_file(self, filename, password=None):
self._write_private_key_file('RSA', filename, self._encode_key(), password) self._write_private_key_file('RSA', filename, self._encode_key(), password)

View File

@ -86,7 +86,7 @@ CMD_NAMES = {
CMD_ATTRS: 'attrs', CMD_ATTRS: 'attrs',
CMD_EXTENDED: 'extended', CMD_EXTENDED: 'extended',
CMD_EXTENDED_REPLY: 'extended_reply' CMD_EXTENDED_REPLY: 'extended_reply'
} }
class SFTPError (Exception): class SFTPError (Exception):
@ -125,7 +125,7 @@ class BaseSFTP (object):
msg = Message() msg = Message()
msg.add_int(_VERSION) msg.add_int(_VERSION)
msg.add(*extension_pairs) msg.add(*extension_pairs)
self._send_packet(CMD_VERSION, str(msg)) self._send_packet(CMD_VERSION, msg)
return version return version
def _log(self, level, msg, *args): def _log(self, level, msg, *args):
@ -167,6 +167,7 @@ class BaseSFTP (object):
def _send_packet(self, t, packet): def _send_packet(self, t, packet):
#self._log(DEBUG2, 'write: %s (len=%d)' % (CMD_NAMES.get(t, '0x%02x' % t), len(packet))) #self._log(DEBUG2, 'write: %s (len=%d)' % (CMD_NAMES.get(t, '0x%02x' % t), len(packet)))
out = struct.pack('>I', len(packet) + 1) + chr(t) + packet out = struct.pack('>I', len(packet) + 1) + chr(t) + packet
packet = asbytes(packet)
if self.ultra_debug: if self.ultra_debug:
self._log(DEBUG, util.format_binary(out, 'OUT: ')) self._log(DEBUG, util.format_binary(out, 'OUT: '))
self._write_all(out) self._write_all(out)

View File

@ -173,7 +173,7 @@ class SFTPClient (BaseSFTP):
t, msg = self._request(CMD_OPENDIR, path) t, msg = self._request(CMD_OPENDIR, path)
if t != CMD_HANDLE: if t != CMD_HANDLE:
raise SFTPError('Expected handle') raise SFTPError('Expected handle')
handle = msg.get_string() handle = msg.get_binary()
filelist = [] filelist = []
while True: while True:
try: try:
@ -245,7 +245,7 @@ class SFTPClient (BaseSFTP):
t, msg = self._request(CMD_OPEN, filename, imode, attrblock) t, msg = self._request(CMD_OPEN, filename, imode, attrblock)
if t != CMD_HANDLE: if t != CMD_HANDLE:
raise SFTPError('Expected handle') raise SFTPError('Expected handle')
handle = msg.get_string() handle = msg.get_binary()
self._log(DEBUG, 'open(%r, %r) -> %s' % (filename, mode, hexlify(handle))) self._log(DEBUG, 'open(%r, %r) -> %s' % (filename, mode, hexlify(handle)))
return SFTPFile(self, handle, mode, bufsize) return SFTPFile(self, handle, mode, bufsize)
@ -369,8 +369,7 @@ class SFTPClient (BaseSFTP):
""" """
dest = self._adjust_cwd(dest) dest = self._adjust_cwd(dest)
self._log(DEBUG, 'symlink(%r, %r)' % (source, dest)) self._log(DEBUG, 'symlink(%r, %r)' % (source, dest))
if type(source) is unicode: source = bytestring(source)
source = source.encode('utf-8')
self._request(CMD_SYMLINK, source, dest) self._request(CMD_SYMLINK, source, dest)
def chmod(self, path, mode): def chmod(self, path, mode):
@ -610,7 +609,7 @@ class SFTPClient (BaseSFTP):
@since: 1.4 @since: 1.4
""" """
file_size = os.stat(localpath).st_size file_size = os.stat(localpath).st_size
fl = file(localpath, 'rb') fl = open(localpath, 'rb')
try: try:
return self.putfo(fl, remotepath, os.stat(localpath).st_size, callback, confirm) return self.putfo(fl, remotepath, os.stat(localpath).st_size, callback, confirm)
finally: finally:
@ -636,7 +635,7 @@ class SFTPClient (BaseSFTP):
@since: 1.4 @since: 1.4
""" """
fr = self.file(remotepath, 'rb') fr = self.open(remotepath, 'rb')
file_size = self.stat(remotepath).st_size file_size = self.stat(remotepath).st_size
fr.prefetch() fr.prefetch()
try: try:
@ -671,7 +670,7 @@ class SFTPClient (BaseSFTP):
@since: 1.4 @since: 1.4
""" """
file_size = self.stat(remotepath).st_size file_size = self.stat(remotepath).st_size
fl = file(localpath, 'wb') fl = open(localpath, 'wb')
try: try:
size = self.getfo(remotepath, fl, callback) size = self.getfo(remotepath, fl, callback)
finally: finally:
@ -707,7 +706,7 @@ class SFTPClient (BaseSFTP):
raise Exception('unknown type for %r type %r' % (item, type(item))) raise Exception('unknown type for %r type %r' % (item, type(item)))
num = self.request_number num = self.request_number
self._expecting[num] = fileobj self._expecting[num] = fileobj
self._send_packet(t, str(msg)) self._send_packet(t, msg)
self.request_number += 1 self.request_number += 1
finally: finally:
self._lock.release() self._lock.release()
@ -752,7 +751,7 @@ class SFTPClient (BaseSFTP):
Raises EOFError or IOError on error status; otherwise does nothing. Raises EOFError or IOError on error status; otherwise does nothing.
""" """
code = msg.get_int() code = msg.get_int()
text = msg.get_string() text = msg.get_text()
if code == SFTP_OK: if code == SFTP_OK:
return return
elif code == SFTP_EOF: elif code == SFTP_EOF:
@ -770,8 +769,7 @@ class SFTPClient (BaseSFTP):
Return an adjusted path if we're emulating a "current working Return an adjusted path if we're emulating a "current working
directory" for the server. directory" for the server.
""" """
if type(path) is unicode: path = bytestring(path)
path = path.encode('utf-8')
if self._cwd is None: if self._cwd is None:
return path return path
if (len(path) > 0) and (path[0] == '/'): if (len(path) > 0) and (path[0] == '/'):

View File

@ -348,8 +348,8 @@ class SFTPFile (BufferedFile):
""" """
t, msg = self.sftp._request(CMD_EXTENDED, 'check-file', self.handle, t, msg = self.sftp._request(CMD_EXTENDED, 'check-file', self.handle,
hash_algorithm, long(offset), long(length), block_size) hash_algorithm, long(offset), long(length), block_size)
ext = msg.get_string() ext = msg.get_text()
alg = msg.get_string() alg = msg.get_text()
data = msg.get_remainder() data = msg.get_remainder()
return data return data

View File

@ -189,7 +189,7 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
item._pack(msg) item._pack(msg)
else: else:
raise Exception('unknown type for ' + repr(item) + ' type ' + repr(type(item))) raise Exception('unknown type for ' + repr(item) + ' type ' + repr(type(item)))
self._send_packet(t, str(msg)) self._send_packet(t, msg)
def _send_handle_response(self, request_number, handle, folder=False): def _send_handle_response(self, request_number, handle, folder=False):
if not issubclass(type(handle), SFTPHandle): if not issubclass(type(handle), SFTPHandle):
@ -236,14 +236,14 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
msg.add_string(attr.filename) msg.add_string(attr.filename)
msg.add_string(str(attr)) msg.add_string(str(attr))
attr._pack(msg) attr._pack(msg)
self._send_packet(CMD_NAME, str(msg)) self._send_packet(CMD_NAME, msg)
def _check_file(self, request_number, msg): def _check_file(self, request_number, msg):
# this extension actually comes from v6 protocol, but since it's an # this extension actually comes from v6 protocol, but since it's an
# extension, i feel like we can reasonably support it backported. # extension, i feel like we can reasonably support it backported.
# it's very useful for verifying uploaded files or checking for # it's very useful for verifying uploaded files or checking for
# rsync-like differences between local and remote files. # rsync-like differences between local and remote files.
handle = msg.get_string() handle = msg.get_binary()
alg_list = msg.get_list() alg_list = msg.get_list()
start = msg.get_int64() start = msg.get_int64()
length = msg.get_int64() length = msg.get_int64()
@ -295,7 +295,7 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
msg.add_string('check-file') msg.add_string('check-file')
msg.add_string(algname) msg.add_string(algname)
msg.add_bytes(sum_out) msg.add_bytes(sum_out)
self._send_packet(CMD_EXTENDED_REPLY, str(msg)) self._send_packet(CMD_EXTENDED_REPLY, msg)
def _convert_pflags(self, pflags): def _convert_pflags(self, pflags):
"convert SFTP-style open() flags to python's os.open() flags" "convert SFTP-style open() flags to python's os.open() flags"
@ -318,12 +318,12 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
def _process(self, t, request_number, msg): def _process(self, t, request_number, msg):
self._log(DEBUG, 'Request: %s' % CMD_NAMES[t]) self._log(DEBUG, 'Request: %s' % CMD_NAMES[t])
if t == CMD_OPEN: if t == CMD_OPEN:
path = msg.get_string() path = msg.get_text()
flags = self._convert_pflags(msg.get_int()) flags = self._convert_pflags(msg.get_int())
attr = SFTPAttributes._from_msg(msg) attr = SFTPAttributes._from_msg(msg)
self._send_handle_response(request_number, self.server.open(path, flags, attr)) self._send_handle_response(request_number, self.server.open(path, flags, attr))
elif t == CMD_CLOSE: elif t == CMD_CLOSE:
handle = msg.get_string() handle = msg.get_binary()
if handle in self.folder_table: if handle in self.folder_table:
del self.folder_table[handle] del self.folder_table[handle]
self._send_status(request_number, SFTP_OK) self._send_status(request_number, SFTP_OK)
@ -335,7 +335,7 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
return return
self._send_status(request_number, SFTP_BAD_MESSAGE, 'Invalid handle') self._send_status(request_number, SFTP_BAD_MESSAGE, 'Invalid handle')
elif t == CMD_READ: elif t == CMD_READ:
handle = msg.get_string() handle = msg.get_binary()
offset = msg.get_int64() offset = msg.get_int64()
length = msg.get_int() length = msg.get_int()
if handle not in self.file_table: if handle not in self.file_table:
@ -350,54 +350,54 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
else: else:
self._send_status(request_number, data) self._send_status(request_number, data)
elif t == CMD_WRITE: elif t == CMD_WRITE:
handle = msg.get_string() handle = msg.get_binary()
offset = msg.get_int64() offset = msg.get_int64()
data = msg.get_string() data = msg.get_binary()
if handle not in self.file_table: if handle not in self.file_table:
self._send_status(request_number, SFTP_BAD_MESSAGE, 'Invalid handle') self._send_status(request_number, SFTP_BAD_MESSAGE, 'Invalid handle')
return return
self._send_status(request_number, self.file_table[handle].write(offset, data)) self._send_status(request_number, self.file_table[handle].write(offset, data))
elif t == CMD_REMOVE: elif t == CMD_REMOVE:
path = msg.get_string() path = msg.get_text()
self._send_status(request_number, self.server.remove(path)) self._send_status(request_number, self.server.remove(path))
elif t == CMD_RENAME: elif t == CMD_RENAME:
oldpath = msg.get_string() oldpath = msg.get_text()
newpath = msg.get_string() newpath = msg.get_text()
self._send_status(request_number, self.server.rename(oldpath, newpath)) self._send_status(request_number, self.server.rename(oldpath, newpath))
elif t == CMD_MKDIR: elif t == CMD_MKDIR:
path = msg.get_string() path = msg.get_text()
attr = SFTPAttributes._from_msg(msg) attr = SFTPAttributes._from_msg(msg)
self._send_status(request_number, self.server.mkdir(path, attr)) self._send_status(request_number, self.server.mkdir(path, attr))
elif t == CMD_RMDIR: elif t == CMD_RMDIR:
path = msg.get_string() path = msg.get_text()
self._send_status(request_number, self.server.rmdir(path)) self._send_status(request_number, self.server.rmdir(path))
elif t == CMD_OPENDIR: elif t == CMD_OPENDIR:
path = msg.get_string() path = msg.get_text()
self._open_folder(request_number, path) self._open_folder(request_number, path)
return return
elif t == CMD_READDIR: elif t == CMD_READDIR:
handle = msg.get_string() handle = msg.get_binary()
if handle not in self.folder_table: if handle not in self.folder_table:
self._send_status(request_number, SFTP_BAD_MESSAGE, 'Invalid handle') self._send_status(request_number, SFTP_BAD_MESSAGE, 'Invalid handle')
return return
folder = self.folder_table[handle] folder = self.folder_table[handle]
self._read_folder(request_number, folder) self._read_folder(request_number, folder)
elif t == CMD_STAT: elif t == CMD_STAT:
path = msg.get_string() path = msg.get_text()
resp = self.server.stat(path) resp = self.server.stat(path)
if issubclass(type(resp), SFTPAttributes): if issubclass(type(resp), SFTPAttributes):
self._response(request_number, CMD_ATTRS, resp) self._response(request_number, CMD_ATTRS, resp)
else: else:
self._send_status(request_number, resp) self._send_status(request_number, resp)
elif t == CMD_LSTAT: elif t == CMD_LSTAT:
path = msg.get_string() path = msg.get_text()
resp = self.server.lstat(path) resp = self.server.lstat(path)
if issubclass(type(resp), SFTPAttributes): if issubclass(type(resp), SFTPAttributes):
self._response(request_number, CMD_ATTRS, resp) self._response(request_number, CMD_ATTRS, resp)
else: else:
self._send_status(request_number, resp) self._send_status(request_number, resp)
elif t == CMD_FSTAT: elif t == CMD_FSTAT:
handle = msg.get_string() handle = msg.get_binary()
if handle not in self.file_table: if handle not in self.file_table:
self._send_status(request_number, SFTP_BAD_MESSAGE, 'Invalid handle') self._send_status(request_number, SFTP_BAD_MESSAGE, 'Invalid handle')
return return
@ -407,18 +407,18 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
else: else:
self._send_status(request_number, resp) self._send_status(request_number, resp)
elif t == CMD_SETSTAT: elif t == CMD_SETSTAT:
path = msg.get_string() path = msg.get_text()
attr = SFTPAttributes._from_msg(msg) attr = SFTPAttributes._from_msg(msg)
self._send_status(request_number, self.server.chattr(path, attr)) self._send_status(request_number, self.server.chattr(path, attr))
elif t == CMD_FSETSTAT: elif t == CMD_FSETSTAT:
handle = msg.get_string() handle = msg.get_binary()
attr = SFTPAttributes._from_msg(msg) attr = SFTPAttributes._from_msg(msg)
if handle not in self.file_table: if handle not in self.file_table:
self._response(request_number, SFTP_BAD_MESSAGE, 'Invalid handle') self._response(request_number, SFTP_BAD_MESSAGE, 'Invalid handle')
return return
self._send_status(request_number, self.file_table[handle].chattr(attr)) self._send_status(request_number, self.file_table[handle].chattr(attr))
elif t == CMD_READLINK: elif t == CMD_READLINK:
path = msg.get_string() path = msg.get_text()
resp = self.server.readlink(path) resp = self.server.readlink(path)
if type(resp) is str: if type(resp) is str:
self._response(request_number, CMD_NAME, 1, resp, '', SFTPAttributes()) self._response(request_number, CMD_NAME, 1, resp, '', SFTPAttributes())
@ -426,15 +426,15 @@ class SFTPServer (BaseSFTP, SubsystemHandler):
self._send_status(request_number, resp) self._send_status(request_number, resp)
elif t == CMD_SYMLINK: elif t == CMD_SYMLINK:
# the sftp 2 draft is incorrect here! path always follows target_path # the sftp 2 draft is incorrect here! path always follows target_path
target_path = msg.get_string() target_path = msg.get_text()
path = msg.get_string() path = msg.get_text()
self._send_status(request_number, self.server.symlink(target_path, path)) self._send_status(request_number, self.server.symlink(target_path, path))
elif t == CMD_REALPATH: elif t == CMD_REALPATH:
path = msg.get_string() path = msg.get_text()
rpath = self.server.canonicalize(path) rpath = self.server.canonicalize(path)
self._response(request_number, CMD_NAME, 1, rpath, '', SFTPAttributes()) self._response(request_number, CMD_NAME, 1, rpath, '', SFTPAttributes())
elif t == CMD_EXTENDED: elif t == CMD_EXTENDED:
tag = msg.get_string() tag = msg.get_text()
if tag == 'check-file': if tag == 'check-file':
self._check_file(request_number, msg) self._check_file(request_number, msg)
else: else:

View File

@ -112,8 +112,8 @@ class SecurityOptions (object):
x = tuple(x) x = tuple(x)
if type(x) is not tuple: if type(x) is not tuple:
raise TypeError('expected tuple or list') raise TypeError('expected tuple or list')
possible = getattr(self._transport, orig).keys() possible = list(getattr(self._transport, orig).keys())
forbidden = filter(lambda n: n not in possible, x) forbidden = [n for n in x if n not in possible]
if len(forbidden) > 0: if len(forbidden) > 0:
raise ValueError('unknown cipher') raise ValueError('unknown cipher')
setattr(self._transport, name, x) setattr(self._transport, name, x)
@ -276,7 +276,7 @@ class Transport (threading.Thread):
@param sock: a socket or socket-like object to create the session over. @param sock: a socket or socket-like object to create the session over.
@type sock: socket @type sock: socket
""" """
if isinstance(sock, (str, unicode)): if isinstance(sock, string_types):
# convert "host:port" into (host, port) # convert "host:port" into (host, port)
hl = sock.split(':', 1) hl = sock.split(':', 1)
if len(hl) == 1: if len(hl) == 1:
@ -735,7 +735,7 @@ class Transport (threading.Thread):
try: try:
chanid = self._next_channel() chanid = self._next_channel()
m = Message() m = Message()
m.add_byte(chr(MSG_CHANNEL_OPEN)) m.add_byte(cMSG_CHANNEL_OPEN)
m.add_string(kind) m.add_string(kind)
m.add_int(chanid) m.add_int(chanid)
m.add_int(self.window_size) m.add_int(self.window_size)
@ -861,7 +861,7 @@ class Transport (threading.Thread):
@type byte_count: int @type byte_count: int
""" """
m = Message() m = Message()
m.add_byte(chr(MSG_IGNORE)) m.add_byte(cMSG_IGNORE)
if byte_count is None: if byte_count is None:
byte_count = (byte_ord(rng.read(1)) % 32) + 10 byte_count = (byte_ord(rng.read(1)) % 32) + 10
m.add_bytes(rng.read(byte_count)) m.add_bytes(rng.read(byte_count))
@ -927,7 +927,7 @@ class Transport (threading.Thread):
if wait: if wait:
self.completion_event = threading.Event() self.completion_event = threading.Event()
m = Message() m = Message()
m.add_byte(chr(MSG_GLOBAL_REQUEST)) m.add_byte(cMSG_GLOBAL_REQUEST)
m.add_string(kind) m.add_string(kind)
m.add_boolean(wait) m.add_boolean(wait)
if data is not None: if data is not None:
@ -1013,10 +1013,10 @@ class Transport (threading.Thread):
# check host key if we were given one # check host key if we were given one
if (hostkey is not None): if (hostkey is not None):
key = self.get_remote_server_key() key = self.get_remote_server_key()
if (key.get_name() != hostkey.get_name()) or (str(key) != str(hostkey)): if (key.get_name() != hostkey.get_name()) or (key.asbytes() != hostkey.asbytes()):
self._log(DEBUG, 'Bad host key from server') self._log(DEBUG, 'Bad host key from server')
self._log(DEBUG, 'Expected: %s: %s' % (hostkey.get_name(), repr(str(hostkey)))) self._log(DEBUG, 'Expected: %s: %s' % (hostkey.get_name(), repr(hostkey.asbytes())))
self._log(DEBUG, 'Got : %s: %s' % (key.get_name(), repr(str(key)))) self._log(DEBUG, 'Got : %s: %s' % (key.get_name(), repr(key.asbytes())))
raise SSHException('Bad host key from server') raise SSHException('Bad host key from server')
self._log(DEBUG, 'Host key verified (%s)' % hostkey.get_name()) self._log(DEBUG, 'Host key verified (%s)' % hostkey.get_name())
@ -1476,15 +1476,15 @@ class Transport (threading.Thread):
m = Message() m = Message()
m.add_mpint(self.K) m.add_mpint(self.K)
m.add_bytes(self.H) m.add_bytes(self.H)
m.add_byte(id) m.add_byte(b(id))
m.add_bytes(self.session_id) m.add_bytes(self.session_id)
out = sofar = SHA.new(str(m)).digest() out = sofar = SHA.new(m.asbytes()).digest()
while len(out) < nbytes: while len(out) < nbytes:
m = Message() m = Message()
m.add_mpint(self.K) m.add_mpint(self.K)
m.add_bytes(self.H) m.add_bytes(self.H)
m.add_bytes(sofar) m.add_bytes(sofar)
digest = SHA.new(str(m)).digest() digest = SHA.new(m.asbytes()).digest()
out += digest out += digest
sofar += digest sofar += digest
return out[:nbytes] return out[:nbytes]
@ -1606,7 +1606,7 @@ class Transport (threading.Thread):
else: else:
self._log(WARNING, 'Oops, unhandled type %d' % ptype) self._log(WARNING, 'Oops, unhandled type %d' % ptype)
msg = Message() msg = Message()
msg.add_byte(chr(MSG_UNIMPLEMENTED)) msg.add_byte(cMSG_UNIMPLEMENTED)
msg.add_int(m.seqno) msg.add_int(m.seqno)
self._send_message(msg) self._send_message(msg)
except SSHException: except SSHException:
@ -1633,7 +1633,7 @@ class Transport (threading.Thread):
self._log(ERROR, util.tb_strings()) self._log(ERROR, util.tb_strings())
self.saved_exception = e self.saved_exception = e
_active_threads.remove(self) _active_threads.remove(self)
for chan in self._channels.values(): for chan in list(self._channels.values()):
chan._unlink() chan._unlink()
if self.active: if self.active:
self.active = False self.active = False
@ -1642,7 +1642,7 @@ class Transport (threading.Thread):
self.completion_event.set() self.completion_event.set()
if self.auth_handler is not None: if self.auth_handler is not None:
self.auth_handler.abort() self.auth_handler.abort()
for event in self.channel_events.values(): for event in list(self.channel_events.values()):
event.set() event.set()
try: try:
self.lock.acquire() self.lock.acquire()
@ -1731,13 +1731,13 @@ class Transport (threading.Thread):
pkex = list(self.get_security_options().kex) pkex = list(self.get_security_options().kex)
pkex.remove('diffie-hellman-group-exchange-sha1') pkex.remove('diffie-hellman-group-exchange-sha1')
self.get_security_options().kex = pkex self.get_security_options().kex = pkex
available_server_keys = filter(self.server_key_dict.keys().__contains__, available_server_keys = list(filter(list(self.server_key_dict.keys()).__contains__,
self._preferred_keys) self._preferred_keys))
else: else:
available_server_keys = self._preferred_keys available_server_keys = self._preferred_keys
m = Message() m = Message()
m.add_byte(chr(MSG_KEXINIT)) m.add_byte(cMSG_KEXINIT)
m.add_bytes(rng.read(16)) m.add_bytes(rng.read(16))
m.add_list(self._preferred_kex) m.add_list(self._preferred_kex)
m.add_list(available_server_keys) m.add_list(available_server_keys)
@ -1752,7 +1752,7 @@ class Transport (threading.Thread):
m.add_boolean(False) m.add_boolean(False)
m.add_int(0) m.add_int(0)
# save a copy for later (needed to compute a hash) # save a copy for later (needed to compute a hash)
self.local_kex_init = str(m) self.local_kex_init = m.asbytes()
self._send_message(m) self._send_message(m)
def _parse_kex_init(self, m): def _parse_kex_init(self, m):
@ -1850,7 +1850,7 @@ class Transport (threading.Thread):
# actually some extra bytes (one NUL byte in openssh's case) added to # actually some extra bytes (one NUL byte in openssh's case) added to
# the end of the packet but not parsed. turns out we need to throw # the end of the packet but not parsed. turns out we need to throw
# away those bytes because they aren't part of the hash. # away those bytes because they aren't part of the hash.
self.remote_kex_init = chr(MSG_KEXINIT) + m.get_so_far() self.remote_kex_init = cMSG_KEXINIT + m.get_so_far()
def _activate_inbound(self): def _activate_inbound(self):
"switch on newly negotiated encryption parameters for inbound traffic" "switch on newly negotiated encryption parameters for inbound traffic"
@ -1879,7 +1879,7 @@ class Transport (threading.Thread):
def _activate_outbound(self): def _activate_outbound(self):
"switch on newly negotiated encryption parameters for outbound traffic" "switch on newly negotiated encryption parameters for outbound traffic"
m = Message() m = Message()
m.add_byte(chr(MSG_NEWKEYS)) m.add_byte(cMSG_NEWKEYS)
self._send_message(m) self._send_message(m)
block_size = self._cipher_info[self.local_cipher]['block-size'] block_size = self._cipher_info[self.local_cipher]['block-size']
if self.server_mode: if self.server_mode:
@ -1952,20 +1952,20 @@ class Transport (threading.Thread):
self._log(INFO, 'Disconnect (code %d): %s' % (code, desc)) self._log(INFO, 'Disconnect (code %d): %s' % (code, desc))
def _parse_global_request(self, m): def _parse_global_request(self, m):
kind = m.get_string() kind = m.get_text()
self._log(DEBUG, 'Received global request "%s"' % kind) self._log(DEBUG, 'Received global request "%s"' % kind)
want_reply = m.get_boolean() want_reply = m.get_boolean()
if not self.server_mode: if not self.server_mode:
self._log(DEBUG, 'Rejecting "%s" global request from server.' % kind) self._log(DEBUG, 'Rejecting "%s" global request from server.' % kind)
ok = False ok = False
elif kind == 'tcpip-forward': elif kind == 'tcpip-forward':
address = m.get_string() address = m.get_text()
port = m.get_int() port = m.get_int()
ok = self.server_object.check_port_forward_request(address, port) ok = self.server_object.check_port_forward_request(address, port)
if ok != False: if ok != False:
ok = (ok,) ok = (ok,)
elif kind == 'cancel-tcpip-forward': elif kind == 'cancel-tcpip-forward':
address = m.get_string() address = m.get_text()
port = m.get_int() port = m.get_int()
self.server_object.cancel_port_forward_request(address, port) self.server_object.cancel_port_forward_request(address, port)
ok = True ok = True
@ -1978,10 +1978,10 @@ class Transport (threading.Thread):
if want_reply: if want_reply:
msg = Message() msg = Message()
if ok: if ok:
msg.add_byte(chr(MSG_REQUEST_SUCCESS)) msg.add_byte(cMSG_REQUEST_SUCCESS)
msg.add(*extra) msg.add(*extra)
else: else:
msg.add_byte(chr(MSG_REQUEST_FAILURE)) msg.add_byte(cMSG_REQUEST_FAILURE)
self._send_message(msg) self._send_message(msg)
def _parse_request_success(self, m): def _parse_request_success(self, m):
@ -2019,8 +2019,8 @@ class Transport (threading.Thread):
def _parse_channel_open_failure(self, m): def _parse_channel_open_failure(self, m):
chanid = m.get_int() chanid = m.get_int()
reason = m.get_int() reason = m.get_int()
reason_str = m.get_string() reason_str = m.get_text()
lang = m.get_string() lang = m.get_text()
reason_text = CONNECTION_FAILED_CODE.get(reason, '(unknown code)') reason_text = CONNECTION_FAILED_CODE.get(reason, '(unknown code)')
self._log(INFO, 'Secsh channel %d open FAILED: %s: %s' % (chanid, reason_str, reason_text)) self._log(INFO, 'Secsh channel %d open FAILED: %s: %s' % (chanid, reason_str, reason_text))
self.lock.acquire() self.lock.acquire()
@ -2036,7 +2036,7 @@ class Transport (threading.Thread):
return return
def _parse_channel_open(self, m): def _parse_channel_open(self, m):
kind = m.get_string() kind = m.get_text()
chanid = m.get_int() chanid = m.get_int()
initial_window_size = m.get_int() initial_window_size = m.get_int()
max_packet_size = m.get_int() max_packet_size = m.get_int()
@ -2049,7 +2049,7 @@ class Transport (threading.Thread):
finally: finally:
self.lock.release() self.lock.release()
elif (kind == 'x11') and (self._x11_handler is not None): elif (kind == 'x11') and (self._x11_handler is not None):
origin_addr = m.get_string() origin_addr = m.get_text()
origin_port = m.get_int() origin_port = m.get_int()
self._log(DEBUG, 'Incoming x11 connection from %s:%d' % (origin_addr, origin_port)) self._log(DEBUG, 'Incoming x11 connection from %s:%d' % (origin_addr, origin_port))
self.lock.acquire() self.lock.acquire()
@ -2058,9 +2058,9 @@ class Transport (threading.Thread):
finally: finally:
self.lock.release() self.lock.release()
elif (kind == 'forwarded-tcpip') and (self._tcp_handler is not None): elif (kind == 'forwarded-tcpip') and (self._tcp_handler is not None):
server_addr = m.get_string() server_addr = m.get_text()
server_port = m.get_int() server_port = m.get_int()
origin_addr = m.get_string() origin_addr = m.get_text()
origin_port = m.get_int() origin_port = m.get_int()
self._log(DEBUG, 'Incoming tcp forwarded connection from %s:%d' % (origin_addr, origin_port)) self._log(DEBUG, 'Incoming tcp forwarded connection from %s:%d' % (origin_addr, origin_port))
self.lock.acquire() self.lock.acquire()
@ -2080,9 +2080,9 @@ class Transport (threading.Thread):
self.lock.release() self.lock.release()
if kind == 'direct-tcpip': if kind == 'direct-tcpip':
# handle direct-tcpip requests comming from the client # handle direct-tcpip requests comming from the client
dest_addr = m.get_string() dest_addr = m.get_text()
dest_port = m.get_int() dest_port = m.get_int()
origin_addr = m.get_string() origin_addr = m.get_text()
origin_port = m.get_int() origin_port = m.get_int()
reason = self.server_object.check_channel_direct_tcpip_request( reason = self.server_object.check_channel_direct_tcpip_request(
my_chanid, (origin_addr, origin_port), my_chanid, (origin_addr, origin_port),
@ -2094,7 +2094,7 @@ class Transport (threading.Thread):
reject = True reject = True
if reject: if reject:
msg = Message() msg = Message()
msg.add_byte(chr(MSG_CHANNEL_OPEN_FAILURE)) msg.add_byte(cMSG_CHANNEL_OPEN_FAILURE)
msg.add_int(chanid) msg.add_int(chanid)
msg.add_int(reason) msg.add_int(reason)
msg.add_string('') msg.add_string('')
@ -2113,7 +2113,7 @@ class Transport (threading.Thread):
finally: finally:
self.lock.release() self.lock.release()
m = Message() m = Message()
m.add_byte(chr(MSG_CHANNEL_OPEN_SUCCESS)) m.add_byte(cMSG_CHANNEL_OPEN_SUCCESS)
m.add_int(chanid) m.add_int(chanid)
m.add_int(my_chanid) m.add_int(my_chanid)
m.add_int(self.window_size) m.add_int(self.window_size)

View File

@ -111,8 +111,8 @@ class AuthTest (unittest.TestCase):
self.sockc.close() self.sockc.close()
def start_server(self): def start_server(self):
self.public_host_key = RSAKey(data=str(host_key))
host_key = RSAKey.from_private_key_file(test_path('test_rsa.key')) host_key = RSAKey.from_private_key_file(test_path('test_rsa.key'))
self.public_host_key = RSAKey(data=host_key.asbytes())
self.ts.add_server_key(host_key) self.ts.add_server_key(host_key)
self.event = threading.Event() self.event = threading.Event()
self.server = NullServer() self.server = NullServer()

View File

@ -86,8 +86,8 @@ class SSHClientTest (unittest.TestCase):
""" """
verify that the SSHClient stuff works too. verify that the SSHClient stuff works too.
""" """
public_host_key = paramiko.RSAKey(data=str(host_key))
host_key = paramiko.RSAKey.from_private_key_file(test_path('test_rsa.key')) host_key = paramiko.RSAKey.from_private_key_file(test_path('test_rsa.key'))
public_host_key = paramiko.RSAKey(data=host_key.asbytes())
self.tc = paramiko.SSHClient() self.tc = paramiko.SSHClient()
self.tc.get_host_keys().add('[%s]:%d' % (self.addr, self.port), 'ssh-rsa', public_host_key) self.tc.get_host_keys().add('[%s]:%d' % (self.addr, self.port), 'ssh-rsa', public_host_key)
@ -119,8 +119,8 @@ class SSHClientTest (unittest.TestCase):
""" """
verify that SSHClient works with a DSA key. verify that SSHClient works with a DSA key.
""" """
public_host_key = paramiko.RSAKey(data=str(host_key))
host_key = paramiko.RSAKey.from_private_key_file(test_path('test_rsa.key')) host_key = paramiko.RSAKey.from_private_key_file(test_path('test_rsa.key'))
public_host_key = paramiko.RSAKey(data=host_key.asbytes())
self.tc = paramiko.SSHClient() self.tc = paramiko.SSHClient()
self.tc.get_host_keys().add('[%s]:%d' % (self.addr, self.port), 'ssh-rsa', public_host_key) self.tc.get_host_keys().add('[%s]:%d' % (self.addr, self.port), 'ssh-rsa', public_host_key)
@ -152,8 +152,8 @@ class SSHClientTest (unittest.TestCase):
""" """
verify that SSHClient accepts and tries multiple key files. verify that SSHClient accepts and tries multiple key files.
""" """
public_host_key = paramiko.RSAKey(data=str(host_key))
host_key = paramiko.RSAKey.from_private_key_file(test_path('test_rsa.key')) host_key = paramiko.RSAKey.from_private_key_file(test_path('test_rsa.key'))
public_host_key = paramiko.RSAKey(data=host_key.asbytes())
self.tc = paramiko.SSHClient() self.tc = paramiko.SSHClient()
self.tc.get_host_keys().add('[%s]:%d' % (self.addr, self.port), 'ssh-rsa', public_host_key) self.tc.get_host_keys().add('[%s]:%d' % (self.addr, self.port), 'ssh-rsa', public_host_key)
@ -169,8 +169,8 @@ class SSHClientTest (unittest.TestCase):
""" """
verify that SSHClient's AutoAddPolicy works. verify that SSHClient's AutoAddPolicy works.
""" """
public_host_key = paramiko.RSAKey(data=str(host_key))
host_key = paramiko.RSAKey.from_private_key_file(test_path('test_rsa.key')) host_key = paramiko.RSAKey.from_private_key_file(test_path('test_rsa.key'))
public_host_key = paramiko.RSAKey(data=host_key.asbytes())
self.tc = paramiko.SSHClient() self.tc = paramiko.SSHClient()
self.tc.set_missing_host_key_policy(paramiko.AutoAddPolicy()) self.tc.set_missing_host_key_policy(paramiko.AutoAddPolicy())
@ -190,8 +190,8 @@ class SSHClientTest (unittest.TestCase):
verify that when an SSHClient is collected, its transport (and the verify that when an SSHClient is collected, its transport (and the
transport's packetizer) is closed. transport's packetizer) is closed.
""" """
public_host_key = paramiko.RSAKey(data=str(host_key))
host_key = paramiko.RSAKey.from_private_key_file(test_path('test_rsa.key')) host_key = paramiko.RSAKey.from_private_key_file(test_path('test_rsa.key'))
public_host_key = paramiko.RSAKey(data=host_key.asbytes())
self.tc = paramiko.SSHClient() self.tc = paramiko.SSHClient()
self.tc.set_missing_host_key_policy(paramiko.AutoAddPolicy()) self.tc.set_missing_host_key_policy(paramiko.AutoAddPolicy())

View File

@ -65,8 +65,8 @@ class HostKeysTest (unittest.TestCase):
def test_1_load(self): def test_1_load(self):
hostdict = paramiko.HostKeys('hostfile.temp') hostdict = paramiko.HostKeys('hostfile.temp')
self.assertEquals(2, len(hostdict)) self.assertEquals(2, len(hostdict))
self.assertEquals(1, len(hostdict.values()[0])) self.assertEquals(1, len(list(hostdict.values())[0]))
self.assertEquals(1, len(hostdict.values()[1])) self.assertEquals(1, len(list(hostdict.values())[1]))
fp = hexlify(hostdict['secure.example.com']['ssh-rsa'].get_fingerprint()).upper() fp = hexlify(hostdict['secure.example.com']['ssh-rsa'].get_fingerprint()).upper()
self.assertEquals('E6684DB30E109B67B70FF1DC5C7F1363', fp) self.assertEquals('E6684DB30E109B67B70FF1DC5C7F1363', fp)
@ -75,7 +75,7 @@ class HostKeysTest (unittest.TestCase):
hh = '|1|BMsIC6cUIP2zBuXR3t2LRcJYjzM=|hpkJMysjTk/+zzUUzxQEa2ieq6c=' hh = '|1|BMsIC6cUIP2zBuXR3t2LRcJYjzM=|hpkJMysjTk/+zzUUzxQEa2ieq6c='
key = paramiko.RSAKey(data=base64.decodestring(keyblob)) key = paramiko.RSAKey(data=base64.decodestring(keyblob))
hostdict.add(hh, 'ssh-rsa', key) hostdict.add(hh, 'ssh-rsa', key)
self.assertEquals(3, len(hostdict)) self.assertEquals(3, len(list(hostdict)))
x = hostdict['foo.example.com'] x = hostdict['foo.example.com']
fp = hexlify(x['ssh-rsa'].get_fingerprint()).upper() fp = hexlify(x['ssh-rsa'].get_fingerprint()).upper()
self.assertEquals('7EC91BB336CB6D810B124B1353C32396', fp) self.assertEquals('7EC91BB336CB6D810B124B1353C32396', fp)
@ -85,8 +85,8 @@ class HostKeysTest (unittest.TestCase):
hostdict = paramiko.HostKeys('hostfile.temp') hostdict = paramiko.HostKeys('hostfile.temp')
self.assert_('secure.example.com' in hostdict) self.assert_('secure.example.com' in hostdict)
self.assert_('not.example.com' not in hostdict) self.assert_('not.example.com' not in hostdict)
self.assert_(hostdict.has_key('secure.example.com')) self.assert_('secure.example.com' in hostdict)
self.assert_(not hostdict.has_key('not.example.com')) self.assert_('not.example.com' not in hostdict)
x = hostdict.get('secure.example.com', None) x = hostdict.get('secure.example.com', None)
self.assert_(x is not None) self.assert_(x is not None)
fp = hexlify(x['ssh-rsa'].get_fingerprint()).upper() fp = hexlify(x['ssh-rsa'].get_fingerprint()).upper()
@ -108,9 +108,9 @@ class HostKeysTest (unittest.TestCase):
hostdict['fake.example.com']['ssh-rsa'] = key hostdict['fake.example.com']['ssh-rsa'] = key
self.assertEquals(3, len(hostdict)) self.assertEquals(3, len(hostdict))
self.assertEquals(2, len(hostdict.values()[0])) self.assertEquals(2, len(list(hostdict.values())[0]))
self.assertEquals(1, len(hostdict.values()[1])) self.assertEquals(1, len(list(hostdict.values())[1]))
self.assertEquals(1, len(hostdict.values()[2])) self.assertEquals(1, len(list(hostdict.values())[2]))
fp = hexlify(hostdict['secure.example.com']['ssh-rsa'].get_fingerprint()).upper() fp = hexlify(hostdict['secure.example.com']['ssh-rsa'].get_fingerprint()).upper()
self.assertEquals('7EC91BB336CB6D810B124B1353C32396', fp) self.assertEquals('7EC91BB336CB6D810B124B1353C32396', fp)
fp = hexlify(hostdict['secure.example.com']['ssh-dss'].get_fingerprint()).upper() fp = hexlify(hostdict['secure.example.com']['ssh-dss'].get_fingerprint()).upper()

View File

@ -37,6 +37,8 @@ class FakeRng (object):
class FakeKey (object): class FakeKey (object):
def __str__(self): def __str__(self):
return 'fake-key' return 'fake-key'
def asbytes(self):
return b('fake-key')
def sign_ssh_data(self, rng, H): def sign_ssh_data(self, rng, H):
return 'fake-sig' return 'fake-sig'
@ -90,7 +92,7 @@ class KexTest (unittest.TestCase):
kex = KexGroup1(transport) kex = KexGroup1(transport)
kex.start_kex() kex.start_kex()
x = '1E000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D4' x = '1E000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D4'
self.assertEquals(x, hexlify(str(transport._message)).upper()) self.assertEquals(x, hexlify(transport._message.asbytes()).upper())
self.assertEquals((paramiko.kex_group1._MSG_KEXDH_REPLY,), transport._expect) self.assertEquals((paramiko.kex_group1._MSG_KEXDH_REPLY,), transport._expect)
# fake "reply" # fake "reply"
@ -121,7 +123,7 @@ class KexTest (unittest.TestCase):
x = '1F0000000866616B652D6B6579000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D40000000866616B652D736967' x = '1F0000000866616B652D6B6579000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D40000000866616B652D736967'
self.assertEquals(self.K, transport._K) self.assertEquals(self.K, transport._K)
self.assertEquals(H, hexlify(transport._H).upper()) self.assertEquals(H, hexlify(transport._H).upper())
self.assertEquals(x, hexlify(str(transport._message)).upper()) self.assertEquals(x, hexlify(transport._message.asbytes()).upper())
self.assert_(transport._activated) self.assert_(transport._activated)
def test_3_gex_client(self): def test_3_gex_client(self):
@ -130,7 +132,7 @@ class KexTest (unittest.TestCase):
kex = KexGex(transport) kex = KexGex(transport)
kex.start_kex() kex.start_kex()
x = '22000004000000080000002000' x = '22000004000000080000002000'
self.assertEquals(x, hexlify(str(transport._message)).upper()) self.assertEquals(x, hexlify(transport._message.asbytes()).upper())
self.assertEquals((paramiko.kex_gex._MSG_KEXDH_GEX_GROUP,), transport._expect) self.assertEquals((paramiko.kex_gex._MSG_KEXDH_GEX_GROUP,), transport._expect)
msg = Message() msg = Message()
@ -139,7 +141,7 @@ class KexTest (unittest.TestCase):
msg.rewind() msg.rewind()
kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_GROUP, msg) kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_GROUP, msg)
x = '20000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D4' x = '20000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D4'
self.assertEquals(x, hexlify(str(transport._message)).upper()) self.assertEquals(x, hexlify(transport._message.asbytes()).upper())
self.assertEquals((paramiko.kex_gex._MSG_KEXDH_GEX_REPLY,), transport._expect) self.assertEquals((paramiko.kex_gex._MSG_KEXDH_GEX_REPLY,), transport._expect)
msg = Message() msg = Message()
@ -160,7 +162,7 @@ class KexTest (unittest.TestCase):
kex = KexGex(transport) kex = KexGex(transport)
kex.start_kex(_test_old_style=True) kex.start_kex(_test_old_style=True)
x = '1E00000800' x = '1E00000800'
self.assertEquals(x, hexlify(str(transport._message)).upper()) self.assertEquals(x, hexlify(transport._message.asbytes()).upper())
self.assertEquals((paramiko.kex_gex._MSG_KEXDH_GEX_GROUP,), transport._expect) self.assertEquals((paramiko.kex_gex._MSG_KEXDH_GEX_GROUP,), transport._expect)
msg = Message() msg = Message()
@ -169,7 +171,7 @@ class KexTest (unittest.TestCase):
msg.rewind() msg.rewind()
kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_GROUP, msg) kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_GROUP, msg)
x = '20000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D4' x = '20000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D4'
self.assertEquals(x, hexlify(str(transport._message)).upper()) self.assertEquals(x, hexlify(transport._message.asbytes()).upper())
self.assertEquals((paramiko.kex_gex._MSG_KEXDH_GEX_REPLY,), transport._expect) self.assertEquals((paramiko.kex_gex._MSG_KEXDH_GEX_REPLY,), transport._expect)
msg = Message() msg = Message()
@ -198,19 +200,19 @@ class KexTest (unittest.TestCase):
msg.rewind() msg.rewind()
kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST, msg) kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST, msg)
x = '1F0000008100FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF0000000102' x = '1F0000008100FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF0000000102'
self.assertEquals(x, hexlify(str(transport._message)).upper()) self.assertEquals(x, hexlify(transport._message.asbytes()).upper())
self.assertEquals((paramiko.kex_gex._MSG_KEXDH_GEX_INIT,), transport._expect) self.assertEquals((paramiko.kex_gex._MSG_KEXDH_GEX_INIT,), transport._expect)
msg = Message() msg = Message()
msg.add_mpint(12345) msg.add_mpint(12345)
msg.rewind() msg.rewind()
kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_INIT, msg) kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_INIT, msg)
K = 67592995013596137876033460028393339951879041140378510871612128162185209509220726296697886624612526735888348020498716482757677848959420073720160491114319163078862905400020959196386947926388406687288901564192071077389283980347784184487280885335302632305026248574716290537036069329724382811853044654824945750581L K = 67592995013596137876033460028393339951879041140378510871612128162185209509220726296697886624612526735888348020498716482757677848959420073720160491114319163078862905400020959196386947926388406687288901564192071077389283980347784184487280885335302632305026248574716290537036069329724382811853044654824945750581
H = 'CE754197C21BF3452863B4F44D0B3951F12516EF' H = 'CE754197C21BF3452863B4F44D0B3951F12516EF'
x = '210000000866616B652D6B6579000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D40000000866616B652D736967' x = '210000000866616B652D6B6579000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D40000000866616B652D736967'
self.assertEquals(K, transport._K) self.assertEquals(K, transport._K)
self.assertEquals(H, hexlify(transport._H).upper()) self.assertEquals(H, hexlify(transport._H).upper())
self.assertEquals(x, hexlify(str(transport._message)).upper()) self.assertEquals(x, hexlify(transport._message.asbytes()).upper())
self.assert_(transport._activated) self.assert_(transport._activated)
def test_6_gex_server_with_old_client(self): def test_6_gex_server_with_old_client(self):
@ -225,17 +227,17 @@ class KexTest (unittest.TestCase):
msg.rewind() msg.rewind()
kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST_OLD, msg) kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_REQUEST_OLD, msg)
x = '1F0000008100FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF0000000102' x = '1F0000008100FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF0000000102'
self.assertEquals(x, hexlify(str(transport._message)).upper()) self.assertEquals(x, hexlify(transport._message.asbytes()).upper())
self.assertEquals((paramiko.kex_gex._MSG_KEXDH_GEX_INIT,), transport._expect) self.assertEquals((paramiko.kex_gex._MSG_KEXDH_GEX_INIT,), transport._expect)
msg = Message() msg = Message()
msg.add_mpint(12345) msg.add_mpint(12345)
msg.rewind() msg.rewind()
kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_INIT, msg) kex.parse_next(paramiko.kex_gex._MSG_KEXDH_GEX_INIT, msg)
K = 67592995013596137876033460028393339951879041140378510871612128162185209509220726296697886624612526735888348020498716482757677848959420073720160491114319163078862905400020959196386947926388406687288901564192071077389283980347784184487280885335302632305026248574716290537036069329724382811853044654824945750581L K = 67592995013596137876033460028393339951879041140378510871612128162185209509220726296697886624612526735888348020498716482757677848959420073720160491114319163078862905400020959196386947926388406687288901564192071077389283980347784184487280885335302632305026248574716290537036069329724382811853044654824945750581
H = 'B41A06B2E59043CEFC1AE16EC31F1E2D12EC455B' H = 'B41A06B2E59043CEFC1AE16EC31F1E2D12EC455B'
x = '210000000866616B652D6B6579000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D40000000866616B652D736967' x = '210000000866616B652D6B6579000000807E2DDB1743F3487D6545F04F1C8476092FB912B013626AB5BCEB764257D88BBA64243B9F348DF7B41B8C814A995E00299913503456983FFB9178D3CD79EB6D55522418A8ABF65375872E55938AB99A84A0B5FC8A1ECC66A7C3766E7E0F80B7CE2C9225FC2DD683F4764244B72963BBB383F529DCF0C5D17740B8A2ADBE9208D40000000866616B652D736967'
self.assertEquals(K, transport._K) self.assertEquals(K, transport._K)
self.assertEquals(H, hexlify(transport._H).upper()) self.assertEquals(H, hexlify(transport._H).upper())
self.assertEquals(x, hexlify(str(transport._message)).upper()) self.assertEquals(x, hexlify(transport._message.asbytes()).upper())
self.assert_(transport._activated) self.assert_(transport._activated)

View File

@ -27,10 +27,16 @@ from paramiko.common import *
class MessageTest (unittest.TestCase): class MessageTest (unittest.TestCase):
__a = '\x00\x00\x00\x17\x07\x60\xe0\x90\x00\x00\x00\x01q\x00\x00\x00\x05hello\x00\x00\x03\xe8' + ('x' * 1000) if PY3:
__b = '\x01\x00\xf3\x00\x3f\x00\x00\x00\x10huey,dewey,louie' __a = b'\x00\x00\x00\x17\x07\x60\xe0\x90\x00\x00\x00\x01q\x00\x00\x00\x05hello\x00\x00\x03\xe8' + (b'x' * 1000)
__c = '\x00\x00\x00\x00\x00\x00\x00\x05\x00\x00\xf5\xe4\xd3\xc2\xb1\x09\x00\x00\x00\x01\x11\x00\x00\x00\x07\x00\xf5\xe4\xd3\xc2\xb1\x09\x00\x00\x00\x06\x9a\x1b\x2c\x3d\x4e\xf7' __b = b'\x01\x00\xf3\x00\x3f\x00\x00\x00\x10huey,dewey,louie'
__d = '\x00\x00\x00\x05\x00\x00\x00\x05\x11\x22\x33\x44\x55\x01\x00\x00\x00\x03cat\x00\x00\x00\x03a,b' __c = b'\x00\x00\x00\x05\xff\x00\x00\x00\x07\x00\xf5\xe4\xd3\xc2\xb1\x09\x00\x00\x00\x01\x11\x00\x00\x00\x07\x00\xf5\xe4\xd3\xc2\xb1\x09\x00\x00\x00\x06\x9a\x1b\x2c\x3d\x4e\xf7'
__d = b'\x00\x00\x00\x05\xff\x00\x00\x00\x05\x11\x22\x33\x44\x55\xff\x00\x00\x00\x0a\x00\xf0\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x03cat\x00\x00\x00\x03a,b'
else:
__a = '\x00\x00\x00\x17\x07\x60\xe0\x90\x00\x00\x00\x01q\x00\x00\x00\x05hello\x00\x00\x03\xe8' + ('x' * 1000)
__b = '\x01\x00\xf3\x00\x3f\x00\x00\x00\x10huey,dewey,louie'
__c = '\x00\x00\x00\x05\xff\x00\x00\x00\x07\x00\xf5\xe4\xd3\xc2\xb1\x09\x00\x00\x00\x01\x11\x00\x00\x00\x07\x00\xf5\xe4\xd3\xc2\xb1\x09\x00\x00\x00\x06\x9a\x1b\x2c\x3d\x4e\xf7'
__d = '\x00\x00\x00\x05\xff\x00\x00\x00\x05\x11\x22\x33\x44\x55\xff\x00\x00\x00\x0a\x00\xf0\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x03cat\x00\x00\x00\x03a,b'
def test_1_encode(self): def test_1_encode(self):
msg = Message() msg = Message()
@ -39,63 +45,65 @@ class MessageTest (unittest.TestCase):
msg.add_string('q') msg.add_string('q')
msg.add_string('hello') msg.add_string('hello')
msg.add_string('x' * 1000) msg.add_string('x' * 1000)
self.assertEquals(str(msg), self.__a) self.assertEquals(msg.asbytes(), self.__a)
msg = Message() msg = Message()
msg.add_boolean(True) msg.add_boolean(True)
msg.add_boolean(False) msg.add_boolean(False)
msg.add_byte('\xf3') msg.add_byte(byte_chr(0xf3))
msg.add_bytes('\x00\x3f')
msg.add_bytes(zero_byte + byte_chr(0x3f))
msg.add_list(['huey', 'dewey', 'louie']) msg.add_list(['huey', 'dewey', 'louie'])
self.assertEquals(str(msg), self.__b) self.assertEquals(msg.asbytes(), self.__b)
msg = Message() msg = Message()
msg.add_int64(5) msg.add_int64(5)
msg.add_int64(0xf5e4d3c2b109L) msg.add_int64(0xf5e4d3c2b109)
msg.add_mpint(17) msg.add_mpint(17)
msg.add_mpint(0xf5e4d3c2b109L) msg.add_mpint(0xf5e4d3c2b109)
msg.add_mpint(-0x65e4d3c2b109L) msg.add_mpint(-0x65e4d3c2b109)
self.assertEquals(str(msg), self.__c) self.assertEquals(msg.asbytes(), self.__c)
def test_2_decode(self): def test_2_decode(self):
msg = Message(self.__a) msg = Message(self.__a)
self.assertEquals(msg.get_int(), 23) self.assertEquals(msg.get_int(), 23)
self.assertEquals(msg.get_int(), 123789456) self.assertEquals(msg.get_int(), 123789456)
self.assertEquals(msg.get_string(), 'q') self.assertEquals(msg.get_text(), 'q')
self.assertEquals(msg.get_string(), 'hello') self.assertEquals(msg.get_text(), 'hello')
self.assertEquals(msg.get_string(), 'x' * 1000) self.assertEquals(msg.get_text(), 'x' * 1000)
msg = Message(self.__b) msg = Message(self.__b)
self.assertEquals(msg.get_boolean(), True) self.assertEquals(msg.get_boolean(), True)
self.assertEquals(msg.get_boolean(), False) self.assertEquals(msg.get_boolean(), False)
self.assertEquals(msg.get_byte(), '\xf3') self.assertEquals(msg.get_byte(), byte_chr(0xf3))
self.assertEquals(msg.get_bytes(2), '\x00\x3f') self.assertEquals(msg.get_bytes(2), zero_byte + byte_chr(0x3f))
self.assertEquals(msg.get_list(), ['huey', 'dewey', 'louie']) self.assertEquals(msg.get_list(), ['huey', 'dewey', 'louie'])
msg = Message(self.__c) msg = Message(self.__c)
self.assertEquals(msg.get_int64(), 5) self.assertEquals(msg.get_int64(), 5)
self.assertEquals(msg.get_int64(), 0xf5e4d3c2b109L) self.assertEquals(msg.get_int64(), 0xf5e4d3c2b109)
self.assertEquals(msg.get_mpint(), 17) self.assertEquals(msg.get_mpint(), 17)
self.assertEquals(msg.get_mpint(), 0xf5e4d3c2b109L) self.assertEquals(msg.get_mpint(), 0xf5e4d3c2b109)
self.assertEquals(msg.get_mpint(), -0x65e4d3c2b109L) self.assertEquals(msg.get_mpint(), -0x65e4d3c2b109)
def test_3_add(self): def test_3_add(self):
msg = Message() msg = Message()
msg.add(5) msg.add(5)
msg.add(0x1122334455L) msg.add(0x1122334455)
msg.add(0xf00000000000000000)
msg.add(True) msg.add(True)
msg.add('cat') msg.add('cat')
msg.add(['a', 'b']) msg.add(['a', 'b'])
self.assertEquals(str(msg), self.__d) self.assertEquals(msg.asbytes(), self.__d)
def test_4_misc(self): def test_4_misc(self):
msg = Message(self.__d) msg = Message(self.__d)
self.assertEquals(msg.get_int(), 5) self.assertEquals(msg.get_int(), 5)
self.assertEquals(msg.get_mpint(), 0x1122334455L) self.assertEquals(msg.get_int(), 0x1122334455)
self.assertEquals(msg.get_so_far(), self.__d[:13]) self.assertEquals(msg.get_int(), 0xf00000000000000000)
self.assertEquals(msg.get_remainder(), self.__d[13:]) self.assertEquals(msg.get_so_far(), self.__d[:29])
self.assertEquals(msg.get_remainder(), self.__d[29:])
msg.rewind() msg.rewind()
self.assertEquals(msg.get_int(), 5) self.assertEquals(msg.get_int(), 5)
self.assertEquals(msg.get_so_far(), self.__d[:4]) self.assertEquals(msg.get_so_far(), self.__d[:4])
self.assertEquals(msg.get_remainder(), self.__d[4:]) self.assertEquals(msg.get_remainder(), self.__d[4:])

View File

@ -42,7 +42,7 @@ class PacketizerTest (unittest.TestCase):
# message has to be at least 16 bytes long, so we'll have at least one # message has to be at least 16 bytes long, so we'll have at least one
# block of data encrypted that contains zero random padding bytes # block of data encrypted that contains zero random padding bytes
m = Message() m = Message()
m.add_byte(chr(100)) m.add_byte(byte_chr(100))
m.add_int(100) m.add_int(100)
m.add_int(1) m.add_int(1)
m.add_int(900) m.add_int(900)

View File

@ -144,7 +144,7 @@ class KeyTest (unittest.TestCase):
# verify that the private & public keys compare equal # verify that the private & public keys compare equal
key = RSAKey.from_private_key_file('tests/test_rsa.key') key = RSAKey.from_private_key_file('tests/test_rsa.key')
self.assertEquals(key, key) self.assertEquals(key, key)
pub = RSAKey(data=str(key)) pub = RSAKey(data=key.asbytes())
self.assert_(key.can_sign()) self.assert_(key.can_sign())
self.assert_(not pub.can_sign()) self.assert_(not pub.can_sign())
self.assertEquals(key, pub) self.assertEquals(key, pub)
@ -153,7 +153,7 @@ class KeyTest (unittest.TestCase):
# verify that the private & public keys compare equal # verify that the private & public keys compare equal
key = DSSKey.from_private_key_file('tests/test_dss.key') key = DSSKey.from_private_key_file('tests/test_dss.key')
self.assertEquals(key, key) self.assertEquals(key, key)
pub = DSSKey(data=str(key)) pub = DSSKey(data=key.asbytes())
self.assert_(key.can_sign()) self.assert_(key.can_sign())
self.assert_(not pub.can_sign()) self.assert_(not pub.can_sign())
self.assertEquals(key, pub) self.assertEquals(key, pub)
@ -164,11 +164,11 @@ class KeyTest (unittest.TestCase):
msg = key.sign_ssh_data(rng, 'ice weasels') msg = key.sign_ssh_data(rng, 'ice weasels')
self.assert_(type(msg) is Message) self.assert_(type(msg) is Message)
msg.rewind() msg.rewind()
self.assertEquals('ssh-rsa', msg.get_string()) self.assertEquals('ssh-rsa', msg.get_text())
sig = ''.join([chr(int(x, 16)) for x in SIGNED_RSA.split(':')]) sig = bytes().join([byte_chr(int(x, 16)) for x in SIGNED_RSA.split(':')])
self.assertEquals(sig, msg.get_string()) self.assertEquals(sig, msg.get_binary())
msg.rewind() msg.rewind()
pub = RSAKey(data=str(key)) pub = RSAKey(data=key.asbytes())
self.assert_(pub.verify_ssh_sig('ice weasels', msg)) self.assert_(pub.verify_ssh_sig('ice weasels', msg))
def test_9_sign_dss(self): def test_9_sign_dss(self):
@ -177,13 +177,13 @@ class KeyTest (unittest.TestCase):
msg = key.sign_ssh_data(rng, 'ice weasels') msg = key.sign_ssh_data(rng, 'ice weasels')
self.assert_(type(msg) is Message) self.assert_(type(msg) is Message)
msg.rewind() msg.rewind()
self.assertEquals('ssh-dss', msg.get_string()) self.assertEquals('ssh-dss', msg.get_text())
# can't do the same test as we do for RSA, because DSS signatures # can't do the same test as we do for RSA, because DSS signatures
# are usually different each time. but we can test verification # are usually different each time. but we can test verification
# anyway so it's ok. # anyway so it's ok.
self.assertEquals(40, len(msg.get_string())) self.assertEquals(40, len(msg.get_binary()))
msg.rewind() msg.rewind()
pub = DSSKey(data=str(key)) pub = DSSKey(data=key.asbytes())
self.assert_(pub.verify_ssh_sig('ice weasels', msg)) self.assert_(pub.verify_ssh_sig('ice weasels', msg))
def test_A_generate_rsa(self): def test_A_generate_rsa(self):
@ -227,7 +227,7 @@ class KeyTest (unittest.TestCase):
# verify that the private & public keys compare equal # verify that the private & public keys compare equal
key = ECDSAKey.from_private_key_file('tests/test_ecdsa.key') key = ECDSAKey.from_private_key_file('tests/test_ecdsa.key')
self.assertEquals(key, key) self.assertEquals(key, key)
pub = ECDSAKey(data=str(key)) pub = ECDSAKey(data=key.asbytes())
self.assert_(key.can_sign()) self.assert_(key.can_sign())
self.assert_(not pub.can_sign()) self.assert_(not pub.can_sign())
self.assertEquals(key, pub) self.assertEquals(key, pub)
@ -238,12 +238,12 @@ class KeyTest (unittest.TestCase):
msg = key.sign_ssh_data(rng, 'ice weasels') msg = key.sign_ssh_data(rng, 'ice weasels')
self.assert_(type(msg) is Message) self.assert_(type(msg) is Message)
msg.rewind() msg.rewind()
self.assertEquals('ecdsa-sha2-nistp256', msg.get_string()) self.assertEquals('ecdsa-sha2-nistp256', msg.get_text())
# ECDSA signatures, like DSS signatures, tend to be different # ECDSA signatures, like DSS signatures, tend to be different
# each time, so we can't compare against a "known correct" # each time, so we can't compare against a "known correct"
# signature. # signature.
# Even the length of the signature can change. # Even the length of the signature can change.
msg.rewind() msg.rewind()
pub = ECDSAKey(data=str(key)) pub = ECDSAKey(data=key.asbytes())
self.assert_(pub.verify_ssh_sig('ice weasels', msg)) self.assert_(pub.verify_ssh_sig('ice weasels', msg))

View File

@ -645,7 +645,8 @@ class SFTPTest (unittest.TestCase):
try: try:
sftp.rename(FOLDER + '/something', FOLDER + u'/\u00fcnic\u00f8de') sftp.rename(FOLDER + '/something', FOLDER + u'/\u00fcnic\u00f8de')
sftp.open(FOLDER + '/\xc3\xbcnic\xc3\xb8\x64\x65', 'r') sftp.open(FOLDER + '/\xc3\xbcnic\xc3\xb8\x64\x65', 'r')
except Exception, e: except Exception:
e = sys.exc_info()[1]
self.fail('exception ' + e) self.fail('exception ' + e)
sftp.unlink(FOLDER + '/\xc3\xbcnic\xc3\xb8\x64\x65') sftp.unlink(FOLDER + '/\xc3\xbcnic\xc3\xb8\x64\x65')

View File

@ -73,7 +73,7 @@ class BigSFTPTest (unittest.TestCase):
# now make sure every file is there, by creating a list of filenmes # now make sure every file is there, by creating a list of filenmes
# and reading them in random order. # and reading them in random order.
numlist = range(numfiles) numlist = list(range(numfiles))
while len(numlist) > 0: while len(numlist) > 0:
r = numlist[random.randint(0, len(numlist) - 1)] r = numlist[random.randint(0, len(numlist) - 1)]
f = sftp.open('%s/file%d.txt' % (FOLDER, r)) f = sftp.open('%s/file%d.txt' % (FOLDER, r))

View File

@ -121,8 +121,8 @@ class TransportTest(ParamikoTest):
self.sockc.close() self.sockc.close()
def setup_test_server(self, client_options=None, server_options=None): def setup_test_server(self, client_options=None, server_options=None):
public_host_key = RSAKey(data=str(host_key))
host_key = RSAKey.from_private_key_file(test_path('test_rsa.key')) host_key = RSAKey.from_private_key_file(test_path('test_rsa.key'))
public_host_key = RSAKey(data=host_key.asbytes())
self.ts.add_server_key(host_key) self.ts.add_server_key(host_key)
if client_options is not None: if client_options is not None:
@ -171,8 +171,8 @@ class TransportTest(ParamikoTest):
loopback sockets. this is hardly "simple" but it's simpler than the loopback sockets. this is hardly "simple" but it's simpler than the
later tests. :) later tests. :)
""" """
public_host_key = RSAKey(data=str(host_key))
host_key = RSAKey.from_private_key_file(test_path('test_rsa.key')) host_key = RSAKey.from_private_key_file(test_path('test_rsa.key'))
public_host_key = RSAKey(data=host_key.asbytes())
self.ts.add_server_key(host_key) self.ts.add_server_key(host_key)
event = threading.Event() event = threading.Event()
server = NullServer() server = NullServer()
@ -196,8 +196,8 @@ class TransportTest(ParamikoTest):
""" """
verify that a long banner doesn't mess up the handshake. verify that a long banner doesn't mess up the handshake.
""" """
public_host_key = RSAKey(data=str(host_key))
host_key = RSAKey.from_private_key_file(test_path('test_rsa.key')) host_key = RSAKey.from_private_key_file(test_path('test_rsa.key'))
public_host_key = RSAKey(data=host_key.asbytes())
self.ts.add_server_key(host_key) self.ts.add_server_key(host_key)
event = threading.Event() event = threading.Event()
server = NullServer() server = NullServer()
@ -708,7 +708,7 @@ class TransportTest(ParamikoTest):
# Simulate in-transit MSG_CHANNEL_WINDOW_ADJUST by sending it # Simulate in-transit MSG_CHANNEL_WINDOW_ADJUST by sending it
# before responding to the incoming MSG_KEXINIT. # before responding to the incoming MSG_KEXINIT.
m2 = Message() m2 = Message()
m2.add_byte(chr(MSG_CHANNEL_WINDOW_ADJUST)) m2.add_byte(cMSG_CHANNEL_WINDOW_ADJUST)
m2.add_int(chan.remote_chanid) m2.add_int(chan.remote_chanid)
m2.add_int(1) # bytes to add m2.add_int(1) # bytes to add
self._send_message(m2) self._send_message(m2)