diff --git a/paramiko/__init__.py b/paramiko/__init__.py index e9ed375..6e3190e 100644 --- a/paramiko/__init__.py +++ b/paramiko/__init__.py @@ -72,7 +72,7 @@ from transport import randpool, SecurityOptions, Transport from auth_handler import AuthHandler from channel import Channel, ChannelFile from ssh_exception import SSHException, PasswordRequiredException, BadAuthenticationType -from server import ServerInterface, SubsystemHandler +from server import ServerInterface, SubsystemHandler, InteractiveQuery from rsakey import RSAKey from dsskey import DSSKey from sftp import SFTPError, BaseSFTP diff --git a/paramiko/auth_handler.py b/paramiko/auth_handler.py index 953a9cc..f2ea0d8 100644 --- a/paramiko/auth_handler.py +++ b/paramiko/auth_handler.py @@ -29,6 +29,7 @@ from common import * import util from message import Message from ssh_exception import SSHException, BadAuthenticationType, PartialAuthentication +from server import InteractiveQuery class AuthHandler (object): @@ -211,6 +212,39 @@ class AuthHandler (object): else: self.transport._log(DEBUG, 'Service request "%s" accepted (?)' % service) + def _send_auth_result(self, username, method, result): + # okay, send result + m = Message() + if result == AUTH_SUCCESSFUL: + self.transport._log(INFO, 'Auth granted (%s).' % method) + m.add_byte(chr(MSG_USERAUTH_SUCCESS)) + self.authenticated = True + else: + self.transport._log(INFO, 'Auth rejected (%s).' % method) + m.add_byte(chr(MSG_USERAUTH_FAILURE)) + m.add_string(self.transport.server_object.get_allowed_auths(username)) + if result == AUTH_PARTIALLY_SUCCESSFUL: + m.add_boolean(1) + else: + m.add_boolean(0) + self.auth_fail_count += 1 + self.transport._send_message(m) + if self.auth_fail_count >= 10: + self._disconnect_no_more_auth() + + def _interactive_query(self, q): + # make interactive query instead of response + m = Message() + m.add_byte(chr(MSG_USERAUTH_INFO_REQUEST)) + m.add_string(q.name) + m.add_string(q.instructions) + m.add_string('') + m.add_int(len(q.prompts)) + for p in q.prompts: + m.add_string(p[0]) + m.add_boolean(p[1]) + self.transport._send_message(m) + def _parse_userauth_request(self, m): if not self.transport.server_mode: # er, uh... what? @@ -282,26 +316,18 @@ class AuthHandler (object): if not key.verify_ssh_sig(blob, sig): self.transport._log(INFO, 'Auth rejected: invalid signature') result = AUTH_FAILED + elif method == 'keyboard-interactive': + lang = m.get_string() + submethods = m.get_string() + result = self.transport.server_object.check_auth_interactive(username, submethods) + if isinstance(result, InteractiveQuery): + # make interactive query instead of response + self._interactive_query(result) + return else: result = self.transport.server_object.check_auth_none(username) # okay, send result - m = Message() - if result == AUTH_SUCCESSFUL: - self.transport._log(INFO, 'Auth granted (%s).' % method) - m.add_byte(chr(MSG_USERAUTH_SUCCESS)) - self.authenticated = True - else: - self.transport._log(INFO, 'Auth rejected (%s).' % method) - m.add_byte(chr(MSG_USERAUTH_FAILURE)) - m.add_string(self.transport.server_object.get_allowed_auths(username)) - if result == AUTH_PARTIALLY_SUCCESSFUL: - m.add_boolean(1) - else: - m.add_boolean(0) - self.auth_fail_count += 1 - self.transport._send_message(m) - if self.auth_fail_count >= 10: - self._disconnect_no_more_auth() + self._send_auth_result(username, method, result) def _parse_userauth_success(self, m): self.transport._log(INFO, 'Authentication successful!') @@ -351,6 +377,21 @@ class AuthHandler (object): for r in response_list: m.add_string(r) self.transport._send_message(m) + + def _parse_userauth_info_response(self, m): + if not self.transport.server_mode: + raise SSHException('Illegal info response from server') + n = m.get_int() + responses = [] + for i in range(n): + responses.append(m.get_string()) + result = self.transport.server_object.check_auth_interactive_response(responses) + if isinstance(type(result), InteractiveQuery): + # make interactive query instead of response + self._interactive_query(result) + return + self._send_auth_result(self.auth_username, 'keyboard-interactive', result) + _handler_table = { MSG_SERVICE_REQUEST: _parse_service_request, @@ -360,6 +401,7 @@ class AuthHandler (object): MSG_USERAUTH_FAILURE: _parse_userauth_failure, MSG_USERAUTH_BANNER: _parse_userauth_banner, MSG_USERAUTH_INFO_REQUEST: _parse_userauth_info_request, + MSG_USERAUTH_INFO_RESPONSE: _parse_userauth_info_response, } diff --git a/paramiko/server.py b/paramiko/server.py index 151dd12..859ae6b 100644 --- a/paramiko/server.py +++ b/paramiko/server.py @@ -26,6 +26,47 @@ import threading from common import * import util + +class InteractiveQuery (object): + """ + A query (set of prompts) for a user during interactive authentication. + """ + + def __init__(self, name='', instructions='', *prompts): + """ + Create a new interactive query to send to the client. The name and + instructions are optional, but are generally displayed to the end + user. A list of prompts may be included, or they may be added via + the L{add_prompt} method. + + @param name: name of this query + @type name: str + @param instructions: user instructions (usually short) about this query + @type instructions: str + """ + self.name = name + self.instructions = instructions + self.prompts = [] + for x in prompts: + if (type(x) is str) or (type(x) is unicode): + self.add_prompt(x) + else: + self.add_prompt(x[0], x[1]) + + def add_prompt(self, prompt, echo=True): + """ + Add a prompt to this query. The prompt should be a (reasonably short) + string. Multiple prompts can be added to the same query. + + @param prompt: the user prompt + @type prompt: str + @param echo: C{True} (default) if the user's response should be echoed; + C{False} if not (for a password or similar) + @type echo: bool + """ + self.prompts.append((prompt, echo)) + + class ServerInterface (object): """ This class defines an interface for controlling the behavior of paramiko @@ -154,7 +195,7 @@ class ServerInterface (object): Return L{AUTH_FAILED} if the key is not accepted, L{AUTH_SUCCESSFUL} if the key is accepted and completes the authentication, or L{AUTH_PARTIALLY_SUCCESSFUL} if your - authentication is stateful, and this key is accepted for + authentication is stateful, and this password is accepted for authentication, but more authentication is required. (In this latter case, L{get_allowed_auths} will be called to report to the client what options it has for continuing the authentication.) @@ -165,17 +206,74 @@ class ServerInterface (object): The default implementation always returns L{AUTH_FAILED}. - @param username: the username of the authenticating client. + @param username: the username of the authenticating client @type username: str - @param key: the key object provided by the client. + @param key: the key object provided by the client @type key: L{PKey } @return: L{AUTH_FAILED} if the client can't authenticate with this key; L{AUTH_SUCCESSFUL} if it can; L{AUTH_PARTIALLY_SUCCESSFUL} if it can authenticate with - this key but must continue with authentication. + this key but must continue with authentication @rtype: int """ return AUTH_FAILED + + def check_auth_interactive(self, username, submethods): + """ + Begin an interactive authentication challenge, if supported. You + should override this method in server mode if you want to support the + C{"keyboard-interactive"} auth type, which requires you to send a + series of questions for the client to answer. + + Return L{AUTH_FAILED} if this auth method isn't supported. Otherwise, + you should return an L{InteractiveQuery} object containing the prompts + and instructions for the user. The response will be sent via a call + to L{check_auth_interactive_response}. + + The default implementation always returns L{AUTH_FAILED}. + + @param username: the username of the authenticating client + @type username: str + @param submethods: a comma-separated list of methods preferred by the + client (usually empty) + @type submethods: str + @return: L{AUTH_FAILED} if this auth method isn't supported; otherwise + an object containing queries for the user + @rtype: int or L{InteractiveQuery} + """ + return AUTH_FAILED + + def check_auth_interactive_response(self, responses): + """ + Continue or finish an interactive authentication challenge, if + supported. You should override this method in server mode if you want + to support the C{"keyboard-interactive"} auth type. + + Return L{AUTH_FAILED} if the responses are not accepted, + L{AUTH_SUCCESSFUL} if the responses are accepted and complete + the authentication, or L{AUTH_PARTIALLY_SUCCESSFUL} if your + authentication is stateful, and this set of responses is accepted for + authentication, but more authentication is required. (In this latter + case, L{get_allowed_auths} will be called to report to the client what + options it has for continuing the authentication.) + + If you wish to continue interactive authentication with more questions, + you may return an L{InteractiveQuery} object, which should cause the + client to respond with more answers, calling this method again. This + cycle can continue indefinitely. + + The default implementation always returns L{AUTH_FAILED}. + + @param responses: list of responses from the client + @type responses: list(str) + @return: L{AUTH_FAILED} if the authentication fails; + L{AUTH_SUCCESSFUL} if it succeeds; + L{AUTH_PARTIALLY_SUCCESSFUL} if the interactive auth is + successful, but authentication must continue; otherwise an object + containing queries for the user + @rtype: int or L{InteractiveQuery} + """ + return AUTH_FAILED def check_global_request(self, kind, msg): """ diff --git a/tests/test_transport.py b/tests/test_transport.py index ecea6e1..81254b4 100644 --- a/tests/test_transport.py +++ b/tests/test_transport.py @@ -23,7 +23,7 @@ Some unit tests for the ssh2 protocol in Transport. import sys, time, threading, unittest import select from paramiko import Transport, SecurityOptions, ServerInterface, RSAKey, DSSKey, \ - SSHException, BadAuthenticationType, util + SSHException, BadAuthenticationType, InteractiveQuery, util from paramiko import AUTH_FAILED, AUTH_PARTIALLY_SUCCESSFUL, AUTH_SUCCESSFUL from paramiko import OPEN_SUCCEEDED from loop import LoopSocket @@ -44,6 +44,8 @@ class NullServer (ServerInterface): return 'publickey' else: return 'password' + if username == 'commie': + return 'keyboard-interactive' return 'publickey' def check_auth_password(self, username, password): @@ -66,6 +68,18 @@ class NullServer (ServerInterface): return AUTH_PARTIALLY_SUCCESSFUL return AUTH_FAILED + def check_auth_interactive(self, username, submethods): + if username == 'commie': + self.username = username + return InteractiveQuery('password', 'Please enter a password.', ('Password', False)) + return AUTH_FAILED + + def check_auth_interactive_response(self, responses): + if self.username == 'commie': + if (len(responses) == 1) and (responses[0] == 'cat'): + return AUTH_SUCCESSFUL + return AUTH_FAILED + def check_channel_request(self, kind, chanid): return OPEN_SUCCEEDED @@ -270,7 +284,54 @@ class TransportTest (unittest.TestCase): self.assert_(event.isSet()) self.assert_(self.ts.is_active()) - def test_9_exec_command(self): + def test_9_interactive_auth(self): + """ + verify keyboard-interactive auth works. + """ + host_key = RSAKey.from_private_key_file('tests/test_rsa.key') + public_host_key = RSAKey(data=str(host_key)) + self.ts.add_server_key(host_key) + event = threading.Event() + server = NullServer() + self.assert_(not event.isSet()) + self.ts.start_server(event, server) + self.tc.ultra_debug = True + self.tc.connect(hostkey=public_host_key) + + def handler(title, instructions, prompts): + self.got_title = title + self.got_instructions = instructions + self.got_prompts = prompts + return ['cat'] + remain = self.tc.auth_interactive('commie', handler) + self.assertEquals(self.got_title, 'password') + self.assertEquals(self.got_prompts, [('Password', False)]) + self.assertEquals([], remain) + event.wait(1.0) + self.assert_(event.isSet()) + self.assert_(self.ts.is_active()) + + def test_A_interactive_auth_fallback(self): + """ + verify that a password auth attempt will fallback to "interactive" + if password auth isn't supported but interactive is. + """ + host_key = RSAKey.from_private_key_file('tests/test_rsa.key') + public_host_key = RSAKey(data=str(host_key)) + self.ts.add_server_key(host_key) + event = threading.Event() + server = NullServer() + self.assert_(not event.isSet()) + self.ts.start_server(event, server) + self.tc.ultra_debug = True + self.tc.connect(hostkey=public_host_key) + remain = self.tc.auth_password('commie', 'cat') + self.assertEquals([], remain) + event.wait(1.0) + self.assert_(event.isSet()) + self.assert_(self.ts.is_active()) + + def test_B_exec_command(self): """ verify that exec_command() does something reasonable. """ @@ -320,7 +381,7 @@ class TransportTest (unittest.TestCase): self.assertEquals('This is on stderr.\n', f.readline()) self.assertEquals('', f.readline()) - def test_A_invoke_shell(self): + def test_C_invoke_shell(self): """ verify that invoke_shell() does something reasonable. """ @@ -347,7 +408,7 @@ class TransportTest (unittest.TestCase): chan.close() self.assertEquals('', f.readline()) - def test_B_exit_status(self): + def test_D_exit_status(self): """ verify that get_exit_status() works. """ @@ -381,7 +442,7 @@ class TransportTest (unittest.TestCase): self.assertEquals(23, chan.recv_exit_status()) chan.close() - def test_C_select(self): + def test_E_select(self): """ verify that select() on a channel works. """