summary refs log tree commit diff
path: root/tests/handlers
diff options
context:
space:
mode:
Diffstat (limited to 'tests/handlers')
-rw-r--r--tests/handlers/test_oidc.py151
1 files changed, 83 insertions, 68 deletions
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index bd24375018..c54f1c5797 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -21,6 +21,7 @@ import pymacaroons
 
 from synapse.handlers.oidc_handler import OidcError
 from synapse.handlers.sso import MappingException
+from synapse.server import HomeServer
 from synapse.types import UserID
 
 from tests.test_utils import FakeResponse, simple_async_mock
@@ -399,7 +400,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
             client_redirect_url=client_redirect_url,
             ui_auth_session_id=None,
         )
-        request = self._build_callback_request(
+        request = _build_callback_request(
             code, state, session, user_agent=user_agent, ip_address=ip_address
         )
 
@@ -607,7 +608,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
             client_redirect_url=client_redirect_url,
             ui_auth_session_id=None,
         )
-        request = self._build_callback_request("code", state, session)
+        request = _build_callback_request("code", state, session)
 
         self.get_success(self.handler.handle_oidc_callback(request))
 
@@ -624,7 +625,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "sub": "test_user",
             "username": "test_user",
         }
-        self._make_callback_with_userinfo(userinfo)
+        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
         auth_handler.complete_sso_login.assert_called_once_with(
             "@test_user:test", ANY, ANY, None,
         )
@@ -635,7 +636,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "sub": 1234,
             "username": "test_user_2",
         }
-        self._make_callback_with_userinfo(userinfo)
+        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
         auth_handler.complete_sso_login.assert_called_once_with(
             "@test_user_2:test", ANY, ANY, None,
         )
@@ -648,7 +649,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
             store.register_user(user_id=user3.to_string(), password_hash=None)
         )
         userinfo = {"sub": "test3", "username": "test_user_3"}
-        self._make_callback_with_userinfo(userinfo)
+        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
         auth_handler.complete_sso_login.assert_not_called()
         self.assertRenderedError(
             "mapping_error",
@@ -672,14 +673,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "sub": "test",
             "username": "test_user",
         }
-        self._make_callback_with_userinfo(userinfo)
+        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
         auth_handler.complete_sso_login.assert_called_once_with(
             user.to_string(), ANY, ANY, None,
         )
         auth_handler.complete_sso_login.reset_mock()
 
         # Subsequent calls should map to the same mxid.
-        self._make_callback_with_userinfo(userinfo)
+        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
         auth_handler.complete_sso_login.assert_called_once_with(
             user.to_string(), ANY, ANY, None,
         )
@@ -694,7 +695,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "sub": "test1",
             "username": "test_user",
         }
-        self._make_callback_with_userinfo(userinfo)
+        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
         auth_handler.complete_sso_login.assert_called_once_with(
             user.to_string(), ANY, ANY, None,
         )
@@ -715,7 +716,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "sub": "test2",
             "username": "TEST_USER_2",
         }
-        self._make_callback_with_userinfo(userinfo)
+        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
         auth_handler.complete_sso_login.assert_not_called()
         args = self.assertRenderedError("mapping_error")
         self.assertTrue(
@@ -730,14 +731,16 @@ class OidcHandlerTestCase(HomeserverTestCase):
             store.register_user(user_id=user2.to_string(), password_hash=None)
         )
 
-        self._make_callback_with_userinfo(userinfo)
+        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
         auth_handler.complete_sso_login.assert_called_once_with(
             "@TEST_USER_2:test", ANY, ANY, None,
         )
 
     def test_map_userinfo_to_invalid_localpart(self):
         """If the mapping provider generates an invalid localpart it should be rejected."""
-        self._make_callback_with_userinfo({"sub": "test2", "username": "föö"})
+        self.get_success(
+            _make_callback_with_userinfo(self.hs, {"sub": "test2", "username": "föö"})
+        )
         self.assertRenderedError("mapping_error", "localpart is invalid: föö")
 
     @override_config(
@@ -762,7 +765,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "sub": "test",
             "username": "test_user",
         }
-        self._make_callback_with_userinfo(userinfo)
+        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
 
         # test_user is already taken, so test_user1 gets registered instead.
         auth_handler.complete_sso_login.assert_called_once_with(
@@ -784,68 +787,80 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "sub": "tester",
             "username": "tester",
         }
-        self._make_callback_with_userinfo(userinfo)
+        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
         auth_handler.complete_sso_login.assert_not_called()
         self.assertRenderedError(
             "mapping_error", "Unable to generate a Matrix ID from the SSO response"
         )
 
-    def _make_callback_with_userinfo(
-        self, userinfo: dict, client_redirect_url: str = "http://client/redirect"
-    ) -> None:
-        self.handler._exchange_code = simple_async_mock(return_value={})
-        self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
-        self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
 
-        state = "state"
-        session = self.handler._generate_oidc_session_token(
-            state=state,
-            nonce="nonce",
-            client_redirect_url=client_redirect_url,
-            ui_auth_session_id=None,
-        )
-        request = self._build_callback_request("code", state, session)
+async def _make_callback_with_userinfo(
+    hs: HomeServer, userinfo: dict, client_redirect_url: str = "http://client/redirect"
+) -> None:
+    """Mock up an OIDC callback with the given userinfo dict
 
-        self.get_success(self.handler.handle_oidc_callback(request))
+    We'll pull out the OIDC handler from the homeserver, stub out a couple of methods,
+    and poke in the userinfo dict as if it were the response to an OIDC userinfo call.
 
-    def _build_callback_request(
-        self,
-        code: str,
-        state: str,
-        session: str,
-        user_agent: str = "Browser",
-        ip_address: str = "10.0.0.1",
-    ):
-        """Builds a fake SynapseRequest to mock the browser callback
-
-        Returns a Mock object which looks like the SynapseRequest we get from a browser
-        after SSO (before we return to the client)
-
-        Args:
-            code: the authorization code which would have been returned by the OIDC
-               provider
-            state: the "state" param which would have been passed around in the
-               query param. Should be the same as was embedded in the session in
-               _build_oidc_session.
-            session: the "session" which would have been passed around in the cookie.
-            user_agent: the user-agent to present
-            ip_address: the IP address to pretend the request came from
-        """
-        request = Mock(
-            spec=[
-                "args",
-                "getCookie",
-                "addCookie",
-                "requestHeaders",
-                "getClientIP",
-                "get_user_agent",
-            ]
-        )
+    Args:
+        hs: the HomeServer impl to send the callback to.
+        userinfo: the OIDC userinfo dict
+        client_redirect_url: the URL to redirect to on success.
+    """
+    handler = hs.get_oidc_handler()
+    handler._exchange_code = simple_async_mock(return_value={})
+    handler._parse_id_token = simple_async_mock(return_value=userinfo)
+    handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
 
-        request.getCookie.return_value = session
-        request.args = {}
-        request.args[b"code"] = [code.encode("utf-8")]
-        request.args[b"state"] = [state.encode("utf-8")]
-        request.getClientIP.return_value = ip_address
-        request.get_user_agent.return_value = user_agent
-        return request
+    state = "state"
+    session = handler._generate_oidc_session_token(
+        state=state,
+        nonce="nonce",
+        client_redirect_url=client_redirect_url,
+        ui_auth_session_id=None,
+    )
+    request = _build_callback_request("code", state, session)
+
+    await handler.handle_oidc_callback(request)
+
+
+def _build_callback_request(
+    code: str,
+    state: str,
+    session: str,
+    user_agent: str = "Browser",
+    ip_address: str = "10.0.0.1",
+):
+    """Builds a fake SynapseRequest to mock the browser callback
+
+    Returns a Mock object which looks like the SynapseRequest we get from a browser
+    after SSO (before we return to the client)
+
+    Args:
+        code: the authorization code which would have been returned by the OIDC
+           provider
+        state: the "state" param which would have been passed around in the
+           query param. Should be the same as was embedded in the session in
+           _build_oidc_session.
+        session: the "session" which would have been passed around in the cookie.
+        user_agent: the user-agent to present
+        ip_address: the IP address to pretend the request came from
+    """
+    request = Mock(
+        spec=[
+            "args",
+            "getCookie",
+            "addCookie",
+            "requestHeaders",
+            "getClientIP",
+            "get_user_agent",
+        ]
+    )
+
+    request.getCookie.return_value = session
+    request.args = {}
+    request.args[b"code"] = [code.encode("utf-8")]
+    request.args[b"state"] = [state.encode("utf-8")]
+    request.getClientIP.return_value = ip_address
+    request.get_user_agent.return_value = user_agent
+    return request