summary refs log tree commit diff
path: root/tests/handlers/test_oidc.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/handlers/test_oidc.py')
-rw-r--r--tests/handlers/test_oidc.py340
1 files changed, 171 insertions, 169 deletions
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index a308c46da9..c54f1c5797 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -15,32 +15,18 @@
 import json
 from urllib.parse import parse_qs, urlparse
 
-from mock import Mock, patch
+from mock import ANY, Mock, patch
 
-import attr
 import pymacaroons
 
-from twisted.python.failure import Failure
-from twisted.web._newclient import ResponseDone
-
-from synapse.handlers.oidc_handler import OidcError, OidcMappingProvider
+from synapse.handlers.oidc_handler import OidcError
 from synapse.handlers.sso import MappingException
+from synapse.server import HomeServer
 from synapse.types import UserID
 
+from tests.test_utils import FakeResponse, simple_async_mock
 from tests.unittest import HomeserverTestCase, override_config
 
-
-@attr.s
-class FakeResponse:
-    code = attr.ib()
-    body = attr.ib()
-    phrase = attr.ib()
-
-    def deliverBody(self, protocol):
-        protocol.dataReceived(self.body)
-        protocol.connectionLost(Failure(ResponseDone()))
-
-
 # These are a few constants that are used as config parameters in the tests.
 ISSUER = "https://issuer/"
 CLIENT_ID = "test-client-id"
@@ -70,11 +56,14 @@ COOKIE_NAME = b"oidc_session"
 COOKIE_PATH = "/_synapse/oidc"
 
 
-class TestMappingProvider(OidcMappingProvider):
+class TestMappingProvider:
     @staticmethod
     def parse_config(config):
         return
 
+    def __init__(self, config):
+        pass
+
     def get_remote_user_id(self, userinfo):
         return userinfo["sub"]
 
@@ -97,16 +86,6 @@ class TestMappingProviderFailures(TestMappingProvider):
         }
 
 
-def simple_async_mock(return_value=None, raises=None):
-    # AsyncMock is not available in python3.5, this mimics part of its behaviour
-    async def cb(*args, **kwargs):
-        if raises:
-            raise raises
-        return return_value
-
-    return Mock(side_effect=cb)
-
-
 async def get_json(url):
     # Mock get_json calls to handle jwks & oidc discovery endpoints
     if url == WELL_KNOWN:
@@ -175,6 +154,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
             self.assertEqual(args[2], error_description)
         # Reset the render_error mock
         self.render_error.reset_mock()
+        return args
 
     def test_config(self):
         """Basic config correctly sets up the callback URL and client auth correctly."""
@@ -384,31 +364,29 @@ class OidcHandlerTestCase(HomeserverTestCase):
          - when the userinfo fetching fails
          - when the code exchange fails
         """
+
+        # ensure that we are correctly testing the fallback when "get_extra_attributes"
+        # is not implemented.
+        mapping_provider = self.handler._user_mapping_provider
+        with self.assertRaises(AttributeError):
+            _ = mapping_provider.get_extra_attributes
+
         token = {
             "type": "bearer",
             "id_token": "id_token",
             "access_token": "access_token",
         }
+        username = "bar"
         userinfo = {
             "sub": "foo",
-            "preferred_username": "bar",
+            "username": username,
         }
-        user_id = "@foo:domain.org"
+        expected_user_id = "@%s:%s" % (username, self.hs.hostname)
         self.handler._exchange_code = simple_async_mock(return_value=token)
         self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
         self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
-        self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
-        self.handler._auth_handler.complete_sso_login = simple_async_mock()
-        request = Mock(
-            spec=[
-                "args",
-                "getCookie",
-                "addCookie",
-                "requestHeaders",
-                "getClientIP",
-                "get_user_agent",
-            ]
-        )
+        auth_handler = self.hs.get_auth_handler()
+        auth_handler.complete_sso_login = simple_async_mock()
 
         code = "code"
         state = "state"
@@ -416,64 +394,54 @@ class OidcHandlerTestCase(HomeserverTestCase):
         client_redirect_url = "http://client/redirect"
         user_agent = "Browser"
         ip_address = "10.0.0.1"
-        request.getCookie.return_value = self.handler._generate_oidc_session_token(
+        session = self.handler._generate_oidc_session_token(
             state=state,
             nonce=nonce,
             client_redirect_url=client_redirect_url,
             ui_auth_session_id=None,
         )
-
-        request.args = {}
-        request.args[b"code"] = [code.encode("utf-8")]
-        request.args[b"state"] = [state.encode("utf-8")]
-
-        request.getClientIP.return_value = ip_address
-        request.get_user_agent.return_value = user_agent
+        request = _build_callback_request(
+            code, state, session, user_agent=user_agent, ip_address=ip_address
+        )
 
         self.get_success(self.handler.handle_oidc_callback(request))
 
-        self.handler._auth_handler.complete_sso_login.assert_called_once_with(
-            user_id, request, client_redirect_url, {},
+        auth_handler.complete_sso_login.assert_called_once_with(
+            expected_user_id, request, client_redirect_url, None,
         )
         self.handler._exchange_code.assert_called_once_with(code)
         self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce)
-        self.handler._map_userinfo_to_user.assert_called_once_with(
-            userinfo, token, user_agent, ip_address
-        )
         self.handler._fetch_userinfo.assert_not_called()
         self.render_error.assert_not_called()
 
         # Handle mapping errors
-        self.handler._map_userinfo_to_user = simple_async_mock(
-            raises=MappingException()
-        )
-        self.get_success(self.handler.handle_oidc_callback(request))
-        self.assertRenderedError("mapping_error")
-        self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
+        with patch.object(
+            self.handler,
+            "_remote_id_from_userinfo",
+            new=Mock(side_effect=MappingException()),
+        ):
+            self.get_success(self.handler.handle_oidc_callback(request))
+            self.assertRenderedError("mapping_error")
 
         # Handle ID token errors
         self.handler._parse_id_token = simple_async_mock(raises=Exception())
         self.get_success(self.handler.handle_oidc_callback(request))
         self.assertRenderedError("invalid_token")
 
-        self.handler._auth_handler.complete_sso_login.reset_mock()
+        auth_handler.complete_sso_login.reset_mock()
         self.handler._exchange_code.reset_mock()
         self.handler._parse_id_token.reset_mock()
-        self.handler._map_userinfo_to_user.reset_mock()
         self.handler._fetch_userinfo.reset_mock()
 
         # With userinfo fetching
         self.handler._scopes = []  # do not ask the "openid" scope
         self.get_success(self.handler.handle_oidc_callback(request))
 
-        self.handler._auth_handler.complete_sso_login.assert_called_once_with(
-            user_id, request, client_redirect_url, {},
+        auth_handler.complete_sso_login.assert_called_once_with(
+            expected_user_id, request, client_redirect_url, None,
         )
         self.handler._exchange_code.assert_called_once_with(code)
         self.handler._parse_id_token.assert_not_called()
-        self.handler._map_userinfo_to_user.assert_called_once_with(
-            userinfo, token, user_agent, ip_address
-        )
         self.handler._fetch_userinfo.assert_called_once_with(token)
         self.render_error.assert_not_called()
 
@@ -624,72 +592,55 @@ class OidcHandlerTestCase(HomeserverTestCase):
         }
         userinfo = {
             "sub": "foo",
+            "username": "foo",
             "phone": "1234567",
         }
-        user_id = "@foo:domain.org"
         self.handler._exchange_code = simple_async_mock(return_value=token)
         self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
-        self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
-        self.handler._auth_handler.complete_sso_login = simple_async_mock()
-        request = Mock(
-            spec=[
-                "args",
-                "getCookie",
-                "addCookie",
-                "requestHeaders",
-                "getClientIP",
-                "get_user_agent",
-            ]
-        )
+        auth_handler = self.hs.get_auth_handler()
+        auth_handler.complete_sso_login = simple_async_mock()
 
         state = "state"
         client_redirect_url = "http://client/redirect"
-        request.getCookie.return_value = self.handler._generate_oidc_session_token(
+        session = self.handler._generate_oidc_session_token(
             state=state,
             nonce="nonce",
             client_redirect_url=client_redirect_url,
             ui_auth_session_id=None,
         )
-
-        request.args = {}
-        request.args[b"code"] = [b"code"]
-        request.args[b"state"] = [state.encode("utf-8")]
-
-        request.getClientIP.return_value = "10.0.0.1"
-        request.get_user_agent.return_value = "Browser"
+        request = _build_callback_request("code", state, session)
 
         self.get_success(self.handler.handle_oidc_callback(request))
 
-        self.handler._auth_handler.complete_sso_login.assert_called_once_with(
-            user_id, request, client_redirect_url, {"phone": "1234567"},
+        auth_handler.complete_sso_login.assert_called_once_with(
+            "@foo:test", request, client_redirect_url, {"phone": "1234567"},
         )
 
     def test_map_userinfo_to_user(self):
         """Ensure that mapping the userinfo returned from a provider to an MXID works properly."""
+        auth_handler = self.hs.get_auth_handler()
+        auth_handler.complete_sso_login = simple_async_mock()
+
         userinfo = {
             "sub": "test_user",
             "username": "test_user",
         }
-        # The token doesn't matter with the default user mapping provider.
-        token = {}
-        mxid = self.get_success(
-            self.handler._map_userinfo_to_user(
-                userinfo, token, "user-agent", "10.10.10.10"
-            )
+        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+        auth_handler.complete_sso_login.assert_called_once_with(
+            "@test_user:test", ANY, ANY, None,
         )
-        self.assertEqual(mxid, "@test_user:test")
+        auth_handler.complete_sso_login.reset_mock()
 
         # Some providers return an integer ID.
         userinfo = {
             "sub": 1234,
             "username": "test_user_2",
         }
-        mxid = self.get_success(
-            self.handler._map_userinfo_to_user(
-                userinfo, token, "user-agent", "10.10.10.10"
-            )
+        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+        auth_handler.complete_sso_login.assert_called_once_with(
+            "@test_user_2:test", ANY, ANY, None,
         )
-        self.assertEqual(mxid, "@test_user_2:test")
+        auth_handler.complete_sso_login.reset_mock()
 
         # Test if the mxid is already taken
         store = self.hs.get_datastore()
@@ -698,14 +649,11 @@ class OidcHandlerTestCase(HomeserverTestCase):
             store.register_user(user_id=user3.to_string(), password_hash=None)
         )
         userinfo = {"sub": "test3", "username": "test_user_3"}
-        e = self.get_failure(
-            self.handler._map_userinfo_to_user(
-                userinfo, token, "user-agent", "10.10.10.10"
-            ),
-            MappingException,
-        )
-        self.assertEqual(
-            str(e.value), "Mapping provider does not support de-duplicating Matrix IDs",
+        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+        auth_handler.complete_sso_login.assert_not_called()
+        self.assertRenderedError(
+            "mapping_error",
+            "Mapping provider does not support de-duplicating Matrix IDs",
         )
 
     @override_config({"oidc_config": {"allow_existing_users": True}})
@@ -717,26 +665,26 @@ class OidcHandlerTestCase(HomeserverTestCase):
             store.register_user(user_id=user.to_string(), password_hash=None)
         )
 
+        auth_handler = self.hs.get_auth_handler()
+        auth_handler.complete_sso_login = simple_async_mock()
+
         # Map a user via SSO.
         userinfo = {
             "sub": "test",
             "username": "test_user",
         }
-        token = {}
-        mxid = self.get_success(
-            self.handler._map_userinfo_to_user(
-                userinfo, token, "user-agent", "10.10.10.10"
-            )
+        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+        auth_handler.complete_sso_login.assert_called_once_with(
+            user.to_string(), ANY, ANY, None,
         )
-        self.assertEqual(mxid, "@test_user:test")
+        auth_handler.complete_sso_login.reset_mock()
 
         # Subsequent calls should map to the same mxid.
-        mxid = self.get_success(
-            self.handler._map_userinfo_to_user(
-                userinfo, token, "user-agent", "10.10.10.10"
-            )
+        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+        auth_handler.complete_sso_login.assert_called_once_with(
+            user.to_string(), ANY, ANY, None,
         )
-        self.assertEqual(mxid, "@test_user:test")
+        auth_handler.complete_sso_login.reset_mock()
 
         # Note that a second SSO user can be mapped to the same Matrix ID. (This
         # requires a unique sub, but something that maps to the same matrix ID,
@@ -747,13 +695,11 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "sub": "test1",
             "username": "test_user",
         }
-        token = {}
-        mxid = self.get_success(
-            self.handler._map_userinfo_to_user(
-                userinfo, token, "user-agent", "10.10.10.10"
-            )
+        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+        auth_handler.complete_sso_login.assert_called_once_with(
+            user.to_string(), ANY, ANY, None,
         )
-        self.assertEqual(mxid, "@test_user:test")
+        auth_handler.complete_sso_login.reset_mock()
 
         # Register some non-exact matching cases.
         user2 = UserID.from_string("@TEST_user_2:test")
@@ -770,14 +716,11 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "sub": "test2",
             "username": "TEST_USER_2",
         }
-        e = self.get_failure(
-            self.handler._map_userinfo_to_user(
-                userinfo, token, "user-agent", "10.10.10.10"
-            ),
-            MappingException,
-        )
+        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+        auth_handler.complete_sso_login.assert_not_called()
+        args = self.assertRenderedError("mapping_error")
         self.assertTrue(
-            str(e.value).startswith(
+            args[2].startswith(
                 "Attempted to login as '@TEST_USER_2:test' but it matches more than one user inexactly:"
             )
         )
@@ -788,28 +731,17 @@ class OidcHandlerTestCase(HomeserverTestCase):
             store.register_user(user_id=user2.to_string(), password_hash=None)
         )
 
-        mxid = self.get_success(
-            self.handler._map_userinfo_to_user(
-                userinfo, token, "user-agent", "10.10.10.10"
-            )
+        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+        auth_handler.complete_sso_login.assert_called_once_with(
+            "@TEST_USER_2:test", ANY, ANY, None,
         )
-        self.assertEqual(mxid, "@TEST_USER_2:test")
 
     def test_map_userinfo_to_invalid_localpart(self):
         """If the mapping provider generates an invalid localpart it should be rejected."""
-        userinfo = {
-            "sub": "test2",
-            "username": "föö",
-        }
-        token = {}
-
-        e = self.get_failure(
-            self.handler._map_userinfo_to_user(
-                userinfo, token, "user-agent", "10.10.10.10"
-            ),
-            MappingException,
+        self.get_success(
+            _make_callback_with_userinfo(self.hs, {"sub": "test2", "username": "föö"})
         )
-        self.assertEqual(str(e.value), "localpart is invalid: föö")
+        self.assertRenderedError("mapping_error", "localpart is invalid: föö")
 
     @override_config(
         {
@@ -822,6 +754,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
     )
     def test_map_userinfo_to_user_retries(self):
         """The mapping provider can retry generating an MXID if the MXID is already in use."""
+        auth_handler = self.hs.get_auth_handler()
+        auth_handler.complete_sso_login = simple_async_mock()
+
         store = self.hs.get_datastore()
         self.get_success(
             store.register_user(user_id="@test_user:test", password_hash=None)
@@ -830,14 +765,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "sub": "test",
             "username": "test_user",
         }
-        token = {}
-        mxid = self.get_success(
-            self.handler._map_userinfo_to_user(
-                userinfo, token, "user-agent", "10.10.10.10"
-            )
-        )
+        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+
         # test_user is already taken, so test_user1 gets registered instead.
-        self.assertEqual(mxid, "@test_user1:test")
+        auth_handler.complete_sso_login.assert_called_once_with(
+            "@test_user1:test", ANY, ANY, None,
+        )
+        auth_handler.complete_sso_login.reset_mock()
 
         # Register all of the potential mxids for a particular OIDC username.
         self.get_success(
@@ -853,12 +787,80 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "sub": "tester",
             "username": "tester",
         }
-        e = self.get_failure(
-            self.handler._map_userinfo_to_user(
-                userinfo, token, "user-agent", "10.10.10.10"
-            ),
-            MappingException,
-        )
-        self.assertEqual(
-            str(e.value), "Unable to generate a Matrix ID from the SSO response"
-        )
+        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+        auth_handler.complete_sso_login.assert_not_called()
+        self.assertRenderedError(
+            "mapping_error", "Unable to generate a Matrix ID from the SSO response"
+        )
+
+
+async def _make_callback_with_userinfo(
+    hs: HomeServer, userinfo: dict, client_redirect_url: str = "http://client/redirect"
+) -> None:
+    """Mock up an OIDC callback with the given userinfo dict
+
+    We'll pull out the OIDC handler from the homeserver, stub out a couple of methods,
+    and poke in the userinfo dict as if it were the response to an OIDC userinfo call.
+
+    Args:
+        hs: the HomeServer impl to send the callback to.
+        userinfo: the OIDC userinfo dict
+        client_redirect_url: the URL to redirect to on success.
+    """
+    handler = hs.get_oidc_handler()
+    handler._exchange_code = simple_async_mock(return_value={})
+    handler._parse_id_token = simple_async_mock(return_value=userinfo)
+    handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
+
+    state = "state"
+    session = handler._generate_oidc_session_token(
+        state=state,
+        nonce="nonce",
+        client_redirect_url=client_redirect_url,
+        ui_auth_session_id=None,
+    )
+    request = _build_callback_request("code", state, session)
+
+    await handler.handle_oidc_callback(request)
+
+
+def _build_callback_request(
+    code: str,
+    state: str,
+    session: str,
+    user_agent: str = "Browser",
+    ip_address: str = "10.0.0.1",
+):
+    """Builds a fake SynapseRequest to mock the browser callback
+
+    Returns a Mock object which looks like the SynapseRequest we get from a browser
+    after SSO (before we return to the client)
+
+    Args:
+        code: the authorization code which would have been returned by the OIDC
+           provider
+        state: the "state" param which would have been passed around in the
+           query param. Should be the same as was embedded in the session in
+           _build_oidc_session.
+        session: the "session" which would have been passed around in the cookie.
+        user_agent: the user-agent to present
+        ip_address: the IP address to pretend the request came from
+    """
+    request = Mock(
+        spec=[
+            "args",
+            "getCookie",
+            "addCookie",
+            "requestHeaders",
+            "getClientIP",
+            "get_user_agent",
+        ]
+    )
+
+    request.getCookie.return_value = session
+    request.args = {}
+    request.args[b"code"] = [code.encode("utf-8")]
+    request.args[b"state"] = [state.encode("utf-8")]
+    request.getClientIP.return_value = ip_address
+    request.get_user_agent.return_value = user_agent
+    return request