diff --git a/msal/__main__.py b/msal/__main__.py index 4107fd89..a28801eb 100644 --- a/msal/__main__.py +++ b/msal/__main__.py @@ -339,6 +339,8 @@ def _main(): logging.error("Invalid input: %s", e) except KeyboardInterrupt: # Useful for bailing out a stuck interactive flow print("Aborted") + except Exception as e: + logging.error("Error: %s", e) if __name__ == "__main__": _main() diff --git a/msal/application.py b/msal/application.py index 57e40980..4fb5c289 100644 --- a/msal/application.py +++ b/msal/application.py @@ -961,7 +961,7 @@ def initiate_auth_code_flow( :param str response_mode: OPTIONAL. Specifies the method with which response parameters should be returned. - The default value is equivalent to ``query``, which is still secure enough in MSAL Python + The default value is equivalent to ``query``, which was still secure enough in MSAL Python (because MSAL Python does not transfer tokens via query parameter in the first place). For even better security, we recommend using the value ``form_post``. In "form_post" mode, response parameters @@ -973,6 +973,11 @@ def initiate_auth_code_flow( `here ` and `here ` + .. note:: + You should configure your web framework to accept form_post responses instead of query responses. + While this parameter still works, it will be removed in a future version. + Using query-based response modes is less secure and should be avoided. + :return: The auth code flow. It is a dict in this form:: @@ -991,6 +996,9 @@ def initiate_auth_code_flow( 3. and then relay this dict and subsequent auth response to :func:`~acquire_token_by_auth_code_flow()`. """ + # Note to maintainers: Do not emit warning for the use of response_mode here, + # because response_mode=form_post is still the recommended usage for MSAL Python 1.x. + # App developers making the right call shall not be disturbed by unactionable warnings. client = _ClientWithCcsRoutingInfo( {"authorization_endpoint": self.authority.authorization_endpoint}, self.client_id, diff --git a/msal/oauth2cli/authcode.py b/msal/oauth2cli/authcode.py index ba266223..4b2466cf 100644 --- a/msal/oauth2cli/authcode.py +++ b/msal/oauth2cli/authcode.py @@ -5,6 +5,7 @@ It optionally opens a browser window to guide a human user to manually login. After obtaining an auth code, the web server will automatically shut down. """ +from collections import defaultdict import logging import os import socket @@ -109,29 +110,49 @@ def _printify(text): class _AuthCodeHandler(BaseHTTPRequestHandler): def do_GET(self): - # For flexibility, we choose to not check self.path matching redirect_uri - #assert self.path.startswith('/THE_PATH_REGISTERED_BY_THE_APP') qs = parse_qs(urlparse(self.path).query) - if qs.get('code') or qs.get("error"): # So, it is an auth response - auth_response = _qs2kv(qs) - logger.debug("Got auth response: %s", auth_response) - if self.server.auth_state and self.server.auth_state != auth_response.get("state"): - # OAuth2 successful and error responses contain state when it was used - # https://www.rfc-editor.org/rfc/rfc6749#section-4.2.2.1 - self._send_full_response("State mismatch") # Possibly an attack - else: - template = (self.server.success_template - if "code" in qs else self.server.error_template) - if _is_html(template.template): - safe_data = _escape(auth_response) # Foiling an XSS attack - else: - safe_data = auth_response - self._send_full_response(template.safe_substitute(**safe_data)) - self.server.auth_response = auth_response # Set it now, after the response is likely sent + if qs: + # GET request with auth code or error - reject for security (form_post only) + self._send_full_response( + "response_mode=query is not supported for authentication responses. " + "This application operates in response_mode=form_post mode only.", + is_ok=False) else: + # Other GET requests - show welcome page self._send_full_response(self.server.welcome_page) # NOTE: Don't do self.server.shutdown() here. It'll halt the server. + def do_POST(self): # Handle form_post response where auth code is in body + # For flexibility, we choose to not check self.path matching redirect_uri + #assert self.path.startswith('/THE_PATH_REGISTERED_BY_THE_APP') + content_length = int(self.headers.get('Content-Length', 0)) + post_data = self.rfile.read(content_length).decode('utf-8') + qs = parse_qs(post_data) + if qs.get('code') or qs.get('error'): # So, it is an auth response + self._process_auth_response(_qs2kv(qs)) + else: + self._send_full_response("Invalid POST request", is_ok=False) + # NOTE: Don't do self.server.shutdown() here. It'll halt the server. + + def _process_auth_response(self, auth_response): + """Process the auth response from either GET or POST request.""" + logger.debug("Got auth response: %s", auth_response) + if self.server.auth_state and self.server.auth_state != auth_response.get("state"): + # OAuth2 successful and error responses contain state when it was used + # https://www.rfc-editor.org/rfc/rfc6749#section-4.2.2.1 + self._send_full_response( # Possibly an attack + "State mismatch. Waiting for next response... or you may abort.", is_ok=False) + else: + template = (self.server.success_template + if "code" in auth_response else self.server.error_template) + if _is_html(template.template): + safe_data = _escape(auth_response) # Foiling an XSS attack + else: + safe_data = auth_response + filled_data = defaultdict(str, safe_data) # So that missing keys will be empty string + self._send_full_response(template.safe_substitute(**filled_data)) + self.server.auth_response = auth_response # Set it now, after the response is likely sent + def _send_full_response(self, body, is_ok=True): self.send_response(200 if is_ok else 400) content_type = 'text/html' if _is_html(body) else 'text/plain' @@ -215,6 +236,7 @@ def get_auth_response(self, timeout=None, **kwargs): :param str auth_uri: If provided, this function will try to open a local browser. + Starting from 2026, the built-in http server will require response_mode=form_post. :param int timeout: In seconds. None means wait indefinitely. :param str state: You may provide the state you used in auth_uri, @@ -287,8 +309,20 @@ def _get_auth_response(self, result, auth_uri=None, timeout=None, state=None, welcome_uri = "http://localhost:{p}".format(p=self.get_port()) abort_uri = "{loc}?error=abort".format(loc=welcome_uri) logger.debug("Abort by visit %s", abort_uri) - self._server.welcome_page = Template(welcome_template or "").safe_substitute( - auth_uri=auth_uri, abort_uri=abort_uri) + + if auth_uri: + # Note to maintainers: + # Do not enforce response_mode=form_post by secretly hardcoding it here. + # Just validate it here, so we won't surprise caller by changing their auth_uri behind the scene. + params = parse_qs(urlparse(auth_uri).query) + assert params.get('response_mode', [None])[0] == 'form_post', ( + "The built-in http server supports HTTP POST only. " + "The auth_uri must be built with response_mode=form_post") + + self._server.welcome_page = Template( + welcome_template or + "Sign In, or Abort" + ).safe_substitute(auth_uri=auth_uri, abort_uri=abort_uri) if auth_uri: # Now attempt to open a local browser to visit it _uri = welcome_uri if welcome_template else auth_uri logger.info("Open a browser on this device to visit: %s" % _uri) @@ -317,8 +351,11 @@ def _get_auth_response(self, result, auth_uri=None, timeout=None, state=None, auth_uri_callback(_uri) self._server.success_template = Template(success_template or - "Authentication completed. You can close this window now.") + "Authentication complete. You can return to the application. Please close this browser tab.") self._server.error_template = Template(error_template or + # Do NOT invent new placeholders in this template. Just use standard keys defined in OAuth2 RFC. + # Otherwise there is no obvious canonical way for caller to know what placeholders are supported. + # Besides, we have been using these standard keys for years. Changing now would break backward compatibility. "Authentication failed. $error: $error_description. ($error_uri)") self._server.timeout = timeout # Otherwise its handle_timeout() won't work @@ -370,8 +407,6 @@ def __exit__(self, exc_type, exc_val, exc_tb): ) print(json.dumps(receiver.get_auth_response( auth_uri=flow["auth_uri"], - welcome_template= - "Sign In, or Abort<tag>foo</tag>", - requests.get("http://localhost:{}?error=foo".format( - receiver.get_port())).text, - "Unsafe data in HTML should be escaped", + "<script>alert('xss');</script>", + requests.post( + "http://localhost:{}".format(receiver.get_port()), + data={"error": ""}, + ).text, ))] receiver.get_auth_response( # Starts server and hang until timeout timeout=3, error_template="$error", ) + def test_get_request_with_auth_code_is_rejected(self): + """Test that GET request with auth code is rejected for security""" + with AuthCodeReceiver() as receiver: + test_state = "test_state_67890" + receiver._scheduled_actions = [( + 1, + lambda: self.assertEqual(400, requests.get( + "http://localhost:{}".format(receiver.get_port()), params={ + "code": "test_auth_code_12345", + "state": test_state + } + ).status_code) + )] + result = receiver.get_auth_response(timeout=3, state=test_state) + self.assertIsNone(result, "Should not receive auth response via GET") + + def test_post_request_with_auth_code(self): + """Test that POST request with auth code is handled correctly (form_post response mode)""" + with AuthCodeReceiver() as receiver: + test_code = "test_auth_code_12345" + test_state = "test_state_67890" + receiver._scheduled_actions = [( + 1, + lambda: requests.post( + "http://localhost:{}".format(receiver.get_port()), + data={"code": test_code, "state": test_state}, + ) + )] + result = receiver.get_auth_response(timeout=3, state=test_state) + self.assertIsNotNone(result, "Should receive auth response via POST") + self.assertEqual(result.get("code"), test_code) + self.assertEqual(result.get("state"), test_state) + + def test_post_request_with_error(self): + """Test that POST request with error is handled correctly""" + with AuthCodeReceiver() as receiver: + test_error = "access_denied" + test_error_description = "User denied access" + receiver._scheduled_actions = [( + 1, + lambda: requests.post( + "http://localhost:{}".format(receiver.get_port()), + data={"error": test_error, "error_description": test_error_description}, + ) + )] + result = receiver.get_auth_response(timeout=3) + self.assertIsNotNone(result, "Should receive auth response via POST") + self.assertEqual(result.get("error"), test_error) + self.assertEqual(result.get("error_description"), test_error_description) + + def test_post_request_state_mismatch(self): + """Test that POST request with mismatched state is rejected""" + with AuthCodeReceiver() as receiver: + receiver._scheduled_actions = [( + 1, + lambda: requests.post( + "http://localhost:{}".format(receiver.get_port()), + data={"code": "test_code", "state": "wrong_state"}, + ) + )] + result = receiver.get_auth_response(timeout=3, state="expected_state") + self.assertIsNone(result, "Should not receive auth response due to state mismatch") diff --git a/tests/test_client_obtain_token_by_browser.py b/tests/test_client_obtain_token_by_browser.py new file mode 100644 index 00000000..465d139f --- /dev/null +++ b/tests/test_client_obtain_token_by_browser.py @@ -0,0 +1,117 @@ +"""Integration tests for form_post response mode in authorization code flow""" +import json +import unittest +try: + from urllib.parse import urlparse, parse_qs +except ImportError: + from urlparse import urlparse, parse_qs + +from msal.oauth2cli import Client +from msal.oauth2cli.authcode import AuthCodeReceiver +import requests + + +class _BrowserlessAuthCodeReceiver(AuthCodeReceiver): + """Subclass that bypass browser opening behavior for testing. + + :param scheduled_action: + Optional callable with signature: (*, state: str, port: int) -> Callable[[], Any] + It receives state and port as keyword arguments, and should return a callable + that will be executed as the scheduled action. + + This allows the test cases to define mocked http requests with the correct state and port. + """ + def __init__(self, *args, scheduled_action=None, **kwargs): + super(_BrowserlessAuthCodeReceiver, self).__init__(*args, **kwargs) + self.scheduled_action = scheduled_action + + def get_auth_response(self, **kwargs): + """Override to strip auth_uri, preventing browser launch, and optionally inject scheduled actions.""" + kwargs.pop('auth_uri', None) # Remove auth_uri to skip browser behavior + kwargs.pop('auth_uri_callback', None) # Also remove callback + + if self.scheduled_action: + state = kwargs.get('state') + port = self.get_port() + self._scheduled_actions = [( + 1, + self.scheduled_action(state=state, port=port) + )] + + return super(_BrowserlessAuthCodeReceiver, self).get_auth_response(**kwargs) + + +class TestResponseModeIntegration(unittest.TestCase): + """Integration test for response_mode with end-to-end authentication flow""" + fake_access_token = "fake_token_xyz" + + def setUp(self): + # Mock http_client that returns fake token + class MockResponse: + def __init__(self): + self.status_code = 200 + self.text = json.dumps({ + "access_token": TestResponseModeIntegration.fake_access_token, + "token_type": "Bearer", + "expires_in": 3600, + }) + + class MockHttpClient: + def post(self, url, **kwargs): + return MockResponse() + + self.client = Client( + { + "authorization_endpoint": "https://example.com/authorize", + "token_endpoint": "https://example.com/token", + }, + "test_client_id", + http_client=MockHttpClient(), + ) + + def test_initiate_auth_code_flow_with_non_form_post_response_mode_should_warn(self): + """Test that initiating auth code flow warns for non-form_post response modes""" + for mode in ['query', 'fragment', None]: + with self.assertWarns(UserWarning) as cm: + flow = self.client.initiate_auth_code_flow( + scope=["openid", "profile"], + redirect_uri="http://localhost:8080", + response_mode=mode, + ) + self.assertIn( + "form_post", str(cm.warning).lower(), "Warning should mention form_post requirement") + if mode is not None: + # Verify response_mode in the auth_uri (if it was explicitly set) + params = parse_qs(urlparse(flow["auth_uri"]).query) + self.assertIn( + "response_mode", params, "response_mode should be in auth_uri when explicitly set") + self.assertEqual( + params.get("response_mode", [None])[0], mode, + f"response_mode should be set as requested") + + def test_http_post_should_work_with_obtain_token_by_browser(self): + def action_builder(*, state, port): + """Build an action that sends the auth response via HTTP POST""" + return lambda: requests.post( + "http://localhost:{}".format(port), + data={"code": "auth_code_from_server", "state": state}, + ) + + with _BrowserlessAuthCodeReceiver(scheduled_action=action_builder) as receiver: + try: + # This will use form_post internally (as required by obtain_token_by_browser) + result = self.client.obtain_token_by_browser( + redirect_uri="http://localhost:{}".format(receiver.get_port()), + auth_code_receiver=receiver, + scope=["openid", "profile"], + timeout=3, + ) + self.assertIsNotNone(result, "obtain_token_by_browser should return result") + self.assertIn("access_token", result, "Result should contain access_token") + self.assertEqual(result["access_token"], self.fake_access_token) + except Exception as e: + self.fail(f"obtain_token_by_browser failed: {e}") + + +if __name__ == '__main__': + unittest.main()