summary refs log tree commit diff
path: root/tests/handlers/test_oidc.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/handlers/test_oidc.py')
-rw-r--r--tests/handlers/test_oidc.py185
1 files changed, 104 insertions, 81 deletions
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index 464e569ac8..c54f1c5797 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -19,8 +19,9 @@ from mock import ANY, Mock, patch
 
 import pymacaroons
 
-from synapse.handlers.oidc_handler import OidcError, OidcMappingProvider
+from synapse.handlers.oidc_handler import OidcError
 from synapse.handlers.sso import MappingException
+from synapse.server import HomeServer
 from synapse.types import UserID
 
 from tests.test_utils import FakeResponse, simple_async_mock
@@ -55,11 +56,14 @@ COOKIE_NAME = b"oidc_session"
 COOKIE_PATH = "/_synapse/oidc"
 
 
-class TestMappingProvider(OidcMappingProvider):
+class TestMappingProvider:
     @staticmethod
     def parse_config(config):
         return
 
+    def __init__(self, config):
+        pass
+
     def get_remote_user_id(self, userinfo):
         return userinfo["sub"]
 
@@ -360,6 +364,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
          - when the userinfo fetching fails
          - when the code exchange fails
         """
+
+        # ensure that we are correctly testing the fallback when "get_extra_attributes"
+        # is not implemented.
+        mapping_provider = self.handler._user_mapping_provider
+        with self.assertRaises(AttributeError):
+            _ = mapping_provider.get_extra_attributes
+
         token = {
             "type": "bearer",
             "id_token": "id_token",
@@ -389,14 +400,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
             client_redirect_url=client_redirect_url,
             ui_auth_session_id=None,
         )
-        request = self._build_callback_request(
+        request = _build_callback_request(
             code, state, session, user_agent=user_agent, ip_address=ip_address
         )
 
         self.get_success(self.handler.handle_oidc_callback(request))
 
         auth_handler.complete_sso_login.assert_called_once_with(
-            expected_user_id, request, client_redirect_url, {},
+            expected_user_id, request, client_redirect_url, None,
         )
         self.handler._exchange_code.assert_called_once_with(code)
         self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce)
@@ -427,7 +438,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.get_success(self.handler.handle_oidc_callback(request))
 
         auth_handler.complete_sso_login.assert_called_once_with(
-            expected_user_id, request, client_redirect_url, {},
+            expected_user_id, request, client_redirect_url, None,
         )
         self.handler._exchange_code.assert_called_once_with(code)
         self.handler._parse_id_token.assert_not_called()
@@ -597,7 +608,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
             client_redirect_url=client_redirect_url,
             ui_auth_session_id=None,
         )
-        request = self._build_callback_request("code", state, session)
+        request = _build_callback_request("code", state, session)
 
         self.get_success(self.handler.handle_oidc_callback(request))
 
@@ -614,9 +625,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "sub": "test_user",
             "username": "test_user",
         }
-        self._make_callback_with_userinfo(userinfo)
+        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user:test", ANY, ANY, {}
+            "@test_user:test", ANY, ANY, None,
         )
         auth_handler.complete_sso_login.reset_mock()
 
@@ -625,9 +636,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "sub": 1234,
             "username": "test_user_2",
         }
-        self._make_callback_with_userinfo(userinfo)
+        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user_2:test", ANY, ANY, {}
+            "@test_user_2:test", ANY, ANY, None,
         )
         auth_handler.complete_sso_login.reset_mock()
 
@@ -638,7 +649,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
             store.register_user(user_id=user3.to_string(), password_hash=None)
         )
         userinfo = {"sub": "test3", "username": "test_user_3"}
-        self._make_callback_with_userinfo(userinfo)
+        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
         auth_handler.complete_sso_login.assert_not_called()
         self.assertRenderedError(
             "mapping_error",
@@ -662,16 +673,16 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "sub": "test",
             "username": "test_user",
         }
-        self._make_callback_with_userinfo(userinfo)
+        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
         auth_handler.complete_sso_login.assert_called_once_with(
-            user.to_string(), ANY, ANY, {},
+            user.to_string(), ANY, ANY, None,
         )
         auth_handler.complete_sso_login.reset_mock()
 
         # Subsequent calls should map to the same mxid.
-        self._make_callback_with_userinfo(userinfo)
+        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
         auth_handler.complete_sso_login.assert_called_once_with(
-            user.to_string(), ANY, ANY, {},
+            user.to_string(), ANY, ANY, None,
         )
         auth_handler.complete_sso_login.reset_mock()
 
@@ -684,9 +695,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "sub": "test1",
             "username": "test_user",
         }
-        self._make_callback_with_userinfo(userinfo)
+        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
         auth_handler.complete_sso_login.assert_called_once_with(
-            user.to_string(), ANY, ANY, {},
+            user.to_string(), ANY, ANY, None,
         )
         auth_handler.complete_sso_login.reset_mock()
 
@@ -705,7 +716,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "sub": "test2",
             "username": "TEST_USER_2",
         }
-        self._make_callback_with_userinfo(userinfo)
+        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
         auth_handler.complete_sso_login.assert_not_called()
         args = self.assertRenderedError("mapping_error")
         self.assertTrue(
@@ -720,14 +731,16 @@ class OidcHandlerTestCase(HomeserverTestCase):
             store.register_user(user_id=user2.to_string(), password_hash=None)
         )
 
-        self._make_callback_with_userinfo(userinfo)
+        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@TEST_USER_2:test", ANY, ANY, {},
+            "@TEST_USER_2:test", ANY, ANY, None,
         )
 
     def test_map_userinfo_to_invalid_localpart(self):
         """If the mapping provider generates an invalid localpart it should be rejected."""
-        self._make_callback_with_userinfo({"sub": "test2", "username": "föö"})
+        self.get_success(
+            _make_callback_with_userinfo(self.hs, {"sub": "test2", "username": "föö"})
+        )
         self.assertRenderedError("mapping_error", "localpart is invalid: föö")
 
     @override_config(
@@ -752,11 +765,11 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "sub": "test",
             "username": "test_user",
         }
-        self._make_callback_with_userinfo(userinfo)
+        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
 
         # test_user is already taken, so test_user1 gets registered instead.
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user1:test", ANY, ANY, {},
+            "@test_user1:test", ANY, ANY, None,
         )
         auth_handler.complete_sso_login.reset_mock()
 
@@ -774,70 +787,80 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "sub": "tester",
             "username": "tester",
         }
-        self._make_callback_with_userinfo(userinfo)
+        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
         auth_handler.complete_sso_login.assert_not_called()
         self.assertRenderedError(
             "mapping_error", "Unable to generate a Matrix ID from the SSO response"
         )
 
-    def _make_callback_with_userinfo(
-        self, userinfo: dict, client_redirect_url: str = "http://client/redirect"
-    ) -> None:
-        self.handler._exchange_code = simple_async_mock(return_value={})
-        self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
-        self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
-        auth_handler = self.hs.get_auth_handler()
-        auth_handler.complete_sso_login = simple_async_mock()
 
-        state = "state"
-        session = self.handler._generate_oidc_session_token(
-            state=state,
-            nonce="nonce",
-            client_redirect_url=client_redirect_url,
-            ui_auth_session_id=None,
-        )
-        request = self._build_callback_request("code", state, session)
+async def _make_callback_with_userinfo(
+    hs: HomeServer, userinfo: dict, client_redirect_url: str = "http://client/redirect"
+) -> None:
+    """Mock up an OIDC callback with the given userinfo dict
 
-        self.get_success(self.handler.handle_oidc_callback(request))
+    We'll pull out the OIDC handler from the homeserver, stub out a couple of methods,
+    and poke in the userinfo dict as if it were the response to an OIDC userinfo call.
 
-    def _build_callback_request(
-        self,
-        code: str,
-        state: str,
-        session: str,
-        user_agent: str = "Browser",
-        ip_address: str = "10.0.0.1",
-    ):
-        """Builds a fake SynapseRequest to mock the browser callback
-
-        Returns a Mock object which looks like the SynapseRequest we get from a browser
-        after SSO (before we return to the client)
-
-        Args:
-            code: the authorization code which would have been returned by the OIDC
-               provider
-            state: the "state" param which would have been passed around in the
-               query param. Should be the same as was embedded in the session in
-               _build_oidc_session.
-            session: the "session" which would have been passed around in the cookie.
-            user_agent: the user-agent to present
-            ip_address: the IP address to pretend the request came from
-        """
-        request = Mock(
-            spec=[
-                "args",
-                "getCookie",
-                "addCookie",
-                "requestHeaders",
-                "getClientIP",
-                "get_user_agent",
-            ]
-        )
+    Args:
+        hs: the HomeServer impl to send the callback to.
+        userinfo: the OIDC userinfo dict
+        client_redirect_url: the URL to redirect to on success.
+    """
+    handler = hs.get_oidc_handler()
+    handler._exchange_code = simple_async_mock(return_value={})
+    handler._parse_id_token = simple_async_mock(return_value=userinfo)
+    handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
 
-        request.getCookie.return_value = session
-        request.args = {}
-        request.args[b"code"] = [code.encode("utf-8")]
-        request.args[b"state"] = [state.encode("utf-8")]
-        request.getClientIP.return_value = ip_address
-        request.get_user_agent.return_value = user_agent
-        return request
+    state = "state"
+    session = handler._generate_oidc_session_token(
+        state=state,
+        nonce="nonce",
+        client_redirect_url=client_redirect_url,
+        ui_auth_session_id=None,
+    )
+    request = _build_callback_request("code", state, session)
+
+    await handler.handle_oidc_callback(request)
+
+
+def _build_callback_request(
+    code: str,
+    state: str,
+    session: str,
+    user_agent: str = "Browser",
+    ip_address: str = "10.0.0.1",
+):
+    """Builds a fake SynapseRequest to mock the browser callback
+
+    Returns a Mock object which looks like the SynapseRequest we get from a browser
+    after SSO (before we return to the client)
+
+    Args:
+        code: the authorization code which would have been returned by the OIDC
+           provider
+        state: the "state" param which would have been passed around in the
+           query param. Should be the same as was embedded in the session in
+           _build_oidc_session.
+        session: the "session" which would have been passed around in the cookie.
+        user_agent: the user-agent to present
+        ip_address: the IP address to pretend the request came from
+    """
+    request = Mock(
+        spec=[
+            "args",
+            "getCookie",
+            "addCookie",
+            "requestHeaders",
+            "getClientIP",
+            "get_user_agent",
+        ]
+    )
+
+    request.getCookie.return_value = session
+    request.args = {}
+    request.args[b"code"] = [code.encode("utf-8")]
+    request.args[b"state"] = [state.encode("utf-8")]
+    request.getClientIP.return_value = ip_address
+    request.get_user_agent.return_value = user_agent
+    return request